Refactor identity framework to be more robust

This commit is contained in:
Bolke de Bruin 2022-10-13 11:13:24 +02:00
parent bbd0735289
commit b42c3cd3cc
11 changed files with 245 additions and 92 deletions

View file

@ -0,0 +1,160 @@
package common
import (
"context"
"github.com/google/uuid"
"net/http"
"time"
)
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}
type User struct {
authenticated bool
domain string
userName string
displayName string
email string
authTime time.Time
sessionId string
expiry time.Time
attributes map[string]interface{}
groupMembership map[string]bool
}
func NewUser() *User {
uuid := uuid.New().String()
return &User{
attributes: make(map[string]interface{}),
groupMembership: make(map[string]bool),
sessionId: uuid,
}
}
func (u *User) UserName() string {
return u.userName
}
func (u *User) SetUserName(s string) {
u.userName = s
}
func (u *User) DisplayName() string {
if u.displayName == "" {
return u.userName
}
return u.displayName
}
func (u *User) SetDisplayName(s string) {
u.displayName = s
}
func (u *User) Domain() string {
return u.domain
}
func (u *User) SetDomain(s string) {
u.domain = s
}
func (u *User) Authenticated() bool {
return u.authenticated
}
func (u *User) SetAuthenticated(b bool) {
u.authenticated = b
}
func (u *User) AuthTime() time.Time {
return u.authTime
}
func (u *User) SetAuthTime(t time.Time) {
u.authTime = t
}
func (u *User) SessionId() string {
return u.sessionId
}
func (u *User) SetAttribute(s string, i interface{}) {
u.attributes[s] = i
}
func (u *User) GetAttribute(s string) interface{} {
if found, ok := u.attributes[s]; ok {
return found
}
return nil
}
func (u *User) Attributes() map[string]interface{} {
return u.attributes
}
func (u *User) DelAttribute(s string) {
delete(u.attributes, s)
}
func (u *User) Email() string {
return u.email
}
func (u *User) SetEmail(s string) {
u.email = s
}
func (u *User) Expiry() time.Time {
return u.expiry
}
func (u *User) SetExpiry(t time.Time) {
u.expiry = t
}

View file

@ -10,16 +10,17 @@ import (
) )
const ( const (
ClientIPCtx = "ClientIP" CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress"
TunnelCtx = "TUNNEL"
UsernameCtx = "preferred_username"
) )
func EnrichContext(next http.Handler) http.Handler { func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() id := FromRequestCtx(r)
if id == nil {
id = NewUser()
}
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
id.SessionId(), id.UserName(), id.Authenticated())
h := r.Header.Get("X-Forwarded-For") h := r.Header.Get("X-Forwarded-For")
if h != "" { if h != "" {
@ -32,41 +33,36 @@ func EnrichContext(next http.Handler) http.Handler {
if len(ips) > 1 { if len(ips) > 1 {
proxies = ips[1:] proxies = ips[1:]
} }
ctx = context.WithValue(ctx, ClientIPCtx, clientIp) id.SetAttribute(AttrClientIp, clientIp)
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies) id.SetAttribute(AttrProxies, proxies)
} }
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr) id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
if h == "" { if h == "" {
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr) clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
ctx = context.WithValue(ctx, ClientIPCtx, clientIp) id.SetAttribute(AttrClientIp, clientIp)
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, AddToRequestCtx(id, r))
}) })
} }
func FixKerberosContext(next http.Handler) http.Handler { func FixKerberosContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() gid := goidentity.FromHTTPRequestContext(r)
if gid != nil {
id := goidentity.FromHTTPRequestContext(r) id := FromRequestCtx(r)
if id != nil { id.SetUserName(gid.UserName())
ctx = context.WithValue(ctx, UsernameCtx, id.UserName()) id.SetAuthenticated(gid.Authenticated())
id.SetDomain(gid.Domain())
id.SetAuthTime(gid.AuthTime())
r = AddToRequestCtx(id, r)
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r)
}) })
} }
func GetClientIp(ctx context.Context) string {
s, ok := ctx.Value(ClientIPCtx).(string)
if !ok {
return ""
}
return s
}
func GetAccessToken(ctx context.Context) string { func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value("access_token").(string) token, ok := ctx.Value(CtxAccessToken).(string)
if !ok { if !ok {
log.Printf("cannot get access token from context") log.Printf("cannot get access token from context")
return "" return ""

View file

@ -226,7 +226,7 @@ func main() {
oidc := initOIDC(url, store) oidc := initOIDC(url, store)
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload)))) http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.HandleFunc("/callback", oidc.HandleCallback) http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
} }
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo) http.HandleFunc("/tokeninfo", web.TokenInfo)

View file

@ -61,24 +61,20 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
var t *Tunnel var t *Tunnel
ctx := r.Context() ctx := r.Context()
id := common.FromRequestCtx(r)
connId := r.Header.Get(rdgConnectionIdKey) connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId) x, found := c.Get(connId)
if !found { if !found {
t = &Tunnel{ t = &Tunnel{
RDGId: connId, RDGId: connId,
RemoteAddr: ctx.Value(common.ClientIPCtx).(string), RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
} User: id,
// username can be nil with openid & kerberos as it's only available later
// todo grab kerberos principal now?
username := ctx.Value(common.UsernameCtx)
if username != nil {
t.UserName = username.(string)
} }
} else { } else {
t = x.(*Tunnel) t = x.(*Tunnel)
} }
ctx = context.WithValue(ctx, common.TunnelCtx, t) ctx = context.WithValue(ctx, CtxTunnel, t)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
@ -187,13 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
id := common.FromRequestCtx(r)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w) out, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context())) log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
t.transportOut = out t.transportOut = out
out.SendAccept(true) out.SendAccept(true)
@ -215,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
t.transportIn = in t.transportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration) c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context())) log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
in.SendAccept(false) in.SendAccept(false)
// read some initial data // read some initial data
in.Drain() in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context())) log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
handler := NewProcessor(g, t) handler := NewProcessor(g, t)
RegisterTunnel(t, handler) RegisterTunnel(t, handler)
defer RemoveTunnel(t) defer RemoveTunnel(t)

View file

@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx)) log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
if p.state != SERVER_STATE_INITIALIZED { if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED) log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR) msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
_, cookie := p.tunnelRequest(pkt) _, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil { if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok { if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx)) log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg) p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED) return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
@ -180,9 +180,9 @@ func (p *Processor) Process(ctx context.Context) error {
} }
} }
// Creates a packet the is a response to a handshakeRequest request // Creates a packet and is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure // but could be in Windows. However, the NTLM protocol is insecure
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte { func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code

View file

@ -1,11 +1,16 @@
package protocol package protocol
import ( import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net" "net"
"time" "time"
) )
const (
CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel"
)
type Tunnel struct { type Tunnel struct {
// Id identifies the connection in the server // Id identifies the connection in the server
Id string Id string
@ -22,7 +27,7 @@ type Tunnel struct {
// The obtained client ip address // The obtained client ip address
RemoteAddr string RemoteAddr string
// User // User
UserName string User common.Identity
// rwc is the underlying connection to the remote desktop server. // rwc is the underlying connection to the remote desktop server.
// It is of the type *net.TCPConn // It is of the type *net.TCPConn

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"log" "log"
"strings" "strings"
) )
@ -22,23 +21,14 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
// todo get from context? // todo get from context?
return false, errors.New("cannot verify host in 'signed' mode as token data is missing") return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
case "roundrobin", "unsigned": case "roundrobin", "unsigned":
var username string
s := getTunnel(ctx) s := getTunnel(ctx)
if s == nil || s.UserName == "" { if s.User.UserName() == "" {
var ok bool return false, errors.New("no valid session info or username found in context")
username, ok = ctx.Value(common.UsernameCtx).(string)
if !ok {
return false, errors.New("no valid session info or username found in context")
}
} else {
username = s.UserName
} }
log.Printf("Checking host for user %s", username)
log.Printf("Checking host for user %s", s.User.UserName())
for _, h := range Hosts { for _, h := range Hosts {
if username != "" { h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1)
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
}
if h == host { if h == host {
return true, nil return true, nil
} }

View file

@ -12,14 +12,16 @@ var (
RDGId: "myid", RDGId: "myid",
TargetServer: "my.remote.server", TargetServer: "my.remote.server",
RemoteAddr: "10.0.0.1", RemoteAddr: "10.0.0.1",
UserName: "Frank",
} }
hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"} hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"}
) )
func TestCheckHost(t *testing.T) { func TestCheckHost(t *testing.T) {
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info) info.User = common.NewUser()
info.User.SetUserName("MYNAME")
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)
Hosts = hosts Hosts = hosts
@ -40,14 +42,7 @@ func TestCheckHost(t *testing.T) {
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err) t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
} }
host = "my-Frank-host:3389" host = "my-MYNAME-host:3389"
if ok, err := CheckHost(ctx, host); !ok {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
}
info.UserName = ""
ctx = context.WithValue(ctx, "preferred_username", "dummy")
host = "my-dummy-host:3389"
if ok, err := CheckHost(ctx, host); !ok { if ok, err := CheckHost(ctx, host); !ok {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
} }

View file

@ -35,19 +35,21 @@ type customClaims struct {
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
return func(ctx context.Context, host string) (bool, error) { return func(ctx context.Context, host string) (bool, error) {
s := getTunnel(ctx) tunnel := getTunnel(ctx)
if s == nil { if tunnel == nil {
return false, errors.New("no valid session info found in context") return false, errors.New("no valid session info found in context")
} }
if s.TargetServer != host { if tunnel.TargetServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer) log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer)
return false, nil return false, nil
} }
if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) { // use identity from context rather then set by tunnel
id := common.FromCtx(ctx)
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
log.Printf("Current client ip address %s does not match token client ip %s", log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.RemoteAddr) id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
return false, nil return false, nil
} }
return next(ctx, host) return next(ctx, host)
@ -106,7 +108,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
tunnel.TargetServer = custom.RemoteServer tunnel.TargetServer = custom.RemoteServer
tunnel.RemoteAddr = custom.ClientIP tunnel.RemoteAddr = custom.ClientIP
tunnel.UserName = user.Subject tunnel.User.SetUserName(user.Subject)
return true, nil return true, nil
} }
@ -127,10 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
Subject: username, Subject: username,
} }
id := common.FromCtx(ctx)
private := customClaims{ private := customClaims{
RemoteServer: server, RemoteServer: server,
ClientIP: common.GetClientIp(ctx), ClientIP: id.GetAttribute(common.AttrClientIp).(string),
AccessToken: common.GetAccessToken(ctx), AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
} }
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil { if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
@ -289,7 +292,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
} }
func getTunnel(ctx context.Context) *protocol.Tunnel { func getTunnel(ctx context.Context) *protocol.Tunnel {
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel) s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel)
if !ok { if !ok {
log.Printf("cannot get session info from context") log.Printf("cannot get session info from context")
return nil return nil

View file

@ -52,8 +52,12 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
if !res.Authenticated { if !res.Authenticated {
log.Printf("User %s is not authenticated for this service", username) log.Printf("User %s is not authenticated for this service", username)
} else { } else {
ctx := context.WithValue(r.Context(), common.UsernameCtx, username) log.Printf("User %s authenticated", username)
next.ServeHTTP(w, r.WithContext(ctx)) id := common.FromRequestCtx(r)
id.SetUserName(username)
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
return return
} }

View file

@ -1,7 +1,6 @@
package web package web
import ( import (
"context"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
@ -15,8 +14,10 @@ import (
) )
const ( const (
CacheExpiration = time.Minute * 2 CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5 CleanupInterval = time.Minute * 5
sessionKeyAuthenticated = "authenticated"
oidcKeyUserName = "preferred_username"
) )
type OIDC struct { type OIDC struct {
@ -90,10 +91,14 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
return return
} }
id := common.FromRequestCtx(r)
id.SetUserName(data[oidcKeyUserName].(string))
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
session.Options.MaxAge = MaxAge session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"] session.Values[common.CTXKey] = id
session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
if err = session.Save(r, w); err != nil { if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -110,8 +115,8 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return return
} }
found := session.Values["authenticated"] id := session.Values[common.CTXKey].(common.Identity)
if found == nil || !found.(bool) { if id == nil {
seed := make([]byte, 16) seed := make([]byte, 16)
rand.Read(seed) rand.Read(seed)
state := hex.EncodeToString(seed) state := hex.EncodeToString(seed)
@ -120,9 +125,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return return
} }
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"]) // replace the identity with the one from the sessions
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"]) next.ServeHTTP(w, common.AddToRequestCtx(id, r))
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }