mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
211 lines
5.8 KiB
Go
211 lines
5.8 KiB
Go
package protocol
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"testing"
|
|
)
|
|
|
|
const (
|
|
HeaderLen = 8
|
|
HandshakeRequestLen = HeaderLen + 6
|
|
HandshakeResponseLen = HeaderLen + 10
|
|
TunnelCreateRequestLen = HeaderLen + 8 // + dynamic
|
|
TunnelCreateResponseLen = HeaderLen + 18
|
|
TunnelAuthLen = HeaderLen + 2 // + dynamic
|
|
TunnelAuthResponseLen = HeaderLen + 16
|
|
ChannelCreateLen = HeaderLen + 8 // + dynamic
|
|
ChannelResponseLen = HeaderLen + 12
|
|
)
|
|
|
|
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
|
pt, size, pkt, err := readHeader(data)
|
|
|
|
if pt != expPt {
|
|
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
|
|
}
|
|
|
|
if size != expSize {
|
|
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size)
|
|
}
|
|
|
|
if err != nil {
|
|
return 0, 0, []byte{}, err
|
|
}
|
|
|
|
return pt, size, pkt, nil
|
|
}
|
|
|
|
func TestHandshake(t *testing.T) {
|
|
client := ClientConfig{
|
|
PAAToken: "abab",
|
|
}
|
|
s := &SessionInfo{}
|
|
hc := &ServerConf{
|
|
TokenAuth: true,
|
|
}
|
|
h := NewServer(s, hc)
|
|
|
|
data := client.handshakeRequest()
|
|
|
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen)
|
|
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
|
|
log.Printf("pkt: %x", pkt)
|
|
|
|
major, minor, version, extAuth := h.handshakeRequest(pkt)
|
|
if major != MajorVersion || minor != MinorVersion || version != Version {
|
|
t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d",
|
|
major, minor, version, MajorVersion, MinorVersion, Version)
|
|
}
|
|
|
|
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
|
t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA)
|
|
}
|
|
|
|
data = h.handshakeResponse(0x0, 0x0)
|
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
log.Printf("pkt: %x", pkt)
|
|
|
|
caps, err := client.handshakeResponse(pkt)
|
|
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
|
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA)
|
|
}
|
|
}
|
|
|
|
func TestTunnelCreation(t *testing.T) {
|
|
client := ClientConfig{
|
|
PAAToken: "abab",
|
|
}
|
|
s := &SessionInfo{}
|
|
hc := &ServerConf{
|
|
TokenAuth: true,
|
|
}
|
|
h := NewServer(s, hc)
|
|
|
|
data := client.tunnelRequest()
|
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
|
uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2))
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
|
|
caps, token := h.tunnelRequest(pkt)
|
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
|
t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
|
}
|
|
if token != client.PAAToken {
|
|
t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken)
|
|
}
|
|
|
|
data = h.tunnelResponse()
|
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
|
|
tid, caps, err := client.tunnelResponse(pkt)
|
|
if err != nil {
|
|
t.Fatalf("Error %s", err)
|
|
}
|
|
if tid != tunnelId {
|
|
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
|
|
}
|
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
|
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
|
}
|
|
}
|
|
|
|
func TestTunnelAuth(t *testing.T) {
|
|
name := "test_name"
|
|
client := ClientConfig{
|
|
Name: name,
|
|
}
|
|
s := &SessionInfo{}
|
|
hc := &ServerConf{
|
|
TokenAuth: true,
|
|
IdleTimeout: 10,
|
|
RedirectFlags: RedirectFlags{
|
|
Clipboard: true,
|
|
},
|
|
}
|
|
h := NewServer(s, hc)
|
|
|
|
data := client.tunnelAuthRequest()
|
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
|
|
n := h.tunnelAuthRequest(pkt)
|
|
if n != name {
|
|
t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name)
|
|
}
|
|
|
|
data = h.tunnelAuthResponse()
|
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen)
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
flags, timeout, err := client.tunnelAuthResponse(pkt)
|
|
if err != nil {
|
|
t.Fatalf("tunnel auth error %s", err)
|
|
}
|
|
if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD {
|
|
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
|
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
|
}
|
|
if int(timeout) != hc.IdleTimeout {
|
|
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
|
timeout, hc.IdleTimeout)
|
|
}
|
|
}
|
|
|
|
func TestChannelCreation(t *testing.T) {
|
|
server := "test_server"
|
|
client := ClientConfig{
|
|
Server: server,
|
|
Port: 3389,
|
|
}
|
|
s := &SessionInfo{}
|
|
hc := &ServerConf{
|
|
TokenAuth: true,
|
|
IdleTimeout: 10,
|
|
RedirectFlags: RedirectFlags{
|
|
Clipboard: true,
|
|
},
|
|
}
|
|
h := NewServer(s, hc)
|
|
|
|
data := client.channelRequest()
|
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
hServer, hPort := h.channelRequest(pkt)
|
|
if hServer != server {
|
|
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
|
|
}
|
|
if int(hPort) != client.Port {
|
|
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, client.Port)
|
|
}
|
|
|
|
data = h.channelResponse()
|
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
|
|
if err != nil {
|
|
t.Fatalf("verifyHeader failed: %s", err)
|
|
}
|
|
channelId, err := client.channelResponse(pkt)
|
|
if err != nil {
|
|
t.Fatalf("channelResponse failed: %s", err)
|
|
}
|
|
if channelId < 1 {
|
|
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
|
|
}
|
|
}
|