mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-17 14:03:50 +02:00
Refactor names
This commit is contained in:
parent
29d4b276e6
commit
fe6509d8ca
5 changed files with 154 additions and 141 deletions
414
protocol/server.go
Normal file
414
protocol/server.go
Normal file
|
@ -0,0 +1,414 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/client"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
||||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||
|
||||
type RedirectFlags struct {
|
||||
Clipboard bool
|
||||
Port bool
|
||||
Drive bool
|
||||
Printer bool
|
||||
Pnp bool
|
||||
DisableAll bool
|
||||
EnableAll bool
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
Session *SessionInfo
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
RedirectFlags int
|
||||
IdleTimeout int
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
ClientName string
|
||||
Remote net.Conn
|
||||
State int
|
||||
}
|
||||
|
||||
type ServerConf struct {
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
RedirectFlags RedirectFlags
|
||||
IdleTimeout int
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
}
|
||||
|
||||
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
||||
h := &Server{
|
||||
State: SERVER_STATE_INITIAL,
|
||||
Session: s,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
IdleTimeout: conf.IdleTimeout,
|
||||
SmartCardAuth: conf.SmartCardAuth,
|
||||
TokenAuth: conf.TokenAuth,
|
||||
VerifyTunnelCreate: conf.VerifyTunnelCreate,
|
||||
VerifyServerFunc: conf.VerifyServerFunc,
|
||||
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
const tunnelId = 10
|
||||
|
||||
func (s *Server) Process(ctx context.Context) error {
|
||||
for {
|
||||
pt, sz, pkt, err := s.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("Cannot read message from stream %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx))
|
||||
if s.State != SERVER_STATE_INITIAL {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
|
||||
msg := s.handshakeResponse(major, minor)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_HANDSHAKE
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
log.Printf("Tunnel create")
|
||||
if s.State != SERVER_STATE_HANDSHAKE {
|
||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_HANDSHAKE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
_, cookie := s.tunnelRequest(pkt)
|
||||
if s.VerifyTunnelCreate != nil {
|
||||
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
|
||||
return errors.New("invalid PAA cookie")
|
||||
}
|
||||
}
|
||||
msg := s.tunnelResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_CREATE
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
log.Printf("Tunnel auth")
|
||||
if s.State != SERVER_STATE_TUNNEL_CREATE {
|
||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
client := s.tunnelAuthRequest(pkt)
|
||||
if s.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
}
|
||||
}
|
||||
msg := s.tunnelAuthResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
log.Printf("Channel create")
|
||||
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
server, port := s.channelRequest(pkt)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if s.VerifyServerFunc != nil {
|
||||
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
return errors.New("denied by security policy")
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %s", host, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
msg := s.channelResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go s.sendDataPacket()
|
||||
s.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
s.State = SERVER_STATE_OPENED
|
||||
s.forwardDataPacket(pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// keepalives can be received while the channel is not open yet
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
|
||||
// avoid concurrency issues
|
||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
log.Printf("Close channel")
|
||||
if s.State != SERVER_STATE_OPENED {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
s.Session.TransportIn.Close()
|
||||
s.Session.TransportOut.Close()
|
||||
s.State = SERVER_STATE_CLOSED
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||
fragment := false
|
||||
index := 0
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
size, pkt, err := s.Session.TransportIn.ReadPacket()
|
||||
if err != nil {
|
||||
return 0, 0, []byte{0, 0}, err
|
||||
}
|
||||
|
||||
// check for fragments
|
||||
var pt uint16
|
||||
var sz uint32
|
||||
var msg []byte
|
||||
|
||||
if !fragment {
|
||||
pt, sz, msg, err = readHeader(pkt[:size])
|
||||
if err != nil {
|
||||
fragment = true
|
||||
index = copy(buf, pkt[:size])
|
||||
continue
|
||||
}
|
||||
index = 0
|
||||
} else {
|
||||
fragment = false
|
||||
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
|
||||
// header is corrupted even after defragmenting
|
||||
if err != nil {
|
||||
return 0, 0, []byte{0, 0}, err
|
||||
}
|
||||
}
|
||||
if !fragment {
|
||||
return int(pt), int(sz), msg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a packet the is a response to a handshakeRequest request
|
||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||
// but could be in Windows. However the NTLM protocol is insecure
|
||||
func (s *Server) handshakeResponse(major byte, minor byte) []byte {
|
||||
var caps uint16
|
||||
if s.SmartCardAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||
}
|
||||
if s.TokenAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
|
||||
buf.Write([]byte{major, minor})
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||
binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
|
||||
|
||||
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &major)
|
||||
binary.Read(r, binary.LittleEndian, &minor)
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(r, binary.LittleEndian, &extAuth)
|
||||
|
||||
log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
binary.Read(r, binary.LittleEndian, &caps)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
|
||||
if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
|
||||
var size uint16
|
||||
binary.Read(r, binary.LittleEndian, &size)
|
||||
cookieB := make([]byte, size)
|
||||
r.Read(cookieB)
|
||||
cookie, _ = DecodeUTF16(cookieB)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) tunnelResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// tunnel id (when is it used?)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(tunnelId))
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(HTTP_CAPABILITY_IDLE_TIMEOUT))
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthRequest(data []byte) string {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var size uint16
|
||||
binary.Read(buf, binary.LittleEndian, &size)
|
||||
clData := make([]byte, size)
|
||||
binary.Read(buf, binary.LittleEndian, &clData)
|
||||
clientName, _ := DecodeUTF16(clData)
|
||||
|
||||
return clientName
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// idle timeout
|
||||
if s.IdleTimeout < 0 {
|
||||
s.IdleTimeout = 0
|
||||
}
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var resourcesSize byte
|
||||
var alternative byte
|
||||
var protocol uint16
|
||||
var nameSize uint16
|
||||
|
||||
binary.Read(buf, binary.LittleEndian, &resourcesSize)
|
||||
binary.Read(buf, binary.LittleEndian, &alternative)
|
||||
binary.Read(buf, binary.LittleEndian, &port)
|
||||
binary.Read(buf, binary.LittleEndian, &protocol)
|
||||
binary.Read(buf, binary.LittleEndian, &nameSize)
|
||||
|
||||
nameData := make([]byte, nameSize)
|
||||
binary.Read(buf, binary.LittleEndian, &nameData)
|
||||
|
||||
server, _ = DecodeUTF16(nameData)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) channelResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// channel id is required for Windows clients
|
||||
binary.Write(buf, binary.LittleEndian, uint32(1)) // channel id
|
||||
|
||||
// optional fields
|
||||
// channel id uint32 (4)
|
||||
// udp port uint16 (2)
|
||||
// udp auth cookie 1 byte for side channel
|
||||
// length uint16
|
||||
|
||||
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) forwardDataPacket(data []byte) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var cblen uint16
|
||||
binary.Read(buf, binary.LittleEndian, &cblen)
|
||||
pkt := make([]byte, cblen)
|
||||
binary.Read(buf, binary.LittleEndian, &pkt)
|
||||
|
||||
s.Remote.Write(pkt)
|
||||
}
|
||||
|
||||
func (s *Server) sendDataPacket() {
|
||||
defer s.Remote.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
for {
|
||||
n, err := s.Remote.Read(buf)
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
if err != nil {
|
||||
log.Printf("Error reading from conn %s", err)
|
||||
break
|
||||
}
|
||||
b1.Write(buf[:n])
|
||||
s.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func makeRedirectFlags(flags RedirectFlags) int {
|
||||
var redir = 0
|
||||
|
||||
if flags.DisableAll {
|
||||
return HTTP_TUNNEL_REDIR_DISABLE_ALL
|
||||
}
|
||||
if flags.EnableAll {
|
||||
return HTTP_TUNNEL_REDIR_ENABLE_ALL
|
||||
}
|
||||
|
||||
if !flags.Port {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
|
||||
}
|
||||
if !flags.Clipboard {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
|
||||
}
|
||||
if !flags.Drive {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
|
||||
}
|
||||
if !flags.Pnp {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
|
||||
}
|
||||
if !flags.Printer {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
|
||||
}
|
||||
return redir
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue