Allow host query parameter

the host query parameter can now be used
dependent on the `hostselection` config.
This commit is contained in:
Bolke de Bruin 2022-08-17 10:48:14 +02:00
parent 45a57f44ff
commit 8bc3e25f83
5 changed files with 142 additions and 31 deletions

View file

@ -59,10 +59,13 @@ Server:
Hosts:
- localhost:3389
- my-{{ preferred_username }}-host:3389
# Allow the user to connect to any host (insecure)
- any
# if true the server randomly selects a host to connect to
RoundRobin: false
# valid options are:
# - roundrobin, which selects a random host from the list (default)
# - signed, a listed host specified in the signed query parameter
# - unsigned, a listed host specified in the query parameter
# - any, insecurely allow any host specified in the query parameter
HostSelection: roundrobin
# a random strings of at least 32 characters to secure cookies on the client
# make sure to share this across the different pods
SessionKey: thisisasessionkeyreplacethisjetzt

View file

@ -13,6 +13,7 @@ import (
"log"
"math/rand"
"net/http"
"net/url"
"os"
"strconv"
"strings"
@ -21,7 +22,7 @@ import (
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
@ -30,7 +31,7 @@ type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type Config struct {
SessionKey []byte
SessionEncryptionKey []byte
SessionStore string
SessionStore string
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
EnableUserToken bool
@ -39,13 +40,14 @@ type Config struct {
OIDCTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
Hosts []string
HostSelection string
GatewayAddress string
UsernameTemplate string
NetworkAutoDetect int
BandwidthAutoDetect int
ConnectionType int
SplitUserDomain bool
DefaultDomain string
SplitUserDomain bool
DefaultDomain string
}
func (c *Config) NewApi() {
@ -151,6 +153,47 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
})
}
func (c *Config) selectRandomHost() string {
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
return host
}
func (c *Config) getHost(u *url.URL) (string, error) {
var host string
switch c.HostSelection {
case "roundrobin":
host = c.selectRandomHost()
case "signed":
case "unsigned":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
found := false
for _, check := range c.Hosts {
if check == hosts[0] {
host = hosts[0]
found = true
break
}
}
if !found {
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
}
case "any":
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
host = hosts[0]
default:
host = c.selectRandomHost()
}
return host, nil
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
@ -161,9 +204,12 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
return
}
// do a round robin selection for now
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
// determine host to connect to
host, err := c.getHost(r.URL)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
// split the username into user and domain
@ -210,19 +256,19 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
w.Header().Set("Content-Type", "application/x-rdp")
data := "full address:s:"+host+"\r\n"+
"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+
"gatewaycredentialssource:i:5\r\n"+
"gatewayusagemethod:i:1\r\n"+
"gatewayprofileusagemethod:i:1\r\n"+
"gatewayaccesstoken:s:"+token+"\r\n"+
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
"username:s:"+render+"\r\n"+
"domain:s:"+domain+"\r\n"+
"bitmapcachesize:i:32000\r\n"+
"smart sizing:i:1\r\n"
data := "full address:s:" + host + "\r\n" +
"gatewayhostname:s:" + c.GatewayAddress + "\r\n" +
"gatewaycredentialssource:i:5\r\n" +
"gatewayusagemethod:i:1\r\n" +
"gatewayprofileusagemethod:i:1\r\n" +
"gatewayaccesstoken:s:" + token + "\r\n" +
"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
"username:s:" + render + "\r\n" +
"domain:s:" + domain + "\r\n" +
"bitmapcachesize:i:32000\r\n" +
"smart sizing:i:1\r\n"
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
}

58
cmd/rdpgw/api/web_test.go Normal file
View file

@ -0,0 +1,58 @@
package api
import (
"net/url"
"testing"
)
var (
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
)
func contains(needle string, haystack []string) bool {
for _, val := range haystack {
if val == needle {
return true
}
}
return false
}
func TestGetHost(t *testing.T) {
c := Config{
HostSelection: "roundrobin",
Hosts: hosts,
}
u := &url.URL{
Host: "example.com",
}
vals := u.Query()
host, err := c.getHost(u)
if err != nil {
t.Fatalf("#{err}")
}
if !contains(host, hosts) {
t.Fatalf("host %s is not in hosts list", host)
}
// check unsigned
c.HostSelection = "unsigned"
vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
if err == nil {
t.Fatalf("Accepted host %s is not in hosts list", host)
}
vals.Set("host", hosts[0])
u.RawQuery = vals.Encode()
host, err = c.getHost(u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
if host != hosts[0] {
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
}
}

View file

@ -25,7 +25,7 @@ type ServerConfig struct {
CertFile string `koanf:"certfile"`
KeyFile string `koanf:"keyfile"`
Hosts []string `koanf:"hosts"`
RoundRobin bool `koanf:"roundrobin"`
HostSelection string `koanf:"hostselection"`
SessionKey string `koanf:"sessionkey"`
SessionEncryptionKey string `koanf:"sessionencryptionkey"`
SessionStore string `koanf:"sessionstore"`
@ -118,6 +118,7 @@ func Load(configFile string) Configuration {
"Server.TlsDisabled": false,
"Server.Port": 443,
"Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Client.NetworkAutoDetect": 1,
"Client.BandwidthAutoDetect": 1,
"Security.VerifyClientIp": true,
@ -153,14 +154,16 @@ func Load(configFile string) Configuration {
log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Security.UserTokenEncryptionKey) != 32 {
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Security.EnableUserToken {
if len(Conf.Security.UserTokenEncryptionKey) != 32 {
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Security.UserTokenSigningKey) != 32 {
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
if len(Conf.Security.UserTokenSigningKey) != 32 {
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
}
}
if len(Conf.Server.SessionKey) != 32 {

View file

@ -79,6 +79,7 @@ func main() {
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
SessionStore: conf.Server.SessionStore,
Hosts: conf.Server.Hosts,
HostSelection: conf.Server.HostSelection,
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
UsernameTemplate: conf.Client.UsernameTemplate,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,