Use context

This commit is contained in:
Bolke de Bruin 2020-07-25 11:48:11 +02:00
parent 5de3767e70
commit 39c73fc8fc
3 changed files with 21 additions and 20 deletions

View file

@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
return
}
next.ServeHTTP(w, r)
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
session, err := c.store.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
userName := session.Values["preferred_username"]
if userName == nil || userName.(string) == "" {
// This shouldnt happen if the Authenticated handler is used to wrap this func
log.Printf("Found expired or non existent session")
http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError)
if !ok {
log.Printf("preferred_username not found in context")
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return
}
// do a round robin selection for now
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1)
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
user := userName.(string)
user := userName
if c.UsernameTemplate != "" {
user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
if c.UsernameTemplate == user {

View file

@ -2,6 +2,7 @@ package protocol
import (
"bytes"
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
const tunnelId = 10
func (h *Handler) Process() error {
func (h *Handler) Process(ctx context.Context) error {
for {
pt, sz, pkt, err := h.ReadMessage()
if err != nil {

View file

@ -1,6 +1,7 @@
package protocol
import (
"context"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
@ -64,6 +65,9 @@ func init() {
}
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo
@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r, s)
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
return
}
r.Method = "GET" // force
@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
g.handleWebsocketProtocol(conn, s)
g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r, s)
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
}
}
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc()
defer websocketConnections.Dec()
@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
s.TransportOut = inout
s.TransportIn = inout
handler := NewHandler(s, g.HandlerConf)
handler.Process()
handler.Process(ctx)
}
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
handler := NewHandler(s, g.HandlerConf)
handler.Process()
handler.Process(r.Context())
}
}
}