mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-19 15:03:48 +02:00
Rework tunnels to support statistics
This commit is contained in:
parent
eb1b287751
commit
94d7cddc4b
14 changed files with 74 additions and 45 deletions
|
@ -12,6 +12,8 @@ const (
|
||||||
ClientIPCtx = "ClientIP"
|
ClientIPCtx = "ClientIP"
|
||||||
ProxyAddressesCtx = "ProxyAddresses"
|
ProxyAddressesCtx = "ProxyAddresses"
|
||||||
RemoteAddressCtx = "RemoteAddress"
|
RemoteAddressCtx = "RemoteAddress"
|
||||||
|
TunnelCtx = "TUNNEL"
|
||||||
|
UsernameCtx = "preferred_username"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EnrichContext(next http.Handler) http.Handler {
|
func EnrichContext(next http.Handler) http.Handler {
|
||||||
|
|
|
@ -231,6 +231,8 @@ var gwserver *protocol.Gateway
|
||||||
func List(w http.ResponseWriter, r *http.Request) {
|
func List(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
for k, v := range protocol.Connections {
|
for k, v := range protocol.Connections {
|
||||||
fmt.Fprintf(w, "RDGId: %s Connected Since: %s User: %s \n", k, v.Since, v.SessionInfo.UserName)
|
fmt.Fprintf(w, "Id: %s Rdg-Id: %s User: %s From: %s Connected Since: %s Bytes Sent: %d Bytes Received: %d Last Seen: %s Target: %s\n",
|
||||||
|
k, v.Tunnel.RDGId, v.Tunnel.UserName, v.Tunnel.RemoteAddr, v.Tunnel.ConnectedOn, v.Tunnel.BytesSent, v.Tunnel.BytesReceived,
|
||||||
|
v.Tunnel.LastSeen, v.Tunnel.TargetServer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
||||||
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
||||||
}
|
}
|
||||||
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||||
go forward(c.LocalConn, c.Session.TransportOut)
|
//go forward(c.LocalConn, c.Session.TransportOut)
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
receive(pkt, c.LocalConn)
|
receive(pkt, c.LocalConn)
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -92,7 +92,7 @@ func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
||||||
func forward(in net.Conn, out transport.Transport) {
|
func forward(in net.Conn, tunnel *Tunnel) {
|
||||||
defer in.Close()
|
defer in.Close()
|
||||||
|
|
||||||
b1 := new(bytes.Buffer)
|
b1 := new(bytes.Buffer)
|
||||||
|
@ -106,7 +106,7 @@ func forward(in net.Conn, out transport.Transport) {
|
||||||
}
|
}
|
||||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||||
b1.Write(buf[:n])
|
b1.Write(buf[:n])
|
||||||
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
tunnel.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||||
b1.Reset()
|
b1.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"log"
|
"log"
|
||||||
|
@ -16,7 +17,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
rdgConnectionIdKey = "Rdg-Connection-RDGId"
|
rdgConnectionIdKey = "Rdg-Connection-Id"
|
||||||
MethodRDGIN = "RDG_IN_DATA"
|
MethodRDGIN = "RDG_IN_DATA"
|
||||||
MethodRDGOUT = "RDG_OUT_DATA"
|
MethodRDGOUT = "RDG_OUT_DATA"
|
||||||
)
|
)
|
||||||
|
@ -59,14 +60,19 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
var t *Tunnel
|
var t *Tunnel
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), common.TunnelCtx, t)
|
||||||
|
|
||||||
connId := r.Header.Get(rdgConnectionIdKey)
|
connId := r.Header.Get(rdgConnectionIdKey)
|
||||||
x, found := c.Get(connId)
|
x, found := c.Get(connId)
|
||||||
if !found {
|
if !found {
|
||||||
t = &Tunnel{RDGId: connId}
|
t = &Tunnel{
|
||||||
|
RDGId: connId,
|
||||||
|
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
|
||||||
|
UserName: ctx.Value(common.UsernameCtx).(string),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
t = x.(*Tunnel)
|
t = x.(*Tunnel)
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), "Tunnel", t)
|
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||||
|
@ -158,13 +164,14 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
||||||
inout, _ := transport.NewWS(c)
|
inout, _ := transport.NewWS(c)
|
||||||
defer inout.Close()
|
defer inout.Close()
|
||||||
|
|
||||||
|
t.Id = uuid.New().String()
|
||||||
t.TransportOut = inout
|
t.TransportOut = inout
|
||||||
t.TransportIn = inout
|
t.TransportIn = inout
|
||||||
t.ConnectedOn = time.Now()
|
t.ConnectedOn = time.Now()
|
||||||
|
|
||||||
handler := NewProcessor(g, t)
|
handler := NewProcessor(g, t)
|
||||||
RegisterConnection(handler, t)
|
RegisterTunnel(t, handler)
|
||||||
defer RemoveConnection(t.RDGId)
|
defer RemoveTunnel(t)
|
||||||
handler.Process(ctx)
|
handler.Process(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,6 +205,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
||||||
defer in.Close()
|
defer in.Close()
|
||||||
|
|
||||||
if t.TransportIn == nil {
|
if t.TransportIn == nil {
|
||||||
|
t.Id = uuid.New().String()
|
||||||
t.TransportIn = in
|
t.TransportIn = in
|
||||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
|
|
||||||
|
@ -209,8 +217,8 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
||||||
|
|
||||||
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
|
||||||
handler := NewProcessor(g, t)
|
handler := NewProcessor(g, t)
|
||||||
RegisterConnection(handler, t)
|
RegisterTunnel(t, handler)
|
||||||
defer RemoveConnection(t.RDGId)
|
defer RemoveTunnel(t)
|
||||||
handler.Process(r.Context())
|
handler.Process(r.Context())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
|
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
log.Printf("Client handshakeRequest from %p", common.GetClientIp(ctx))
|
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||||
if p.state != SERVER_STATE_INITIALIZED {
|
if p.state != SERVER_STATE_INITIALIZED {
|
||||||
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
log.Printf("Handshake attempted while in wrong state %d != %d", p.state, SERVER_STATE_INITIALIZED)
|
||||||
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
msg := p.handshakeResponse(0x0, 0x0, 0, E_PROXY_INTERNALERROR)
|
||||||
|
@ -77,7 +77,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
_, cookie := p.tunnelRequest(pkt)
|
_, cookie := p.tunnelRequest(pkt)
|
||||||
if p.gw.CheckPAACookie != nil {
|
if p.gw.CheckPAACookie != nil {
|
||||||
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
if ok, _ := p.gw.CheckPAACookie(ctx, cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie received from client %p", common.GetClientIp(ctx))
|
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||||
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
msg := p.tunnelResponse(E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
return fmt.Errorf("%x: invalid PAA cookie", E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED)
|
||||||
|
@ -98,7 +98,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
client := p.tunnelAuthRequest(pkt)
|
client := p.tunnelAuthRequest(pkt)
|
||||||
if p.gw.CheckClientName != nil {
|
if p.gw.CheckClientName != nil {
|
||||||
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
if ok, _ := p.gw.CheckClientName(ctx, client); !ok {
|
||||||
log.Printf("Invalid client name: %p", client)
|
log.Printf("Invalid client name: %s", client)
|
||||||
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
msg := p.tunnelAuthResponse(ERROR_ACCESS_DENIED)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
return fmt.Errorf("%x: Tunnel auth rejected, invalid client name", ERROR_ACCESS_DENIED)
|
||||||
|
@ -119,18 +119,18 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
server, port := p.channelRequest(pkt)
|
server, port := p.channelRequest(pkt)
|
||||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
if p.gw.CheckHost != nil {
|
if p.gw.CheckHost != nil {
|
||||||
log.Printf("Verifying %p host connection", host)
|
log.Printf("Verifying %s host connection", host)
|
||||||
if ok, _ := p.gw.CheckHost(ctx, host); !ok {
|
if ok, _ := p.gw.CheckHost(ctx, host); !ok {
|
||||||
log.Printf("Not allowed to connect to %p by policy handler", host)
|
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||||
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
msg := p.channelResponse(E_PROXY_RAP_ACCESSDENIED)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
return fmt.Errorf("%x: denied by security policy", E_PROXY_RAP_ACCESSDENIED)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("Establishing connection to RDP server: %p", host)
|
log.Printf("Establishing connection to RDP server: %s", host)
|
||||||
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
|
p.tunnel.rwc, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error connecting to %p, %p", host, err)
|
log.Printf("Error connecting to %s, %s", host, err)
|
||||||
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
msg := p.channelResponse(E_PROXY_INTERNALERROR)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
return err
|
return err
|
||||||
|
@ -142,7 +142,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
||||||
|
|
||||||
// Make sure to start the flow from the RDP server first otherwise connections
|
// Make sure to start the flow from the RDP server first otherwise connections
|
||||||
// might hang eventually
|
// might hang eventually
|
||||||
go forward(p.tunnel.rwc, p.tunnel.TransportOut)
|
go forward(p.tunnel.rwc, p.tunnel)
|
||||||
p.state = SERVER_STATE_CHANNEL_CREATE
|
p.state = SERVER_STATE_CHANNEL_CREATE
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
if p.state < SERVER_STATE_CHANNEL_CREATE {
|
||||||
|
|
|
@ -7,19 +7,19 @@ type Monitor struct {
|
||||||
Tunnel *Tunnel
|
Tunnel *Tunnel
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterConnection(h *Processor, t *Tunnel) {
|
func RegisterTunnel(t *Tunnel, p *Processor) {
|
||||||
if Connections == nil {
|
if Connections == nil {
|
||||||
Connections = make(map[string]*Monitor)
|
Connections = make(map[string]*Monitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
Connections[t.RDGId] = &Monitor{
|
Connections[t.Id] = &Monitor{
|
||||||
Processor: h,
|
Processor: p,
|
||||||
Tunnel: t,
|
Tunnel: t,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoveConnection(connId string) {
|
func RemoveTunnel(t *Tunnel) {
|
||||||
delete(Connections, connId)
|
delete(Connections, t.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateSpeedPerSecond calculate moving average.
|
// CalculateSpeedPerSecond calculate moving average.
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
|
// Id identifies the connection in the server
|
||||||
|
Id string
|
||||||
// The connection-id (RDG-ConnID) as reported by the client
|
// The connection-id (RDG-ConnID) as reported by the client
|
||||||
RDGId string
|
RDGId string
|
||||||
// The underlying incoming transport being either websocket or legacy http
|
// The underlying incoming transport being either websocket or legacy http
|
||||||
|
@ -26,18 +28,28 @@ type Tunnel struct {
|
||||||
// It is of the type *net.TCPConn
|
// It is of the type *net.TCPConn
|
||||||
rwc net.Conn
|
rwc net.Conn
|
||||||
|
|
||||||
ByteSent int64
|
// BytesSent is the total amount of bytes sent by the server to the client minus tunnel overhead
|
||||||
|
BytesSent int64
|
||||||
|
|
||||||
|
// BytesReceived is the total amount of bytes received by the server from the client minus tunnel overhad
|
||||||
BytesReceived int64
|
BytesReceived int64
|
||||||
|
|
||||||
|
// ConnectedOn is when the client connected to the server
|
||||||
ConnectedOn time.Time
|
ConnectedOn time.Time
|
||||||
|
|
||||||
|
// LastSeen is when the server received the last packet from the client
|
||||||
LastSeen time.Time
|
LastSeen time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write puts the packet on the transport and updates the statistics for bytes sent
|
||||||
func (t *Tunnel) Write(pkt []byte) {
|
func (t *Tunnel) Write(pkt []byte) {
|
||||||
n, _ := t.TransportOut.WritePacket(pkt)
|
n, _ := t.TransportOut.WritePacket(pkt)
|
||||||
t.ByteSent += int64(n)
|
t.BytesSent += int64(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read picks up a packet from the transport and returns the packet type
|
||||||
|
// packet, with the header removed, and the packet size. It updates the
|
||||||
|
// statistics for bytes received
|
||||||
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
||||||
pt, size, pkt, err = readMessage(t.TransportIn)
|
pt, size, pkt, err = readMessage(t.TransportIn)
|
||||||
t.BytesReceived += int64(size)
|
t.BytesReceived += int64(size)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -23,10 +24,10 @@ func CheckHost(ctx context.Context, host string) (bool, error) {
|
||||||
case "roundrobin", "unsigned":
|
case "roundrobin", "unsigned":
|
||||||
var username string
|
var username string
|
||||||
|
|
||||||
s := getSessionInfo(ctx)
|
s := getTunnel(ctx)
|
||||||
if s == nil || s.UserName == "" {
|
if s == nil || s.UserName == "" {
|
||||||
var ok bool
|
var ok bool
|
||||||
username, ok = ctx.Value("preferred_username").(string)
|
username, ok = ctx.Value(common.UsernameCtx).(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, errors.New("no valid session info or username found in context")
|
return false, errors.New("no valid session info or username found in context")
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package security
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -20,7 +21,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckHost(t *testing.T) {
|
func TestCheckHost(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), "Tunnel", &info)
|
ctx := context.WithValue(context.Background(), common.TunnelCtx, &info)
|
||||||
|
|
||||||
Hosts = hosts
|
Hosts = hosts
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ type customClaims struct {
|
||||||
|
|
||||||
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc {
|
||||||
return func(ctx context.Context, host string) (bool, error) {
|
return func(ctx context.Context, host string) (bool, error) {
|
||||||
s := getSessionInfo(ctx)
|
s := getTunnel(ctx)
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return false, errors.New("no valid session info found in context")
|
return false, errors.New("no valid session info found in context")
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
|
|
||||||
token, err := jwt.ParseSigned(tokenString)
|
token, err := jwt.ParseSigned(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot parse token due to: %s", err)
|
log.Printf("cannot parse token due to: %tunnel", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
// Claims automagically checks the signature...
|
// Claims automagically checks the signature...
|
||||||
err = token.Claims(SigningKey, &standard, &custom)
|
err = token.Claims(SigningKey, &standard, &custom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("token signature validation failed due to %s", err)
|
log.Printf("token signature validation failed due to %tunnel", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("token validation failed due to %s", err)
|
log.Printf("token validation failed due to %tunnel", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,15 +98,15 @@ func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) {
|
||||||
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
||||||
user, err := OIDCProvider.UserInfo(ctx, tokenSource)
|
user, err := OIDCProvider.UserInfo(ctx, tokenSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot get user info for access token: %s", err)
|
log.Printf("Cannot get user info for access token: %tunnel", err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s := getSessionInfo(ctx)
|
tunnel := getTunnel(ctx)
|
||||||
|
|
||||||
s.TargetServer = custom.RemoteServer
|
tunnel.TargetServer = custom.RemoteServer
|
||||||
s.RemoteAddr = custom.ClientIP
|
tunnel.RemoteAddr = custom.ClientIP
|
||||||
s.UserName = user.Subject
|
tunnel.UserName = user.Subject
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -288,7 +288,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
||||||
return token, err
|
return token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSessionInfo(ctx context.Context) *protocol.Tunnel {
|
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
||||||
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
|
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("cannot get session info from context")
|
log.Printf("cannot get session info from context")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
@ -12,7 +13,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protocol = "unix"
|
protocolGrpc = "unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BasicAuthHandler struct {
|
type BasicAuthHandler struct {
|
||||||
|
@ -27,7 +28,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
|
||||||
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
return net.Dial(protocol, addr)
|
return net.Dial(protocolGrpc, addr)
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot reach authentication provider: %s", err)
|
log.Printf("Cannot reach authentication provider: %s", err)
|
||||||
|
@ -51,7 +52,7 @@ func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if !res.Authenticated {
|
if !res.Authenticated {
|
||||||
log.Printf("User %s is not authenticated for this service", username)
|
log.Printf("User %s is not authenticated for this service", username)
|
||||||
} else {
|
} else {
|
||||||
ctx := context.WithValue(r.Context(), "preferred_username", username)
|
ctx := context.WithValue(r.Context(), common.UsernameCtx, username)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
@ -119,7 +120,7 @@ func (h *OIDC) Authenticated(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
ctx := context.WithValue(r.Context(), common.UsernameCtx, session.Values["preferred_username"])
|
||||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -6,6 +6,7 @@ require (
|
||||||
github.com/coreos/go-oidc/v3 v3.2.0
|
github.com/coreos/go-oidc/v3 v3.2.0
|
||||||
github.com/fatih/structs v1.1.0
|
github.com/fatih/structs v1.1.0
|
||||||
github.com/go-jose/go-jose/v3 v3.0.0
|
github.com/go-jose/go-jose/v3 v3.0.0
|
||||||
|
github.com/google/uuid v1.1.2
|
||||||
github.com/gorilla/sessions v1.2.1
|
github.com/gorilla/sessions v1.2.1
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/knadh/koanf v1.4.2
|
github.com/knadh/koanf v1.4.2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue