rdpgw/cmd/rdpgw/config/configuration_unix.go
2024-08-11 23:28:40 +00:00

126 lines
4.3 KiB
Go

// +build !windows
package config
import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/v2"
"log"
"os"
"strings"
)
func Load(configFile string) Configuration {
var k = koanf.New(".")
k.Load(confmap.Provider(map[string]interface{}{
"Server.Tls": "auto",
"Server.Port": 443,
"Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Server.Authentication": "openid",
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
"Server.BasicAuthTimeout": 5,
"Client.NetworkAutoDetect": 1,
"Client.BandwidthAutoDetect": 1,
"Security.VerifyClientIp": true,
"Caps.TokenAuth": true,
}, "."), nil)
if _, err := os.Stat(configFile); os.IsNotExist(err) {
log.Printf("Config file %s not found, using defaults and environment", configFile)
} else {
if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil {
log.Fatalf("Error loading config from file: %v", err)
}
}
if err := k.Load(env.ProviderWithValue("RDPGW_", ".", func(s string, v string) (string, interface{}) {
key := strings.Replace(strings.ToLower(strings.TrimPrefix(s, "RDPGW_")), "__", ".", -1)
key = ToCamel(key)
v = strings.Trim(v, " ")
// handle lists
if strings.Contains(v, " ") {
return key, strings.Split(v, " ")
}
return key, v
}), nil); err != nil {
log.Fatalf("Error loading config from environment: %v", err)
}
koanfTag := koanf.UnmarshalConf{Tag: "koanf"}
k.UnmarshalWithConf("Server", &Conf.Server, koanfTag)
k.UnmarshalWithConf("OpenId", &Conf.OpenId, koanfTag)
k.UnmarshalWithConf("Caps", &Conf.Caps, koanfTag)
k.UnmarshalWithConf("Security", &Conf.Security, koanfTag)
k.UnmarshalWithConf("Client", &Conf.Client, koanfTag)
k.UnmarshalWithConf("Kerberos", &Conf.Kerberos, koanfTag)
if len(Conf.Security.PAATokenEncryptionKey) != 32 {
Conf.Security.PAATokenEncryptionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.paatokenencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Security.PAATokenSigningKey) != 32 {
Conf.Security.PAATokenSigningKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Security.EnableUserToken {
if len(Conf.Security.UserTokenEncryptionKey) != 32 {
Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Security.UserTokenSigningKey) != 32 {
Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random")
}
}
if len(Conf.Server.SessionKey) != 32 {
Conf.Server.SessionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `server.sessionkey` specified (empty or not 32 characters). Setting to random")
}
if len(Conf.Server.SessionEncryptionKey) != 32 {
Conf.Server.SessionEncryptionKey, _ = security.GenerateRandomString(32)
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}
if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" {
log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
}
if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() {
log.Fatalf("ntlm and kerberos authentication are not stackable")
}
if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() {
log.Fatalf("openid is configured but tokenauth disabled")
}
if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" {
log.Fatalf("kerberos is configured but no keytab was specified")
}
// prepend '//' if required for URL parsing
if !strings.Contains(Conf.Server.GatewayAddress, "//") {
Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress
}
return Conf
}