diff --git a/.gitignore b/.gitignore index 08cb523..ee50204 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ go.sum +bin +*.swp diff --git a/cmd/rdpgw/protocol/client.go b/cmd/rdpgw/protocol/client.go index c1c27ac..950eace 100644 --- a/cmd/rdpgw/protocol/client.go +++ b/cmd/rdpgw/protocol/client.go @@ -30,52 +30,58 @@ func (c *ClientConfig) ConnectAndForward() error { c.Session.transportOut.WritePacket(c.handshakeRequest()) for { - pt, sz, pkt, err := readMessage(c.Session.transportIn) + messages, err := readMessage(c.Session.transportIn) if err != nil { log.Printf("Cannot read message from stream %s", err) return err } - switch pt { - case PKT_TYPE_HANDSHAKE_RESPONSE: - caps, err := c.handshakeResponse(pkt) - if err != nil { - log.Printf("Cannot connect to %s due to %s", c.Server, err) - return err + for _, message := range messages { + if message.err != nil { + log.Printf("Cannot read message from stream %p", err) + continue } - log.Printf("Handshake response received. Caps: %d", caps) - c.Session.transportOut.WritePacket(c.tunnelRequest()) - case PKT_TYPE_TUNNEL_RESPONSE: - tid, caps, err := c.tunnelResponse(pkt) - if err != nil { - log.Printf("Cannot setup tunnel due to %s", err) - return err + switch message.packetType { + case PKT_TYPE_HANDSHAKE_RESPONSE: + caps, err := c.handshakeResponse(message.msg) + if err != nil { + log.Printf("Cannot connect to %s due to %s", c.Server, err) + return err + } + log.Printf("Handshake response received. Caps: %d", caps) + c.Session.transportOut.WritePacket(c.tunnelRequest()) + case PKT_TYPE_TUNNEL_RESPONSE: + tid, caps, err := c.tunnelResponse(message.msg) + if err != nil { + log.Printf("Cannot setup tunnel due to %s", err) + return err + } + log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) + c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) + case PKT_TYPE_TUNNEL_AUTH_RESPONSE: + flags, timeout, err := c.tunnelAuthResponse(message.msg) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) + c.Session.transportOut.WritePacket(c.channelRequest()) + case PKT_TYPE_CHANNEL_RESPONSE: + cid, err := c.channelResponse(message.msg) + if err != nil { + log.Printf("Cannot do tunnel auth due to %s", err) + return err + } + if cid < 1 { + log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) + } + log.Printf("Channel creation succesful. Channel id: %d", cid) + //go forward(c.LocalConn, c.Session.transportOut) + case PKT_TYPE_DATA: + receive(message.msg, c.LocalConn) + default: + log.Printf("Unknown packet type received: %d size %d", message.packetType, message.length) } - log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) - c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) - case PKT_TYPE_TUNNEL_AUTH_RESPONSE: - flags, timeout, err := c.tunnelAuthResponse(pkt) - if err != nil { - log.Printf("Cannot do tunnel auth due to %s", err) - return err - } - log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) - c.Session.transportOut.WritePacket(c.channelRequest()) - case PKT_TYPE_CHANNEL_RESPONSE: - cid, err := c.channelResponse(pkt) - if err != nil { - log.Printf("Cannot do tunnel auth due to %s", err) - return err - } - if cid < 1 { - log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) - } - log.Printf("Channel creation succesful. Channel id: %d", cid) - //go forward(c.LocalConn, c.Session.transportOut) - case PKT_TYPE_DATA: - receive(pkt, c.LocalConn) - default: - log.Printf("Unknown packet type received: %d size %d", pt, sz) } } } diff --git a/cmd/rdpgw/protocol/common.go b/cmd/rdpgw/protocol/common.go index a8875b2..4509f5c 100644 --- a/cmd/rdpgw/protocol/common.go +++ b/cmd/rdpgw/protocol/common.go @@ -4,12 +4,18 @@ import ( "bytes" "encoding/binary" "errors" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + "fmt" "io" "log" "net" "os" "syscall" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" +) + +const ( + maxFragmentSize = 65536 ) type RedirectFlags struct { @@ -22,46 +28,57 @@ type RedirectFlags struct { EnableAll bool } -// readMessage parses and defragments a packet from a Transport. It returns -// at most the bytes that have been reported by the packet -func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { - fragment := false +func handleMsgFrame(packet *packetReader) *message { + pt, sz, msg, err := readHeader(packet.getPtr()) + if err == nil { + packet.incrementPtr(int(sz)) + return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} + } + + buf := make([]byte, maxFragmentSize) index := 0 - buf := make([]byte, 4096) - for { - size, pkt, err := in.ReadPacket() + // keep parsing thfragment + if len(packet.getPtr()) > len(buf[index:]) { + return &message{packetType: int(pt), length: int(sz), msg: msg, err: fmt.Errorf("fragment exceeded max fragment size")} + } + index += copy(buf[index:], packet.getPtr()) + // Get a new frame + err := packet.read() if err != nil { - return 0, 0, []byte{0, 0}, err + // Failed to make a msg + return &message{packetType: int(pt), length: int(sz), msg: msg, err: err} } - - // check for fragments - var pt uint16 - var sz uint32 - var msg []byte - - if !fragment { - pt, sz, msg, err = readHeader(pkt[:size]) - if err != nil { - fragment = true - index = copy(buf, pkt[:size]) - continue - } - index = 0 - } else { - fragment = false - pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...)) - // header is corrupted even after defragmenting - if err != nil { - return 0, 0, []byte{0, 0}, err - } - } - if !fragment { - return int(pt), int(sz), msg, nil + pt, sz, msg, err = readHeader(append(buf[:index], packet.getPtr()...)) + if err == nil { + // the increment is based upon how much of the data we have used + // in this packet. The index tells us how much is in the previous frame(s), + // So we remove that from the size of the message. + packet.incrementPtr(int(sz) - index) + return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} } } } +// readMessage parses and defragments a packet from a Transport. It returns +// at most the bytes that have been reported by the packet. +func readMessage(in transport.Transport) ([]*message, error) { + messages := make([]*message, 0) + + packet := newTransportPacket(in) + err := packet.read() + if err != nil { + return messages, err + } + + var message *message + for packet.hasMoreData() { + message = handleMsgFrame(packet) + messages = append(messages, message) + } + return messages, nil +} + // createPacket wraps the data into the protocol packet func createPacket(pktType uint16, data []byte) (packet []byte) { size := len(data) + 8 diff --git a/cmd/rdpgw/protocol/common_test.go b/cmd/rdpgw/protocol/common_test.go new file mode 100644 index 0000000..da97275 --- /dev/null +++ b/cmd/rdpgw/protocol/common_test.go @@ -0,0 +1,297 @@ +package protocol + +import ( + "fmt" + "math/rand" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +type messageMock struct { + buffer []byte + msgBuffer []byte +} + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randBytes(message []byte) { + for index := range message { + message[index] = letterBytes[rand.Intn(len(letterBytes))] + } +} + +func newMessageMock(packetType uint16, message []byte) *messageMock { + randBytes(message) + buf := createPacket(packetType, message) + return &messageMock{msgBuffer: buf[8:], buffer: buf} +} + +type packetMock struct { + bytes []byte + err error +} + +func newPacketMock() *packetMock { + return &packetMock{bytes: make([]byte, 0)} +} + +func (p *packetMock) addBytes(b []byte) { + p.bytes = append(p.bytes, b...) +} + +func (p *packetMock) GetPacket() (int, []byte, error) { + return len(p.bytes), p.bytes, p.err +} + +type transportMock struct { + lock sync.Mutex + packets []*packetMock + packetPtr int +} + +func newTransportMock() *transportMock { + return &transportMock{packets: make([]*packetMock, 0)} +} + +func (t *transportMock) addPacket(p *packetMock) { + t.lock.Lock() + defer t.lock.Unlock() + + t.packets = append(t.packets, p) +} + +func (t *transportMock) ReadPacket() (n int, p []byte, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.packetPtr >= len(t.packets) { + return 0, nil, fmt.Errorf("no packets available") + } + packet := t.packets[t.packetPtr] + t.packetPtr++ + return packet.GetPacket() +} + +func (t *transportMock) WritePacket(b []byte) (n int, err error) { + return 0, fmt.Errorf("not tested") +} + +func (t *transportMock) Close() error { + return nil +} + +func TestSimplePacket(t *testing.T) { + transport := newTransportMock() + m := newMessageMock(6, make([]byte, 10)) + p := newPacketMock() + p.addBytes(m.buffer) + transport.addPacket(p) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 18, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) +} + +func TestMultiMessageInPacket(t *testing.T) { + transport := newTransportMock() + p := newPacketMock() + + m := newMessageMock(6, make([]byte, 10)) + p.addBytes(m.buffer) + + m2 := newMessageMock(8, make([]byte, 12)) + p.addBytes(m2.buffer) + + m3 := newMessageMock(8, make([]byte, 12)) + p.addBytes(m3.buffer) + + transport.addPacket(p) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 3) + assert.Nil(t, messages[0].err) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 18, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + assert.Nil(t, messages[1].err) + assert.Equal(t, 8, messages[1].packetType) + assert.Equal(t, 20, messages[1].length) + assert.Equal(t, m2.msgBuffer, messages[1].msg) + + assert.Nil(t, messages[2].err) + assert.Equal(t, 8, messages[2].packetType) + assert.Equal(t, 20, messages[2].length) + assert.Equal(t, m3.msgBuffer, messages[2].msg) +} + +func TestFragment(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // split the message across 2 packets + p1.addBytes(m.buffer[0:50]) + p2.addBytes(m.buffer[50:]) + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 108, messages[0].length) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestDroppedBytes(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // add only partial bytes + p1.addBytes(m.buffer[0:50]) + transport.addPacket(p1) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.Len(t, messages, 1) + assert.NotNil(t, messages[0].err) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestTooMuchData(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + + m := newMessageMock(6, make([]byte, 100)) + // add only partial bytes + p1.addBytes(m.buffer) + p1.addBytes([]byte{0, 0, 0}) + // add some junk bytes + transport.addPacket(p1) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 2) + assert.Nil(t, messages[0].err) + assert.NotNil(t, messages[1].err) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} + +func TestJumbo(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m := newMessageMock(6, make([]byte, maxFragmentSize)) + // add only partial bytes + p1.addBytes(m.buffer[0 : maxFragmentSize/2]) + p2.addBytes(m.buffer[maxFragmentSize/2:]) + // add some junk bytes + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Equal(t, m.msgBuffer, messages[0].msg) +} + +func TestManyFragments(t *testing.T) { + transport := newTransportMock() + + m := newMessageMock(6, make([]byte, 256)) + fragmentSize := len(m.buffer) / 5 + bufferSize := len(m.buffer) + for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { + p := newPacketMock() + p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) + transport.addPacket(p) + } + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 1) + assert.Nil(t, messages[0].err) + assert.Equal(t, m.msgBuffer, messages[0].msg) + + messages, err = readMessage(transport) + // no more packets + fmt.Println(messages) + assert.NotNil(t, err) +} + +func TestFragmentTooLarge(t *testing.T) { + transport := newTransportMock() + + m := newMessageMock(6, make([]byte, maxFragmentSize*2)) + fragmentSize := len(m.buffer) / 5 + bufferSize := len(m.buffer) + for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { + p := newPacketMock() + p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) + transport.addPacket(p) + } + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages[0].err) + assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error()) +} + +// TestFragmentWithMultiMessage the first message is fragmented, +// while the second message is found whole in the final packet +func TestFragmentWithMultiMessage(t *testing.T) { + transport := newTransportMock() + p1 := newPacketMock() + p2 := newPacketMock() + + m1 := newMessageMock(6, make([]byte, 100)) + m2 := newMessageMock(6, make([]byte, 10)) + // split the message across 2 packets + p1.addBytes(m1.buffer[0:50]) + p2.addBytes(m1.buffer[50:]) + p2.addBytes(m2.buffer) + transport.addPacket(p1) + transport.addPacket(p2) + + messages, err := readMessage(transport) + assert.Nil(t, err) + assert.NotNil(t, messages) + assert.Len(t, messages, 2) + assert.Equal(t, 6, messages[0].packetType) + assert.Equal(t, 108, messages[0].length) + assert.Equal(t, m1.msgBuffer, messages[0].msg) + + assert.Equal(t, 6, messages[1].packetType) + assert.Equal(t, 18, messages[1].length) + assert.Equal(t, m2.msgBuffer, messages[1].msg) + + _, err = readMessage(transport) + // no more packets + assert.NotNil(t, err) +} diff --git a/cmd/rdpgw/protocol/packet_reader.go b/cmd/rdpgw/protocol/packet_reader.go new file mode 100644 index 0000000..1375538 --- /dev/null +++ b/cmd/rdpgw/protocol/packet_reader.go @@ -0,0 +1,40 @@ +package protocol + +import "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" + +type packetReader struct { + in transport.Transport + size int + pkt []byte + err error + readPtr int +} + +func newTransportPacket(in transport.Transport) *packetReader { + return &packetReader{in: in} +} + +func (t *packetReader) hasMoreData() bool { + return t.readPtr < t.size +} + +func (t *packetReader) getPtr() []byte { + return t.pkt[t.readPtr:] +} + +func (t *packetReader) incrementPtr(size int) { + t.readPtr += size +} + +func (t *packetReader) read() error { + size, pkt, err := t.in.ReadPacket() + if err != nil { + t.size = 0 + } else { + t.size = size + } + t.pkt = pkt + t.err = err + t.readPtr = 0 + return err +} diff --git a/cmd/rdpgw/protocol/process.go b/cmd/rdpgw/protocol/process.go index bf90c07..ceaf16e 100644 --- a/cmd/rdpgw/protocol/process.go +++ b/cmd/rdpgw/protocol/process.go @@ -6,12 +6,13 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "io" "log" "net" "strconv" "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" ) type Processor struct { @@ -43,139 +44,146 @@ const tunnelId = 10 func (p *Processor) Process(ctx context.Context) error { for { - pt, sz, pkt, err := p.tunnel.Read() + //pt, sz, pkt, err := p.tunnel.Read() + messages, err := p.tunnel.Read() if err != nil { log.Printf("Cannot read message from stream %p", err) return err } - switch pt { - case PKT_TYPE_HANDSHAKE_REQUEST: - log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) - if p.state != SERVER_STATE_INITIALIZED { - log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) - msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) + for _, message := range messages { + if message.err != nil { + log.Printf("Cannot read message from stream %p", err) + continue } - major, minor, _, reqAuth := p.handshakeRequest(pkt) - caps, err := p.matchAuth(reqAuth) - if err != nil { - log.Println(err) - msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) - p.tunnel.Write(msg) - return err - } - msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_HANDSHAKE - case PKT_TYPE_TUNNEL_CREATE: - log.Printf("Tunnel create") - if p.state != SERVER_STATE_HANDSHAKE { - log.Printf("Tunnel create attempted while in wrong state %d != %d", - p.state, SERVER_STATE_HANDSHAKE) - msg := p.tunnelResponse(E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) - } - _, cookie := p.tunnelRequest(pkt) - if p.gw.CheckPAACookie != nil { - if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { - log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) - msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + switch message.packetType { + case PKT_TYPE_HANDSHAKE_REQUEST: + log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) + if p.state != SERVER_STATE_INITIALIZED { + log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) p.tunnel.Write(msg) - return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) } - } - msg := p.tunnelResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_TUNNEL_CREATE - case PKT_TYPE_TUNNEL_AUTH: - log.Printf("Tunnel auth") - if p.state != SERVER_STATE_TUNNEL_CREATE { - log.Printf("Tunnel auth attempted while in wrong state %d != %d", - p.state, SERVER_STATE_TUNNEL_CREATE) - msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) - p.tunnel.Write(msg) - return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) - } - client := p.tunnelAuthRequest(pkt) - if p.gw.CheckClientName != nil { - if ok, _ := p.gw.CheckClientName(ctx, client); !ok { - log.Printf("Invalid client name: %s", client) - msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) + major, minor, _, reqAuth := p.handshakeRequest(message.msg) + caps, err := p.matchAuth(reqAuth) + if err != nil { + log.Println(err) + msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) p.tunnel.Write(msg) - return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) + return err } - } - msg := p.tunnelAuthResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_TUNNEL_AUTHORIZE - case PKT_TYPE_CHANNEL_CREATE: - log.Printf("Channel create") - if p.state != SERVER_STATE_TUNNEL_AUTHORIZE { - log.Printf("Channel create attempted while in wrong state %d != %d", - p.state, SERVER_STATE_TUNNEL_AUTHORIZE) - msg := p.channelResponse(E_PROXY_INTERNALERROR) + msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) p.tunnel.Write(msg) - return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) - } - server, port := p.channelRequest(pkt) - host := net.JoinHostPort(server, strconv.Itoa(int(port))) - if p.gw.CheckHost != nil { - log.Printf("Verifying %s host connection", host) - if ok, _ := p.gw.CheckHost(ctx, host); !ok { - log.Printf("Not allowed to connect to %s by policy handler", host) - msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) + p.state = SERVER_STATE_HANDSHAKE + case PKT_TYPE_TUNNEL_CREATE: + log.Printf("Tunnel create") + if p.state != SERVER_STATE_HANDSHAKE { + log.Printf("Tunnel create attempted while in wrong state %d != %d", + p.state, SERVER_STATE_HANDSHAKE) + msg := p.tunnelResponse(E_PROXY_INTERNALERROR) p.tunnel.Write(msg) - return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) + return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) } - } - log.Printf("Establishing connection to RDP server: %s", host) - p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) - if err != nil { - log.Printf("Error connecting to %s, %s", host, err) - msg := p.channelResponse(E_PROXY_INTERNALERROR) + _, cookie := p.tunnelRequest(message.msg) + if p.gw.CheckPAACookie != nil { + if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { + log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) + msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) + } + } + msg := p.tunnelResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_CREATE + case PKT_TYPE_TUNNEL_AUTH: + log.Printf("Tunnel auth") + if p.state != SERVER_STATE_TUNNEL_CREATE { + log.Printf("Tunnel auth attempted while in wrong state %d != %d", + p.state, SERVER_STATE_TUNNEL_CREATE) + msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) + } + client := p.tunnelAuthRequest(message.msg) + if p.gw.CheckClientName != nil { + if ok, _ := p.gw.CheckClientName(ctx, client); !ok { + log.Printf("Invalid client name: %s", client) + msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) + } + } + msg := p.tunnelAuthResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_TUNNEL_AUTHORIZE + case PKT_TYPE_CHANNEL_CREATE: + log.Printf("Channel create") + if p.state != SERVER_STATE_TUNNEL_AUTHORIZE { + log.Printf("Channel create attempted while in wrong state %d != %d", + p.state, SERVER_STATE_TUNNEL_AUTHORIZE) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) + } + server, port := p.channelRequest(message.msg) + host := net.JoinHostPort(server, strconv.Itoa(int(port))) + if p.gw.CheckHost != nil { + log.Printf("Verifying %s host connection", host) + if ok, _ := p.gw.CheckHost(ctx, host); !ok { + log.Printf("Not allowed to connect to %s by policy handler", host) + msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) + p.tunnel.Write(msg) + return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) + } + } + log.Printf("Establishing connection to RDP server: %s", host) + p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) + if err != nil { + log.Printf("Error connecting to %s, %s", host, err) + msg := p.channelResponse(E_PROXY_INTERNALERROR) + p.tunnel.Write(msg) + return err + } + p.tunnel.TargetServer = host + log.Printf("Connection established") + msg := p.channelResponse(ERROR_SUCCESS) p.tunnel.Write(msg) - return err - } - p.tunnel.TargetServer = host - log.Printf("Connection established") - msg := p.channelResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - // Make sure to start the flow from the RDP server first otherwise connections - // might hang eventually - go forward(p.tunnel.rwc, p.tunnel) - p.state = SERVER_STATE_CHANNEL_CREATE - case PKT_TYPE_DATA: - if p.state < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) - return errors.New("wrong state") - } - p.state = SERVER_STATE_OPENED - receive(pkt, p.tunnel.rwc) - case PKT_TYPE_KEEPALIVE: - // keepalives can be received while the channel is not open yet - if p.state < SERVER_STATE_CHANNEL_CREATE { - log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) - return errors.New("wrong state") - } + // Make sure to start the flow from the RDP server first otherwise connections + // might hang eventually + go forward(p.tunnel.rwc, p.tunnel) + p.state = SERVER_STATE_CHANNEL_CREATE + case PKT_TYPE_DATA: + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } + p.state = SERVER_STATE_OPENED + receive(message.msg, p.tunnel.rwc) + case PKT_TYPE_KEEPALIVE: + // keepalives can be received while the channel is not open yet + if p.state < SERVER_STATE_CHANNEL_CREATE { + log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) + return errors.New("wrong state") + } - // avoid concurrency issues - // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) - case PKT_TYPE_CLOSE_CHANNEL: - log.Printf("Close channel") - if p.state != SERVER_STATE_OPENED { - log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) - return errors.New("wrong state") + // avoid concurrency issues + // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) + case PKT_TYPE_CLOSE_CHANNEL: + log.Printf("Close channel") + if p.state != SERVER_STATE_OPENED { + log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) + return errors.New("wrong state") + } + msg := p.channelCloseResponse(ERROR_SUCCESS) + p.tunnel.Write(msg) + p.state = SERVER_STATE_CLOSED + return nil + default: + log.Printf("Unknown packet (size %d): %x", message.length, message.msg) } - msg := p.channelCloseResponse(ERROR_SUCCESS) - p.tunnel.Write(msg) - p.state = SERVER_STATE_CLOSED - return nil - default: - log.Printf("Unknown packet (size %d): %x", sz, pkt) } } } diff --git a/cmd/rdpgw/protocol/tunnel.go b/cmd/rdpgw/protocol/tunnel.go index dc3b1d5..0f4a350 100644 --- a/cmd/rdpgw/protocol/tunnel.go +++ b/cmd/rdpgw/protocol/tunnel.go @@ -1,10 +1,11 @@ package protocol import ( - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "net" "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" ) const ( @@ -46,6 +47,13 @@ type Tunnel struct { LastSeen time.Time } +type message struct { + packetType int + length int + msg []byte + err error +} + // Write puts the packet on the transport and updates the statistics for bytes sent func (t *Tunnel) Write(pkt []byte) { n, _ := t.transportOut.WritePacket(pkt) @@ -55,10 +63,14 @@ func (t *Tunnel) Write(pkt []byte) { // Read picks up a packet from the transport and returns the packet type // packet, with the header removed, and the packet size. It updates the // statistics for bytes received -func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { - pt, size, pkt, err = readMessage(t.transportIn) - t.BytesReceived += int64(size) - t.LastSeen = time.Now() - - return pt, size, pkt, err +func (t *Tunnel) Read() ([]*message, error) { + messages, err := readMessage(t.transportIn) + if err != nil { + return nil, err + } + for _, message := range messages { + t.BytesReceived += int64(message.length) + t.LastSeen = time.Now() + } + return messages, err }