Check hostname specified by client against the token

This commit is contained in:
Bolke de Bruin 2020-07-25 19:37:33 +02:00
parent 39c73fc8fc
commit 5f3c7d07e2
5 changed files with 53 additions and 32 deletions

View file

@ -19,9 +19,10 @@ import (
const ( const (
RdpGwSession = "RDPGWSESSION" RdpGwSession = "RDPGWSESSION"
MaxAge = 120
) )
type TokenGeneratorFunc func(string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type Config struct { type Config struct {
SessionKey []byte SessionKey []byte
@ -99,6 +100,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
return return
} }
session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"] session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true session.Values["authenticated"] = true
@ -157,7 +159,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
} }
} }
token, err := c.TokenGenerator(user, host) token, err := c.TokenGenerator(ctx, user, host)
if err != nil { if err != nil {
log.Printf("Cannot generate token for user %s due to %s", user, err) log.Printf("Cannot generate token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)

View file

@ -116,6 +116,7 @@ func main() {
EnableAll: conf.Caps.RedirectAll, EnableAll: conf.Caps.RedirectAll,
}, },
VerifyTunnelCreate: security.VerifyPAAToken, VerifyTunnelCreate: security.VerifyPAAToken,
VerifyServerFunc: security.VerifyServerFunc,
} }
gw := protocol.Gateway{ gw := protocol.Gateway{
HandlerConf: &handlerConfig, HandlerConf: &handlerConfig,

View file

@ -5,7 +5,6 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/transport"
"io" "io"
"log" "log"
"net" "net"
@ -13,9 +12,9 @@ import (
"time" "time"
) )
type VerifyTunnelCreate func(*SessionInfo, string) (bool, error) type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error) type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(*SessionInfo, string) (bool, error) type VerifyServerFunc func(context.Context, string) (bool, error)
type RedirectFlags struct { type RedirectFlags struct {
Clipboard bool Clipboard bool
@ -29,8 +28,6 @@ type RedirectFlags struct {
type Handler struct { type Handler struct {
Session *SessionInfo Session *SessionInfo
TransportIn transport.Transport
TransportOut transport.Transport
VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc VerifyServerFunc VerifyServerFunc
@ -55,10 +52,8 @@ type HandlerConf struct {
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler { func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
h := &Handler{ h := &Handler{
State: SERVER_STATE_INITIAL,
Session: s, Session: s,
TransportIn: s.TransportIn,
TransportOut: s.TransportOut,
State: SERVER_STATE_INITIAL,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags), RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
IdleTimeout: conf.IdleTimeout, IdleTimeout: conf.IdleTimeout,
SmartCardAuth: conf.SmartCardAuth, SmartCardAuth: conf.SmartCardAuth,
@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error {
} }
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
msg := h.handshakeResponse(major, minor) msg := h.handshakeResponse(major, minor)
h.TransportOut.WritePacket(msg) h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_HANDSHAKE h.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create") log.Printf("Tunnel create")
@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error {
} }
_, cookie := readCreateTunnelRequest(pkt) _, cookie := readCreateTunnelRequest(pkt)
if h.VerifyTunnelCreate != nil { if h.VerifyTunnelCreate != nil {
if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok { if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received") log.Printf("Invalid PAA cookie received")
return errors.New("invalid PAA cookie") return errors.New("invalid PAA cookie")
} }
} }
msg := createTunnelResponse() msg := createTunnelResponse()
h.TransportOut.WritePacket(msg) h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_CREATE h.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth") log.Printf("Tunnel auth")
@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error {
} }
client := h.readTunnelAuthRequest(pkt) client := h.readTunnelAuthRequest(pkt)
if h.VerifyTunnelAuthFunc != nil { if h.VerifyTunnelAuthFunc != nil {
if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok { if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %s", client) log.Printf("Invalid client name: %s", client)
return errors.New("invalid client name") return errors.New("invalid client name")
} }
} }
msg := h.createTunnelAuthResponse() msg := h.createTunnelAuthResponse()
h.TransportOut.WritePacket(msg) h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_AUTHORIZE h.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create") log.Printf("Channel create")
@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error {
server, port := readChannelCreateRequest(pkt) server, port := readChannelCreateRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port))) host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if h.VerifyServerFunc != nil { if h.VerifyServerFunc != nil {
if ok, _ := h.VerifyServerFunc(h.Session, host); !ok { if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host) log.Printf("Not allowed to connect to %s by policy handler", host)
return errors.New("denied by security policy")
} }
} }
log.Printf("Establishing connection to RDP server: %s", host) log.Printf("Establishing connection to RDP server: %s", host)
@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error {
} }
log.Printf("Connection established") log.Printf("Connection established")
msg := createChannelCreateResponse() msg := createChannelCreateResponse()
h.TransportOut.WritePacket(msg) h.Session.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error {
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED) log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
return errors.New("wrong state") return errors.New("wrong state")
} }
h.TransportIn.Close() h.Session.TransportIn.Close()
h.TransportOut.Close() h.Session.TransportOut.Close()
h.State = SERVER_STATE_CLOSED h.State = SERVER_STATE_CLOSED
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt) log.Printf("Unknown packet (size %d): %x", sz, pkt)
@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
size, pkt, err := h.TransportIn.ReadPacket() size, pkt, err := h.Session.TransportIn.ReadPacket()
if err != nil { if err != nil {
return 0, 0, []byte{0, 0}, err return 0, 0, []byte{0, 0}, err
} }
@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() {
break break
} }
b1.Write(buf[:n]) b1.Write(buf[:n])
h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes())) h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset() b1.Reset()
} }
} }

View file

@ -46,13 +46,11 @@ type Gateway struct {
type SessionInfo struct { type SessionInfo struct {
ConnId string ConnId string
CorrelationId string
ClientGeneration string
TransportIn transport.Transport TransportIn transport.Transport
TransportOut transport.Transport TransportOut transport.Transport
RemoteAddress string RemoteAddress string
ProxyAddresses string ProxyAddress string
UserName string RemoteServer string
} }
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
@ -65,9 +63,6 @@ func init() {
} }
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionCache.Set(float64(c.ItemCount())) connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo var s *SessionInfo
@ -79,6 +74,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
} else { } else {
s = x.(*SessionInfo) s = x.(*SessionInfo)
} }
ctx := context.WithValue(r.Context(), "SessionInfo", s)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {

View file

@ -1,6 +1,7 @@
package security package security
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
@ -17,7 +18,7 @@ type customClaims struct {
jwt.StandardClaims jwt.StandardClaims
} }
func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) { func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@ -30,7 +31,9 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
return false, err return false, err
} }
if _, ok := token.Claims.(*customClaims); ok && token.Valid { if c, ok := token.Claims.(*customClaims); ok && token.Valid {
s := getSessionInfo(ctx)
s.RemoteServer = c.RemoteServer
return true, nil return true, nil
} }
@ -38,7 +41,21 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
return false, err return false, err
} }
func GeneratePAAToken(username string, server string) (string, error) { func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
return true, nil
}
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
if len(SigningKey) < 32 { if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified") return "", errors.New("token signing key not long enough or not specified")
} }
@ -67,4 +84,13 @@ func GeneratePAAToken(username string, server string) (string, error) {
} else { } else {
return ss, nil return ss, nil
} }
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok {
log.Printf("cannot get session info from context")
return nil
}
return s
} }