Use standard HandleFunc pattern

This commit is contained in:
Bolke de Bruin 2020-07-13 15:38:25 +02:00
parent 3ab375314a
commit 1f191a5e41
2 changed files with 20 additions and 26 deletions

View file

@ -41,10 +41,10 @@ func main() {
cfg.Certificates = append(cfg.Certificates, cert)
server := http.Server{
Addr: ":" + strconv.Itoa(*port),
Handler: Upgrade(nil),
TLSConfig: cfg,
}
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
err = server.ListenAndServeTLS("", "")
if err != nil {
log.Fatal("ListenAndServe: ", err)

42
rdg.go
View file

@ -110,10 +110,6 @@ var ErrNotHijacker = RejectConnectionError(
var DefaultSession RdgSession
func Upgrade(next http.Handler) http.Handler {
return handleGatewayProtocol(next)
}
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
log.Print("Accept connection")
hj, ok := w.(http.Hijacker)
@ -132,29 +128,27 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err
var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r)
return
}
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return
}
defer conn.Close()
handleWebsocketProtocol(conn)
} else if r.Method == MethodRDGIN {
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r)
return
}
})
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return
}
defer conn.Close()
handleWebsocketProtocol(conn)
} else if r.Method == MethodRDGIN {
handleLegacyProtocol(w, r)
}
}
func handleWebsocketProtocol(conn *websocket.Conn) {
func handleWebsocketProtocol(conn *websocket.Conn) {
fragment := false
buf := make([]byte, 4096)
index := 0
@ -375,7 +369,7 @@ func handshakeResponse(major byte, minor byte, auth uint16) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
buf.Write([]byte{major, minor})
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())