Add states and more verifications

This commit is contained in:
Bolke de Bruin 2020-07-20 18:51:00 +02:00
parent f50dc2c82d
commit 6fd7b047cd
3 changed files with 71 additions and 16 deletions

View file

@ -29,12 +29,14 @@ type Handler struct {
TokenAuth bool TokenAuth bool
ClientName string ClientName string
Remote net.Conn Remote net.Conn
State int
} }
func NewHandler(in transport.Transport, out transport.Transport) *Handler { func NewHandler(in transport.Transport, out transport.Transport) *Handler {
h := &Handler{ h := &Handler{
TransportIn: in, TransportIn: in,
TransportOut: out, TransportOut: out,
State: SERVER_STATE_INITIAL,
} }
return h return h
} }
@ -49,35 +51,61 @@ func (h *Handler) Process() error {
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
if h.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
return errors.New("wrong state")
}
major, minor, _, auth := readHandshake(pkt) major, minor, _, auth := readHandshake(pkt)
msg := h.handshakeResponse(major, minor, auth) msg := h.handshakeResponse(major, minor, auth)
h.TransportOut.WritePacket(msg) h.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create") if h.State != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d",
h.State, SERVER_STATE_HANDSHAKE)
return errors.New("wrong state")
}
_, cookie := readCreateTunnelRequest(pkt) _, cookie := readCreateTunnelRequest(pkt)
if h.VerifyPAACookieFunc != nil { if h.VerifyPAACookieFunc != nil {
if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false { if ok, _ := h.VerifyPAACookieFunc(cookie); !ok {
log.Printf("Invalid PAA cookie: %s", cookie) log.Printf("Invalid PAA cookie: %s", cookie)
return errors.New("invalid PAA cookie") return errors.New("invalid PAA cookie")
} }
} }
msg := createTunnelResponse() msg := createTunnelResponse()
h.TransportOut.WritePacket(msg) h.TransportOut.WritePacket(msg)
log.Printf("Tunnel done") h.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth") if h.State != SERVER_STATE_TUNNEL_CREATE {
h.readTunnelAuthRequest(pkt) log.Printf("Tunnel auth attempted while in wrong state %d != %d",
h.State, SERVER_STATE_TUNNEL_CREATE)
return errors.New("wrong state")
}
client := h.readTunnelAuthRequest(pkt)
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
log.Printf("Invalid client name: %s", client)
return errors.New("invalid client name")
}
msg := h.createTunnelAuthResponse() msg := h.createTunnelAuthResponse()
h.TransportOut.WritePacket(msg) h.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
if h.State != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
h.State, SERVER_STATE_TUNNEL_AUTHORIZE)
return errors.New("wrong state")
}
server, port := readChannelCreateRequest(pkt) server, port := readChannelCreateRequest(pkt)
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) host := net.JoinHostPort(server, strconv.Itoa(int(port)))
h.Remote, err = net.DialTimeout( if h.VerifyServerFunc != nil {
"tcp", if ok, _ := h.VerifyServerFunc(host); !ok {
net.JoinHostPort(server, strconv.Itoa(int(port))), log.Printf("Not allowed to connect to %s by policy handler", host)
time.Second*15) }
}
log.Printf("Establishing connection to RDP server: %s", host)
h.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil { if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err) log.Printf("Error connecting to %s, %s", host, err)
return err return err
} }
log.Printf("Connection established") log.Printf("Connection established")
@ -87,14 +115,31 @@ func (h *Handler) Process() error {
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go h.sendDataPacket() go h.sendDataPacket()
h.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
if h.State != SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
h.State = SERVER_STATE_OPENED
h.forwardDataPacket(pkt) h.forwardDataPacket(pkt)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet
if h.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
// avoid concurrency issues // avoid concurrency issues
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
if h.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
return errors.New("wrong state")
}
h.TransportIn.Close() h.TransportIn.Close()
h.TransportOut.Close() h.TransportOut.Close()
h.State = SERVER_STATE_CLOSED
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt) log.Printf("Unknown packet (size %d): %x", sz, pkt)
} }
@ -223,7 +268,7 @@ func createTunnelResponse() []byte {
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()) return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
} }
func (h *Handler) readTunnelAuthRequest(data []byte) { func (h *Handler) readTunnelAuthRequest(data []byte) string {
buf := bytes.NewReader(data) buf := bytes.NewReader(data)
var size uint16 var size uint16
@ -231,7 +276,8 @@ func (h *Handler) readTunnelAuthRequest(data []byte) {
clData := make([]byte, size) clData := make([]byte, size)
binary.Read(buf, binary.LittleEndian, &clData) binary.Read(buf, binary.LittleEndian, &clData)
clientName, _ := DecodeUTF16(clData) clientName, _ := DecodeUTF16(clData)
log.Printf("Client: %s", clientName)
return clientName
} }
func (h *Handler) createTunnelAuthResponse() []byte { func (h *Handler) createTunnelAuthResponse() []byte {

View file

@ -58,3 +58,12 @@ const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
) )
const (
SERVER_STATE_INITIAL = 0x0
SERVER_STATE_HANDSHAKE = 0x1
SERVER_STATE_TUNNEL_CREATE = 0x2
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
SERVER_STATE_CHANNEL_CREATE = 0x4
SERVER_STATE_OPENED = 0x5
SERVER_STATE_CLOSED = 0x6
)