mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-25 17:43:27 +02:00
Split web api so it becomes more testable and maintainable
This commit is contained in:
parent
2a2edaa21c
commit
0c5f93e810
9 changed files with 471 additions and 367 deletions
|
@ -1,293 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
|
||||||
"github.com/gorilla/sessions"
|
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"log"
|
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
RdpGwSession = "RDPGWSESSION"
|
|
||||||
MaxAge = 120
|
|
||||||
)
|
|
||||||
|
|
||||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
|
||||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
|
||||||
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
SessionKey []byte
|
|
||||||
SessionEncryptionKey []byte
|
|
||||||
SessionStore string
|
|
||||||
PAATokenGenerator TokenGeneratorFunc
|
|
||||||
UserTokenGenerator UserTokenGeneratorFunc
|
|
||||||
QueryInfo QueryInfoFunc
|
|
||||||
QueryTokenIssuer string
|
|
||||||
EnableUserToken bool
|
|
||||||
OAuth2Config *oauth2.Config
|
|
||||||
store sessions.Store
|
|
||||||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
|
||||||
stateStore *cache.Cache
|
|
||||||
Hosts []string
|
|
||||||
HostSelection string
|
|
||||||
GatewayAddress *url.URL
|
|
||||||
UsernameTemplate string
|
|
||||||
NetworkAutoDetect int
|
|
||||||
BandwidthAutoDetect int
|
|
||||||
ConnectionType int
|
|
||||||
SplitUserDomain bool
|
|
||||||
DefaultDomain string
|
|
||||||
SocketAddress string
|
|
||||||
Authentication string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) NewApi() {
|
|
||||||
if len(c.SessionKey) < 32 {
|
|
||||||
log.Fatal("Session key too small")
|
|
||||||
}
|
|
||||||
if len(c.Hosts) < 1 {
|
|
||||||
log.Fatal("Not enough hosts to connect to specified")
|
|
||||||
}
|
|
||||||
if c.SessionStore == "file" {
|
|
||||||
log.Println("Filesystem is used as session storage")
|
|
||||||
c.store = sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
|
|
||||||
} else {
|
|
||||||
log.Println("Cookies are used as session storage")
|
|
||||||
c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
|
||||||
}
|
|
||||||
c.stateStore = cache.New(time.Minute*2, 5*time.Minute)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|
||||||
state := r.URL.Query().Get("state")
|
|
||||||
s, found := c.stateStore.Get(state)
|
|
||||||
if !found {
|
|
||||||
http.Error(w, "unknown state", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
url := s.(string)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := struct {
|
|
||||||
OAuth2Token *oauth2.Token
|
|
||||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
|
||||||
}{oauth2Token, new(json.RawMessage)}
|
|
||||||
|
|
||||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var data map[string]interface{}
|
|
||||||
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := c.store.Get(r, RdpGwSession)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session.Options.MaxAge = MaxAge
|
|
||||||
session.Values["preferred_username"] = data["preferred_username"]
|
|
||||||
session.Values["authenticated"] = true
|
|
||||||
session.Values["access_token"] = oauth2Token.AccessToken
|
|
||||||
|
|
||||||
if err = session.Save(r, w); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(w, r, url, http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) Authenticated(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
session, err := c.store.Get(r, RdpGwSession)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
found := session.Values["authenticated"]
|
|
||||||
if found == nil || !found.(bool) {
|
|
||||||
seed := make([]byte, 16)
|
|
||||||
rand.Read(seed)
|
|
||||||
state := hex.EncodeToString(seed)
|
|
||||||
c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
|
||||||
http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
|
||||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) selectRandomHost() string {
|
|
||||||
rand.Seed(time.Now().Unix())
|
|
||||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) getHost(ctx context.Context, u *url.URL) (string, error) {
|
|
||||||
switch c.HostSelection {
|
|
||||||
case "roundrobin":
|
|
||||||
return c.selectRandomHost(), nil
|
|
||||||
case "signed":
|
|
||||||
hosts, ok := u.Query()["host"]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("invalid query parameter")
|
|
||||||
}
|
|
||||||
host, err := c.QueryInfo(ctx, hosts[0], c.QueryTokenIssuer)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
found := false
|
|
||||||
for _, check := range c.Hosts {
|
|
||||||
if check == host {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
log.Printf("Invalid host %s specified in token", hosts[0])
|
|
||||||
return "", errors.New("invalid host specified in query token")
|
|
||||||
}
|
|
||||||
return host, nil
|
|
||||||
case "unsigned":
|
|
||||||
hosts, ok := u.Query()["host"]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("invalid query parameter")
|
|
||||||
}
|
|
||||||
for _, check := range c.Hosts {
|
|
||||||
if check == hosts[0] {
|
|
||||||
return hosts[0], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// not found
|
|
||||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
|
||||||
return "", errors.New("invalid host specified in query parameter")
|
|
||||||
case "any":
|
|
||||||
hosts, ok := u.Query()["host"]
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("invalid query parameter")
|
|
||||||
}
|
|
||||||
return hosts[0], nil
|
|
||||||
default:
|
|
||||||
return c.selectRandomHost(), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
userName, ok := ctx.Value("preferred_username").(string)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
log.Printf("preferred_username not found in context")
|
|
||||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// determine host to connect to
|
|
||||||
host, err := c.getHost(ctx, r.URL)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
|
||||||
|
|
||||||
// split the username into user and domain
|
|
||||||
var user = userName
|
|
||||||
var domain = c.DefaultDomain
|
|
||||||
if c.SplitUserDomain {
|
|
||||||
creds := strings.SplitN(userName, "@", 2)
|
|
||||||
user = creds[0]
|
|
||||||
if len(creds) > 1 {
|
|
||||||
domain = creds[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
render := user
|
|
||||||
if c.UsernameTemplate != "" {
|
|
||||||
render = fmt.Sprintf(c.UsernameTemplate)
|
|
||||||
render = strings.Replace(render, "{{ username }}", user, 1)
|
|
||||||
if c.UsernameTemplate == render {
|
|
||||||
log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
|
|
||||||
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := c.PAATokenGenerator(ctx, user, host)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
|
|
||||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.EnableUserToken {
|
|
||||||
userToken, err := c.UserTokenGenerator(ctx, user)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
|
||||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
render = strings.Replace(render, "{{ token }}", userToken, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// authenticated
|
|
||||||
seed := make([]byte, 16)
|
|
||||||
rand.Read(seed)
|
|
||||||
fn := hex.EncodeToString(seed) + ".rdp"
|
|
||||||
|
|
||||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
|
||||||
w.Header().Set("Content-Type", "application/x-rdp")
|
|
||||||
data := "full address:s:" + host + "\r\n" +
|
|
||||||
"gatewayhostname:s:" + c.GatewayAddress.Host + "\r\n" +
|
|
||||||
"gatewaycredentialssource:i:5\r\n" +
|
|
||||||
"gatewayusagemethod:i:1\r\n" +
|
|
||||||
"gatewayprofileusagemethod:i:1\r\n" +
|
|
||||||
"gatewayaccesstoken:s:" + token + "\r\n" +
|
|
||||||
"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
|
|
||||||
"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
|
|
||||||
"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
|
|
||||||
"username:s:" + render + "\r\n" +
|
|
||||||
"domain:s:" + domain + "\r\n" +
|
|
||||||
"bitmapcachesize:i:32000\r\n" +
|
|
||||||
"smart sizing:i:1\r\n"
|
|
||||||
|
|
||||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
|
||||||
}
|
|
|
@ -31,7 +31,7 @@ type ServerConfig struct {
|
||||||
SessionStore string `koanf:"sessionstore"`
|
SessionStore string `koanf:"sessionstore"`
|
||||||
SendBuf int `koanf:"sendbuf"`
|
SendBuf int `koanf:"sendbuf"`
|
||||||
ReceiveBuf int `koanf:"receivebuf"`
|
ReceiveBuf int `koanf:"receivebuf"`
|
||||||
Tls string `koanf:"disabletls"`
|
Tls string `koanf:"tls"`
|
||||||
Authentication string `koanf:"authentication"`
|
Authentication string `koanf:"authentication"`
|
||||||
AuthSocket string `koanf:"authsocket"`
|
AuthSocket string `koanf:"authsocket"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,13 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
|
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/web"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/thought-machine/go-flags"
|
"github.com/thought-machine/go-flags"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
@ -27,17 +28,56 @@ var opts struct {
|
||||||
|
|
||||||
var conf config.Configuration
|
var conf config.Configuration
|
||||||
|
|
||||||
|
func initOIDC(callbackUrl *url.URL, store sessions.Store) *web.OIDC {
|
||||||
|
// set oidc config
|
||||||
|
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||||
|
}
|
||||||
|
oidcConfig := &oidc.Config{
|
||||||
|
ClientID: conf.OpenId.ClientId,
|
||||||
|
}
|
||||||
|
verifier := provider.Verifier(oidcConfig)
|
||||||
|
|
||||||
|
oauthConfig := oauth2.Config{
|
||||||
|
ClientID: conf.OpenId.ClientId,
|
||||||
|
ClientSecret: conf.OpenId.ClientSecret,
|
||||||
|
RedirectURL: callbackUrl.String(),
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
security.OIDCProvider = provider
|
||||||
|
security.Oauth2Config = oauthConfig
|
||||||
|
|
||||||
|
o := web.OIDCConfig{
|
||||||
|
OAuth2Config: &oauthConfig,
|
||||||
|
OIDCTokenVerifier: verifier,
|
||||||
|
SessionStore: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.New()
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// get config
|
// load config
|
||||||
_, err := flags.Parse(&opts)
|
_, err := flags.Parse(&opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
conf = config.Load(opts.ConfigFile)
|
conf = config.Load(opts.ConfigFile)
|
||||||
|
|
||||||
security.VerifyClientIP = conf.Security.VerifyClientIp
|
// set callback url and external advertised gateway address
|
||||||
|
url, err := url.Parse(conf.Server.GatewayAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot parse server gateway address %s due to %s", url, err)
|
||||||
|
}
|
||||||
|
if url.Scheme == "" {
|
||||||
|
url.Scheme = "https"
|
||||||
|
}
|
||||||
|
url.Path = "callback"
|
||||||
|
|
||||||
// set security keys
|
// set security options
|
||||||
|
security.VerifyClientIP = conf.Security.VerifyClientIp
|
||||||
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
|
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
|
||||||
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
||||||
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
||||||
|
@ -46,66 +86,39 @@ func main() {
|
||||||
security.HostSelection = conf.Server.HostSelection
|
security.HostSelection = conf.Server.HostSelection
|
||||||
security.Hosts = conf.Server.Hosts
|
security.Hosts = conf.Server.Hosts
|
||||||
|
|
||||||
// configure api
|
// init session store
|
||||||
api := &api.Config{
|
sessionConf := web.SessionManagerConf{
|
||||||
QueryInfo: security.QueryInfo,
|
|
||||||
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
|
||||||
EnableUserToken: conf.Security.EnableUserToken,
|
|
||||||
SessionKey: []byte(conf.Server.SessionKey),
|
SessionKey: []byte(conf.Server.SessionKey),
|
||||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||||
SessionStore: conf.Server.SessionStore,
|
StoreType: conf.Server.SessionStore,
|
||||||
Hosts: conf.Server.Hosts,
|
}
|
||||||
HostSelection: conf.Server.HostSelection,
|
store := sessionConf.Init()
|
||||||
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
|
||||||
UsernameTemplate: conf.Client.UsernameTemplate,
|
// configure web backend
|
||||||
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
w := &web.Config{
|
||||||
ConnectionType: conf.Client.ConnectionType,
|
QueryInfo: security.QueryInfo,
|
||||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
QueryTokenIssuer: conf.Security.QueryTokenIssuer,
|
||||||
DefaultDomain: conf.Client.DefaultDomain,
|
EnableUserToken: conf.Security.EnableUserToken,
|
||||||
SocketAddress: conf.Server.AuthSocket,
|
SessionStore: store,
|
||||||
Authentication: conf.Server.Authentication,
|
Hosts: conf.Server.Hosts,
|
||||||
|
HostSelection: conf.Server.HostSelection,
|
||||||
|
RdpOpts: web.RdpOpts{
|
||||||
|
UsernameTemplate: conf.Client.UsernameTemplate,
|
||||||
|
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||||
|
DefaultDomain: conf.Client.DefaultDomain,
|
||||||
|
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
||||||
|
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
||||||
|
ConnectionType: conf.Client.ConnectionType,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.Caps.TokenAuth {
|
if conf.Caps.TokenAuth {
|
||||||
api.PAATokenGenerator = security.GeneratePAAToken
|
w.PAATokenGenerator = security.GeneratePAAToken
|
||||||
}
|
}
|
||||||
if conf.Security.EnableUserToken {
|
if conf.Security.EnableUserToken {
|
||||||
api.UserTokenGenerator = security.GenerateUserToken
|
w.UserTokenGenerator = security.GenerateUserToken
|
||||||
}
|
}
|
||||||
|
h := w.NewHandler()
|
||||||
// get callback url and external advertised gateway address
|
|
||||||
url, err := url.Parse(conf.Server.GatewayAddress)
|
|
||||||
if url.Scheme == "" {
|
|
||||||
url.Scheme = "https"
|
|
||||||
}
|
|
||||||
url.Path = "callback"
|
|
||||||
|
|
||||||
if conf.Server.Authentication == "openid" {
|
|
||||||
// set oidc config
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
|
||||||
}
|
|
||||||
oidcConfig := &oidc.Config{
|
|
||||||
ClientID: conf.OpenId.ClientId,
|
|
||||||
}
|
|
||||||
verifier := provider.Verifier(oidcConfig)
|
|
||||||
|
|
||||||
api.GatewayAddress = url
|
|
||||||
|
|
||||||
oauthConfig := oauth2.Config{
|
|
||||||
ClientID: conf.OpenId.ClientId,
|
|
||||||
ClientSecret: conf.OpenId.ClientSecret,
|
|
||||||
RedirectURL: url.String(),
|
|
||||||
Endpoint: provider.Endpoint(),
|
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
||||||
}
|
|
||||||
security.OIDCProvider = provider
|
|
||||||
security.Oauth2Config = oauthConfig
|
|
||||||
api.OAuth2Config = &oauthConfig
|
|
||||||
api.OIDCTokenVerifier = verifier
|
|
||||||
}
|
|
||||||
api.NewApi()
|
|
||||||
|
|
||||||
log.Printf("Starting remote desktop gateway server")
|
log.Printf("Starting remote desktop gateway server")
|
||||||
cfg := &tls.Config{}
|
cfg := &tls.Config{}
|
||||||
|
@ -151,7 +164,7 @@ func main() {
|
||||||
cfg.GetCertificate = certMgr.GetCertificate
|
cfg.GetCertificate = certMgr.GetCertificate
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
http.ListenAndServe(":http", certMgr.HTTPHandler(nil))
|
http.ListenAndServe(":80", certMgr.HTTPHandler(nil))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,15 +203,17 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.Server.Authentication == "local" {
|
if conf.Server.Authentication == "local" {
|
||||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
|
h := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket}
|
||||||
|
http.Handle("/remoteDesktopGateway/", common.EnrichContext(h.BasicAuth(gw.HandleGatewayProtocol)))
|
||||||
} else {
|
} else {
|
||||||
// openid
|
// openid
|
||||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
oidc := initOIDC(url, store)
|
||||||
|
http.Handle("/connect", common.EnrichContext(oidc.Authenticated(http.HandlerFunc(h.HandleDownload))))
|
||||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
http.HandleFunc("/callback", api.HandleCallback)
|
http.HandleFunc("/callback", oidc.HandleCallback)
|
||||||
}
|
}
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
http.HandleFunc("/tokeninfo", web.TokenInfo)
|
||||||
|
|
||||||
if conf.Server.Tls == "disabled" {
|
if conf.Server.Tls == "disabled" {
|
||||||
err = server.ListenAndServe()
|
err = server.ListenAndServe()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -15,13 +15,17 @@ const (
|
||||||
protocol = "unix"
|
protocol = "unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
type BasicAuthHandler struct {
|
||||||
|
SocketAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
username, password, ok := r.BasicAuth()
|
username, password, ok := r.BasicAuth()
|
||||||
if ok {
|
if ok {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
return net.Dial(protocol, addr)
|
return net.Dial(protocol, addr)
|
||||||
}))
|
}))
|
127
cmd/rdpgw/web/oidc.go
Normal file
127
cmd/rdpgw/web/oidc.go
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CacheExpiration = time.Minute * 2
|
||||||
|
CleanupInterval = time.Minute * 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type OIDC struct {
|
||||||
|
oAuth2Config *oauth2.Config
|
||||||
|
oidcTokenVerifier *oidc.IDTokenVerifier
|
||||||
|
stateStore *cache.Cache
|
||||||
|
sessionStore sessions.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCConfig struct {
|
||||||
|
OAuth2Config *oauth2.Config
|
||||||
|
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||||
|
SessionStore sessions.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OIDCConfig) New() *OIDC {
|
||||||
|
return &OIDC{
|
||||||
|
oAuth2Config: c.OAuth2Config,
|
||||||
|
oidcTokenVerifier: c.OIDCTokenVerifier,
|
||||||
|
stateStore: cache.New(CacheExpiration, CleanupInterval),
|
||||||
|
sessionStore: c.SessionStore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
s, found := h.stateStore.Get(state)
|
||||||
|
if !found {
|
||||||
|
http.Error(w, "unknown state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
url := s.(string)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := struct {
|
||||||
|
OAuth2Token *oauth2.Token
|
||||||
|
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||||
|
}{oauth2Token, new(json.RawMessage)}
|
||||||
|
|
||||||
|
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]interface{}
|
||||||
|
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session.Options.MaxAge = MaxAge
|
||||||
|
session.Values["preferred_username"] = data["preferred_username"]
|
||||||
|
session.Values["authenticated"] = true
|
||||||
|
session.Values["access_token"] = oauth2Token.AccessToken
|
||||||
|
|
||||||
|
if err = session.Save(r, w); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(w, r, url, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
session, err := h.sessionStore.Get(r, RdpGwSession)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
found := session.Values["authenticated"]
|
||||||
|
if found == nil || !found.(bool) {
|
||||||
|
seed := make([]byte, 16)
|
||||||
|
rand.Read(seed)
|
||||||
|
state := hex.EncodeToString(seed)
|
||||||
|
h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
||||||
|
http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||||
|
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
30
cmd/rdpgw/web/session.go
Normal file
30
cmd/rdpgw/web/session.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionManagerConf struct {
|
||||||
|
SessionKey []byte
|
||||||
|
SessionEncryptionKey []byte
|
||||||
|
StoreType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SessionManagerConf) Init() sessions.Store {
|
||||||
|
if len(c.SessionKey) < 32 {
|
||||||
|
log.Fatal("Session key too small")
|
||||||
|
}
|
||||||
|
if len(c.SessionEncryptionKey) < 32 {
|
||||||
|
log.Fatal("Session key too small")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.StoreType == "file" {
|
||||||
|
log.Println("Filesystem is used as session storage")
|
||||||
|
return sessions.NewFilesystemStore(os.TempDir(), c.SessionKey, c.SessionEncryptionKey)
|
||||||
|
} else {
|
||||||
|
log.Println("Cookies are used as session storage")
|
||||||
|
return sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -9,7 +9,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
|
func TokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
|
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
215
cmd/rdpgw/web/web.go
Normal file
215
cmd/rdpgw/web/web.go
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RdpGwSession = "RDPGWSESSION"
|
||||||
|
MaxAge = 120
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||||
|
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||||
|
type QueryInfoFunc func(context.Context, string, string) (string, error)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
SessionStore sessions.Store
|
||||||
|
PAATokenGenerator TokenGeneratorFunc
|
||||||
|
UserTokenGenerator UserTokenGeneratorFunc
|
||||||
|
QueryInfo QueryInfoFunc
|
||||||
|
QueryTokenIssuer string
|
||||||
|
EnableUserToken bool
|
||||||
|
Hosts []string
|
||||||
|
HostSelection string
|
||||||
|
GatewayAddress *url.URL
|
||||||
|
RdpOpts RdpOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
type RdpOpts struct {
|
||||||
|
UsernameTemplate string
|
||||||
|
SplitUserDomain bool
|
||||||
|
DefaultDomain string
|
||||||
|
NetworkAutoDetect int
|
||||||
|
BandwidthAutoDetect int
|
||||||
|
ConnectionType int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
sessionStore sessions.Store
|
||||||
|
paaTokenGenerator TokenGeneratorFunc
|
||||||
|
enableUserToken bool
|
||||||
|
userTokenGenerator UserTokenGeneratorFunc
|
||||||
|
queryInfo QueryInfoFunc
|
||||||
|
queryTokenIssuer string
|
||||||
|
gatewayAddress *url.URL
|
||||||
|
hosts []string
|
||||||
|
hostSelection string
|
||||||
|
rdpOpts RdpOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) NewHandler() *Handler {
|
||||||
|
if len(c.Hosts) < 1 {
|
||||||
|
log.Fatal("Not enough hosts to connect to specified")
|
||||||
|
}
|
||||||
|
return &Handler{
|
||||||
|
sessionStore: c.SessionStore,
|
||||||
|
paaTokenGenerator: c.PAATokenGenerator,
|
||||||
|
enableUserToken: c.EnableUserToken,
|
||||||
|
userTokenGenerator: c.UserTokenGenerator,
|
||||||
|
queryInfo: c.QueryInfo,
|
||||||
|
queryTokenIssuer: c.QueryTokenIssuer,
|
||||||
|
gatewayAddress: c.GatewayAddress,
|
||||||
|
hosts: c.Hosts,
|
||||||
|
hostSelection: c.HostSelection,
|
||||||
|
rdpOpts: c.RdpOpts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) selectRandomHost() string {
|
||||||
|
rand.Seed(time.Now().Unix())
|
||||||
|
host := h.hosts[rand.Intn(len(h.hosts))]
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||||
|
switch h.hostSelection {
|
||||||
|
case "roundrobin":
|
||||||
|
return h.selectRandomHost(), nil
|
||||||
|
case "signed":
|
||||||
|
hosts, ok := u.Query()["host"]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("invalid query parameter")
|
||||||
|
}
|
||||||
|
host, err := h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, check := range h.hosts {
|
||||||
|
if check == host {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
log.Printf("Invalid host %s specified in token", hosts[0])
|
||||||
|
return "", errors.New("invalid host specified in query token")
|
||||||
|
}
|
||||||
|
return host, nil
|
||||||
|
case "unsigned":
|
||||||
|
hosts, ok := u.Query()["host"]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("invalid query parameter")
|
||||||
|
}
|
||||||
|
for _, check := range h.hosts {
|
||||||
|
if check == hosts[0] {
|
||||||
|
return hosts[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// not found
|
||||||
|
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||||
|
return "", errors.New("invalid host specified in query parameter")
|
||||||
|
case "any":
|
||||||
|
hosts, ok := u.Query()["host"]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("invalid query parameter")
|
||||||
|
}
|
||||||
|
return hosts[0], nil
|
||||||
|
default:
|
||||||
|
return h.selectRandomHost(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
userName, ok := ctx.Value("preferred_username").(string)
|
||||||
|
|
||||||
|
opts := h.rdpOpts
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
log.Printf("preferred_username not found in context")
|
||||||
|
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine host to connect to
|
||||||
|
host, err := h.getHost(ctx, r.URL)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||||
|
|
||||||
|
// split the username into user and domain
|
||||||
|
var user = userName
|
||||||
|
var domain = opts.DefaultDomain
|
||||||
|
if opts.SplitUserDomain {
|
||||||
|
creds := strings.SplitN(userName, "@", 2)
|
||||||
|
user = creds[0]
|
||||||
|
if len(creds) > 1 {
|
||||||
|
domain = creds[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
render := user
|
||||||
|
if opts.UsernameTemplate != "" {
|
||||||
|
render = fmt.Sprintf(h.rdpOpts.UsernameTemplate)
|
||||||
|
render = strings.Replace(render, "{{ username }}", user, 1)
|
||||||
|
if h.rdpOpts.UsernameTemplate == render {
|
||||||
|
log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user)
|
||||||
|
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := h.paaTokenGenerator(ctx, user, host)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
|
||||||
|
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.enableUserToken {
|
||||||
|
userToken, err := h.userTokenGenerator(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
||||||
|
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
render = strings.Replace(render, "{{ token }}", userToken, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticated
|
||||||
|
seed := make([]byte, 16)
|
||||||
|
rand.Read(seed)
|
||||||
|
fn := hex.EncodeToString(seed) + ".rdp"
|
||||||
|
|
||||||
|
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||||
|
w.Header().Set("Content-Type", "application/x-rdp")
|
||||||
|
|
||||||
|
data := "full address:s:" + host + "\r\n" +
|
||||||
|
"gatewayhostname:s:" + h.gatewayAddress.Host + "\r\n" +
|
||||||
|
"gatewaycredentialssource:i:5\r\n" +
|
||||||
|
"gatewayusagemethod:i:1\r\n" +
|
||||||
|
"gatewayprofileusagemethod:i:1\r\n" +
|
||||||
|
"gatewayaccesstoken:s:" + token + "\r\n" +
|
||||||
|
"networkautodetect:i:" + strconv.Itoa(opts.NetworkAutoDetect) + "\r\n" +
|
||||||
|
"bandwidthautodetect:i:" + strconv.Itoa(opts.BandwidthAutoDetect) + "\r\n" +
|
||||||
|
"connection type:i:" + strconv.Itoa(opts.ConnectionType) + "\r\n" +
|
||||||
|
"username:s:" + render + "\r\n" +
|
||||||
|
"domain:s:" + domain + "\r\n" +
|
||||||
|
"bitmapcachesize:i:32000\r\n" +
|
||||||
|
"smart sizing:i:1\r\n"
|
||||||
|
|
||||||
|
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -27,12 +27,14 @@ func TestGetHost(t *testing.T) {
|
||||||
HostSelection: "roundrobin",
|
HostSelection: "roundrobin",
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
}
|
}
|
||||||
|
h := c.NewHandler()
|
||||||
|
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Host: "example.com",
|
Host: "example.com",
|
||||||
}
|
}
|
||||||
vals := u.Query()
|
vals := u.Query()
|
||||||
|
|
||||||
host, err := c.getHost(ctx, u)
|
host, err := h.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("#{err}")
|
t.Fatalf("#{err}")
|
||||||
}
|
}
|
||||||
|
@ -44,14 +46,16 @@ func TestGetHost(t *testing.T) {
|
||||||
c.HostSelection = "unsigned"
|
c.HostSelection = "unsigned"
|
||||||
vals.Set("host", "in.valid.host")
|
vals.Set("host", "in.valid.host")
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(ctx, u)
|
h = c.NewHandler()
|
||||||
|
host, err = h.getHost(ctx, u)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Accepted host %s is not in hosts list", host)
|
t.Fatalf("Accepted host %s is not in hosts list", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
vals.Set("host", hosts[0])
|
vals.Set("host", hosts[0])
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(ctx, u)
|
h = c.NewHandler()
|
||||||
|
host, err = h.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||||
}
|
}
|
||||||
|
@ -64,7 +68,8 @@ func TestGetHost(t *testing.T) {
|
||||||
test := "bla.bla.com"
|
test := "bla.bla.com"
|
||||||
vals.Set("host", test)
|
vals.Set("host", test)
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(ctx, u)
|
h = c.NewHandler()
|
||||||
|
host, err = h.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%s is not accepted", host)
|
t.Fatalf("%s is not accepted", host)
|
||||||
}
|
}
|
||||||
|
@ -83,7 +88,8 @@ func TestGetHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
vals.Set("host", queryToken)
|
vals.Set("host", queryToken)
|
||||||
u.RawQuery = vals.Encode()
|
u.RawQuery = vals.Encode()
|
||||||
host, err = c.getHost(ctx, u)
|
h = c.NewHandler()
|
||||||
|
host, err = h.getHost(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue