mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-25 20:08:20 +02:00
Improve configurability
This commit is contained in:
parent
93b31ec9b6
commit
01345b9416
4 changed files with 97 additions and 53 deletions
13
main.go
13
main.go
|
@ -89,7 +89,18 @@ func main() {
|
||||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||||
}
|
}
|
||||||
|
|
||||||
http.HandleFunc("/remoteDesktopGateway/", protocol.HandleGatewayProtocol)
|
// create the gateway
|
||||||
|
handlerConfig := protocol.HandlerConf{
|
||||||
|
TokenAuth: true,
|
||||||
|
RedirectFlags: protocol.RedirectFlags{
|
||||||
|
Clipboard: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
gw := protocol.Gateway{
|
||||||
|
HandlerConf: &handlerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol)
|
||||||
http.HandleFunc("/connect", handleRdpDownload)
|
http.HandleFunc("/connect", handleRdpDownload)
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
http.HandleFunc("/callback", handleCallback)
|
http.HandleFunc("/callback", handleCallback)
|
||||||
|
|
|
@ -12,19 +12,28 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// When should the client disconnect when idle in minutes
|
|
||||||
var IdleTimeout = 0
|
|
||||||
|
|
||||||
type VerifyPAACookieFunc func(string) (bool, error)
|
type VerifyPAACookieFunc func(string) (bool, error)
|
||||||
type VerifyTunnelAuthFunc func(string) (bool, error)
|
type VerifyTunnelAuthFunc func(string) (bool, error)
|
||||||
type VerifyServerFunc func(string) (bool, error)
|
type VerifyServerFunc func(string) (bool, error)
|
||||||
|
|
||||||
|
type RedirectFlags struct {
|
||||||
|
Clipboard bool
|
||||||
|
Port bool
|
||||||
|
Drive bool
|
||||||
|
Printer bool
|
||||||
|
Pnp bool
|
||||||
|
disableAll bool
|
||||||
|
enableAll bool
|
||||||
|
}
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
TransportIn transport.Transport
|
TransportIn transport.Transport
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
VerifyPAACookieFunc VerifyPAACookieFunc
|
VerifyPAACookieFunc VerifyPAACookieFunc
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||||
VerifyServerFunc VerifyServerFunc
|
VerifyServerFunc VerifyServerFunc
|
||||||
|
RedirectFlags int
|
||||||
|
IdleTimeout int
|
||||||
SmartCardAuth bool
|
SmartCardAuth bool
|
||||||
TokenAuth bool
|
TokenAuth bool
|
||||||
ClientName string
|
ClientName string
|
||||||
|
@ -32,11 +41,28 @@ type Handler struct {
|
||||||
State int
|
State int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
|
type HandlerConf struct {
|
||||||
|
VerifyPAACookieFunc VerifyPAACookieFunc
|
||||||
|
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||||
|
VerifyServerFunc VerifyServerFunc
|
||||||
|
RedirectFlags RedirectFlags
|
||||||
|
IdleTimeout int
|
||||||
|
SmartCardAuth bool
|
||||||
|
TokenAuth bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler {
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
TransportIn: in,
|
TransportIn: in,
|
||||||
TransportOut: out,
|
TransportOut: out,
|
||||||
State: SERVER_STATE_INITIAL,
|
State: SERVER_STATE_INITIAL,
|
||||||
|
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||||
|
IdleTimeout: conf.IdleTimeout,
|
||||||
|
SmartCardAuth: conf.SmartCardAuth,
|
||||||
|
TokenAuth: conf.TokenAuth,
|
||||||
|
VerifyPAACookieFunc: conf.VerifyPAACookieFunc,
|
||||||
|
VerifyServerFunc: conf.VerifyServerFunc,
|
||||||
|
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
|
||||||
}
|
}
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
@ -55,8 +81,8 @@ func (h *Handler) Process() error {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
|
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
major, minor, _, auth := readHandshake(pkt)
|
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
|
||||||
msg := h.handshakeResponse(major, minor, auth)
|
msg := h.handshakeResponse(major, minor)
|
||||||
h.TransportOut.WritePacket(msg)
|
h.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_HANDSHAKE
|
h.State = SERVER_STATE_HANDSHAKE
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
|
@ -189,7 +215,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||||
// Creates a packet the is a response to a handshake request
|
// Creates a packet the is a response to a handshake request
|
||||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||||
// but could be in Windows. However the NTLM protocol is insecure
|
// but could be in Windows. However the NTLM protocol is insecure
|
||||||
func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte {
|
func (h *Handler) handshakeResponse(major byte, minor byte) []byte {
|
||||||
var caps uint16
|
var caps uint16
|
||||||
if h.SmartCardAuth {
|
if h.SmartCardAuth {
|
||||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||||
|
@ -289,40 +315,13 @@ func (h *Handler) createTunnelAuthResponse() []byte {
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
|
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
|
||||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||||
|
|
||||||
// flags
|
|
||||||
var redir uint32
|
|
||||||
/*
|
|
||||||
if conf.Caps.RedirectAll {
|
|
||||||
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
|
|
||||||
} else if conf.Caps.DisableRedirect {
|
|
||||||
redir = HTTP_TUNNEL_REDIR_DISABLE_ALL
|
|
||||||
} else {
|
|
||||||
if conf.Caps.DisableClipboard {
|
|
||||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
|
|
||||||
}
|
|
||||||
if conf.Caps.DisableDrive {
|
|
||||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
|
|
||||||
}
|
|
||||||
if conf.Caps.DisablePnp {
|
|
||||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
|
|
||||||
}
|
|
||||||
if conf.Caps.DisablePrinter {
|
|
||||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
|
|
||||||
}
|
|
||||||
if conf.Caps.DisablePort {
|
|
||||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
|
|
||||||
|
|
||||||
// idle timeout
|
// idle timeout
|
||||||
if IdleTimeout < 0 {
|
if h.IdleTimeout < 0 {
|
||||||
IdleTimeout = 0
|
h.IdleTimeout = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags
|
binary.Write(buf, binary.LittleEndian, uint32(h.RedirectFlags)) // redir flags
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes
|
binary.Write(buf, binary.LittleEndian, uint32(h.IdleTimeout)) // timeout in minutes
|
||||||
|
|
||||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
@ -405,3 +404,31 @@ func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||||
|
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeRedirectFlags(flags RedirectFlags) int {
|
||||||
|
var redir = 0
|
||||||
|
|
||||||
|
if flags.disableAll {
|
||||||
|
return HTTP_TUNNEL_REDIR_DISABLE_ALL
|
||||||
|
}
|
||||||
|
if flags.enableAll {
|
||||||
|
return HTTP_TUNNEL_REDIR_ENABLE_ALL
|
||||||
|
}
|
||||||
|
|
||||||
|
if !flags.Port {
|
||||||
|
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
|
||||||
|
}
|
||||||
|
if !flags.Clipboard {
|
||||||
|
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
|
||||||
|
}
|
||||||
|
if !flags.Drive {
|
||||||
|
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
|
||||||
|
}
|
||||||
|
if !flags.Pnp {
|
||||||
|
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
|
||||||
|
}
|
||||||
|
if !flags.Printer {
|
||||||
|
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
|
||||||
|
}
|
||||||
|
return redir
|
||||||
|
}
|
||||||
|
|
|
@ -39,6 +39,10 @@ var (
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Gateway struct {
|
||||||
|
HandlerConf *HandlerConf
|
||||||
|
}
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
ConnId string
|
ConnId string
|
||||||
CorrelationId string
|
CorrelationId string
|
||||||
|
@ -60,14 +64,11 @@ func init() {
|
||||||
prometheus.MustRegister(websocketConnections)
|
prometheus.MustRegister(websocketConnections)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
for name, value := range r.Header {
|
|
||||||
log.Printf("Header Name: %s Value: %s", name, value)
|
|
||||||
}
|
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
handleLegacyProtocol(w, r)
|
g.handleLegacyProtocol(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Method = "GET" // force
|
r.Method = "GET" // force
|
||||||
|
@ -78,25 +79,25 @@ func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
handleWebsocketProtocol(conn)
|
g.handleWebsocketProtocol(conn)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
handleLegacyProtocol(w, r)
|
g.handleLegacyProtocol(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleWebsocketProtocol(c *websocket.Conn) {
|
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
inout, _ := transport.NewWS(c)
|
inout, _ := transport.NewWS(c)
|
||||||
handler := NewHandler(inout, inout)
|
handler := NewHandler(inout, inout, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process()
|
||||||
}
|
}
|
||||||
|
|
||||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
var s SessionInfo
|
var s SessionInfo
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
|
@ -143,7 +144,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||||
handler := NewHandler(in, s.TransportOut)
|
handler := NewHandler(in, s.TransportOut, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
5
security/simple.go
Normal file
5
security/simple.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
package security
|
||||||
|
|
||||||
|
func VerifyServerTemplate(server string) (bool, err) {
|
||||||
|
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue