mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-21 10:06:01 +02:00
Add viper & download rdp file
This commit is contained in:
parent
31f09af9d0
commit
3797e279c2
4 changed files with 77 additions and 9 deletions
30
download.go
Normal file
30
download.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
fn := hex.EncodeToString(seed) + ".rdp"
|
||||
|
||||
rand.Read(seed)
|
||||
token := hex.EncodeToString(seed)
|
||||
|
||||
tokens.Set(token, token, cache.DefaultExpiration)
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||
w.Header().Set("Content-Type", "application/x-rdp")
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
|
||||
"full address:s:localhost\r\n"+
|
||||
"gatewayhostname:s:localhost\r\n"+
|
||||
"gatewaycredentialssource:i:5\r\n"+
|
||||
"gatewayusagemethod:i:1\r\n"+
|
||||
"gatewayaccesstoken:s:" + token + "\r\n"))
|
||||
}
|
2
go.mod
2
go.mod
|
@ -6,4 +6,6 @@ require (
|
|||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/prometheus/client_golang v1.7.1
|
||||
github.com/spf13/cobra v1.0.0
|
||||
github.com/spf13/viper v1.7.0
|
||||
)
|
||||
|
|
48
main.go
48
main.go
|
@ -2,23 +2,54 @@ package main
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var cmd = &cobra.Command{
|
||||
Use: "rdpgw",
|
||||
Long: "Remote Desktop Gateway",
|
||||
}
|
||||
|
||||
var (
|
||||
port int
|
||||
certFile string
|
||||
keyFile string
|
||||
|
||||
configFile string
|
||||
)
|
||||
|
||||
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 443, "port to listen on for incoming connections")
|
||||
certFile := flag.String("certfile", "server.pem", "public key certificate file")
|
||||
keyFile := flag.String("keyfile", "key.pem", "private key file")
|
||||
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
|
||||
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
|
||||
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
|
||||
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
||||
|
||||
flag.Parse()
|
||||
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
|
||||
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
|
||||
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
|
||||
|
||||
if *certFile == "" || *keyFile == "" {
|
||||
viper.SetConfigFile(configFile)
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetEnvPrefix("RDPGW")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
log.Printf("No config file found. Using defaults")
|
||||
}
|
||||
|
||||
if certFile == "" || keyFile == "" {
|
||||
log.Fatal("Both certfile and keyfile need to be specified")
|
||||
}
|
||||
|
||||
|
@ -36,18 +67,19 @@ func main() {
|
|||
log.Printf("Key log file set to: %s", tlsDebug)
|
||||
cfg.KeyLogWriter = w
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cfg.Certificates = append(cfg.Certificates, cert)
|
||||
server := http.Server{
|
||||
Addr: ":" + strconv.Itoa(*port),
|
||||
Addr: ":" + strconv.Itoa(port),
|
||||
TLSConfig: cfg,
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||
}
|
||||
|
||||
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
|
||||
http.HandleFunc("/connect", handleRdpDownload)
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
prometheus.MustRegister(connectionCache)
|
||||
|
|
6
rdg.go
6
rdg.go
|
@ -217,7 +217,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
|||
log.Printf("Handshake response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
readCreateTunnelRequest(pkt)
|
||||
_, cookie := readCreateTunnelRequest(pkt)
|
||||
if _, found := tokens.Get(cookie); found == false {
|
||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
msg := createTunnelResponse()
|
||||
log.Printf("Create tunnel response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue