From cdf6e686842ab4563afe95df000037772f3096b3 Mon Sep 17 00:00:00 2001 From: totomz Date: Fri, 7 Apr 2023 12:15:06 +0200 Subject: [PATCH 1/2] Use multiple oidc claim to find the username The clim `preferred_username` is optional in Azure AD. Although is listed as preferred, in some enterprise environment it's not possible to add this additional claim. `unique_name` and `upn` are legacy alternatives --- cmd/rdpgw/web/oidc.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index 1a41b01..927f855 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -3,10 +3,12 @@ package web import ( "encoding/hex" "encoding/json" + "errors" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" "golang.org/x/oauth2" + "log" "math/rand" "net/http" "time" @@ -15,7 +17,6 @@ import ( const ( CacheExpiration = time.Minute * 2 CleanupInterval = time.Minute * 5 - oidcKeyUserName = "preferred_username" ) type OIDC struct { @@ -81,7 +82,15 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { } id := identity.FromRequestCtx(r) - id.SetUserName(data[oidcKeyUserName].(string)) + + userName := findUsernameInClaims(data) + if userName == "" { + err = errors.New("no odic claim for username found") + log.Print(err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } + + id.SetUserName(userName) id.SetAuthenticated(true) id.SetAuthTime(time.Now()) id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken) @@ -93,6 +102,18 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, url, http.StatusFound) } +func findUsernameInClaims(data map[string]interface{}) string { + candidates := []string{"preferred_username", "unique_name", "upn"} + for _, claim := range candidates { + userName, found := data[claim].(string) + if found { + return userName + } + } + + return "" +} + func (h *OIDC) Authenticated(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := identity.FromRequestCtx(r) From 9d9b7a9ab5983f5bd8f4beb7ca89bf8ad41cf75e Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Sun, 16 Apr 2023 10:02:47 +0200 Subject: [PATCH 2/2] Add test --- cmd/rdpgw/web/oidc.go | 6 +---- cmd/rdpgw/web/oidc_test.go | 49 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 cmd/rdpgw/web/oidc_test.go diff --git a/cmd/rdpgw/web/oidc.go b/cmd/rdpgw/web/oidc.go index 927f855..03cece1 100644 --- a/cmd/rdpgw/web/oidc.go +++ b/cmd/rdpgw/web/oidc.go @@ -3,12 +3,10 @@ package web import ( "encoding/hex" "encoding/json" - "errors" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/coreos/go-oidc/v3/oidc" "github.com/patrickmn/go-cache" "golang.org/x/oauth2" - "log" "math/rand" "net/http" "time" @@ -85,9 +83,7 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { userName := findUsernameInClaims(data) if userName == "" { - err = errors.New("no odic claim for username found") - log.Print(err) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, "no oidc claim for username found", http.StatusInternalServerError) } id.SetUserName(userName) diff --git a/cmd/rdpgw/web/oidc_test.go b/cmd/rdpgw/web/oidc_test.go new file mode 100644 index 0000000..37eb908 --- /dev/null +++ b/cmd/rdpgw/web/oidc_test.go @@ -0,0 +1,49 @@ +package web + +import "testing" + +func TestFindUserNameInClaims(t *testing.T) { + cases := []struct { + data map[string]interface{} + ret string + name string + }{ + { + data: map[string]interface{}{ + "preferred_username": "exists", + }, + ret: "exists", + name: "preferred_username", + }, + { + data: map[string]interface{}{ + "upn": "exists", + }, + ret: "exists", + name: "upn", + }, + { + data: map[string]interface{}{ + "unique_name": "exists", + }, + ret: "exists", + name: "unique_name", + }, + { + data: map[string]interface{}{ + "fail": "exists", + }, + ret: "", + name: "fail", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := findUsernameInClaims(tc.data) + if s != tc.ret { + t.Fatalf("expected return: %v, got: %v", tc.ret, s) + } + }) + } +}