Check hostname specified by client against the token

This commit is contained in:
Bolke de Bruin 2020-07-25 19:37:33 +02:00
parent 39c73fc8fc
commit 5f3c7d07e2
5 changed files with 53 additions and 32 deletions

View file

@ -5,7 +5,6 @@ import (
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"log"
"net"
@ -13,9 +12,9 @@ import (
"time"
)
type VerifyTunnelCreate func(*SessionInfo, string) (bool, error)
type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error)
type VerifyServerFunc func(*SessionInfo, string) (bool, error)
type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type RedirectFlags struct {
Clipboard bool
@ -29,8 +28,6 @@ type RedirectFlags struct {
type Handler struct {
Session *SessionInfo
TransportIn transport.Transport
TransportOut transport.Transport
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
@ -55,10 +52,8 @@ type HandlerConf struct {
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
h := &Handler{
State: SERVER_STATE_INITIAL,
Session: s,
TransportIn: s.TransportIn,
TransportOut: s.TransportOut,
State: SERVER_STATE_INITIAL,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
IdleTimeout: conf.IdleTimeout,
SmartCardAuth: conf.SmartCardAuth,
@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error {
}
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
msg := h.handshakeResponse(major, minor)
h.TransportOut.WritePacket(msg)
h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error {
}
_, cookie := readCreateTunnelRequest(pkt)
if h.VerifyTunnelCreate != nil {
if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok {
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received")
return errors.New("invalid PAA cookie")
}
}
msg := createTunnelResponse()
h.TransportOut.WritePacket(msg)
h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error {
}
client := h.readTunnelAuthRequest(pkt)
if h.VerifyTunnelAuthFunc != nil {
if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok {
if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %s", client)
return errors.New("invalid client name")
}
}
msg := h.createTunnelAuthResponse()
h.TransportOut.WritePacket(msg)
h.Session.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create")
@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error {
server, port := readChannelCreateRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if h.VerifyServerFunc != nil {
if ok, _ := h.VerifyServerFunc(h.Session, host); !ok {
if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host)
return errors.New("denied by security policy")
}
}
log.Printf("Establishing connection to RDP server: %s", host)
@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error {
}
log.Printf("Connection established")
msg := createChannelCreateResponse()
h.TransportOut.WritePacket(msg)
h.Session.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error {
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
return errors.New("wrong state")
}
h.TransportIn.Close()
h.TransportOut.Close()
h.Session.TransportIn.Close()
h.Session.TransportOut.Close()
h.State = SERVER_STATE_CLOSED
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
buf := make([]byte, 4096)
for {
size, pkt, err := h.TransportIn.ReadPacket()
size, pkt, err := h.Session.TransportIn.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() {
break
}
b1.Write(buf[:n])
h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}