mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 22:13:50 +02:00
Refactor names
This commit is contained in:
parent
29d4b276e6
commit
fe6509d8ca
5 changed files with 154 additions and 141 deletions
4
main.go
4
main.go
|
@ -103,7 +103,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the gateway
|
// create the gateway
|
||||||
handlerConfig := protocol.HandlerConf{
|
handlerConfig := protocol.ServerConf{
|
||||||
IdleTimeout: conf.Caps.IdleTimeout,
|
IdleTimeout: conf.Caps.IdleTimeout,
|
||||||
TokenAuth: conf.Caps.TokenAuth,
|
TokenAuth: conf.Caps.TokenAuth,
|
||||||
SmartCardAuth: conf.Caps.SmartCardAuth,
|
SmartCardAuth: conf.Caps.SmartCardAuth,
|
||||||
|
@ -120,7 +120,7 @@ func main() {
|
||||||
VerifyServerFunc: security.VerifyServerFunc,
|
VerifyServerFunc: security.VerifyServerFunc,
|
||||||
}
|
}
|
||||||
gw := protocol.Gateway{
|
gw := protocol.Gateway{
|
||||||
HandlerConf: &handlerConfig,
|
ServerConf: &handlerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
|
|
37
protocol/common.go
Normal file
37
protocol/common.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||||
|
size := len(data) + 8
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(pktType))
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint32(size))
|
||||||
|
buf.Write(data)
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
||||||
|
// header needs to be 8 min
|
||||||
|
if len(data) < 8 {
|
||||||
|
return 0, 0, nil, errors.New("header too short, fragment likely")
|
||||||
|
}
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
binary.Read(r, binary.LittleEndian, &packetType)
|
||||||
|
r.Seek(4, io.SeekStart)
|
||||||
|
binary.Read(r, binary.LittleEndian, &size)
|
||||||
|
if len(data) < int(size) {
|
||||||
|
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
|
||||||
|
}
|
||||||
|
return packetType, size, data[8:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,11 @@ const (
|
||||||
TunnelAuthResponseLen = HeaderLen + 16
|
TunnelAuthResponseLen = HeaderLen + 16
|
||||||
)
|
)
|
||||||
|
|
||||||
func verifyPacketHeader(data []byte , expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
||||||
pt, size, pkt, err := readHeader(data)
|
pt, size, pkt, err := readHeader(data)
|
||||||
|
|
||||||
if pt != expPt {
|
if pt != expPt {
|
||||||
return 0,0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
|
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if size != expSize {
|
if size != expSize {
|
||||||
|
@ -38,6 +38,11 @@ func TestHandshake(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
PAAToken: "abab",
|
PAAToken: "abab",
|
||||||
}
|
}
|
||||||
|
s := &SessionInfo{}
|
||||||
|
hc := &ServerConf{
|
||||||
|
TokenAuth: true,
|
||||||
|
}
|
||||||
|
h := NewServer(s, hc)
|
||||||
|
|
||||||
data := client.handshakeRequest()
|
data := client.handshakeRequest()
|
||||||
|
|
||||||
|
@ -49,23 +54,16 @@ func TestHandshake(t *testing.T) {
|
||||||
|
|
||||||
log.Printf("pkt: %x", pkt)
|
log.Printf("pkt: %x", pkt)
|
||||||
|
|
||||||
major, minor, version, extAuth := readHandshake(pkt)
|
major, minor, version, extAuth := h.handshakeRequest(pkt)
|
||||||
if major != MajorVersion || minor != MinorVersion || version != Version {
|
if major != MajorVersion || minor != MinorVersion || version != Version {
|
||||||
t.Fatalf("readHandshake failed got version %d.%d protocol %d, expected %d.%d protocol %d",
|
t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d",
|
||||||
major, minor, version, MajorVersion, MinorVersion, Version)
|
major, minor, version, MajorVersion, MinorVersion, Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||||
t.Fatalf("readHandshake failed got ext auth %d, expected %d", extAuth, extAuth | HTTP_EXTENDED_AUTH_PAA)
|
t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &SessionInfo{}
|
|
||||||
hc := &HandlerConf{
|
|
||||||
TokenAuth: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
h := NewHandler(s, hc)
|
|
||||||
|
|
||||||
data = h.handshakeResponse(0x0, 0x0)
|
data = h.handshakeResponse(0x0, 0x0)
|
||||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -75,7 +73,7 @@ func TestHandshake(t *testing.T) {
|
||||||
|
|
||||||
caps, err := client.handshakeResponse(pkt)
|
caps, err := client.handshakeResponse(pkt)
|
||||||
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||||
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps | HTTP_EXTENDED_AUTH_PAA)
|
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,23 +81,28 @@ func TestTunnelCreation(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
PAAToken: "abab",
|
PAAToken: "abab",
|
||||||
}
|
}
|
||||||
|
s := &SessionInfo{}
|
||||||
|
hc := &ServerConf{
|
||||||
|
TokenAuth: true,
|
||||||
|
}
|
||||||
|
h := NewServer(s, hc)
|
||||||
|
|
||||||
data := client.tunnelRequest()
|
data := client.tunnelRequest()
|
||||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
||||||
uint32(TunnelCreateRequestLen + 2 + len(client.PAAToken)*2))
|
uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("verifyHeader failed: %s", err)
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
caps, token := readCreateTunnelRequest(pkt)
|
caps, token := h.tunnelRequest(pkt)
|
||||||
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||||
t.Fatalf("readCreateTunnelRequest failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT)
|
t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||||
}
|
}
|
||||||
if token != client.PAAToken {
|
if token != client.PAAToken {
|
||||||
t.Fatalf("readCreateTunnelRequest failed got token %s, expected %s", token, client.PAAToken)
|
t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
data = createTunnelResponse()
|
data = h.tunnelResponse()
|
||||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("verifyHeader failed: %s", err)
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
|
@ -113,35 +116,35 @@ func TestTunnelCreation(t *testing.T) {
|
||||||
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
|
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
|
||||||
}
|
}
|
||||||
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||||
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps | HTTP_CAPABILITY_IDLE_TIMEOUT)
|
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunnelAuth(t *testing.T) {
|
func TestTunnelAuth(t *testing.T) {
|
||||||
client := ClientConfig{}
|
client := ClientConfig{}
|
||||||
s := &SessionInfo{}
|
s := &SessionInfo{}
|
||||||
hc := &HandlerConf{
|
hc := &ServerConf{
|
||||||
TokenAuth: true,
|
TokenAuth: true,
|
||||||
IdleTimeout: 10,
|
IdleTimeout: 10,
|
||||||
RedirectFlags: RedirectFlags{
|
RedirectFlags: RedirectFlags{
|
||||||
Clipboard: true,
|
Clipboard: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h := NewHandler(s, hc)
|
h := NewServer(s, hc)
|
||||||
name := "test_name"
|
name := "test_name"
|
||||||
|
|
||||||
data := client.tunnelAuthRequest(name)
|
data := client.tunnelAuthRequest(name)
|
||||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen + len(name) * 2))
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("verifyHeader failed: %s", err)
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n := h.readTunnelAuthRequest(pkt)
|
n := h.tunnelAuthRequest(pkt)
|
||||||
if n != name {
|
if n != name {
|
||||||
t.Fatalf("readTunnelAuthRequest failed got name %s, expected %s", n, name)
|
t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
data = h.createTunnelAuthResponse()
|
data = h.tunnelAuthResponse()
|
||||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen)
|
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("verifyHeader failed: %s", err)
|
t.Fatalf("verifyHeader failed: %s", err)
|
||||||
|
@ -152,7 +155,7 @@ func TestTunnelAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD {
|
if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD {
|
||||||
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
||||||
flags, flags | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
||||||
}
|
}
|
||||||
if int(timeout) != hc.IdleTimeout {
|
if int(timeout) != hc.IdleTimeout {
|
||||||
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
|
@ -42,7 +42,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
HandlerConf *HandlerConf
|
ServerConf *ServerConf
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
|
@ -102,12 +102,12 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
||||||
inout, _ := transport.NewWS(c)
|
inout, _ := transport.NewWS(c)
|
||||||
s.TransportOut = inout
|
s.TransportOut = inout
|
||||||
s.TransportIn = inout
|
s.TransportIn = inout
|
||||||
handler := NewHandler(s, g.HandlerConf)
|
handler := NewServer(s, g.ServerConf)
|
||||||
handler.Process(ctx)
|
handler.Process(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
||||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||||
|
@ -145,8 +145,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context()))
|
||||||
handler := NewHandler(s, g.HandlerConf)
|
handler := NewServer(s, g.ServerConf)
|
||||||
handler.Process(r.Context())
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ type RedirectFlags struct {
|
||||||
EnableAll bool
|
EnableAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handler struct {
|
type Server struct {
|
||||||
Session *SessionInfo
|
Session *SessionInfo
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
VerifyTunnelCreate VerifyTunnelCreate
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||||
|
@ -41,7 +41,7 @@ type Handler struct {
|
||||||
State int
|
State int
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandlerConf struct {
|
type ServerConf struct {
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
VerifyTunnelCreate VerifyTunnelCreate
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||||
VerifyServerFunc VerifyServerFunc
|
VerifyServerFunc VerifyServerFunc
|
||||||
|
@ -51,8 +51,8 @@ type HandlerConf struct {
|
||||||
TokenAuth bool
|
TokenAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
||||||
h := &Handler{
|
h := &Server{
|
||||||
State: SERVER_STATE_INITIAL,
|
State: SERVER_STATE_INITIAL,
|
||||||
Session: s,
|
Session: s,
|
||||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||||
|
@ -68,9 +68,9 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||||
|
|
||||||
const tunnelId = 10
|
const tunnelId = 10
|
||||||
|
|
||||||
func (h *Handler) Process(ctx context.Context) error {
|
func (s *Server) Process(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := h.ReadMessage()
|
pt, sz, pkt, err := s.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %s", err)
|
log.Printf("Cannot read message from stream %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -78,89 +78,89 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
|
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshake from %s", client.GetClientIp(ctx))
|
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx))
|
||||||
if h.State != SERVER_STATE_INITIAL {
|
if s.State != SERVER_STATE_INITIAL {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
|
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
|
major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
|
||||||
msg := h.handshakeResponse(major, minor)
|
msg := s.handshakeResponse(major, minor)
|
||||||
h.Session.TransportOut.WritePacket(msg)
|
s.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_HANDSHAKE
|
s.State = SERVER_STATE_HANDSHAKE
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
log.Printf("Tunnel create")
|
log.Printf("Tunnel create")
|
||||||
if h.State != SERVER_STATE_HANDSHAKE {
|
if s.State != SERVER_STATE_HANDSHAKE {
|
||||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||||
h.State, SERVER_STATE_HANDSHAKE)
|
s.State, SERVER_STATE_HANDSHAKE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
_, cookie := s.tunnelRequest(pkt)
|
||||||
if h.VerifyTunnelCreate != nil {
|
if s.VerifyTunnelCreate != nil {
|
||||||
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
|
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
|
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
|
||||||
return errors.New("invalid PAA cookie")
|
return errors.New("invalid PAA cookie")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := createTunnelResponse()
|
msg := s.tunnelResponse()
|
||||||
h.Session.TransportOut.WritePacket(msg)
|
s.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_TUNNEL_CREATE
|
s.State = SERVER_STATE_TUNNEL_CREATE
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
log.Printf("Tunnel auth")
|
log.Printf("Tunnel auth")
|
||||||
if h.State != SERVER_STATE_TUNNEL_CREATE {
|
if s.State != SERVER_STATE_TUNNEL_CREATE {
|
||||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||||
h.State, SERVER_STATE_TUNNEL_CREATE)
|
s.State, SERVER_STATE_TUNNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
client := h.readTunnelAuthRequest(pkt)
|
client := s.tunnelAuthRequest(pkt)
|
||||||
if h.VerifyTunnelAuthFunc != nil {
|
if s.VerifyTunnelAuthFunc != nil {
|
||||||
if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
|
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||||
log.Printf("Invalid client name: %s", client)
|
log.Printf("Invalid client name: %s", client)
|
||||||
return errors.New("invalid client name")
|
return errors.New("invalid client name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := h.createTunnelAuthResponse()
|
msg := s.tunnelAuthResponse()
|
||||||
h.Session.TransportOut.WritePacket(msg)
|
s.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
log.Printf("Channel create")
|
log.Printf("Channel create")
|
||||||
if h.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||||
h.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := s.channelRequest(pkt)
|
||||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
if h.VerifyServerFunc != nil {
|
if s.VerifyServerFunc != nil {
|
||||||
if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
|
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
|
||||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||||
return errors.New("denied by security policy")
|
return errors.New("denied by security policy")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("Establishing connection to RDP server: %s", host)
|
log.Printf("Establishing connection to RDP server: %s", host)
|
||||||
h.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error connecting to %s, %s", host, err)
|
log.Printf("Error connecting to %s, %s", host, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
msg := createChannelCreateResponse()
|
msg := s.channelResponse()
|
||||||
h.Session.TransportOut.WritePacket(msg)
|
s.Session.TransportOut.WritePacket(msg)
|
||||||
|
|
||||||
// Make sure to start the flow from the RDP server first otherwise connections
|
// Make sure to start the flow from the RDP server first otherwise connections
|
||||||
// might hang eventually
|
// might hang eventually
|
||||||
go h.sendDataPacket()
|
go s.sendDataPacket()
|
||||||
h.State = SERVER_STATE_CHANNEL_CREATE
|
s.State = SERVER_STATE_CHANNEL_CREATE
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||||
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
h.State = SERVER_STATE_OPENED
|
s.State = SERVER_STATE_OPENED
|
||||||
h.forwardDataPacket(pkt)
|
s.forwardDataPacket(pkt)
|
||||||
case PKT_TYPE_KEEPALIVE:
|
case PKT_TYPE_KEEPALIVE:
|
||||||
// keepalives can be received while the channel is not open yet
|
// keepalives can be received while the channel is not open yet
|
||||||
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||||
log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,26 +168,26 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||||
case PKT_TYPE_CLOSE_CHANNEL:
|
case PKT_TYPE_CLOSE_CHANNEL:
|
||||||
log.Printf("Close channel")
|
log.Printf("Close channel")
|
||||||
if h.State != SERVER_STATE_OPENED {
|
if s.State != SERVER_STATE_OPENED {
|
||||||
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
h.Session.TransportIn.Close()
|
s.Session.TransportIn.Close()
|
||||||
h.Session.TransportOut.Close()
|
s.Session.TransportOut.Close()
|
||||||
h.State = SERVER_STATE_CLOSED
|
s.State = SERVER_STATE_CLOSED
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
|
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||||
fragment := false
|
fragment := false
|
||||||
index := 0
|
index := 0
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
size, pkt, err := h.Session.TransportIn.ReadPacket()
|
size, pkt, err := s.Session.TransportIn.ReadPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, []byte{0, 0}, err
|
return 0, 0, []byte{0, 0}, err
|
||||||
}
|
}
|
||||||
|
@ -219,15 +219,15 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a packet the is a response to a handshake request
|
// Creates a packet the is a response to a handshakeRequest request
|
||||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||||
// but could be in Windows. However the NTLM protocol is insecure
|
// but could be in Windows. However the NTLM protocol is insecure
|
||||||
func (h *Handler) handshakeResponse(major byte, minor byte) []byte {
|
func (s *Server) handshakeResponse(major byte, minor byte) []byte {
|
||||||
var caps uint16
|
var caps uint16
|
||||||
if h.SmartCardAuth {
|
if s.SmartCardAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||||
}
|
}
|
||||||
if h.TokenAuth {
|
if s.TokenAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,22 +240,7 @@ func (h *Handler) handshakeResponse(major byte, minor byte) []byte {
|
||||||
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
||||||
// header needs to be 8 min
|
|
||||||
if len(data) < 8 {
|
|
||||||
return 0, 0, nil, errors.New("header too short, fragment likely")
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(data)
|
|
||||||
binary.Read(r, binary.LittleEndian, &packetType)
|
|
||||||
r.Seek(4, io.SeekStart)
|
|
||||||
binary.Read(r, binary.LittleEndian, &size)
|
|
||||||
if len(data) < int(size) {
|
|
||||||
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
|
|
||||||
}
|
|
||||||
return packetType, size, data[8:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
binary.Read(r, binary.LittleEndian, &major)
|
binary.Read(r, binary.LittleEndian, &major)
|
||||||
binary.Read(r, binary.LittleEndian, &minor)
|
binary.Read(r, binary.LittleEndian, &minor)
|
||||||
|
@ -266,7 +251,7 @@ func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
|
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||||
var fields uint16
|
var fields uint16
|
||||||
|
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
|
@ -285,7 +270,7 @@ func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTunnelResponse() []byte {
|
func (s *Server) tunnelResponse() []byte {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||||
|
@ -301,7 +286,7 @@ func createTunnelResponse() []byte {
|
||||||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) readTunnelAuthRequest(data []byte) string {
|
func (s *Server) tunnelAuthRequest(data []byte) string {
|
||||||
buf := bytes.NewReader(data)
|
buf := bytes.NewReader(data)
|
||||||
|
|
||||||
var size uint16
|
var size uint16
|
||||||
|
@ -313,7 +298,7 @@ func (h *Handler) readTunnelAuthRequest(data []byte) string {
|
||||||
return clientName
|
return clientName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) createTunnelAuthResponse() []byte {
|
func (s *Server) tunnelAuthResponse() []byte {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||||
|
@ -321,17 +306,17 @@ func (h *Handler) createTunnelAuthResponse() []byte {
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
|
||||||
// idle timeout
|
// idle timeout
|
||||||
if h.IdleTimeout < 0 {
|
if s.IdleTimeout < 0 {
|
||||||
h.IdleTimeout = 0
|
s.IdleTimeout = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(h.RedirectFlags)) // redir flags
|
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(h.IdleTimeout)) // timeout in minutes
|
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
|
||||||
|
|
||||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func readChannelCreateRequest(data []byte) (server string, port uint16) {
|
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
|
||||||
buf := bytes.NewReader(data)
|
buf := bytes.NewReader(data)
|
||||||
|
|
||||||
var resourcesSize byte
|
var resourcesSize byte
|
||||||
|
@ -353,7 +338,7 @@ func readChannelCreateRequest(data []byte) (server string, port uint16) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createChannelCreateResponse() []byte {
|
func (s *Server) channelResponse() []byte {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||||
|
@ -372,7 +357,7 @@ func createChannelCreateResponse() []byte {
|
||||||
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) forwardDataPacket(data []byte) {
|
func (s *Server) forwardDataPacket(data []byte) {
|
||||||
buf := bytes.NewReader(data)
|
buf := bytes.NewReader(data)
|
||||||
|
|
||||||
var cblen uint16
|
var cblen uint16
|
||||||
|
@ -380,38 +365,26 @@ func (h *Handler) forwardDataPacket(data []byte) {
|
||||||
pkt := make([]byte, cblen)
|
pkt := make([]byte, cblen)
|
||||||
binary.Read(buf, binary.LittleEndian, &pkt)
|
binary.Read(buf, binary.LittleEndian, &pkt)
|
||||||
|
|
||||||
h.Remote.Write(pkt)
|
s.Remote.Write(pkt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendDataPacket() {
|
func (s *Server) sendDataPacket() {
|
||||||
defer h.Remote.Close()
|
defer s.Remote.Close()
|
||||||
b1 := new(bytes.Buffer)
|
b1 := new(bytes.Buffer)
|
||||||
buf := make([]byte, 4086)
|
buf := make([]byte, 4086)
|
||||||
for {
|
for {
|
||||||
n, err := h.Remote.Read(buf)
|
n, err := s.Remote.Read(buf)
|
||||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error reading from conn %s", err)
|
log.Printf("Error reading from conn %s", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
b1.Write(buf[:n])
|
b1.Write(buf[:n])
|
||||||
h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||||
b1.Reset()
|
b1.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
|
||||||
size := len(data) + 8
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(pktType))
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(size))
|
|
||||||
buf.Write(data)
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeRedirectFlags(flags RedirectFlags) int {
|
func makeRedirectFlags(flags RedirectFlags) int {
|
||||||
var redir = 0
|
var redir = 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue