Add viper & download rdp file

This commit is contained in:
Bolke de Bruin 2020-07-16 16:04:04 +02:00
parent 31f09af9d0
commit 3797e279c2
4 changed files with 77 additions and 9 deletions

48
main.go
View file

@ -2,23 +2,54 @@ package main
import (
"crypto/tls"
"flag"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"log"
"net/http"
"os"
"strconv"
"time"
)
var cmd = &cobra.Command{
Use: "rdpgw",
Long: "Remote Desktop Gateway",
}
var (
port int
certFile string
keyFile string
configFile string
)
var tokens = cache.New(time.Minute *5, 10*time.Minute)
func main() {
port := flag.Int("port", 443, "port to listen on for incoming connections")
certFile := flag.String("certfile", "server.pem", "public key certificate file")
keyFile := flag.String("keyfile", "key.pem", "private key file")
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
flag.Parse()
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
if *certFile == "" || *keyFile == "" {
viper.SetConfigFile(configFile)
viper.AddConfigPath(".")
viper.SetEnvPrefix("RDPGW")
viper.AutomaticEnv()
err := viper.ReadInConfig()
if err != nil {
log.Printf("No config file found. Using defaults")
}
if certFile == "" || keyFile == "" {
log.Fatal("Both certfile and keyfile need to be specified")
}
@ -36,18 +67,19 @@ func main() {
log.Printf("Key log file set to: %s", tlsDebug)
cfg.KeyLogWriter = w
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Fatal(err)
}
cfg.Certificates = append(cfg.Certificates, cert)
server := http.Server{
Addr: ":" + strconv.Itoa(*port),
Addr: ":" + strconv.Itoa(port),
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
}
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
http.HandleFunc("/connect", handleRdpDownload)
http.Handle("/metrics", promhttp.Handler())
prometheus.MustRegister(connectionCache)