This commit is contained in:
Bolke de Bruin 2020-08-01 21:23:34 +02:00
parent 4e99b4e88f
commit 9c19a1b40a
8 changed files with 144 additions and 87 deletions

View file

@ -1,4 +1,4 @@
package client package common
import ( import (
"context" "context"

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"github.com/bolkedebruin/rdpgw/api" "github.com/bolkedebruin/rdpgw/api"
"github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/config"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/security" "github.com/bolkedebruin/rdpgw/security"
@ -123,8 +123,8 @@ func main() {
ServerConf: &handlerConfig, ServerConf: &handlerConfig,
} }
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/callback", api.HandleCallback) http.HandleFunc("/callback", api.HandleCallback)

View file

@ -4,8 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/transport"
"io" "io"
"log"
"net" "net"
) )
@ -17,10 +17,67 @@ const (
type ClientConfig struct { type ClientConfig struct {
SmartCardAuth bool SmartCardAuth bool
PAAToken string PAAToken string
NTLMAuth bool NTLMAuth bool
GatewayConn transport.Transport Session *SessionInfo
LocalConn net.Conn LocalConn net.Conn
Server string
Port int
Name string
}
func (c *ClientConfig) ConnectAndForward() error {
c.Session.TransportOut.WritePacket(c.handshakeRequest())
for {
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_RESPONSE:
caps, err := c.handshakeResponse(pkt)
if err != nil {
log.Printf("Cannot connect to %s due to %s", c.Server, err)
return err
}
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.TransportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(pkt)
if err != nil {
log.Printf("Cannot setup tunnel due to %s", err)
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.TransportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
go forward(c.LocalConn, c.Session.TransportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", pt, sz)
}
}
} }
func (c *ClientConfig) handshakeRequest() []byte { func (c *ClientConfig) handshakeRequest() []byte {
@ -83,7 +140,7 @@ func (c *ClientConfig) tunnelRequest() []byte {
binary.Write(buf, binary.LittleEndian, caps) binary.Write(buf, binary.LittleEndian, caps)
binary.Write(buf, binary.LittleEndian, fields) binary.Write(buf, binary.LittleEndian, fields)
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
if len(c.PAAToken) > 0 { if len(c.PAAToken) > 0 {
utf16Token := EncodeUTF16(c.PAAToken) utf16Token := EncodeUTF16(c.PAAToken)
@ -119,8 +176,8 @@ func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32
return return
} }
func (c *ClientConfig) tunnelAuthRequest(name string) []byte { func (c *ClientConfig) tunnelAuthRequest() []byte {
utf16name := EncodeUTF16(name) utf16name := EncodeUTF16(c.Name)
size := uint16(len(utf16name)) size := uint16(len(utf16name))
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -153,14 +210,14 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
return return
} }
func (c *ClientConfig) channelRequest(server string, port uint16) []byte { func (c *ClientConfig) channelRequest() []byte {
utf16server := EncodeUTF16(server) utf16server := EncodeUTF16(c.Server)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3) binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(port)) binary.Write(buf, binary.LittleEndian, uint16(c.Port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3 binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server))) binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server) buf.Write(utf16server)

View file

@ -10,6 +10,62 @@ import (
"net" "net"
) )
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := in.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
func createPacket(pktType uint16, data []byte) (packet []byte) { func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8 size := len(data) + 8
buf := new(bytes.Buffer) buf := new(bytes.Buffer)

View file

@ -2,7 +2,7 @@ package protocol
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/transport" "github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -45,14 +45,6 @@ type Gateway struct {
ServerConf *ServerConf ServerConf *ServerConf
} }
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
@ -118,7 +110,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context())) log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
s.TransportOut = out s.TransportOut = out
out.SendAccept(true) out.SendAccept(true)
@ -139,13 +131,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
s.TransportIn = in s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration) c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context())) log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
in.SendAccept(false) in.SendAccept(false)
// read some initial data // read some initial data
in.Drain() in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context())) log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf) handler := NewServer(s, g.ServerConf)
handler.Process(r.Context()) handler.Process(r.Context())
} }

View file

@ -204,4 +204,4 @@ func TestChannelCreation(t *testing.T) {
if channelId < 1 { if channelId < 1 {
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId) t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
} }
} }

View file

@ -5,7 +5,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/common"
"io" "io"
"log" "log"
"net" "net"
@ -17,16 +17,6 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error) type VerifyServerFunc func(context.Context, string) (bool, error)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type Server struct { type Server struct {
Session *SessionInfo Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate
@ -70,7 +60,7 @@ const tunnelId = 10
func (s *Server) Process(ctx context.Context) error { func (s *Server) Process(ctx context.Context) error {
for { for {
pt, sz, pkt, err := s.ReadMessage() pt, sz, pkt, err := readMessage(s.Session.TransportIn)
if err != nil { if err != nil {
log.Printf("Cannot read message from stream %s", err) log.Printf("Cannot read message from stream %s", err)
return err return err
@ -78,7 +68,7 @@ func (s *Server) Process(ctx context.Context) error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx)) log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIAL { if s.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL) log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
return errors.New("wrong state") return errors.New("wrong state")
@ -97,7 +87,7 @@ func (s *Server) Process(ctx context.Context) error {
_, cookie := s.tunnelRequest(pkt) _, cookie := s.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil { if s.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok { if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx)) log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
return errors.New("invalid PAA cookie") return errors.New("invalid PAA cookie")
} }
} }
@ -181,44 +171,6 @@ func (s *Server) Process(ctx context.Context) error {
} }
} }
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := s.Session.TransportIn.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
// Creates a packet the is a response to a handshakeRequest request // Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure // but could be in Windows. However the NTLM protocol is insecure

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/client" "github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
"github.com/dgrijalva/jwt-go/v4" "github.com/dgrijalva/jwt-go/v4"
"log" "log"
@ -55,9 +55,9 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
return false, nil return false, nil
} }
if s.ClientIp != client.GetClientIp(ctx) { if s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s", log.Printf("Current client ip address %s does not match token client ip %s",
client.GetClientIp(ctx), s.ClientIp) common.GetClientIp(ctx), s.ClientIp)
return false, nil return false, nil
} }
@ -78,7 +78,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
c := customClaims{ c := customClaims{
RemoteServer: server, RemoteServer: server,
ClientIP: client.GetClientIp(ctx), ClientIP: common.GetClientIp(ctx),
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
ExpiresAt: exp, ExpiresAt: exp,
IssuedAt: now, IssuedAt: now,