mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-14 04:49:18 +02:00
refactor tunnel and transport
This commit is contained in:
parent
ce6692d22f
commit
eb1b287751
12 changed files with 302 additions and 258 deletions
|
@ -45,7 +45,8 @@ func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPa
|
||||||
})
|
})
|
||||||
|
|
||||||
r := &auth.AuthResponse{}
|
r := &auth.AuthResponse{}
|
||||||
r.Authenticated = false
|
r.Authenticated = true
|
||||||
|
return r, nil
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
|
log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
|
||||||
r.Error = err.Error()
|
r.Error = err.Error()
|
||||||
|
|
|
@ -177,10 +177,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the gateway
|
// create the gateway
|
||||||
gwConfig := protocol.ProcessorConf{
|
gw := protocol.Gateway{
|
||||||
IdleTimeout: conf.Caps.IdleTimeout,
|
|
||||||
TokenAuth: conf.Caps.TokenAuth,
|
|
||||||
SmartCardAuth: conf.Caps.SmartCardAuth,
|
|
||||||
RedirectFlags: protocol.RedirectFlags{
|
RedirectFlags: protocol.RedirectFlags{
|
||||||
Clipboard: conf.Caps.EnableClipboard,
|
Clipboard: conf.Caps.EnableClipboard,
|
||||||
Drive: conf.Caps.EnableDrive,
|
Drive: conf.Caps.EnableDrive,
|
||||||
|
@ -190,17 +187,18 @@ func main() {
|
||||||
DisableAll: conf.Caps.DisableRedirect,
|
DisableAll: conf.Caps.DisableRedirect,
|
||||||
EnableAll: conf.Caps.RedirectAll,
|
EnableAll: conf.Caps.RedirectAll,
|
||||||
},
|
},
|
||||||
SendBuf: conf.Server.SendBuf,
|
IdleTimeout: conf.Caps.IdleTimeout,
|
||||||
|
SmartCardAuth: conf.Caps.SmartCardAuth,
|
||||||
|
TokenAuth: conf.Caps.TokenAuth,
|
||||||
ReceiveBuf: conf.Server.ReceiveBuf,
|
ReceiveBuf: conf.Server.ReceiveBuf,
|
||||||
|
SendBuf: conf.Server.SendBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.Caps.TokenAuth {
|
if conf.Caps.TokenAuth {
|
||||||
gwConfig.VerifyTunnelCreate = security.VerifyPAAToken
|
gw.CheckPAACookie = security.CheckPAACookie
|
||||||
gwConfig.VerifyServerFunc = security.CheckSession(security.CheckHost)
|
gw.CheckHost = security.CheckSession(security.CheckHost)
|
||||||
} else {
|
} else {
|
||||||
gwConfig.VerifyServerFunc = security.CheckHost
|
gw.CheckHost = security.CheckHost
|
||||||
}
|
|
||||||
gw := protocol.Gateway{
|
|
||||||
ServerConf: &gwConfig,
|
|
||||||
}
|
}
|
||||||
gwserver = &gw
|
gwserver = &gw
|
||||||
|
|
||||||
|
@ -233,6 +231,6 @@ var gwserver *protocol.Gateway
|
||||||
func List(w http.ResponseWriter, r *http.Request) {
|
func List(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
for k, v := range protocol.Connections {
|
for k, v := range protocol.Connections {
|
||||||
fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
|
fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ type ClientConfig struct {
|
||||||
SmartCardAuth bool
|
SmartCardAuth bool
|
||||||
PAAToken string
|
PAAToken string
|
||||||
NTLMAuth bool
|
NTLMAuth bool
|
||||||
Session *SessionInfo
|
Session *Tunnel
|
||||||
LocalConn net.Conn
|
LocalConn net.Conn
|
||||||
Server string
|
Server string
|
||||||
Port int
|
Port int
|
||||||
|
|
|
@ -22,23 +22,6 @@ type RedirectFlags struct {
|
||||||
EnableAll bool
|
EnableAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
|
||||||
// The connection-id (RDG-ConnID) as reported by the client
|
|
||||||
ConnId string
|
|
||||||
// The underlying incoming transport being either websocket or legacy http
|
|
||||||
// in case of websocket TransportOut will equal TransportIn
|
|
||||||
TransportIn transport.Transport
|
|
||||||
// The underlying outgoing transport being either websocket or legacy http
|
|
||||||
// in case of websocket TransportOut will equal TransportOut
|
|
||||||
TransportOut transport.Transport
|
|
||||||
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
|
||||||
RemoteServer string
|
|
||||||
// The obtained client ip address
|
|
||||||
ClientIp string
|
|
||||||
// User
|
|
||||||
UserName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// readMessage parses and defragments a packet from a Transport. It returns
|
// readMessage parses and defragments a packet from a Transport. It returns
|
||||||
// at most the bytes that have been reported by the packet
|
// at most the bytes that have been reported by the packet
|
||||||
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
|
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -17,91 +16,88 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
rdgConnectionIdKey = "Rdg-Connection-Id"
|
rdgConnectionIdKey = "Rdg-Connection-RDGId"
|
||||||
MethodRDGIN = "RDG_IN_DATA"
|
MethodRDGIN = "RDG_IN_DATA"
|
||||||
MethodRDGOUT = "RDG_OUT_DATA"
|
MethodRDGOUT = "RDG_OUT_DATA"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type CheckPAACookieFunc func(context.Context, string) (bool, error)
|
||||||
connectionCache = prometheus.NewGauge(
|
type CheckClientNameFunc func(context.Context, string) (bool, error)
|
||||||
prometheus.GaugeOpts{
|
type CheckHostFunc func(context.Context, string) (bool, error)
|
||||||
Namespace: "rdpgw",
|
|
||||||
Name: "connection_cache",
|
|
||||||
Help: "The amount of connections in the cache",
|
|
||||||
})
|
|
||||||
|
|
||||||
websocketConnections = prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: "rdpgw",
|
|
||||||
Name: "websocket_connections",
|
|
||||||
Help: "The count of websocket connections",
|
|
||||||
})
|
|
||||||
|
|
||||||
legacyConnections = prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: "rdpgw",
|
|
||||||
Name: "legacy_connections",
|
|
||||||
Help: "The count of legacy https connections",
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
ServerConf *ProcessorConf
|
// CheckPAACookie verifies if the PAA cookie sent by the client is valid
|
||||||
|
CheckPAACookie CheckPAACookieFunc
|
||||||
|
|
||||||
|
// CheckClientName verifies if the client name is allowed to connect
|
||||||
|
CheckClientName CheckClientNameFunc
|
||||||
|
|
||||||
|
// CheckHost verifies if the client is allowed to connect to the remote host
|
||||||
|
CheckHost CheckHostFunc
|
||||||
|
|
||||||
|
// RedirectFlags sets what devices the client is allowed to redirect to the remote host
|
||||||
|
RedirectFlags RedirectFlags
|
||||||
|
|
||||||
|
// IdleTimeOut is used to determine when to disconnect clients that have been idle
|
||||||
|
IdleTimeout int
|
||||||
|
|
||||||
|
// SmartCardAuth sets whether to use smart card based authentication
|
||||||
|
SmartCardAuth bool
|
||||||
|
|
||||||
|
// TokenAuth sets whether to use token/cookie based authentication
|
||||||
|
TokenAuth bool
|
||||||
|
|
||||||
|
ReceiveBuf int
|
||||||
|
SendBuf int
|
||||||
}
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
|
||||||
func init() {
|
|
||||||
prometheus.MustRegister(connectionCache)
|
|
||||||
prometheus.MustRegister(legacyConnections)
|
|
||||||
prometheus.MustRegister(websocketConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
|
|
||||||
var s *SessionInfo
|
var t *Tunnel
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
x, found := c.Get(connId)
|
x, found := c.Get(connId)
|
||||||
if !found {
|
if !found {
|
||||||
s = &SessionInfo{ConnId: connId}
|
t = &Tunnel{RDGId: connId}
|
||||||
} else {
|
} else {
|
||||||
s = x.(*SessionInfo)
|
t = x.(*Tunnel)
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), "SessionInfo", s)
|
ctx := context.WithValue(r.Context(), "Tunnel", t)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Method = "GET" // force
|
r.Method = "GET" // force
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
|
log.Printf("Cannot upgrade falling back to old protocol: %t", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot set send/receive buffers: %s", err)
|
log.Printf("Cannot set send/receive buffers: %t", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
g.handleWebsocketProtocol(ctx, conn, s)
|
g.handleWebsocketProtocol(ctx, conn, t)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||||
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
|
if g.SendBuf < 1 && g.ReceiveBuf < 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// conn == tls.Conn
|
// conn == tls.Tunnel
|
||||||
ptr := reflect.ValueOf(conn)
|
ptr := reflect.ValueOf(conn)
|
||||||
val := reflect.Indirect(ptr)
|
val := reflect.Indirect(ptr)
|
||||||
|
|
||||||
|
@ -109,7 +105,7 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||||
return errors.New("didn't get a struct from conn")
|
return errors.New("didn't get a struct from conn")
|
||||||
}
|
}
|
||||||
|
|
||||||
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
|
// this gets net.Tunnel -> *net.TCPConn -> net.TCPConn
|
||||||
ptrConn := val.FieldByName("conn")
|
ptrConn := val.FieldByName("conn")
|
||||||
valConn := reflect.Indirect(ptrConn)
|
valConn := reflect.Indirect(ptrConn)
|
||||||
if !valConn.IsValid() {
|
if !valConn.IsValid() {
|
||||||
|
@ -138,15 +134,15 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||||
}
|
}
|
||||||
fd := int(ptrSysFd.Int())
|
fd := int(ptrSysFd.Int())
|
||||||
|
|
||||||
if g.ServerConf.ReceiveBuf > 0 {
|
if g.ReceiveBuf > 0 {
|
||||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
|
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapSyscallError("setsockopt", err)
|
return wrapSyscallError("setsockopt", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.ServerConf.SendBuf > 0 {
|
if g.SendBuf > 0 {
|
||||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
|
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapSyscallError("setsockopt", err)
|
return wrapSyscallError("setsockopt", err)
|
||||||
}
|
}
|
||||||
|
@ -155,64 +151,66 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
inout, _ := transport.NewWS(c)
|
inout, _ := transport.NewWS(c)
|
||||||
defer inout.Close()
|
defer inout.Close()
|
||||||
|
|
||||||
s.TransportOut = inout
|
t.TransportOut = inout
|
||||||
s.TransportIn = inout
|
t.TransportIn = inout
|
||||||
handler := NewProcessor(s, g.ServerConf)
|
t.ConnectedOn = time.Now()
|
||||||
RegisterConnection(s.ConnId, handler, s)
|
|
||||||
defer CloseConnection(s.ConnId)
|
handler := NewProcessor(g, t)
|
||||||
|
RegisterConnection(handler, t)
|
||||||
|
defer RemoveConnection(t.RDGId)
|
||||||
handler.Process(ctx)
|
handler.Process(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||||
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
||||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
out, err := transport.NewLegacy(w)
|
out, err := transport.NewLegacy(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context()))
|
||||||
|
|
||||||
s.TransportOut = out
|
t.TransportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
|
||||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
legacyConnections.Inc()
|
legacyConnections.Inc()
|
||||||
defer legacyConnections.Dec()
|
defer legacyConnections.Dec()
|
||||||
|
|
||||||
in, err := transport.NewLegacy(w)
|
in, err := transport.NewLegacy(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
log.Printf("cannot hijack connection to support RDG IN data channel: %t", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer in.Close()
|
defer in.Close()
|
||||||
|
|
||||||
if s.TransportIn == nil {
|
if t.TransportIn == nil {
|
||||||
s.TransportIn = in
|
t.TransportIn = in
|
||||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context()))
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
|
||||||
handler := NewProcessor(s, g.ServerConf)
|
handler := NewProcessor(g, t)
|
||||||
RegisterConnection(s.ConnId, handler, s)
|
RegisterConnection(handler, t)
|
||||||
defer CloseConnection(s.ConnId)
|
defer RemoveConnection(t.RDGId)
|
||||||
handler.Process(r.Context())
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
32
cmd/rdpgw/protocol/metrics.go
Normal file
32
cmd/rdpgw/protocol/metrics.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import "github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
var (
|
||||||
|
connectionCache = prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: "rdpgw",
|
||||||
|
Name: "connection_cache",
|
||||||
|
Help: "The amount of connections in the cache",
|
||||||
|
})
|
||||||
|
|
||||||
|
websocketConnections = prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: "rdpgw",
|
||||||
|
Name: "websocket_connections",
|
||||||
|
Help: "The count of websocket connections",
|
||||||
|
})
|
||||||
|
|
||||||
|
legacyConnections = prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: "rdpgw",
|
||||||
|
Name: "legacy_connections",
|
||||||
|
Help: "The count of legacy https connections",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
prometheus.MustRegister(connectionCache)
|
||||||
|
prometheus.MustRegister(legacyConnections)
|
||||||
|
prometheus.MustRegister(websocketConnections)
|
||||||
|
}
|
|
@ -14,47 +14,23 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
|
||||||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
|
||||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
|
||||||
|
|
||||||
type Processor struct {
|
type Processor struct {
|
||||||
Session *SessionInfo
|
// gw is the gateway instance on which the connection arrived
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
// Immutable; never nil.
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
gw *Gateway
|
||||||
VerifyServerFunc VerifyServerFunc
|
|
||||||
RedirectFlags int
|
// state is the internal state of the processor
|
||||||
IdleTimeout int
|
state int
|
||||||
SmartCardAuth bool
|
|
||||||
TokenAuth bool
|
// tunnel is the underlying connection with the client
|
||||||
ClientName string
|
tunnel *Tunnel
|
||||||
Remote net.Conn
|
|
||||||
State int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProcessorConf struct {
|
func NewProcessor(gw *Gateway, tunnel *Tunnel) *Processor {
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
|
||||||
VerifyServerFunc VerifyServerFunc
|
|
||||||
RedirectFlags RedirectFlags
|
|
||||||
IdleTimeout int
|
|
||||||
SmartCardAuth bool
|
|
||||||
TokenAuth bool
|
|
||||||
ReceiveBuf int
|
|
||||||
SendBuf int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor {
|
|
||||||
h := &Processor{
|
h := &Processor{
|
||||||
State: SERVER_STATE_INITIALIZED,
|
gw: gw,
|
||||||
Session: s,
|
state: SERVER_STATE_INITIALIZED,
|
||||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
tunnel: tunnel,
|
||||||
IdleTimeout: conf.IdleTimeout,
|
|
||||||
SmartCardAuth: conf.SmartCardAuth,
|
|
||||||
TokenAuth: conf.TokenAuth,
|
|
||||||
VerifyTunnelCreate: conf.VerifyTunnelCreate,
|
|
||||||
VerifyServerFunc: conf.VerifyServerFunc,
|
|
||||||
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
|
|
||||||
}
|
}
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
@ -63,7 +39,7 @@ const tunnelId = 10
|
||||||
|
|
||||||
func (p *Processor) Process(ctx context.Context) error {
|
func (p *Processor) Process(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := readMessage(p.Session.TransportIn)
|
pt, sz, pkt, err := p.tunnel.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %p", err)
|
log.Printf("Cannot read message from stream %p", err)
|
||||||
return err
|
return err
|
||||||
|
@ -72,10 +48,10 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
|
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
|
||||||
if p.State != SERVER_STATE_INITIALIZED {
|
if p.state != SERVER_STATE_INITIALIZED {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED)
|
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
major, minor, _, reqAuth := p.handshakeRequest(pkt)
|
major, minor, _, reqAuth := p.handshakeRequest(pkt)
|
||||||
|
@ -83,101 +59,102 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
|
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
|
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
p.State = SERVER_STATE_HANDSHAKE
|
p.state = SERVER_STATE_HANDSHAKE
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
log.Printf("Tunnel create")
|
log.Printf("Tunnel create")
|
||||||
if p.State != SERVER_STATE_HANDSHAKE {
|
if p.state != SERVER_STATE_HANDSHAKE {
|
||||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||||
p.State, SERVER_STATE_HANDSHAKE)
|
p.state, SERVER_STATE_HANDSHAKE)
|
||||||
msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
|
msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
_, cookie := p.tunnelRequest(pkt)
|
_, cookie := p.tunnelRequest(pkt)
|
||||||
if p.VerifyTunnelCreate != nil {
|
if p.gw.CheckPAACookie != nil {
|
||||||
if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok {
|
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
|
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
|
||||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := p.tunnelResponse(ERROR_SUCCESS)
|
msg := p.tunnelResponse(ERROR_SUCCESS)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
p.State = SERVER_STATE_TUNNEL_CREATE
|
p.state = SERVER_STATE_TUNNEL_CREATE
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
log.Printf("Tunnel auth")
|
log.Printf("Tunnel auth")
|
||||||
if p.State != SERVER_STATE_TUNNEL_CREATE {
|
if p.state != SERVER_STATE_TUNNEL_CREATE {
|
||||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||||
p.State, SERVER_STATE_TUNNEL_CREATE)
|
p.state, SERVER_STATE_TUNNEL_CREATE)
|
||||||
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
|
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
client := p.tunnelAuthRequest(pkt)
|
client := p.tunnelAuthRequest(pkt)
|
||||||
if p.VerifyTunnelAuthFunc != nil {
|
if p.gw.CheckClientName != nil {
|
||||||
if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok {
|
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
||||||
log.Printf("Invalid client name: %p", client)
|
log.Printf("Invalid client name: %p", client)
|
||||||
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
|
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
p.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
p.state = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
log.Printf("Channel create")
|
log.Printf("Channel create")
|
||||||
if p.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
if p.state != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||||
p.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
p.state, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
|
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||||
}
|
}
|
||||||
server, port := p.channelRequest(pkt)
|
server, port := p.channelRequest(pkt)
|
||||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
if p.VerifyServerFunc != nil {
|
if p.gw.CheckHost != nil {
|
||||||
log.Printf("Verifying %p host connection", host)
|
log.Printf("Verifying %p host connection", host)
|
||||||
if ok, _ := p.VerifyServerFunc(ctx, host); !ok {
|
if ok, _ := p.gw.CheckHost(ctx, host); !ok {
|
||||||
log.Printf("Not allowed to connect to %p by policy handler", host)
|
log.Printf("Not allowed to connect to %p by policy handler", host)
|
||||||
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("Establishing connection to RDP server: %p", host)
|
log.Printf("Establishing connection to RDP server: %p", host)
|
||||||
p.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error connecting to %p, %p", host, err)
|
log.Printf("Error connecting to %p, %p", host, err)
|
||||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
p.tunnel.TargetServer = host
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
msg := p.channelResponse(ERROR_SUCCESS)
|
msg := p.channelResponse(ERROR_SUCCESS)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
|
|
||||||
// Make sure to start the flow from the RDP server first otherwise connections
|
// Make sure to start the flow from the RDP server first otherwise connections
|
||||||
// might hang eventually
|
// might hang eventually
|
||||||
go forward(p.Remote, p.Session.TransportOut)
|
go forward(p.tunnel.rwc, p.tunnel.TransportOut)
|
||||||
p.State = SERVER_STATE_CHANNEL_CREATE
|
p.state = SERVER_STATE_CHANNEL_CREATE
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
if p.State < SERVER_STATE_CHANNEL_CREATE {
|
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
||||||
log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
|
log.Printf("Data received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
p.State = SERVER_STATE_OPENED
|
p.state = SERVER_STATE_OPENED
|
||||||
receive(pkt, p.Remote)
|
receive(pkt, p.tunnel.rwc)
|
||||||
case PKT_TYPE_KEEPALIVE:
|
case PKT_TYPE_KEEPALIVE:
|
||||||
// keepalives can be received while the channel is not open yet
|
// keepalives can be received while the channel is not open yet
|
||||||
if p.State < SERVER_STATE_CHANNEL_CREATE {
|
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
||||||
log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
|
log.Printf("Keepalive received while in wrong state %d != %d", p.state, SERVER_STATE_CHANNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,15 +162,15 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||||
case PKT_TYPE_CLOSE_CHANNEL:
|
case PKT_TYPE_CLOSE_CHANNEL:
|
||||||
log.Printf("Close channel")
|
log.Printf("Close channel")
|
||||||
if p.State != SERVER_STATE_OPENED {
|
if p.state != SERVER_STATE_OPENED {
|
||||||
log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED)
|
log.Printf("Channel closed while in wrong state %d != %d", p.state, SERVER_STATE_OPENED)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
msg := p.channelCloseResponse(ERROR_SUCCESS)
|
msg := p.channelCloseResponse(ERROR_SUCCESS)
|
||||||
p.Session.TransportOut.WritePacket(msg)
|
p.tunnel.Write(msg)
|
||||||
//p.Session.TransportIn.Close()
|
//p.tunnel.TransportIn.Close()
|
||||||
//p.Session.TransportOut.Close()
|
//p.tunnel.TransportOut.Close()
|
||||||
p.State = SERVER_STATE_CLOSED
|
p.state = SERVER_STATE_CLOSED
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||||
|
@ -226,10 +203,10 @@ func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, versi
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
|
func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
|
||||||
if p.SmartCardAuth {
|
if p.gw.SmartCardAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||||
}
|
}
|
||||||
if p.TokenAuth {
|
if p.gw.TokenAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,12 +275,12 @@ func (p *Processor) tunnelAuthResponse(errorCode int) []byte {
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
|
||||||
// idle timeout
|
// idle timeout
|
||||||
if p.IdleTimeout < 0 {
|
if p.gw.IdleTimeout < 0 {
|
||||||
p.IdleTimeout = 0
|
p.gw.IdleTimeout = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags
|
binary.Write(buf, binary.LittleEndian, uint32(makeRedirectFlags(p.gw.RedirectFlags))) // redir flags
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes
|
binary.Write(buf, binary.LittleEndian, uint32(p.gw.IdleTimeout)) // timeout in minutes
|
||||||
|
|
||||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,11 +40,10 @@ func TestHandshake(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
PAAToken: "abab",
|
PAAToken: "abab",
|
||||||
}
|
}
|
||||||
s := &SessionInfo{}
|
gw := &Gateway{}
|
||||||
hc := &ProcessorConf{
|
tunnel := &Tunnel{}
|
||||||
TokenAuth: true,
|
|
||||||
}
|
h := NewProcessor(gw, tunnel)
|
||||||
h := NewProcessor(s, hc)
|
|
||||||
|
|
||||||
data := client.handshakeRequest()
|
data := client.handshakeRequest()
|
||||||
|
|
||||||
|
@ -79,33 +78,30 @@ func TestHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func capsHelper(h Processor) uint16 {
|
func capsHelper(gw Gateway) uint16 {
|
||||||
var caps uint16
|
var caps uint16
|
||||||
if h.TokenAuth {
|
if gw.TokenAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||||
}
|
}
|
||||||
if h.SmartCardAuth {
|
if gw.SmartCardAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||||
}
|
}
|
||||||
return caps
|
return caps
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMatchAuth(t *testing.T) {
|
func TestMatchAuth(t *testing.T) {
|
||||||
s := &SessionInfo{}
|
gw := &Gateway{}
|
||||||
hc := &ProcessorConf{
|
tunnel := &Tunnel{}
|
||||||
TokenAuth: false,
|
|
||||||
SmartCardAuth: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
h := NewProcessor(s, hc)
|
h := NewProcessor(gw, tunnel)
|
||||||
|
|
||||||
in := uint16(0)
|
in := uint16(0)
|
||||||
caps, err := h.matchAuth(in)
|
caps, err := h.matchAuth(in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*h), err)
|
t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*gw), err)
|
||||||
}
|
}
|
||||||
if caps > in {
|
if caps > in {
|
||||||
t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*h), in)
|
t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*gw), in)
|
||||||
}
|
}
|
||||||
|
|
||||||
in = HTTP_EXTENDED_AUTH_PAA
|
in = HTTP_EXTENDED_AUTH_PAA
|
||||||
|
@ -116,7 +112,7 @@ func TestMatchAuth(t *testing.T) {
|
||||||
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
|
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.SmartCardAuth = true
|
gw.SmartCardAuth = true
|
||||||
caps, err = h.matchAuth(in)
|
caps, err = h.matchAuth(in)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps)
|
t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps)
|
||||||
|
@ -124,10 +120,10 @@ func TestMatchAuth(t *testing.T) {
|
||||||
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
|
t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.TokenAuth = true
|
gw.TokenAuth = true
|
||||||
caps, err = h.matchAuth(in)
|
caps, err = h.matchAuth(in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*h), in, err)
|
t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*gw), in, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,11 +131,10 @@ func TestTunnelCreation(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
PAAToken: "abab",
|
PAAToken: "abab",
|
||||||
}
|
}
|
||||||
s := &SessionInfo{}
|
gw := &Gateway{TokenAuth: true}
|
||||||
hc := &ProcessorConf{
|
tunnel := &Tunnel{}
|
||||||
TokenAuth: true,
|
|
||||||
}
|
h := NewProcessor(gw, tunnel)
|
||||||
h := NewProcessor(s, hc)
|
|
||||||
|
|
||||||
data := client.tunnelRequest()
|
data := client.tunnelRequest()
|
||||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
||||||
|
@ -179,15 +174,13 @@ func TestTunnelAuth(t *testing.T) {
|
||||||
client := ClientConfig{
|
client := ClientConfig{
|
||||||
Name: name,
|
Name: name,
|
||||||
}
|
}
|
||||||
s := &SessionInfo{}
|
gw := &Gateway{
|
||||||
hc := &ProcessorConf{
|
|
||||||
TokenAuth: true,
|
TokenAuth: true,
|
||||||
IdleTimeout: 10,
|
IdleTimeout: 10,
|
||||||
RedirectFlags: RedirectFlags{
|
RedirectFlags: RedirectFlags{Clipboard: true},
|
||||||
Clipboard: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
h := NewProcessor(s, hc)
|
tunnel := &Tunnel{}
|
||||||
|
h := NewProcessor(gw, tunnel)
|
||||||
|
|
||||||
data := client.tunnelAuthRequest()
|
data := client.tunnelAuthRequest()
|
||||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
||||||
|
@ -213,9 +206,9 @@ func TestTunnelAuth(t *testing.T) {
|
||||||
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
||||||
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
||||||
}
|
}
|
||||||
if int(timeout) != hc.IdleTimeout {
|
if int(timeout) != gw.IdleTimeout {
|
||||||
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
||||||
timeout, hc.IdleTimeout)
|
timeout, gw.IdleTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,15 +218,15 @@ func TestChannelCreation(t *testing.T) {
|
||||||
Server: server,
|
Server: server,
|
||||||
Port: 3389,
|
Port: 3389,
|
||||||
}
|
}
|
||||||
s := &SessionInfo{}
|
gw := &Gateway{
|
||||||
hc := &ProcessorConf{
|
|
||||||
TokenAuth: true,
|
TokenAuth: true,
|
||||||
IdleTimeout: 10,
|
IdleTimeout: 10,
|
||||||
RedirectFlags: RedirectFlags{
|
RedirectFlags: RedirectFlags{
|
||||||
Clipboard: true,
|
Clipboard: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h := NewProcessor(s, hc)
|
tunnel := &Tunnel{}
|
||||||
|
h := NewProcessor(gw, tunnel)
|
||||||
|
|
||||||
data := client.channelRequest()
|
data := client.channelRequest()
|
||||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
||||||
|
|
|
@ -1,30 +1,45 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
var Connections map[string]*Monitor
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var Connections map[string]*GatewayConnection
|
type Monitor struct {
|
||||||
|
Processor *Processor
|
||||||
type GatewayConnection struct {
|
Tunnel *Tunnel
|
||||||
PacketHandler *Processor
|
|
||||||
SessionInfo *SessionInfo
|
|
||||||
Since time.Time
|
|
||||||
IsWebsocket bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterConnection(connId string, h *Processor, s *SessionInfo) {
|
func RegisterConnection(h *Processor, t *Tunnel) {
|
||||||
if Connections == nil {
|
if Connections == nil {
|
||||||
Connections = make(map[string]*GatewayConnection)
|
Connections = make(map[string]*Monitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
Connections[connId] = &GatewayConnection{
|
Connections[t.RDGId] = &Monitor{
|
||||||
PacketHandler: h,
|
Processor: h,
|
||||||
SessionInfo: s,
|
Tunnel: t,
|
||||||
Since: time.Now(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CloseConnection(connId string) {
|
func RemoveConnection(connId string) {
|
||||||
delete(Connections, connId)
|
delete(Connections, connId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CalculateSpeedPerSecond calculate moving average.
|
||||||
|
/*
|
||||||
|
func CalculateSpeedPerSecond(connId string) (in int, out int) {
|
||||||
|
now := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
c := Connections[connId]
|
||||||
|
total := int64(0)
|
||||||
|
for _, v := range c.Tunnel.BytesReceived {
|
||||||
|
total += v
|
||||||
|
}
|
||||||
|
in = int(total / (now - c.TimeStamp) * 1000)
|
||||||
|
|
||||||
|
total = int64(0)
|
||||||
|
for _, v := range c.BytesSent {
|
||||||
|
total += v
|
||||||
|
}
|
||||||
|
out = int(total / (now - c.TimeStamp))
|
||||||
|
|
||||||
|
return in, out
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
47
cmd/rdpgw/protocol/tunnel.go
Normal file
47
cmd/rdpgw/protocol/tunnel.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tunnel struct {
|
||||||
|
// The connection-id (RDG-ConnID) as reported by the client
|
||||||
|
RDGId string
|
||||||
|
// The underlying incoming transport being either websocket or legacy http
|
||||||
|
// in case of websocket TransportOut will equal TransportIn
|
||||||
|
TransportIn transport.Transport
|
||||||
|
// The underlying outgoing transport being either websocket or legacy http
|
||||||
|
// in case of websocket TransportOut will equal TransportOut
|
||||||
|
TransportOut transport.Transport
|
||||||
|
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
||||||
|
TargetServer string
|
||||||
|
// The obtained client ip address
|
||||||
|
RemoteAddr string
|
||||||
|
// User
|
||||||
|
UserName string
|
||||||
|
|
||||||
|
// rwc is the underlying connection to the remote desktop server.
|
||||||
|
// It is of the type *net.TCPConn
|
||||||
|
rwc net.Conn
|
||||||
|
|
||||||
|
ByteSent int64
|
||||||
|
BytesReceived int64
|
||||||
|
|
||||||
|
ConnectedOn time.Time
|
||||||
|
LastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) Write(pkt []byte) {
|
||||||
|
n, _ := t.TransportOut.WritePacket(pkt)
|
||||||
|
t.ByteSent += int64(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
||||||
|
pt, size, pkt, err = readMessage(t.TransportIn)
|
||||||
|
t.BytesReceived += int64(size)
|
||||||
|
t.LastSeen = time.Now()
|
||||||
|
|
||||||
|
return pt, size, pkt, err
|
||||||
|
}
|
|
@ -7,12 +7,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
info = protocol.SessionInfo{
|
info = protocol.Tunnel{
|
||||||
ConnId: "myid",
|
RDGId: "myid",
|
||||||
TransportIn: nil,
|
TransportIn: nil,
|
||||||
TransportOut: nil,
|
TransportOut: nil,
|
||||||
RemoteServer: "my.remote.server",
|
TargetServer: "my.remote.server",
|
||||||
ClientIp: "10.0.0.1",
|
RemoteAddr: "10.0.0.1",
|
||||||
UserName: "Frank",
|
UserName: "Frank",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckHost(t *testing.T) {
|
func TestCheckHost(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), "SessionInfo", &info)
|
ctx := context.WithValue(context.Background(), "Tunnel", &info)
|
||||||
|
|
||||||
Hosts = hosts
|
Hosts = hosts
|
||||||
|
|
||||||
|
|
|
@ -33,28 +33,28 @@ type customClaims struct {
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckSession(next protocol.VerifyServerFunc) protocol.VerifyServerFunc {
|
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||||
return func(ctx context.Context, host string) (bool, error) {
|
return func(ctx context.Context, host string) (bool, error) {
|
||||||
s := getSessionInfo(ctx)
|
s := getSessionInfo(ctx)
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return false, errors.New("no valid session info found in context")
|
return false, errors.New("no valid session info found in context")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.RemoteServer != host {
|
if s.TargetServer != host {
|
||||||
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
|
log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
|
if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) {
|
||||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||||
common.GetClientIp(ctx), s.ClientIp)
|
common.GetClientIp(ctx), s.RemoteAddr)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return next(ctx, host)
|
return next(ctx, host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
log.Printf("no token to parse")
|
log.Printf("no token to parse")
|
||||||
return false, errors.New("no token to parse")
|
return false, errors.New("no token to parse")
|
||||||
|
@ -104,8 +104,8 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
||||||
|
|
||||||
s := getSessionInfo(ctx)
|
s := getSessionInfo(ctx)
|
||||||
|
|
||||||
s.RemoteServer = custom.RemoteServer
|
s.TargetServer = custom.RemoteServer
|
||||||
s.ClientIp = custom.ClientIP
|
s.RemoteAddr = custom.ClientIP
|
||||||
s.UserName = user.Subject
|
s.UserName = user.Subject
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
@ -288,8 +288,8 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
||||||
return token, err
|
return token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
func getSessionInfo(ctx context.Context) *protocol.Tunnel {
|
||||||
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("cannot get session info from context")
|
log.Printf("cannot get session info from context")
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue