Work with go-jose to have encryption

This commit is contained in:
Bolke de Bruin 2020-08-19 11:33:26 +02:00
parent 2822dc8dd1
commit 188f077da1
7 changed files with 131 additions and 55 deletions

View file

@ -85,7 +85,8 @@ client:
security: security:
# a random string of at least 32 characters to secure cookies on the client # a random string of at least 32 characters to secure cookies on the client
# make sure to share this amongst different pods # make sure to share this amongst different pods
tokenSigningKey: thisisasessionkeyreplacethisjetzt PAATokenSigningKey: thisisasessionkeyreplacethisjetzt
PAATokenEncryptionKey: thisisasessionkeyreplacethisjetzt
``` ```
## Testing locally ## Testing locally
A convenience docker-compose allows you to test the RDPGW locally. It uses [Keycloak](http://www.keycloak.org) A convenience docker-compose allows you to test the RDPGW locally. It uses [Keycloak](http://www.keycloak.org)

View file

@ -23,14 +23,16 @@ const (
) )
type TokenGeneratorFunc func(context.Context, string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type Config struct { type Config struct {
SessionKey []byte SessionKey []byte
SessionEncryptionKey []byte SessionEncryptionKey []byte
TokenGenerator TokenGeneratorFunc PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
OAuth2Config *oauth2.Config OAuth2Config *oauth2.Config
store *sessions.CookieStore store *sessions.CookieStore
TokenVerifier *oidc.IDTokenVerifier OIDCTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache stateStore *cache.Cache
Hosts []string Hosts []string
GatewayAddress string GatewayAddress string
@ -72,7 +74,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return return
} }
idToken, err := c.TokenVerifier.Verify(ctx, rawIDToken) idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return
@ -103,6 +105,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
session.Options.MaxAge = MaxAge session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"] session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
if err = session.Save(r, w); err != nil { if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -130,6 +133,8 @@ func (c *Config) Authenticated(next http.Handler) http.Handler {
} }
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"]) ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
@ -159,7 +164,13 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
} }
} }
token, err := c.TokenGenerator(ctx, user, host) token, err := c.PAATokenGenerator(ctx, user, host)
if err != nil {
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
userToken, err := c.UserTokenGenerator(ctx, user)
if err != nil { if err != nil {
log.Printf("Cannot generate token for user %s due to %s", user, err) log.Printf("Cannot generate token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
@ -182,6 +193,6 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+ "networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+ "bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+ "connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
"username:s:"+token+"\r\n"+ "username:s:"+userToken+"\r\n"+
"bitmapcachesize:i:32000\r\n")) "bitmapcachesize:i:32000\r\n"))
} }

View file

@ -2,6 +2,7 @@ package common
import ( import (
"context" "context"
"log"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -48,3 +49,12 @@ func GetClientIp(ctx context.Context) string {
} }
return s return s
} }
func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value("access_token").(string)
if !ok {
log.Printf("cannot get access token from context")
return ""
}
return token
}

View file

@ -44,9 +44,10 @@ type RDGCapsConfig struct {
} }
type SecurityConfig struct { type SecurityConfig struct {
EnableOpenId bool PAATokenEncryptionKey string
TokenSigningKey string PAATokenSigningKey string
PassTokenAsPassword bool UserTokenEncryptionKey string
UserTokenSigningKey string
} }
type ClientConfig struct { type ClientConfig struct {
@ -82,7 +83,7 @@ func Load(configFile string) Configuration {
log.Fatalf("Cannot unmarshal the config file; %s", err) log.Fatalf("Cannot unmarshal the config file; %s", err)
} }
if len(conf.Security.TokenSigningKey) < 32 { if len(conf.Security.PAATokenSigningKey) < 32 {
log.Fatalf("Token signing key not long enough") log.Fatalf("Token signing key not long enough")
} }

2
go.mod
View file

@ -4,12 +4,12 @@ go 1.14
require ( require (
github.com/coreos/go-oidc/v3 v3.0.0-alpha.1 github.com/coreos/go-oidc/v3 v3.0.0-alpha.1
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1
github.com/gorilla/sessions v1.2.0 github.com/gorilla/sessions v1.2.0
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.7.1
github.com/spf13/cobra v1.0.0 github.com/spf13/cobra v1.0.0
github.com/spf13/viper v1.7.0 github.com/spf13/viper v1.7.0
github.com/square/go-jose/v3 v3.0.0-20200630053402-0a67ce9b0693
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
) )

26
main.go
View file

@ -35,7 +35,10 @@ func main() {
conf = config.Load(configFile) conf = config.Load(configFile)
// set security keys // set security keys
security.SigningKey = []byte(conf.Security.TokenSigningKey) security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
// set oidc config // set oidc config
ctx := context.Background() ctx := context.Background()
@ -57,17 +60,18 @@ func main() {
} }
api := &api.Config{ api := &api.Config{
GatewayAddress: conf.Server.GatewayAddress, GatewayAddress: conf.Server.GatewayAddress,
OAuth2Config: &oauthConfig, OAuth2Config: &oauthConfig,
TokenVerifier: verifier, OIDCTokenVerifier: verifier,
TokenGenerator: security.GeneratePAAToken, PAATokenGenerator: security.GeneratePAAToken,
SessionKey: []byte(conf.Server.SessionKey), UserTokenGenerator: security.GenerateUserToken,
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey), SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
Hosts: conf.Server.Hosts, Hosts: conf.Server.Hosts,
NetworkAutoDetect: conf.Client.NetworkAutoDetect, NetworkAutoDetect: conf.Client.NetworkAutoDetect,
UsernameTemplate: conf.Client.UsernameTemplate, UsernameTemplate: conf.Client.UsernameTemplate,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect, BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
ConnectionType: conf.Client.ConnectionType, ConnectionType: conf.Client.ConnectionType,
} }
api.NewApi() api.NewApi()

View file

@ -6,42 +6,64 @@ import (
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/common" "github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
"github.com/dgrijalva/jwt-go/v4" "github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/jwt"
"log" "log"
"time" "time"
) )
var SigningKey []byte var (
SigningKey []byte
EncryptionKey []byte
UserSigningKey []byte
UserEncryptionKey []byte
)
var ExpiryTime time.Duration = 5 var ExpiryTime time.Duration = 5
type customClaims struct { type customClaims struct {
RemoteServer string `json:"remoteServer"` RemoteServer string `json:"remoteServer"`
ClientIP string `json:"clientIp"` ClientIP string `json:"clientIp"`
jwt.StandardClaims AccessToken string `json:"accessToken"`
} }
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) { func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseSigned(tokenString)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) // check if the signing algo matches what we expect
for _, header := range token.Headers {
if header.Algorithm != string(jose.HS256) {
return false, fmt.Errorf("unexpected signing method: %v", header.Algorithm)
} }
}
return SigningKey, nil standard := jwt.Claims{}
}) custom := customClaims{}
// Claims automagically checks the signature...
err = token.Claims(SigningKey, &standard, &custom)
if err != nil { if err != nil {
log.Printf("token signature validation failed due to %s", err)
return false, err return false, err
} }
if c, ok := token.Claims.(*customClaims); ok && token.Valid { // ...but doesn't check the expiry claim :/
s := getSessionInfo(ctx) err = standard.Validate(jwt.Expected{
s.RemoteServer = c.RemoteServer Issuer: "rdpgw",
s.ClientIp = c.ClientIP Time: time.Now(),
return true, nil })
if err != nil {
log.Printf("token validation failed due to %s", err)
return false, err
} }
log.Printf("token validation failed: %s", err) s := getSessionInfo(ctx)
return false, err
s.RemoteServer = custom.RemoteServer
s.ClientIp = custom.ClientIP
return true, nil
} }
func VerifyServerFunc(ctx context.Context, host string) (bool, error) { func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
@ -68,34 +90,61 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
if len(SigningKey) < 32 { if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified") return "", errors.New("token signing key not long enough or not specified")
} }
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
exp := &jwt.Time{ if err != nil {
Time: time.Now().Add(time.Minute * 5), log.Printf("Cannot obtain signer %s", err)
} return "", err
now := &jwt.Time{
Time: time.Now(),
} }
c := customClaims{ standard := jwt.Claims{
Issuer: "rdpgw",
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Subject: username,
}
private := customClaims{
RemoteServer: server, RemoteServer: server,
ClientIP: common.GetClientIp(ctx), ClientIP: common.GetClientIp(ctx),
StandardClaims: jwt.StandardClaims{ AccessToken: common.GetAccessToken(ctx),
ExpiresAt: exp,
IssuedAt: now,
Issuer: "rdpgw",
Subject: username,
},
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, c) if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
if ss, err := token.SignedString(SigningKey); err != nil {
log.Printf("Cannot sign PAA token %s", err) log.Printf("Cannot sign PAA token %s", err)
return "", err return "", err
} else { } else {
return ss, nil return token, nil
} }
} }
func GenerateUserToken(ctx context.Context, userName string) (string, error) {
if len(UserEncryptionKey) < 32 {
return "", errors.New("user token encryption key not long enough or not specified")
}
claims := jwt.Claims{
Subject: userName,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Issuer: "rdpgw",
}
enc, err := jose.NewEncrypter(
jose.A128CBC_HS256,
jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey},
(&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"),
)
if err != nil {
log.Printf("Cannot encrypt user token due to %s", err)
return "", err
}
// this makes the token bigger and we deal with a limited space of 511 characters
// sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
// token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize()
token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize()
return token, err
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo { func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok { if !ok {