Split web api so it becomes more testable and maintainable

This commit is contained in:
Bolke de Bruin 2022-09-06 12:14:08 +02:00
parent 2a2edaa21c
commit 0c5f93e810
9 changed files with 471 additions and 367 deletions

View file

@ -1,293 +0,0 @@
package api
import (
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"log"
"math/rand"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
)
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct {
SessionKey []byte
SessionEncryptionKey []byte
SessionStore string
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc
QueryTokenIssuer string
EnableUserToken bool
OAuth2Config *oauth2.Config
store sessions.Store
OIDCTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
Hosts []string
HostSelection string
GatewayAddress *url.URL
UsernameTemplate string
NetworkAutoDetect int
BandwidthAutoDetect int
ConnectionType int
SplitUserDomain bool
DefaultDomain string
SocketAddress string
Authentication string
}
func (c *Config) NewApi() {
if len(c.SessionKey) < 32 {
log.Fatal("Session key too small")
}
if len(c.Hosts) < 1 {
log.Fatal("Not enough hosts to connect to specified")
}
if c.SessionStore == "file" {
log.Println("Filesystem is used as session storage")
c.store = sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
} else {
log.Println("Cookies are used as session storage")
c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
}
c.stateStore = cache.New(time.Minute*2, 5*time.Minute)
}
func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
s, found := c.stateStore.Get(state)
if !found {
http.Error(w, "unknown state", http.StatusBadRequest)
return
}
url := s.(string)
ctx := context.Background()
oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return
}
idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return
}
resp := struct {
OAuth2Token *oauth2.Token
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var data map[string]interface{}
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, err := c.store.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Redirect(w, r, url, http.StatusFound)
}
func (c *Config) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := c.store.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
found := session.Values["authenticated"]
if found == nil || !found.(bool) {
seed := make([]byte, 16)
rand.Read(seed)
state := hex.EncodeToString(seed)
c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (c *Config) selectRandomHost() string {
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
return host
}
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
switch c.HostSelection {
case "roundrobin":
return c.selectRandomHost(), nil
case "signed":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
if err != nil {
return "", err
}
found := false
for _, check := range c.Hosts {
if check == host {
found = true
break
}
}
if !found {
log.Printf("Invalid host %s specified in token", hosts[0])
return "", errors.New("invalid host specified in query token")
}
return host, nil
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
for _, check := range c.Hosts {
if check == hosts[0] {
return hosts[0], nil
}
}
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
return hosts[0], nil
default:
return c.selectRandomHost(), nil
}
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
if !ok {
log.Printf("preferred_username not found in context")
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return
}
// determine host to connect to
host, err := c.getHost(ctx, r.URL)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
// split the username into user and domain
var user = userName
var domain = c.DefaultDomain
if c.SplitUserDomain {
creds := strings.SplitN(userName, "@", 2)
user = creds[0]
if len(creds) > 1 {
domain = creds[1]
}
}
render := user
if c.UsernameTemplate != "" {
render = fmt.Sprintf(c.UsernameTemplate)
render = strings.Replace(render, "{{ username }}", user, 1)
if c.UsernameTemplate == render {
log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
return
}
}
token, err := c.PAATokenGenerator(ctx, user, host)
if err != nil {
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
if c.EnableUserToken {
userToken, err := c.UserTokenGenerator(ctx, user)
if err != nil {
log.Printf("Cannot generate token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
render = strings.Replace(render, "{{ token }}", userToken, 1)
}
// authenticated
seed := make([]byte, 16)
rand.Read(seed)
fn := hex.EncodeToString(seed) + ".rdp"
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
w.Header().Set("Content-Type", "application/x-rdp")
data := "full address:s:" + host + "\r\n" +
"gatewayhostname:s:" + c.GatewayAddress.Host + "\r\n" +
"gatewaycredentialssource:i:5\r\n" +
"gatewayusagemethod:i:1\r\n" +
"gatewayprofileusagemethod:i:1\r\n" +
"gatewayaccesstoken:s:" + token + "\r\n" +
"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
"username:s:" + render + "\r\n" +
"domain:s:" + domain + "\r\n" +
"bitmapcachesize:i:32000\r\n" +
"smart sizing:i:1\r\n"
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
}

View file

@ -31,7 +31,7 @@ type ServerConfig struct {
SessionStore string `koanf:"sessionstore"` SessionStore string `koanf:"sessionstore"`
SendBuf int `koanf:"sendbuf"` SendBuf int `koanf:"sendbuf"`
ReceiveBuf int `koanf:"receivebuf"` ReceiveBuf int `koanf:"receivebuf"`
Tls string `koanf:"disabletls"` Tls string `koanf:"tls"`
Authentication string `koanf:"authentication"` Authentication string `koanf:"authentication"`
AuthSocket string `koanf:"authsocket"` AuthSocket string `koanf:"authsocket"`
} }

View file

@ -4,12 +4,13 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/thought-machine/go-flags" "github.com/thought-machine/go-flags"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
@ -27,60 +28,7 @@ var opts struct {
var conf config.Configuration var conf config.Configuration
func main() { func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
// get config
_, err := flags.Parse(&opts)
if err != nil {
panic(err)
}
conf = config.Load(opts.ConfigFile)
security.VerifyClientIP = conf.Security.VerifyClientIp
// set security keys
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
security.HostSelection = conf.Server.HostSelection
security.Hosts = conf.Server.Hosts
// configure api
api := &api.Config{
QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken,
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
SessionStore: conf.Server.SessionStore,
Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection,
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
UsernameTemplate: conf.Client.UsernameTemplate,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
ConnectionType: conf.Client.ConnectionType,
SplitUserDomain: conf.Client.SplitUserDomain,
DefaultDomain: conf.Client.DefaultDomain,
SocketAddress: conf.Server.AuthSocket,
Authentication: conf.Server.Authentication,
}
if conf.Caps.TokenAuth {
api.PAATokenGenerator = security.GeneratePAAToken
}
if conf.Security.EnableUserToken {
api.UserTokenGenerator = security.GenerateUserToken
}
// get callback url and external advertised gateway address
url, err := url.Parse(conf.Server.GatewayAddress)
if url.Scheme == "" {
url.Scheme = "https"
}
url.Path = "callback"
if conf.Server.Authentication == "openid" {
// set oidc config // set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
if err != nil { if err != nil {
@ -91,21 +39,86 @@ func main() {
} }
verifier := provider.Verifier(oidcConfig) verifier := provider.Verifier(oidcConfig)
api.GatewayAddress = url
oauthConfig := oauth2.Config{ oauthConfig := oauth2.Config{
ClientID: conf.OpenId.ClientId, ClientID: conf.OpenId.ClientId,
ClientSecret: conf.OpenId.ClientSecret, ClientSecret: conf.OpenId.ClientSecret,
RedirectURL: url.String(), RedirectURL: callbackUrl.String(),
Endpoint: provider.Endpoint(), Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
} }
security.OIDCProvider = provider security.OIDCProvider = provider
security.Oauth2Config = oauthConfig security.Oauth2Config = oauthConfig
api.OAuth2Config = &oauthConfig
api.OIDCTokenVerifier = verifier o := web.OIDCConfig{
OAuth2Config: &oauthConfig,
OIDCTokenVerifier: verifier,
SessionStore: store,
} }
api.NewApi()
return o.New()
}
func main() {
// load config
_, err := flags.Parse(&opts)
if err != nil {
panic(err)
}
conf = config.Load(opts.ConfigFile)
// set callback url and external advertised gateway address
url, err := url.Parse(conf.Server.GatewayAddress)
if err != nil {
log.Printf("Cannot parse server gateway address %s due to %s", url, err)
}
if url.Scheme == "" {
url.Scheme = "https"
}
url.Path = "callback"
// set security options
security.VerifyClientIP = conf.Security.VerifyClientIp
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey)
security.HostSelection = conf.Server.HostSelection
security.Hosts = conf.Server.Hosts
// init session store
sessionConf := web.SessionManagerConf{
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
StoreType: conf.Server.SessionStore,
}
store := sessionConf.Init()
// configure web backend
w := &web.Config{
QueryInfo: security.QueryInfo,
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
EnableUserToken: conf.Security.EnableUserToken,
SessionStore: store,
Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection,
RdpOpts: web.RdpOpts{
UsernameTemplate: conf.Client.UsernameTemplate,
SplitUserDomain: conf.Client.SplitUserDomain,
DefaultDomain: conf.Client.DefaultDomain,
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
ConnectionType: conf.Client.ConnectionType,
},
}
if conf.Caps.TokenAuth {
w.PAATokenGenerator = security.GeneratePAAToken
}
if conf.Security.EnableUserToken {
w.UserTokenGenerator = security.GenerateUserToken
}
h := w.NewHandler()
log.Printf("Starting remote desktop gateway server") log.Printf("Starting remote desktop gateway server")
cfg := &tls.Config{} cfg := &tls.Config{}
@ -151,7 +164,7 @@ func main() {
cfg.GetCertificate = certMgr.GetCertificate cfg.GetCertificate = certMgr.GetCertificate
go func() { go func() {
http.ListenAndServe(":http", certMgr.HTTPHandler(nil)) http.ListenAndServe(":80", certMgr.HTTPHandler(nil))
}() }()
} }
} }
@ -190,15 +203,17 @@ func main() {
} }
if conf.Server.Authentication == "local" { if conf.Server.Authentication == "local" {
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol))) h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol)))
} else { } else {
// openid // openid
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) oidc := initOIDC(url, store)
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.HandleFunc("/callback", api.HandleCallback) http.HandleFunc("/callback", oidc.HandleCallback)
} }
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo) http.HandleFunc("/tokeninfo", web.TokenInfo)
if conf.Server.Tls == "disabled" { if conf.Server.Tls == "disabled" {
err = server.ListenAndServe() err = server.ListenAndServe()

View file

@ -1,4 +1,4 @@
package api package web
import ( import (
"context" "context"
@ -15,13 +15,17 @@ const (
protocol = "unix" protocol = "unix"
) )
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc { type BasicAuthHandler struct {
SocketAddress string
}
func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if ok { if ok {
ctx := r.Context() ctx := r.Context()
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial(protocol, addr) return net.Dial(protocol, addr)
})) }))

127
cmd/rdpgw/web/oidc.go Normal file
View file

@ -0,0 +1,127 @@
package web
import (
"context"
"encoding/hex"
"encoding/json"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"math/rand"
"net/http"
"time"
)
const (
CacheExpiration = time.Minute * 2
CleanupInterval = time.Minute * 5
)
type OIDC struct {
oAuth2Config *oauth2.Config
oidcTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
sessionStore sessions.Store
}
type OIDCConfig struct {
OAuth2Config *oauth2.Config
OIDCTokenVerifier *oidc.IDTokenVerifier
SessionStore sessions.Store
}
func (c *OIDCConfig) New() *OIDC {
return &OIDC{
oAuth2Config: c.OAuth2Config,
oidcTokenVerifier: c.OIDCTokenVerifier,
stateStore: cache.New(CacheExpiration, CleanupInterval),
sessionStore: c.SessionStore,
}
}
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
s, found := h.stateStore.Get(state)
if !found {
http.Error(w, "unknown state", http.StatusBadRequest)
return
}
url := s.(string)
ctx := r.Context()
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return
}
idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return
}
resp := struct {
OAuth2Token *oauth2.Token
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var data map[string]interface{}
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, err := h.sessionStore.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Redirect(w, r, url, http.StatusFound)
}
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := h.sessionStore.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
found := session.Values["authenticated"]
if found == nil || !found.(bool) {
seed := make([]byte, 16)
rand.Read(seed)
state := hex.EncodeToString(seed)
h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx))
})
}

30
cmd/rdpgw/web/session.go Normal file
View file

@ -0,0 +1,30 @@
package web
import (
"github.com/gorilla/sessions"
"log"
"os"
)
type SessionManagerConf struct {
SessionKey []byte
SessionEncryptionKey []byte
StoreType string
}
func (c *SessionManagerConf) Init() sessions.Store {
if len(c.SessionKey) < 32 {
log.Fatal("Session key too small")
}
if len(c.SessionEncryptionKey) < 32 {
log.Fatal("Session key too small")
}
if c.StoreType == "file" {
log.Println("Filesystem is used as session storage")
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
} else {
log.Println("Cookies are used as session storage")
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
}
}

View file

@ -1,4 +1,4 @@
package api package web
import ( import (
"context" "context"
@ -9,7 +9,7 @@ import (
"net/http" "net/http"
) )
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) { func TokenInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "Invalid request", http.StatusMethodNotAllowed) http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
return return

215
cmd/rdpgw/web/web.go Normal file
View file

@ -0,0 +1,215 @@
package web
import (
"context"
"encoding/hex"
"errors"
"fmt"
"github.com/gorilla/sessions"
"log"
"math/rand"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type QueryInfoFunc func(context.Context, string, string) (string, error)
type Config struct {
SessionStore sessions.Store
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
QueryInfo QueryInfoFunc
QueryTokenIssuer string
EnableUserToken bool
Hosts []string
HostSelection string
GatewayAddress *url.URL
RdpOpts RdpOpts
}
type RdpOpts struct {
UsernameTemplate string
SplitUserDomain bool
DefaultDomain string
NetworkAutoDetect int
BandwidthAutoDetect int
ConnectionType int
}
type Handler struct {
sessionStore sessions.Store
paaTokenGenerator TokenGeneratorFunc
enableUserToken bool
userTokenGenerator UserTokenGeneratorFunc
queryInfo QueryInfoFunc
queryTokenIssuer string
gatewayAddress *url.URL
hosts []string
hostSelection string
rdpOpts RdpOpts
}
func (c *Config) NewHandler() *Handler {
if len(c.Hosts) < 1 {
log.Fatal("Not enough hosts to connect to specified")
}
return &Handler{
sessionStore: c.SessionStore,
paaTokenGenerator: c.PAATokenGenerator,
enableUserToken: c.EnableUserToken,
userTokenGenerator: c.UserTokenGenerator,
queryInfo: c.QueryInfo,
queryTokenIssuer: c.QueryTokenIssuer,
gatewayAddress: c.GatewayAddress,
hosts: c.Hosts,
hostSelection: c.HostSelection,
rdpOpts: c.RdpOpts,
}
}
func (h *Handler) selectRandomHost() string {
rand.Seed(time.Now().Unix())
host := h.hosts[rand.Intn(len(h.hosts))]
return host
}
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
switch h.hostSelection {
case "roundrobin":
return h.selectRandomHost(), nil
case "signed":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host, err := h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
if err != nil {
return "", err
}
found := false
for _, check := range h.hosts {
if check == host {
found = true
break
}
}
if !found {
log.Printf("Invalid host %s specified in token", hosts[0])
return "", errors.New("invalid host specified in query token")
}
return host, nil
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
for _, check := range h.hosts {
if check == hosts[0] {
return hosts[0], nil
}
}
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
return hosts[0], nil
default:
return h.selectRandomHost(), nil
}
}
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
opts := h.rdpOpts
if !ok {
log.Printf("preferred_username not found in context")
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return
}
// determine host to connect to
host, err := h.getHost(ctx, r.URL)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
// split the username into user and domain
var user = userName
var domain = opts.DefaultDomain
if opts.SplitUserDomain {
creds := strings.SplitN(userName, "@", 2)
user = creds[0]
if len(creds) > 1 {
domain = creds[1]
}
}
render := user
if opts.UsernameTemplate != "" {
render = fmt.Sprintf(h.rdpOpts.UsernameTemplate)
render = strings.Replace(render, "{{ username }}", user, 1)
if h.rdpOpts.UsernameTemplate == render {
log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user)
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
return
}
}
token, err := h.paaTokenGenerator(ctx, user, host)
if err != nil {
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
if h.enableUserToken {
userToken, err := h.userTokenGenerator(ctx, user)
if err != nil {
log.Printf("Cannot generate token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
render = strings.Replace(render, "{{ token }}", userToken, 1)
}
// authenticated
seed := make([]byte, 16)
rand.Read(seed)
fn := hex.EncodeToString(seed) + ".rdp"
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
w.Header().Set("Content-Type", "application/x-rdp")
data := "full address:s:" + host + "\r\n" +
"gatewayhostname:s:" + h.gatewayAddress.Host + "\r\n" +
"gatewaycredentialssource:i:5\r\n" +
"gatewayusagemethod:i:1\r\n" +
"gatewayprofileusagemethod:i:1\r\n" +
"gatewayaccesstoken:s:" + token + "\r\n" +
"networkautodetect:i:" + strconv.Itoa(opts.NetworkAutoDetect) + "\r\n" +
"bandwidthautodetect:i:" + strconv.Itoa(opts.BandwidthAutoDetect) + "\r\n" +
"connection type:i:" + strconv.Itoa(opts.ConnectionType) + "\r\n" +
"username:s:" + render + "\r\n" +
"domain:s:" + domain + "\r\n" +
"bitmapcachesize:i:32000\r\n" +
"smart sizing:i:1\r\n"
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
}

View file

@ -1,4 +1,4 @@
package api package web
import ( import (
"context" "context"
@ -27,12 +27,14 @@ func TestGetHost(t *testing.T) {
HostSelection: "roundrobin", HostSelection: "roundrobin",
Hosts: hosts, Hosts: hosts,
} }
h := c.NewHandler()
u := &url.URL{ u := &url.URL{
Host: "example.com", Host: "example.com",
} }
vals := u.Query() vals := u.Query()
host, err := c.getHost(ctx, u) host, err := h.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("#{err}") t.Fatalf("#{err}")
} }
@ -44,14 +46,16 @@ func TestGetHost(t *testing.T) {
c.HostSelection = "unsigned" c.HostSelection = "unsigned"
vals.Set("host", "in.valid.host") vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u) h = c.NewHandler()
host, err = h.getHost(ctx, u)
if err == nil { if err == nil {
t.Fatalf("Accepted host %s is not in hosts list", host) t.Fatalf("Accepted host %s is not in hosts list", host)
} }
vals.Set("host", hosts[0]) vals.Set("host", hosts[0])
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u) h = c.NewHandler()
host, err = h.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
} }
@ -64,7 +68,8 @@ func TestGetHost(t *testing.T) {
test := "bla.bla.com" test := "bla.bla.com"
vals.Set("host", test) vals.Set("host", test)
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u) h = c.NewHandler()
host, err = h.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("%s is not accepted", host) t.Fatalf("%s is not accepted", host)
} }
@ -83,7 +88,8 @@ func TestGetHost(t *testing.T) {
} }
vals.Set("host", queryToken) vals.Set("host", queryToken)
u.RawQuery = vals.Encode() u.RawQuery = vals.Encode()
host, err = c.getHost(ctx, u) h = c.NewHandler()
host, err = h.getHost(ctx, u)
if err != nil { if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
} }