mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-27 12:56:09 +02:00
Some improvements
This commit is contained in:
parent
6fd7b047cd
commit
354355e6b0
3 changed files with 25 additions and 23 deletions
|
@ -47,6 +47,7 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
"gatewayhostname:s:" + conf.Server.GatewayAddress +"\r\n"+
|
"gatewayhostname:s:" + conf.Server.GatewayAddress +"\r\n"+
|
||||||
"gatewaycredentialssource:i:5\r\n"+
|
"gatewaycredentialssource:i:5\r\n"+
|
||||||
"gatewayusagemethod:i:1\r\n"+
|
"gatewayusagemethod:i:1\r\n"+
|
||||||
|
"gatewayprofileusagemethod:i:1\r\n"+
|
||||||
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,9 +82,11 @@ func (h *Handler) Process() error {
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
client := h.readTunnelAuthRequest(pkt)
|
client := h.readTunnelAuthRequest(pkt)
|
||||||
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
if h.VerifyTunnelAuthFunc != nil {
|
||||||
log.Printf("Invalid client name: %s", client)
|
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
||||||
return errors.New("invalid client name")
|
log.Printf("Invalid client name: %s", client)
|
||||||
|
return errors.New("invalid client name")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
msg := h.createTunnelAuthResponse()
|
msg := h.createTunnelAuthResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.TransportOut.WritePacket(msg)
|
||||||
|
@ -117,7 +119,7 @@ func (h *Handler) Process() error {
|
||||||
go h.sendDataPacket()
|
go h.sendDataPacket()
|
||||||
h.State = SERVER_STATE_CHANNEL_CREATE
|
h.State = SERVER_STATE_CHANNEL_CREATE
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
if h.State != SERVER_STATE_CHANNEL_CREATE {
|
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
||||||
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
|
@ -342,10 +344,8 @@ func readChannelCreateRequest(data []byte) (server string, port uint16) {
|
||||||
nameData := make([]byte, nameSize)
|
nameData := make([]byte, nameSize)
|
||||||
binary.Read(buf, binary.LittleEndian, &nameData)
|
binary.Read(buf, binary.LittleEndian, &nameData)
|
||||||
|
|
||||||
log.Printf("Name data %q", nameData)
|
|
||||||
server, _ = DecodeUTF16(nameData)
|
server, _ = DecodeUTF16(nameData)
|
||||||
|
|
||||||
log.Printf("Should connect to %s on port %d", server, port)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,15 +7,14 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
rdgConnectionIdKey = "Rdg-Connection-Id"
|
rdgConnectionIdKey = "Rdg-Connection-Id"
|
||||||
MethodRDGIN = "RDG_IN_DATA"
|
MethodRDGIN = "RDG_IN_DATA"
|
||||||
MethodRDGOUT = "RDG_OUT_DATA"
|
MethodRDGOUT = "RDG_OUT_DATA"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -47,18 +46,17 @@ type HandshakeHeader interface {
|
||||||
io.WriterTo
|
io.WriterTo
|
||||||
}
|
}
|
||||||
|
|
||||||
type RdgSession struct {
|
type SessionInfo struct {
|
||||||
ConnId string
|
ConnId string
|
||||||
CorrelationId string
|
CorrelationId string
|
||||||
UserId string
|
ClientGeneration string
|
||||||
TransportIn transport.Transport
|
TransportIn transport.Transport
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
StateIn int
|
RemoteAddress string
|
||||||
StateOut int
|
ProxyAddresses string
|
||||||
Remote net.Conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultSession RdgSession
|
var DefaultSession SessionInfo
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
@ -72,6 +70,9 @@ func init() {
|
||||||
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
|
for name, value := range r.Header {
|
||||||
|
log.Printf("Header Name: %s Value: %s", name, value)
|
||||||
|
}
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
handleLegacyProtocol(w, r)
|
handleLegacyProtocol(w, r)
|
||||||
return
|
return
|
||||||
|
@ -103,14 +104,14 @@ func handleWebsocketProtocol(c *websocket.Conn) {
|
||||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
var s RdgSession
|
var s SessionInfo
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
x, found := c.Get(connId)
|
x, found := c.Get(connId)
|
||||||
if !found {
|
if !found {
|
||||||
s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
|
s = SessionInfo{ConnId: connId}
|
||||||
} else {
|
} else {
|
||||||
s = x.(RdgSession)
|
s = x.(SessionInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||||
|
@ -153,4 +154,4 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
handler.Process()
|
handler.Process()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue