mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-31 14:56:09 +02:00
Support fragmented packets for legacy protocol
This commit is contained in:
parent
47975e8f8c
commit
83aa49ad3b
1 changed files with 48 additions and 18 deletions
66
rdg.go
66
rdg.go
|
@ -154,10 +154,10 @@ var c = cache.New(5*time.Minute, 10*time.Minute)
|
|||
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
//if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
handleLegacyProtocol(w, r)
|
||||
return
|
||||
}
|
||||
//}
|
||||
r.Method = "GET" // force
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
|
@ -270,7 +270,11 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
conn, rw, _ := Accept(w)
|
||||
conn, rw, err := Accept(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
|
||||
|
||||
s.ConnOut = conn
|
||||
|
@ -283,10 +287,18 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var remote net.Conn
|
||||
|
||||
conn, rw, _ := Accept(w)
|
||||
conn, rw, err := Accept(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.ConnIn == nil {
|
||||
defer conn.Close()
|
||||
fragment := false
|
||||
index := 0
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
s.ConnIn = conn
|
||||
c.Set(connId, s, cache.DefaultExpiration)
|
||||
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
|
||||
|
@ -296,33 +308,51 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
|
||||
chunkScanner := httputil.NewChunkedReader(rw.Reader)
|
||||
packet := make([]byte, 4096) // bufio.defaultBufSize
|
||||
msg := make([]byte, 4096) // bufio.defaultBufSize
|
||||
|
||||
for {
|
||||
n, err := chunkScanner.Read(packet)
|
||||
n, err := chunkScanner.Read(msg)
|
||||
if err == io.EOF || n == 0 {
|
||||
break
|
||||
}
|
||||
packetType, size, packet, err := readHeader(packet)
|
||||
if err != nil {
|
||||
log.Printf("Need to deal with fragment %s", err)
|
||||
|
||||
// check for fragments
|
||||
var pt uint16
|
||||
var sz uint32
|
||||
var pkt []byte
|
||||
|
||||
if !fragment {
|
||||
pt, sz, pkt, err = readHeader(msg[:n])
|
||||
if err != nil {
|
||||
// fragment received
|
||||
log.Printf("Received non websocket fragment")
|
||||
fragment = true
|
||||
index = copy(buf, msg[:n])
|
||||
continue
|
||||
}
|
||||
index = 0
|
||||
} else {
|
||||
log.Printf("Dealing with fragment")
|
||||
fragment = false
|
||||
pt, sz, pkt, _ = readHeader(append(buf[:index], msg[:n]...))
|
||||
}
|
||||
log.Printf("Scanned packet got packet type %x size %d", packetType, size)
|
||||
switch packetType {
|
||||
|
||||
log.Printf("Scanned packet got packet type %x size %d", pt, sz)
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
major, minor, _, auth := readHandshake(packet)
|
||||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := handshakeResponse(major, minor, auth)
|
||||
s.ConnOut.Write(msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
readCreateTunnelRequest(packet)
|
||||
readCreateTunnelRequest(pkt)
|
||||
msg := createTunnelResponse()
|
||||
s.ConnOut.Write(msg)
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
readTunnelAuthRequest(packet)
|
||||
readTunnelAuthRequest(pkt)
|
||||
msg := createTunnelAuthResponse()
|
||||
s.ConnOut.Write(msg)
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
server, port := readChannelCreateRequest(packet)
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
var err error
|
||||
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
|
||||
if err != nil {
|
||||
|
@ -336,7 +366,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
// might hang eventually
|
||||
go sendDataPacket(remote, s.ConnOut)
|
||||
case PKT_TYPE_DATA:
|
||||
forwardDataPacket(remote, packet)
|
||||
forwardDataPacket(remote, pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// avoid concurrency issues
|
||||
// s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
|
@ -345,7 +375,7 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
s.ConnOut.Close()
|
||||
break
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", n, packet)
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue