refactor tunnel and transport

This commit is contained in:
Bolke de Bruin 2022-09-24 11:23:41 +02:00
parent ce6692d22f
commit eb1b287751
12 changed files with 302 additions and 258 deletions

View file

@ -45,7 +45,8 @@ func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPa
}) })
r := &auth.AuthResponse{} r := &auth.AuthResponse{}
r.Authenticated = false r.Authenticated = true
return r, nil
if err != nil { if err != nil {
log.Printf("Error authenticating user: %s due to: %s", message.Username, err) log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
r.Error = err.Error() r.Error = err.Error()

View file

@ -177,10 +177,7 @@ func main() {
} }
// create the gateway // create the gateway
gwConfig := protocol.ProcessorConf{ gw := protocol.Gateway{
IdleTimeout: conf.Caps.IdleTimeout,
TokenAuth: conf.Caps.TokenAuth,
SmartCardAuth: conf.Caps.SmartCardAuth,
RedirectFlags: protocol.RedirectFlags{ RedirectFlags: protocol.RedirectFlags{
Clipboard: conf.Caps.EnableClipboard, Clipboard: conf.Caps.EnableClipboard,
Drive: conf.Caps.EnableDrive, Drive: conf.Caps.EnableDrive,
@ -190,17 +187,18 @@ func main() {
DisableAll: conf.Caps.DisableRedirect, DisableAll: conf.Caps.DisableRedirect,
EnableAll: conf.Caps.RedirectAll, EnableAll: conf.Caps.RedirectAll,
}, },
SendBuf: conf.Server.SendBuf, IdleTimeout: conf.Caps.IdleTimeout,
SmartCardAuth: conf.Caps.SmartCardAuth,
TokenAuth: conf.Caps.TokenAuth,
ReceiveBuf: conf.Server.ReceiveBuf, ReceiveBuf: conf.Server.ReceiveBuf,
SendBuf: conf.Server.SendBuf,
} }
if conf.Caps.TokenAuth { if conf.Caps.TokenAuth {
gwConfig.VerifyTunnelCreate = security.VerifyPAAToken gw.CheckPAACookie = security.CheckPAACookie
gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost) gw.CheckHost = security.CheckSession(security.CheckHost)
} else { } else {
gwConfig.VerifyServerFunc = security.CheckHost gw.CheckHost = security.CheckHost
}
gw := protocol.Gateway{
ServerConf: &gwConfig,
} }
gwserver = &gw gwserver = &gw
@ -233,6 +231,6 @@ var gwserver *protocol.Gateway
func List(w http.ResponseWriter, r *http.Request) { func List(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
for k, v := range protocol.Connections { for k, v := range protocol.Connections {
fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
} }
} }

View file

@ -19,7 +19,7 @@ type ClientConfig struct {
SmartCardAuth bool SmartCardAuth bool
PAAToken string PAAToken string
NTLMAuth bool NTLMAuth bool
Session *SessionInfo Session *Tunnel
LocalConn net.Conn LocalConn net.Conn
Server string Server string
Port int Port int

View file

@ -22,23 +22,6 @@ type RedirectFlags struct {
EnableAll bool EnableAll bool
} }
type SessionInfo struct {
// The connection-id (RDG-ConnID) as reported by the client
ConnId string
// The underlying incoming transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportIn
TransportIn transport.Transport
// The underlying outgoing transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportOut
TransportOut transport.Transport
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
RemoteServer string
// The obtained client ip address
ClientIp string
// User
UserName string
}
// readMessage parses and defragments a packet from a Transport. It returns // readMessage parses and defragments a packet from a Transport. It returns
// at most the bytes that have been reported by the packet // at most the bytes that have been reported by the packet
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) { func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {

View file

@ -7,7 +7,6 @@ import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -17,91 +16,88 @@ import (
) )
const ( const (
rdgConnectionIdKey = "Rdg-Connection-Id" rdgConnectionIdKey = "Rdg-Connection-RDGId"
MethodRDGIN = "RDG_IN_DATA" MethodRDGIN = "RDG_IN_DATA"
MethodRDGOUT = "RDG_OUT_DATA" MethodRDGOUT = "RDG_OUT_DATA"
) )
var ( type CheckPAACookieFunc func(context.Context, string) (bool, error)
connectionCache = prometheus.NewGauge( type CheckClientNameFunc func(context.Context, string) (bool, error)
prometheus.GaugeOpts{ type CheckHostFunc func(context.Context, string) (bool, error)
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
type Gateway struct { type Gateway struct {
ServerConf *ProcessorConf // CheckPAACookie verifies if the PAA cookie sent by the client is valid
CheckPAACookie CheckPAACookieFunc
// CheckClientName verifies if the client name is allowed to connect
CheckClientName CheckClientNameFunc
// CheckHost verifies if the client is allowed to connect to the remote host
CheckHost CheckHostFunc
// RedirectFlags sets what devices the client is allowed to redirect to the remote host
RedirectFlags RedirectFlags
// IdleTimeOut is used to determine when to disconnect clients that have been idle
IdleTimeout int
// SmartCardAuth sets whether to use smart card based authentication
SmartCardAuth bool
// TokenAuth sets whether to use token/cookie based authentication
TokenAuth bool
ReceiveBuf int
SendBuf int
} }
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
func init() {
prometheus.MustRegister(connectionCache)
prometheus.MustRegister(legacyConnections)
prometheus.MustRegister(websocketConnections)
}
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount())) connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo var t *Tunnel
connId := r.Header.Get(rdgConnectionIdKey) connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId) x, found := c.Get(connId)
if !found { if !found {
s = &SessionInfo{ConnId: connId} t = &Tunnel{RDGId: connId}
} else { } else {
s = x.(*SessionInfo) t = x.(*Tunnel)
} }
ctx := context.WithValue(r.Context(), "SessionInfo", s) ctx := context.WithValue(r.Context(), "Tunnel", t)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r.WithContext(ctx), s) g.handleLegacyProtocol(w, r.WithContext(ctx), t)
return return
} }
r.Method = "GET" // force r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err) log.Printf("Cannot upgrade falling back to old protocol: %t", err)
return return
} }
defer conn.Close() defer conn.Close()
err = g.setSendReceiveBuffers(conn.UnderlyingConn()) err = g.setSendReceiveBuffers(conn.UnderlyingConn())
if err != nil { if err != nil {
log.Printf("Cannot set send/receive buffers: %s", err) log.Printf("Cannot set send/receive buffers: %t", err)
} }
g.handleWebsocketProtocol(ctx, conn, s) g.handleWebsocketProtocol(ctx, conn, t)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r.WithContext(ctx), s) g.handleLegacyProtocol(w, r.WithContext(ctx), t)
} }
} }
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 { if g.SendBuf < 1 && g.ReceiveBuf < 1 {
return nil return nil
} }
// conn == tls.Conn // conn == tls.Tunnel
ptr := reflect.ValueOf(conn) ptr := reflect.ValueOf(conn)
val := reflect.Indirect(ptr) val := reflect.Indirect(ptr)
@ -109,7 +105,7 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
return errors.New("didn't get a struct from conn") return errors.New("didn't get a struct from conn")
} }
// this gets net.Conn -> *net.TCPConn -> net.TCPConn // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn
ptrConn := val.FieldByName("conn") ptrConn := val.FieldByName("conn")
valConn := reflect.Indirect(ptrConn) valConn := reflect.Indirect(ptrConn)
if !valConn.IsValid() { if !valConn.IsValid() {
@ -138,15 +134,15 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
} }
fd := int(ptrSysFd.Int()) fd := int(ptrSysFd.Int())
if g.ServerConf.ReceiveBuf > 0 { if g.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf) err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf)
if err != nil { if err != nil {
return wrapSyscallError("setsockopt", err) return wrapSyscallError("setsockopt", err)
} }
} }
if g.ServerConf.SendBuf > 0 { if g.SendBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf) err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf)
if err != nil { if err != nil {
return wrapSyscallError("setsockopt", err) return wrapSyscallError("setsockopt", err)
} }
@ -155,64 +151,66 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
return nil return nil
} }
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) { func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) {
websocketConnections.Inc() websocketConnections.Inc()
defer websocketConnections.Dec() defer websocketConnections.Dec()
inout, _ := transport.NewWS(c) inout, _ := transport.NewWS(c)
defer inout.Close() defer inout.Close()
s.TransportOut = inout t.TransportOut = inout
s.TransportIn = inout t.TransportIn = inout
handler := NewProcessor(s, g.ServerConf) t.ConnectedOn = time.Now()
RegisterConnection(s.ConnId, handler, s)
defer CloseConnection(s.ConnId) handler := NewProcessor(g, t)
RegisterConnection(handler, t)
defer RemoveConnection(t.RDGId)
handler.Process(ctx) handler.Process(ctx)
} }
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different // and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
// to ensure the connections do not get cached or terminated by a proxy prematurely. // to ensure the connections do not get cached or terminated by a proxy prematurely.
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) { func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil) log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w) out, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context()))
s.TransportOut = out t.TransportOut = out
out.SendAccept(true) out.SendAccept(true)
c.Set(s.ConnId, s, cache.DefaultExpiration) c.Set(t.RDGId, t, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
legacyConnections.Inc() legacyConnections.Inc()
defer legacyConnections.Dec() defer legacyConnections.Dec()
in, err := transport.NewLegacy(w) in, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) log.Printf("cannot hijack connection to support RDG IN data channel: %t", err)
return return
} }
defer in.Close() defer in.Close()
if s.TransportIn == nil { if t.TransportIn == nil {
s.TransportIn = in t.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration) c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context()))
in.SendAccept(false) in.SendAccept(false)
// read some initial data // read some initial data
in.Drain() in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
handler := NewProcessor(s, g.ServerConf) handler := NewProcessor(g, t)
RegisterConnection(s.ConnId, handler, s) RegisterConnection(handler, t)
defer CloseConnection(s.ConnId) defer RemoveConnection(t.RDGId)
handler.Process(r.Context()) handler.Process(r.Context())
} }
} }

View file

@ -0,0 +1,32 @@
package protocol
import "github.com/prometheus/client_golang/prometheus"
var (
connectionCache = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
func init() {
prometheus.MustRegister(connectionCache)
prometheus.MustRegister(legacyConnections)
prometheus.MustRegister(websocketConnections)
}

View file

@ -14,47 +14,23 @@ import (
"time" "time"
) )
type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type Processor struct { type Processor struct {
Session *SessionInfo // gw is the gateway instance on which the connection arrived
VerifyTunnelCreate VerifyTunnelCreate // Immutable; never nil.
VerifyTunnelAuthFunc VerifyTunnelAuthFunc gw *Gateway
VerifyServerFunc VerifyServerFunc
RedirectFlags int // state is the internal state of the processor
IdleTimeout int state int
SmartCardAuth bool
TokenAuth bool // tunnel is the underlying connection with the client
ClientName string tunnel *Tunnel
Remote net.Conn
State int
} }
type ProcessorConf struct { func NewProcessor(gw *Gateway, tunnel *Tunnel) *Processor {
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
RedirectFlags RedirectFlags
IdleTimeout int
SmartCardAuth bool
TokenAuth bool
ReceiveBuf int
SendBuf int
}
func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor {
h := &Processor{ h := &Processor{
State: SERVER_STATE_INITIALIZED, gw: gw,
Session: s, state: SERVER_STATE_INITIALIZED,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags), tunnel: tunnel,
IdleTimeout: conf.IdleTimeout,
SmartCardAuth: conf.SmartCardAuth,
TokenAuth: conf.TokenAuth,
VerifyTunnelCreate: conf.VerifyTunnelCreate,
VerifyServerFunc: conf.VerifyServerFunc,
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
} }
return h return h
} }
@ -63,7 +39,7 @@ const tunnelId = 10
func (p *Processor) Process(ctx context.Context) error { func (p *Processor) Process(ctx context.Context) error {
for { for {
pt, sz, pkt, err := readMessage(p.Session.TransportIn) pt, sz, pkt, err := p.tunnel.Read()
if err != nil { if err != nil {
log.Printf("Cannot read message from stream %p", err) log.Printf("Cannot read message from stream %p", err)
return err return err
@ -72,10 +48,10 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx)) log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
if p.State != SERVER_STATE_INITIALIZED { if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED) log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
} }
major, minor, _, reqAuth := p.handshakeRequest(pkt) major, minor, _, reqAuth := p.handshakeRequest(pkt)
@ -83,101 +59,102 @@ func (p *Processor) Process(ctx context.Context) error {
if err != nil { if err != nil {
log.Println(err) log.Println(err)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return err return err
} }
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS) msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
p.State = SERVER_STATE_HANDSHAKE p.state = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create") log.Printf("Tunnel create")
if p.State != SERVER_STATE_HANDSHAKE { if p.state != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d", log.Printf("Tunnel create attempted while in wrong state %d != %d",
p.State, SERVER_STATE_HANDSHAKE) p.state, SERVER_STATE_HANDSHAKE)
msg := p.tunnelResponse(E_PROXY_INTERNALERROR) msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
} }
_, cookie := p.tunnelRequest(pkt) _, cookie := p.tunnelRequest(pkt)
if p.VerifyTunnelCreate != nil { if p.gw.CheckPAACookie != nil {
if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx)) log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
} }
} }
msg := p.tunnelResponse(ERROR_SUCCESS) msg := p.tunnelResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
p.State = SERVER_STATE_TUNNEL_CREATE p.state = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth") log.Printf("Tunnel auth")
if p.State != SERVER_STATE_TUNNEL_CREATE { if p.state != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d", log.Printf("Tunnel auth attempted while in wrong state %d != %d",
p.State, SERVER_STATE_TUNNEL_CREATE) p.state, SERVER_STATE_TUNNEL_CREATE)
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR) msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
} }
client := p.tunnelAuthRequest(pkt) client := p.tunnelAuthRequest(pkt)
if p.VerifyTunnelAuthFunc != nil { if p.gw.CheckClientName != nil {
if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok { if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
log.Printf("Invalid client name: %p", client) log.Printf("Invalid client name: %p", client)
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
} }
} }
msg := p.tunnelAuthResponse(ERROR_SUCCESS) msg := p.tunnelAuthResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
p.State = SERVER_STATE_TUNNEL_AUTHORIZE p.state = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create") log.Printf("Channel create")
if p.State != SERVER_STATE_TUNNEL_AUTHORIZE { if p.state != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d", log.Printf("Channel create attempted while in wrong state %d != %d",
p.State, SERVER_STATE_TUNNEL_AUTHORIZE) p.state, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := p.channelResponse(E_PROXY_INTERNALERROR) msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
} }
server, port := p.channelRequest(pkt) server, port := p.channelRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port))) host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if p.VerifyServerFunc != nil { if p.gw.CheckHost != nil {
log.Printf("Verifying %p host connection", host) log.Printf("Verifying %p host connection", host)
if ok, _ := p.VerifyServerFunc(ctx, host); !ok { if ok, _ := p.gw.CheckHost(ctx, host); !ok {
log.Printf("Not allowed to connect to %p by policy handler", host) log.Printf("Not allowed to connect to %p by policy handler", host)
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
} }
} }
log.Printf("Establishing connection to RDP server: %p", host) log.Printf("Establishing connection to RDP server: %p", host)
p.Remote, err = net.DialTimeout("tcp", host, time.Second*15) p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil { if err != nil {
log.Printf("Error connecting to %p, %p", host, err) log.Printf("Error connecting to %p, %p", host, err)
msg := p.channelResponse(E_PROXY_INTERNALERROR) msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
return err return err
} }
p.tunnel.TargetServer = host
log.Printf("Connection established") log.Printf("Connection established")
msg := p.channelResponse(ERROR_SUCCESS) msg := p.channelResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go forward(p.Remote, p.Session.TransportOut) go forward(p.tunnel.rwc, p.tunnel.TransportOut)
p.State = SERVER_STATE_CHANNEL_CREATE p.state = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if p.State < SERVER_STATE_CHANNEL_CREATE { if p.state < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
p.State = SERVER_STATE_OPENED p.state = SERVER_STATE_OPENED
receive(pkt, p.Remote) receive(pkt, p.tunnel.rwc)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet // keepalives can be received while the channel is not open yet
if p.State < SERVER_STATE_CHANNEL_CREATE { if p.state < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE) log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
@ -185,15 +162,15 @@ func (p *Processor) Process(ctx context.Context) error {
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel") log.Printf("Close channel")
if p.State != SERVER_STATE_OPENED { if p.state != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED) log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED)
return errors.New("wrong state") return errors.New("wrong state")
} }
msg := p.channelCloseResponse(ERROR_SUCCESS) msg := p.channelCloseResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg) p.tunnel.Write(msg)
//p.Session.TransportIn.Close() //p.tunnel.TransportIn.Close()
//p.Session.TransportOut.Close() //p.tunnel.TransportOut.Close()
p.State = SERVER_STATE_CLOSED p.state = SERVER_STATE_CLOSED
return nil return nil
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt) log.Printf("Unknown packet (size %d): %x", sz, pkt)
@ -226,10 +203,10 @@ func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, versi
} }
func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
if p.SmartCardAuth { if p.gw.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC caps = caps | HTTP_EXTENDED_AUTH_SC
} }
if p.TokenAuth { if p.gw.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA caps = caps | HTTP_EXTENDED_AUTH_PAA
} }
@ -298,12 +275,12 @@ func (p *Processor) tunnelAuthResponse(errorCode int) []byte {
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// idle timeout // idle timeout
if p.IdleTimeout < 0 { if p.gw.IdleTimeout < 0 {
p.IdleTimeout = 0 p.gw.IdleTimeout = 0
} }
binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags binary.Write(buf, binary.LittleEndian, uint32(makeRedirectFlags(p.gw.RedirectFlags))) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes binary.Write(buf, binary.LittleEndian, uint32(p.gw.IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
} }

View file

@ -40,11 +40,10 @@ func TestHandshake(t *testing.T) {
client := ClientConfig{ client := ClientConfig{
PAAToken: "abab", PAAToken: "abab",
} }
s := &SessionInfo{} gw := &Gateway{}
hc := &ProcessorConf{ tunnel := &Tunnel{}
TokenAuth: true,
} h := NewProcessor(gw, tunnel)
h := NewProcessor(s, hc)
data := client.handshakeRequest() data := client.handshakeRequest()
@ -79,33 +78,30 @@ func TestHandshake(t *testing.T) {
} }
} }
func capsHelper(h Processor) uint16 { func capsHelper(gw Gateway) uint16 {
var caps uint16 var caps uint16
if h.TokenAuth { if gw.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA caps = caps | HTTP_EXTENDED_AUTH_PAA
} }
if h.SmartCardAuth { if gw.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC caps = caps | HTTP_EXTENDED_AUTH_SC
} }
return caps return caps
} }
func TestMatchAuth(t *testing.T) { func TestMatchAuth(t *testing.T) {
s := &SessionInfo{} gw := &Gateway{}
hc := &ProcessorConf{ tunnel := &Tunnel{}
TokenAuth: false,
SmartCardAuth: false,
}
h := NewProcessor(s, hc) h := NewProcessor(gw, tunnel)
in := uint16(0) in := uint16(0)
caps, err := h.matchAuth(in) caps, err := h.matchAuth(in)
if err != nil { if err != nil {
t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*h), err) t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*gw), err)
} }
if caps > in { if caps > in {
t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*h), in) t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*gw), in)
} }
in = HTTP_EXTENDED_AUTH_PAA in = HTTP_EXTENDED_AUTH_PAA
@ -116,7 +112,7 @@ func TestMatchAuth(t *testing.T) {
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
} }
h.SmartCardAuth = true gw.SmartCardAuth = true
caps, err = h.matchAuth(in) caps, err = h.matchAuth(in)
if err == nil { if err == nil {
t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps)
@ -124,10 +120,10 @@ func TestMatchAuth(t *testing.T) {
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
} }
h.TokenAuth = true gw.TokenAuth = true
caps, err = h.matchAuth(in) caps, err = h.matchAuth(in)
if err != nil { if err != nil {
t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*h), in, err) t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*gw), in, err)
} }
} }
@ -135,11 +131,10 @@ func TestTunnelCreation(t *testing.T) {
client := ClientConfig{ client := ClientConfig{
PAAToken: "abab", PAAToken: "abab",
} }
s := &SessionInfo{} gw := &Gateway{TokenAuth: true}
hc := &ProcessorConf{ tunnel := &Tunnel{}
TokenAuth: true,
} h := NewProcessor(gw, tunnel)
h := NewProcessor(s, hc)
data := client.tunnelRequest() data := client.tunnelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
@ -179,15 +174,13 @@ func TestTunnelAuth(t *testing.T) {
client := ClientConfig{ client := ClientConfig{
Name: name, Name: name,
} }
s := &SessionInfo{} gw := &Gateway{
hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
IdleTimeout: 10, IdleTimeout: 10,
RedirectFlags: RedirectFlags{ RedirectFlags: RedirectFlags{Clipboard: true},
Clipboard: true,
},
} }
h := NewProcessor(s, hc) tunnel := &Tunnel{}
h := NewProcessor(gw, tunnel)
data := client.tunnelAuthRequest() data := client.tunnelAuthRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
@ -213,9 +206,9 @@ func TestTunnelAuth(t *testing.T) {
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d", t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
} }
if int(timeout) != hc.IdleTimeout { if int(timeout) != gw.IdleTimeout {
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d", t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
timeout, hc.IdleTimeout) timeout, gw.IdleTimeout)
} }
} }
@ -225,15 +218,15 @@ func TestChannelCreation(t *testing.T) {
Server: server, Server: server,
Port: 3389, Port: 3389,
} }
s := &SessionInfo{} gw := &Gateway{
hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
IdleTimeout: 10, IdleTimeout: 10,
RedirectFlags: RedirectFlags{ RedirectFlags: RedirectFlags{
Clipboard: true, Clipboard: true,
}, },
} }
h := NewProcessor(s, hc) tunnel := &Tunnel{}
h := NewProcessor(gw, tunnel)
data := client.channelRequest() data := client.channelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))

View file

@ -1,30 +1,45 @@
package protocol package protocol
import ( var Connections map[string]*Monitor
"time"
)
var Connections map[string]*GatewayConnection type Monitor struct {
Processor *Processor
type GatewayConnection struct { Tunnel *Tunnel
PacketHandler *Processor
SessionInfo *SessionInfo
Since time.Time
IsWebsocket bool
} }
func RegisterConnection(connId string, h *Processor, s *SessionInfo) { func RegisterConnection(h *Processor, t *Tunnel) {
if Connections == nil { if Connections == nil {
Connections = make(map[string]*GatewayConnection) Connections = make(map[string]*Monitor)
} }
Connections[connId] = &GatewayConnection{ Connections[t.RDGId] = &Monitor{
PacketHandler: h, Processor: h,
SessionInfo: s, Tunnel: t,
Since: time.Now(),
} }
} }
func CloseConnection(connId string) { func RemoveConnection(connId string) {
delete(Connections, connId) delete(Connections, connId)
} }
// CalculateSpeedPerSecond calculate moving average.
/*
func CalculateSpeedPerSecond(connId string) (in int, out int) {
now := time.Now().UnixMilli()
c := Connections[connId]
total := int64(0)
for _, v := range c.Tunnel.BytesReceived {
total += v
}
in = int(total / (now - c.TimeStamp) * 1000)
total = int64(0)
for _, v := range c.BytesSent {
total += v
}
out = int(total / (now - c.TimeStamp))
return in, out
}
*/

View file

@ -0,0 +1,47 @@
package protocol
import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net"
"time"
)
type Tunnel struct {
// The connection-id (RDG-ConnID) as reported by the client
RDGId string
// The underlying incoming transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportIn
TransportIn transport.Transport
// The underlying outgoing transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportOut
TransportOut transport.Transport
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
TargetServer string
// The obtained client ip address
RemoteAddr string
// User
UserName string
// rwc is the underlying connection to the remote desktop server.
// It is of the type *net.TCPConn
rwc net.Conn
ByteSent int64
BytesReceived int64
ConnectedOn time.Time
LastSeen time.Time
}
func (t *Tunnel) Write(pkt []byte) {
n, _ := t.TransportOut.WritePacket(pkt)
t.ByteSent += int64(n)
}
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
pt, size, pkt, err = readMessage(t.TransportIn)
t.BytesReceived += int64(size)
t.LastSeen = time.Now()
return pt, size, pkt, err
}

View file

@ -7,12 +7,12 @@ import (
) )
var ( var (
info = protocol.SessionInfo{ info = protocol.Tunnel{
ConnId: "myid", RDGId: "myid",
TransportIn: nil, TransportIn: nil,
TransportOut: nil, TransportOut: nil,
RemoteServer: "my.remote.server", TargetServer: "my.remote.server",
ClientIp: "10.0.0.1", RemoteAddr: "10.0.0.1",
UserName: "Frank", UserName: "Frank",
} }
@ -20,7 +20,7 @@ var (
) )
func TestCheckHost(t *testing.T) { func TestCheckHost(t *testing.T) {
ctx := context.WithValue(context.Background(), "SessionInfo", &info) ctx := context.WithValue(context.Background(), "Tunnel", &info)
Hosts = hosts Hosts = hosts

View file

@ -33,28 +33,28 @@ type customClaims struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
} }
func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc { func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
return func(ctx context.Context, host string) (bool, error) { return func(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx) s := getSessionInfo(ctx)
if s == nil { if s == nil {
return false, errors.New("no valid session info found in context") return false, errors.New("no valid session info found in context")
} }
if s.RemoteServer != host { if s.TargetServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer) log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer)
return false, nil return false, nil
} }
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) { if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s", log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp) common.GetClientIp(ctx), s.RemoteAddr)
return false, nil return false, nil
} }
return next(ctx, host) return next(ctx, host)
} }
} }
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
if tokenString == "" { if tokenString == "" {
log.Printf("no token to parse") log.Printf("no token to parse")
return false, errors.New("no token to parse") return false, errors.New("no token to parse")
@ -104,8 +104,8 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
s := getSessionInfo(ctx) s := getSessionInfo(ctx)
s.RemoteServer = custom.RemoteServer s.TargetServer = custom.RemoteServer
s.ClientIp = custom.ClientIP s.RemoteAddr = custom.ClientIP
s.UserName = user.Subject s.UserName = user.Subject
return true, nil return true, nil
@ -288,8 +288,8 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
return token, err return token, err
} }
func getSessionInfo(ctx context.Context) *protocol.SessionInfo { func getSessionInfo(ctx context.Context) *protocol.Tunnel {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
if !ok { if !ok {
log.Printf("cannot get session info from context") log.Printf("cannot get session info from context")
return nil return nil