rdpgw/protocol/handler.go
2020-07-20 15:51:23 +02:00

361 lines
10 KiB
Go

package protocol
import (
"bytes"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"log"
"net"
"strconv"
"time"
)
// When should the client disconnect when idle in minutes
var IdleTimeout = 0
type VerifyPAACookieFunc func(string) (bool, error)
type VerifyTunnelAuthFunc func(string) (bool, error)
type VerifyServerFunc func(string) (bool, error)
type Handler struct {
TransportIn transport.Transport
TransportOut transport.Transport
VerifyPAACookieFunc VerifyPAACookieFunc
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
SmartCardAuth bool
TokenAuth bool
ClientName string
Remote net.Conn
}
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
h := &Handler{
TransportIn: in,
TransportOut: out,
}
return h
}
func (h *Handler) Process() error {
for {
pt, sz, pkt, err := h.ReadMessage()
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := h.handshakeResponse(major, minor, auth)
h.TransportOut.WritePacket(msg)
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
_, cookie := readCreateTunnelRequest(pkt)
if h.VerifyPAACookieFunc != nil {
if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
log.Printf("Invalid PAA cookie: %s", cookie)
return errors.New("invalid PAA cookie")
}
}
msg := createTunnelResponse()
h.TransportOut.WritePacket(msg)
log.Printf("Tunnel done")
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
h.readTunnelAuthRequest(pkt)
msg := h.createTunnelAuthResponse()
h.TransportOut.WritePacket(msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
h.Remote, err = net.DialTimeout(
"tcp",
net.JoinHostPort(server, strconv.Itoa(int(port))),
time.Second*15)
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return err
}
log.Printf("Connection established")
msg := createChannelCreateResponse()
h.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go h.sendDataPacket()
case PKT_TYPE_DATA:
h.forwardDataPacket(pkt)
case PKT_TYPE_KEEPALIVE:
// avoid concurrency issues
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
h.TransportIn.Close()
h.TransportOut.Close()
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
}
}
}
func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := h.TransportIn.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
// Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
func (h *Handler) handshakeResponse(major byte, minor byte, auth uint16) []byte {
var caps uint16
if h.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA
}
if h.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA
}
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
buf.Write([]byte{major, minor})
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
}
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}
func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor)
binary.Read(r, binary.LittleEndian, &version)
binary.Read(r, binary.LittleEndian, &extAuth)
log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
return
}
func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &caps)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
var size uint16
binary.Read(r, binary.LittleEndian, &size)
cookieB := make([]byte, size)
r.Read(cookieB)
cookie, _ = DecodeUTF16(cookieB)
}
log.Printf("Create tunnel caps: %d, cookie: %s", caps, cookie)
return
}
func createTunnelResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// tunnel id ?
binary.Write(buf, binary.LittleEndian, uint32(15))
// caps ?
binary.Write(buf, binary.LittleEndian, uint32(2))
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func (h *Handler) readTunnelAuthRequest(data []byte) {
buf := bytes.NewReader(data)
var size uint16
binary.Read(buf, binary.LittleEndian, &size)
clData := make([]byte, size)
binary.Read(buf, binary.LittleEndian, &clData)
clientName, _ := DecodeUTF16(clData)
log.Printf("Client: %s", clientName)
}
func (h *Handler) createTunnelAuthResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// flags
var redir uint32
/*
if conf.Caps.RedirectAll {
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
} else if conf.Caps.DisableRedirect {
redir = HTTP_TUNNEL_REDIR_DISABLE_ALL
} else {
if conf.Caps.DisableClipboard {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
}
if conf.Caps.DisableDrive {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
}
if conf.Caps.DisablePnp {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
}
if conf.Caps.DisablePrinter {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
}
if conf.Caps.DisablePort {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
}
}
*/
redir = HTTP_TUNNEL_REDIR_ENABLE_ALL
// idle timeout
if IdleTimeout < 0 {
IdleTimeout = 0
}
binary.Write(buf, binary.LittleEndian, uint32(redir)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
}
func readChannelCreateRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data)
var resourcesSize byte
var alternative byte
var protocol uint16
var nameSize uint16
binary.Read(buf, binary.LittleEndian, &resourcesSize)
binary.Read(buf, binary.LittleEndian, &alternative)
binary.Read(buf, binary.LittleEndian, &port)
binary.Read(buf, binary.LittleEndian, &protocol)
binary.Read(buf, binary.LittleEndian, &nameSize)
nameData := make([]byte, nameSize)
binary.Read(buf, binary.LittleEndian, &nameData)
log.Printf("Name data %q", nameData)
server, _ = DecodeUTF16(nameData)
log.Printf("Should connect to %s on port %d", server, port)
return
}
func createChannelCreateResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// optional fields
// channel id uint32 (4)
// udp port uint16 (2)
// udp auth cookie 1 byte for side channel
// length uint16
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
}
func (h *Handler) forwardDataPacket(data []byte) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
h.Remote.Write(pkt)
}
func (h *Handler) sendDataPacket() {
defer h.Remote.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := h.Remote.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(pktType))
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint32(size))
buf.Write(data)
return buf.Bytes()
}