mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
Make sure to validate the client's ip address
This commit is contained in:
parent
5fc75ef877
commit
0b299619ff
6 changed files with 70 additions and 10 deletions
|
@ -25,7 +25,7 @@ Connect integration enabled by default. Cookies are encrypted and signed on the
|
||||||
on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens)
|
on [Gorilla Sessions](https://www.gorillatoolkit.org/pkg/sessions). PAA tokens (gateway access tokens)
|
||||||
are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go)
|
are generated and signed according to the JWT spec by using [jwt-go](https://github.com/dgrijalva/jwt-go)
|
||||||
signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by
|
signed with a 512 bit HMAC. Hosts provided by the user are verified against what was provided by
|
||||||
the server.
|
the server. Finally, the client's ip address needs to match the one it obtained the token with.
|
||||||
|
|
||||||
## How to build
|
## How to build
|
||||||
```bash
|
```bash
|
||||||
|
|
49
client/remote.go
Normal file
49
client/remote.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ClientIPCtx = "ClientIP"
|
||||||
|
ProxyAddressesCtx = "ProxyAddresses"
|
||||||
|
RemoteAddressCtx = "RemoteAddress"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EnrichContext(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
h := r.Header.Get("X-Forwarded-For")
|
||||||
|
if h != "" {
|
||||||
|
var proxies []string
|
||||||
|
ips := strings.Split(h, ",")
|
||||||
|
for i := range ips {
|
||||||
|
ips[i] = strings.TrimSpace(ips[i])
|
||||||
|
}
|
||||||
|
clientIp := ips[0]
|
||||||
|
if len(ips) > 1 {
|
||||||
|
proxies = ips[1:]
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
||||||
|
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
|
||||||
|
}
|
||||||
|
|
||||||
|
remote := r.Header.Get("REMOTE_ADDR")
|
||||||
|
ctx = context.WithValue(ctx, RemoteAddressCtx, remote)
|
||||||
|
if h == "" {
|
||||||
|
ctx = context.WithValue(ctx, ClientIPCtx, remote)
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientIp(ctx context.Context) string {
|
||||||
|
s, ok := ctx.Value(ClientIPCtx).(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
5
main.go
5
main.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"github.com/bolkedebruin/rdpgw/api"
|
"github.com/bolkedebruin/rdpgw/api"
|
||||||
|
"github.com/bolkedebruin/rdpgw/client"
|
||||||
"github.com/bolkedebruin/rdpgw/config"
|
"github.com/bolkedebruin/rdpgw/config"
|
||||||
"github.com/bolkedebruin/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/protocol"
|
||||||
"github.com/bolkedebruin/rdpgw/security"
|
"github.com/bolkedebruin/rdpgw/security"
|
||||||
|
@ -122,8 +123,8 @@ func main() {
|
||||||
HandlerConf: &handlerConfig,
|
HandlerConf: &handlerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
http.HandleFunc("/remoteDesktopGateway/", gw.HandleGatewayProtocol)
|
http.Handle("/remoteDesktopGateway/", client.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||||
http.Handle("/connect", api.Authenticated(http.HandlerFunc(api.HandleDownload)))
|
http.Handle("/connect", client.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
http.HandleFunc("/callback", api.HandleCallback)
|
http.HandleFunc("/callback", api.HandleCallback)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/bolkedebruin/rdpgw/client"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
@ -96,7 +97,7 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
_, cookie := readCreateTunnelRequest(pkt)
|
||||||
if h.VerifyTunnelCreate != nil {
|
if h.VerifyTunnelCreate != nil {
|
||||||
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
|
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received")
|
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
|
||||||
return errors.New("invalid PAA cookie")
|
return errors.New("invalid PAA cookie")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/bolkedebruin/rdpgw/client"
|
||||||
"github.com/bolkedebruin/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/transport"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
@ -48,9 +49,8 @@ type SessionInfo struct {
|
||||||
ConnId string
|
ConnId string
|
||||||
TransportIn transport.Transport
|
TransportIn transport.Transport
|
||||||
TransportOut transport.Transport
|
TransportOut transport.Transport
|
||||||
RemoteAddress string
|
|
||||||
ProxyAddress string
|
|
||||||
RemoteServer string
|
RemoteServer string
|
||||||
|
ClientIp string
|
||||||
}
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
|
@ -118,7 +118,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
|
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
|
||||||
|
|
||||||
s.TransportOut = out
|
s.TransportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
@ -139,13 +139,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||||
s.TransportIn = in
|
s.TransportIn = in
|
||||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
log.Printf("Legacy handshake done for client %s", client.GetClientIp(r.Context()))
|
||||||
handler := NewHandler(s, g.HandlerConf)
|
handler := NewHandler(s, g.HandlerConf)
|
||||||
handler.Process(r.Context())
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bolkedebruin/rdpgw/client"
|
||||||
"github.com/bolkedebruin/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/protocol"
|
||||||
"github.com/dgrijalva/jwt-go/v4"
|
"github.com/dgrijalva/jwt-go/v4"
|
||||||
"log"
|
"log"
|
||||||
|
@ -15,6 +16,7 @@ var ExpiryTime time.Duration = 5
|
||||||
|
|
||||||
type customClaims struct {
|
type customClaims struct {
|
||||||
RemoteServer string `json:"remoteServer"`
|
RemoteServer string `json:"remoteServer"`
|
||||||
|
ClientIP string `json:"clientIp"`
|
||||||
jwt.StandardClaims
|
jwt.StandardClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,6 +36,7 @@ func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
||||||
if c, ok := token.Claims.(*customClaims); ok && token.Valid {
|
if c, ok := token.Claims.(*customClaims); ok && token.Valid {
|
||||||
s := getSessionInfo(ctx)
|
s := getSessionInfo(ctx)
|
||||||
s.RemoteServer = c.RemoteServer
|
s.RemoteServer = c.RemoteServer
|
||||||
|
s.ClientIp = client.GetClientIp(ctx)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +51,13 @@ func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.RemoteServer != host {
|
if s.RemoteServer != host {
|
||||||
log.Printf("Client host %s does not match token host %s", host, s.RemoteServer)
|
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ClientIp != client.GetClientIp(ctx) {
|
||||||
|
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||||
|
client.GetClientIp(ctx), s.ClientIp)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue