mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 05:53:50 +02:00
Refactor and add tests
This commit is contained in:
parent
ecfa9e6cf4
commit
4e99b4e88f
4 changed files with 119 additions and 31 deletions
|
@ -4,7 +4,9 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -17,6 +19,8 @@ type ClientConfig struct {
|
|||
SmartCardAuth bool
|
||||
PAAToken string
|
||||
NTLMAuth bool
|
||||
GatewayConn transport.Transport
|
||||
LocalConn net.Conn
|
||||
}
|
||||
|
||||
func (c *ClientConfig) handshakeRequest() []byte {
|
||||
|
@ -147,4 +151,39 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
|
|||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
|
||||
utf16server := EncodeUTF16(server)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
|
||||
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
|
||||
binary.Write(buf, binary.LittleEndian, uint16(port))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
|
||||
buf.Write(utf16server)
|
||||
|
||||
return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) {
|
||||
var errorCode uint32
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
|
||||
if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID {
|
||||
binary.Read(r, binary.LittleEndian, &channelId)
|
||||
}
|
||||
|
||||
if errorCode > 0 {
|
||||
return 0, fmt.Errorf("channel response error %d", errorCode)
|
||||
}
|
||||
|
||||
return channelId, nil
|
||||
}
|
||||
|
|
|
@ -4,7 +4,10 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||
|
@ -34,4 +37,35 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
|
|||
return packetType, size, data[8:], nil
|
||||
}
|
||||
|
||||
// sends data wrapped inside the rdpgw protocol
|
||||
func forward(in net.Conn, out transport.Transport) {
|
||||
defer in.Close()
|
||||
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
|
||||
for {
|
||||
n, err := in.Read(buf)
|
||||
if err != nil {
|
||||
log.Printf("Error reading from local conn %s", err)
|
||||
break
|
||||
}
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
b1.Write(buf[:n])
|
||||
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// receive data from the wire, unwrap and forward to the client
|
||||
func receive(data []byte, out net.Conn) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var cblen uint16
|
||||
binary.Read(buf, binary.LittleEndian, &cblen)
|
||||
pkt := make([]byte, cblen)
|
||||
binary.Read(buf, binary.LittleEndian, &pkt)
|
||||
|
||||
out.Write(pkt)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ const (
|
|||
TunnelCreateResponseLen = HeaderLen + 18
|
||||
TunnelAuthLen = HeaderLen + 2 // + dynamic
|
||||
TunnelAuthResponseLen = HeaderLen + 16
|
||||
ChannelCreateLen = HeaderLen + 8 // + dynamic
|
||||
ChannelResponseLen = HeaderLen + 12
|
||||
)
|
||||
|
||||
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
||||
|
@ -162,3 +164,44 @@ func TestTunnelAuth(t *testing.T) {
|
|||
timeout, hc.IdleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelCreation(t *testing.T) {
|
||||
client := ClientConfig{}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: true,
|
||||
IdleTimeout: 10,
|
||||
RedirectFlags: RedirectFlags{
|
||||
Clipboard: true,
|
||||
},
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
server := "test_server"
|
||||
port := uint16(3389)
|
||||
|
||||
data := client.channelRequest(server, port)
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
hServer, hPort := h.channelRequest(pkt)
|
||||
if hServer != server {
|
||||
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
|
||||
}
|
||||
if hPort != port {
|
||||
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, port)
|
||||
}
|
||||
|
||||
data = h.channelResponse()
|
||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
channelId, err := client.channelResponse(pkt)
|
||||
if err != nil {
|
||||
t.Fatalf("channelResponse failed: %s", err)
|
||||
}
|
||||
if channelId < 1 {
|
||||
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
|
||||
}
|
||||
}
|
|
@ -148,7 +148,7 @@ func (s *Server) Process(ctx context.Context) error {
|
|||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go s.sendDataPacket()
|
||||
go forward(s.Remote, s.Session.TransportOut)
|
||||
s.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
|
@ -156,7 +156,7 @@ func (s *Server) Process(ctx context.Context) error {
|
|||
return errors.New("wrong state")
|
||||
}
|
||||
s.State = SERVER_STATE_OPENED
|
||||
s.forwardDataPacket(pkt)
|
||||
receive(pkt, s.Remote)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// keepalives can be received while the channel is not open yet
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
|
@ -357,34 +357,6 @@ func (s *Server) channelResponse() []byte {
|
|||
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) forwardDataPacket(data []byte) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var cblen uint16
|
||||
binary.Read(buf, binary.LittleEndian, &cblen)
|
||||
pkt := make([]byte, cblen)
|
||||
binary.Read(buf, binary.LittleEndian, &pkt)
|
||||
|
||||
s.Remote.Write(pkt)
|
||||
}
|
||||
|
||||
func (s *Server) sendDataPacket() {
|
||||
defer s.Remote.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
for {
|
||||
n, err := s.Remote.Read(buf)
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
if err != nil {
|
||||
log.Printf("Error reading from conn %s", err)
|
||||
break
|
||||
}
|
||||
b1.Write(buf[:n])
|
||||
s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func makeRedirectFlags(flags RedirectFlags) int {
|
||||
var redir = 0
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue