This commit is contained in:
Bolke de Bruin 2020-08-01 21:23:34 +02:00
parent 4e99b4e88f
commit 9c19a1b40a
8 changed files with 144 additions and 87 deletions

View file

@ -1,4 +1,4 @@
package client
package common
import (
"context"

View file

@ -4,7 +4,7 @@ import (
"context"
"crypto/tls"
"github.com/bolkedebruin/rdpgw/api"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/config"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/security"
@ -123,8 +123,8 @@ func main() {
ServerConf: &handlerConfig,
}
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/callback", api.HandleCallback)

View file

@ -4,8 +4,8 @@ import (
"bytes"
"encoding/binary"
"fmt"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"log"
"net"
)
@ -17,10 +17,67 @@ const (
type ClientConfig struct {
SmartCardAuth bool
PAAToken string
NTLMAuth bool
GatewayConn transport.Transport
LocalConn net.Conn
PAAToken string
NTLMAuth bool
Session *SessionInfo
LocalConn net.Conn
Server string
Port int
Name string
}
func (c *ClientConfig) ConnectAndForward() error {
c.Session.TransportOut.WritePacket(c.handshakeRequest())
for {
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_RESPONSE:
caps, err := c.handshakeResponse(pkt)
if err != nil {
log.Printf("Cannot connect to %s due to %s", c.Server, err)
return err
}
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.TransportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(pkt)
if err != nil {
log.Printf("Cannot setup tunnel due to %s", err)
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.TransportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
go forward(c.LocalConn, c.Session.TransportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", pt, sz)
}
}
}
func (c *ClientConfig) handshakeRequest() []byte {
@ -83,7 +140,7 @@ func (c *ClientConfig) tunnelRequest() []byte {
binary.Write(buf, binary.LittleEndian, caps)
binary.Write(buf, binary.LittleEndian, fields)
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
if len(c.PAAToken) > 0 {
utf16Token := EncodeUTF16(c.PAAToken)
@ -119,8 +176,8 @@ func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32
return
}
func (c *ClientConfig) tunnelAuthRequest(name string) []byte {
utf16name := EncodeUTF16(name)
func (c *ClientConfig) tunnelAuthRequest() []byte {
utf16name := EncodeUTF16(c.Name)
size := uint16(len(utf16name))
buf := new(bytes.Buffer)
@ -153,14 +210,14 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
return
}
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
utf16server := EncodeUTF16(server)
func (c *ClientConfig) channelRequest() []byte {
utf16server := EncodeUTF16(c.Server)
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(c.Port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server)

View file

@ -10,6 +10,62 @@ import (
"net"
)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := in.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)

View file

@ -2,7 +2,7 @@ package protocol
import (
"context"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
@ -45,14 +45,6 @@ type Gateway struct {
ServerConf *ServerConf
}
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute)
@ -118,7 +110,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
s.TransportOut = out
out.SendAccept(true)
@ -139,13 +131,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context()))
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf)
handler.Process(r.Context())
}

View file

@ -5,7 +5,7 @@ import (
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"io"
"log"
"net"
@ -17,16 +17,6 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type Server struct {
Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate
@ -70,7 +60,7 @@ const tunnelId = 10
func (s *Server) Process(ctx context.Context) error {
for {
pt, sz, pkt, err := s.ReadMessage()
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
@ -78,7 +68,7 @@ func (s *Server) Process(ctx context.Context) error {
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx))
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
return errors.New("wrong state")
@ -97,7 +87,7 @@ func (s *Server) Process(ctx context.Context) error {
_, cookie := s.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
return errors.New("invalid PAA cookie")
}
}
@ -181,44 +171,6 @@ func (s *Server) Process(ctx context.Context) error {
}
}
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := s.Session.TransportIn.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
// Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure

View file

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/dgrijalva/jwt-go/v4"
"log"
@ -55,9 +55,9 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
return false, nil
}
if s.ClientIp != client.GetClientIp(ctx) {
if s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
client.GetClientIp(ctx), s.ClientIp)
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
@ -78,7 +78,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
c := customClaims{
RemoteServer: server,
ClientIP: client.GetClientIp(ctx),
ClientIP: common.GetClientIp(ctx),
StandardClaims: jwt.StandardClaims{
ExpiresAt: exp,
IssuedAt: now,