Refactor add bit of tracking

This commit is contained in:
Bolke de Bruin 2022-09-22 17:21:16 +02:00
parent 8aa7c8cbb7
commit ce6692d22f
5 changed files with 158 additions and 113 deletions

View file

@ -177,7 +177,7 @@ func main() {
} }
// create the gateway // create the gateway
gwConfig := protocol.ServerConf{ gwConfig := protocol.ProcessorConf{
IdleTimeout: conf.Caps.IdleTimeout, IdleTimeout: conf.Caps.IdleTimeout,
TokenAuth: conf.Caps.TokenAuth, TokenAuth: conf.Caps.TokenAuth,
SmartCardAuth: conf.Caps.SmartCardAuth, SmartCardAuth: conf.Caps.SmartCardAuth,
@ -202,6 +202,7 @@ func main() {
gw := protocol.Gateway{ gw := protocol.Gateway{
ServerConf: &gwConfig, ServerConf: &gwConfig,
} }
gwserver = &gw
if conf.Server.Authentication == config.AuthenticationBasic { if conf.Server.Authentication == config.AuthenticationBasic {
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket} h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
@ -215,6 +216,7 @@ func main() {
} }
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo) http.HandleFunc("/tokeninfo", web.TokenInfo)
http.HandleFunc("/list", List)
if conf.Server.Tls == config.TlsDisable { if conf.Server.Tls == config.TlsDisable {
err = server.ListenAndServe() err = server.ListenAndServe()
@ -225,3 +227,12 @@ func main() {
log.Fatal("ListenAndServe: ", err) log.Fatal("ListenAndServe: ", err)
} }
} }
var gwserver *protocol.Gateway
func List(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
for k, v := range protocol.Connections {
fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
}
}

View file

@ -46,7 +46,7 @@ var (
) )
type Gateway struct { type Gateway struct {
ServerConf *ServerConf ServerConf *ProcessorConf
} }
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
@ -164,7 +164,9 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
s.TransportOut = inout s.TransportOut = inout
s.TransportIn = inout s.TransportIn = inout
handler := NewServer(s, g.ServerConf) handler := NewProcessor(s, g.ServerConf)
RegisterConnection(s.ConnId, handler, s)
defer CloseConnection(s.ConnId)
handler.Process(ctx) handler.Process(ctx)
} }
@ -208,7 +210,9 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
in.Drain() in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf) handler := NewProcessor(s, g.ServerConf)
RegisterConnection(s.ConnId, handler, s)
defer CloseConnection(s.ConnId)
handler.Process(r.Context()) handler.Process(r.Context())
} }
} }

View file

@ -18,7 +18,7 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error) type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error) type VerifyServerFunc func(context.Context, string) (bool, error)
type Server struct { type Processor struct {
Session *SessionInfo Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc
@ -32,7 +32,7 @@ type Server struct {
State int State int
} }
type ServerConf struct { type ProcessorConf struct {
VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc VerifyServerFunc VerifyServerFunc
@ -44,8 +44,8 @@ type ServerConf struct {
SendBuf int SendBuf int
} }
func NewServer(s *SessionInfo, conf *ServerConf) *Server { func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor {
h := &Server{ h := &Processor{
State: SERVER_STATE_INITIALIZED, State: SERVER_STATE_INITIALIZED,
Session: s, Session: s,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags), RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
@ -61,123 +61,123 @@ func NewServer(s *SessionInfo, conf *ServerConf) *Server {
const tunnelId = 10 const tunnelId = 10
func (s *Server) Process(ctx context.Context) error { func (p *Processor) Process(ctx context.Context) error {
for { for {
pt, sz, pkt, err := readMessage(s.Session.TransportIn) pt, sz, pkt, err := readMessage(p.Session.TransportIn)
if err != nil { if err != nil {
log.Printf("Cannot read message from stream %s", err) log.Printf("Cannot read message from stream %p", err)
return err return err
} }
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIALIZED { if p.State != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIALIZED) log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED)
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
} }
major, minor, _, reqAuth := s.handshakeRequest(pkt) major, minor, _, reqAuth := p.handshakeRequest(pkt)
caps, err := s.matchAuth(reqAuth) caps, err := p.matchAuth(reqAuth)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return err return err
} }
msg := s.handshakeResponse(major, minor, caps, ERROR_SUCCESS) msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_HANDSHAKE p.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create") log.Printf("Tunnel create")
if s.State != SERVER_STATE_HANDSHAKE { if p.State != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d", log.Printf("Tunnel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_HANDSHAKE) p.State, SERVER_STATE_HANDSHAKE)
msg := s.tunnelResponse(E_PROXY_INTERNALERROR) msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
} }
_, cookie := s.tunnelRequest(pkt) _, cookie := p.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil { if p.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok { if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
msg := s.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
} }
} }
msg := s.tunnelResponse(ERROR_SUCCESS) msg := p.tunnelResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_CREATE p.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth") log.Printf("Tunnel auth")
if s.State != SERVER_STATE_TUNNEL_CREATE { if p.State != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d", log.Printf("Tunnel auth attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_CREATE) p.State, SERVER_STATE_TUNNEL_CREATE)
msg := s.tunnelAuthResponse(E_PROXY_INTERNALERROR) msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
} }
client := s.tunnelAuthRequest(pkt) client := p.tunnelAuthRequest(pkt)
if s.VerifyTunnelAuthFunc != nil { if p.VerifyTunnelAuthFunc != nil {
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok { if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %s", client) log.Printf("Invalid client name: %p", client)
msg := s.tunnelAuthResponse(ERROR_ACCESS_DENIED) msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
} }
} }
msg := s.tunnelAuthResponse(ERROR_SUCCESS) msg := p.tunnelAuthResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_AUTHORIZE p.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create") log.Printf("Channel create")
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE { if p.State != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d", log.Printf("Channel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_AUTHORIZE) p.State, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := s.channelResponse(E_PROXY_INTERNALERROR) msg := p.channelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR) return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
} }
server, port := s.channelRequest(pkt) server, port := p.channelRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port))) host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if s.VerifyServerFunc != nil { if p.VerifyServerFunc != nil {
log.Printf("Verifying %s host connection", host) log.Printf("Verifying %p host connection", host)
if ok, _ := s.VerifyServerFunc(ctx, host); !ok { if ok, _ := p.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host) log.Printf("Not allowed to connect to %p by policy handler", host)
msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED) msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
} }
} }
log.Printf("Establishing connection to RDP server: %s", host) log.Printf("Establishing connection to RDP server: %p", host)
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15) p.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil { if err != nil {
log.Printf("Error connecting to %s, %s", host, err) log.Printf("Error connecting to %p, %p", host, err)
msg := s.channelResponse(E_PROXY_INTERNALERROR) msg := p.channelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
return err return err
} }
log.Printf("Connection established") log.Printf("Connection established")
msg := s.channelResponse(ERROR_SUCCESS) msg := p.channelResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go forward(s.Remote, s.Session.TransportOut) go forward(p.Remote, p.Session.TransportOut)
s.State = SERVER_STATE_CHANNEL_CREATE p.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if s.State < SERVER_STATE_CHANNEL_CREATE { if p.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
s.State = SERVER_STATE_OPENED p.State = SERVER_STATE_OPENED
receive(pkt, s.Remote) receive(pkt, p.Remote)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet // keepalives can be received while the channel is not open yet
if s.State < SERVER_STATE_CHANNEL_CREATE { if p.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE) log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state") return errors.New("wrong state")
} }
@ -185,15 +185,15 @@ func (s *Server) Process(ctx context.Context) error {
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel") log.Printf("Close channel")
if s.State != SERVER_STATE_OPENED { if p.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED) log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED)
return errors.New("wrong state") return errors.New("wrong state")
} }
msg := s.channelCloseResponse(ERROR_SUCCESS) msg := p.channelCloseResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg) p.Session.TransportOut.WritePacket(msg)
//s.Session.TransportIn.Close() //p.Session.TransportIn.Close()
//s.Session.TransportOut.Close() //p.Session.TransportOut.Close()
s.State = SERVER_STATE_CLOSED p.State = SERVER_STATE_CLOSED
return nil return nil
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt) log.Printf("Unknown packet (size %d): %x", sz, pkt)
@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context) error {
// Creates a packet the is a response to a handshakeRequest request // Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure // but could be in Windows. However the NTLM protocol is insecure
func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte { func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
buf.Write([]byte{major, minor}) buf.Write([]byte{major, minor})
@ -214,7 +214,7 @@ func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCod
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
} }
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) { func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &major) binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor) binary.Read(r, binary.LittleEndian, &minor)
@ -225,11 +225,11 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version
return return
} }
func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) { func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
if s.SmartCardAuth { if p.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC caps = caps | HTTP_EXTENDED_AUTH_SC
} }
if s.TokenAuth { if p.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA caps = caps | HTTP_EXTENDED_AUTH_PAA
} }
@ -243,7 +243,7 @@ func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
return caps, nil return caps, nil
} }
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) { func (p *Processor) tunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16 var fields uint16
r := bytes.NewReader(data) r := bytes.NewReader(data)
@ -262,7 +262,7 @@ func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
return return
} }
func (s *Server) tunnelResponse(errorCode int) []byte { func (p *Processor) tunnelResponse(errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
@ -278,7 +278,7 @@ func (s *Server) tunnelResponse(errorCode int) []byte {
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
} }
func (s *Server) tunnelAuthRequest(data []byte) string { func (p *Processor) tunnelAuthRequest(data []byte) string {
buf := bytes.NewReader(data) buf := bytes.NewReader(data)
var size uint16 var size uint16
@ -290,7 +290,7 @@ func (s *Server) tunnelAuthRequest(data []byte) string {
return clientName return clientName
} }
func (s *Server) tunnelAuthResponse(errorCode int) []byte { func (p *Processor) tunnelAuthResponse(errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
@ -298,17 +298,17 @@ func (s *Server) tunnelAuthResponse(errorCode int) []byte {
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// idle timeout // idle timeout
if s.IdleTimeout < 0 { if p.IdleTimeout < 0 {
s.IdleTimeout = 0 p.IdleTimeout = 0
} }
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
} }
func (s *Server) channelRequest(data []byte) (server string, port uint16) { func (p *Processor) channelRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data) buf := bytes.NewReader(data)
var resourcesSize byte var resourcesSize byte
@ -330,7 +330,7 @@ func (s *Server) channelRequest(data []byte) (server string, port uint16) {
return return
} }
func (s *Server) channelResponse(errorCode int) []byte { func (p *Processor) channelResponse(errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
@ -349,7 +349,7 @@ func (s *Server) channelResponse(errorCode int) []byte {
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
} }
func (s *Server) channelCloseResponse(errorCode int) []byte { func (p *Processor) channelCloseResponse(errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code

View file

@ -41,10 +41,10 @@ func TestHandshake(t *testing.T) {
PAAToken: "abab", PAAToken: "abab",
} }
s := &SessionInfo{} s := &SessionInfo{}
hc := &ServerConf{ hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
} }
h := NewServer(s, hc) h := NewProcessor(s, hc)
data := client.handshakeRequest() data := client.handshakeRequest()
@ -79,7 +79,7 @@ func TestHandshake(t *testing.T) {
} }
} }
func capsHelper(h Server) uint16 { func capsHelper(h Processor) uint16 {
var caps uint16 var caps uint16
if h.TokenAuth { if h.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA caps = caps | HTTP_EXTENDED_AUTH_PAA
@ -92,12 +92,12 @@ func capsHelper(h Server) uint16 {
func TestMatchAuth(t *testing.T) { func TestMatchAuth(t *testing.T) {
s := &SessionInfo{} s := &SessionInfo{}
hc := &ServerConf{ hc := &ProcessorConf{
TokenAuth: false, TokenAuth: false,
SmartCardAuth: false, SmartCardAuth: false,
} }
h:= NewServer(s, hc) h := NewProcessor(s, hc)
in := uint16(0) in := uint16(0)
caps, err := h.matchAuth(in) caps, err := h.matchAuth(in)
@ -136,10 +136,10 @@ func TestTunnelCreation(t *testing.T) {
PAAToken: "abab", PAAToken: "abab",
} }
s := &SessionInfo{} s := &SessionInfo{}
hc := &ServerConf{ hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
} }
h := NewServer(s, hc) h := NewProcessor(s, hc)
data := client.tunnelRequest() data := client.tunnelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
@ -180,14 +180,14 @@ func TestTunnelAuth(t *testing.T) {
Name: name, Name: name,
} }
s := &SessionInfo{} s := &SessionInfo{}
hc := &ServerConf{ hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
IdleTimeout: 10, IdleTimeout: 10,
RedirectFlags: RedirectFlags{ RedirectFlags: RedirectFlags{
Clipboard: true, Clipboard: true,
}, },
} }
h := NewServer(s, hc) h := NewProcessor(s, hc)
data := client.tunnelAuthRequest() data := client.tunnelAuthRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
@ -226,14 +226,14 @@ func TestChannelCreation(t *testing.T) {
Port: 3389, Port: 3389,
} }
s := &SessionInfo{} s := &SessionInfo{}
hc := &ServerConf{ hc := &ProcessorConf{
TokenAuth: true, TokenAuth: true,
IdleTimeout: 10, IdleTimeout: 10,
RedirectFlags: RedirectFlags{ RedirectFlags: RedirectFlags{
Clipboard: true, Clipboard: true,
}, },
} }
h := NewServer(s, hc) h := NewProcessor(s, hc)
data := client.channelRequest() data := client.channelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))

View file

@ -0,0 +1,30 @@
package protocol
import (
"time"
)
var Connections map[string]*GatewayConnection
type GatewayConnection struct {
PacketHandler *Processor
SessionInfo *SessionInfo
Since time.Time
IsWebsocket bool
}
func RegisterConnection(connId string, h *Processor, s *SessionInfo) {
if Connections == nil {
Connections = make(map[string]*GatewayConnection)
}
Connections[connId] = &GatewayConnection{
PacketHandler: h,
SessionInfo: s,
Since: time.Now(),
}
}
func CloseConnection(connId string) {
delete(Connections, connId)
}