mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
Use context
This commit is contained in:
parent
5de3767e70
commit
39c73fc8fc
3 changed files with 21 additions and 20 deletions
22
api/web.go
22
api/web.go
|
@ -127,31 +127,27 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
session, err := c.store.Get(r, RdpGwSession)
|
ctx := r.Context()
|
||||||
if err != nil {
|
userName, ok := ctx.Value("preferred_username").(string)
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userName := session.Values["preferred_username"]
|
if !ok {
|
||||||
if userName == nil || userName.(string) == "" {
|
log.Printf("preferred_username not found in context")
|
||||||
// This shouldnt happen if the Authenticated handler is used to wrap this func
|
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||||
log.Printf("Found expired or non existent session")
|
|
||||||
http.Error(w, errors.New("cannot find session").Error(), http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// do a round robin selection for now
|
// do a round robin selection for now
|
||||||
rand.Seed(time.Now().Unix())
|
rand.Seed(time.Now().Unix())
|
||||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||||
host = strings.Replace(host, "{{ preferred_username }}", userName.(string), 1)
|
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||||
|
|
||||||
user := userName.(string)
|
user := userName
|
||||||
if c.UsernameTemplate != "" {
|
if c.UsernameTemplate != "" {
|
||||||
user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
|
user = strings.Replace(c.UsernameTemplate, "{{ username }}", user, 1)
|
||||||
if c.UsernameTemplate == user {
|
if c.UsernameTemplate == user {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/transport"
|
||||||
|
@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||||
|
|
||||||
const tunnelId = 10
|
const tunnelId = 10
|
||||||
|
|
||||||
func (h *Handler) Process() error {
|
func (h *Handler) Process(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := h.ReadMessage()
|
pt, sz, pkt, err := h.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/transport"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
@ -64,6 +65,9 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
|
|
||||||
var s *SessionInfo
|
var s *SessionInfo
|
||||||
|
@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
g.handleLegacyProtocol(w, r, s)
|
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Method = "GET" // force
|
r.Method = "GET" // force
|
||||||
|
@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
g.handleWebsocketProtocol(conn, s)
|
g.handleWebsocketProtocol(ctx, conn, s)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
g.handleLegacyProtocol(w, r, s)
|
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
|
@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||||
s.TransportOut = inout
|
s.TransportOut = inout
|
||||||
s.TransportIn = inout
|
s.TransportIn = inout
|
||||||
handler := NewHandler(s, g.HandlerConf)
|
handler := NewHandler(s, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||||
|
@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
|
|
||||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||||
handler := NewHandler(s, g.HandlerConf)
|
handler := NewHandler(s, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue