Add states and more verifications

This commit is contained in:
Bolke de Bruin 2020-07-20 18:51:00 +02:00
parent f50dc2c82d
commit 6fd7b047cd
3 changed files with 71 additions and 16 deletions

View file

@ -21,7 +21,7 @@ type VerifyServerFunc func(string) (bool, error)
type Handler struct {
TransportIn transport.Transport
TransportOut transport.Transport
TransportOut transport.Transport
VerifyPAACookieFunc VerifyPAACookieFunc
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
@ -29,12 +29,14 @@ type Handler struct {
TokenAuth bool
ClientName string
Remote net.Conn
State int
}
func NewHandler(in transport.Transport, out transport.Transport) *Handler {
h := &Handler{
TransportIn: in,
TransportIn: in,
TransportOut: out,
State: SERVER_STATE_INITIAL,
}
return h
}
@ -49,35 +51,61 @@ func (h *Handler) Process() error {
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
if h.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", h.State, SERVER_STATE_INITIAL)
return errors.New("wrong state")
}
major, minor, _, auth := readHandshake(pkt)
msg := h.handshakeResponse(major, minor, auth)
h.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
if h.State != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d",
h.State, SERVER_STATE_HANDSHAKE)
return errors.New("wrong state")
}
_, cookie := readCreateTunnelRequest(pkt)
if h.VerifyPAACookieFunc != nil {
if ok, _ := h.VerifyPAACookieFunc(cookie); ok == false {
if ok, _ := h.VerifyPAACookieFunc(cookie); !ok {
log.Printf("Invalid PAA cookie: %s", cookie)
return errors.New("invalid PAA cookie")
}
}
msg := createTunnelResponse()
h.TransportOut.WritePacket(msg)
log.Printf("Tunnel done")
h.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
h.readTunnelAuthRequest(pkt)
if h.State != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
h.State, SERVER_STATE_TUNNEL_CREATE)
return errors.New("wrong state")
}
client := h.readTunnelAuthRequest(pkt)
if ok, _ := h.VerifyTunnelAuthFunc(client); !ok {
log.Printf("Invalid client name: %s", client)
return errors.New("invalid client name")
}
msg := h.createTunnelAuthResponse()
h.TransportOut.WritePacket(msg)
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
if h.State != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
h.State, SERVER_STATE_TUNNEL_AUTHORIZE)
return errors.New("wrong state")
}
server, port := readChannelCreateRequest(pkt)
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
h.Remote, err = net.DialTimeout(
"tcp",
net.JoinHostPort(server, strconv.Itoa(int(port))),
time.Second*15)
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if h.VerifyServerFunc != nil {
if ok, _ := h.VerifyServerFunc(host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host)
}
}
log.Printf("Establishing connection to RDP server: %s", host)
h.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
log.Printf("Error connecting to %s, %s", host, err)
return err
}
log.Printf("Connection established")
@ -87,14 +115,31 @@ func (h *Handler) Process() error {
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go h.sendDataPacket()
h.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA:
if h.State != SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
h.State = SERVER_STATE_OPENED
h.forwardDataPacket(pkt)
case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet
if h.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", h.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
// avoid concurrency issues
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
if h.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
return errors.New("wrong state")
}
h.TransportIn.Close()
h.TransportOut.Close()
h.State = SERVER_STATE_CLOSED
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
}
@ -223,7 +268,7 @@ func createTunnelResponse() []byte {
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func (h *Handler) readTunnelAuthRequest(data []byte) {
func (h *Handler) readTunnelAuthRequest(data []byte) string {
buf := bytes.NewReader(data)
var size uint16
@ -231,7 +276,8 @@ func (h *Handler) readTunnelAuthRequest(data []byte) {
clData := make([]byte, size)
binary.Read(buf, binary.LittleEndian, &clData)
clientName, _ := DecodeUTF16(clData)
log.Printf("Client: %s", clientName)
return clientName
}
func (h *Handler) createTunnelAuthResponse() []byte {

View file

@ -58,3 +58,12 @@ const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
)
const (
SERVER_STATE_INITIAL = 0x0
SERVER_STATE_HANDSHAKE = 0x1
SERVER_STATE_TUNNEL_CREATE = 0x2
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
SERVER_STATE_CHANNEL_CREATE = 0x4
SERVER_STATE_OPENED = 0x5
SERVER_STATE_CLOSED = 0x6
)

View file

@ -79,4 +79,4 @@ func (t *LegacyPKT) SendAccept(doSeed bool) {
func (t *LegacyPKT) Drain() {
p := make([]byte, 32767)
t.Conn.Read(p)
}
}