mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
Add test cases and start client
This commit is contained in:
parent
dadaeb611b
commit
5618294f10
4 changed files with 162 additions and 12 deletions
|
@ -3,6 +3,8 @@ package protocol
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -43,4 +45,72 @@ func (c *ClientConfig) handshakeRequest() []byte {
|
||||||
return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes())
|
return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) readServerHandshakeResponse(data []byte) ()
|
func (c *ClientConfig) handshakeResponse(data []byte) (caps uint16, err error) {
|
||||||
|
var errorCode int32
|
||||||
|
var major byte
|
||||||
|
var minor byte
|
||||||
|
var version uint16
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||||
|
binary.Read(r, binary.LittleEndian, &major)
|
||||||
|
binary.Read(r, binary.LittleEndian, &minor)
|
||||||
|
binary.Read(r, binary.LittleEndian, &version)
|
||||||
|
binary.Read(r, binary.LittleEndian, &caps)
|
||||||
|
|
||||||
|
if errorCode > 0 {
|
||||||
|
return 0, fmt.Errorf("error code: %d", errorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return caps, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConfig) tunnelRequest() []byte {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
var caps uint32
|
||||||
|
var size uint16
|
||||||
|
var fields uint16
|
||||||
|
|
||||||
|
if len(c.PAAToken) > 0 {
|
||||||
|
fields = fields | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE
|
||||||
|
}
|
||||||
|
|
||||||
|
caps = caps | HTTP_CAPABILITY_IDLE_TIMEOUT
|
||||||
|
|
||||||
|
binary.Write(buf, binary.LittleEndian, caps)
|
||||||
|
binary.Write(buf, binary.LittleEndian, fields)
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
|
||||||
|
if len(c.PAAToken) > 0 {
|
||||||
|
utf16Token := EncodeUTF16(c.PAAToken)
|
||||||
|
size = uint16(len(utf16Token))
|
||||||
|
binary.Write(buf, binary.LittleEndian, size)
|
||||||
|
buf.Write(utf16Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return createPacket(PKT_TYPE_TUNNEL_CREATE, buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32, err error) {
|
||||||
|
var version uint16
|
||||||
|
var errorCode uint32
|
||||||
|
var fields uint16
|
||||||
|
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
binary.Read(r, binary.LittleEndian, &version)
|
||||||
|
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||||
|
binary.Read(r, binary.LittleEndian, &fields)
|
||||||
|
r.Seek(2, io.SeekCurrent)
|
||||||
|
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID) == HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID {
|
||||||
|
binary.Read(r, binary.LittleEndian, &tunnelId)
|
||||||
|
}
|
||||||
|
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_CAPS) == HTTP_TUNNEL_RESPONSE_FIELD_CAPS {
|
||||||
|
binary.Read(r, binary.LittleEndian, &caps)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errorCode != 0 {
|
||||||
|
err = fmt.Errorf("tunnel error %d", errorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -8,26 +9,40 @@ import (
|
||||||
const (
|
const (
|
||||||
HeaderLen = 8
|
HeaderLen = 8
|
||||||
HandshakeRequestLen = HeaderLen + 6
|
HandshakeRequestLen = HeaderLen + 6
|
||||||
|
HandshakeResponseLen = HeaderLen + 10
|
||||||
|
TunnelCreateRequestLen = HeaderLen + 8 // + dynamic
|
||||||
|
TunnelCreateResponseLen = HeaderLen + 18
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func verifyPacketHeader(data []byte , expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
||||||
|
pt, size, pkt, err := readHeader(data)
|
||||||
|
|
||||||
|
if pt != expPt {
|
||||||
|
return 0,0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if size != expSize {
|
||||||
|
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, []byte{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return pt, size, pkt, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func TestHandshake(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
PAAToken: "abab",
|
PAAToken: "abab",
|
||||||
}
|
}
|
||||||
|
|
||||||
data := client.handshakeRequest()
|
data := client.handshakeRequest()
|
||||||
pt, size, pkt, err := readHeader(data)
|
|
||||||
|
|
||||||
if pt != PKT_TYPE_HANDSHAKE_REQUEST {
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen)
|
||||||
t.Fatalf("readHeader failed, expected packet type %d got %d", PKT_TYPE_HANDSHAKE_REQUEST, pt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if size != HandshakeRequestLen {
|
|
||||||
t.Fatalf("readHeader failed, expected size %d, got %d", HandshakeRequestLen, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readHeader failed got error %s", err)
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("pkt: %x", pkt)
|
log.Printf("pkt: %x", pkt)
|
||||||
|
@ -41,4 +56,61 @@ func TestHandshake(t *testing.T) {
|
||||||
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||||
t.Fatalf("readHandshake failed got ext auth %d, expected %d", extAuth, extAuth | HTTP_EXTENDED_AUTH_PAA)
|
t.Fatalf("readHandshake failed got ext auth %d, expected %d", extAuth, extAuth | HTTP_EXTENDED_AUTH_PAA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s := &SessionInfo{}
|
||||||
|
hc := &HandlerConf{
|
||||||
|
TokenAuth: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewHandler(s, hc)
|
||||||
|
|
||||||
|
data = h.handshakeResponse(0x0, 0x0)
|
||||||
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
|
}
|
||||||
|
log.Printf("pkt: %x", pkt)
|
||||||
|
|
||||||
|
caps, err := client.handshakeResponse(pkt)
|
||||||
|
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||||
|
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps | HTTP_EXTENDED_AUTH_PAA)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTunnelCreation(t *testing.T) {
|
||||||
|
client := ClientConfig{
|
||||||
|
PAAToken: "abab",
|
||||||
|
}
|
||||||
|
|
||||||
|
data := client.tunnelRequest()
|
||||||
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
||||||
|
uint32(TunnelCreateRequestLen + 2 + len(client.PAAToken)*2))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caps, token := readCreateTunnelRequest(pkt)
|
||||||
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||||
|
t.Fatalf("readCreateTunnelRequest failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||||
|
}
|
||||||
|
if token != client.PAAToken {
|
||||||
|
t.Fatalf("readCreateTunnelRequest failed got token %s, expected %s", token, client.PAAToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
data = createTunnelResponse()
|
||||||
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tid, caps, err := client.tunnelResponse(pkt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error %s", err)
|
||||||
|
}
|
||||||
|
if tid != tunnelId {
|
||||||
|
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
|
||||||
|
}
|
||||||
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||||
|
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,2 +0,0 @@
|
||||||
package protocol
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -30,3 +31,12 @@ func DecodeUTF16(b []byte) (string, error) {
|
||||||
}
|
}
|
||||||
return string(bret), nil
|
return string(bret), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EncodeUTF16(s string) []byte {
|
||||||
|
ret := new(bytes.Buffer)
|
||||||
|
enc := utf16.Encode([]rune(s))
|
||||||
|
for c := range enc {
|
||||||
|
binary.Write(ret, binary.LittleEndian, enc[c])
|
||||||
|
}
|
||||||
|
return ret.Bytes()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue