Support fragmented packets for legacy protocol

This commit is contained in:
Bolke de Bruin 2020-07-14 09:44:24 +02:00
parent 47975e8f8c
commit 83aa49ad3b

66
rdg.go
View file

@ -154,10 +154,10 @@ var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount()))
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
//if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r)
return
}
//}
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
@ -270,7 +270,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil)
if r.Method == MethodRDGOUT {
conn, rw, _ := Accept(w)
conn, rw, err := Accept(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
s.ConnOut = conn
@ -283,10 +287,18 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
var remote net.Conn
conn, rw, _ := Accept(w)
conn, rw, err := Accept(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
return
}
defer conn.Close()
if s.ConnIn == nil {
defer conn.Close()
fragment := false
index := 0
buf := make([]byte, 4096)
s.ConnIn = conn
c.Set(connId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
@ -296,33 +308,51 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
chunkScanner := httputil.NewChunkedReader(rw.Reader)
packet := make([]byte, 4096) // bufio.defaultBufSize
msg := make([]byte, 4096) // bufio.defaultBufSize
for {
n, err := chunkScanner.Read(packet)
n, err := chunkScanner.Read(msg)
if err == io.EOF || n == 0 {
break
}
packetType, size, packet, err := readHeader(packet)
if err != nil {
log.Printf("Need to deal with fragment %s", err)
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg[:n])
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg[:n])
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
}
log.Printf("Scanned packet got packet type %x size %d", packetType, size)
switch packetType {
log.Printf("Scanned packet got packet type %x size %d", pt, sz)
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(packet)
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
s.ConnOut.Write(msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(packet)
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
s.ConnOut.Write(msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(packet)
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
s.ConnOut.Write(msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(packet)
server, port := readChannelCreateRequest(pkt)
var err error
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
@ -336,7 +366,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
// might hang eventually
go sendDataPacket(remote, s.ConnOut)
case PKT_TYPE_DATA:
forwardDataPacket(remote, packet)
forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE:
// avoid concurrency issues
// s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
@ -345,7 +375,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
s.ConnOut.Close()
break
default:
log.Printf("Unknown packet (size %d): %x", n, packet)
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
}
}
}