Add tokeninfo endpoint

This commit is contained in:
Bolke de Bruin 2020-08-19 12:12:52 +02:00
parent 188f077da1
commit 22d796c5cf
3 changed files with 113 additions and 0 deletions

40
api/token.go Normal file
View file

@ -0,0 +1,40 @@
package api
import (
"context"
"encoding/json"
"fmt"
"github.com/bolkedebruin/rdpgw/security"
"log"
"net/http"
)
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
return
}
tokens, ok := r.URL.Query()["access_token"]
if !ok || len(tokens[0]) < 1 {
log.Printf("Missing access_token in request")
http.Error(w, "access_token missing in request", http.StatusBadRequest)
return
}
token := tokens[0]
info, err := security.UserInfo(context.Background(), token)
if err != nil {
log.Printf("Token validation failed due to %s", err)
http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden)
return
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
if err = json.NewEncoder(w).Encode(info); err != nil {
log.Printf("Cannot encode json due to %s", err)
http.Error(w, "cannot encode json", http.StatusInternalServerError)
return
}
}

View file

@ -130,6 +130,7 @@ func main() {
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo)
http.HandleFunc("/callback", api.HandleCallback) http.HandleFunc("/callback", api.HandleCallback)
err = server.ListenAndServeTLS("", "") err = server.ListenAndServeTLS("", "")

View file

@ -145,6 +145,69 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
return token, err return token, err
} }
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
standard := jwt.Claims{}
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
enc, err := jwt.ParseSignedAndEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
token, err := enc.Decrypt(UserEncryptionKey)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
} else if len(UserSigningKey) == 0 {
token, err := jwt.ParseEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
err = token.Claims(UserEncryptionKey, &standard)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
} else {
token, err := jwt.ParseSigned(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
err = token.Claims(UserSigningKey, &standard)
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
}
// go-jose doesnt verify the expiry
err := standard.Validate(jwt.Expected{
Issuer: "rdpgw",
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return standard, fmt.Errorf("token validation failed due to %s", err)
}
return standard, nil
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo { func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo) s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok { if !ok {
@ -153,3 +216,12 @@ func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
} }
return s return s
} }
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
for _, header := range headers {
if header.Algorithm != alg {
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
}
}
return true, nil
}