rdpgw/cmd/rdpgw/web/oidc.go
2024-02-06 19:07:51 +08:00

133 lines
3.4 KiB
Go

package web
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http"
"time"
)
const (
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
)
type OIDC struct {
oAuth2Config *oauth2.Config
oidcTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
}
type OIDCConfig struct {
OAuth2Config *oauth2.Config
OIDCTokenVerifier *oidc.IDTokenVerifier
}
func (c *OIDCConfig) New() *OIDC {
return &OIDC{
oAuth2Config: c.OAuth2Config,
oidcTokenVerifier: c.OIDCTokenVerifier,
stateStore: cache.New(CacheExpiration, CleanupInterval),
}
}
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
s, found := h.stateStore.Get(state)
if !found {
http.Error(w, "unknown state", http.StatusBadRequest)
return
}
url := s.(string)
ctx := r.Context()
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return
}
idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return
}
resp := struct {
OAuth2Token *oauth2.Token
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var data map[string]interface{}
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
id := identity.FromRequestCtx(r)
userName := findUsernameInClaims(data)
if userName == "" {
http.Error(w, "no oidc claim for username found", http.StatusInternalServerError)
}
id.SetUserName(userName)
id.SetAuthenticated(true)
id.SetAuthTime(time.Now())
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
if err = SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Redirect(w, r, url, http.StatusFound)
}
func findUsernameInClaims(data map[string]interface{}) string {
candidates := []string{"preferred_username", "unique_name", "upn", "username"}
for _, claim := range candidates {
userName, found := data[claim].(string)
if found {
return userName
}
}
return ""
}
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
if !id.Authenticated() {
seed := make([]byte, 16)
_, err := rand.Read(seed)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
state := hex.EncodeToString(seed)
h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
// replace the identity with the one from the sessions
next.ServeHTTP(w, r)
})
}