mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-04 00:32:03 +02:00
Refactor some stuff
This commit is contained in:
parent
7264e7b92f
commit
3839058eb8
2 changed files with 26 additions and 24 deletions
|
@ -51,10 +51,10 @@ type HandlerConf struct {
|
|||
TokenAuth bool
|
||||
}
|
||||
|
||||
func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler {
|
||||
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||
h := &Handler{
|
||||
TransportIn: in,
|
||||
TransportOut: out,
|
||||
TransportIn: s.TransportIn,
|
||||
TransportOut: s.TransportOut,
|
||||
State: SERVER_STATE_INITIAL,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
IdleTimeout: conf.IdleTimeout,
|
||||
|
|
|
@ -51,10 +51,9 @@ type SessionInfo struct {
|
|||
TransportOut transport.Transport
|
||||
RemoteAddress string
|
||||
ProxyAddresses string
|
||||
UserName string
|
||||
}
|
||||
|
||||
var DefaultSession SessionInfo
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
|
@ -66,9 +65,20 @@ func init() {
|
|||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = &SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(*SessionInfo)
|
||||
}
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
g.handleLegacyProtocol(w, r)
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
return
|
||||
}
|
||||
r.Method = "GET" // force
|
||||
|
@ -79,35 +89,27 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
g.handleWebsocketProtocol(conn)
|
||||
g.handleWebsocketProtocol(conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r)
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) {
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
inout, _ := transport.NewWS(c)
|
||||
handler := NewHandler(inout, inout, g.HandlerConf)
|
||||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
var s SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(SessionInfo)
|
||||
}
|
||||
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
|
@ -121,7 +123,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
s.TransportOut = out
|
||||
out.SendAccept(true)
|
||||
|
||||
c.Set(connId, s, cache.DefaultExpiration)
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
legacyConnections.Inc()
|
||||
defer legacyConnections.Dec()
|
||||
|
@ -135,7 +137,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
if s.TransportIn == nil {
|
||||
s.TransportIn = in
|
||||
c.Set(connId, s, cache.DefaultExpiration)
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
|
||||
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
||||
in.SendAccept(false)
|
||||
|
@ -144,7 +146,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||
handler := NewHandler(in, s.TransportOut, g.HandlerConf)
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue