Make sure to validate the client's ip address

This commit is contained in:
Bolke de Bruin 2020-07-25 21:00:58 +02:00
parent 5fc75ef877
commit 0b299619ff
6 changed files with 70 additions and 10 deletions

View file

@ -25,7 +25,7 @@ Connect integration enabled by default. Cookies are encrypted and signed on the
on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens)
are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go)
signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by
the server.
the server. Finally, the client's ip address needs to match the one it obtained the token with.
## How to build
```bash

49
client/remote.go Normal file
View file

@ -0,0 +1,49 @@
package client
import (
"context"
"net/http"
"strings"
)
const (
ClientIPCtx = "ClientIP"
ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress"
)
func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h := r.Header.Get("X-Forwarded-For")
if h != "" {
var proxies []string
ips := strings.Split(h, ",")
for i := range ips {
ips[i] = strings.TrimSpace(ips[i])
}
clientIp := ips[0]
if len(ips) > 1 {
proxies = ips[1:]
}
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
}
remote := r.Header.Get("REMOTE_ADDR")
ctx = context.WithValue(ctx, RemoteAddressCtx, remote)
if h == "" {
ctx = context.WithValue(ctx, ClientIPCtx, remote)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetClientIp(ctx context.Context) string {
s, ok := ctx.Value(ClientIPCtx).(string)
if !ok {
return ""
}
return s
}

View file

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"github.com/bolkedebruin/rdpgw/api"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/config"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/security"
@ -122,8 +123,8 @@ func main() {
HandlerConf: &handlerConfig,
}
http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol)
http.Handle("/connect", api.Authenticated(http.HandlerFunc(api.HandleDownload)))
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/callback", api.HandleCallback)

View file

@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/client"
"io"
"log"
"net"
@ -96,7 +97,7 @@ func (h *Handler) Process(ctx context.Context) error {
_, cookie := readCreateTunnelRequest(pkt)
if h.VerifyTunnelCreate != nil {
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received")
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
return errors.New("invalid PAA cookie")
}
}

View file

@ -2,6 +2,7 @@ package protocol
import (
"context"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
@ -48,9 +49,8 @@ type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteAddress string
ProxyAddress string
RemoteServer string
ClientIp string
}
var upgrader = websocket.Upgrader{}
@ -118,7 +118,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
s.TransportOut = out
out.SendAccept(true)
@ -139,13 +139,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context()))
handler := NewHandler(s, g.HandlerConf)
handler.Process(r.Context())
}

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/dgrijalva/jwt-go/v4"
"log"
@ -15,6 +16,7 @@ var ExpiryTime time.Duration = 5
type customClaims struct {
RemoteServer string `json:"remoteServer"`
ClientIP string `json:"clientIp"`
jwt.StandardClaims
}
@ -34,6 +36,7 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
if c, ok := token.Claims.(*customClaims); ok && token.Valid {
s := getSessionInfo(ctx)
s.RemoteServer = c.RemoteServer
s.ClientIp = client.GetClientIp(ctx)
return true, nil
}
@ -48,7 +51,13 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
}
if s.RemoteServer != host {
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if s.ClientIp != client.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
client.GetClientIp(ctx), s.ClientIp)
return false, nil
}