Refactor identity framework to be more robust

This commit is contained in:
Bolke de Bruin 2022-10-13 11:13:24 +02:00
parent bbd0735289
commit b42c3cd3cc
11 changed files with 245 additions and 92 deletions

View file

@ -0,0 +1,160 @@
package common
import (
"context"
"github.com/google/uuid"
"net/http"
"time"
)
const (
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
AttrRemoteAddr = "remoteAddr"
AttrClientIp = "clientIp"
AttrProxies = "proxyAddresses"
AttrAccessToken = "accessToken" // todo remove for security reasons
)
type Identity interface {
UserName() string
SetUserName(string)
DisplayName() string
SetDisplayName(string)
Domain() string
SetDomain(string)
Authenticated() bool
SetAuthenticated(bool)
AuthTime() time.Time
SetAuthTime(time2 time.Time)
SessionId() string
SetAttribute(string, interface{})
GetAttribute(string) interface{}
Attributes() map[string]interface{}
DelAttribute(string)
Email() string
SetEmail(string)
Expiry() time.Time
SetExpiry(time.Time)
}
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, CTXKey, id)
return r.WithContext(ctx)
}
func FromRequestCtx(r *http.Request) Identity {
return FromCtx(r.Context())
}
func FromCtx(ctx context.Context) Identity {
if id, ok := ctx.Value(CTXKey).(Identity); ok {
return id
}
return nil
}
type User struct {
authenticated bool
domain string
userName string
displayName string
email string
authTime time.Time
sessionId string
expiry time.Time
attributes map[string]interface{}
groupMembership map[string]bool
}
func NewUser() *User {
uuid := uuid.New().String()
return &User{
attributes: make(map[string]interface{}),
groupMembership: make(map[string]bool),
sessionId: uuid,
}
}
func (u *User) UserName() string {
return u.userName
}
func (u *User) SetUserName(s string) {
u.userName = s
}
func (u *User) DisplayName() string {
if u.displayName == "" {
return u.userName
}
return u.displayName
}
func (u *User) SetDisplayName(s string) {
u.displayName = s
}
func (u *User) Domain() string {
return u.domain
}
func (u *User) SetDomain(s string) {
u.domain = s
}
func (u *User) Authenticated() bool {
return u.authenticated
}
func (u *User) SetAuthenticated(b bool) {
u.authenticated = b
}
func (u *User) AuthTime() time.Time {
return u.authTime
}
func (u *User) SetAuthTime(t time.Time) {
u.authTime = t
}
func (u *User) SessionId() string {
return u.sessionId
}
func (u *User) SetAttribute(s string, i interface{}) {
u.attributes[s] = i
}
func (u *User) GetAttribute(s string) interface{} {
if found, ok := u.attributes[s]; ok {
return found
}
return nil
}
func (u *User) Attributes() map[string]interface{} {
return u.attributes
}
func (u *User) DelAttribute(s string) {
delete(u.attributes, s)
}
func (u *User) Email() string {
return u.email
}
func (u *User) SetEmail(s string) {
u.email = s
}
func (u *User) Expiry() time.Time {
return u.expiry
}
func (u *User) SetExpiry(t time.Time) {
u.expiry = t
}

View file

@ -10,16 +10,17 @@ import (
)
const (
ClientIPCtx = "ClientIP"
ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress"
TunnelCtx = "TUNNEL"
UsernameCtx = "preferred_username"
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
)
func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := FromRequestCtx(r)
if id == nil {
id = NewUser()
}
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
id.SessionId(), id.UserName(), id.Authenticated())
h := r.Header.Get("X-Forwarded-For")
if h != "" {
@ -32,41 +33,36 @@ func EnrichContext(next http.Handler) http.Handler {
if len(ips) > 1 {
proxies = ips[1:]
}
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
id.SetAttribute(AttrClientIp, clientIp)
id.SetAttribute(AttrProxies, proxies)
}
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr)
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
if h == "" {
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
id.SetAttribute(AttrClientIp, clientIp)
}
next.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, AddToRequestCtx(id, r))
})
}
func FixKerberosContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := goidentity.FromHTTPRequestContext(r)
if id != nil {
ctx = context.WithValue(ctx, UsernameCtx, id.UserName())
gid := goidentity.FromHTTPRequestContext(r)
if gid != nil {
id := FromRequestCtx(r)
id.SetUserName(gid.UserName())
id.SetAuthenticated(gid.Authenticated())
id.SetDomain(gid.Domain())
id.SetAuthTime(gid.AuthTime())
r = AddToRequestCtx(id, r)
}
next.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, r)
})
}
func GetClientIp(ctx context.Context) string {
s, ok := ctx.Value(ClientIPCtx).(string)
if !ok {
return ""
}
return s
}
func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value("access_token").(string)
token, ok := ctx.Value(CtxAccessToken).(string)
if !ok {
log.Printf("cannot get access token from context")
return ""

View file

@ -226,7 +226,7 @@ func main() {
oidc := initOIDC(url, store)
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.HandleFunc("/callback", oidc.HandleCallback)
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
}
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", web.TokenInfo)

View file

@ -61,24 +61,20 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
var t *Tunnel
ctx := r.Context()
id := common.FromRequestCtx(r)
connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId)
if !found {
t = &Tunnel{
RDGId: connId,
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
}
// username can be nil with openid & kerberos as it's only available later
// todo grab kerberos principal now?
username := ctx.Value(common.UsernameCtx)
if username != nil {
t.UserName = username.(string)
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
User: id,
}
} else {
t = x.(*Tunnel)
}
ctx = context.WithValue(ctx, common.TunnelCtx, t)
ctx = context.WithValue(ctx, CtxTunnel, t)
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
@ -187,13 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
id := common.FromRequestCtx(r)
if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
t.transportOut = out
out.SendAccept(true)
@ -215,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
t.transportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
handler := NewProcessor(g, t)
RegisterTunnel(t, handler)
defer RemoveTunnel(t)

View file

@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
if p.state != SERVER_STATE_INITIALIZED {
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
_, cookie := p.tunnelRequest(pkt)
if p.gw.CheckPAACookie != nil {
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
p.tunnel.Write(msg)
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
@ -180,9 +180,9 @@ func (p *Processor) Process(ctx context.Context) error {
}
}
// Creates a packet the is a response to a handshakeRequest request
// Creates a packet and is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
// but could be in Windows. However, the NTLM protocol is insecure
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code

View file

@ -1,11 +1,16 @@
package protocol
import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"net"
"time"
)
const (
CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel"
)
type Tunnel struct {
// Id identifies the connection in the server
Id string
@ -22,7 +27,7 @@ type Tunnel struct {
// The obtained client ip address
RemoteAddr string
// User
UserName string
User common.Identity
// rwc is the underlying connection to the remote desktop server.
// It is of the type *net.TCPConn

View file

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"log"
"strings"
)
@ -22,23 +21,14 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
// todo get from context?
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
case "roundrobin", "unsigned":
var username string
s := getTunnel(ctx)
if s == nil || s.UserName == "" {
var ok bool
username, ok = ctx.Value(common.UsernameCtx).(string)
if !ok {
return false, errors.New("no valid session info or username found in context")
}
} else {
username = s.UserName
if s.User.UserName() == "" {
return false, errors.New("no valid session info or username found in context")
}
log.Printf("Checking host for user %s", username)
log.Printf("Checking host for user %s", s.User.UserName())
for _, h := range Hosts {
if username != "" {
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
}
h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1)
if h == host {
return true, nil
}

View file

@ -12,14 +12,16 @@ var (
RDGId: "myid",
TargetServer: "my.remote.server",
RemoteAddr: "10.0.0.1",
UserName: "Frank",
}
hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"}
)
func TestCheckHost(t *testing.T) {
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
info.User = common.NewUser()
info.User.SetUserName("MYNAME")
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)
Hosts = hosts
@ -40,14 +42,7 @@ func TestCheckHost(t *testing.T) {
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
}
host = "my-Frank-host:3389"
if ok, err := CheckHost(ctx, host); !ok {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
}
info.UserName = ""
ctx = context.WithValue(ctx, "preferred_username", "dummy")
host = "my-dummy-host:3389"
host = "my-MYNAME-host:3389"
if ok, err := CheckHost(ctx, host); !ok {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
}

View file

@ -35,19 +35,21 @@ type customClaims struct {
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
return func(ctx context.Context, host string) (bool, error) {
s := getTunnel(ctx)
if s == nil {
tunnel := getTunnel(ctx)
if tunnel == nil {
return false, errors.New("no valid session info found in context")
}
if s.TargetServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer)
if tunnel.TargetServer != host {
log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer)
return false, nil
}
if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) {
// use identity from context rather then set by tunnel
id := common.FromCtx(ctx)
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.RemoteAddr)
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
return false, nil
}
return next(ctx, host)
@ -106,7 +108,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
tunnel.TargetServer = custom.RemoteServer
tunnel.RemoteAddr = custom.ClientIP
tunnel.UserName = user.Subject
tunnel.User.SetUserName(user.Subject)
return true, nil
}
@ -127,10 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
Subject: username,
}
id := common.FromCtx(ctx)
private := customClaims{
RemoteServer: server,
ClientIP: common.GetClientIp(ctx),
AccessToken: common.GetAccessToken(ctx),
ClientIP: id.GetAttribute(common.AttrClientIp).(string),
AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
}
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
@ -289,7 +292,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
}
func getTunnel(ctx context.Context) *protocol.Tunnel {
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel)
if !ok {
log.Printf("cannot get session info from context")
return nil

View file

@ -52,8 +52,12 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
if !res.Authenticated {
log.Printf("User %s is not authenticated for this service", username)
} else {
ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
next.ServeHTTP(w, r.WithContext(ctx))
log.Printf("User %s authenticated", username)
id := common.FromRequestCtx(r)
id.SetUserName(username)
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
return
}

View file

@ -1,7 +1,6 @@
package web
import (
"context"
"encoding/hex"
"encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
@ -15,8 +14,10 @@ import (
)
const (
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
sessionKeyAuthenticated = "authenticated"
oidcKeyUserName = "preferred_username"
)
type OIDC struct {
@ -90,10 +91,14 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
id := common.FromRequestCtx(r)
id.SetUserName(data[oidcKeyUserName].(string))
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
session.Values[common.CTXKey] = id
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -110,8 +115,8 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return
}
found := session.Values["authenticated"]
if found == nil || !found.(bool) {
id := session.Values[common.CTXKey].(common.Identity)
if id == nil {
seed := make([]byte, 16)
rand.Read(seed)
state := hex.EncodeToString(seed)
@ -120,9 +125,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return
}
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx))
// replace the identity with the one from the sessions
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
})
}