Normalize packet handling

This commit is contained in:
Bolke de Bruin 2020-07-20 14:29:24 +02:00
parent 46be8de038
commit 2f78a7fd8e
2 changed files with 84 additions and 72 deletions

73
protocol/handler.go Normal file
View file

@ -0,0 +1,73 @@
package protocol
import (
"bytes"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/transport"
"io"
)
type Handler struct {
Transport transport.Transport
}
func NewHandler(t transport.Transport) *Handler {
h := &Handler{
Transport: t,
}
return h
}
func (p *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := p.Transport.ReadPacket()
if err != nil {
return 0, 0, []byte{0,0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0,0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}

83
rdg.go
View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/transport" "github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
@ -156,46 +157,21 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
} }
func handleWebsocketProtocol(c *websocket.Conn) { func handleWebsocketProtocol(c *websocket.Conn) {
fragment := false
buf := make([]byte, 4096)
index := 0
var remote net.Conn var remote net.Conn
websocketConnections.Inc() websocketConnections.Inc()
defer websocketConnections.Dec() defer websocketConnections.Dec()
inout, _ := transport.NewWS(c) inout, _ := transport.NewWS(c)
handler := protocol.NewHandler(inout)
var host string var host string
for { for {
_, msg, err := inout.ReadPacket() pt, sz, pkt, err := handler.ReadMessage()
if err != nil { if err != nil {
log.Printf("Error read: %s", err) log.Printf("Cannot read message from stream %s", err)
break return
} }
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg)
if err != nil {
// fragment received
// log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
//log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
}
switch pt { switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt) major, minor, _, auth := readHandshake(pkt)
@ -301,10 +277,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
defer in.Close() defer in.Close()
if s.TransportIn == nil { if s.TransportIn == nil {
fragment := false
index := 0
buf := make([]byte, 4096)
s.TransportIn = in s.TransportIn = in
c.Set(connId, s, cache.DefaultExpiration) c.Set(connId, s, cache.DefaultExpiration)
@ -315,30 +287,12 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
in.Drain() in.Drain()
log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String()) log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String())
handler := protocol.NewHandler(in)
for { for {
n, msg, err := in.ReadPacket() pt, sz, pkt, err := handler.ReadMessage()
if err == io.EOF || n == 0 { if err != nil {
break log.Printf("Cannot read message from stream %s", err)
} return
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg[:n])
if err != nil {
// fragment received
fragment = true
index = copy(buf, msg[:n])
continue
}
index = 0
} else {
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
} }
switch pt { switch pt {
@ -386,28 +340,13 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
s.TransportOut.Close() s.TransportOut.Close()
break break
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n]) log.Printf("Unknown packet (size %d): %x", sz, pkt)
} }
} }
} }
} }
} }
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}
// Creates a packet the is a response to a handshake request // Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux // HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure // but could be in Windows. However the NTLM protocol is insecure