Rework tunnels to support statistics

This commit is contained in:
Bolke de Bruin 2022-09-24 13:21:01 +02:00
parent eb1b287751
commit 94d7cddc4b
14 changed files with 74 additions and 45 deletions

View file

@ -12,6 +12,8 @@ const (
ClientIPCtx = "ClientIP" ClientIPCtx = "ClientIP"
ProxyAddressesCtx = "ProxyAddresses" ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress" RemoteAddressCtx = "RemoteAddress"
TunnelCtx = "TUNNEL"
UsernameCtx = "preferred_username"
) )
func EnrichContext(next http.Handler) http.Handler { func EnrichContext(next http.Handler) http.Handler {

View file

@ -231,6 +231,8 @@ var gwserver *protocol.Gateway
func List(w http.ResponseWriter, r *http.Request) { func List(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
for k, v := range protocol.Connections { for k, v := range protocol.Connections {
fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName) fmt.Fprintf(w, "Id: %s Rdg-Id: %s User: %s From: %s Connected Since: %s Bytes Sent: %d Bytes Received: %d Last Seen: %s Target: %s\n",
k, v.Tunnel.RDGId, v.Tunnel.UserName, v.Tunnel.RemoteAddr, v.Tunnel.ConnectedOn, v.Tunnel.BytesSent, v.Tunnel.BytesReceived,
v.Tunnel.LastSeen, v.Tunnel.TargetServer)
} }
} }

View file

@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
} }
log.Printf("Channel creation succesful. Channel id: %d", cid) log.Printf("Channel creation succesful. Channel id: %d", cid)
go forward(c.LocalConn, c.Session.TransportOut) //go forward(c.LocalConn, c.Session.TransportOut)
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
receive(pkt, c.LocalConn) receive(pkt, c.LocalConn)
default: default:

View file

@ -92,7 +92,7 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
} }
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol // forwards data from a Connection to Transport and wraps it in the rdpgw protocol
func forward(in net.Conn, out transport.Transport) { func forward(in net.Conn, tunnel *Tunnel) {
defer in.Close() defer in.Close()
b1 := new(bytes.Buffer) b1 := new(bytes.Buffer)
@ -106,7 +106,7 @@ func forward(in net.Conn, out transport.Transport) {
} }
binary.Write(b1, binary.LittleEndian, uint16(n)) binary.Write(b1, binary.LittleEndian, uint16(n))
b1.Write(buf[:n]) b1.Write(buf[:n])
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) tunnel.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset() b1.Reset()
} }
} }

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"log" "log"
@ -16,7 +17,7 @@ import (
) )
const ( const (
rdgConnectionIdKey = "Rdg-Connection-RDGId" rdgConnectionIdKey = "Rdg-Connection-Id"
MethodRDGIN = "RDG_IN_DATA" MethodRDGIN = "RDG_IN_DATA"
MethodRDGOUT = "RDG_OUT_DATA" MethodRDGOUT = "RDG_OUT_DATA"
) )
@ -59,14 +60,19 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
var t *Tunnel var t *Tunnel
ctx := context.WithValue(r.Context(), common.TunnelCtx, t)
connId := r.Header.Get(rdgConnectionIdKey) connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId) x, found := c.Get(connId)
if !found { if !found {
t = &Tunnel{RDGId: connId} t = &Tunnel{
RDGId: connId,
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
UserName: ctx.Value(common.UsernameCtx).(string),
}
} else { } else {
t = x.(*Tunnel) t = x.(*Tunnel)
} }
ctx := context.WithValue(r.Context(), "Tunnel", t)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
@ -158,13 +164,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
inout, _ := transport.NewWS(c) inout, _ := transport.NewWS(c)
defer inout.Close() defer inout.Close()
t.Id = uuid.New().String()
t.TransportOut = inout t.TransportOut = inout
t.TransportIn = inout t.TransportIn = inout
t.ConnectedOn = time.Now() t.ConnectedOn = time.Now()
handler := NewProcessor(g, t) handler := NewProcessor(g, t)
RegisterConnection(handler, t) RegisterTunnel(t, handler)
defer RemoveConnection(t.RDGId) defer RemoveTunnel(t)
handler.Process(ctx) handler.Process(ctx)
} }
@ -198,6 +205,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
defer in.Close() defer in.Close()
if t.TransportIn == nil { if t.TransportIn == nil {
t.Id = uuid.New().String()
t.TransportIn = in t.TransportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration) c.Set(t.RDGId, t, cache.DefaultExpiration)
@ -209,8 +217,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context())) log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
handler := NewProcessor(g, t) handler := NewProcessor(g, t)
RegisterConnection(handler, t) RegisterTunnel(t, handler)
defer RemoveConnection(t.RDGId) defer RemoveTunnel(t)
handler.Process(r.Context()) handler.Process(r.Context())
} }
} }

View file

@ -47,7 +47,7 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx)) log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if p.state != SERVER_STATE_INITIALIZED { if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
@ -77,7 +77,7 @@ func (p *Processor) Process(ctx context.Context) error {
_, cookie := p.tunnelRequest(pkt) _, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil { if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx)) log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
@ -98,7 +98,7 @@ func (p *Processor) Process(ctx context.Context) error {
client := p.tunnelAuthRequest(pkt) client := p.tunnelAuthRequest(pkt)
if p.gw.CheckClientName != nil { if p.gw.CheckClientName != nil {
if ok, _ := p.gw.CheckClientName(ctx, client); !ok { if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
log.Printf("Invalid client name: %p", client) log.Printf("Invalid client name: %s", client)
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED) msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED) return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
@ -119,18 +119,18 @@ func (p *Processor) Process(ctx context.Context) error {
server, port := p.channelRequest(pkt) server, port := p.channelRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port))) host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if p.gw.CheckHost != nil { if p.gw.CheckHost != nil {
log.Printf("Verifying %p host connection", host) log.Printf("Verifying %s host connection", host)
if ok, _ := p.gw.CheckHost(ctx, host); !ok { if ok, _ := p.gw.CheckHost(ctx, host); !ok {
log.Printf("Not allowed to connect to %p by policy handler", host) log.Printf("Not allowed to connect to %s by policy handler", host)
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED) msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED) return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
} }
} }
log.Printf("Establishing connection to RDP server: %p", host) log.Printf("Establishing connection to RDP server: %s", host)
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15) p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil { if err != nil {
log.Printf("Error connecting to %p, %p", host, err) log.Printf("Error connecting to %s, %s", host, err)
msg := p.channelResponse(E_PROXY_INTERNALERROR) msg := p.channelResponse(E_PROXY_INTERNALERROR)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return err return err
@ -142,7 +142,7 @@ func (p *Processor) Process(ctx context.Context) error {
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go forward(p.tunnel.rwc, p.tunnel.TransportOut) go forward(p.tunnel.rwc, p.tunnel)
p.state = SERVER_STATE_CHANNEL_CREATE p.state = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if p.state < SERVER_STATE_CHANNEL_CREATE { if p.state < SERVER_STATE_CHANNEL_CREATE {

View file

@ -7,19 +7,19 @@ type Monitor struct {
Tunnel *Tunnel Tunnel *Tunnel
} }
func RegisterConnection(h *Processor, t *Tunnel) { func RegisterTunnel(t *Tunnel, p *Processor) {
if Connections == nil { if Connections == nil {
Connections = make(map[string]*Monitor) Connections = make(map[string]*Monitor)
} }
Connections[t.RDGId] = &Monitor{ Connections[t.Id] = &Monitor{
Processor: h, Processor: p,
Tunnel: t, Tunnel: t,
} }
} }
func RemoveConnection(connId string) { func RemoveTunnel(t *Tunnel) {
delete(Connections, connId) delete(Connections, t.Id)
} }
// CalculateSpeedPerSecond calculate moving average. // CalculateSpeedPerSecond calculate moving average.

View file

@ -7,6 +7,8 @@ import (
) )
type Tunnel struct { type Tunnel struct {
// Id identifies the connection in the server
Id string
// The connection-id (RDG-ConnID) as reported by the client // The connection-id (RDG-ConnID) as reported by the client
RDGId string RDGId string
// The underlying incoming transport being either websocket or legacy http // The underlying incoming transport being either websocket or legacy http
@ -26,18 +28,28 @@ type Tunnel struct {
// It is of the type *net.TCPConn // It is of the type *net.TCPConn
rwc net.Conn rwc net.Conn
ByteSent int64 // BytesSent is the total amount of bytes sent by the server to the client minus tunnel overhead
BytesSent int64
// BytesReceived is the total amount of bytes received by the server from the client minus tunnel overhad
BytesReceived int64 BytesReceived int64
// ConnectedOn is when the client connected to the server
ConnectedOn time.Time ConnectedOn time.Time
// LastSeen is when the server received the last packet from the client
LastSeen time.Time LastSeen time.Time
} }
// Write puts the packet on the transport and updates the statistics for bytes sent
func (t *Tunnel) Write(pkt []byte) { func (t *Tunnel) Write(pkt []byte) {
n, _ := t.TransportOut.WritePacket(pkt) n, _ := t.TransportOut.WritePacket(pkt)
t.ByteSent += int64(n) t.BytesSent += int64(n)
} }
// Read picks up a packet from the transport and returns the packet type
// packet, with the header removed, and the packet size. It updates the
// statistics for bytes received
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) { func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
pt, size, pkt, err = readMessage(t.TransportIn) pt, size, pkt, err = readMessage(t.TransportIn)
t.BytesReceived += int64(size) t.BytesReceived += int64(size)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"log" "log"
"strings" "strings"
) )
@ -23,10 +24,10 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
case "roundrobin", "unsigned": case "roundrobin", "unsigned":
var username string var username string
s := getSessionInfo(ctx) s := getTunnel(ctx)
if s == nil || s.UserName == "" { if s == nil || s.UserName == "" {
var ok bool var ok bool
username, ok = ctx.Value("preferred_username").(string) username, ok = ctx.Value(common.UsernameCtx).(string)
if !ok { if !ok {
return false, errors.New("no valid session info or username found in context") return false, errors.New("no valid session info or username found in context")
} }

View file

@ -2,6 +2,7 @@ package security
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"testing" "testing"
) )
@ -20,7 +21,7 @@ var (
) )
func TestCheckHost(t *testing.T) { func TestCheckHost(t *testing.T) {
ctx := context.WithValue(context.Background(), "Tunnel", &info) ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
Hosts = hosts Hosts = hosts

View file

@ -35,7 +35,7 @@ type customClaims struct {
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
return func(ctx context.Context, host string) (bool, error) { return func(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx) s := getTunnel(ctx)
if s == nil { if s == nil {
return false, errors.New("no valid session info found in context") return false, errors.New("no valid session info found in context")
} }
@ -62,7 +62,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
token, err := jwt.ParseSigned(tokenString) token, err := jwt.ParseSigned(tokenString)
if err != nil { if err != nil {
log.Printf("cannot parse token due to: %s", err) log.Printf("cannot parse token due to: %tunnel", err)
return false, err return false, err
} }
@ -79,7 +79,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
// Claims automagically checks the signature... // Claims automagically checks the signature...
err = token.Claims(SigningKey, &standard, &custom) err = token.Claims(SigningKey, &standard, &custom)
if err != nil { if err != nil {
log.Printf("token signature validation failed due to %s", err) log.Printf("token signature validation failed due to %tunnel", err)
return false, err return false, err
} }
@ -90,7 +90,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
}) })
if err != nil { if err != nil {
log.Printf("token validation failed due to %s", err) log.Printf("token validation failed due to %tunnel", err)
return false, err return false, err
} }
@ -98,15 +98,15 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
user, err := OIDCProvider.UserInfo(ctx, tokenSource) user, err := OIDCProvider.UserInfo(ctx, tokenSource)
if err != nil { if err != nil {
log.Printf("Cannot get user info for access token: %s", err) log.Printf("Cannot get user info for access token: %tunnel", err)
return false, err return false, err
} }
s := getSessionInfo(ctx) tunnel := getTunnel(ctx)
s.TargetServer = custom.RemoteServer tunnel.TargetServer = custom.RemoteServer
s.RemoteAddr = custom.ClientIP tunnel.RemoteAddr = custom.ClientIP
s.UserName = user.Subject tunnel.UserName = user.Subject
return true, nil return true, nil
} }
@ -288,7 +288,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
return token, err return token, err
} }
func getSessionInfo(ctx context.Context) *protocol.Tunnel { func getTunnel(ctx context.Context) *protocol.Tunnel {
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel) s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
if !ok { if !ok {
log.Printf("cannot get session info from context") log.Printf("cannot get session info from context")

View file

@ -2,6 +2,7 @@ package web
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/shared/auth" "github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -12,7 +13,7 @@ import (
) )
const ( const (
protocol = "unix" protocolGrpc = "unix"
) )
type BasicAuthHandler struct { type BasicAuthHandler struct {
@ -27,7 +28,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial(protocol, addr) return net.Dial(protocolGrpc, addr)
})) }))
if err != nil { if err != nil {
log.Printf("Cannot reach authentication provider: %s", err) log.Printf("Cannot reach authentication provider: %s", err)
@ -51,7 +52,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
if !res.Authenticated { if !res.Authenticated {
log.Printf("User %s is not authenticated for this service", username) log.Printf("User %s is not authenticated for this service", username)
} else { } else {
ctx := context.WithValue(r.Context(), "preferred_username", username) ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return return
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -119,7 +120,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return return
} }
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))

1
go.mod
View file

@ -6,6 +6,7 @@ require (
github.com/coreos/go-oidc/v3 v3.2.0 github.com/coreos/go-oidc/v3 v3.2.0
github.com/fatih/structs v1.1.0 github.com/fatih/structs v1.1.0
github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-jose/go-jose/v3 v3.0.0
github.com/google/uuid v1.1.2
github.com/gorilla/sessions v1.2.1 github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/knadh/koanf v1.4.2 github.com/knadh/koanf v1.4.2