Refactor and add tests

This commit is contained in:
Bolke de Bruin 2020-08-01 15:50:13 +02:00
parent ecfa9e6cf4
commit 4e99b4e88f
4 changed files with 119 additions and 31 deletions

View file

@ -4,7 +4,9 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/transport"
"io" "io"
"net"
) )
const ( const (
@ -17,6 +19,8 @@ type ClientConfig struct {
SmartCardAuth bool SmartCardAuth bool
PAAToken string PAAToken string
NTLMAuth bool NTLMAuth bool
GatewayConn transport.Transport
LocalConn net.Conn
} }
func (c *ClientConfig) handshakeRequest() []byte { func (c *ClientConfig) handshakeRequest() []byte {
@ -148,3 +152,38 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
return return
} }
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
utf16server := EncodeUTF16(server)
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server)
return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes())
}
func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) {
var errorCode uint32
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID {
binary.Read(r, binary.LittleEndian, &channelId)
}
if errorCode > 0 {
return 0, fmt.Errorf("channel response error %d", errorCode)
}
return channelId, nil
}

View file

@ -4,7 +4,10 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/transport"
"io" "io"
"log"
"net"
) )
func createPacket(pktType uint16, data []byte) (packet []byte) { func createPacket(pktType uint16, data []byte) (packet []byte) {
@ -34,4 +37,35 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
return packetType, size, data[8:], nil return packetType, size, data[8:], nil
} }
// sends data wrapped inside the rdpgw protocol
func forward(in net.Conn, out transport.Transport) {
defer in.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := in.Read(buf)
if err != nil {
log.Printf("Error reading from local conn %s", err)
break
}
binary.Write(b1, binary.LittleEndian, uint16(n))
b1.Write(buf[:n])
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
// receive data from the wire, unwrap and forward to the client
func receive(data []byte, out net.Conn) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
out.Write(pkt)
}

View file

@ -14,6 +14,8 @@ const (
TunnelCreateResponseLen = HeaderLen + 18 TunnelCreateResponseLen = HeaderLen + 18
TunnelAuthLen = HeaderLen + 2 // + dynamic TunnelAuthLen = HeaderLen + 2 // + dynamic
TunnelAuthResponseLen = HeaderLen + 16 TunnelAuthResponseLen = HeaderLen + 16
ChannelCreateLen = HeaderLen + 8 // + dynamic
ChannelResponseLen = HeaderLen + 12
) )
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
@ -162,3 +164,44 @@ func TestTunnelAuth(t *testing.T) {
timeout, hc.IdleTimeout) timeout, hc.IdleTimeout)
} }
} }
func TestChannelCreation(t *testing.T) {
client := ClientConfig{}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
server := "test_server"
port := uint16(3389)
data := client.channelRequest(server, port)
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
hServer, hPort := h.channelRequest(pkt)
if hServer != server {
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
}
if hPort != port {
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, port)
}
data = h.channelResponse()
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
channelId, err := client.channelResponse(pkt)
if err != nil {
t.Fatalf("channelResponse failed: %s", err)
}
if channelId < 1 {
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
}
}

View file

@ -148,7 +148,7 @@ func (s *Server) Process(ctx context.Context) error {
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go s.sendDataPacket() go forward(s.Remote, s.Session.TransportOut)
s.State = SERVER_STATE_CHANNEL_CREATE s.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if s.State < SERVER_STATE_CHANNEL_CREATE { if s.State < SERVER_STATE_CHANNEL_CREATE {
@ -156,7 +156,7 @@ func (s *Server) Process(ctx context.Context) error {
return errors.New("wrong state") return errors.New("wrong state")
} }
s.State = SERVER_STATE_OPENED s.State = SERVER_STATE_OPENED
s.forwardDataPacket(pkt) receive(pkt, s.Remote)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet // keepalives can be received while the channel is not open yet
if s.State < SERVER_STATE_CHANNEL_CREATE { if s.State < SERVER_STATE_CHANNEL_CREATE {
@ -357,34 +357,6 @@ func (s *Server) channelResponse() []byte {
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
} }
func (s *Server) forwardDataPacket(data []byte) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
s.Remote.Write(pkt)
}
func (s *Server) sendDataPacket() {
defer s.Remote.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := s.Remote.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func makeRedirectFlags(flags RedirectFlags) int { func makeRedirectFlags(flags RedirectFlags) int {
var redir = 0 var redir = 0