Use context

This commit is contained in:
Bolke de Bruin 2020-07-25 11:48:11 +02:00
parent 5de3767e70
commit 39c73fc8fc
3 changed files with 21 additions and 20 deletions

View file

@ -1,6 +1,7 @@
package protocol
import (
"context"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
@ -64,6 +65,9 @@ func init() {
}
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo
@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r, s)
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
return
}
r.Method = "GET" // force
@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
g.handleWebsocketProtocol(conn, s)
g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r, s)
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
}
}
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc()
defer websocketConnections.Dec()
@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
s.TransportOut = inout
s.TransportIn = inout
handler := NewHandler(s, g.HandlerConf)
handler.Process()
handler.Process(ctx)
}
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
handler := NewHandler(s, g.HandlerConf)
handler.Process()
handler.Process(r.Context())
}
}
}