mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-18 06:23:49 +02:00
Add tokeninfo endpoint
This commit is contained in:
parent
188f077da1
commit
22d796c5cf
3 changed files with 113 additions and 0 deletions
40
api/token.go
Normal file
40
api/token.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/bolkedebruin/rdpgw/security"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, ok := r.URL.Query()["access_token"]
|
||||||
|
if !ok || len(tokens[0]) < 1 {
|
||||||
|
log.Printf("Missing access_token in request")
|
||||||
|
http.Error(w, "access_token missing in request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := tokens[0]
|
||||||
|
|
||||||
|
info, err := security.UserInfo(context.Background(), token)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Token validation failed due to %s", err)
|
||||||
|
http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
if err = json.NewEncoder(w).Encode(info); err != nil {
|
||||||
|
log.Printf("Cannot encode json due to %s", err)
|
||||||
|
http.Error(w, "cannot encode json", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
1
main.go
1
main.go
|
@ -130,6 +130,7 @@ func main() {
|
||||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
|
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
||||||
http.HandleFunc("/callback", api.HandleCallback)
|
http.HandleFunc("/callback", api.HandleCallback)
|
||||||
|
|
||||||
err = server.ListenAndServeTLS("", "")
|
err = server.ListenAndServeTLS("", "")
|
||||||
|
|
|
@ -145,6 +145,69 @@ func GenerateUserToken(ctx context.Context, userName string) (string, error) {
|
||||||
return token, err
|
return token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
|
||||||
|
standard := jwt.Claims{}
|
||||||
|
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
|
||||||
|
enc, err := jwt.ParseSignedAndEncrypted(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot get token %s", err)
|
||||||
|
return standard, errors.New("cannot get token")
|
||||||
|
}
|
||||||
|
token, err := enc.Decrypt(UserEncryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot decrypt token %s", err)
|
||||||
|
return standard, errors.New("cannot decrypt token")
|
||||||
|
}
|
||||||
|
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||||
|
log.Printf("signature validation failure: %s", err)
|
||||||
|
return standard, errors.New("signature validation failure")
|
||||||
|
}
|
||||||
|
if err = token.Claims(UserSigningKey, &standard); err != nil {
|
||||||
|
log.Printf("cannot verify signature %s", err)
|
||||||
|
return standard, errors.New("cannot verify signature")
|
||||||
|
}
|
||||||
|
} else if len(UserSigningKey) == 0 {
|
||||||
|
token, err := jwt.ParseEncrypted(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot get token %s", err)
|
||||||
|
return standard, errors.New("cannot get token")
|
||||||
|
}
|
||||||
|
err = token.Claims(UserEncryptionKey, &standard)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot decrypt token %s", err)
|
||||||
|
return standard, errors.New("cannot decrypt token")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
token, err := jwt.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot get token %s", err)
|
||||||
|
return standard, errors.New("cannot get token")
|
||||||
|
}
|
||||||
|
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||||
|
log.Printf("signature validation failure: %s", err)
|
||||||
|
return standard, errors.New("signature validation failure")
|
||||||
|
}
|
||||||
|
err = token.Claims(UserSigningKey, &standard)
|
||||||
|
if err = token.Claims(UserSigningKey, &standard); err != nil {
|
||||||
|
log.Printf("cannot verify signature %s", err)
|
||||||
|
return standard, errors.New("cannot verify signature")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// go-jose doesnt verify the expiry
|
||||||
|
err := standard.Validate(jwt.Expected{
|
||||||
|
Issuer: "rdpgw",
|
||||||
|
Time: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("token validation failed due to %s", err)
|
||||||
|
return standard, fmt.Errorf("token validation failed due to %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return standard, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||||
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -153,3 +216,12 @@ func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
|
||||||
|
for _, header := range headers {
|
||||||
|
if header.Algorithm != alg {
|
||||||
|
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue