mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-20 17:45:55 +02:00
Add states and more verifications
This commit is contained in:
parent
f50dc2c82d
commit
6fd7b047cd
3 changed files with 71 additions and 16 deletions
|
@ -21,7 +21,7 @@ type VerifyServerFunc func(string) (bool, error)
|
|||
|
||||
type Handler struct {
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
TransportOut transport.Transport
|
||||
VerifyPAACookieFunc VerifyPAACookieFunc
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
|
@ -29,12 +29,14 @@ type Handler struct {
|
|||
TokenAuth bool
|
||||
ClientName string
|
||||
Remote net.Conn
|
||||
State int
|
||||
}
|
||||
|
||||
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
|
||||
h := &Handler{
|
||||
TransportIn: in,
|
||||
TransportIn: in,
|
||||
TransportOut: out,
|
||||
State: SERVER_STATE_INITIAL,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
@ -49,35 +51,61 @@ func (h *Handler) Process() error {
|
|||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
if h.State != SERVER_STATE_INITIAL {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := h.handshakeResponse(major, minor, auth)
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.State = SERVER_STATE_HANDSHAKE
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
log.Printf("Tunnel create")
|
||||
if h.State != SERVER_STATE_HANDSHAKE {
|
||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||
h.State, SERVER_STATE_HANDSHAKE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
_, cookie := readCreateTunnelRequest(pkt)
|
||||
if h.VerifyPAACookieFunc != nil {
|
||||
if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
|
||||
if ok, _ := h.VerifyPAACookieFunc(cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie: %s", cookie)
|
||||
return errors.New("invalid PAA cookie")
|
||||
}
|
||||
}
|
||||
msg := createTunnelResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
log.Printf("Tunnel done")
|
||||
h.State = SERVER_STATE_TUNNEL_CREATE
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
log.Printf("Tunnel auth")
|
||||
h.readTunnelAuthRequest(pkt)
|
||||
if h.State != SERVER_STATE_TUNNEL_CREATE {
|
||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||
h.State, SERVER_STATE_TUNNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
client := h.readTunnelAuthRequest(pkt)
|
||||
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
}
|
||||
msg := h.createTunnelAuthResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
if h.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||
h.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
|
||||
h.Remote, err = net.DialTimeout(
|
||||
"tcp",
|
||||
net.JoinHostPort(server, strconv.Itoa(int(port))),
|
||||
time.Second*15)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if h.VerifyServerFunc != nil {
|
||||
if ok, _ := h.VerifyServerFunc(host); !ok {
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
h.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %d, %s", server, port, err)
|
||||
log.Printf("Error connecting to %s, %s", host, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
|
@ -87,14 +115,31 @@ func (h *Handler) Process() error {
|
|||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go h.sendDataPacket()
|
||||
h.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if h.State != SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
h.State = SERVER_STATE_OPENED
|
||||
h.forwardDataPacket(pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// keepalives can be received while the channel is not open yet
|
||||
if h.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
|
||||
// avoid concurrency issues
|
||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
if h.State != SERVER_STATE_OPENED {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
h.TransportIn.Close()
|
||||
h.TransportOut.Close()
|
||||
h.State = SERVER_STATE_CLOSED
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||
}
|
||||
|
@ -223,7 +268,7 @@ func createTunnelResponse() []byte {
|
|||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (h *Handler) readTunnelAuthRequest(data []byte) {
|
||||
func (h *Handler) readTunnelAuthRequest(data []byte) string {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var size uint16
|
||||
|
@ -231,7 +276,8 @@ func (h *Handler) readTunnelAuthRequest(data []byte) {
|
|||
clData := make([]byte, size)
|
||||
binary.Read(buf, binary.LittleEndian, &clData)
|
||||
clientName, _ := DecodeUTF16(clData)
|
||||
log.Printf("Client: %s", clientName)
|
||||
|
||||
return clientName
|
||||
}
|
||||
|
||||
func (h *Handler) createTunnelAuthResponse() []byte {
|
||||
|
|
|
@ -58,3 +58,12 @@ const (
|
|||
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
|
||||
)
|
||||
|
||||
const (
|
||||
SERVER_STATE_INITIAL = 0x0
|
||||
SERVER_STATE_HANDSHAKE = 0x1
|
||||
SERVER_STATE_TUNNEL_CREATE = 0x2
|
||||
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
|
||||
SERVER_STATE_CHANNEL_CREATE = 0x4
|
||||
SERVER_STATE_OPENED = 0x5
|
||||
SERVER_STATE_CLOSED = 0x6
|
||||
)
|
||||
|
|
|
@ -79,4 +79,4 @@ func (t *LegacyPKT) SendAccept(doSeed bool) {
|
|||
func (t *LegacyPKT) Drain() {
|
||||
p := make([]byte, 32767)
|
||||
t.Conn.Read(p)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue