fix: handle multiple message frames inside packet (#143)

Running the gateway as non-tls, but using an external TLS gateway in
kubernetes+istio, I determined that the istio TLS gateway would join
messages frames into a single TCP packet. The packet read code assumed
that a single packet is a message. This is not the case for a TCP
stream, since you don't know how the frames are segmented via proxies,
etc.

The fix turned out more complex that I would have liked, but added a
number of unit tests to cover all the corner cases. Likely fragmentation
was not working correctly as well, as there was some cases that were
previously not handled.

Note that this might address issue #126 as well.
This commit is contained in:
Mike Marchetti 2025-05-06 11:38:16 -04:00 committed by GitHub
parent 6b4e6bdced
commit 20307b9a76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 579 additions and 197 deletions

2
.gitignore vendored
View file

@ -1 +1,3 @@
go.sum go.sum
bin
*.swp

View file

@ -30,52 +30,58 @@ func (c *ClientConfig) ConnectAndForward() error {
c.Session.transportOut.WritePacket(c.handshakeRequest()) c.Session.transportOut.WritePacket(c.handshakeRequest())
for { for {
pt, sz, pkt, err := readMessage(c.Session.transportIn) messages, err := readMessage(c.Session.transportIn)
if err != nil { if err != nil {
log.Printf("Cannot read message from stream %s", err) log.Printf("Cannot read message from stream %s", err)
return err return err
} }
switch pt { for _, message := range messages {
case PKT_TYPE_HANDSHAKE_RESPONSE: if message.err != nil {
caps, err := c.handshakeResponse(pkt) log.Printf("Cannot read message from stream %p", err)
if err != nil { continue
log.Printf("Cannot connect to %s due to %s", c.Server, err)
return err
} }
log.Printf("Handshake response received. Caps: %d", caps) switch message.packetType {
c.Session.transportOut.WritePacket(c.tunnelRequest()) case PKT_TYPE_HANDSHAKE_RESPONSE:
case PKT_TYPE_TUNNEL_RESPONSE: caps, err := c.handshakeResponse(message.msg)
tid, caps, err := c.tunnelResponse(pkt) if err != nil {
if err != nil { log.Printf("Cannot connect to %s due to %s", c.Server, err)
log.Printf("Cannot setup tunnel due to %s", err) return err
return err }
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.transportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(message.msg)
if err != nil {
log.Printf("Cannot setup tunnel due to %s", err)
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(message.msg)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.transportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(message.msg)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
//go forward(c.LocalConn, c.Session.transportOut)
case PKT_TYPE_DATA:
receive(message.msg, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", message.packetType, message.length)
} }
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.transportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
//go forward(c.LocalConn, c.Session.transportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", pt, sz)
} }
} }
} }

View file

@ -4,12 +4,18 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "fmt"
"io" "io"
"log" "log"
"net" "net"
"os" "os"
"syscall" "syscall"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
)
const (
maxFragmentSize = 65536
) )
type RedirectFlags struct { type RedirectFlags struct {
@ -22,46 +28,57 @@ type RedirectFlags struct {
EnableAll bool EnableAll bool
} }
// readMessage parses and defragments a packet from a Transport. It returns func handleMsgFrame(packet *packetReader) *message {
// at most the bytes that have been reported by the packet pt, sz, msg, err := readHeader(packet.getPtr())
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { if err == nil {
fragment := false packet.incrementPtr(int(sz))
return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil}
}
buf := make([]byte, maxFragmentSize)
index := 0 index := 0
buf := make([]byte, 4096)
for { for {
size, pkt, err := in.ReadPacket() // keep parsing thfragment
if len(packet.getPtr()) > len(buf[index:]) {
return &message{packetType: int(pt), length: int(sz), msg: msg, err: fmt.Errorf("fragment exceeded max fragment size")}
}
index += copy(buf[index:], packet.getPtr())
// Get a new frame
err := packet.read()
if err != nil { if err != nil {
return 0, 0, []byte{0, 0}, err // Failed to make a msg
return &message{packetType: int(pt), length: int(sz), msg: msg, err: err}
} }
pt, sz, msg, err = readHeader(append(buf[:index], packet.getPtr()...))
// check for fragments if err == nil {
var pt uint16 // the increment is based upon how much of the data we have used
var sz uint32 // in this packet. The index tells us how much is in the previous frame(s),
var msg []byte // So we remove that from the size of the message.
packet.incrementPtr(int(sz) - index)
if !fragment { return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil}
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
} }
} }
} }
// readMessage parses and defragments a packet from a Transport. It returns
// at most the bytes that have been reported by the packet.
func readMessage(in transport.Transport) ([]*message, error) {
messages := make([]*message, 0)
packet := newTransportPacket(in)
err := packet.read()
if err != nil {
return messages, err
}
var message *message
for packet.hasMoreData() {
message = handleMsgFrame(packet)
messages = append(messages, message)
}
return messages, nil
}
// createPacket wraps the data into the protocol packet // createPacket wraps the data into the protocol packet
func createPacket(pktType uint16, data []byte) (packet []byte) { func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8 size := len(data) + 8

View file

@ -0,0 +1,297 @@
package protocol
import (
"fmt"
"math/rand"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
type messageMock struct {
buffer []byte
msgBuffer []byte
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randBytes(message []byte) {
for index := range message {
message[index] = letterBytes[rand.Intn(len(letterBytes))]
}
}
func newMessageMock(packetType uint16, message []byte) *messageMock {
randBytes(message)
buf := createPacket(packetType, message)
return &messageMock{msgBuffer: buf[8:], buffer: buf}
}
type packetMock struct {
bytes []byte
err error
}
func newPacketMock() *packetMock {
return &packetMock{bytes: make([]byte, 0)}
}
func (p *packetMock) addBytes(b []byte) {
p.bytes = append(p.bytes, b...)
}
func (p *packetMock) GetPacket() (int, []byte, error) {
return len(p.bytes), p.bytes, p.err
}
type transportMock struct {
lock sync.Mutex
packets []*packetMock
packetPtr int
}
func newTransportMock() *transportMock {
return &transportMock{packets: make([]*packetMock, 0)}
}
func (t *transportMock) addPacket(p *packetMock) {
t.lock.Lock()
defer t.lock.Unlock()
t.packets = append(t.packets, p)
}
func (t *transportMock) ReadPacket() (n int, p []byte, err error) {
t.lock.Lock()
defer t.lock.Unlock()
if t.packetPtr >= len(t.packets) {
return 0, nil, fmt.Errorf("no packets available")
}
packet := t.packets[t.packetPtr]
t.packetPtr++
return packet.GetPacket()
}
func (t *transportMock) WritePacket(b []byte) (n int, err error) {
return 0, fmt.Errorf("not tested")
}
func (t *transportMock) Close() error {
return nil
}
func TestSimplePacket(t *testing.T) {
transport := newTransportMock()
m := newMessageMock(6, make([]byte, 10))
p := newPacketMock()
p.addBytes(m.buffer)
transport.addPacket(p)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 1)
assert.Equal(t, 6, messages[0].packetType)
assert.Equal(t, 18, messages[0].length)
assert.Equal(t, m.msgBuffer, messages[0].msg)
}
func TestMultiMessageInPacket(t *testing.T) {
transport := newTransportMock()
p := newPacketMock()
m := newMessageMock(6, make([]byte, 10))
p.addBytes(m.buffer)
m2 := newMessageMock(8, make([]byte, 12))
p.addBytes(m2.buffer)
m3 := newMessageMock(8, make([]byte, 12))
p.addBytes(m3.buffer)
transport.addPacket(p)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 3)
assert.Nil(t, messages[0].err)
assert.Equal(t, 6, messages[0].packetType)
assert.Equal(t, 18, messages[0].length)
assert.Equal(t, m.msgBuffer, messages[0].msg)
assert.Nil(t, messages[1].err)
assert.Equal(t, 8, messages[1].packetType)
assert.Equal(t, 20, messages[1].length)
assert.Equal(t, m2.msgBuffer, messages[1].msg)
assert.Nil(t, messages[2].err)
assert.Equal(t, 8, messages[2].packetType)
assert.Equal(t, 20, messages[2].length)
assert.Equal(t, m3.msgBuffer, messages[2].msg)
}
func TestFragment(t *testing.T) {
transport := newTransportMock()
p1 := newPacketMock()
p2 := newPacketMock()
m := newMessageMock(6, make([]byte, 100))
// split the message across 2 packets
p1.addBytes(m.buffer[0:50])
p2.addBytes(m.buffer[50:])
transport.addPacket(p1)
transport.addPacket(p2)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 1)
assert.Equal(t, 6, messages[0].packetType)
assert.Equal(t, 108, messages[0].length)
assert.Equal(t, m.msgBuffer, messages[0].msg)
_, err = readMessage(transport)
// no more packets
assert.NotNil(t, err)
}
func TestDroppedBytes(t *testing.T) {
transport := newTransportMock()
p1 := newPacketMock()
m := newMessageMock(6, make([]byte, 100))
// add only partial bytes
p1.addBytes(m.buffer[0:50])
transport.addPacket(p1)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.Len(t, messages, 1)
assert.NotNil(t, messages[0].err)
_, err = readMessage(transport)
// no more packets
assert.NotNil(t, err)
}
func TestTooMuchData(t *testing.T) {
transport := newTransportMock()
p1 := newPacketMock()
m := newMessageMock(6, make([]byte, 100))
// add only partial bytes
p1.addBytes(m.buffer)
p1.addBytes([]byte{0, 0, 0})
// add some junk bytes
transport.addPacket(p1)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 2)
assert.Nil(t, messages[0].err)
assert.NotNil(t, messages[1].err)
_, err = readMessage(transport)
// no more packets
assert.NotNil(t, err)
}
func TestJumbo(t *testing.T) {
transport := newTransportMock()
p1 := newPacketMock()
p2 := newPacketMock()
m := newMessageMock(6, make([]byte, maxFragmentSize))
// add only partial bytes
p1.addBytes(m.buffer[0 : maxFragmentSize/2])
p2.addBytes(m.buffer[maxFragmentSize/2:])
// add some junk bytes
transport.addPacket(p1)
transport.addPacket(p2)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 1)
assert.Equal(t, m.msgBuffer, messages[0].msg)
}
func TestManyFragments(t *testing.T) {
transport := newTransportMock()
m := newMessageMock(6, make([]byte, 256))
fragmentSize := len(m.buffer) / 5
bufferSize := len(m.buffer)
for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize {
p := newPacketMock()
p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)])
transport.addPacket(p)
}
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 1)
assert.Nil(t, messages[0].err)
assert.Equal(t, m.msgBuffer, messages[0].msg)
messages, err = readMessage(transport)
// no more packets
fmt.Println(messages)
assert.NotNil(t, err)
}
func TestFragmentTooLarge(t *testing.T) {
transport := newTransportMock()
m := newMessageMock(6, make([]byte, maxFragmentSize*2))
fragmentSize := len(m.buffer) / 5
bufferSize := len(m.buffer)
for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize {
p := newPacketMock()
p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)])
transport.addPacket(p)
}
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages[0].err)
assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error())
}
// TestFragmentWithMultiMessage the first message is fragmented,
// while the second message is found whole in the final packet
func TestFragmentWithMultiMessage(t *testing.T) {
transport := newTransportMock()
p1 := newPacketMock()
p2 := newPacketMock()
m1 := newMessageMock(6, make([]byte, 100))
m2 := newMessageMock(6, make([]byte, 10))
// split the message across 2 packets
p1.addBytes(m1.buffer[0:50])
p2.addBytes(m1.buffer[50:])
p2.addBytes(m2.buffer)
transport.addPacket(p1)
transport.addPacket(p2)
messages, err := readMessage(transport)
assert.Nil(t, err)
assert.NotNil(t, messages)
assert.Len(t, messages, 2)
assert.Equal(t, 6, messages[0].packetType)
assert.Equal(t, 108, messages[0].length)
assert.Equal(t, m1.msgBuffer, messages[0].msg)
assert.Equal(t, 6, messages[1].packetType)
assert.Equal(t, 18, messages[1].length)
assert.Equal(t, m2.msgBuffer, messages[1].msg)
_, err = readMessage(transport)
// no more packets
assert.NotNil(t, err)
}

View file

@ -0,0 +1,40 @@
package protocol
import "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
type packetReader struct {
in transport.Transport
size int
pkt []byte
err error
readPtr int
}
func newTransportPacket(in transport.Transport) *packetReader {
return &packetReader{in: in}
}
func (t *packetReader) hasMoreData() bool {
return t.readPtr < t.size
}
func (t *packetReader) getPtr() []byte {
return t.pkt[t.readPtr:]
}
func (t *packetReader) incrementPtr(size int) {
t.readPtr += size
}
func (t *packetReader) read() error {
size, pkt, err := t.in.ReadPacket()
if err != nil {
t.size = 0
} else {
t.size = size
}
t.pkt = pkt
t.err = err
t.readPtr = 0
return err
}

View file

@ -6,12 +6,13 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"io" "io"
"log" "log"
"net" "net"
"strconv" "strconv"
"time" "time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
) )
type Processor struct { type Processor struct {
@ -43,139 +44,146 @@ const tunnelId = 10
func (p *Processor) Process(ctx context.Context) error { func (p *Processor) Process(ctx context.Context) error {
for { for {
pt, sz, pkt, err := p.tunnel.Read() //pt, sz, pkt, err := p.tunnel.Read()
messages, err := p.tunnel.Read()
if err != nil { if err != nil {
log.Printf("Cannot read message from stream %p", err) log.Printf("Cannot read message from stream %p", err)
return err return err
} }
switch pt { for _, message := range messages {
case PKT_TYPE_HANDSHAKE_REQUEST: if message.err != nil {
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp)) log.Printf("Cannot read message from stream %p", err)
if p.state != SERVER_STATE_INITIALIZED { continue
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
} }
major, minor, _, reqAuth := p.handshakeRequest(pkt) switch message.packetType {
caps, err := p.matchAuth(reqAuth) case PKT_TYPE_HANDSHAKE_REQUEST:
if err != nil { log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
log.Println(err) if p.state != SERVER_STATE_INITIALIZED {
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
p.tunnel.Write(msg) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
return err
}
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
if p.state != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d",
p.state, SERVER_STATE_HANDSHAKE)
msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
}
_, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
} }
} major, minor, _, reqAuth := p.handshakeRequest(message.msg)
msg := p.tunnelResponse(ERROR_SUCCESS) caps, err := p.matchAuth(reqAuth)
p.tunnel.Write(msg) if err != nil {
p.state = SERVER_STATE_TUNNEL_CREATE log.Println(err)
case PKT_TYPE_TUNNEL_AUTH: msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
log.Printf("Tunnel auth")
if p.state != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
p.state, SERVER_STATE_TUNNEL_CREATE)
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
}
client := p.tunnelAuthRequest(pkt)
if p.gw.CheckClientName != nil {
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
log.Printf("Invalid client name: %s", client)
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) return err
} }
} msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create")
if p.state != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
p.state, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) p.state = SERVER_STATE_HANDSHAKE
} case PKT_TYPE_TUNNEL_CREATE:
server, port := p.channelRequest(pkt) log.Printf("Tunnel create")
host := net.JoinHostPort(server, strconv.Itoa(int(port))) if p.state != SERVER_STATE_HANDSHAKE {
if p.gw.CheckHost != nil { log.Printf("Tunnel create attempted while in wrong state %d != %d",
log.Printf("Verifying %s host connection", host) p.state, SERVER_STATE_HANDSHAKE)
if ok, _ := p.gw.CheckHost(ctx, host); !ok { msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
log.Printf("Not allowed to connect to %s by policy handler", host)
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
} }
} _, cookie := p.tunnelRequest(message.msg)
log.Printf("Establishing connection to RDP server: %s", host) if p.gw.CheckPAACookie != nil {
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
if err != nil { log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(identity.AttrClientIp))
log.Printf("Error connecting to %s, %s", host, err) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
msg := p.channelResponse(E_PROXY_INTERNALERROR) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
}
}
msg := p.tunnelResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
if p.state != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
p.state, SERVER_STATE_TUNNEL_CREATE)
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
}
client := p.tunnelAuthRequest(message.msg)
if p.gw.CheckClientName != nil {
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
log.Printf("Invalid client name: %s", client)
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
}
}
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create")
if p.state != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
p.state, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
}
server, port := p.channelRequest(message.msg)
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if p.gw.CheckHost != nil {
log.Printf("Verifying %s host connection", host)
if ok, _ := p.gw.CheckHost(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host)
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
p.tunnel.Write(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
}
}
log.Printf("Establishing connection to RDP server: %s", host)
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil {
log.Printf("Error connecting to %s, %s", host, err)
msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg)
return err
}
p.tunnel.TargetServer = host
log.Printf("Connection established")
msg := p.channelResponse(ERROR_SUCCESS)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return err
}
p.tunnel.TargetServer = host
log.Printf("Connection established")
msg := p.channelResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go forward(p.tunnel.rwc, p.tunnel) go forward(p.tunnel.rwc, p.tunnel)
p.state = SERVER_STATE_CHANNEL_CREATE p.state = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if p.state < SERVER_STATE_CHANNEL_CREATE { if p.state < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
p.state = SERVER_STATE_OPENED p.state = SERVER_STATE_OPENED
receive(pkt, p.tunnel.rwc) receive(message.msg, p.tunnel.rwc)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet // keepalives can be received while the channel is not open yet
if p.state < SERVER_STATE_CHANNEL_CREATE { if p.state < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE) log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
// avoid concurrency issues // avoid concurrency issues
// p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel") log.Printf("Close channel")
if p.state != SERVER_STATE_OPENED { if p.state != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED) log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED)
return errors.New("wrong state") return errors.New("wrong state")
}
msg := p.channelCloseResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_CLOSED
return nil
default:
log.Printf("Unknown packet (size %d): %x", message.length, message.msg)
} }
msg := p.channelCloseResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
p.state = SERVER_STATE_CLOSED
return nil
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
} }
} }
} }

View file

@ -1,10 +1,11 @@
package protocol package protocol
import ( import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net" "net"
"time" "time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
) )
const ( const (
@ -46,6 +47,13 @@ type Tunnel struct {
LastSeen time.Time LastSeen time.Time
} }
type message struct {
packetType int
length int
msg []byte
err error
}
// Write puts the packet on the transport and updates the statistics for bytes sent // Write puts the packet on the transport and updates the statistics for bytes sent
func (t *Tunnel) Write(pkt []byte) { func (t *Tunnel) Write(pkt []byte) {
n, _ := t.transportOut.WritePacket(pkt) n, _ := t.transportOut.WritePacket(pkt)
@ -55,10 +63,14 @@ func (t *Tunnel) Write(pkt []byte) {
// Read picks up a packet from the transport and returns the packet type // Read picks up a packet from the transport and returns the packet type
// packet, with the header removed, and the packet size. It updates the // packet, with the header removed, and the packet size. It updates the
// statistics for bytes received // statistics for bytes received
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { func (t *Tunnel) Read() ([]*message, error) {
pt, size, pkt, err = readMessage(t.transportIn) messages, err := readMessage(t.transportIn)
t.BytesReceived += int64(size) if err != nil {
t.LastSeen = time.Now() return nil, err
}
return pt, size, pkt, err for _, message := range messages {
t.BytesReceived += int64(message.length)
t.LastSeen = time.Now()
}
return messages, err
} }