mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-13 04:19:19 +02:00
Refactor some stuff
This commit is contained in:
parent
7264e7b92f
commit
3839058eb8
2 changed files with 26 additions and 24 deletions
|
@ -51,10 +51,10 @@ type HandlerConf struct {
|
||||||
TokenAuth bool
|
TokenAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler {
|
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
TransportIn: in,
|
TransportIn: s.TransportIn,
|
||||||
TransportOut: out,
|
TransportOut: s.TransportOut,
|
||||||
State: SERVER_STATE_INITIAL,
|
State: SERVER_STATE_INITIAL,
|
||||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||||
IdleTimeout: conf.IdleTimeout,
|
IdleTimeout: conf.IdleTimeout,
|
||||||
|
|
|
@ -51,10 +51,9 @@ type SessionInfo struct {
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
RemoteAddress string
|
RemoteAddress string
|
||||||
ProxyAddresses string
|
ProxyAddresses string
|
||||||
|
UserName string
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultSession SessionInfo
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
|
||||||
|
@ -66,9 +65,20 @@ func init() {
|
||||||
|
|
||||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
|
|
||||||
|
var s *SessionInfo
|
||||||
|
|
||||||
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
|
x, found := c.Get(connId)
|
||||||
|
if !found {
|
||||||
|
s = &SessionInfo{ConnId: connId}
|
||||||
|
} else {
|
||||||
|
s = x.(*SessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
g.handleLegacyProtocol(w, r)
|
g.handleLegacyProtocol(w, r, s)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Method = "GET" // force
|
r.Method = "GET" // force
|
||||||
|
@ -79,35 +89,27 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
g.handleWebsocketProtocol(conn)
|
g.handleWebsocketProtocol(conn, s)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
g.handleLegacyProtocol(w, r)
|
g.handleLegacyProtocol(w, r, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) {
|
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
inout, _ := transport.NewWS(c)
|
inout, _ := transport.NewWS(c)
|
||||||
handler := NewHandler(inout, inout, g.HandlerConf)
|
s.TransportOut = inout
|
||||||
|
s.TransportIn = inout
|
||||||
|
handler := NewHandler(s, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process()
|
||||||
}
|
}
|
||||||
|
|
||||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
||||||
var s SessionInfo
|
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
|
||||||
x, found := c.Get(connId)
|
|
||||||
if !found {
|
|
||||||
s = SessionInfo{ConnId: connId}
|
|
||||||
} else {
|
|
||||||
s = x.(SessionInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
|
@ -121,7 +123,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
s.TransportOut = out
|
s.TransportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
|
||||||
c.Set(connId, s, cache.DefaultExpiration)
|
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
legacyConnections.Inc()
|
legacyConnections.Inc()
|
||||||
defer legacyConnections.Dec()
|
defer legacyConnections.Dec()
|
||||||
|
@ -135,7 +137,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
if s.TransportIn == nil {
|
if s.TransportIn == nil {
|
||||||
s.TransportIn = in
|
s.TransportIn = in
|
||||||
c.Set(connId, s, cache.DefaultExpiration)
|
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
@ -144,7 +146,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||||
handler := NewHandler(in, s.TransportOut, g.HandlerConf)
|
handler := NewHandler(s, g.HandlerConf)
|
||||||
handler.Process()
|
handler.Process()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue