mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-20 07:23:48 +02:00
Refactor
This commit is contained in:
parent
4e99b4e88f
commit
9c19a1b40a
8 changed files with 144 additions and 87 deletions
|
@ -1,4 +1,4 @@
|
||||||
package client
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
6
main.go
6
main.go
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"github.com/bolkedebruin/rdpgw/api"
|
"github.com/bolkedebruin/rdpgw/api"
|
||||||
"github.com/bolkedebruin/rdpgw/client"
|
"github.com/bolkedebruin/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/config"
|
"github.com/bolkedebruin/rdpgw/config"
|
||||||
"github.com/bolkedebruin/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/protocol"
|
||||||
"github.com/bolkedebruin/rdpgw/security"
|
"github.com/bolkedebruin/rdpgw/security"
|
||||||
|
@ -123,8 +123,8 @@ func main() {
|
||||||
ServerConf: &handlerConfig,
|
ServerConf: &handlerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
http.HandleFunc("/callback", api.HandleCallback)
|
http.HandleFunc("/callback", api.HandleCallback)
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,10 +17,67 @@ const (
|
||||||
|
|
||||||
type ClientConfig struct {
|
type ClientConfig struct {
|
||||||
SmartCardAuth bool
|
SmartCardAuth bool
|
||||||
PAAToken string
|
PAAToken string
|
||||||
NTLMAuth bool
|
NTLMAuth bool
|
||||||
GatewayConn transport.Transport
|
Session *SessionInfo
|
||||||
LocalConn net.Conn
|
LocalConn net.Conn
|
||||||
|
Server string
|
||||||
|
Port int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConfig) ConnectAndForward() error {
|
||||||
|
c.Session.TransportOut.WritePacket(c.handshakeRequest())
|
||||||
|
|
||||||
|
for {
|
||||||
|
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot read message from stream %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pt {
|
||||||
|
case PKT_TYPE_HANDSHAKE_RESPONSE:
|
||||||
|
caps, err := c.handshakeResponse(pkt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot connect to %s due to %s", c.Server, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("Handshake response received. Caps: %d", caps)
|
||||||
|
c.Session.TransportOut.WritePacket(c.tunnelRequest())
|
||||||
|
case PKT_TYPE_TUNNEL_RESPONSE:
|
||||||
|
tid, caps, err := c.tunnelResponse(pkt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot setup tunnel due to %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
||||||
|
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
|
||||||
|
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
||||||
|
flags, timeout, err := c.tunnelAuthResponse(pkt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
||||||
|
c.Session.TransportOut.WritePacket(c.channelRequest())
|
||||||
|
case PKT_TYPE_CHANNEL_RESPONSE:
|
||||||
|
cid, err := c.channelResponse(pkt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cid < 1 {
|
||||||
|
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
||||||
|
}
|
||||||
|
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||||
|
go forward(c.LocalConn, c.Session.TransportOut)
|
||||||
|
case PKT_TYPE_DATA:
|
||||||
|
receive(pkt, c.LocalConn)
|
||||||
|
default:
|
||||||
|
log.Printf("Unknown packet type received: %d size %d", pt, sz)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) handshakeRequest() []byte {
|
func (c *ClientConfig) handshakeRequest() []byte {
|
||||||
|
@ -83,7 +140,7 @@ func (c *ClientConfig) tunnelRequest() []byte {
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, caps)
|
binary.Write(buf, binary.LittleEndian, caps)
|
||||||
binary.Write(buf, binary.LittleEndian, fields)
|
binary.Write(buf, binary.LittleEndian, fields)
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
|
||||||
if len(c.PAAToken) > 0 {
|
if len(c.PAAToken) > 0 {
|
||||||
utf16Token := EncodeUTF16(c.PAAToken)
|
utf16Token := EncodeUTF16(c.PAAToken)
|
||||||
|
@ -119,8 +176,8 @@ func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) tunnelAuthRequest(name string) []byte {
|
func (c *ClientConfig) tunnelAuthRequest() []byte {
|
||||||
utf16name := EncodeUTF16(name)
|
utf16name := EncodeUTF16(c.Name)
|
||||||
size := uint16(len(utf16name))
|
size := uint16(len(utf16name))
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
@ -153,14 +210,14 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
|
func (c *ClientConfig) channelRequest() []byte {
|
||||||
utf16server := EncodeUTF16(server)
|
utf16server := EncodeUTF16(c.Server)
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
|
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
|
||||||
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
|
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(port))
|
binary.Write(buf, binary.LittleEndian, uint16(c.Port))
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
|
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
|
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
|
||||||
buf.Write(utf16server)
|
buf.Write(utf16server)
|
||||||
|
|
|
@ -10,6 +10,62 @@ import (
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RedirectFlags struct {
|
||||||
|
Clipboard bool
|
||||||
|
Port bool
|
||||||
|
Drive bool
|
||||||
|
Printer bool
|
||||||
|
Pnp bool
|
||||||
|
DisableAll bool
|
||||||
|
EnableAll bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionInfo struct {
|
||||||
|
ConnId string
|
||||||
|
TransportIn transport.Transport
|
||||||
|
TransportOut transport.Transport
|
||||||
|
RemoteServer string
|
||||||
|
ClientIp string
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
|
||||||
|
fragment := false
|
||||||
|
index := 0
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
|
||||||
|
for {
|
||||||
|
size, pkt, err := in.ReadPacket()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, []byte{0, 0}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for fragments
|
||||||
|
var pt uint16
|
||||||
|
var sz uint32
|
||||||
|
var msg []byte
|
||||||
|
|
||||||
|
if !fragment {
|
||||||
|
pt, sz, msg, err = readHeader(pkt[:size])
|
||||||
|
if err != nil {
|
||||||
|
fragment = true
|
||||||
|
index = copy(buf, pkt[:size])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
index = 0
|
||||||
|
} else {
|
||||||
|
fragment = false
|
||||||
|
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
|
||||||
|
// header is corrupted even after defragmenting
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, []byte{0, 0}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !fragment {
|
||||||
|
return int(pt), int(sz), msg, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||||
size := len(data) + 8
|
size := len(data) + 8
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
|
|
@ -2,7 +2,7 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/bolkedebruin/rdpgw/client"
|
"github.com/bolkedebruin/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/transport"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
@ -45,14 +45,6 @@ type Gateway struct {
|
||||||
ServerConf *ServerConf
|
ServerConf *ServerConf
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
|
||||||
ConnId string
|
|
||||||
TransportIn transport.Transport
|
|
||||||
TransportOut transport.Transport
|
|
||||||
RemoteServer string
|
|
||||||
ClientIp string
|
|
||||||
}
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
|
||||||
|
@ -118,7 +110,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
|
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
||||||
|
|
||||||
s.TransportOut = out
|
s.TransportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
@ -139,13 +131,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
s.TransportIn = in
|
s.TransportIn = in
|
||||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
|
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
||||||
handler := NewServer(s, g.ServerConf)
|
handler := NewServer(s, g.ServerConf)
|
||||||
handler.Process(r.Context())
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,4 +204,4 @@ func TestChannelCreation(t *testing.T) {
|
||||||
if channelId < 1 {
|
if channelId < 1 {
|
||||||
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
|
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/client"
|
"github.com/bolkedebruin/rdpgw/common"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
@ -17,16 +17,6 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
||||||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||||
|
|
||||||
type RedirectFlags struct {
|
|
||||||
Clipboard bool
|
|
||||||
Port bool
|
|
||||||
Drive bool
|
|
||||||
Printer bool
|
|
||||||
Pnp bool
|
|
||||||
DisableAll bool
|
|
||||||
EnableAll bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Session *SessionInfo
|
Session *SessionInfo
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
VerifyTunnelCreate VerifyTunnelCreate
|
||||||
|
@ -70,7 +60,7 @@ const tunnelId = 10
|
||||||
|
|
||||||
func (s *Server) Process(ctx context.Context) error {
|
func (s *Server) Process(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := s.ReadMessage()
|
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %s", err)
|
log.Printf("Cannot read message from stream %s", err)
|
||||||
return err
|
return err
|
||||||
|
@ -78,7 +68,7 @@ func (s *Server) Process(ctx context.Context) error {
|
||||||
|
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx))
|
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||||
if s.State != SERVER_STATE_INITIAL {
|
if s.State != SERVER_STATE_INITIAL {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
|
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
|
@ -97,7 +87,7 @@ func (s *Server) Process(ctx context.Context) error {
|
||||||
_, cookie := s.tunnelRequest(pkt)
|
_, cookie := s.tunnelRequest(pkt)
|
||||||
if s.VerifyTunnelCreate != nil {
|
if s.VerifyTunnelCreate != nil {
|
||||||
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
|
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||||
return errors.New("invalid PAA cookie")
|
return errors.New("invalid PAA cookie")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,44 +171,6 @@ func (s *Server) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
|
|
||||||
fragment := false
|
|
||||||
index := 0
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
|
|
||||||
for {
|
|
||||||
size, pkt, err := s.Session.TransportIn.ReadPacket()
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, []byte{0, 0}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for fragments
|
|
||||||
var pt uint16
|
|
||||||
var sz uint32
|
|
||||||
var msg []byte
|
|
||||||
|
|
||||||
if !fragment {
|
|
||||||
pt, sz, msg, err = readHeader(pkt[:size])
|
|
||||||
if err != nil {
|
|
||||||
fragment = true
|
|
||||||
index = copy(buf, pkt[:size])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
index = 0
|
|
||||||
} else {
|
|
||||||
fragment = false
|
|
||||||
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
|
|
||||||
// header is corrupted even after defragmenting
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, []byte{0, 0}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !fragment {
|
|
||||||
return int(pt), int(sz), msg, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a packet the is a response to a handshakeRequest request
|
// Creates a packet the is a response to a handshakeRequest request
|
||||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||||
// but could be in Windows. However the NTLM protocol is insecure
|
// but could be in Windows. However the NTLM protocol is insecure
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/client"
|
"github.com/bolkedebruin/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/protocol"
|
||||||
"github.com/dgrijalva/jwt-go/v4"
|
"github.com/dgrijalva/jwt-go/v4"
|
||||||
"log"
|
"log"
|
||||||
|
@ -55,9 +55,9 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ClientIp != client.GetClientIp(ctx) {
|
if s.ClientIp != common.GetClientIp(ctx) {
|
||||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||||
client.GetClientIp(ctx), s.ClientIp)
|
common.GetClientIp(ctx), s.ClientIp)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
|
||||||
|
|
||||||
c := customClaims{
|
c := customClaims{
|
||||||
RemoteServer: server,
|
RemoteServer: server,
|
||||||
ClientIP: client.GetClientIp(ctx),
|
ClientIP: common.GetClientIp(ctx),
|
||||||
StandardClaims: jwt.StandardClaims{
|
StandardClaims: jwt.StandardClaims{
|
||||||
ExpiresAt: exp,
|
ExpiresAt: exp,
|
||||||
IssuedAt: now,
|
IssuedAt: now,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue