Use standard HandleFunc pattern

This commit is contained in:
Bolke de Bruin 2020-07-13 15:38:25 +02:00
parent 3ab375314a
commit 1f191a5e41
2 changed files with 20 additions and 26 deletions

View file

@ -41,9 +41,9 @@ func main() {
cfg.Certificates = append(cfg.Certificates, cert) cfg.Certificates = append(cfg.Certificates, cert)
server := http.Server{ server := http.Server{
Addr: ":" + strconv.Itoa(*port), Addr: ":" + strconv.Itoa(*port),
Handler: Upgrade(nil),
TLSConfig: cfg, TLSConfig: cfg,
} }
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
err = server.ListenAndServeTLS("", "") err = server.ListenAndServeTLS("", "")
if err != nil { if err != nil {

8
rdg.go
View file

@ -110,10 +110,6 @@ var ErrNotHijacker = RejectConnectionError(
var DefaultSession RdgSession var DefaultSession RdgSession
func Upgrade(next http.Handler) http.Handler {
return handleGatewayProtocol(next)
}
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
log.Print("Accept connection") log.Print("Accept connection")
hj, ok := w.(http.Hijacker) hj, ok := w.(http.Hijacker)
@ -132,8 +128,7 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(next http.Handler) http.Handler { func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r) handleLegacyProtocol(w, r)
@ -151,7 +146,6 @@ func handleGatewayProtocol(next http.Handler) http.Handler {
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
handleLegacyProtocol(w, r) handleLegacyProtocol(w, r)
} }
})
} }
func handleWebsocketProtocol(conn *websocket.Conn) { func handleWebsocketProtocol(conn *websocket.Conn) {