Use jose v4 and make clearer and fix signing/encryption

This commit is contained in:
Bolke de Bruin 2024-04-12 12:33:46 +02:00
parent bc36b2b0cb
commit 9c6d056d69
4 changed files with 110 additions and 47 deletions

View file

@ -299,6 +299,8 @@ Security:
# PAATokenEncryptionKey: thisisasessionkeyreplacethisjetzt
# a random string of 32 characters to secure cookies on the client
UserTokenEncryptionKey: thisisasessionkeyreplacethisjetzt
# Signing makes the token bigger and we are limited to 511 characters
# UserTokenSigningKey: thisisasessionkeyreplacethisjetzt
# if you want to enable token generation for the user
# if true the username will be set to a jwt with the username embedded into it
EnableUserToken: true

View file

@ -7,8 +7,8 @@ import (
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"golang.org/x/oauth2"
"log"
"time"
@ -62,9 +62,9 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
return false, errors.New("no token to parse")
}
token, err := jwt.ParseSigned(tokenString)
token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256})
if err != nil {
log.Printf("cannot parse token due to: %tunnel", err)
log.Printf("cannot parse token due to: %t", err)
return false, err
}
@ -136,7 +136,7 @@ func GeneratePAAToken(ctx context.Context, username string, server string) (stri
AccessToken: id.GetAttribute(identity.AttrAccessToken).(string),
}
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).Serialize(); err != nil {
log.Printf("Cannot sign PAA token %s", err)
return "", err
} else {
@ -157,7 +157,10 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
enc, err := jose.NewEncrypter(
jose.A128CBC_HS256,
jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey},
jose.Recipient{
Algorithm: jose.DIRECT,
Key: UserEncryptionKey,
},
(&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"),
)
@ -167,16 +170,29 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
}
// this makes the token bigger and we deal with a limited space of 511 characters
// sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
// token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize()
token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize()
if len(UserSigningKey) > 0 {
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: UserSigningKey}, nil)
token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).Serialize()
if len(token) > 511 {
log.Printf("WARNING: token too long: len %d > 511", len(token))
}
return token, err
}
// no signature
token, err := jwt.Encrypted(enc).Claims(claims).Serialize()
return token, err
}
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
standard := jwt.Claims{}
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
enc, err := jwt.ParseSignedAndEncrypted(token)
enc, err := jwt.ParseSignedAndEncrypted(
token,
[]jose.KeyAlgorithm{jose.DIRECT},
[]jose.ContentEncryption{jose.A128CBC_HS256},
[]jose.SignatureAlgorithm{jose.HS256},
)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
@ -186,16 +202,12 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
} else if len(UserSigningKey) == 0 {
token, err := jwt.ParseEncrypted(token)
token, err := jwt.ParseEncrypted(token, []jose.KeyAlgorithm{jose.DIRECT}, []jose.ContentEncryption{jose.A128CBC_HS256})
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
@ -205,21 +217,6 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
} else {
token, err := jwt.ParseSigned(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
err = token.Claims(UserSigningKey, &standard)
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
}
// go-jose doesnt verify the expiry
@ -238,15 +235,11 @@ func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) {
standard := jwt.Claims{}
token, err := jwt.ParseSigned(tokenString)
token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256})
if err != nil {
log.Printf("Cannot get token %s", err)
return "", errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return "", errors.New("signature validation failure")
}
err = token.Claims(QuerySigningKey, &standard)
if err = token.Claims(QuerySigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
@ -287,7 +280,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
return "", err
}
token, err := jwt.Signed(sig).Claims(claims).CompactSerialize()
token, err := jwt.Signed(sig).Claims(claims).Serialize()
return token, err
}
@ -299,12 +292,3 @@ func getTunnel(ctx context.Context) *protocol.Tunnel {
}
return s
}
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
for _, header := range headers {
if header.Algorithm != alg {
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
}
}
return true, nil
}

View file

@ -0,0 +1,76 @@
package security
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"testing"
)
func TestGenerateUserToken(t *testing.T) {
cases := []struct {
SigningKey []byte
EncryptionKey []byte
name string
username string
}{
{
SigningKey: []byte("5aa3a1568fe8421cd7e127d5ace28d2d"),
EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"),
name: "sign_and_encrypt",
username: "test_sign_and_encrypt",
},
{
SigningKey: nil,
EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"),
name: "encrypt_only",
username: "test_encrypt_only",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
SigningKey = tc.SigningKey
UserEncryptionKey = tc.EncryptionKey
token, err := GenerateUserToken(context.Background(), tc.username)
if err != nil {
t.Fatalf("GenerateUserToken failed: %s", err)
}
claims, err := UserInfo(context.Background(), token)
if err != nil {
t.Fatalf("UserInfo failed: %s", err)
}
if claims.Subject != tc.username {
t.Fatalf("Expected %s, got %s", tc.username, claims.Subject)
}
})
}
}
func TestPAACookie(t *testing.T) {
SigningKey = []byte("5aa3a1568fe8421cd7e127d5ace28d2d")
EncryptionKey = []byte("d3ecd7e565e56e37e2f2e95b584d8c0c")
username := "test_paa_cookie"
attr_client_ip := "127.0.0.1"
attr_access_token := "aabbcc"
id := identity.NewUser()
id.SetUserName(username)
id.SetAttribute(identity.AttrClientIp, attr_client_ip)
id.SetAttribute(identity.AttrAccessToken, attr_access_token)
ctx := context.Background()
ctx = context.WithValue(ctx, identity.CTXKey, id)
_, err := GeneratePAAToken(ctx, "test_paa_cookie", "host.does.not.exist")
if err != nil {
t.Fatalf("GeneratePAAToken failed: %s", err)
}
/*ok, err := CheckPAACookie(ctx, token)
if err != nil {
t.Fatalf("CheckPAACookie failed: %s", err)
}
if !ok {
t.Fatalf("CheckPAACookie failed")
}*/
}

5
go.mod
View file

@ -6,7 +6,8 @@ require (
github.com/bolkedebruin/gokrb5/v8 v8.5.0
github.com/coreos/go-oidc/v3 v3.9.0
github.com/fatih/structs v1.1.0
github.com/go-jose/go-jose/v3 v3.0.3
github.com/go-jose/go-jose/v4 v4.0.1
github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/gorilla/sessions v1.2.2
@ -34,7 +35,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect