mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-27 12:56:09 +02:00
Some improvements
This commit is contained in:
parent
6fd7b047cd
commit
354355e6b0
3 changed files with 25 additions and 23 deletions
|
@ -47,6 +47,7 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
|||
"gatewayhostname:s:" + conf.Server.GatewayAddress +"\r\n"+
|
||||
"gatewaycredentialssource:i:5\r\n"+
|
||||
"gatewayusagemethod:i:1\r\n"+
|
||||
"gatewayprofileusagemethod:i:1\r\n"+
|
||||
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
||||
}
|
||||
|
||||
|
|
|
@ -82,9 +82,11 @@ func (h *Handler) Process() error {
|
|||
return errors.New("wrong state")
|
||||
}
|
||||
client := h.readTunnelAuthRequest(pkt)
|
||||
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
if h.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
}
|
||||
}
|
||||
msg := h.createTunnelAuthResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
|
@ -117,7 +119,7 @@ func (h *Handler) Process() error {
|
|||
go h.sendDataPacket()
|
||||
h.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if h.State != SERVER_STATE_CHANNEL_CREATE {
|
||||
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
|
@ -342,10 +344,8 @@ func readChannelCreateRequest(data []byte) (server string, port uint16) {
|
|||
nameData := make([]byte, nameSize)
|
||||
binary.Read(buf, binary.LittleEndian, &nameData)
|
||||
|
||||
log.Printf("Name data %q", nameData)
|
||||
server, _ = DecodeUTF16(nameData)
|
||||
|
||||
log.Printf("Should connect to %s on port %d", server, port)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -7,15 +7,14 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
rdgConnectionIdKey = "Rdg-Connection-Id"
|
||||
MethodRDGIN = "RDG_IN_DATA"
|
||||
MethodRDGOUT = "RDG_OUT_DATA"
|
||||
MethodRDGIN = "RDG_IN_DATA"
|
||||
MethodRDGOUT = "RDG_OUT_DATA"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -47,18 +46,17 @@ type HandshakeHeader interface {
|
|||
io.WriterTo
|
||||
}
|
||||
|
||||
type RdgSession struct {
|
||||
ConnId string
|
||||
CorrelationId string
|
||||
UserId string
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
StateIn int
|
||||
StateOut int
|
||||
Remote net.Conn
|
||||
type SessionInfo struct {
|
||||
ConnId string
|
||||
CorrelationId string
|
||||
ClientGeneration string
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
RemoteAddress string
|
||||
ProxyAddresses string
|
||||
}
|
||||
|
||||
var DefaultSession RdgSession
|
||||
var DefaultSession SessionInfo
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
@ -72,6 +70,9 @@ func init() {
|
|||
func HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
if r.Method == MethodRDGOUT {
|
||||
for name, value := range r.Header {
|
||||
log.Printf("Header Name: %s Value: %s", name, value)
|
||||
}
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
handleLegacyProtocol(w, r)
|
||||
return
|
||||
|
@ -103,14 +104,14 @@ func handleWebsocketProtocol(c *websocket.Conn) {
|
|||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||
func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
var s RdgSession
|
||||
var s SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
|
||||
s = SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(RdgSession)
|
||||
s = x.(SessionInfo)
|
||||
}
|
||||
|
||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue