Refactor add bit of tracking

This commit is contained in:
Bolke de Bruin 2022-09-22 17:21:16 +02:00
parent 8aa7c8cbb7
commit ce6692d22f
5 changed files with 158 additions and 113 deletions

View file

@ -177,7 +177,7 @@ func main() {
}
// create the gateway
gwConfig := protocol.ServerConf{
gwConfig := protocol.ProcessorConf{
IdleTimeout: conf.Caps.IdleTimeout,
TokenAuth: conf.Caps.TokenAuth,
SmartCardAuth: conf.Caps.SmartCardAuth,
@ -202,6 +202,7 @@ func main() {
gw := protocol.Gateway{
ServerConf: &gwConfig,
}
gwserver = &gw
if conf.Server.Authentication == config.AuthenticationBasic {
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
@ -215,6 +216,7 @@ func main() {
}
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo)
http.HandleFunc("/list", List)
if conf.Server.Tls == config.TlsDisable {
err = server.ListenAndServe()
@ -225,3 +227,12 @@ func main() {
log.Fatal("ListenAndServe: ", err)
}
}
var gwserver *protocol.Gateway
func List(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
for k, v := range protocol.Connections {
fmt.Fprintf(w, "ConnId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
}
}

View file

@ -46,7 +46,7 @@ var (
)
type Gateway struct {
ServerConf *ServerConf
ServerConf *ProcessorConf
}
var upgrader = websocket.Upgrader{}
@ -164,7 +164,9 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
s.TransportOut = inout
s.TransportIn = inout
handler := NewServer(s, g.ServerConf)
handler := NewProcessor(s, g.ServerConf)
RegisterConnection(s.ConnId, handler, s)
defer CloseConnection(s.ConnId)
handler.Process(ctx)
}
@ -208,7 +210,9 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf)
handler := NewProcessor(s, g.ServerConf)
RegisterConnection(s.ConnId, handler, s)
defer CloseConnection(s.ConnId)
handler.Process(r.Context())
}
}

View file

@ -18,7 +18,7 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type Server struct {
type Processor struct {
Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
@ -32,7 +32,7 @@ type Server struct {
State int
}
type ServerConf struct {
type ProcessorConf struct {
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
@ -44,8 +44,8 @@ type ServerConf struct {
SendBuf int
}
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
h := &Server{
func NewProcessor(s *SessionInfo, conf *ProcessorConf) *Processor {
h := &Processor{
State: SERVER_STATE_INITIALIZED,
Session: s,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
@ -61,123 +61,123 @@ func NewServer(s *SessionInfo, conf *ServerConf) *Server {
const tunnelId = 10
func (s *Server) Process(ctx context.Context) error {
func (p *Processor) Process(ctx context.Context) error {
for {
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
pt, sz, pkt, err := readMessage(p.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
log.Printf("Cannot read message from stream %p", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIALIZED)
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg)
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
if p.State != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.State, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: wrong state", E_PROXY_INTERNALERROR)
}
major, minor, _, reqAuth := s.handshakeRequest(pkt)
caps, err := s.matchAuth(reqAuth)
major, minor, _, reqAuth := p.handshakeRequest(pkt)
caps, err := p.matchAuth(reqAuth)
if err != nil {
log.Println(err)
msg := s.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
s.Session.TransportOut.WritePacket(msg)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_CAPABILITYMISMATCH)
p.Session.TransportOut.WritePacket(msg)
return err
}
msg := s.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_HANDSHAKE
msg := p.handshakeResponse(major, minor, caps, ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg)
p.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
if s.State != SERVER_STATE_HANDSHAKE {
if p.State != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_HANDSHAKE)
msg := s.tunnelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg)
p.State, SERVER_STATE_HANDSHAKE)
msg := p.tunnelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: PAA cookie rejected, wrong state", E_PROXY_INTERNALERROR)
}
_, cookie := s.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
msg := s.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
s.Session.TransportOut.WritePacket(msg)
_, cookie := p.tunnelRequest(pkt)
if p.VerifyTunnelCreate != nil {
if ok, _ := p.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
}
}
msg := s.tunnelResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_CREATE
msg := p.tunnelResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg)
p.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
if s.State != SERVER_STATE_TUNNEL_CREATE {
if p.State != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_CREATE)
msg := s.tunnelAuthResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg)
p.State, SERVER_STATE_TUNNEL_CREATE)
msg := p.tunnelAuthResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Tunnel auth rejected, wrong state", E_PROXY_INTERNALERROR)
}
client := s.tunnelAuthRequest(pkt)
if s.VerifyTunnelAuthFunc != nil {
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %s", client)
msg := s.tunnelAuthResponse(ERROR_ACCESS_DENIED)
s.Session.TransportOut.WritePacket(msg)
client := p.tunnelAuthRequest(pkt)
if p.VerifyTunnelAuthFunc != nil {
if ok, _ := p.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %p", client)
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
}
}
msg := s.tunnelAuthResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
msg := p.tunnelAuthResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg)
p.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create")
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
if p.State != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := s.channelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg)
p.State, SERVER_STATE_TUNNEL_AUTHORIZE)
msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: Channel create rejected, wrong state", E_PROXY_INTERNALERROR)
}
server, port := s.channelRequest(pkt)
server, port := p.channelRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if s.VerifyServerFunc != nil {
log.Printf("Verifying %s host connection", host)
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host)
msg := s.channelResponse(E_PROXY_RAP_ACCESSDENIED)
s.Session.TransportOut.WritePacket(msg)
if p.VerifyServerFunc != nil {
log.Printf("Verifying %p host connection", host)
if ok, _ := p.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %p by policy handler", host)
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
p.Session.TransportOut.WritePacket(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
}
}
log.Printf("Establishing connection to RDP server: %s", host)
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
log.Printf("Establishing connection to RDP server: %p", host)
p.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil {
log.Printf("Error connecting to %s, %s", host, err)
msg := s.channelResponse(E_PROXY_INTERNALERROR)
s.Session.TransportOut.WritePacket(msg)
log.Printf("Error connecting to %p, %p", host, err)
msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.Session.TransportOut.WritePacket(msg)
return err
}
log.Printf("Connection established")
msg := s.channelResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg)
msg := p.channelResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go forward(s.Remote, s.Session.TransportOut)
s.State = SERVER_STATE_CHANNEL_CREATE
go forward(p.Remote, p.Session.TransportOut)
p.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA:
if s.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
if p.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
s.State = SERVER_STATE_OPENED
receive(pkt, s.Remote)
p.State = SERVER_STATE_OPENED
receive(pkt, p.Remote)
case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet
if s.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
if p.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", p.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
@ -185,15 +185,15 @@ func (s *Server) Process(ctx context.Context) error {
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel")
if s.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
if p.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", p.State, SERVER_STATE_OPENED)
return errors.New("wrong state")
}
msg := s.channelCloseResponse(ERROR_SUCCESS)
s.Session.TransportOut.WritePacket(msg)
//s.Session.TransportIn.Close()
//s.Session.TransportOut.Close()
s.State = SERVER_STATE_CLOSED
msg := p.channelCloseResponse(ERROR_SUCCESS)
p.Session.TransportOut.WritePacket(msg)
//p.Session.TransportIn.Close()
//p.Session.TransportOut.Close()
p.State = SERVER_STATE_CLOSED
return nil
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
@ -204,7 +204,7 @@ func (s *Server) Process(ctx context.Context) error {
// Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
buf.Write([]byte{major, minor})
@ -214,7 +214,7 @@ func (s *Server) handshakeResponse(major byte, minor byte, caps uint16, errorCod
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
}
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
func (p *Processor) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor)
@ -225,11 +225,11 @@ func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version
return
}
func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
if s.SmartCardAuth {
func (p *Processor) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
if p.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC
}
if s.TokenAuth {
if p.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA
}
@ -243,7 +243,7 @@ func (s *Server) matchAuth(clientAuthCaps uint16) (caps uint16, err error) {
return caps, nil
}
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
func (p *Processor) tunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16
r := bytes.NewReader(data)
@ -262,7 +262,7 @@ func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
return
}
func (s *Server) tunnelResponse(errorCode int) []byte {
func (p *Processor) tunnelResponse(errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
@ -278,7 +278,7 @@ func (s *Server) tunnelResponse(errorCode int) []byte {
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func (s *Server) tunnelAuthRequest(data []byte) string {
func (p *Processor) tunnelAuthRequest(data []byte) string {
buf := bytes.NewReader(data)
var size uint16
@ -290,7 +290,7 @@ func (s *Server) tunnelAuthRequest(data []byte) string {
return clientName
}
func (s *Server) tunnelAuthResponse(errorCode int) []byte {
func (p *Processor) tunnelAuthResponse(errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
@ -298,17 +298,17 @@ func (s *Server) tunnelAuthResponse(errorCode int) []byte {
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// idle timeout
if s.IdleTimeout < 0 {
s.IdleTimeout = 0
if p.IdleTimeout < 0 {
p.IdleTimeout = 0
}
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
binary.Write(buf, binary.LittleEndian, uint32(p.RedirectFlags)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(p.IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
}
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
func (p *Processor) channelRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data)
var resourcesSize byte
@ -330,7 +330,7 @@ func (s *Server) channelRequest(data []byte) (server string, port uint16) {
return
}
func (s *Server) channelResponse(errorCode int) []byte {
func (p *Processor) channelResponse(errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code
@ -349,7 +349,7 @@ func (s *Server) channelResponse(errorCode int) []byte {
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
}
func (s *Server) channelCloseResponse(errorCode int) []byte {
func (p *Processor) channelCloseResponse(errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error code

View file

@ -14,8 +14,8 @@ const (
TunnelCreateResponseLen = HeaderLen + 18
TunnelAuthLen = HeaderLen + 2 // + dynamic
TunnelAuthResponseLen = HeaderLen + 16
ChannelCreateLen = HeaderLen + 8 // + dynamic
ChannelResponseLen = HeaderLen + 12
ChannelCreateLen = HeaderLen + 8 // + dynamic
ChannelResponseLen = HeaderLen + 12
)
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
@ -41,10 +41,10 @@ func TestHandshake(t *testing.T) {
PAAToken: "abab",
}
s := &SessionInfo{}
hc := &ServerConf{
hc := &ProcessorConf{
TokenAuth: true,
}
h := NewServer(s, hc)
h := NewProcessor(s, hc)
data := client.handshakeRequest()
@ -79,7 +79,7 @@ func TestHandshake(t *testing.T) {
}
}
func capsHelper(h Server) uint16 {
func capsHelper(h Processor) uint16 {
var caps uint16
if h.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA
@ -92,12 +92,12 @@ func capsHelper(h Server) uint16 {
func TestMatchAuth(t *testing.T) {
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: false,
hc := &ProcessorConf{
TokenAuth: false,
SmartCardAuth: false,
}
h:= NewServer(s, hc)
h := NewProcessor(s, hc)
in := uint16(0)
caps, err := h.matchAuth(in)
@ -136,10 +136,10 @@ func TestTunnelCreation(t *testing.T) {
PAAToken: "abab",
}
s := &SessionInfo{}
hc := &ServerConf{
hc := &ProcessorConf{
TokenAuth: true,
}
h := NewServer(s, hc)
h := NewProcessor(s, hc)
data := client.tunnelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
@ -180,14 +180,14 @@ func TestTunnelAuth(t *testing.T) {
Name: name,
}
s := &SessionInfo{}
hc := &ServerConf{
hc := &ProcessorConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
h := NewProcessor(s, hc)
data := client.tunnelAuthRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
@ -223,17 +223,17 @@ func TestChannelCreation(t *testing.T) {
server := "test_server"
client := ClientConfig{
Server: server,
Port: 3389,
Port: 3389,
}
s := &SessionInfo{}
hc := &ServerConf{
hc := &ProcessorConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
h := NewProcessor(s, hc)
data := client.channelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))

View file

@ -0,0 +1,30 @@
package protocol
import (
"time"
)
var Connections map[string]*GatewayConnection
type GatewayConnection struct {
PacketHandler *Processor
SessionInfo *SessionInfo
Since time.Time
IsWebsocket bool
}
func RegisterConnection(connId string, h *Processor, s *SessionInfo) {
if Connections == nil {
Connections = make(map[string]*GatewayConnection)
}
Connections[connId] = &GatewayConnection{
PacketHandler: h,
SessionInfo: s,
Since: time.Now(),
}
}
func CloseConnection(connId string) {
delete(Connections, connId)
}