mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-20 23:34:22 +02:00
fix: handle multiple message frames inside packet (#143)
Running the gateway as non-tls, but using an external TLS gateway in kubernetes+istio, I determined that the istio TLS gateway would join messages frames into a single TCP packet. The packet read code assumed that a single packet is a message. This is not the case for a TCP stream, since you don't know how the frames are segmented via proxies, etc. The fix turned out more complex that I would have liked, but added a number of unit tests to cover all the corner cases. Likely fragmentation was not working correctly as well, as there was some cases that were previously not handled. Note that this might address issue #126 as well.
This commit is contained in:
parent
6b4e6bdced
commit
20307b9a76
7 changed files with 579 additions and 197 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1,3 @@
|
||||||
go.sum
|
go.sum
|
||||||
|
bin
|
||||||
|
*.swp
|
||||||
|
|
|
@ -30,15 +30,20 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
c.Session.transportOut.WritePacket(c.handshakeRequest())
|
c.Session.transportOut.WritePacket(c.handshakeRequest())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := readMessage(c.Session.transportIn)
|
messages, err := readMessage(c.Session.transportIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %s", err)
|
log.Printf("Cannot read message from stream %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch pt {
|
for _, message := range messages {
|
||||||
|
if message.err != nil {
|
||||||
|
log.Printf("Cannot read message from stream %p", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch message.packetType {
|
||||||
case PKT_TYPE_HANDSHAKE_RESPONSE:
|
case PKT_TYPE_HANDSHAKE_RESPONSE:
|
||||||
caps, err := c.handshakeResponse(pkt)
|
caps, err := c.handshakeResponse(message.msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot connect to %s due to %s", c.Server, err)
|
log.Printf("Cannot connect to %s due to %s", c.Server, err)
|
||||||
return err
|
return err
|
||||||
|
@ -46,7 +51,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
log.Printf("Handshake response received. Caps: %d", caps)
|
log.Printf("Handshake response received. Caps: %d", caps)
|
||||||
c.Session.transportOut.WritePacket(c.tunnelRequest())
|
c.Session.transportOut.WritePacket(c.tunnelRequest())
|
||||||
case PKT_TYPE_TUNNEL_RESPONSE:
|
case PKT_TYPE_TUNNEL_RESPONSE:
|
||||||
tid, caps, err := c.tunnelResponse(pkt)
|
tid, caps, err := c.tunnelResponse(message.msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot setup tunnel due to %s", err)
|
log.Printf("Cannot setup tunnel due to %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -54,7 +59,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
||||||
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
|
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
|
||||||
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
||||||
flags, timeout, err := c.tunnelAuthResponse(pkt)
|
flags, timeout, err := c.tunnelAuthResponse(message.msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot do tunnel auth due to %s", err)
|
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -62,7 +67,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
||||||
c.Session.transportOut.WritePacket(c.channelRequest())
|
c.Session.transportOut.WritePacket(c.channelRequest())
|
||||||
case PKT_TYPE_CHANNEL_RESPONSE:
|
case PKT_TYPE_CHANNEL_RESPONSE:
|
||||||
cid, err := c.channelResponse(pkt)
|
cid, err := c.channelResponse(message.msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot do tunnel auth due to %s", err)
|
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -73,9 +78,10 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||||
//go forward(c.LocalConn, c.Session.transportOut)
|
//go forward(c.LocalConn, c.Session.transportOut)
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
receive(pkt, c.LocalConn)
|
receive(message.msg, c.LocalConn)
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet type received: %d size %d", pt, sz)
|
log.Printf("Unknown packet type received: %d size %d", message.packetType, message.length)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,18 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxFragmentSize = 65536
|
||||||
)
|
)
|
||||||
|
|
||||||
type RedirectFlags struct {
|
type RedirectFlags struct {
|
||||||
|
@ -22,44 +28,55 @@ type RedirectFlags struct {
|
||||||
EnableAll bool
|
EnableAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// readMessage parses and defragments a packet from a Transport. It returns
|
func handleMsgFrame(packet *packetReader) *message {
|
||||||
// at most the bytes that have been reported by the packet
|
pt, sz, msg, err := readHeader(packet.getPtr())
|
||||||
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
|
if err == nil {
|
||||||
fragment := false
|
packet.incrementPtr(int(sz))
|
||||||
|
return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, maxFragmentSize)
|
||||||
index := 0
|
index := 0
|
||||||
buf := make([]byte, 4096)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
size, pkt, err := in.ReadPacket()
|
// keep parsing thfragment
|
||||||
|
if len(packet.getPtr()) > len(buf[index:]) {
|
||||||
|
return &message{packetType: int(pt), length: int(sz), msg: msg, err: fmt.Errorf("fragment exceeded max fragment size")}
|
||||||
|
}
|
||||||
|
index += copy(buf[index:], packet.getPtr())
|
||||||
|
// Get a new frame
|
||||||
|
err := packet.read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, []byte{0, 0}, err
|
// Failed to make a msg
|
||||||
|
return &message{packetType: int(pt), length: int(sz), msg: msg, err: err}
|
||||||
|
}
|
||||||
|
pt, sz, msg, err = readHeader(append(buf[:index], packet.getPtr()...))
|
||||||
|
if err == nil {
|
||||||
|
// the increment is based upon how much of the data we have used
|
||||||
|
// in this packet. The index tells us how much is in the previous frame(s),
|
||||||
|
// So we remove that from the size of the message.
|
||||||
|
packet.incrementPtr(int(sz) - index)
|
||||||
|
return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for fragments
|
// readMessage parses and defragments a packet from a Transport. It returns
|
||||||
var pt uint16
|
// at most the bytes that have been reported by the packet.
|
||||||
var sz uint32
|
func readMessage(in transport.Transport) ([]*message, error) {
|
||||||
var msg []byte
|
messages := make([]*message, 0)
|
||||||
|
|
||||||
if !fragment {
|
packet := newTransportPacket(in)
|
||||||
pt, sz, msg, err = readHeader(pkt[:size])
|
err := packet.read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fragment = true
|
return messages, err
|
||||||
index = copy(buf, pkt[:size])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
index = 0
|
|
||||||
} else {
|
|
||||||
fragment = false
|
|
||||||
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
|
|
||||||
// header is corrupted even after defragmenting
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, []byte{0, 0}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !fragment {
|
|
||||||
return int(pt), int(sz), msg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var message *message
|
||||||
|
for packet.hasMoreData() {
|
||||||
|
message = handleMsgFrame(packet)
|
||||||
|
messages = append(messages, message)
|
||||||
}
|
}
|
||||||
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createPacket wraps the data into the protocol packet
|
// createPacket wraps the data into the protocol packet
|
||||||
|
|
297
cmd/rdpgw/protocol/common_test.go
Normal file
297
cmd/rdpgw/protocol/common_test.go
Normal file
|
@ -0,0 +1,297 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageMock struct {
|
||||||
|
buffer []byte
|
||||||
|
msgBuffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
|
func randBytes(message []byte) {
|
||||||
|
for index := range message {
|
||||||
|
message[index] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessageMock(packetType uint16, message []byte) *messageMock {
|
||||||
|
randBytes(message)
|
||||||
|
buf := createPacket(packetType, message)
|
||||||
|
return &messageMock{msgBuffer: buf[8:], buffer: buf}
|
||||||
|
}
|
||||||
|
|
||||||
|
type packetMock struct {
|
||||||
|
bytes []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacketMock() *packetMock {
|
||||||
|
return &packetMock{bytes: make([]byte, 0)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packetMock) addBytes(b []byte) {
|
||||||
|
p.bytes = append(p.bytes, b...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packetMock) GetPacket() (int, []byte, error) {
|
||||||
|
return len(p.bytes), p.bytes, p.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type transportMock struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
packets []*packetMock
|
||||||
|
packetPtr int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransportMock() *transportMock {
|
||||||
|
return &transportMock{packets: make([]*packetMock, 0)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportMock) addPacket(p *packetMock) {
|
||||||
|
t.lock.Lock()
|
||||||
|
defer t.lock.Unlock()
|
||||||
|
|
||||||
|
t.packets = append(t.packets, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportMock) ReadPacket() (n int, p []byte, err error) {
|
||||||
|
t.lock.Lock()
|
||||||
|
defer t.lock.Unlock()
|
||||||
|
|
||||||
|
if t.packetPtr >= len(t.packets) {
|
||||||
|
return 0, nil, fmt.Errorf("no packets available")
|
||||||
|
}
|
||||||
|
packet := t.packets[t.packetPtr]
|
||||||
|
t.packetPtr++
|
||||||
|
return packet.GetPacket()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportMock) WritePacket(b []byte) (n int, err error) {
|
||||||
|
return 0, fmt.Errorf("not tested")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportMock) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSimplePacket(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
m := newMessageMock(6, make([]byte, 10))
|
||||||
|
p := newPacketMock()
|
||||||
|
p.addBytes(m.buffer)
|
||||||
|
transport.addPacket(p)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 1)
|
||||||
|
assert.Equal(t, 6, messages[0].packetType)
|
||||||
|
assert.Equal(t, 18, messages[0].length)
|
||||||
|
assert.Equal(t, m.msgBuffer, messages[0].msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiMessageInPacket(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p := newPacketMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, 10))
|
||||||
|
p.addBytes(m.buffer)
|
||||||
|
|
||||||
|
m2 := newMessageMock(8, make([]byte, 12))
|
||||||
|
p.addBytes(m2.buffer)
|
||||||
|
|
||||||
|
m3 := newMessageMock(8, make([]byte, 12))
|
||||||
|
p.addBytes(m3.buffer)
|
||||||
|
|
||||||
|
transport.addPacket(p)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 3)
|
||||||
|
assert.Nil(t, messages[0].err)
|
||||||
|
assert.Equal(t, 6, messages[0].packetType)
|
||||||
|
assert.Equal(t, 18, messages[0].length)
|
||||||
|
assert.Equal(t, m.msgBuffer, messages[0].msg)
|
||||||
|
|
||||||
|
assert.Nil(t, messages[1].err)
|
||||||
|
assert.Equal(t, 8, messages[1].packetType)
|
||||||
|
assert.Equal(t, 20, messages[1].length)
|
||||||
|
assert.Equal(t, m2.msgBuffer, messages[1].msg)
|
||||||
|
|
||||||
|
assert.Nil(t, messages[2].err)
|
||||||
|
assert.Equal(t, 8, messages[2].packetType)
|
||||||
|
assert.Equal(t, 20, messages[2].length)
|
||||||
|
assert.Equal(t, m3.msgBuffer, messages[2].msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFragment(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p1 := newPacketMock()
|
||||||
|
p2 := newPacketMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, 100))
|
||||||
|
// split the message across 2 packets
|
||||||
|
p1.addBytes(m.buffer[0:50])
|
||||||
|
p2.addBytes(m.buffer[50:])
|
||||||
|
transport.addPacket(p1)
|
||||||
|
transport.addPacket(p2)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 1)
|
||||||
|
assert.Equal(t, 6, messages[0].packetType)
|
||||||
|
assert.Equal(t, 108, messages[0].length)
|
||||||
|
assert.Equal(t, m.msgBuffer, messages[0].msg)
|
||||||
|
|
||||||
|
_, err = readMessage(transport)
|
||||||
|
// no more packets
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDroppedBytes(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p1 := newPacketMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, 100))
|
||||||
|
// add only partial bytes
|
||||||
|
p1.addBytes(m.buffer[0:50])
|
||||||
|
transport.addPacket(p1)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Len(t, messages, 1)
|
||||||
|
assert.NotNil(t, messages[0].err)
|
||||||
|
|
||||||
|
_, err = readMessage(transport)
|
||||||
|
// no more packets
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTooMuchData(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p1 := newPacketMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, 100))
|
||||||
|
// add only partial bytes
|
||||||
|
p1.addBytes(m.buffer)
|
||||||
|
p1.addBytes([]byte{0, 0, 0})
|
||||||
|
// add some junk bytes
|
||||||
|
transport.addPacket(p1)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 2)
|
||||||
|
assert.Nil(t, messages[0].err)
|
||||||
|
assert.NotNil(t, messages[1].err)
|
||||||
|
|
||||||
|
_, err = readMessage(transport)
|
||||||
|
// no more packets
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJumbo(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p1 := newPacketMock()
|
||||||
|
p2 := newPacketMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, maxFragmentSize))
|
||||||
|
// add only partial bytes
|
||||||
|
p1.addBytes(m.buffer[0 : maxFragmentSize/2])
|
||||||
|
p2.addBytes(m.buffer[maxFragmentSize/2:])
|
||||||
|
// add some junk bytes
|
||||||
|
transport.addPacket(p1)
|
||||||
|
transport.addPacket(p2)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 1)
|
||||||
|
assert.Equal(t, m.msgBuffer, messages[0].msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyFragments(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, 256))
|
||||||
|
fragmentSize := len(m.buffer) / 5
|
||||||
|
bufferSize := len(m.buffer)
|
||||||
|
for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize {
|
||||||
|
p := newPacketMock()
|
||||||
|
p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)])
|
||||||
|
transport.addPacket(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 1)
|
||||||
|
assert.Nil(t, messages[0].err)
|
||||||
|
assert.Equal(t, m.msgBuffer, messages[0].msg)
|
||||||
|
|
||||||
|
messages, err = readMessage(transport)
|
||||||
|
// no more packets
|
||||||
|
fmt.Println(messages)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFragmentTooLarge(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
|
||||||
|
m := newMessageMock(6, make([]byte, maxFragmentSize*2))
|
||||||
|
fragmentSize := len(m.buffer) / 5
|
||||||
|
bufferSize := len(m.buffer)
|
||||||
|
for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize {
|
||||||
|
p := newPacketMock()
|
||||||
|
p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)])
|
||||||
|
transport.addPacket(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages[0].err)
|
||||||
|
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFragmentWithMultiMessage the first message is fragmented,
|
||||||
|
// while the second message is found whole in the final packet
|
||||||
|
func TestFragmentWithMultiMessage(t *testing.T) {
|
||||||
|
transport := newTransportMock()
|
||||||
|
p1 := newPacketMock()
|
||||||
|
p2 := newPacketMock()
|
||||||
|
|
||||||
|
m1 := newMessageMock(6, make([]byte, 100))
|
||||||
|
m2 := newMessageMock(6, make([]byte, 10))
|
||||||
|
// split the message across 2 packets
|
||||||
|
p1.addBytes(m1.buffer[0:50])
|
||||||
|
p2.addBytes(m1.buffer[50:])
|
||||||
|
p2.addBytes(m2.buffer)
|
||||||
|
transport.addPacket(p1)
|
||||||
|
transport.addPacket(p2)
|
||||||
|
|
||||||
|
messages, err := readMessage(transport)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, messages)
|
||||||
|
assert.Len(t, messages, 2)
|
||||||
|
assert.Equal(t, 6, messages[0].packetType)
|
||||||
|
assert.Equal(t, 108, messages[0].length)
|
||||||
|
assert.Equal(t, m1.msgBuffer, messages[0].msg)
|
||||||
|
|
||||||
|
assert.Equal(t, 6, messages[1].packetType)
|
||||||
|
assert.Equal(t, 18, messages[1].length)
|
||||||
|
assert.Equal(t, m2.msgBuffer, messages[1].msg)
|
||||||
|
|
||||||
|
_, err = readMessage(transport)
|
||||||
|
// no more packets
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
40
cmd/rdpgw/protocol/packet_reader.go
Normal file
40
cmd/rdpgw/protocol/packet_reader.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
|
|
||||||
|
type packetReader struct {
|
||||||
|
in transport.Transport
|
||||||
|
size int
|
||||||
|
pkt []byte
|
||||||
|
err error
|
||||||
|
readPtr int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransportPacket(in transport.Transport) *packetReader {
|
||||||
|
return &packetReader{in: in}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetReader) hasMoreData() bool {
|
||||||
|
return t.readPtr < t.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetReader) getPtr() []byte {
|
||||||
|
return t.pkt[t.readPtr:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetReader) incrementPtr(size int) {
|
||||||
|
t.readPtr += size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *packetReader) read() error {
|
||||||
|
size, pkt, err := t.in.ReadPacket()
|
||||||
|
if err != nil {
|
||||||
|
t.size = 0
|
||||||
|
} else {
|
||||||
|
t.size = size
|
||||||
|
}
|
||||||
|
t.pkt = pkt
|
||||||
|
t.err = err
|
||||||
|
t.readPtr = 0
|
||||||
|
return err
|
||||||
|
}
|
|
@ -6,12 +6,13 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Processor struct {
|
type Processor struct {
|
||||||
|
@ -43,13 +44,19 @@ const tunnelId = 10
|
||||||
|
|
||||||
func (p *Processor) Process(ctx context.Context) error {
|
func (p *Processor) Process(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := p.tunnel.Read()
|
//pt, sz, pkt, err := p.tunnel.Read()
|
||||||
|
messages, err := p.tunnel.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %p", err)
|
log.Printf("Cannot read message from stream %p", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch pt {
|
for _, message := range messages {
|
||||||
|
if message.err != nil {
|
||||||
|
log.Printf("Cannot read message from stream %p", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch message.packetType {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
||||||
if p.state != SERVER_STATE_INITIALIZED {
|
if p.state != SERVER_STATE_INITIALIZED {
|
||||||
|
@ -58,7 +65,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
major, minor, _, reqAuth := p.handshakeRequest(pkt)
|
major, minor, _, reqAuth := p.handshakeRequest(message.msg)
|
||||||
caps, err := p.matchAuth(reqAuth)
|
caps, err := p.matchAuth(reqAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
|
@ -78,7 +85,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
_, cookie := p.tunnelRequest(pkt)
|
_, cookie := p.tunnelRequest(message.msg)
|
||||||
if p.gw.CheckPAACookie != nil {
|
if p.gw.CheckPAACookie != nil {
|
||||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
|
||||||
|
@ -99,7 +106,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
client := p.tunnelAuthRequest(pkt)
|
client := p.tunnelAuthRequest(message.msg)
|
||||||
if p.gw.CheckClientName != nil {
|
if p.gw.CheckClientName != nil {
|
||||||
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
||||||
log.Printf("Invalid client name: %s", client)
|
log.Printf("Invalid client name: %s", client)
|
||||||
|
@ -120,7 +127,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
server, port := p.channelRequest(pkt)
|
server, port := p.channelRequest(message.msg)
|
||||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
if p.gw.CheckHost != nil {
|
if p.gw.CheckHost != nil {
|
||||||
log.Printf("Verifying %s host connection", host)
|
log.Printf("Verifying %s host connection", host)
|
||||||
|
@ -154,7 +161,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
p.state = SERVER_STATE_OPENED
|
p.state = SERVER_STATE_OPENED
|
||||||
receive(pkt, p.tunnel.rwc)
|
receive(message.msg, p.tunnel.rwc)
|
||||||
case PKT_TYPE_KEEPALIVE:
|
case PKT_TYPE_KEEPALIVE:
|
||||||
// keepalives can be received while the channel is not open yet
|
// keepalives can be received while the channel is not open yet
|
||||||
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
||||||
|
@ -175,7 +182,8 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
p.state = SERVER_STATE_CLOSED
|
p.state = SERVER_STATE_CLOSED
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
log.Printf("Unknown packet (size %d): %x", message.length, message.msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -46,6 +47,13 @@ type Tunnel struct {
|
||||||
LastSeen time.Time
|
LastSeen time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type message struct {
|
||||||
|
packetType int
|
||||||
|
length int
|
||||||
|
msg []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// Write puts the packet on the transport and updates the statistics for bytes sent
|
// Write puts the packet on the transport and updates the statistics for bytes sent
|
||||||
func (t *Tunnel) Write(pkt []byte) {
|
func (t *Tunnel) Write(pkt []byte) {
|
||||||
n, _ := t.transportOut.WritePacket(pkt)
|
n, _ := t.transportOut.WritePacket(pkt)
|
||||||
|
@ -55,10 +63,14 @@ func (t *Tunnel) Write(pkt []byte) {
|
||||||
// Read picks up a packet from the transport and returns the packet type
|
// Read picks up a packet from the transport and returns the packet type
|
||||||
// packet, with the header removed, and the packet size. It updates the
|
// packet, with the header removed, and the packet size. It updates the
|
||||||
// statistics for bytes received
|
// statistics for bytes received
|
||||||
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
func (t *Tunnel) Read() ([]*message, error) {
|
||||||
pt, size, pkt, err = readMessage(t.transportIn)
|
messages, err := readMessage(t.transportIn)
|
||||||
t.BytesReceived += int64(size)
|
if err != nil {
|
||||||
t.LastSeen = time.Now()
|
return nil, err
|
||||||
|
}
|
||||||
return pt, size, pkt, err
|
for _, message := range messages {
|
||||||
|
t.BytesReceived += int64(message.length)
|
||||||
|
t.LastSeen = time.Now()
|
||||||
|
}
|
||||||
|
return messages, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue