mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-19 06:53:49 +02:00
Rework tunnels to support statistics
This commit is contained in:
parent
eb1b287751
commit
94d7cddc4b
14 changed files with 74 additions and 45 deletions
|
@ -12,6 +12,8 @@ const (
|
|||
ClientIPCtx = "ClientIP"
|
||||
ProxyAddressesCtx = "ProxyAddresses"
|
||||
RemoteAddressCtx = "RemoteAddress"
|
||||
TunnelCtx = "TUNNEL"
|
||||
UsernameCtx = "preferred_username"
|
||||
)
|
||||
|
||||
func EnrichContext(next http.Handler) http.Handler {
|
||||
|
|
|
@ -231,6 +231,8 @@ var gwserver *protocol.Gateway
|
|||
func List(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
for k, v := range protocol.Connections {
|
||||
fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
|
||||
fmt.Fprintf(w, "Id: %s Rdg-Id: %s User: %s From: %s Connected Since: %s Bytes Sent: %d Bytes Received: %d Last Seen: %s Target: %s\n",
|
||||
k, v.Tunnel.RDGId, v.Tunnel.UserName, v.Tunnel.RemoteAddr, v.Tunnel.ConnectedOn, v.Tunnel.BytesSent, v.Tunnel.BytesReceived,
|
||||
v.Tunnel.LastSeen, v.Tunnel.TargetServer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
|||
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
||||
}
|
||||
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||
go forward(c.LocalConn, c.Session.TransportOut)
|
||||
//go forward(c.LocalConn, c.Session.TransportOut)
|
||||
case PKT_TYPE_DATA:
|
||||
receive(pkt, c.LocalConn)
|
||||
default:
|
||||
|
|
|
@ -92,7 +92,7 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
|
|||
}
|
||||
|
||||
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
||||
func forward(in net.Conn, out transport.Transport) {
|
||||
func forward(in net.Conn, tunnel *Tunnel) {
|
||||
defer in.Close()
|
||||
|
||||
b1 := new(bytes.Buffer)
|
||||
|
@ -106,7 +106,7 @@ func forward(in net.Conn, out transport.Transport) {
|
|||
}
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
b1.Write(buf[:n])
|
||||
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
tunnel.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"log"
|
||||
|
@ -16,7 +17,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
rdgConnectionIdKey = "Rdg-Connection-RDGId"
|
||||
rdgConnectionIdKey = "Rdg-Connection-Id"
|
||||
MethodRDGIN = "RDG_IN_DATA"
|
||||
MethodRDGOUT = "RDG_OUT_DATA"
|
||||
)
|
||||
|
@ -59,14 +60,19 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
var t *Tunnel
|
||||
|
||||
ctx := context.WithValue(r.Context(), common.TunnelCtx, t)
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
t = &Tunnel{RDGId: connId}
|
||||
t = &Tunnel{
|
||||
RDGId: connId,
|
||||
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
|
||||
UserName: ctx.Value(common.UsernameCtx).(string),
|
||||
}
|
||||
} else {
|
||||
t = x.(*Tunnel)
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), "Tunnel", t)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
|
@ -158,13 +164,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
|||
inout, _ := transport.NewWS(c)
|
||||
defer inout.Close()
|
||||
|
||||
t.Id = uuid.New().String()
|
||||
t.TransportOut = inout
|
||||
t.TransportIn = inout
|
||||
t.ConnectedOn = time.Now()
|
||||
|
||||
handler := NewProcessor(g, t)
|
||||
RegisterConnection(handler, t)
|
||||
defer RemoveConnection(t.RDGId)
|
||||
RegisterTunnel(t, handler)
|
||||
defer RemoveTunnel(t)
|
||||
handler.Process(ctx)
|
||||
}
|
||||
|
||||
|
@ -198,6 +205,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
|||
defer in.Close()
|
||||
|
||||
if t.TransportIn == nil {
|
||||
t.Id = uuid.New().String()
|
||||
t.TransportIn = in
|
||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||
|
||||
|
@ -209,8 +217,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
|||
|
||||
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
|
||||
handler := NewProcessor(g, t)
|
||||
RegisterConnection(handler, t)
|
||||
defer RemoveConnection(t.RDGId)
|
||||
RegisterTunnel(t, handler)
|
||||
defer RemoveTunnel(t)
|
||||
handler.Process(r.Context())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
|
||||
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||
if p.state != SERVER_STATE_INITIALIZED {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||
|
@ -77,7 +77,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
_, cookie := p.tunnelRequest(pkt)
|
||||
if p.gw.CheckPAACookie != nil {
|
||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
|
||||
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
p.tunnel.Write(msg)
|
||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
|
@ -98,7 +98,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
client := p.tunnelAuthRequest(pkt)
|
||||
if p.gw.CheckClientName != nil {
|
||||
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %p", client)
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
||||
p.tunnel.Write(msg)
|
||||
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
||||
|
@ -119,18 +119,18 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
server, port := p.channelRequest(pkt)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if p.gw.CheckHost != nil {
|
||||
log.Printf("Verifying %p host connection", host)
|
||||
log.Printf("Verifying %s host connection", host)
|
||||
if ok, _ := p.gw.CheckHost(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %p by policy handler", host)
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
||||
p.tunnel.Write(msg)
|
||||
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %p", host)
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %p, %p", host, err)
|
||||
log.Printf("Error connecting to %s, %s", host, err)
|
||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||
p.tunnel.Write(msg)
|
||||
return err
|
||||
|
@ -142,7 +142,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go forward(p.tunnel.rwc, p.tunnel.TransportOut)
|
||||
go forward(p.tunnel.rwc, p.tunnel)
|
||||
p.state = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
||||
|
|
|
@ -7,19 +7,19 @@ type Monitor struct {
|
|||
Tunnel *Tunnel
|
||||
}
|
||||
|
||||
func RegisterConnection(h *Processor, t *Tunnel) {
|
||||
func RegisterTunnel(t *Tunnel, p *Processor) {
|
||||
if Connections == nil {
|
||||
Connections = make(map[string]*Monitor)
|
||||
}
|
||||
|
||||
Connections[t.RDGId] = &Monitor{
|
||||
Processor: h,
|
||||
Connections[t.Id] = &Monitor{
|
||||
Processor: p,
|
||||
Tunnel: t,
|
||||
}
|
||||
}
|
||||
|
||||
func RemoveConnection(connId string) {
|
||||
delete(Connections, connId)
|
||||
func RemoveTunnel(t *Tunnel) {
|
||||
delete(Connections, t.Id)
|
||||
}
|
||||
|
||||
// CalculateSpeedPerSecond calculate moving average.
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
)
|
||||
|
||||
type Tunnel struct {
|
||||
// Id identifies the connection in the server
|
||||
Id string
|
||||
// The connection-id (RDG-ConnID) as reported by the client
|
||||
RDGId string
|
||||
// The underlying incoming transport being either websocket or legacy http
|
||||
|
@ -26,18 +28,28 @@ type Tunnel struct {
|
|||
// It is of the type *net.TCPConn
|
||||
rwc net.Conn
|
||||
|
||||
ByteSent int64
|
||||
// BytesSent is the total amount of bytes sent by the server to the client minus tunnel overhead
|
||||
BytesSent int64
|
||||
|
||||
// BytesReceived is the total amount of bytes received by the server from the client minus tunnel overhad
|
||||
BytesReceived int64
|
||||
|
||||
// ConnectedOn is when the client connected to the server
|
||||
ConnectedOn time.Time
|
||||
LastSeen time.Time
|
||||
|
||||
// LastSeen is when the server received the last packet from the client
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Write puts the packet on the transport and updates the statistics for bytes sent
|
||||
func (t *Tunnel) Write(pkt []byte) {
|
||||
n, _ := t.TransportOut.WritePacket(pkt)
|
||||
t.ByteSent += int64(n)
|
||||
t.BytesSent += int64(n)
|
||||
}
|
||||
|
||||
// Read picks up a packet from the transport and returns the packet type
|
||||
// packet, with the header removed, and the packet size. It updates the
|
||||
// statistics for bytes received
|
||||
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
||||
pt, size, pkt, err = readMessage(t.TransportIn)
|
||||
t.BytesReceived += int64(size)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
@ -23,10 +24,10 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
|
|||
case "roundrobin", "unsigned":
|
||||
var username string
|
||||
|
||||
s := getSessionInfo(ctx)
|
||||
s := getTunnel(ctx)
|
||||
if s == nil || s.UserName == "" {
|
||||
var ok bool
|
||||
username, ok = ctx.Value("preferred_username").(string)
|
||||
username, ok = ctx.Value(common.UsernameCtx).(string)
|
||||
if !ok {
|
||||
return false, errors.New("no valid session info or username found in context")
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package security
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"testing"
|
||||
)
|
||||
|
@ -20,7 +21,7 @@ var (
|
|||
)
|
||||
|
||||
func TestCheckHost(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "Tunnel", &info)
|
||||
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
|
||||
|
||||
Hosts = hosts
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ type customClaims struct {
|
|||
|
||||
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||
return func(ctx context.Context, host string) (bool, error) {
|
||||
s := getSessionInfo(ctx)
|
||||
s := getTunnel(ctx)
|
||||
if s == nil {
|
||||
return false, errors.New("no valid session info found in context")
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
|||
|
||||
token, err := jwt.ParseSigned(tokenString)
|
||||
if err != nil {
|
||||
log.Printf("cannot parse token due to: %s", err)
|
||||
log.Printf("cannot parse token due to: %tunnel", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
|||
// Claims automagically checks the signature...
|
||||
err = token.Claims(SigningKey, &standard, &custom)
|
||||
if err != nil {
|
||||
log.Printf("token signature validation failed due to %s", err)
|
||||
log.Printf("token signature validation failed due to %tunnel", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
@ -90,7 +90,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("token validation failed due to %s", err)
|
||||
log.Printf("token validation failed due to %tunnel", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
@ -98,15 +98,15 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
|||
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
||||
user, err := OIDCProvider.UserInfo(ctx, tokenSource)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get user info for access token: %s", err)
|
||||
log.Printf("Cannot get user info for access token: %tunnel", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
s := getSessionInfo(ctx)
|
||||
tunnel := getTunnel(ctx)
|
||||
|
||||
s.TargetServer = custom.RemoteServer
|
||||
s.RemoteAddr = custom.ClientIP
|
||||
s.UserName = user.Subject
|
||||
tunnel.TargetServer = custom.RemoteServer
|
||||
tunnel.RemoteAddr = custom.ClientIP
|
||||
tunnel.UserName = user.Subject
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
|||
return token, err
|
||||
}
|
||||
|
||||
func getSessionInfo(ctx context.Context) *protocol.Tunnel {
|
||||
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
||||
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
|
||||
if !ok {
|
||||
log.Printf("cannot get session info from context")
|
||||
|
|
|
@ -2,6 +2,7 @@ package web
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
@ -12,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
protocol = "unix"
|
||||
protocolGrpc = "unix"
|
||||
)
|
||||
|
||||
type BasicAuthHandler struct {
|
||||
|
@ -27,7 +28,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||
|
||||
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return net.Dial(protocol, addr)
|
||||
return net.Dial(protocolGrpc, addr)
|
||||
}))
|
||||
if err != nil {
|
||||
log.Printf("Cannot reach authentication provider: %s", err)
|
||||
|
@ -51,7 +52,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||
if !res.Authenticated {
|
||||
log.Printf("User %s is not authenticated for this service", username)
|
||||
} else {
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", username)
|
||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
@ -119,7 +120,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
|
||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
|
1
go.mod
1
go.mod
|
@ -6,6 +6,7 @@ require (
|
|||
github.com/coreos/go-oidc/v3 v3.2.0
|
||||
github.com/fatih/structs v1.1.0
|
||||
github.com/go-jose/go-jose/v3 v3.0.0
|
||||
github.com/google/uuid v1.1.2
|
||||
github.com/gorilla/sessions v1.2.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/knadh/koanf v1.4.2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue