mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-22 18:45:56 +02:00
Full hack mode: openid connect, host template and short term tokens
This commit is contained in:
parent
3797e279c2
commit
bc897f1011
4 changed files with 157 additions and 15 deletions
92
download.go
92
download.go
|
@ -2,7 +2,11 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -10,21 +14,93 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
func handleRdpDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookie, err := r.Cookie("RDPGWSESSIONV1")
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, found := tokens.Get(cookie.Value)
|
||||||
|
if found == false {
|
||||||
|
log.Printf("Found expired or non existent session: %s", cookie.Value)
|
||||||
|
http.Redirect(w, r, oauthConfig.AuthCodeURL(state), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticated
|
||||||
seed := make([]byte, 16)
|
seed := make([]byte, 16)
|
||||||
rand.Read(seed)
|
rand.Read(seed)
|
||||||
fn := hex.EncodeToString(seed) + ".rdp"
|
fn := hex.EncodeToString(seed) + ".rdp"
|
||||||
|
|
||||||
rand.Read(seed)
|
|
||||||
token := hex.EncodeToString(seed)
|
|
||||||
|
|
||||||
tokens.Set(token, token, cache.DefaultExpiration)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||||
w.Header().Set("Content-Type", "application/x-rdp")
|
w.Header().Set("Content-Type", "application/x-rdp")
|
||||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
|
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(
|
||||||
"full address:s:localhost\r\n"+
|
"full address:s:" + host + "\r\n"+
|
||||||
"gatewayhostname:s:localhost\r\n"+
|
"gatewayhostname:s:" + gateway +"\r\n"+
|
||||||
"gatewaycredentialssource:i:5\r\n"+
|
"gatewaycredentialssource:i:5\r\n"+
|
||||||
"gatewayusagemethod:i:1\r\n"+
|
"gatewayusagemethod:i:1\r\n"+
|
||||||
"gatewayaccesstoken:s:" + token + "\r\n"))
|
"gatewayaccesstoken:s:" + cookie.Value + "\r\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
http.Error(w, "state did not match", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthToken, err := oauthConfig.Exchange(ctx, r.URL.Query().Get("code"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken, ok := oauthToken.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := struct {
|
||||||
|
OAuth2Token *oauth2.Token
|
||||||
|
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||||
|
}{oauthToken, new(json.RawMessage)}
|
||||||
|
|
||||||
|
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q, err := json.MarshalIndent(resp, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]interface{}
|
||||||
|
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
seed := make([]byte, 16)
|
||||||
|
rand.Read(seed)
|
||||||
|
token := hex.EncodeToString(seed)
|
||||||
|
|
||||||
|
cookie := http.Cookie{
|
||||||
|
Name: "RDPGWSESSIONV1",
|
||||||
|
Value: token,
|
||||||
|
Path: "/",
|
||||||
|
Secure: true,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens.Set(token, data[claim].(string), cache.DefaultExpiration)
|
||||||
|
|
||||||
|
http.SetCookie(w, &cookie)
|
||||||
|
http.Redirect(w, r, "/connect", http.StatusFound)
|
||||||
|
}
|
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module github.com/bolkedebruin/rdpgw
|
||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/coreos/go-oidc/v3 v3.0.0-alpha.1
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/prometheus/client_golang v1.7.1
|
github.com/prometheus/client_golang v1.7.1
|
||||||
|
|
55
main.go
55
main.go
|
@ -1,12 +1,15 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -29,17 +32,39 @@ var (
|
||||||
|
|
||||||
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
var tokens = cache.New(time.Minute *5, 10*time.Minute)
|
||||||
|
|
||||||
|
var state string
|
||||||
|
|
||||||
|
var oauthConfig oauth2.Config
|
||||||
|
var oidcConfig *oidc.Config
|
||||||
|
var verifier *oidc.IDTokenVerifier
|
||||||
|
var ctx context.Context
|
||||||
|
|
||||||
|
var gateway string
|
||||||
|
var overrideHost bool
|
||||||
|
var hostTemplate string
|
||||||
|
var claim string
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// get config
|
||||||
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
|
cmd.PersistentFlags().IntVarP(&port, "port", "p", 443, "port to listen on for incoming connection")
|
||||||
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
|
cmd.PersistentFlags().StringVarP(&certFile, "certfile", "", "server.pem", "public key certificate file")
|
||||||
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
|
cmd.PersistentFlags().StringVarP(&keyFile, "keyfile", "", "key.pem", "private key file")
|
||||||
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
||||||
|
cmd.PersistentFlags().StringVarP(&gateway, "gateway", "g", "localhost", "gateway dns name")
|
||||||
|
cmd.PersistentFlags().BoolVarP(&overrideHost, "hostOverride", "", false, "weather the user can override the host to connect to")
|
||||||
|
cmd.PersistentFlags().StringVarP(&hostTemplate, "hostTemplate", "t", "", "host template")
|
||||||
|
cmd.PersistentFlags().StringVarP(&claim, "claim", "", "preferred_username", "openid claim to use for filling in template")
|
||||||
|
|
||||||
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
|
viper.BindPFlag("port", cmd.PersistentFlags().Lookup("port"))
|
||||||
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
|
viper.BindPFlag("certfile", cmd.PersistentFlags().Lookup("certfile"))
|
||||||
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
|
viper.BindPFlag("keyfile", cmd.PersistentFlags().Lookup("keyfile"))
|
||||||
|
viper.BindPFlag("gateway", cmd.PersistentFlags().Lookup("gateway"))
|
||||||
|
viper.BindPFlag("hostOverride", cmd.PersistentFlags().Lookup("hostOverride"))
|
||||||
|
viper.BindPFlag("hostTemplate", cmd.PersistentFlags().Lookup("hostTemplate"))
|
||||||
|
viper.BindPFlag("claim", cmd.PersistentFlags().Lookup("claim"))
|
||||||
|
|
||||||
viper.SetConfigFile(configFile)
|
viper.SetConfigName("rdpgw")
|
||||||
|
//viper.SetConfigFile(configFile)
|
||||||
viper.AddConfigPath(".")
|
viper.AddConfigPath(".")
|
||||||
viper.SetEnvPrefix("RDPGW")
|
viper.SetEnvPrefix("RDPGW")
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
|
@ -49,6 +74,33 @@ func main() {
|
||||||
log.Printf("No config file found. Using defaults")
|
log.Printf("No config file found. Using defaults")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dont understand why I need to do this
|
||||||
|
gateway = viper.GetString("gateway")
|
||||||
|
hostTemplate = viper.GetString("hostTemplate")
|
||||||
|
overrideHost = viper.GetBool("hostOverride")
|
||||||
|
|
||||||
|
// set oidc config
|
||||||
|
ctx = context.Background()
|
||||||
|
provider, err := oidc.NewProvider(ctx, viper.GetString("providerUrl"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||||
|
}
|
||||||
|
oidcConfig = &oidc.Config{
|
||||||
|
ClientID: viper.GetString("clientId"),
|
||||||
|
}
|
||||||
|
verifier = provider.Verifier(oidcConfig)
|
||||||
|
|
||||||
|
oauthConfig = oauth2.Config{
|
||||||
|
ClientID: viper.GetString("clientId"),
|
||||||
|
ClientSecret: viper.GetString("clientSecret"),
|
||||||
|
RedirectURL: "https://" + gateway + "/callback",
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// check what is required
|
||||||
|
state = "rdpstate"
|
||||||
|
|
||||||
if certFile == "" || keyFile == "" {
|
if certFile == "" || keyFile == "" {
|
||||||
log.Fatal("Both certfile and keyfile need to be specified")
|
log.Fatal("Both certfile and keyfile need to be specified")
|
||||||
}
|
}
|
||||||
|
@ -81,6 +133,7 @@ func main() {
|
||||||
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
|
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
|
||||||
http.HandleFunc("/connect", handleRdpDownload)
|
http.HandleFunc("/connect", handleRdpDownload)
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
|
http.HandleFunc("/callback", handleCallback)
|
||||||
|
|
||||||
prometheus.MustRegister(connectionCache)
|
prometheus.MustRegister(connectionCache)
|
||||||
prometheus.MustRegister(legacyConnections)
|
prometheus.MustRegister(legacyConnections)
|
||||||
|
|
24
rdg.go
24
rdg.go
|
@ -16,6 +16,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -182,6 +183,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
|
var host string
|
||||||
for {
|
for {
|
||||||
mt, msg, err := conn.ReadMessage()
|
mt, msg, err := conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -218,10 +220,12 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
conn.WriteMessage(mt, msg)
|
conn.WriteMessage(mt, msg)
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
_, cookie := readCreateTunnelRequest(pkt)
|
||||||
if _, found := tokens.Get(cookie); found == false {
|
data, found := tokens.Get(cookie)
|
||||||
|
if found == false {
|
||||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
host = strings.Replace(hostTemplate, "%%", data.(string), 1)
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
log.Printf("Create tunnel response: %x", msg)
|
log.Printf("Create tunnel response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
conn.WriteMessage(mt, msg)
|
||||||
|
@ -232,13 +236,17 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
conn.WriteMessage(mt, msg)
|
conn.WriteMessage(mt, msg)
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := readChannelCreateRequest(pkt)
|
||||||
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
|
if overrideHost == true {
|
||||||
|
log.Printf("Override allowed")
|
||||||
|
host = net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
|
}
|
||||||
|
log.Printf("Establishing connection to RDP server: %s", host)
|
||||||
remote, err = net.DialTimeout(
|
remote, err = net.DialTimeout(
|
||||||
"tcp",
|
"tcp",
|
||||||
net.JoinHostPort(server, strconv.Itoa(int(port))),
|
host,
|
||||||
time.Second * 15)
|
time.Second * 30)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error connecting to %s, %d, %s", server, port, err)
|
log.Printf("Error connecting to %s", host)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
|
@ -349,7 +357,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
msg := handshakeResponse(major, minor, auth)
|
msg := handshakeResponse(major, minor, auth)
|
||||||
s.ConnOut.Write(msg)
|
s.ConnOut.Write(msg)
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
readCreateTunnelRequest(pkt)
|
_, cookie := readCreateTunnelRequest(pkt)
|
||||||
|
if _, found := tokens.Get(cookie); found == false {
|
||||||
|
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
s.ConnOut.Write(msg)
|
s.ConnOut.Write(msg)
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue