From 20307b9a76d3277af9fe582f703e8ba25985e5f0 Mon Sep 17 00:00:00 2001 From: Mike Marchetti Date: Tue, 6 May 2025 11:38:16 -0400 Subject: [PATCH] fix: handle multiple message frames inside packet (#143) Running the gateway as non-tls, but using an external TLS gateway in kubernetes+istio, I determined that the istio TLS gateway would join messages frames into a single TCP packet. The packet read code assumed that a single packet is a message. This is not the case for a TCP stream, since you don't know how the frames are segmented via proxies, etc. The fix turned out more complex that I would have liked, but added a number of unit tests to cover all the corner cases. Likely fragmentation was not working correctly as well, as there was some cases that were previously not handled. Note that this might address issue #126 as well. --- .gitignore | 2 + cmd/rdpgw/protocol/client.go | 84 ++++---- cmd/rdpgw/protocol/common.go | 83 ++++---- cmd/rdpgw/protocol/common_test.go | 297 ++++++++++++++++++++++++++++ cmd/rdpgw/protocol/packet_reader.go | 40 ++++ cmd/rdpgw/protocol/process.go | 242 ++++++++++++----------- cmd/rdpgw/protocol/tunnel.go | 28 ++- 7 files changed, 579 insertions(+), 197 deletions(-) create mode 100644 cmd/rdpgw/protocol/common_test.go create mode 100644 cmd/rdpgw/protocol/packet_reader.go diff --git a/.gitignore b/.gitignore index 08cb523..ee50204 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ go.sum +bin +*.swp diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go index c1c27ac..950eace 100644 --- a/cmd/rdpgw/protocol/client.go +++ b/cmd/rdpgw/protocol/client.go @@ -30,52 +30,58 @@ func (c *ClientConfig) ConnectAndForward() error { c.Session.transportOut.WritePacket(c.handshakeRequest()) for { - pt, sz, pkt, err := readMessage(c.Session.transportIn) + messages, err := readMessage(c.Session.transportIn) if err != nil { log.Printf("Cannot read message from stream %s", err) return err } - switch pt { - case PKT_TYPE_HANDSHAKE_RESPONSE: - caps, err := c.handshakeResponse(pkt) - if err != nil { - log.Printf("Cannot connect to %s due to %s", c.Server, err) - return err + for _, message := range messages { + if message.err != nil { + log.Printf("Cannot read message from stream %p", err) + continue } - log.Printf("Handshake response received. Caps: %d", caps) - c.Session.transportOut.WritePacket(c.tunnelRequest()) - case PKT_TYPE_TUNNEL_RESPONSE: - tid, caps, err := c.tunnelResponse(pkt) - if err != nil { - log.Printf("Cannot setup tunnel due to %s", err) - return err + switch message.packetType { + case PKT_TYPE_HANDSHAKE_RESPONSE: + caps, err := c.handshakeResponse(message.msg) + if err != nil { + log.Printf("Cannot connect to %s due to %s", c.Server, err) + return err + } + log.Printf("Handshake response received. Caps: %d", caps) + c.Session.transportOut.WritePacket(c.tunnelRequest()) + case PKT_TYPE_TUNNEL_RESPONSE: + tid, caps, err := c.tunnelResponse(message.msg) + if err != nil { + log.Printf("Cannot setup tunnel due to %s", err) + return err + } + log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) + c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) + case PKT_TYPE_TUNNEL_AUTH_RESPONSE: + flags, timeout, err := c.tunnelAuthResponse(message.msg) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) + c.Session.transportOut.WritePacket(c.channelRequest()) + case PKT_TYPE_CHANNEL_RESPONSE: + cid, err := c.channelResponse(message.msg) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + if cid < 1 { + log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) + } + log.Printf("Channel creation succesful. Channel id: %d", cid) + //go forward(c.LocalConn, c.Session.transportOut) + case PKT_TYPE_DATA: + receive(message.msg, c.LocalConn) + default: + log.Printf("Unknown packet type received: %d size %d", message.packetType, message.length) } - log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) - c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) - case PKT_TYPE_TUNNEL_AUTH_RESPONSE: - flags, timeout, err := c.tunnelAuthResponse(pkt) - if err != nil { - log.Printf("Cannot do tunnel auth due to %s", err) - return err - } - log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) - c.Session.transportOut.WritePacket(c.channelRequest()) - case PKT_TYPE_CHANNEL_RESPONSE: - cid, err := c.channelResponse(pkt) - if err != nil { - log.Printf("Cannot do tunnel auth due to %s", err) - return err - } - if cid < 1 { - log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) - } - log.Printf("Channel creation succesful. Channel id: %d", cid) - //go forward(c.LocalConn, c.Session.transportOut) - case PKT_TYPE_DATA: - receive(pkt, c.LocalConn) - default: - log.Printf("Unknown packet type received: %d size %d", pt, sz) } } } diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index a8875b2..4509f5c 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -4,12 +4,18 @@ import ( "bytes" "encoding/binary" "errors" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + "fmt" "io" "log" "net" "os" "syscall" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" +) + +const ( + maxFragmentSize = 65536 ) type RedirectFlags struct { @@ -22,46 +28,57 @@ type RedirectFlags struct { EnableAll bool } -// readMessage parses and defragments a packet from a Transport. It returns -// at most the bytes that have been reported by the packet -func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { - fragment := false +func handleMsgFrame(packet *packetReader) *message { + pt, sz, msg, err := readHeader(packet.getPtr()) + if err == nil { + packet.incrementPtr(int(sz)) + return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} + } + + buf := make([]byte, maxFragmentSize) index := 0 - buf := make([]byte, 4096) - for { - size, pkt, err := in.ReadPacket() + // keep parsing thfragment + if len(packet.getPtr()) > len(buf[index:]) { + return &message{packetType: int(pt), length: int(sz), msg: msg, err: fmt.Errorf("fragment exceeded max fragment size")} + } + index += copy(buf[index:], packet.getPtr()) + // Get a new frame + err := packet.read() if err != nil { - return 0, 0, []byte{0, 0}, err + // Failed to make a msg + return &message{packetType: int(pt), length: int(sz), msg: msg, err: err} } - - // check for fragments - var pt uint16 - var sz uint32 - var msg []byte - - if !fragment { - pt, sz, msg, err = readHeader(pkt[:size]) - if err != nil { - fragment = true - index = copy(buf, pkt[:size]) - continue - } - index = 0 - } else { - fragment = false - pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) - // header is corrupted even after defragmenting - if err != nil { - return 0, 0, []byte{0, 0}, err - } - } - if !fragment { - return int(pt), int(sz), msg, nil + pt, sz, msg, err = readHeader(append(buf[:index], packet.getPtr()...)) + if err == nil { + // the increment is based upon how much of the data we have used + // in this packet. The index tells us how much is in the previous frame(s), + // So we remove that from the size of the message. + packet.incrementPtr(int(sz) - index) + return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} } } } +// readMessage parses and defragments a packet from a Transport. It returns +// at most the bytes that have been reported by the packet. +func readMessage(in transport.Transport) ([]*message, error) { + messages := make([]*message, 0) + + packet := newTransportPacket(in) + err := packet.read() + if err != nil { + return messages, err + } + + var message *message + for packet.hasMoreData() { + message = handleMsgFrame(packet) + messages = append(messages, message) + } + return messages, nil +} + // createPacket wraps the data into the protocol packet func createPacket(pktType uint16, data []byte) (packet []byte) { size := len(data) + 8 diff --git a/cmd/rdpgw/protocol/common_test.go b/cmd/rdpgw/protocol/common_test.go new file mode 100644 index 0000000..da97275 --- /dev/null +++ b/cmd/rdpgw/protocol/common_test.go @@ -0,0 +1,297 @@ +package protocol + +import ( + "fmt" + "math/rand" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +type messageMock struct { + buffer []byte + msgBuffer []byte +} + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randBytes(message []byte) { + for index := range message { + message[index] = letterBytes[rand.Intn(len(letterBytes))] + } +} + +func newMessageMock(packetType uint16, message []byte) *messageMock { + randBytes(message) + buf := createPacket(packetType, message) + return &messageMock{msgBuffer: buf[8:], buffer: buf} +} + +type packetMock struct { + bytes []byte + err error +} + +func newPacketMock() *packetMock { + return &packetMock{bytes: make([]byte, 0)} +} + +func (p *packetMock) addBytes(b []byte) { + p.bytes = append(p.bytes, b...) +} + +func (p *packetMock) GetPacket() (int, []byte, error) { + return len(p.bytes), p.bytes, p.err +} + +type transportMock struct { + lock sync.Mutex + packets []*packetMock + packetPtr int +} + +func newTransportMock() *transportMock { + return &transportMock{packets: make([]*packetMock, 0)} +} + +func (t *transportMock) addPacket(p *packetMock) { + t.lock.Lock() + defer t.lock.Unlock() + + t.packets = append(t.packets, p) +} + +func (t *transportMock) ReadPacket() (n int, p []byte, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.packetPtr >= len(t.packets) { + return 0, nil, fmt.Errorf("no packets available") + } + packet := t.packets[t.packetPtr] + t.packetPtr++ + return packet.GetPacket() +} + +func (t *transportMock) WritePacket(b []byte) (n int, err error) { + return 0, fmt.Errorf("not tested") +} + +func (t *transportMock) Close() error { + return nil +} + +func TestSimplePacket(t *testing.T) { + transport := newTransportMock() + m := newMessageMock(6, make([]byte, 10)) + p := newPacketMock() + p.addBytes(m.buffer) + transport.addPacket(p) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 18, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) +} + +func TestMultiMessageInPacket(t *testing.T) { + transport := newTransportMock() + p := newPacketMock() + + m := newMessageMock(6, make([]byte, 10)) + p.addBytes(m.buffer) + + m2 := newMessageMock(8, make([]byte, 12)) + p.addBytes(m2.buffer) + + m3 := newMessageMock(8, make([]byte, 12)) + p.addBytes(m3.buffer) + + transport.addPacket(p) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 3) + assert.Nil(t, messages[0].err) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 18, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + assert.Nil(t, messages[1].err) + assert.Equal(t, 8, messages[1].packetType) + assert.Equal(t, 20, messages[1].length) + assert.Equal(t, m2.msgBuffer, messages[1].msg) + + assert.Nil(t, messages[2].err) + assert.Equal(t, 8, messages[2].packetType) + assert.Equal(t, 20, messages[2].length) + assert.Equal(t, m3.msgBuffer, messages[2].msg) +} + +func TestFragment(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // split the message across 2 packets + p1.addBytes(m.buffer[0:50]) + p2.addBytes(m.buffer[50:]) + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 108, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestDroppedBytes(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // add only partial bytes + p1.addBytes(m.buffer[0:50]) + transport.addPacket(p1) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.Len(t, messages, 1) + assert.NotNil(t, messages[0].err) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestTooMuchData(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // add only partial bytes + p1.addBytes(m.buffer) + p1.addBytes([]byte{0, 0, 0}) + // add some junk bytes + transport.addPacket(p1) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 2) + assert.Nil(t, messages[0].err) + assert.NotNil(t, messages[1].err) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestJumbo(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m := newMessageMock(6, make([]byte, maxFragmentSize)) + // add only partial bytes + p1.addBytes(m.buffer[0 : maxFragmentSize/2]) + p2.addBytes(m.buffer[maxFragmentSize/2:]) + // add some junk bytes + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, m.msgBuffer, messages[0].msg) +} + +func TestManyFragments(t *testing.T) { + transport := newTransportMock() + + m := newMessageMock(6, make([]byte, 256)) + fragmentSize := len(m.buffer) / 5 + bufferSize := len(m.buffer) + for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { + p := newPacketMock() + p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) + transport.addPacket(p) + } + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Nil(t, messages[0].err) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + messages, err = readMessage(transport) + // no more packets + fmt.Println(messages) + assert.NotNil(t, err) +} + +func TestFragmentTooLarge(t *testing.T) { + transport := newTransportMock() + + m := newMessageMock(6, make([]byte, maxFragmentSize*2)) + fragmentSize := len(m.buffer) / 5 + bufferSize := len(m.buffer) + for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { + p := newPacketMock() + p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) + transport.addPacket(p) + } + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages[0].err) + assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error()) +} + +// TestFragmentWithMultiMessage the first message is fragmented, +// while the second message is found whole in the final packet +func TestFragmentWithMultiMessage(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m1 := newMessageMock(6, make([]byte, 100)) + m2 := newMessageMock(6, make([]byte, 10)) + // split the message across 2 packets + p1.addBytes(m1.buffer[0:50]) + p2.addBytes(m1.buffer[50:]) + p2.addBytes(m2.buffer) + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 2) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 108, messages[0].length) + assert.Equal(t, m1.msgBuffer, messages[0].msg) + + assert.Equal(t, 6, messages[1].packetType) + assert.Equal(t, 18, messages[1].length) + assert.Equal(t, m2.msgBuffer, messages[1].msg) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} diff --git a/cmd/rdpgw/protocol/packet_reader.go b/cmd/rdpgw/protocol/packet_reader.go new file mode 100644 index 0000000..1375538 --- /dev/null +++ b/cmd/rdpgw/protocol/packet_reader.go @@ -0,0 +1,40 @@ +package protocol + +import "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + +type packetReader struct { + in transport.Transport + size int + pkt []byte + err error + readPtr int +} + +func newTransportPacket(in transport.Transport) *packetReader { + return &packetReader{in: in} +} + +func (t *packetReader) hasMoreData() bool { + return t.readPtr < t.size +} + +func (t *packetReader) getPtr() []byte { + return t.pkt[t.readPtr:] +} + +func (t *packetReader) incrementPtr(size int) { + t.readPtr += size +} + +func (t *packetReader) read() error { + size, pkt, err := t.in.ReadPacket() + if err != nil { + t.size = 0 + } else { + t.size = size + } + t.pkt = pkt + t.err = err + t.readPtr = 0 + return err +} diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index bf90c07..ceaf16e 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -6,12 +6,13 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "io" "log" "net" "strconv" "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" ) type Processor struct { @@ -43,139 +44,146 @@ const tunnelId = 10 func (p *Processor) Process(ctx context.Context) error { for { - pt, sz, pkt, err := p.tunnel.Read() + //pt, sz, pkt, err := p.tunnel.Read() + messages, err := p.tunnel.Read() if err != nil { log.Printf("Cannot read message from stream %p", err) return err } - switch pt { - case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) - if p.state != SERVER_STATE_INITIALIZED { - log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) - msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) + for _, message := range messages { + if message.err != nil { + log.Printf("Cannot read message from stream %p", err) + continue } - major, minor, _, reqAuth := p.handshakeRequest(pkt) - caps, err := p.matchAuth(reqAuth) - if err != nil { - log.Println(err) - msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) - p.tunnel.Write(msg) - return err - } - msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_HANDSHAKE - case PKT_TYPE_TUNNEL_CREATE: - log.Printf("Tunnel create") - if p.state != SERVER_STATE_HANDSHAKE { - log.Printf("Tunnel create attempted while in wrong state %d != %d", - p.state, SERVER_STATE_HANDSHAKE) - msg := p.tunnelResponse(E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) - } - _, cookie := p.tunnelRequest(pkt) - if p.gw.CheckPAACookie != nil { - if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) - msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + switch message.packetType { + case PKT_TYPE_HANDSHAKE_REQUEST: + log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) + if p.state != SERVER_STATE_INITIALIZED { + log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) p.tunnel.Write(msg) - return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) } - } - msg := p.tunnelResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_TUNNEL_CREATE - case PKT_TYPE_TUNNEL_AUTH: - log.Printf("Tunnel auth") - if p.state != SERVER_STATE_TUNNEL_CREATE { - log.Printf("Tunnel auth attempted while in wrong state %d != %d", - p.state, SERVER_STATE_TUNNEL_CREATE) - msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) - } - client := p.tunnelAuthRequest(pkt) - if p.gw.CheckClientName != nil { - if ok, _ := p.gw.CheckClientName(ctx, client); !ok { - log.Printf("Invalid client name: %s", client) - msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) + major, minor, _, reqAuth := p.handshakeRequest(message.msg) + caps, err := p.matchAuth(reqAuth) + if err != nil { + log.Println(err) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) p.tunnel.Write(msg) - return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) + return err } - } - msg := p.tunnelAuthResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_TUNNEL_AUTHORIZE - case PKT_TYPE_CHANNEL_CREATE: - log.Printf("Channel create") - if p.state != SERVER_STATE_TUNNEL_AUTHORIZE { - log.Printf("Channel create attempted while in wrong state %d != %d", - p.state, SERVER_STATE_TUNNEL_AUTHORIZE) - msg := p.channelResponse(E_PROXY_INTERNALERROR) + msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) p.tunnel.Write(msg) - return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) - } - server, port := p.channelRequest(pkt) - host := net.JoinHostPort(server, strconv.Itoa(int(port))) - if p.gw.CheckHost != nil { - log.Printf("Verifying %s host connection", host) - if ok, _ := p.gw.CheckHost(ctx, host); !ok { - log.Printf("Not allowed to connect to %s by policy handler", host) - msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) + p.state = SERVER_STATE_HANDSHAKE + case PKT_TYPE_TUNNEL_CREATE: + log.Printf("Tunnel create") + if p.state != SERVER_STATE_HANDSHAKE { + log.Printf("Tunnel create attempted while in wrong state %d != %d", + p.state, SERVER_STATE_HANDSHAKE) + msg := p.tunnelResponse(E_PROXY_INTERNALERROR) p.tunnel.Write(msg) - return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) + return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) } - } - log.Printf("Establishing connection to RDP server: %s", host) - p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) - if err != nil { - log.Printf("Error connecting to %s, %s", host, err) - msg := p.channelResponse(E_PROXY_INTERNALERROR) + _, cookie := p.tunnelRequest(message.msg) + if p.gw.CheckPAACookie != nil { + if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { + log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) + msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + } + } + msg := p.tunnelResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_CREATE + case PKT_TYPE_TUNNEL_AUTH: + log.Printf("Tunnel auth") + if p.state != SERVER_STATE_TUNNEL_CREATE { + log.Printf("Tunnel auth attempted while in wrong state %d != %d", + p.state, SERVER_STATE_TUNNEL_CREATE) + msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) + } + client := p.tunnelAuthRequest(message.msg) + if p.gw.CheckClientName != nil { + if ok, _ := p.gw.CheckClientName(ctx, client); !ok { + log.Printf("Invalid client name: %s", client) + msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) + } + } + msg := p.tunnelAuthResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_AUTHORIZE + case PKT_TYPE_CHANNEL_CREATE: + log.Printf("Channel create") + if p.state != SERVER_STATE_TUNNEL_AUTHORIZE { + log.Printf("Channel create attempted while in wrong state %d != %d", + p.state, SERVER_STATE_TUNNEL_AUTHORIZE) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) + } + server, port := p.channelRequest(message.msg) + host := net.JoinHostPort(server, strconv.Itoa(int(port))) + if p.gw.CheckHost != nil { + log.Printf("Verifying %s host connection", host) + if ok, _ := p.gw.CheckHost(ctx, host); !ok { + log.Printf("Not allowed to connect to %s by policy handler", host) + msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) + } + } + log.Printf("Establishing connection to RDP server: %s", host) + p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) + if err != nil { + log.Printf("Error connecting to %s, %s", host, err) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return err + } + p.tunnel.TargetServer = host + log.Printf("Connection established") + msg := p.channelResponse(ERROR_SUCCESS) p.tunnel.Write(msg) - return err - } - p.tunnel.TargetServer = host - log.Printf("Connection established") - msg := p.channelResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - // Make sure to start the flow from the RDP server first otherwise connections - // might hang eventually - go forward(p.tunnel.rwc, p.tunnel) - p.state = SERVER_STATE_CHANNEL_CREATE - case PKT_TYPE_DATA: - if p.state < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) - return errors.New("wrong state") - } - p.state = SERVER_STATE_OPENED - receive(pkt, p.tunnel.rwc) - case PKT_TYPE_KEEPALIVE: - // keepalives can be received while the channel is not open yet - if p.state < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) - return errors.New("wrong state") - } + // Make sure to start the flow from the RDP server first otherwise connections + // might hang eventually + go forward(p.tunnel.rwc, p.tunnel) + p.state = SERVER_STATE_CHANNEL_CREATE + case PKT_TYPE_DATA: + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } + p.state = SERVER_STATE_OPENED + receive(message.msg, p.tunnel.rwc) + case PKT_TYPE_KEEPALIVE: + // keepalives can be received while the channel is not open yet + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } - // avoid concurrency issues - // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) - case PKT_TYPE_CLOSE_CHANNEL: - log.Printf("Close channel") - if p.state != SERVER_STATE_OPENED { - log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) - return errors.New("wrong state") + // avoid concurrency issues + // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + log.Printf("Close channel") + if p.state != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) + return errors.New("wrong state") + } + msg := p.channelCloseResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_CLOSED + return nil + default: + log.Printf("Unknown packet (size %d): %x", message.length, message.msg) } - msg := p.channelCloseResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_CLOSED - return nil - default: - log.Printf("Unknown packet (size %d): %x", sz, pkt) } } } diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index dc3b1d5..0f4a350 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -1,10 +1,11 @@ package protocol import ( - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "net" "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" ) const ( @@ -46,6 +47,13 @@ type Tunnel struct { LastSeen time.Time } +type message struct { + packetType int + length int + msg []byte + err error +} + // Write puts the packet on the transport and updates the statistics for bytes sent func (t *Tunnel) Write(pkt []byte) { n, _ := t.transportOut.WritePacket(pkt) @@ -55,10 +63,14 @@ func (t *Tunnel) Write(pkt []byte) { // Read picks up a packet from the transport and returns the packet type // packet, with the header removed, and the packet size. It updates the // statistics for bytes received -func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { - pt, size, pkt, err = readMessage(t.transportIn) - t.BytesReceived += int64(size) - t.LastSeen = time.Now() - - return pt, size, pkt, err +func (t *Tunnel) Read() ([]*message, error) { + messages, err := readMessage(t.transportIn) + if err != nil { + return nil, err + } + for _, message := range messages { + t.BytesReceived += int64(message.length) + t.LastSeen = time.Now() + } + return messages, err }