mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
Use context
This commit is contained in:
parent
5de3767e70
commit
39c73fc8fc
3 changed files with 21 additions and 20 deletions
22
api/web.go
22
api/web.go
|
@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := c.store.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
||||
userName := session.Values["preferred_username"]
|
||||
if userName == nil || userName.(string) == "" {
|
||||
// This shouldnt happen if the Authenticated handler is used to wrap this func
|
||||
log.Printf("Found expired or non existent session")
|
||||
http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError)
|
||||
if !ok {
|
||||
log.Printf("preferred_username not found in context")
|
||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// do a round robin selection for now
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1)
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
|
||||
user := userName.(string)
|
||||
user := userName
|
||||
if c.UsernameTemplate != "" {
|
||||
user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
|
||||
if c.UsernameTemplate == user {
|
||||
|
|
|
@ -2,6 +2,7 @@ package protocol
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
|
@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
|||
|
||||
const tunnelId = 10
|
||||
|
||||
func (h *Handler) Process() error {
|
||||
func (h *Handler) Process(ctx context.Context) error {
|
||||
for {
|
||||
pt, sz, pkt, err := h.ReadMessage()
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
@ -64,6 +65,9 @@ func init() {
|
|||
}
|
||||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
|
@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
return
|
||||
}
|
||||
r.Method = "GET" // force
|
||||
|
@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
g.handleWebsocketProtocol(conn, s)
|
||||
g.handleWebsocketProtocol(ctx, conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
|
@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
|||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
handler.Process(ctx)
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
|
@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
|||
|
||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
handler.Process(r.Context())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue