Factor websocket handling out

This commit is contained in:
Bolke de Bruin 2020-07-12 21:15:38 +02:00
parent 80d11598ec
commit 60c5a76350

168
rdg.go
View file

@ -6,18 +6,18 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
//"net/http/httputil"
"strconv" "strconv"
"time" "time"
"unicode/utf16" "unicode/utf16"
"unicode/utf8" "unicode/utf8"
"github.com/gorilla/websocket"
) )
const ( const (
@ -110,7 +110,7 @@ var ErrNotHijacker = RejectConnectionError(
var DefaultSession RdgSession var DefaultSession RdgSession
func Upgrade(next http.Handler) http.Handler { func Upgrade(next http.Handler) http.Handler {
return RdgHandshake(next) return handleGatewayProtocol(next)
} }
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
@ -129,10 +129,9 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err
} }
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute)
func RdgHandshake(next http.Handler) http.Handler { func handleGatewayProtocol(next http.Handler) http.Handler {
c := cache.New(5*time.Minute, 10*time.Minute)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var s RdgSession var s RdgSession
@ -150,88 +149,21 @@ func RdgHandshake(next http.Handler) http.Handler {
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
r.Method = "GET" // force r.Method = "GET" // force
c, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err) log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return return
} }
defer c.Close() defer conn.Close()
fragment := false handleWebsocketProtocol(conn)
buf := make([]byte, 4096)
index := 0
for {
mt, msg, err := c.ReadMessage()
if err != nil {
log.Printf("Error read: %s", err)
break
}
log.Printf("Message type: %d, message: %x", mt, msg)
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment { //conn, rw, _ := Accept(w)
pt, sz, pkt, err = readHeader(msg) //log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
}
switch pt { //s.ConnOut = conn
case PKT_TYPE_HANDSHAKE_REQUEST: //WriteAcceptSeed(rw.Writer, true)
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg)
c.WriteMessage(mt, msg)
go handleWebsocketData(s.Remote, mt, c)
case PKT_TYPE_DATA:
forwardDataPacket(s.Remote, pkt)
case PKT_TYPE_KEEPALIVE:
c.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
s.Remote.Close()
return
default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
}
}
conn, rw, _ := Accept(w)
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
s.ConnOut = conn
WriteAcceptSeed(rw.Writer, true)
//c.Set(connId, s, cache.DefaultExpiration) //c.Set(connId, s, cache.DefaultExpiration)
} /*else if r.Method == MethodRDGIN { } /*else if r.Method == MethodRDGIN {
@ -303,6 +235,82 @@ func RdgHandshake(next http.Handler) http.Handler {
}) })
} }
func handleWebsocketProtocol(conn *websocket.Conn) {
fragment := false
buf := make([]byte, 4096)
index := 0
var remote net.Conn
for {
mt, msg, err := conn.ReadMessage()
if err != nil {
log.Printf("Error read: %s", err)
break
}
log.Printf("Message type: %d, message: %x", mt, msg)
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg)
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg)
conn.WriteMessage(mt, msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg)
conn.WriteMessage(mt, msg)
go handleWebsocketData(remote, mt, conn)
case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE:
conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
remote.Close()
return
default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
}
}
}
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0 // [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes). // The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does // This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does