Use context

This commit is contained in:
Bolke de Bruin 2020-07-25 11:48:11 +02:00
parent 5de3767e70
commit 39c73fc8fc
3 changed files with 21 additions and 20 deletions

View file

@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
return return
} }
next.ServeHTTP(w, r) ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
session, err := c.store.Get(r, RdpGwSession) ctx := r.Context()
if err != nil { userName, ok := ctx.Value("preferred_username").(string)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
userName := session.Values["preferred_username"] if !ok {
if userName == nil || userName.(string) == "" { log.Printf("preferred_username not found in context")
// This shouldnt happen if the Authenticated handler is used to wrap this func http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
log.Printf("Found expired or non existent session")
http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError)
return return
} }
// do a round robin selection for now // do a round robin selection for now
rand.Seed(time.Now().Unix()) rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))] host := c.Hosts[rand.Intn(len(c.Hosts))]
host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1) host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
user := userName.(string) user := userName
if c.UsernameTemplate != "" { if c.UsernameTemplate != "" {
user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1) user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
if c.UsernameTemplate == user { if c.UsernameTemplate == user {

View file

@ -2,6 +2,7 @@ package protocol
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/transport" "github.com/bolkedebruin/rdpgw/transport"
@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
const tunnelId = 10 const tunnelId = 10
func (h *Handler) Process() error { func (h *Handler) Process(ctx context.Context) error {
for { for {
pt, sz, pkt, err := h.ReadMessage() pt, sz, pkt, err := h.ReadMessage()
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package protocol package protocol
import ( import (
"context"
"github.com/bolkedebruin/rdpgw/transport" "github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -64,6 +65,9 @@ func init() {
} }
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionCache.Set(float64(c.ItemCount())) connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo var s *SessionInfo
@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r, s) g.handleLegacyProtocol(w, r.WithContext(ctx), s)
return return
} }
r.Method = "GET" // force r.Method = "GET" // force
@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
} }
defer conn.Close() defer conn.Close()
g.handleWebsocketProtocol(conn, s) g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r, s) g.handleLegacyProtocol(w, r.WithContext(ctx), s)
} }
} }
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) { func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc() websocketConnections.Inc()
defer websocketConnections.Dec() defer websocketConnections.Dec()
@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
s.TransportOut = inout s.TransportOut = inout
s.TransportIn = inout s.TransportIn = inout
handler := NewHandler(s, g.HandlerConf) handler := NewHandler(s, g.HandlerConf)
handler.Process() handler.Process(ctx)
} }
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
handler := NewHandler(s, g.HandlerConf) handler := NewHandler(s, g.HandlerConf)
handler.Process() handler.Process(r.Context())
} }
} }
} }