Allow host query parameter

the host query parameter can now be used
dependent on the `hostselection` config.
This commit is contained in:
Bolke de Bruin 2022-08-17 10:48:14 +02:00
parent 45a57f44ff
commit 8bc3e25f83
5 changed files with 142 additions and 31 deletions

View file

@ -59,10 +59,13 @@ Server:
Hosts: Hosts:
- localhost:3389 - localhost:3389
- my-{{ preferred_username }}-host:3389 - my-{{ preferred_username }}-host:3389
# Allow the user to connect to any host (insecure)
- any
# if true the server randomly selects a host to connect to # if true the server randomly selects a host to connect to
RoundRobin: false # valid options are:
# - roundrobin, which selects a random host from the list (default)
# - signed, a listed host specified in the signed query parameter
# - unsigned, a listed host specified in the query parameter
# - any, insecurely allow any host specified in the query parameter
HostSelection: roundrobin
# a random strings of at least 32 characters to secure cookies on the client # a random strings of at least 32 characters to secure cookies on the client
# make sure to share this across the different pods # make sure to share this across the different pods
SessionKey: thisisasessionkeyreplacethisjetzt SessionKey: thisisasessionkeyreplacethisjetzt

View file

@ -13,6 +13,7 @@ import (
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
"net/url"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -21,7 +22,7 @@ import (
const ( const (
RdpGwSession = "RDPGWSESSION" RdpGwSession = "RDPGWSESSION"
MaxAge = 120 MaxAge = 120
) )
type TokenGeneratorFunc func(context.Context, string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
@ -30,7 +31,7 @@ type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type Config struct { type Config struct {
SessionKey []byte SessionKey []byte
SessionEncryptionKey []byte SessionEncryptionKey []byte
SessionStore string SessionStore string
PAATokenGenerator TokenGeneratorFunc PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc UserTokenGenerator UserTokenGeneratorFunc
EnableUserToken bool EnableUserToken bool
@ -39,13 +40,14 @@ type Config struct {
OIDCTokenVerifier *oidc.IDTokenVerifier OIDCTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache stateStore *cache.Cache
Hosts []string Hosts []string
HostSelection string
GatewayAddress string GatewayAddress string
UsernameTemplate string UsernameTemplate string
NetworkAutoDetect int NetworkAutoDetect int
BandwidthAutoDetect int BandwidthAutoDetect int
ConnectionType int ConnectionType int
SplitUserDomain bool SplitUserDomain bool
DefaultDomain string DefaultDomain string
} }
func (c *Config) NewApi() { func (c *Config) NewApi() {
@ -151,6 +153,47 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
}) })
} }
func (c *Config) selectRandomHost() string {
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
return host
}
func (c *Config) getHost(u *url.URL) (string, error) {
var host string
switch c.HostSelection {
case "roundrobin":
host = c.selectRandomHost()
case "signed":
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
found := false
for _, check := range c.Hosts {
if check == hosts[0] {
host = hosts[0]
found = true
break
}
}
if !found {
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
}
case "any":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host = hosts[0]
default:
host = c.selectRandomHost()
}
return host, nil
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) { func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string) userName, ok := ctx.Value("preferred_username").(string)
@ -161,9 +204,12 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
return return
} }
// do a round robin selection for now // determine host to connect to
rand.Seed(time.Now().Unix()) host, err := c.getHost(r.URL)
host := c.Hosts[rand.Intn(len(c.Hosts))] if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
host = strings.Replace(host, "{{ preferred_username }}", userName, 1) host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
// split the username into user and domain // split the username into user and domain
@ -210,19 +256,19 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Disposition", "attachment; filename="+fn) w.Header().Set("Content-Disposition", "attachment; filename="+fn)
w.Header().Set("Content-Type", "application/x-rdp") w.Header().Set("Content-Type", "application/x-rdp")
data := "full address:s:"+host+"\r\n"+ data := "full address:s:" + host + "\r\n" +
"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+ "gatewayhostname:s:" + c.GatewayAddress + "\r\n" +
"gatewaycredentialssource:i:5\r\n"+ "gatewaycredentialssource:i:5\r\n" +
"gatewayusagemethod:i:1\r\n"+ "gatewayusagemethod:i:1\r\n" +
"gatewayprofileusagemethod:i:1\r\n"+ "gatewayprofileusagemethod:i:1\r\n" +
"gatewayaccesstoken:s:"+token+"\r\n"+ "gatewayaccesstoken:s:" + token + "\r\n" +
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+ "networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+ "bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+ "connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
"username:s:"+render+"\r\n"+ "username:s:" + render + "\r\n" +
"domain:s:"+domain+"\r\n"+ "domain:s:" + domain + "\r\n" +
"bitmapcachesize:i:32000\r\n"+ "bitmapcachesize:i:32000\r\n" +
"smart sizing:i:1\r\n" "smart sizing:i:1\r\n"
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data)) http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
} }

58
cmd/rdpgw/api/web_test.go Normal file
View file

@ -0,0 +1,58 @@
package api
import (
"net/url"
"testing"
)
var (
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
)
func contains(needle string, haystack []string) bool {
for _, val := range haystack {
if val == needle {
return true
}
}
return false
}
func TestGetHost(t *testing.T) {
c := Config{
HostSelection: "roundrobin",
Hosts: hosts,
}
u := &url.URL{
Host: "example.com",
}
vals := u.Query()
host, err := c.getHost(u)
if err != nil {
t.Fatalf("#{err}")
}
if !contains(host, hosts) {
t.Fatalf("host %s is not in hosts list", host)
}
// check unsigned
c.HostSelection = "unsigned"
vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
if err == nil {
t.Fatalf("Accepted host %s is not in hosts list", host)
}
vals.Set("host", hosts[0])
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
if host != hosts[0] {
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
}
}

View file

@ -25,7 +25,7 @@ type ServerConfig struct {
CertFile string `koanf:"certfile"` CertFile string `koanf:"certfile"`
KeyFile string `koanf:"keyfile"` KeyFile string `koanf:"keyfile"`
Hosts []string `koanf:"hosts"` Hosts []string `koanf:"hosts"`
RoundRobin bool `koanf:"roundrobin"` HostSelection string `koanf:"hostselection"`
SessionKey string `koanf:"sessionkey"` SessionKey string `koanf:"sessionkey"`
SessionEncryptionKey string `koanf:"sessionencryptionkey"` SessionEncryptionKey string `koanf:"sessionencryptionkey"`
SessionStore string `koanf:"sessionstore"` SessionStore string `koanf:"sessionstore"`
@ -118,6 +118,7 @@ func Load(configFile string) Configuration {
"Server.TlsDisabled": false, "Server.TlsDisabled": false,
"Server.Port": 443, "Server.Port": 443,
"Server.SessionStore": "cookie", "Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Client.NetworkAutoDetect": 1, "Client.NetworkAutoDetect": 1,
"Client.BandwidthAutoDetect": 1, "Client.BandwidthAutoDetect": 1,
"Security.VerifyClientIp": true, "Security.VerifyClientIp": true,
@ -153,14 +154,16 @@ func Load(configFile string) Configuration {
log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random")
} }
if len(Conf.Security.UserTokenEncryptionKey) != 32 { if Conf.Security.EnableUserToken {
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) if len(Conf.Security.UserTokenEncryptionKey) != 32 {
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
} log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Security.UserTokenSigningKey) != 32 { if len(Conf.Security.UserTokenSigningKey) != 32 {
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
}
} }
if len(Conf.Server.SessionKey) != 32 { if len(Conf.Server.SessionKey) != 32 {

View file

@ -79,6 +79,7 @@ func main() {
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
SessionStore: conf.Server.SessionStore, SessionStore: conf.Server.SessionStore,
Hosts: conf.Server.Hosts, Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection,
NetworkAutoDetect: conf.Client.NetworkAutoDetect, NetworkAutoDetect: conf.Client.NetworkAutoDetect,
UsernameTemplate: conf.Client.UsernameTemplate, UsernameTemplate: conf.Client.UsernameTemplate,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,