mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-15 13:23:46 +02:00
Refactor add bit of tracking
This commit is contained in:
parent
8aa7c8cbb7
commit
ce6692d22f
5 changed files with 158 additions and 113 deletions
|
@ -177,7 +177,7 @@ func main() {
|
|||
}
|
||||
|
||||
// create the gateway
|
||||
gwConfig := protocol.ServerConf{
|
||||
gwConfig := protocol.ProcessorConf{
|
||||
IdleTimeout: conf.Caps.IdleTimeout,
|
||||
TokenAuth: conf.Caps.TokenAuth,
|
||||
SmartCardAuth: conf.Caps.SmartCardAuth,
|
||||
|
@ -202,6 +202,7 @@ func main() {
|
|||
gw := protocol.Gateway{
|
||||
ServerConf: &gwConfig,
|
||||
}
|
||||
gwserver = &gw
|
||||
|
||||
if conf.Server.Authentication == config.AuthenticationBasic {
|
||||
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
|
||||
|
@ -215,6 +216,7 @@ func main() {
|
|||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||
http.HandleFunc("/list", List)
|
||||
|
||||
if conf.Server.Tls == config.TlsDisable {
|
||||
err = server.ListenAndServe()
|
||||
|
@ -225,3 +227,12 @@ func main() {
|
|||
log.Fatal("ListenAndServe: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
var gwserver *protocol.Gateway
|
||||
|
||||
func List(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
for k, v := range protocol.Connections {
|
||||
fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ var (
|
|||
)
|
||||
|
||||
type Gateway struct {
|
||||
ServerConf *ServerConf
|
||||
ServerConf *ProcessorConf
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
@ -164,7 +164,9 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
|||
|
||||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewServer(s, g.ServerConf)
|
||||
handler := NewProcessor(s, g.ServerConf)
|
||||
RegisterConnection(s.ConnId, handler, s)
|
||||
defer CloseConnection(s.ConnId)
|
||||
handler.Process(ctx)
|
||||
}
|
||||
|
||||
|
@ -208,7 +210,9 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
|||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
||||
handler := NewServer(s, g.ServerConf)
|
||||
handler := NewProcessor(s, g.ServerConf)
|
||||
RegisterConnection(s.ConnId, handler, s)
|
||||
defer CloseConnection(s.ConnId)
|
||||
handler.Process(r.Context())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
|||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||
|
||||
type Server struct {
|
||||
type Processor struct {
|
||||
Session *SessionInfo
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
|
@ -32,7 +32,7 @@ type Server struct {
|
|||
State int
|
||||
}
|
||||
|
||||
type ServerConf struct {
|
||||
type ProcessorConf struct {
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
|
@ -44,8 +44,8 @@ type ServerConf struct {
|
|||
SendBuf int
|
||||
}
|
||||
|
||||
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
||||
h := &Server{
|
||||
func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor {
|
||||
h := &Processor{
|
||||
State: SERVER_STATE_INITIALIZED,
|
||||
Session: s,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
|
@ -61,123 +61,123 @@ func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
|||
|
||||
const tunnelId = 10
|
||||
|
||||
func (s *Server) Process(ctx context.Context) error {
|
||||
func (p *Processor) Process(ctx context.Context) error {
|
||||
for {
|
||||
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
|
||||
pt, sz, pkt, err := readMessage(p.Session.TransportIn)
|
||||
if err != nil {
|
||||
log.Printf("Cannot read message from stream %s", err)
|
||||
log.Printf("Cannot read message from stream %p", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||
if s.State != SERVER_STATE_INITIALIZED {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIALIZED)
|
||||
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
|
||||
if p.State != SERVER_STATE_INITIALIZED {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED)
|
||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
|
||||
}
|
||||
major, minor, _, reqAuth := s.handshakeRequest(pkt)
|
||||
caps, err := s.matchAuth(reqAuth)
|
||||
major, minor, _, reqAuth := p.handshakeRequest(pkt)
|
||||
caps, err := p.matchAuth(reqAuth)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return err
|
||||
}
|
||||
msg := s.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_HANDSHAKE
|
||||
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
p.State = SERVER_STATE_HANDSHAKE
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
log.Printf("Tunnel create")
|
||||
if s.State != SERVER_STATE_HANDSHAKE {
|
||||
if p.State != SERVER_STATE_HANDSHAKE {
|
||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_HANDSHAKE)
|
||||
msg := s.tunnelResponse(E_PROXY_INTERNALERROR)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
p.State, SERVER_STATE_HANDSHAKE)
|
||||
msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||
}
|
||||
_, cookie := s.tunnelRequest(pkt)
|
||||
if s.VerifyTunnelCreate != nil {
|
||||
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||
msg := s.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
_, cookie := p.tunnelRequest(pkt)
|
||||
if p.VerifyTunnelCreate != nil {
|
||||
if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
|
||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
}
|
||||
}
|
||||
msg := s.tunnelResponse(ERROR_SUCCESS)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_CREATE
|
||||
msg := p.tunnelResponse(ERROR_SUCCESS)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
p.State = SERVER_STATE_TUNNEL_CREATE
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
log.Printf("Tunnel auth")
|
||||
if s.State != SERVER_STATE_TUNNEL_CREATE {
|
||||
if p.State != SERVER_STATE_TUNNEL_CREATE {
|
||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_CREATE)
|
||||
msg := s.tunnelAuthResponse(E_PROXY_INTERNALERROR)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
p.State, SERVER_STATE_TUNNEL_CREATE)
|
||||
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||
}
|
||||
client := s.tunnelAuthRequest(pkt)
|
||||
if s.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
msg := s.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
client := p.tunnelAuthRequest(pkt)
|
||||
if p.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %p", client)
|
||||
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
||||
}
|
||||
}
|
||||
msg := s.tunnelAuthResponse(ERROR_SUCCESS)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
p.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
log.Printf("Channel create")
|
||||
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||
if p.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||
msg := s.channelResponse(E_PROXY_INTERNALERROR)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
p.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
|
||||
}
|
||||
server, port := s.channelRequest(pkt)
|
||||
server, port := p.channelRequest(pkt)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if s.VerifyServerFunc != nil {
|
||||
log.Printf("Verifying %s host connection", host)
|
||||
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
if p.VerifyServerFunc != nil {
|
||||
log.Printf("Verifying %p host connection", host)
|
||||
if ok, _ := p.VerifyServerFunc(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %p by policy handler", host)
|
||||
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
log.Printf("Establishing connection to RDP server: %p", host)
|
||||
p.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %s", host, err)
|
||||
msg := s.channelResponse(E_PROXY_INTERNALERROR)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
log.Printf("Error connecting to %p, %p", host, err)
|
||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
return err
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
msg := s.channelResponse(ERROR_SUCCESS)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
msg := p.channelResponse(ERROR_SUCCESS)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go forward(s.Remote, s.Session.TransportOut)
|
||||
s.State = SERVER_STATE_CHANNEL_CREATE
|
||||
go forward(p.Remote, p.Session.TransportOut)
|
||||
p.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
if p.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
s.State = SERVER_STATE_OPENED
|
||||
receive(pkt, s.Remote)
|
||||
p.State = SERVER_STATE_OPENED
|
||||
receive(pkt, p.Remote)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// keepalives can be received while the channel is not open yet
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
if p.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
|
||||
|
@ -185,15 +185,15 @@ func (s *Server) Process(ctx context.Context) error {
|
|||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
log.Printf("Close channel")
|
||||
if s.State != SERVER_STATE_OPENED {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
|
||||
if p.State != SERVER_STATE_OPENED {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
msg := s.channelCloseResponse(ERROR_SUCCESS)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
//s.Session.TransportIn.Close()
|
||||
//s.Session.TransportOut.Close()
|
||||
s.State = SERVER_STATE_CLOSED
|
||||
msg := p.channelCloseResponse(ERROR_SUCCESS)
|
||||
p.Session.TransportOut.WritePacket(msg)
|
||||
//p.Session.TransportIn.Close()
|
||||
//p.Session.TransportOut.Close()
|
||||
p.State = SERVER_STATE_CLOSED
|
||||
return nil
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||
|
@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context) error {
|
|||
// Creates a packet the is a response to a handshakeRequest request
|
||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||
// but could be in Windows. However the NTLM protocol is insecure
|
||||
func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
|
||||
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
|
||||
buf.Write([]byte{major, minor})
|
||||
|
@ -214,7 +214,7 @@ func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCod
|
|||
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
||||
func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &major)
|
||||
binary.Read(r, binary.LittleEndian, &minor)
|
||||
|
@ -225,11 +225,11 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
|
||||
if s.SmartCardAuth {
|
||||
func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
|
||||
if p.SmartCardAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||
}
|
||||
if s.TokenAuth {
|
||||
if p.TokenAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||
}
|
||||
|
||||
|
@ -243,7 +243,7 @@ func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
|
|||
return caps, nil
|
||||
}
|
||||
|
||||
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||
func (p *Processor) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
@ -262,7 +262,7 @@ func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) tunnelResponse(errorCode int) []byte {
|
||||
func (p *Processor) tunnelResponse(errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||
|
@ -278,7 +278,7 @@ func (s *Server) tunnelResponse(errorCode int) []byte {
|
|||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthRequest(data []byte) string {
|
||||
func (p *Processor) tunnelAuthRequest(data []byte) string {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var size uint16
|
||||
|
@ -290,7 +290,7 @@ func (s *Server) tunnelAuthRequest(data []byte) string {
|
|||
return clientName
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthResponse(errorCode int) []byte {
|
||||
func (p *Processor) tunnelAuthResponse(errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
|
||||
|
@ -298,17 +298,17 @@ func (s *Server) tunnelAuthResponse(errorCode int) []byte {
|
|||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// idle timeout
|
||||
if s.IdleTimeout < 0 {
|
||||
s.IdleTimeout = 0
|
||||
if p.IdleTimeout < 0 {
|
||||
p.IdleTimeout = 0
|
||||
}
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
|
||||
binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags
|
||||
binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
|
||||
func (p *Processor) channelRequest(data []byte) (server string, port uint16) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var resourcesSize byte
|
||||
|
@ -330,7 +330,7 @@ func (s *Server) channelRequest(data []byte) (server string, port uint16) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) channelResponse(errorCode int) []byte {
|
||||
func (p *Processor) channelResponse(errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
|
||||
|
@ -349,7 +349,7 @@ func (s *Server) channelResponse(errorCode int) []byte {
|
|||
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) channelCloseResponse(errorCode int) []byte {
|
||||
func (p *Processor) channelCloseResponse(errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
|
|
@ -14,8 +14,8 @@ const (
|
|||
TunnelCreateResponseLen = HeaderLen + 18
|
||||
TunnelAuthLen = HeaderLen + 2 // + dynamic
|
||||
TunnelAuthResponseLen = HeaderLen + 16
|
||||
ChannelCreateLen = HeaderLen + 8 // + dynamic
|
||||
ChannelResponseLen = HeaderLen + 12
|
||||
ChannelCreateLen = HeaderLen + 8 // + dynamic
|
||||
ChannelResponseLen = HeaderLen + 12
|
||||
)
|
||||
|
||||
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
||||
|
@ -41,10 +41,10 @@ func TestHandshake(t *testing.T) {
|
|||
PAAToken: "abab",
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
hc := &ProcessorConf{
|
||||
TokenAuth: true,
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
h := NewProcessor(s, hc)
|
||||
|
||||
data := client.handshakeRequest()
|
||||
|
||||
|
@ -79,7 +79,7 @@ func TestHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func capsHelper(h Server) uint16 {
|
||||
func capsHelper(h Processor) uint16 {
|
||||
var caps uint16
|
||||
if h.TokenAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||
|
@ -92,12 +92,12 @@ func capsHelper(h Server) uint16 {
|
|||
|
||||
func TestMatchAuth(t *testing.T) {
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: false,
|
||||
hc := &ProcessorConf{
|
||||
TokenAuth: false,
|
||||
SmartCardAuth: false,
|
||||
}
|
||||
|
||||
h:= NewServer(s, hc)
|
||||
h := NewProcessor(s, hc)
|
||||
|
||||
in := uint16(0)
|
||||
caps, err := h.matchAuth(in)
|
||||
|
@ -136,10 +136,10 @@ func TestTunnelCreation(t *testing.T) {
|
|||
PAAToken: "abab",
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
hc := &ProcessorConf{
|
||||
TokenAuth: true,
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
h := NewProcessor(s, hc)
|
||||
|
||||
data := client.tunnelRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
||||
|
@ -180,14 +180,14 @@ func TestTunnelAuth(t *testing.T) {
|
|||
Name: name,
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
hc := &ProcessorConf{
|
||||
TokenAuth: true,
|
||||
IdleTimeout: 10,
|
||||
RedirectFlags: RedirectFlags{
|
||||
Clipboard: true,
|
||||
},
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
h := NewProcessor(s, hc)
|
||||
|
||||
data := client.tunnelAuthRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
||||
|
@ -223,17 +223,17 @@ func TestChannelCreation(t *testing.T) {
|
|||
server := "test_server"
|
||||
client := ClientConfig{
|
||||
Server: server,
|
||||
Port: 3389,
|
||||
Port: 3389,
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
hc := &ProcessorConf{
|
||||
TokenAuth: true,
|
||||
IdleTimeout: 10,
|
||||
RedirectFlags: RedirectFlags{
|
||||
Clipboard: true,
|
||||
},
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
h := NewProcessor(s, hc)
|
||||
|
||||
data := client.channelRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
||||
|
|
30
cmd/rdpgw/protocol/track.go
Normal file
30
cmd/rdpgw/protocol/track.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var Connections map[string]*GatewayConnection
|
||||
|
||||
type GatewayConnection struct {
|
||||
PacketHandler *Processor
|
||||
SessionInfo *SessionInfo
|
||||
Since time.Time
|
||||
IsWebsocket bool
|
||||
}
|
||||
|
||||
func RegisterConnection(connId string, h *Processor, s *SessionInfo) {
|
||||
if Connections == nil {
|
||||
Connections = make(map[string]*GatewayConnection)
|
||||
}
|
||||
|
||||
Connections[connId] = &GatewayConnection{
|
||||
PacketHandler: h,
|
||||
SessionInfo: s,
|
||||
Since: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func CloseConnection(connId string) {
|
||||
delete(Connections, connId)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue