mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-24 17:20:48 +02:00
Allow host query parameter
the host query parameter can now be used dependent on the `hostselection` config.
This commit is contained in:
parent
45a57f44ff
commit
8bc3e25f83
5 changed files with 142 additions and 31 deletions
|
@ -59,10 +59,13 @@ Server:
|
|||
Hosts:
|
||||
- localhost:3389
|
||||
- my-{{ preferred_username }}-host:3389
|
||||
# Allow the user to connect to any host (insecure)
|
||||
- any
|
||||
# if true the server randomly selects a host to connect to
|
||||
RoundRobin: false
|
||||
# valid options are:
|
||||
# - roundrobin, which selects a random host from the list (default)
|
||||
# - signed, a listed host specified in the signed query parameter
|
||||
# - unsigned, a listed host specified in the query parameter
|
||||
# - any, insecurely allow any host specified in the query parameter
|
||||
HostSelection: roundrobin
|
||||
# a random strings of at least 32 characters to secure cookies on the client
|
||||
# make sure to share this across the different pods
|
||||
SessionKey: thisisasessionkeyreplacethisjetzt
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -21,7 +22,7 @@ import (
|
|||
|
||||
const (
|
||||
RdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
MaxAge = 120
|
||||
)
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
|
@ -30,7 +31,7 @@ type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
|||
type Config struct {
|
||||
SessionKey []byte
|
||||
SessionEncryptionKey []byte
|
||||
SessionStore string
|
||||
SessionStore string
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
EnableUserToken bool
|
||||
|
@ -39,13 +40,14 @@ type Config struct {
|
|||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||
stateStore *cache.Cache
|
||||
Hosts []string
|
||||
HostSelection string
|
||||
GatewayAddress string
|
||||
UsernameTemplate string
|
||||
NetworkAutoDetect int
|
||||
BandwidthAutoDetect int
|
||||
ConnectionType int
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
}
|
||||
|
||||
func (c *Config) NewApi() {
|
||||
|
@ -151,6 +153,47 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func (c *Config) selectRandomHost() string {
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||
return host
|
||||
}
|
||||
|
||||
func (c *Config) getHost(u *url.URL) (string, error) {
|
||||
var host string
|
||||
switch c.HostSelection {
|
||||
case "roundrobin":
|
||||
host = c.selectRandomHost()
|
||||
case "signed":
|
||||
case "unsigned":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
found := false
|
||||
for _, check := range c.Hosts {
|
||||
if check == hosts[0] {
|
||||
host = hosts[0]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
}
|
||||
case "any":
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
host = hosts[0]
|
||||
default:
|
||||
host = c.selectRandomHost()
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
@ -161,9 +204,12 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// do a round robin selection for now
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||
// determine host to connect to
|
||||
host, err := c.getHost(r.URL)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
|
||||
// split the username into user and domain
|
||||
|
@ -210,19 +256,19 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||
w.Header().Set("Content-Type", "application/x-rdp")
|
||||
data := "full address:s:"+host+"\r\n"+
|
||||
"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+
|
||||
"gatewaycredentialssource:i:5\r\n"+
|
||||
"gatewayusagemethod:i:1\r\n"+
|
||||
"gatewayprofileusagemethod:i:1\r\n"+
|
||||
"gatewayaccesstoken:s:"+token+"\r\n"+
|
||||
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
|
||||
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
|
||||
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
|
||||
"username:s:"+render+"\r\n"+
|
||||
"domain:s:"+domain+"\r\n"+
|
||||
"bitmapcachesize:i:32000\r\n"+
|
||||
"smart sizing:i:1\r\n"
|
||||
data := "full address:s:" + host + "\r\n" +
|
||||
"gatewayhostname:s:" + c.GatewayAddress + "\r\n" +
|
||||
"gatewaycredentialssource:i:5\r\n" +
|
||||
"gatewayusagemethod:i:1\r\n" +
|
||||
"gatewayprofileusagemethod:i:1\r\n" +
|
||||
"gatewayaccesstoken:s:" + token + "\r\n" +
|
||||
"networkautodetect:i:" + strconv.Itoa(c.NetworkAutoDetect) + "\r\n" +
|
||||
"bandwidthautodetect:i:" + strconv.Itoa(c.BandwidthAutoDetect) + "\r\n" +
|
||||
"connection type:i:" + strconv.Itoa(c.ConnectionType) + "\r\n" +
|
||||
"username:s:" + render + "\r\n" +
|
||||
"domain:s:" + domain + "\r\n" +
|
||||
"bitmapcachesize:i:32000\r\n" +
|
||||
"smart sizing:i:1\r\n"
|
||||
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
||||
}
|
||||
|
|
58
cmd/rdpgw/api/web_test.go
Normal file
58
cmd/rdpgw/api/web_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"}
|
||||
)
|
||||
|
||||
func contains(needle string, haystack []string) bool {
|
||||
for _, val := range haystack {
|
||||
if val == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestGetHost(t *testing.T) {
|
||||
c := Config{
|
||||
HostSelection: "roundrobin",
|
||||
Hosts: hosts,
|
||||
}
|
||||
u := &url.URL{
|
||||
Host: "example.com",
|
||||
}
|
||||
vals := u.Query()
|
||||
|
||||
host, err := c.getHost(u)
|
||||
if err != nil {
|
||||
t.Fatalf("#{err}")
|
||||
}
|
||||
if !contains(host, hosts) {
|
||||
t.Fatalf("host %s is not in hosts list", host)
|
||||
}
|
||||
|
||||
// check unsigned
|
||||
c.HostSelection = "unsigned"
|
||||
vals.Set("host", "in.valid.host")
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(u)
|
||||
if err == nil {
|
||||
t.Fatalf("Accepted host %s is not in hosts list", host)
|
||||
}
|
||||
|
||||
vals.Set("host", hosts[0])
|
||||
u.RawQuery = vals.Encode()
|
||||
host, err = c.getHost(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
|
||||
}
|
||||
if host != hosts[0] {
|
||||
t.Fatalf("host %s is not equal to input %s", host, hosts[0])
|
||||
}
|
||||
|
||||
}
|
|
@ -25,7 +25,7 @@ type ServerConfig struct {
|
|||
CertFile string `koanf:"certfile"`
|
||||
KeyFile string `koanf:"keyfile"`
|
||||
Hosts []string `koanf:"hosts"`
|
||||
RoundRobin bool `koanf:"roundrobin"`
|
||||
HostSelection string `koanf:"hostselection"`
|
||||
SessionKey string `koanf:"sessionkey"`
|
||||
SessionEncryptionKey string `koanf:"sessionencryptionkey"`
|
||||
SessionStore string `koanf:"sessionstore"`
|
||||
|
@ -118,6 +118,7 @@ func Load(configFile string) Configuration {
|
|||
"Server.TlsDisabled": false,
|
||||
"Server.Port": 443,
|
||||
"Server.SessionStore": "cookie",
|
||||
"Server.HostSelection": "roundrobin",
|
||||
"Client.NetworkAutoDetect": 1,
|
||||
"Client.BandwidthAutoDetect": 1,
|
||||
"Security.VerifyClientIp": true,
|
||||
|
@ -153,14 +154,16 @@ func Load(configFile string) Configuration {
|
|||
log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random")
|
||||
}
|
||||
|
||||
if len(Conf.Security.UserTokenEncryptionKey) != 32 {
|
||||
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
|
||||
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
|
||||
}
|
||||
if Conf.Security.EnableUserToken {
|
||||
if len(Conf.Security.UserTokenEncryptionKey) != 32 {
|
||||
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
|
||||
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
|
||||
}
|
||||
|
||||
if len(Conf.Security.UserTokenSigningKey) != 32 {
|
||||
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
|
||||
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
|
||||
if len(Conf.Security.UserTokenSigningKey) != 32 {
|
||||
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
|
||||
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
|
||||
}
|
||||
}
|
||||
|
||||
if len(Conf.Server.SessionKey) != 32 {
|
||||
|
|
|
@ -79,6 +79,7 @@ func main() {
|
|||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||
SessionStore: conf.Server.SessionStore,
|
||||
Hosts: conf.Server.Hosts,
|
||||
HostSelection: conf.Server.HostSelection,
|
||||
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
||||
UsernameTemplate: conf.Client.UsernameTemplate,
|
||||
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue