mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-25 09:33:28 +02:00
Split web api so it becomes more testable and maintainable
This commit is contained in:
parent
2a2edaa21c
commit
0c5f93e810
9 changed files with 471 additions and 367 deletions
|
@ -1,293 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
RdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
)
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||
|
||||
type Config struct {
|
||||
SessionKey []byte
|
||||
SessionEncryptionKey []byte
|
||||
SessionStore string
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
QueryInfo QueryInfoFunc
|
||||
QueryTokenIssuer string
|
||||
EnableUserToken bool
|
||||
OAuth2Config *oauth2.Config
|
||||
store sessions.Store
|
||||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||
stateStore *cache.Cache
|
||||
Hosts []string
|
||||
HostSelection string
|
||||
GatewayAddress *url.URL
|
||||
UsernameTemplate string
|
||||
NetworkAutoDetect int
|
||||
BandwidthAutoDetect int
|
||||
ConnectionType int
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
SocketAddress string
|
||||
Authentication string
|
||||
}
|
||||
|
||||
func (c *Config) NewApi() {
|
||||
if len(c.SessionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
if len(c.Hosts) < 1 {
|
||||
log.Fatal("Not enough hosts to connect to specified")
|
||||
}
|
||||
if c.SessionStore == "file" {
|
||||
log.Println("Filesystem is used as session storage")
|
||||
c.store = sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
|
||||
} else {
|
||||
log.Println("Cookies are used as session storage")
|
||||
c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
||||
}
|
||||
c.stateStore = cache.New(time.Minute*2, 5*time.Minute)
|
||||
}
|
||||
|
||||
func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state := r.URL.Query().Get("state")
|
||||
s, found := c.stateStore.Get(state)
|
||||
if !found {
|
||||
http.Error(w, "unknown state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := s.(string)
|
||||
|
||||
ctx := context.Background()
|
||||
oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||
}{oauth2Token, new(json.RawMessage)}
|
||||
|
||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := c.store.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session.Options.MaxAge = MaxAge
|
||||
session.Values["preferred_username"] = data["preferred_username"]
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["access_token"] = oauth2Token.AccessToken
|
||||
|
||||
if err = session.Save(r, w); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (c *Config) Authenticated(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := c.store.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
found := session.Values["authenticated"]
|
||||
if found == nil || !found.(bool) {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
state := hex.EncodeToString(seed)
|
||||
c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
||||
http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) selectRandomHost() string {
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||
return host
|
||||
}
|
||||
|
||||
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
switch c.HostSelection {
|
||||
case "roundrobin":
|
||||
return c.selectRandomHost(), nil
|
||||
case "signed":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
found := false
|
||||
for _, check := range c.Hosts {
|
||||
if check == host {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Printf("Invalid host %s specified in token", hosts[0])
|
||||
return "", errors.New("invalid host specified in query token")
|
||||
}
|
||||
return host, nil
|
||||
case "unsigned":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
for _, check := range c.Hosts {
|
||||
if check == hosts[0] {
|
||||
return hosts[0], nil
|
||||
}
|
||||
}
|
||||
// not found
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
case "any":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
return hosts[0], nil
|
||||
default:
|
||||
return c.selectRandomHost(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
||||
if !ok {
|
||||
log.Printf("preferred_username not found in context")
|
||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// determine host to connect to
|
||||
host, err := c.getHost(ctx, r.URL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
|
||||
// split the username into user and domain
|
||||
var user = userName
|
||||
var domain = c.DefaultDomain
|
||||
if c.SplitUserDomain {
|
||||
creds := strings.SplitN(userName, "@", 2)
|
||||
user = creds[0]
|
||||
if len(creds) > 1 {
|
||||
domain = creds[1]
|
||||
}
|
||||
}
|
||||
|
||||
render := user
|
||||
if c.UsernameTemplate != "" {
|
||||
render = fmt.Sprintf(c.UsernameTemplate)
|
||||
render = strings.Replace(render, "{{ username }}", user, 1)
|
||||
if c.UsernameTemplate == render {
|
||||
log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
|
||||
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := c.PAATokenGenerator(ctx, user, host)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if c.EnableUserToken {
|
||||
userToken, err := c.UserTokenGenerator(ctx, user)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
render = strings.Replace(render, "{{ token }}", userToken, 1)
|
||||
}
|
||||
|
||||
// authenticated
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
fn := hex.EncodeToString(seed) + ".rdp"
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||
w.Header().Set("Content-Type", "application/x-rdp")
|
||||
data := "full address:s:" + host + "\r\n" +
|
||||
"gatewayhostname:s:" + c.GatewayAddress.Host + "\r\n" +
|
||||
"gatewaycredentialssource:i:5\r\n" +
|
||||
"gatewayusagemethod:i:1\r\n" +
|
||||
"gatewayprofileusagemethod:i:1\r\n" +
|
||||
"gatewayaccesstoken:s:" + token + "\r\n" +
|
||||
"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
|
||||
"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
|
||||
"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
|
||||
"username:s:" + render + "\r\n" +
|
||||
"domain:s:" + domain + "\r\n" +
|
||||
"bitmapcachesize:i:32000\r\n" +
|
||||
"smart sizing:i:1\r\n"
|
||||
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
||||
}
|
|
@ -31,7 +31,7 @@ type ServerConfig struct {
|
|||
SessionStore string `koanf:"sessionstore"`
|
||||
SendBuf int `koanf:"sendbuf"`
|
||||
ReceiveBuf int `koanf:"receivebuf"`
|
||||
Tls string `koanf:"disabletls"`
|
||||
Tls string `koanf:"tls"`
|
||||
Authentication string `koanf:"authentication"`
|
||||
AuthSocket string `koanf:"authsocket"`
|
||||
}
|
||||
|
|
|
@ -4,12 +4,13 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/thought-machine/go-flags"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
|
@ -27,17 +28,56 @@ var opts struct {
|
|||
|
||||
var conf config.Configuration
|
||||
|
||||
func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
|
||||
// set oidc config
|
||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||
}
|
||||
oidcConfig := &oidc.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
}
|
||||
verifier := provider.Verifier(oidcConfig)
|
||||
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
ClientSecret: conf.OpenId.ClientSecret,
|
||||
RedirectURL: callbackUrl.String(),
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
security.OIDCProvider = provider
|
||||
security.Oauth2Config = oauthConfig
|
||||
|
||||
o := web.OIDCConfig{
|
||||
OAuth2Config: &oauthConfig,
|
||||
OIDCTokenVerifier: verifier,
|
||||
SessionStore: store,
|
||||
}
|
||||
|
||||
return o.New()
|
||||
}
|
||||
|
||||
func main() {
|
||||
// get config
|
||||
// load config
|
||||
_, err := flags.Parse(&opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
conf = config.Load(opts.ConfigFile)
|
||||
|
||||
security.VerifyClientIP = conf.Security.VerifyClientIp
|
||||
// set callback url and external advertised gateway address
|
||||
url, err := url.Parse(conf.Server.GatewayAddress)
|
||||
if err != nil {
|
||||
log.Printf("Cannot parse server gateway address %s due to %s", url, err)
|
||||
}
|
||||
if url.Scheme == "" {
|
||||
url.Scheme = "https"
|
||||
}
|
||||
url.Path = "callback"
|
||||
|
||||
// set security keys
|
||||
// set security options
|
||||
security.VerifyClientIP = conf.Security.VerifyClientIp
|
||||
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
|
||||
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
||||
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
||||
|
@ -46,66 +86,39 @@ func main() {
|
|||
security.HostSelection = conf.Server.HostSelection
|
||||
security.Hosts = conf.Server.Hosts
|
||||
|
||||
// configure api
|
||||
api := &api.Config{
|
||||
QueryInfo: security.QueryInfo,
|
||||
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||
EnableUserToken: conf.Security.EnableUserToken,
|
||||
// init session store
|
||||
sessionConf := web.SessionManagerConf{
|
||||
SessionKey: []byte(conf.Server.SessionKey),
|
||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||
SessionStore: conf.Server.SessionStore,
|
||||
Hosts: conf.Server.Hosts,
|
||||
HostSelection: conf.Server.HostSelection,
|
||||
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
||||
UsernameTemplate: conf.Client.UsernameTemplate,
|
||||
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
||||
ConnectionType: conf.Client.ConnectionType,
|
||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||
DefaultDomain: conf.Client.DefaultDomain,
|
||||
SocketAddress: conf.Server.AuthSocket,
|
||||
Authentication: conf.Server.Authentication,
|
||||
StoreType: conf.Server.SessionStore,
|
||||
}
|
||||
store := sessionConf.Init()
|
||||
|
||||
// configure web backend
|
||||
w := &web.Config{
|
||||
QueryInfo: security.QueryInfo,
|
||||
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||
EnableUserToken: conf.Security.EnableUserToken,
|
||||
SessionStore: store,
|
||||
Hosts: conf.Server.Hosts,
|
||||
HostSelection: conf.Server.HostSelection,
|
||||
RdpOpts: web.RdpOpts{
|
||||
UsernameTemplate: conf.Client.UsernameTemplate,
|
||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||
DefaultDomain: conf.Client.DefaultDomain,
|
||||
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
||||
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
||||
ConnectionType: conf.Client.ConnectionType,
|
||||
},
|
||||
}
|
||||
|
||||
if conf.Caps.TokenAuth {
|
||||
api.PAATokenGenerator = security.GeneratePAAToken
|
||||
w.PAATokenGenerator = security.GeneratePAAToken
|
||||
}
|
||||
if conf.Security.EnableUserToken {
|
||||
api.UserTokenGenerator = security.GenerateUserToken
|
||||
w.UserTokenGenerator = security.GenerateUserToken
|
||||
}
|
||||
|
||||
// get callback url and external advertised gateway address
|
||||
url, err := url.Parse(conf.Server.GatewayAddress)
|
||||
if url.Scheme == "" {
|
||||
url.Scheme = "https"
|
||||
}
|
||||
url.Path = "callback"
|
||||
|
||||
if conf.Server.Authentication == "openid" {
|
||||
// set oidc config
|
||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||
}
|
||||
oidcConfig := &oidc.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
}
|
||||
verifier := provider.Verifier(oidcConfig)
|
||||
|
||||
api.GatewayAddress = url
|
||||
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
ClientSecret: conf.OpenId.ClientSecret,
|
||||
RedirectURL: url.String(),
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
security.OIDCProvider = provider
|
||||
security.Oauth2Config = oauthConfig
|
||||
api.OAuth2Config = &oauthConfig
|
||||
api.OIDCTokenVerifier = verifier
|
||||
}
|
||||
api.NewApi()
|
||||
h := w.NewHandler()
|
||||
|
||||
log.Printf("Starting remote desktop gateway server")
|
||||
cfg := &tls.Config{}
|
||||
|
@ -151,7 +164,7 @@ func main() {
|
|||
cfg.GetCertificate = certMgr.GetCertificate
|
||||
|
||||
go func() {
|
||||
http.ListenAndServe(":http", certMgr.HTTPHandler(nil))
|
||||
http.ListenAndServe(":80", certMgr.HTTPHandler(nil))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
@ -190,15 +203,17 @@ func main() {
|
|||
}
|
||||
|
||||
if conf.Server.Authentication == "local" {
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
|
||||
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol)))
|
||||
} else {
|
||||
// openid
|
||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||
oidc := initOIDC(url, store)
|
||||
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
http.HandleFunc("/callback", oidc.HandleCallback)
|
||||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
||||
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||
|
||||
if conf.Server.Tls == "disabled" {
|
||||
err = server.ListenAndServe()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -15,13 +15,17 @@ const (
|
|||
protocol = "unix"
|
||||
)
|
||||
|
||||
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
type BasicAuthHandler struct {
|
||||
SocketAddress string
|
||||
}
|
||||
|
||||
func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if ok {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return net.Dial(protocol, addr)
|
||||
}))
|
127
cmd/rdpgw/web/oidc.go
Normal file
127
cmd/rdpgw/web/oidc.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CacheExpiration = time.Minute * 2
|
||||
CleanupInterval = time.Minute * 5
|
||||
)
|
||||
|
||||
type OIDC struct {
|
||||
oAuth2Config *oauth2.Config
|
||||
oidcTokenVerifier *oidc.IDTokenVerifier
|
||||
stateStore *cache.Cache
|
||||
sessionStore sessions.Store
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
OAuth2Config *oauth2.Config
|
||||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||
SessionStore sessions.Store
|
||||
}
|
||||
|
||||
func (c *OIDCConfig) New() *OIDC {
|
||||
return &OIDC{
|
||||
oAuth2Config: c.OAuth2Config,
|
||||
oidcTokenVerifier: c.OIDCTokenVerifier,
|
||||
stateStore: cache.New(CacheExpiration, CleanupInterval),
|
||||
sessionStore: c.SessionStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state := r.URL.Query().Get("state")
|
||||
s, found := h.stateStore.Get(state)
|
||||
if !found {
|
||||
http.Error(w, "unknown state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := s.(string)
|
||||
|
||||
ctx := r.Context()
|
||||
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||
}{oauth2Token, new(json.RawMessage)}
|
||||
|
||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session.Options.MaxAge = MaxAge
|
||||
session.Values["preferred_username"] = data["preferred_username"]
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["access_token"] = oauth2Token.AccessToken
|
||||
|
||||
if err = session.Save(r, w); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
found := session.Values["authenticated"]
|
||||
if found == nil || !found.(bool) {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
state := hex.EncodeToString(seed)
|
||||
h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
||||
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
30
cmd/rdpgw/web/session.go
Normal file
30
cmd/rdpgw/web/session.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"github.com/gorilla/sessions"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
type SessionManagerConf struct {
|
||||
SessionKey []byte
|
||||
SessionEncryptionKey []byte
|
||||
StoreType string
|
||||
}
|
||||
|
||||
func (c *SessionManagerConf) Init() sessions.Store {
|
||||
if len(c.SessionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
|
||||
if c.StoreType == "file" {
|
||||
log.Println("Filesystem is used as session storage")
|
||||
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
|
||||
} else {
|
||||
log.Println("Cookies are used as session storage")
|
||||
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||
func TokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
|
||||
return
|
||||
|
@ -37,4 +37,4 @@ func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "cannot encode json", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
215
cmd/rdpgw/web/web.go
Normal file
215
cmd/rdpgw/web/web.go
Normal file
|
@ -0,0 +1,215 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/sessions"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
RdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
)
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||
|
||||
type Config struct {
|
||||
SessionStore sessions.Store
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
QueryInfo QueryInfoFunc
|
||||
QueryTokenIssuer string
|
||||
EnableUserToken bool
|
||||
Hosts []string
|
||||
HostSelection string
|
||||
GatewayAddress *url.URL
|
||||
RdpOpts RdpOpts
|
||||
}
|
||||
|
||||
type RdpOpts struct {
|
||||
UsernameTemplate string
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
NetworkAutoDetect int
|
||||
BandwidthAutoDetect int
|
||||
ConnectionType int
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
sessionStore sessions.Store
|
||||
paaTokenGenerator TokenGeneratorFunc
|
||||
enableUserToken bool
|
||||
userTokenGenerator UserTokenGeneratorFunc
|
||||
queryInfo QueryInfoFunc
|
||||
queryTokenIssuer string
|
||||
gatewayAddress *url.URL
|
||||
hosts []string
|
||||
hostSelection string
|
||||
rdpOpts RdpOpts
|
||||
}
|
||||
|
||||
func (c *Config) NewHandler() *Handler {
|
||||
if len(c.Hosts) < 1 {
|
||||
log.Fatal("Not enough hosts to connect to specified")
|
||||
}
|
||||
return &Handler{
|
||||
sessionStore: c.SessionStore,
|
||||
paaTokenGenerator: c.PAATokenGenerator,
|
||||
enableUserToken: c.EnableUserToken,
|
||||
userTokenGenerator: c.UserTokenGenerator,
|
||||
queryInfo: c.QueryInfo,
|
||||
queryTokenIssuer: c.QueryTokenIssuer,
|
||||
gatewayAddress: c.GatewayAddress,
|
||||
hosts: c.Hosts,
|
||||
hostSelection: c.HostSelection,
|
||||
rdpOpts: c.RdpOpts,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) selectRandomHost() string {
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := h.hosts[rand.Intn(len(h.hosts))]
|
||||
return host
|
||||
}
|
||||
|
||||
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
switch h.hostSelection {
|
||||
case "roundrobin":
|
||||
return h.selectRandomHost(), nil
|
||||
case "signed":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
host, err := h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
found := false
|
||||
for _, check := range h.hosts {
|
||||
if check == host {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Printf("Invalid host %s specified in token", hosts[0])
|
||||
return "", errors.New("invalid host specified in query token")
|
||||
}
|
||||
return host, nil
|
||||
case "unsigned":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
for _, check := range h.hosts {
|
||||
if check == hosts[0] {
|
||||
return hosts[0], nil
|
||||
}
|
||||
}
|
||||
// not found
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
case "any":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
return hosts[0], nil
|
||||
default:
|
||||
return h.selectRandomHost(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
||||
opts := h.rdpOpts
|
||||
|
||||
if !ok {
|
||||
log.Printf("preferred_username not found in context")
|
||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// determine host to connect to
|
||||
host, err := h.getHost(ctx, r.URL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
|
||||
// split the username into user and domain
|
||||
var user = userName
|
||||
var domain = opts.DefaultDomain
|
||||
if opts.SplitUserDomain {
|
||||
creds := strings.SplitN(userName, "@", 2)
|
||||
user = creds[0]
|
||||
if len(creds) > 1 {
|
||||
domain = creds[1]
|
||||
}
|
||||
}
|
||||
|
||||
render := user
|
||||
if opts.UsernameTemplate != "" {
|
||||
render = fmt.Sprintf(h.rdpOpts.UsernameTemplate)
|
||||
render = strings.Replace(render, "{{ username }}", user, 1)
|
||||
if h.rdpOpts.UsernameTemplate == render {
|
||||
log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user)
|
||||
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := h.paaTokenGenerator(ctx, user, host)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if h.enableUserToken {
|
||||
userToken, err := h.userTokenGenerator(ctx, user)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
render = strings.Replace(render, "{{ token }}", userToken, 1)
|
||||
}
|
||||
|
||||
// authenticated
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
fn := hex.EncodeToString(seed) + ".rdp"
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||
w.Header().Set("Content-Type", "application/x-rdp")
|
||||
|
||||
data := "full address:s:" + host + "\r\n" +
|
||||
"gatewayhostname:s:" + h.gatewayAddress.Host + "\r\n" +
|
||||
"gatewaycredentialssource:i:5\r\n" +
|
||||
"gatewayusagemethod:i:1\r\n" +
|
||||
"gatewayprofileusagemethod:i:1\r\n" +
|
||||
"gatewayaccesstoken:s:" + token + "\r\n" +
|
||||
"networkautodetect:i:" + strconv.Itoa(opts.NetworkAutoDetect) + "\r\n" +
|
||||
"bandwidthautodetect:i:" + strconv.Itoa(opts.BandwidthAutoDetect) + "\r\n" +
|
||||
"connection type:i:" + strconv.Itoa(opts.ConnectionType) + "\r\n" +
|
||||
"username:s:" + render + "\r\n" +
|
||||
"domain:s:" + domain + "\r\n" +
|
||||
"bitmapcachesize:i:32000\r\n" +
|
||||
"smart sizing:i:1\r\n"
|
||||
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -27,12 +27,14 @@ func TestGetHost(t *testing.T) {
|
|||
HostSelection: "roundrobin",
|
||||
Hosts: hosts,
|
||||
}
|
||||
h := c.NewHandler()
|
||||
|
||||
u := &url.URL{
|
||||
Host: "example.com",
|
||||
}
|
||||
vals := u.Query()
|
||||
|
||||
host, err := c.getHost(ctx, u)
|
||||
host, err := h.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("#{err}")
|
||||
}
|
||||
|
@ -44,14 +46,16 @@ func TestGetHost(t *testing.T) {
|
|||
c.HostSelection = "unsigned"
|
||||
vals.Set("host", "in.valid.host")
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
h = c.NewHandler()
|
||||
host, err = h.getHost(ctx, u)
|
||||
if err == nil {
|
||||
t.Fatalf("Accepted host %s is not in hosts list", host)
|
||||
}
|
||||
|
||||
vals.Set("host", hosts[0])
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
h = c.NewHandler()
|
||||
host, err = h.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||
}
|
||||
|
@ -64,7 +68,8 @@ func TestGetHost(t *testing.T) {
|
|||
test := "bla.bla.com"
|
||||
vals.Set("host", test)
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
h = c.NewHandler()
|
||||
host, err = h.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("%s is not accepted", host)
|
||||
}
|
||||
|
@ -83,7 +88,8 @@ func TestGetHost(t *testing.T) {
|
|||
}
|
||||
vals.Set("host", queryToken)
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(ctx, u)
|
||||
h = c.NewHandler()
|
||||
host, err = h.getHost(ctx, u)
|
||||
if err != nil {
|
||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue