mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-29 13:56:15 +02:00
Full hack mode: openid connect, host template and short term tokens
This commit is contained in:
parent
3797e279c2
commit
bc897f1011
4 changed files with 157 additions and 15 deletions
55
main.go
55
main.go
|
@ -1,12 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/oauth2"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -29,17 +32,39 @@ var (
|
|||
|
||||
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
||||
|
||||
var state string
|
||||
|
||||
var oauthConfig oauth2.Config
|
||||
var oidcConfig *oidc.Config
|
||||
var verifier *oidc.IDTokenVerifier
|
||||
var ctx context.Context
|
||||
|
||||
var gateway string
|
||||
var overrideHost bool
|
||||
var hostTemplate string
|
||||
var claim string
|
||||
|
||||
func main() {
|
||||
// get config
|
||||
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
|
||||
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
|
||||
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
|
||||
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
||||
cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost", "gateway dns name")
|
||||
cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to")
|
||||
cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template")
|
||||
cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template")
|
||||
|
||||
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
|
||||
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
|
||||
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
|
||||
viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway"))
|
||||
viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride"))
|
||||
viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate"))
|
||||
viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim"))
|
||||
|
||||
viper.SetConfigFile(configFile)
|
||||
viper.SetConfigName("rdpgw")
|
||||
//viper.SetConfigFile(configFile)
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetEnvPrefix("RDPGW")
|
||||
viper.AutomaticEnv()
|
||||
|
@ -49,6 +74,33 @@ func main() {
|
|||
log.Printf("No config file found. Using defaults")
|
||||
}
|
||||
|
||||
// dont understand why I need to do this
|
||||
gateway = viper.GetString("gateway")
|
||||
hostTemplate = viper.GetString("hostTemplate")
|
||||
overrideHost = viper.GetBool("hostOverride")
|
||||
|
||||
// set oidc config
|
||||
ctx = context.Background()
|
||||
provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl"))
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||
}
|
||||
oidcConfig = &oidc.Config{
|
||||
ClientID: viper.GetString("clientId"),
|
||||
}
|
||||
verifier = provider.Verifier(oidcConfig)
|
||||
|
||||
oauthConfig = oauth2.Config{
|
||||
ClientID: viper.GetString("clientId"),
|
||||
ClientSecret: viper.GetString("clientSecret"),
|
||||
RedirectURL: "https://" + gateway + "/callback",
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
// check what is required
|
||||
state = "rdpstate"
|
||||
|
||||
if certFile == "" || keyFile == "" {
|
||||
log.Fatal("Both certfile and keyfile need to be specified")
|
||||
}
|
||||
|
@ -81,6 +133,7 @@ func main() {
|
|||
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
|
||||
http.HandleFunc("/connect", handleRdpDownload)
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/callback", handleCallback)
|
||||
|
||||
prometheus.MustRegister(connectionCache)
|
||||
prometheus.MustRegister(legacyConnections)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue