mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-25 09:33:28 +02:00
Refactor identity framework to be more robust
This commit is contained in:
parent
bbd0735289
commit
b42c3cd3cc
11 changed files with 245 additions and 92 deletions
160
cmd/rdpgw/common/identity.go
Normal file
160
cmd/rdpgw/common/identity.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/google/uuid"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
|
||||
|
||||
AttrRemoteAddr = "remoteAddr"
|
||||
AttrClientIp = "clientIp"
|
||||
AttrProxies = "proxyAddresses"
|
||||
AttrAccessToken = "accessToken" // todo remove for security reasons
|
||||
)
|
||||
|
||||
type Identity interface {
|
||||
UserName() string
|
||||
SetUserName(string)
|
||||
DisplayName() string
|
||||
SetDisplayName(string)
|
||||
Domain() string
|
||||
SetDomain(string)
|
||||
Authenticated() bool
|
||||
SetAuthenticated(bool)
|
||||
AuthTime() time.Time
|
||||
SetAuthTime(time2 time.Time)
|
||||
SessionId() string
|
||||
SetAttribute(string, interface{})
|
||||
GetAttribute(string) interface{}
|
||||
Attributes() map[string]interface{}
|
||||
DelAttribute(string)
|
||||
Email() string
|
||||
SetEmail(string)
|
||||
Expiry() time.Time
|
||||
SetExpiry(time.Time)
|
||||
}
|
||||
|
||||
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, CTXKey, id)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func FromRequestCtx(r *http.Request) Identity {
|
||||
return FromCtx(r.Context())
|
||||
}
|
||||
|
||||
func FromCtx(ctx context.Context) Identity {
|
||||
if id, ok := ctx.Value(CTXKey).(Identity); ok {
|
||||
return id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type User struct {
|
||||
authenticated bool
|
||||
domain string
|
||||
userName string
|
||||
displayName string
|
||||
email string
|
||||
authTime time.Time
|
||||
sessionId string
|
||||
expiry time.Time
|
||||
attributes map[string]interface{}
|
||||
groupMembership map[string]bool
|
||||
}
|
||||
|
||||
func NewUser() *User {
|
||||
uuid := uuid.New().String()
|
||||
return &User{
|
||||
attributes: make(map[string]interface{}),
|
||||
groupMembership: make(map[string]bool),
|
||||
sessionId: uuid,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) UserName() string {
|
||||
return u.userName
|
||||
}
|
||||
|
||||
func (u *User) SetUserName(s string) {
|
||||
u.userName = s
|
||||
}
|
||||
|
||||
func (u *User) DisplayName() string {
|
||||
if u.displayName == "" {
|
||||
return u.userName
|
||||
}
|
||||
return u.displayName
|
||||
}
|
||||
|
||||
func (u *User) SetDisplayName(s string) {
|
||||
u.displayName = s
|
||||
}
|
||||
|
||||
func (u *User) Domain() string {
|
||||
return u.domain
|
||||
}
|
||||
|
||||
func (u *User) SetDomain(s string) {
|
||||
u.domain = s
|
||||
}
|
||||
|
||||
func (u *User) Authenticated() bool {
|
||||
return u.authenticated
|
||||
}
|
||||
|
||||
func (u *User) SetAuthenticated(b bool) {
|
||||
u.authenticated = b
|
||||
}
|
||||
|
||||
func (u *User) AuthTime() time.Time {
|
||||
return u.authTime
|
||||
}
|
||||
|
||||
func (u *User) SetAuthTime(t time.Time) {
|
||||
u.authTime = t
|
||||
}
|
||||
|
||||
func (u *User) SessionId() string {
|
||||
return u.sessionId
|
||||
}
|
||||
|
||||
func (u *User) SetAttribute(s string, i interface{}) {
|
||||
u.attributes[s] = i
|
||||
}
|
||||
|
||||
func (u *User) GetAttribute(s string) interface{} {
|
||||
if found, ok := u.attributes[s]; ok {
|
||||
return found
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) Attributes() map[string]interface{} {
|
||||
return u.attributes
|
||||
}
|
||||
|
||||
func (u *User) DelAttribute(s string) {
|
||||
delete(u.attributes, s)
|
||||
}
|
||||
|
||||
func (u *User) Email() string {
|
||||
return u.email
|
||||
}
|
||||
|
||||
func (u *User) SetEmail(s string) {
|
||||
u.email = s
|
||||
}
|
||||
|
||||
func (u *User) Expiry() time.Time {
|
||||
return u.expiry
|
||||
}
|
||||
|
||||
func (u *User) SetExpiry(t time.Time) {
|
||||
u.expiry = t
|
||||
}
|
|
@ -10,16 +10,17 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ClientIPCtx = "ClientIP"
|
||||
ProxyAddressesCtx = "ProxyAddresses"
|
||||
RemoteAddressCtx = "RemoteAddress"
|
||||
TunnelCtx = "TUNNEL"
|
||||
UsernameCtx = "preferred_username"
|
||||
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
|
||||
)
|
||||
|
||||
func EnrichContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
id := FromRequestCtx(r)
|
||||
if id == nil {
|
||||
id = NewUser()
|
||||
}
|
||||
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
|
||||
id.SessionId(), id.UserName(), id.Authenticated())
|
||||
|
||||
h := r.Header.Get("X-Forwarded-For")
|
||||
if h != "" {
|
||||
|
@ -32,41 +33,36 @@ func EnrichContext(next http.Handler) http.Handler {
|
|||
if len(ips) > 1 {
|
||||
proxies = ips[1:]
|
||||
}
|
||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
||||
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
|
||||
id.SetAttribute(AttrClientIp, clientIp)
|
||||
id.SetAttribute(AttrProxies, proxies)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr)
|
||||
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
|
||||
if h == "" {
|
||||
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
||||
id.SetAttribute(AttrClientIp, clientIp)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
next.ServeHTTP(w, AddToRequestCtx(id, r))
|
||||
})
|
||||
}
|
||||
|
||||
func FixKerberosContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := goidentity.FromHTTPRequestContext(r)
|
||||
if id != nil {
|
||||
ctx = context.WithValue(ctx, UsernameCtx, id.UserName())
|
||||
gid := goidentity.FromHTTPRequestContext(r)
|
||||
if gid != nil {
|
||||
id := FromRequestCtx(r)
|
||||
id.SetUserName(gid.UserName())
|
||||
id.SetAuthenticated(gid.Authenticated())
|
||||
id.SetDomain(gid.Domain())
|
||||
id.SetAuthTime(gid.AuthTime())
|
||||
r = AddToRequestCtx(id, r)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func GetClientIp(ctx context.Context) string {
|
||||
s, ok := ctx.Value(ClientIPCtx).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func GetAccessToken(ctx context.Context) string {
|
||||
token, ok := ctx.Value("access_token").(string)
|
||||
token, ok := ctx.Value(CtxAccessToken).(string)
|
||||
if !ok {
|
||||
log.Printf("cannot get access token from context")
|
||||
return ""
|
||||
|
|
|
@ -226,7 +226,7 @@ func main() {
|
|||
oidc := initOIDC(url, store)
|
||||
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
http.HandleFunc("/callback", oidc.HandleCallback)
|
||||
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
|
||||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||
|
|
|
@ -61,24 +61,20 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
var t *Tunnel
|
||||
|
||||
ctx := r.Context()
|
||||
id := common.FromRequestCtx(r)
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
t = &Tunnel{
|
||||
RDGId: connId,
|
||||
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
|
||||
}
|
||||
// username can be nil with openid & kerberos as it's only available later
|
||||
// todo grab kerberos principal now?
|
||||
username := ctx.Value(common.UsernameCtx)
|
||||
if username != nil {
|
||||
t.UserName = username.(string)
|
||||
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
|
||||
User: id,
|
||||
}
|
||||
} else {
|
||||
t = x.(*Tunnel)
|
||||
}
|
||||
ctx = context.WithValue(ctx, common.TunnelCtx, t)
|
||||
ctx = context.WithValue(ctx, CtxTunnel, t)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
|
@ -187,13 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
|||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
||||
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
||||
|
||||
id := common.FromRequestCtx(r)
|
||||
if r.Method == MethodRDGOUT {
|
||||
out, err := transport.NewLegacy(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
||||
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
|
||||
t.transportOut = out
|
||||
out.SendAccept(true)
|
||||
|
@ -215,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
|||
t.transportIn = in
|
||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||
|
||||
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
||||
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
in.SendAccept(false)
|
||||
|
||||
// read some initial data
|
||||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
||||
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
|
||||
handler := NewProcessor(g, t)
|
||||
RegisterTunnel(t, handler)
|
||||
defer RemoveTunnel(t)
|
||||
|
|
|
@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||
if p.state != SERVER_STATE_INITIALIZED {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||
|
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
_, cookie := p.tunnelRequest(pkt)
|
||||
if p.gw.CheckPAACookie != nil {
|
||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
p.tunnel.Write(msg)
|
||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||
|
@ -180,9 +180,9 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Creates a packet the is a response to a handshakeRequest request
|
||||
// Creates a packet and is a response to a handshakeRequest request
|
||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||
// but could be in Windows. However the NTLM protocol is insecure
|
||||
// but could be in Windows. However, the NTLM protocol is insecure
|
||||
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel"
|
||||
)
|
||||
|
||||
type Tunnel struct {
|
||||
// Id identifies the connection in the server
|
||||
Id string
|
||||
|
@ -22,7 +27,7 @@ type Tunnel struct {
|
|||
// The obtained client ip address
|
||||
RemoteAddr string
|
||||
// User
|
||||
UserName string
|
||||
User common.Identity
|
||||
|
||||
// rwc is the underlying connection to the remote desktop server.
|
||||
// It is of the type *net.TCPConn
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
@ -22,23 +21,14 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
|
|||
// todo get from context?
|
||||
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
||||
case "roundrobin", "unsigned":
|
||||
var username string
|
||||
|
||||
s := getTunnel(ctx)
|
||||
if s == nil || s.UserName == "" {
|
||||
var ok bool
|
||||
username, ok = ctx.Value(common.UsernameCtx).(string)
|
||||
if !ok {
|
||||
return false, errors.New("no valid session info or username found in context")
|
||||
}
|
||||
} else {
|
||||
username = s.UserName
|
||||
if s.User.UserName() == "" {
|
||||
return false, errors.New("no valid session info or username found in context")
|
||||
}
|
||||
log.Printf("Checking host for user %s", username)
|
||||
|
||||
log.Printf("Checking host for user %s", s.User.UserName())
|
||||
for _, h := range Hosts {
|
||||
if username != "" {
|
||||
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
|
||||
}
|
||||
h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1)
|
||||
if h == host {
|
||||
return true, nil
|
||||
}
|
||||
|
|
|
@ -12,14 +12,16 @@ var (
|
|||
RDGId: "myid",
|
||||
TargetServer: "my.remote.server",
|
||||
RemoteAddr: "10.0.0.1",
|
||||
UserName: "Frank",
|
||||
}
|
||||
|
||||
hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"}
|
||||
)
|
||||
|
||||
func TestCheckHost(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
|
||||
info.User = common.NewUser()
|
||||
info.User.SetUserName("MYNAME")
|
||||
|
||||
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)
|
||||
|
||||
Hosts = hosts
|
||||
|
||||
|
@ -40,14 +42,7 @@ func TestCheckHost(t *testing.T) {
|
|||
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
||||
}
|
||||
|
||||
host = "my-Frank-host:3389"
|
||||
if ok, err := CheckHost(ctx, host); !ok {
|
||||
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
||||
}
|
||||
|
||||
info.UserName = ""
|
||||
ctx = context.WithValue(ctx, "preferred_username", "dummy")
|
||||
host = "my-dummy-host:3389"
|
||||
host = "my-MYNAME-host:3389"
|
||||
if ok, err := CheckHost(ctx, host); !ok {
|
||||
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
||||
}
|
||||
|
|
|
@ -35,19 +35,21 @@ type customClaims struct {
|
|||
|
||||
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||
return func(ctx context.Context, host string) (bool, error) {
|
||||
s := getTunnel(ctx)
|
||||
if s == nil {
|
||||
tunnel := getTunnel(ctx)
|
||||
if tunnel == nil {
|
||||
return false, errors.New("no valid session info found in context")
|
||||
}
|
||||
|
||||
if s.TargetServer != host {
|
||||
log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer)
|
||||
if tunnel.TargetServer != host {
|
||||
log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) {
|
||||
// use identity from context rather then set by tunnel
|
||||
id := common.FromCtx(ctx)
|
||||
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
|
||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||
common.GetClientIp(ctx), s.RemoteAddr)
|
||||
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
|
||||
return false, nil
|
||||
}
|
||||
return next(ctx, host)
|
||||
|
@ -106,7 +108,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
|||
|
||||
tunnel.TargetServer = custom.RemoteServer
|
||||
tunnel.RemoteAddr = custom.ClientIP
|
||||
tunnel.UserName = user.Subject
|
||||
tunnel.User.SetUserName(user.Subject)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
@ -127,10 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
|
|||
Subject: username,
|
||||
}
|
||||
|
||||
id := common.FromCtx(ctx)
|
||||
private := customClaims{
|
||||
RemoteServer: server,
|
||||
ClientIP: common.GetClientIp(ctx),
|
||||
AccessToken: common.GetAccessToken(ctx),
|
||||
ClientIP: id.GetAttribute(common.AttrClientIp).(string),
|
||||
AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
|
||||
}
|
||||
|
||||
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
|
||||
|
@ -289,7 +292,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
|||
}
|
||||
|
||||
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
||||
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
|
||||
s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel)
|
||||
if !ok {
|
||||
log.Printf("cannot get session info from context")
|
||||
return nil
|
||||
|
|
|
@ -52,8 +52,12 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||
if !res.Authenticated {
|
||||
log.Printf("User %s is not authenticated for this service", username)
|
||||
} else {
|
||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
log.Printf("User %s authenticated", username)
|
||||
id := common.FromRequestCtx(r)
|
||||
id.SetUserName(username)
|
||||
id.SetAuthenticated(true)
|
||||
id.SetAuthTime(time.Now())
|
||||
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
|
@ -15,8 +14,10 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
CacheExpiration = time.Minute * 2
|
||||
CleanupInterval = time.Minute * 5
|
||||
CacheExpiration = time.Minute * 2
|
||||
CleanupInterval = time.Minute * 5
|
||||
sessionKeyAuthenticated = "authenticated"
|
||||
oidcKeyUserName = "preferred_username"
|
||||
)
|
||||
|
||||
type OIDC struct {
|
||||
|
@ -90,10 +91,14 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
id := common.FromRequestCtx(r)
|
||||
id.SetUserName(data[oidcKeyUserName].(string))
|
||||
id.SetAuthenticated(true)
|
||||
id.SetAuthTime(time.Now())
|
||||
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
|
||||
|
||||
session.Options.MaxAge = MaxAge
|
||||
session.Values["preferred_username"] = data["preferred_username"]
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["access_token"] = oauth2Token.AccessToken
|
||||
session.Values[common.CTXKey] = id
|
||||
|
||||
if err = session.Save(r, w); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
@ -110,8 +115,8 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
found := session.Values["authenticated"]
|
||||
if found == nil || !found.(bool) {
|
||||
id := session.Values[common.CTXKey].(common.Identity)
|
||||
if id == nil {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
state := hex.EncodeToString(seed)
|
||||
|
@ -120,9 +125,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
|
||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
// replace the identity with the one from the sessions
|
||||
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue