Merge branch 'multiple_oidc'

This commit is contained in:
Bolke de Bruin 2023-04-16 10:03:46 +02:00
commit acd98367db
2 changed files with 68 additions and 2 deletions

View file

@ -15,7 +15,6 @@ import (
const (
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
oidcKeyUserName = "preferred_username"
)
type OIDC struct {
@ -81,7 +80,13 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
}
id := identity.FromRequestCtx(r)
id.SetUserName(data[oidcKeyUserName].(string))
userName := findUsernameInClaims(data)
if userName == "" {
http.Error(w, "no oidc claim for username found", http.StatusInternalServerError)
}
id.SetUserName(userName)
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
@ -93,6 +98,18 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, url, http.StatusFound)
}
func findUsernameInClaims(data map[string]interface{}) string {
candidates := []string{"preferred_username", "unique_name", "upn"}
for _, claim := range candidates {
userName, found := data[claim].(string)
if found {
return userName
}
}
return ""
}
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)

View file

@ -0,0 +1,49 @@
package web
import "testing"
func TestFindUserNameInClaims(t *testing.T) {
cases := []struct {
data map[string]interface{}
ret string
name string
}{
{
data: map[string]interface{}{
"preferred_username": "exists",
},
ret: "exists",
name: "preferred_username",
},
{
data: map[string]interface{}{
"upn": "exists",
},
ret: "exists",
name: "upn",
},
{
data: map[string]interface{}{
"unique_name": "exists",
},
ret: "exists",
name: "unique_name",
},
{
data: map[string]interface{}{
"fail": "exists",
},
ret: "",
name: "fail",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s := findUsernameInClaims(tc.data)
if s != tc.ret {
t.Fatalf("expected return: %v, got: %v", tc.ret, s)
}
})
}
}