mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-27 04:48:17 +02:00
150 lines
3.8 KiB
Go
150 lines
3.8 KiB
Go
package protocol
|
|
|
|
import (
|
|
"github.com/bolkedebruin/rdpgw/transport"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/patrickmn/go-cache"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
rdgConnectionIdKey = "Rdg-Connection-Id"
|
|
MethodRDGIN = "RDG_IN_DATA"
|
|
MethodRDGOUT = "RDG_OUT_DATA"
|
|
)
|
|
|
|
var (
|
|
connectionCache = prometheus.NewGauge(
|
|
prometheus.GaugeOpts{
|
|
Namespace: "rdpgw",
|
|
Name: "connection_cache",
|
|
Help: "The amount of connections in the cache",
|
|
})
|
|
|
|
websocketConnections = prometheus.NewGauge(
|
|
prometheus.GaugeOpts{
|
|
Namespace: "rdpgw",
|
|
Name: "websocket_connections",
|
|
Help: "The count of websocket connections",
|
|
})
|
|
|
|
legacyConnections = prometheus.NewGauge(
|
|
prometheus.GaugeOpts{
|
|
Namespace: "rdpgw",
|
|
Name: "legacy_connections",
|
|
Help: "The count of legacy https connections",
|
|
})
|
|
)
|
|
|
|
type SessionInfo struct {
|
|
ConnId string
|
|
CorrelationId string
|
|
ClientGeneration string
|
|
TransportIn transport.Transport
|
|
TransportOut transport.Transport
|
|
RemoteAddress string
|
|
ProxyAddresses string
|
|
}
|
|
|
|
var DefaultSession SessionInfo
|
|
|
|
var upgrader = websocket.Upgrader{}
|
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
|
|
|
func init() {
|
|
prometheus.MustRegister(connectionCache)
|
|
prometheus.MustRegister(legacyConnections)
|
|
prometheus.MustRegister(websocketConnections)
|
|
}
|
|
|
|
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
|
connectionCache.Set(float64(c.ItemCount()))
|
|
if r.Method == MethodRDGOUT {
|
|
for name, value := range r.Header {
|
|
log.Printf("Header Name: %s Value: %s", name, value)
|
|
}
|
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
|
handleLegacyProtocol(w, r)
|
|
return
|
|
}
|
|
r.Method = "GET" // force
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
handleWebsocketProtocol(conn)
|
|
} else if r.Method == MethodRDGIN {
|
|
handleLegacyProtocol(w, r)
|
|
}
|
|
}
|
|
|
|
func handleWebsocketProtocol(c *websocket.Conn) {
|
|
websocketConnections.Inc()
|
|
defer websocketConnections.Dec()
|
|
|
|
inout, _ := transport.NewWS(c)
|
|
handler := NewHandler(inout, inout)
|
|
handler.Process()
|
|
}
|
|
|
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
|
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
|
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|
var s SessionInfo
|
|
|
|
connId := r.Header.Get(rdgConnectionIdKey)
|
|
x, found := c.Get(connId)
|
|
if !found {
|
|
s = SessionInfo{ConnId: connId}
|
|
} else {
|
|
s = x.(SessionInfo)
|
|
}
|
|
|
|
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
|
|
|
if r.Method == MethodRDGOUT {
|
|
out, err := transport.NewLegacy(w)
|
|
if err != nil {
|
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
|
return
|
|
}
|
|
log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
|
|
|
|
s.TransportOut = out
|
|
out.SendAccept(true)
|
|
|
|
c.Set(connId, s, cache.DefaultExpiration)
|
|
} else if r.Method == MethodRDGIN {
|
|
legacyConnections.Inc()
|
|
defer legacyConnections.Dec()
|
|
|
|
in, err := transport.NewLegacy(w)
|
|
if err != nil {
|
|
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
|
return
|
|
}
|
|
defer in.Close()
|
|
|
|
if s.TransportIn == nil {
|
|
s.TransportIn = in
|
|
c.Set(connId, s, cache.DefaultExpiration)
|
|
|
|
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
|
in.SendAccept(false)
|
|
|
|
// read some initial data
|
|
in.Drain()
|
|
|
|
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
|
handler := NewHandler(in, s.TransportOut)
|
|
handler.Process()
|
|
}
|
|
}
|
|
}
|