mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-25 09:33:28 +02:00
Refactor identity framework to be more robust
This commit is contained in:
parent
bbd0735289
commit
b42c3cd3cc
11 changed files with 245 additions and 92 deletions
160
cmd/rdpgw/common/identity.go
Normal file
160
cmd/rdpgw/common/identity.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CTXKey = "github.com/bolkedebruin/rdpgw/common/identity"
|
||||||
|
|
||||||
|
AttrRemoteAddr = "remoteAddr"
|
||||||
|
AttrClientIp = "clientIp"
|
||||||
|
AttrProxies = "proxyAddresses"
|
||||||
|
AttrAccessToken = "accessToken" // todo remove for security reasons
|
||||||
|
)
|
||||||
|
|
||||||
|
type Identity interface {
|
||||||
|
UserName() string
|
||||||
|
SetUserName(string)
|
||||||
|
DisplayName() string
|
||||||
|
SetDisplayName(string)
|
||||||
|
Domain() string
|
||||||
|
SetDomain(string)
|
||||||
|
Authenticated() bool
|
||||||
|
SetAuthenticated(bool)
|
||||||
|
AuthTime() time.Time
|
||||||
|
SetAuthTime(time2 time.Time)
|
||||||
|
SessionId() string
|
||||||
|
SetAttribute(string, interface{})
|
||||||
|
GetAttribute(string) interface{}
|
||||||
|
Attributes() map[string]interface{}
|
||||||
|
DelAttribute(string)
|
||||||
|
Email() string
|
||||||
|
SetEmail(string)
|
||||||
|
Expiry() time.Time
|
||||||
|
SetExpiry(time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddToRequestCtx(id Identity, r *http.Request) *http.Request {
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, CTXKey, id)
|
||||||
|
return r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromRequestCtx(r *http.Request) Identity {
|
||||||
|
return FromCtx(r.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromCtx(ctx context.Context) Identity {
|
||||||
|
if id, ok := ctx.Value(CTXKey).(Identity); ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
authenticated bool
|
||||||
|
domain string
|
||||||
|
userName string
|
||||||
|
displayName string
|
||||||
|
email string
|
||||||
|
authTime time.Time
|
||||||
|
sessionId string
|
||||||
|
expiry time.Time
|
||||||
|
attributes map[string]interface{}
|
||||||
|
groupMembership map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUser() *User {
|
||||||
|
uuid := uuid.New().String()
|
||||||
|
return &User{
|
||||||
|
attributes: make(map[string]interface{}),
|
||||||
|
groupMembership: make(map[string]bool),
|
||||||
|
sessionId: uuid,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) UserName() string {
|
||||||
|
return u.userName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetUserName(s string) {
|
||||||
|
u.userName = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) DisplayName() string {
|
||||||
|
if u.displayName == "" {
|
||||||
|
return u.userName
|
||||||
|
}
|
||||||
|
return u.displayName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetDisplayName(s string) {
|
||||||
|
u.displayName = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Domain() string {
|
||||||
|
return u.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetDomain(s string) {
|
||||||
|
u.domain = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Authenticated() bool {
|
||||||
|
return u.authenticated
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetAuthenticated(b bool) {
|
||||||
|
u.authenticated = b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) AuthTime() time.Time {
|
||||||
|
return u.authTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetAuthTime(t time.Time) {
|
||||||
|
u.authTime = t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SessionId() string {
|
||||||
|
return u.sessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetAttribute(s string, i interface{}) {
|
||||||
|
u.attributes[s] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) GetAttribute(s string) interface{} {
|
||||||
|
if found, ok := u.attributes[s]; ok {
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Attributes() map[string]interface{} {
|
||||||
|
return u.attributes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) DelAttribute(s string) {
|
||||||
|
delete(u.attributes, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Email() string {
|
||||||
|
return u.email
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetEmail(s string) {
|
||||||
|
u.email = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Expiry() time.Time {
|
||||||
|
return u.expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) SetExpiry(t time.Time) {
|
||||||
|
u.expiry = t
|
||||||
|
}
|
|
@ -10,16 +10,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ClientIPCtx = "ClientIP"
|
CtxAccessToken = "github.com/bolkedebruin/rdpgw/oidc/access_token"
|
||||||
ProxyAddressesCtx = "ProxyAddresses"
|
|
||||||
RemoteAddressCtx = "RemoteAddress"
|
|
||||||
TunnelCtx = "TUNNEL"
|
|
||||||
UsernameCtx = "preferred_username"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func EnrichContext(next http.Handler) http.Handler {
|
func EnrichContext(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
id := FromRequestCtx(r)
|
||||||
|
if id == nil {
|
||||||
|
id = NewUser()
|
||||||
|
}
|
||||||
|
log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t",
|
||||||
|
id.SessionId(), id.UserName(), id.Authenticated())
|
||||||
|
|
||||||
h := r.Header.Get("X-Forwarded-For")
|
h := r.Header.Get("X-Forwarded-For")
|
||||||
if h != "" {
|
if h != "" {
|
||||||
|
@ -32,41 +33,36 @@ func EnrichContext(next http.Handler) http.Handler {
|
||||||
if len(ips) > 1 {
|
if len(ips) > 1 {
|
||||||
proxies = ips[1:]
|
proxies = ips[1:]
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
id.SetAttribute(AttrClientIp, clientIp)
|
||||||
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
|
id.SetAttribute(AttrProxies, proxies)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr)
|
id.SetAttribute(AttrRemoteAddr, r.RemoteAddr)
|
||||||
if h == "" {
|
if h == "" {
|
||||||
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
|
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
id.SetAttribute(AttrClientIp, clientIp)
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, AddToRequestCtx(id, r))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func FixKerberosContext(next http.Handler) http.Handler {
|
func FixKerberosContext(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
gid := goidentity.FromHTTPRequestContext(r)
|
||||||
|
if gid != nil {
|
||||||
id := goidentity.FromHTTPRequestContext(r)
|
id := FromRequestCtx(r)
|
||||||
if id != nil {
|
id.SetUserName(gid.UserName())
|
||||||
ctx = context.WithValue(ctx, UsernameCtx, id.UserName())
|
id.SetAuthenticated(gid.Authenticated())
|
||||||
|
id.SetDomain(gid.Domain())
|
||||||
|
id.SetAuthTime(gid.AuthTime())
|
||||||
|
r = AddToRequestCtx(id, r)
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClientIp(ctx context.Context) string {
|
|
||||||
s, ok := ctx.Value(ClientIPCtx).(string)
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccessToken(ctx context.Context) string {
|
func GetAccessToken(ctx context.Context) string {
|
||||||
token, ok := ctx.Value("access_token").(string)
|
token, ok := ctx.Value(CtxAccessToken).(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("cannot get access token from context")
|
log.Printf("cannot get access token from context")
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -226,7 +226,7 @@ func main() {
|
||||||
oidc := initOIDC(url, store)
|
oidc := initOIDC(url, store)
|
||||||
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
||||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
http.HandleFunc("/callback", oidc.HandleCallback)
|
http.Handle("/callback", common.EnrichContext(http.HandlerFunc(oidc.HandleCallback)))
|
||||||
}
|
}
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||||
|
|
|
@ -61,24 +61,20 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
var t *Tunnel
|
var t *Tunnel
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
id := common.FromRequestCtx(r)
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
x, found := c.Get(connId)
|
x, found := c.Get(connId)
|
||||||
if !found {
|
if !found {
|
||||||
t = &Tunnel{
|
t = &Tunnel{
|
||||||
RDGId: connId,
|
RDGId: connId,
|
||||||
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
|
RemoteAddr: id.GetAttribute(common.AttrRemoteAddr).(string),
|
||||||
}
|
User: id,
|
||||||
// username can be nil with openid & kerberos as it's only available later
|
|
||||||
// todo grab kerberos principal now?
|
|
||||||
username := ctx.Value(common.UsernameCtx)
|
|
||||||
if username != nil {
|
|
||||||
t.UserName = username.(string)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t = x.(*Tunnel)
|
t = x.(*Tunnel)
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, common.TunnelCtx, t)
|
ctx = context.WithValue(ctx, CtxTunnel, t)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
|
@ -187,13 +183,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
||||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
||||||
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
||||||
|
|
||||||
|
id := common.FromRequestCtx(r)
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
out, err := transport.NewLegacy(w)
|
out, err := transport.NewLegacy(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGOUT for client %s", id.GetAttribute(common.AttrClientIp))
|
||||||
|
|
||||||
t.transportOut = out
|
t.transportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
@ -215,13 +212,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
||||||
t.transportIn = in
|
t.transportIn = in
|
||||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGIN for client %s", id.GetAttribute(common.AttrClientIp))
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(common.AttrClientIp))
|
||||||
handler := NewProcessor(g, t)
|
handler := NewProcessor(g, t)
|
||||||
RegisterTunnel(t, handler)
|
RegisterTunnel(t, handler)
|
||||||
defer RemoveTunnel(t)
|
defer RemoveTunnel(t)
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
|
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
log.Printf("Client handshakeRequest from %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||||
if p.state != SERVER_STATE_INITIALIZED {
|
if p.state != SERVER_STATE_INITIALIZED {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||||
|
@ -81,7 +81,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
_, cookie := p.tunnelRequest(pkt)
|
_, cookie := p.tunnelRequest(pkt)
|
||||||
if p.gw.CheckPAACookie != nil {
|
if p.gw.CheckPAACookie != nil {
|
||||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
log.Printf("Invalid PAA cookie received from client %s", p.tunnel.User.GetAttribute(common.AttrClientIp))
|
||||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
|
@ -180,9 +180,9 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a packet the is a response to a handshakeRequest request
|
// Creates a packet and is a response to a handshakeRequest request
|
||||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||||
// but could be in Windows. However the NTLM protocol is insecure
|
// but could be in Windows. However, the NTLM protocol is insecure
|
||||||
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
|
func (p *Processor) handshakeResponse(major byte, minor byte, caps uint16, errorCode int) []byte {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
|
binary.Write(buf, binary.LittleEndian, uint32(errorCode)) // error_code
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel"
|
||||||
|
)
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
// Id identifies the connection in the server
|
// Id identifies the connection in the server
|
||||||
Id string
|
Id string
|
||||||
|
@ -22,7 +27,7 @@ type Tunnel struct {
|
||||||
// The obtained client ip address
|
// The obtained client ip address
|
||||||
RemoteAddr string
|
RemoteAddr string
|
||||||
// User
|
// User
|
||||||
UserName string
|
User common.Identity
|
||||||
|
|
||||||
// rwc is the underlying connection to the remote desktop server.
|
// rwc is the underlying connection to the remote desktop server.
|
||||||
// It is of the type *net.TCPConn
|
// It is of the type *net.TCPConn
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -22,23 +21,14 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
|
||||||
// todo get from context?
|
// todo get from context?
|
||||||
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
||||||
case "roundrobin", "unsigned":
|
case "roundrobin", "unsigned":
|
||||||
var username string
|
|
||||||
|
|
||||||
s := getTunnel(ctx)
|
s := getTunnel(ctx)
|
||||||
if s == nil || s.UserName == "" {
|
if s.User.UserName() == "" {
|
||||||
var ok bool
|
|
||||||
username, ok = ctx.Value(common.UsernameCtx).(string)
|
|
||||||
if !ok {
|
|
||||||
return false, errors.New("no valid session info or username found in context")
|
return false, errors.New("no valid session info or username found in context")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
username = s.UserName
|
log.Printf("Checking host for user %s", s.User.UserName())
|
||||||
}
|
|
||||||
log.Printf("Checking host for user %s", username)
|
|
||||||
for _, h := range Hosts {
|
for _, h := range Hosts {
|
||||||
if username != "" {
|
h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1)
|
||||||
h = strings.Replace(h, "{{ preferred_username }}", username, 1)
|
|
||||||
}
|
|
||||||
if h == host {
|
if h == host {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,14 +12,16 @@ var (
|
||||||
RDGId: "myid",
|
RDGId: "myid",
|
||||||
TargetServer: "my.remote.server",
|
TargetServer: "my.remote.server",
|
||||||
RemoteAddr: "10.0.0.1",
|
RemoteAddr: "10.0.0.1",
|
||||||
UserName: "Frank",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"}
|
hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckHost(t *testing.T) {
|
func TestCheckHost(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
|
info.User = common.NewUser()
|
||||||
|
info.User.SetUserName("MYNAME")
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info)
|
||||||
|
|
||||||
Hosts = hosts
|
Hosts = hosts
|
||||||
|
|
||||||
|
@ -40,14 +42,7 @@ func TestCheckHost(t *testing.T) {
|
||||||
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
host = "my-Frank-host:3389"
|
host = "my-MYNAME-host:3389"
|
||||||
if ok, err := CheckHost(ctx, host); !ok {
|
|
||||||
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info.UserName = ""
|
|
||||||
ctx = context.WithValue(ctx, "preferred_username", "dummy")
|
|
||||||
host = "my-dummy-host:3389"
|
|
||||||
if ok, err := CheckHost(ctx, host); !ok {
|
if ok, err := CheckHost(ctx, host); !ok {
|
||||||
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,19 +35,21 @@ type customClaims struct {
|
||||||
|
|
||||||
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||||
return func(ctx context.Context, host string) (bool, error) {
|
return func(ctx context.Context, host string) (bool, error) {
|
||||||
s := getTunnel(ctx)
|
tunnel := getTunnel(ctx)
|
||||||
if s == nil {
|
if tunnel == nil {
|
||||||
return false, errors.New("no valid session info found in context")
|
return false, errors.New("no valid session info found in context")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.TargetServer != host {
|
if tunnel.TargetServer != host {
|
||||||
log.Printf("Client specified host %s does not match token host %s", host, s.TargetServer)
|
log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if VerifyClientIP && s.RemoteAddr != common.GetClientIp(ctx) {
|
// use identity from context rather then set by tunnel
|
||||||
|
id := common.FromCtx(ctx)
|
||||||
|
if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(common.AttrClientIp) {
|
||||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||||
common.GetClientIp(ctx), s.RemoteAddr)
|
id.GetAttribute(common.AttrClientIp), tunnel.RemoteAddr)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return next(ctx, host)
|
return next(ctx, host)
|
||||||
|
@ -106,7 +108,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
|
|
||||||
tunnel.TargetServer = custom.RemoteServer
|
tunnel.TargetServer = custom.RemoteServer
|
||||||
tunnel.RemoteAddr = custom.ClientIP
|
tunnel.RemoteAddr = custom.ClientIP
|
||||||
tunnel.UserName = user.Subject
|
tunnel.User.SetUserName(user.Subject)
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -127,10 +129,11 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
|
||||||
Subject: username,
|
Subject: username,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id := common.FromCtx(ctx)
|
||||||
private := customClaims{
|
private := customClaims{
|
||||||
RemoteServer: server,
|
RemoteServer: server,
|
||||||
ClientIP: common.GetClientIp(ctx),
|
ClientIP: id.GetAttribute(common.AttrClientIp).(string),
|
||||||
AccessToken: common.GetAccessToken(ctx),
|
AccessToken: id.GetAttribute(common.AttrAccessToken).(string),
|
||||||
}
|
}
|
||||||
|
|
||||||
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
|
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
|
||||||
|
@ -289,7 +292,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
||||||
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
|
s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("cannot get session info from context")
|
log.Printf("cannot get session info from context")
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -52,8 +52,12 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if !res.Authenticated {
|
if !res.Authenticated {
|
||||||
log.Printf("User %s is not authenticated for this service", username)
|
log.Printf("User %s is not authenticated for this service", username)
|
||||||
} else {
|
} else {
|
||||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
|
log.Printf("User %s authenticated", username)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
id := common.FromRequestCtx(r)
|
||||||
|
id.SetUserName(username)
|
||||||
|
id.SetAuthenticated(true)
|
||||||
|
id.SetAuthTime(time.Now())
|
||||||
|
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
|
@ -17,6 +16,8 @@ import (
|
||||||
const (
|
const (
|
||||||
CacheExpiration = time.Minute * 2
|
CacheExpiration = time.Minute * 2
|
||||||
CleanupInterval = time.Minute * 5
|
CleanupInterval = time.Minute * 5
|
||||||
|
sessionKeyAuthenticated = "authenticated"
|
||||||
|
oidcKeyUserName = "preferred_username"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OIDC struct {
|
type OIDC struct {
|
||||||
|
@ -90,10 +91,14 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id := common.FromRequestCtx(r)
|
||||||
|
id.SetUserName(data[oidcKeyUserName].(string))
|
||||||
|
id.SetAuthenticated(true)
|
||||||
|
id.SetAuthTime(time.Now())
|
||||||
|
id.SetAttribute(common.AttrAccessToken, oauth2Token.AccessToken)
|
||||||
|
|
||||||
session.Options.MaxAge = MaxAge
|
session.Options.MaxAge = MaxAge
|
||||||
session.Values["preferred_username"] = data["preferred_username"]
|
session.Values[common.CTXKey] = id
|
||||||
session.Values["authenticated"] = true
|
|
||||||
session.Values["access_token"] = oauth2Token.AccessToken
|
|
||||||
|
|
||||||
if err = session.Save(r, w); err != nil {
|
if err = session.Save(r, w); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -110,8 +115,8 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
found := session.Values["authenticated"]
|
id := session.Values[common.CTXKey].(common.Identity)
|
||||||
if found == nil || !found.(bool) {
|
if id == nil {
|
||||||
seed := make([]byte, 16)
|
seed := make([]byte, 16)
|
||||||
rand.Read(seed)
|
rand.Read(seed)
|
||||||
state := hex.EncodeToString(seed)
|
state := hex.EncodeToString(seed)
|
||||||
|
@ -120,9 +125,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
|
// replace the identity with the one from the sessions
|
||||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
next.ServeHTTP(w, common.AddToRequestCtx(id, r))
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue