mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-21 01:55:57 +02:00
Add states and more verifications
This commit is contained in:
parent
f50dc2c82d
commit
6fd7b047cd
3 changed files with 71 additions and 16 deletions
|
@ -29,12 +29,14 @@ type Handler struct {
|
||||||
TokenAuth bool
|
TokenAuth bool
|
||||||
ClientName string
|
ClientName string
|
||||||
Remote net.Conn
|
Remote net.Conn
|
||||||
|
State int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
|
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
TransportIn: in,
|
TransportIn: in,
|
||||||
TransportOut: out,
|
TransportOut: out,
|
||||||
|
State: SERVER_STATE_INITIAL,
|
||||||
}
|
}
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
@ -49,35 +51,61 @@ func (h *Handler) Process() error {
|
||||||
|
|
||||||
switch pt {
|
switch pt {
|
||||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||||
|
if h.State != SERVER_STATE_INITIAL {
|
||||||
|
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
major, minor, _, auth := readHandshake(pkt)
|
major, minor, _, auth := readHandshake(pkt)
|
||||||
msg := h.handshakeResponse(major, minor, auth)
|
msg := h.handshakeResponse(major, minor, auth)
|
||||||
h.TransportOut.WritePacket(msg)
|
h.TransportOut.WritePacket(msg)
|
||||||
|
h.State = SERVER_STATE_HANDSHAKE
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
log.Printf("Tunnel create")
|
if h.State != SERVER_STATE_HANDSHAKE {
|
||||||
|
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||||
|
h.State, SERVER_STATE_HANDSHAKE)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
_, cookie := readCreateTunnelRequest(pkt)
|
||||||
if h.VerifyPAACookieFunc != nil {
|
if h.VerifyPAACookieFunc != nil {
|
||||||
if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
|
if ok, _ := h.VerifyPAACookieFunc(cookie); !ok {
|
||||||
log.Printf("Invalid PAA cookie: %s", cookie)
|
log.Printf("Invalid PAA cookie: %s", cookie)
|
||||||
return errors.New("invalid PAA cookie")
|
return errors.New("invalid PAA cookie")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.TransportOut.WritePacket(msg)
|
||||||
log.Printf("Tunnel done")
|
h.State = SERVER_STATE_TUNNEL_CREATE
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
log.Printf("Tunnel auth")
|
if h.State != SERVER_STATE_TUNNEL_CREATE {
|
||||||
h.readTunnelAuthRequest(pkt)
|
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||||
|
h.State, SERVER_STATE_TUNNEL_CREATE)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
|
client := h.readTunnelAuthRequest(pkt)
|
||||||
|
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
||||||
|
log.Printf("Invalid client name: %s", client)
|
||||||
|
return errors.New("invalid client name")
|
||||||
|
}
|
||||||
msg := h.createTunnelAuthResponse()
|
msg := h.createTunnelAuthResponse()
|
||||||
h.TransportOut.WritePacket(msg)
|
h.TransportOut.WritePacket(msg)
|
||||||
|
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
|
if h.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||||
|
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||||
|
h.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := readChannelCreateRequest(pkt)
|
||||||
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
|
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||||
h.Remote, err = net.DialTimeout(
|
if h.VerifyServerFunc != nil {
|
||||||
"tcp",
|
if ok, _ := h.VerifyServerFunc(host); !ok {
|
||||||
net.JoinHostPort(server, strconv.Itoa(int(port))),
|
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||||
time.Second*15)
|
}
|
||||||
|
}
|
||||||
|
log.Printf("Establishing connection to RDP server: %s", host)
|
||||||
|
h.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error connecting to %s, %d, %s", server, port, err)
|
log.Printf("Error connecting to %s, %s", host, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
|
@ -87,14 +115,31 @@ func (h *Handler) Process() error {
|
||||||
// Make sure to start the flow from the RDP server first otherwise connections
|
// Make sure to start the flow from the RDP server first otherwise connections
|
||||||
// might hang eventually
|
// might hang eventually
|
||||||
go h.sendDataPacket()
|
go h.sendDataPacket()
|
||||||
|
h.State = SERVER_STATE_CHANNEL_CREATE
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
|
if h.State != SERVER_STATE_CHANNEL_CREATE {
|
||||||
|
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
|
h.State = SERVER_STATE_OPENED
|
||||||
h.forwardDataPacket(pkt)
|
h.forwardDataPacket(pkt)
|
||||||
case PKT_TYPE_KEEPALIVE:
|
case PKT_TYPE_KEEPALIVE:
|
||||||
|
// keepalives can be received while the channel is not open yet
|
||||||
|
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
||||||
|
log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
|
|
||||||
// avoid concurrency issues
|
// avoid concurrency issues
|
||||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||||
case PKT_TYPE_CLOSE_CHANNEL:
|
case PKT_TYPE_CLOSE_CHANNEL:
|
||||||
|
if h.State != SERVER_STATE_OPENED {
|
||||||
|
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
||||||
|
return errors.New("wrong state")
|
||||||
|
}
|
||||||
h.TransportIn.Close()
|
h.TransportIn.Close()
|
||||||
h.TransportOut.Close()
|
h.TransportOut.Close()
|
||||||
|
h.State = SERVER_STATE_CLOSED
|
||||||
default:
|
default:
|
||||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||||
}
|
}
|
||||||
|
@ -223,7 +268,7 @@ func createTunnelResponse() []byte {
|
||||||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) readTunnelAuthRequest(data []byte) {
|
func (h *Handler) readTunnelAuthRequest(data []byte) string {
|
||||||
buf := bytes.NewReader(data)
|
buf := bytes.NewReader(data)
|
||||||
|
|
||||||
var size uint16
|
var size uint16
|
||||||
|
@ -231,7 +276,8 @@ func (h *Handler) readTunnelAuthRequest(data []byte) {
|
||||||
clData := make([]byte, size)
|
clData := make([]byte, size)
|
||||||
binary.Read(buf, binary.LittleEndian, &clData)
|
binary.Read(buf, binary.LittleEndian, &clData)
|
||||||
clientName, _ := DecodeUTF16(clData)
|
clientName, _ := DecodeUTF16(clData)
|
||||||
log.Printf("Client: %s", clientName)
|
|
||||||
|
return clientName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) createTunnelAuthResponse() []byte {
|
func (h *Handler) createTunnelAuthResponse() []byte {
|
||||||
|
|
|
@ -58,3 +58,12 @@ const (
|
||||||
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
|
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SERVER_STATE_INITIAL = 0x0
|
||||||
|
SERVER_STATE_HANDSHAKE = 0x1
|
||||||
|
SERVER_STATE_TUNNEL_CREATE = 0x2
|
||||||
|
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
|
||||||
|
SERVER_STATE_CHANNEL_CREATE = 0x4
|
||||||
|
SERVER_STATE_OPENED = 0x5
|
||||||
|
SERVER_STATE_CLOSED = 0x6
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue