mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-19 15:03:48 +02:00
Check hostname specified by client against the token
This commit is contained in:
parent
39c73fc8fc
commit
5f3c7d07e2
5 changed files with 53 additions and 32 deletions
|
@ -19,9 +19,10 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RdpGwSession = "RDPGWSESSION"
|
RdpGwSession = "RDPGWSESSION"
|
||||||
|
MaxAge = 120
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenGeneratorFunc func(string, string) (string, error)
|
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
SessionKey []byte
|
SessionKey []byte
|
||||||
|
@ -99,6 +100,7 @@ func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
session.Options.MaxAge = MaxAge
|
||||||
session.Values["preferred_username"] = data["preferred_username"]
|
session.Values["preferred_username"] = data["preferred_username"]
|
||||||
session.Values["authenticated"] = true
|
session.Values["authenticated"] = true
|
||||||
|
|
||||||
|
@ -157,7 +159,7 @@ func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := c.TokenGenerator(user, host)
|
token, err := c.TokenGenerator(ctx, user, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
||||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||||
|
|
1
main.go
1
main.go
|
@ -116,6 +116,7 @@ func main() {
|
||||||
EnableAll: conf.Caps.RedirectAll,
|
EnableAll: conf.Caps.RedirectAll,
|
||||||
},
|
},
|
||||||
VerifyTunnelCreate: security.VerifyPAAToken,
|
VerifyTunnelCreate: security.VerifyPAAToken,
|
||||||
|
VerifyServerFunc: security.VerifyServerFunc,
|
||||||
}
|
}
|
||||||
gw := protocol.Gateway{
|
gw := protocol.Gateway{
|
||||||
HandlerConf: &handlerConfig,
|
HandlerConf: &handlerConfig,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
@ -13,9 +12,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type VerifyTunnelCreate func(*SessionInfo, string) (bool, error)
|
type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
||||||
type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error)
|
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||||
type VerifyServerFunc func(*SessionInfo, string) (bool, error)
|
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||||
|
|
||||||
type RedirectFlags struct {
|
type RedirectFlags struct {
|
||||||
Clipboard bool
|
Clipboard bool
|
||||||
|
@ -29,8 +28,6 @@ type RedirectFlags struct {
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
Session *SessionInfo
|
Session *SessionInfo
|
||||||
TransportIn transport.Transport
|
|
||||||
TransportOut transport.Transport
|
|
||||||
VerifyTunnelCreate VerifyTunnelCreate
|
VerifyTunnelCreate VerifyTunnelCreate
|
||||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||||
VerifyServerFunc VerifyServerFunc
|
VerifyServerFunc VerifyServerFunc
|
||||||
|
@ -55,10 +52,8 @@ type HandlerConf struct {
|
||||||
|
|
||||||
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
Session: s,
|
|
||||||
TransportIn: s.TransportIn,
|
|
||||||
TransportOut: s.TransportOut,
|
|
||||||
State: SERVER_STATE_INITIAL,
|
State: SERVER_STATE_INITIAL,
|
||||||
|
Session: s,
|
||||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||||
IdleTimeout: conf.IdleTimeout,
|
IdleTimeout: conf.IdleTimeout,
|
||||||
SmartCardAuth: conf.SmartCardAuth,
|
SmartCardAuth: conf.SmartCardAuth,
|
||||||
|
@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
|
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
|
||||||
msg := h.handshakeResponse(major, minor)
|
msg := h.handshakeResponse(major, minor)
|
||||||
h.TransportOut.WritePacket(msg)
|
h.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_HANDSHAKE
|
h.State = SERVER_STATE_HANDSHAKE
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
log.Printf("Tunnel create")
|
log.Printf("Tunnel create")
|
||||||
|
@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
_, cookie := readCreateTunnelRequest(pkt)
|
||||||
if h.VerifyTunnelCreate != nil {
|
if h.VerifyTunnelCreate != nil {
|
||||||
if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok {
|
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received")
|
log.Printf("Invalid PAA cookie received")
|
||||||
return errors.New("invalid PAA cookie")
|
return errors.New("invalid PAA cookie")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_TUNNEL_CREATE
|
h.State = SERVER_STATE_TUNNEL_CREATE
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
log.Printf("Tunnel auth")
|
log.Printf("Tunnel auth")
|
||||||
|
@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
client := h.readTunnelAuthRequest(pkt)
|
client := h.readTunnelAuthRequest(pkt)
|
||||||
if h.VerifyTunnelAuthFunc != nil {
|
if h.VerifyTunnelAuthFunc != nil {
|
||||||
if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok {
|
if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||||
log.Printf("Invalid client name: %s", client)
|
log.Printf("Invalid client name: %s", client)
|
||||||
return errors.New("invalid client name")
|
return errors.New("invalid client name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := h.createTunnelAuthResponse()
|
msg := h.createTunnelAuthResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.Session.TransportOut.WritePacket(msg)
|
||||||
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
log.Printf("Channel create")
|
log.Printf("Channel create")
|
||||||
|
@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := readChannelCreateRequest(pkt)
|
||||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
if h.VerifyServerFunc != nil {
|
if h.VerifyServerFunc != nil {
|
||||||
if ok, _ := h.VerifyServerFunc(h.Session, host); !ok {
|
if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
|
||||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||||
|
return errors.New("denied by security policy")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("Establishing connection to RDP server: %s", host)
|
log.Printf("Establishing connection to RDP server: %s", host)
|
||||||
|
@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
msg := createChannelCreateResponse()
|
msg := createChannelCreateResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.Session.TransportOut.WritePacket(msg)
|
||||||
|
|
||||||
// Make sure to start the flow from the RDP server first otherwise connections
|
// Make sure to start the flow from the RDP server first otherwise connections
|
||||||
// might hang eventually
|
// might hang eventually
|
||||||
|
@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
||||||
return errors.New("wrong state")
|
return errors.New("wrong state")
|
||||||
}
|
}
|
||||||
h.TransportIn.Close()
|
h.Session.TransportIn.Close()
|
||||||
h.TransportOut.Close()
|
h.Session.TransportOut.Close()
|
||||||
h.State = SERVER_STATE_CLOSED
|
h.State = SERVER_STATE_CLOSED
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||||
|
@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
size, pkt, err := h.TransportIn.ReadPacket()
|
size, pkt, err := h.Session.TransportIn.ReadPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, []byte{0, 0}, err
|
return 0, 0, []byte{0, 0}, err
|
||||||
}
|
}
|
||||||
|
@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
b1.Write(buf[:n])
|
b1.Write(buf[:n])
|
||||||
h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||||
b1.Reset()
|
b1.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,13 +46,11 @@ type Gateway struct {
|
||||||
|
|
||||||
type SessionInfo struct {
|
type SessionInfo struct {
|
||||||
ConnId string
|
ConnId string
|
||||||
CorrelationId string
|
|
||||||
ClientGeneration string
|
|
||||||
TransportIn transport.Transport
|
TransportIn transport.Transport
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
RemoteAddress string
|
RemoteAddress string
|
||||||
ProxyAddresses string
|
ProxyAddress string
|
||||||
UserName string
|
RemoteServer string
|
||||||
}
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
|
@ -65,9 +63,6 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
connectionCache.Set(float64(c.ItemCount()))
|
connectionCache.Set(float64(c.ItemCount()))
|
||||||
|
|
||||||
var s *SessionInfo
|
var s *SessionInfo
|
||||||
|
@ -79,6 +74,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
} else {
|
} else {
|
||||||
s = x.(*SessionInfo)
|
s = x.(*SessionInfo)
|
||||||
}
|
}
|
||||||
|
ctx := context.WithValue(r.Context(), "SessionInfo", s)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package security
|
package security
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bolkedebruin/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/protocol"
|
||||||
|
@ -17,7 +18,7 @@ type customClaims struct {
|
||||||
jwt.StandardClaims
|
jwt.StandardClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
|
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &customClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
@ -30,7 +31,9 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := token.Claims.(*customClaims); ok && token.Valid {
|
if c, ok := token.Claims.(*customClaims); ok && token.Valid {
|
||||||
|
s := getSessionInfo(ctx)
|
||||||
|
s.RemoteServer = c.RemoteServer
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +41,21 @@ func VerifyPAAToken(s *protocol.SessionInfo, tokenString string) (bool, error) {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeneratePAAToken(username string, server string) (string, error) {
|
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
|
||||||
|
s := getSessionInfo(ctx)
|
||||||
|
if s == nil {
|
||||||
|
return false, errors.New("no valid session info found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.RemoteServer != host {
|
||||||
|
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
|
||||||
if len(SigningKey) < 32 {
|
if len(SigningKey) < 32 {
|
||||||
return "", errors.New("token signing key not long enough or not specified")
|
return "", errors.New("token signing key not long enough or not specified")
|
||||||
}
|
}
|
||||||
|
@ -68,3 +85,12 @@ func GeneratePAAToken(username string, server string) (string, error) {
|
||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||||
|
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("cannot get session info from context")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue