mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-25 20:08:20 +02:00
Refactor config
This commit is contained in:
parent
24c884f00f
commit
76e30ffa98
4 changed files with 64 additions and 50 deletions
37
config/configuration.go
Normal file
37
config/configuration.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
type Configuration struct {
|
||||||
|
Server ServerConfig
|
||||||
|
OpenId OpenIDConfig
|
||||||
|
Caps RDGCapsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
GatewayAddress string
|
||||||
|
Port int
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
FarmHosts []string
|
||||||
|
EnableOverride bool
|
||||||
|
HostTemplate string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenIDConfig struct {
|
||||||
|
ProviderUrl string
|
||||||
|
ClientId string
|
||||||
|
ClientSecret string
|
||||||
|
CallbackHost string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RDGCapsConfig struct {
|
||||||
|
SmartCardAuth bool
|
||||||
|
TokenAuth bool
|
||||||
|
IdleTimeout int
|
||||||
|
RedirectAll bool
|
||||||
|
DisableRedirect bool
|
||||||
|
DisableClipboard bool
|
||||||
|
DisablePrinter bool
|
||||||
|
DisablePort bool
|
||||||
|
DisablePnp bool
|
||||||
|
DisableDrive bool
|
||||||
|
}
|
|
@ -8,11 +8,14 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const state = "thisismystatebutshouldberandom"
|
||||||
|
|
||||||
func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
cookie, err := r.Cookie("RDPGWSESSIONV1")
|
cookie, err := r.Cookie("RDPGWSESSIONV1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -38,7 +41,7 @@ func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/x-rdp")
|
w.Header().Set("Content-Type", "application/x-rdp")
|
||||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
|
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
|
||||||
"full address:s:" + host + "\r\n"+
|
"full address:s:" + host + "\r\n"+
|
||||||
"gatewayhostname:s:" + gateway +"\r\n"+
|
"gatewayhostname:s:" + net.JoinHostPort(conf.Server.GatewayAddress, string(conf.Server.Port)) +"\r\n"+
|
||||||
"gatewaycredentialssource:i:5\r\n"+
|
"gatewaycredentialssource:i:5\r\n"+
|
||||||
"gatewayusagemethod:i:1\r\n"+
|
"gatewayusagemethod:i:1\r\n"+
|
||||||
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
||||||
|
@ -95,7 +98,8 @@ func handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens.Set(token, data[claim].(string), cache.DefaultExpiration)
|
// TODO: make dynamic
|
||||||
|
tokens.Set(token, data["preferred_username"].(string), cache.DefaultExpiration)
|
||||||
|
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
http.Redirect(w, r, "/connect", http.StatusFound)
|
http.Redirect(w, r, "/connect", http.StatusFound)
|
||||||
|
|
65
main.go
65
main.go
|
@ -3,10 +3,11 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"github.com/bolkedebruin/rdpgw/config"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -23,69 +24,41 @@ var cmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
port int
|
|
||||||
certFile string
|
|
||||||
keyFile string
|
|
||||||
|
|
||||||
configFile string
|
configFile string
|
||||||
)
|
)
|
||||||
|
|
||||||
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
||||||
|
var conf config.Configuration
|
||||||
|
|
||||||
var state string
|
|
||||||
|
|
||||||
var oauthConfig oauth2.Config
|
|
||||||
var oidcConfig *oidc.Config
|
|
||||||
var verifier *oidc.IDTokenVerifier
|
var verifier *oidc.IDTokenVerifier
|
||||||
|
var oauthConfig oauth2.Config
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
|
|
||||||
var gateway string
|
|
||||||
var overrideHost bool
|
|
||||||
var hostTemplate string
|
|
||||||
var claim string
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// get config
|
// get config
|
||||||
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
|
|
||||||
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
|
|
||||||
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
|
|
||||||
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
||||||
cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost", "gateway dns name")
|
|
||||||
cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to")
|
|
||||||
cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template")
|
|
||||||
cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template")
|
|
||||||
|
|
||||||
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
|
|
||||||
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
|
|
||||||
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
|
|
||||||
viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway"))
|
|
||||||
viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride"))
|
|
||||||
viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate"))
|
|
||||||
viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim"))
|
|
||||||
|
|
||||||
viper.SetConfigName("rdpgw")
|
viper.SetConfigName("rdpgw")
|
||||||
//viper.SetConfigFile(configFile)
|
viper.SetConfigFile(configFile)
|
||||||
viper.AddConfigPath(".")
|
viper.AddConfigPath(".")
|
||||||
viper.SetEnvPrefix("RDPGW")
|
viper.SetEnvPrefix("RDPGW")
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
|
|
||||||
err := viper.ReadInConfig()
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
if err != nil {
|
log.Printf("No config file found (%s). Using defaults", err)
|
||||||
log.Printf("No config file found. Using defaults")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dont understand why I need to do this
|
if err := viper.Unmarshal(&conf); err != nil {
|
||||||
gateway = viper.GetString("gateway")
|
log.Fatalf("Cannot unmarshal the config file; %s", err)
|
||||||
hostTemplate = viper.GetString("hostTemplate")
|
}
|
||||||
overrideHost = viper.GetBool("hostOverride")
|
|
||||||
|
|
||||||
// set oidc config
|
// set oidc config
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl"))
|
provider, err := oidc.NewProvider(ctx, conf.OpenId.ProviderUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||||
}
|
}
|
||||||
oidcConfig = &oidc.Config{
|
oidcConfig := &oidc.Config{
|
||||||
ClientID: viper.GetString("clientId"),
|
ClientID: viper.GetString("clientId"),
|
||||||
}
|
}
|
||||||
verifier = provider.Verifier(oidcConfig)
|
verifier = provider.Verifier(oidcConfig)
|
||||||
|
@ -93,15 +66,12 @@ func main() {
|
||||||
oauthConfig = oauth2.Config{
|
oauthConfig = oauth2.Config{
|
||||||
ClientID: viper.GetString("clientId"),
|
ClientID: viper.GetString("clientId"),
|
||||||
ClientSecret: viper.GetString("clientSecret"),
|
ClientSecret: viper.GetString("clientSecret"),
|
||||||
RedirectURL: "https://" + gateway + "/callback",
|
RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback",
|
||||||
Endpoint: provider.Endpoint(),
|
Endpoint: provider.Endpoint(),
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// check what is required
|
if conf.Server.CertFile == "" || conf.Server.KeyFile == "" {
|
||||||
state = "rdpstate"
|
|
||||||
|
|
||||||
if certFile == "" || keyFile == "" {
|
|
||||||
log.Fatal("Both certfile and keyfile need to be specified")
|
log.Fatal("Both certfile and keyfile need to be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,6 +79,7 @@ func main() {
|
||||||
//mux.HandleFunc("*", HelloServer)
|
//mux.HandleFunc("*", HelloServer)
|
||||||
|
|
||||||
log.Printf("Starting remote desktop gateway server")
|
log.Printf("Starting remote desktop gateway server")
|
||||||
|
|
||||||
cfg := &tls.Config{}
|
cfg := &tls.Config{}
|
||||||
tlsDebug := os.Getenv("SSLKEYLOGFILE")
|
tlsDebug := os.Getenv("SSLKEYLOGFILE")
|
||||||
if tlsDebug != "" {
|
if tlsDebug != "" {
|
||||||
|
@ -119,13 +90,15 @@ func main() {
|
||||||
log.Printf("Key log file set to: %s", tlsDebug)
|
log.Printf("Key log file set to: %s", tlsDebug)
|
||||||
cfg.KeyLogWriter = w
|
cfg.KeyLogWriter = w
|
||||||
}
|
}
|
||||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
cfg.Certificates = append(cfg.Certificates, cert)
|
cfg.Certificates = append(cfg.Certificates, cert)
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Addr: ":" + strconv.Itoa(port),
|
Addr: ":" + strconv.Itoa(conf.Server.Port),
|
||||||
TLSConfig: cfg,
|
TLSConfig: cfg,
|
||||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||||
}
|
}
|
||||||
|
|
4
rdg.go
4
rdg.go
|
@ -225,7 +225,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
host = strings.Replace(hostTemplate, "%%", data.(string), 1)
|
host = strings.Replace(conf.Server.HostTemplate, "%%", data.(string), 1)
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
log.Printf("Create tunnel response: %x", msg)
|
log.Printf("Create tunnel response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
conn.WriteMessage(mt, msg)
|
||||||
|
@ -236,7 +236,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
conn.WriteMessage(mt, msg)
|
conn.WriteMessage(mt, msg)
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := readChannelCreateRequest(pkt)
|
||||||
if overrideHost == true {
|
if conf.Server.EnableOverride == true {
|
||||||
log.Printf("Override allowed")
|
log.Printf("Override allowed")
|
||||||
host = net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host = net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue