Make sure to validate the client's ip address

This commit is contained in:
Bolke de Bruin 2020-07-25 21:00:58 +02:00
parent 5fc75ef877
commit 0b299619ff
6 changed files with 70 additions and 10 deletions

View file

@ -25,7 +25,7 @@ Connect integration enabled by default. Cookies are encrypted and signed on the
on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens) on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens)
are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go) are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go)
signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by
the server. the server. Finally, the client's ip address needs to match the one it obtained the token with.
## How to build ## How to build
```bash ```bash

49
client/remote.go Normal file
View file

@ -0,0 +1,49 @@
package client
import (
"context"
"net/http"
"strings"
)
const (
ClientIPCtx = "ClientIP"
ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress"
)
func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h := r.Header.Get("X-Forwarded-For")
if h != "" {
var proxies []string
ips := strings.Split(h, ",")
for i := range ips {
ips[i] = strings.TrimSpace(ips[i])
}
clientIp := ips[0]
if len(ips) > 1 {
proxies = ips[1:]
}
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
}
remote := r.Header.Get("REMOTE_ADDR")
ctx = context.WithValue(ctx, RemoteAddressCtx, remote)
if h == "" {
ctx = context.WithValue(ctx, ClientIPCtx, remote)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetClientIp(ctx context.Context) string {
s, ok := ctx.Value(ClientIPCtx).(string)
if !ok {
return ""
}
return s
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"github.com/bolkedebruin/rdpgw/api" "github.com/bolkedebruin/rdpgw/api"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/config" "github.com/bolkedebruin/rdpgw/config"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/security" "github.com/bolkedebruin/rdpgw/security"
@ -122,8 +123,8 @@ func main() {
HandlerConf: &handlerConfig, HandlerConf: &handlerConfig,
} }
http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol) http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", api.Authenticated(http.HandlerFunc(api.HandleDownload))) http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/callback", api.HandleCallback) http.HandleFunc("/callback", api.HandleCallback)

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/client"
"io" "io"
"log" "log"
"net" "net"
@ -96,7 +97,7 @@ func (h *Handler) Process(ctx context.Context) error {
_, cookie := readCreateTunnelRequest(pkt) _, cookie := readCreateTunnelRequest(pkt)
if h.VerifyTunnelCreate != nil { if h.VerifyTunnelCreate != nil {
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok { if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received") log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
return errors.New("invalid PAA cookie") return errors.New("invalid PAA cookie")
} }
} }

View file

@ -2,6 +2,7 @@ package protocol
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/transport" "github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -48,9 +49,8 @@ type SessionInfo struct {
ConnId string ConnId string
TransportIn transport.Transport TransportIn transport.Transport
TransportOut transport.Transport TransportOut transport.Transport
RemoteAddress string
ProxyAddress string
RemoteServer string RemoteServer string
ClientIp string
} }
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
@ -118,7 +118,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String()) log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
s.TransportOut = out s.TransportOut = out
out.SendAccept(true) out.SendAccept(true)
@ -139,13 +139,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
s.TransportIn = in s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration) c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String()) log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
in.SendAccept(false) in.SendAccept(false)
// read some initial data // read some initial data
in.Drain() in.Drain()
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String()) log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context()))
handler := NewHandler(s, g.HandlerConf) handler := NewHandler(s, g.HandlerConf)
handler.Process(r.Context()) handler.Process(r.Context())
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/protocol" "github.com/bolkedebruin/rdpgw/protocol"
"github.com/dgrijalva/jwt-go/v4" "github.com/dgrijalva/jwt-go/v4"
"log" "log"
@ -15,6 +16,7 @@ var ExpiryTime time.Duration = 5
type customClaims struct { type customClaims struct {
RemoteServer string `json:"remoteServer"` RemoteServer string `json:"remoteServer"`
ClientIP string `json:"clientIp"`
jwt.StandardClaims jwt.StandardClaims
} }
@ -34,6 +36,7 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
if c, ok := token.Claims.(*customClaims); ok && token.Valid { if c, ok := token.Claims.(*customClaims); ok && token.Valid {
s := getSessionInfo(ctx) s := getSessionInfo(ctx)
s.RemoteServer = c.RemoteServer s.RemoteServer = c.RemoteServer
s.ClientIp = client.GetClientIp(ctx)
return true, nil return true, nil
} }
@ -48,7 +51,13 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
} }
if s.RemoteServer != host { if s.RemoteServer != host {
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer) log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if s.ClientIp != client.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
client.GetClientIp(ctx), s.ClientIp)
return false, nil return false, nil
} }