Improve configurability

This commit is contained in:
Bolke de Bruin 2020-07-21 10:16:31 +02:00
parent 93b31ec9b6
commit 01345b9416
4 changed files with 97 additions and 53 deletions

13
main.go
View file

@ -89,7 +89,18 @@ func main() {
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
} }
http.HandleFunc("/remoteDesktopGateway/", protocol.HandleGatewayProtocol) // create the gateway
handlerConfig := protocol.HandlerConf{
TokenAuth: true,
RedirectFlags: protocol.RedirectFlags{
Clipboard: true,
},
}
gw := protocol.Gateway{
HandlerConf: &handlerConfig,
}
http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol)
http.HandleFunc("/connect", handleRdpDownload) http.HandleFunc("/connect", handleRdpDownload)
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/callback", handleCallback) http.HandleFunc("/callback", handleCallback)

View file

@ -12,19 +12,28 @@ import (
"time" "time"
) )
// When should the client disconnect when idle in minutes
var IdleTimeout = 0
type VerifyPAACookieFunc func(string) (bool, error) type VerifyPAACookieFunc func(string) (bool, error)
type VerifyTunnelAuthFunc func(string) (bool, error) type VerifyTunnelAuthFunc func(string) (bool, error)
type VerifyServerFunc func(string) (bool, error) type VerifyServerFunc func(string) (bool, error)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
disableAll bool
enableAll bool
}
type Handler struct { type Handler struct {
TransportIn transport.Transport TransportIn transport.Transport
TransportOut transport.Transport TransportOut transport.Transport
VerifyPAACookieFunc VerifyPAACookieFunc VerifyPAACookieFunc VerifyPAACookieFunc
VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc VerifyServerFunc VerifyServerFunc
RedirectFlags int
IdleTimeout int
SmartCardAuth bool SmartCardAuth bool
TokenAuth bool TokenAuth bool
ClientName string ClientName string
@ -32,11 +41,28 @@ type Handler struct {
State int State int
} }
func NewHandler(in transport.Transport, out transport.Transport) *Handler { type HandlerConf struct {
VerifyPAACookieFunc VerifyPAACookieFunc
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
RedirectFlags RedirectFlags
IdleTimeout int
SmartCardAuth bool
TokenAuth bool
}
func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler {
h := &Handler{ h := &Handler{
TransportIn: in, TransportIn: in,
TransportOut: out, TransportOut: out,
State: SERVER_STATE_INITIAL, State: SERVER_STATE_INITIAL,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
IdleTimeout: conf.IdleTimeout,
SmartCardAuth: conf.SmartCardAuth,
TokenAuth: conf.TokenAuth,
VerifyPAACookieFunc: conf.VerifyPAACookieFunc,
VerifyServerFunc: conf.VerifyServerFunc,
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
} }
return h return h
} }
@ -55,8 +81,8 @@ func (h *Handler) Process() error {
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL) log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
return errors.New("wrong state") return errors.New("wrong state")
} }
major, minor, _, auth := readHandshake(pkt) major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
msg := h.handshakeResponse(major, minor, auth) msg := h.handshakeResponse(major, minor)
h.TransportOut.WritePacket(msg) h.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_HANDSHAKE h.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
@ -189,7 +215,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
// Creates a packet the is a response to a handshake request // Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure // but could be in Windows. However the NTLM protocol is insecure
func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte { func (h *Handler) handshakeResponse(major byte, minor byte) []byte {
var caps uint16 var caps uint16
if h.SmartCardAuth { if h.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA caps = caps | HTTP_EXTENDED_AUTH_PAA
@ -289,40 +315,13 @@ func (h *Handler) createTunnelAuthResponse() []byte {
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// flags
var redir uint32
/*
if conf.Caps.RedirectAll {
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
} else if conf.Caps.DisableRedirect {
redir = HTTP_TUNNEL_REDIR_DISABLE_ALL
} else {
if conf.Caps.DisableClipboard {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
}
if conf.Caps.DisableDrive {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
}
if conf.Caps.DisablePnp {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
}
if conf.Caps.DisablePrinter {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
}
if conf.Caps.DisablePort {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
}
}
*/
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
// idle timeout // idle timeout
if IdleTimeout < 0 { if h.IdleTimeout < 0 {
IdleTimeout = 0 h.IdleTimeout = 0
} }
binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags binary.Write(buf, binary.LittleEndian, uint32(h.RedirectFlags)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes binary.Write(buf, binary.LittleEndian, uint32(h.IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
} }
@ -405,3 +404,31 @@ func createPacket(pktType uint16, data []byte) (packet []byte) {
return buf.Bytes() return buf.Bytes()
} }
func makeRedirectFlags(flags RedirectFlags) int {
var redir = 0
if flags.disableAll {
return HTTP_TUNNEL_REDIR_DISABLE_ALL
}
if flags.enableAll {
return HTTP_TUNNEL_REDIR_ENABLE_ALL
}
if !flags.Port {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
}
if !flags.Clipboard {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
}
if !flags.Drive {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
}
if !flags.Pnp {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
}
if !flags.Printer {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
}
return redir
}

View file

@ -39,6 +39,10 @@ var (
}) })
) )
type Gateway struct {
HandlerConf *HandlerConf
}
type SessionInfo struct { type SessionInfo struct {
ConnId string ConnId string
CorrelationId string CorrelationId string
@ -60,14 +64,11 @@ func init() {
prometheus.MustRegister(websocketConnections) prometheus.MustRegister(websocketConnections)
} }
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount())) connectionCache.Set(float64(c.ItemCount()))
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
for name, value := range r.Header {
log.Printf("Header Name: %s Value: %s", name, value)
}
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r) g.handleLegacyProtocol(w, r)
return return
} }
r.Method = "GET" // force r.Method = "GET" // force
@ -78,25 +79,25 @@ func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
} }
defer conn.Close() defer conn.Close()
handleWebsocketProtocol(conn) g.handleWebsocketProtocol(conn)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
handleLegacyProtocol(w, r) g.handleLegacyProtocol(w, r)
} }
} }
func handleWebsocketProtocol(c *websocket.Conn) { func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) {
websocketConnections.Inc() websocketConnections.Inc()
defer websocketConnections.Dec() defer websocketConnections.Dec()
inout, _ := transport.NewWS(c) inout, _ := transport.NewWS(c)
handler := NewHandler(inout, inout) handler := NewHandler(inout, inout, g.HandlerConf)
handler.Process() handler.Process()
} }
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different // and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
// to ensure the connections do not get cached or terminated by a proxy prematurely. // to ensure the connections do not get cached or terminated by a proxy prematurely.
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) { func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
var s SessionInfo var s SessionInfo
connId := r.Header.Get(rdgConnectionIdKey) connId := r.Header.Get(rdgConnectionIdKey)
@ -143,7 +144,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
in.Drain() in.Drain()
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
handler := NewHandler(in, s.TransportOut) handler := NewHandler(in, s.TransportOut, g.HandlerConf)
handler.Process() handler.Process()
} }
} }

5
security/simple.go Normal file
View file

@ -0,0 +1,5 @@
package security
func VerifyServerTemplate(server string) (bool, err) {
}