mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-03 16:21:54 +02:00
Factor websocket handling out
This commit is contained in:
parent
80d11598ec
commit
60c5a76350
1 changed files with 88 additions and 80 deletions
168
rdg.go
168
rdg.go
|
@ -6,18 +6,18 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
//"net/http/httputil"
|
||||
"strconv"
|
||||
|
||||
"time"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -110,7 +110,7 @@ var ErrNotHijacker = RejectConnectionError(
|
|||
var DefaultSession RdgSession
|
||||
|
||||
func Upgrade(next http.Handler) http.Handler {
|
||||
return RdgHandshake(next)
|
||||
return handleGatewayProtocol(next)
|
||||
}
|
||||
|
||||
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
|
||||
|
@ -129,10 +129,9 @@ func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err err
|
|||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
func RdgHandshake(next http.Handler) http.Handler {
|
||||
c := cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
func handleGatewayProtocol(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var s RdgSession
|
||||
|
||||
|
@ -150,88 +149,21 @@ func RdgHandshake(next http.Handler) http.Handler {
|
|||
|
||||
if r.Method == MethodRDGOUT {
|
||||
r.Method = "GET" // force
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
defer conn.Close()
|
||||
|
||||
fragment := false
|
||||
buf := make([]byte, 4096)
|
||||
index := 0
|
||||
for {
|
||||
mt, msg, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("Error read: %s", err)
|
||||
break
|
||||
}
|
||||
log.Printf("Message type: %d, message: %x", mt, msg)
|
||||
handleWebsocketProtocol(conn)
|
||||
|
||||
// check for fragments
|
||||
var pt uint16
|
||||
var sz uint32
|
||||
var pkt []byte
|
||||
|
||||
if !fragment {
|
||||
pt, sz, pkt, err = readHeader(msg)
|
||||
if err != nil {
|
||||
// fragment received
|
||||
log.Printf("Received non websocket fragment")
|
||||
fragment = true
|
||||
index = copy(buf, msg)
|
||||
continue
|
||||
}
|
||||
index = 0
|
||||
} else {
|
||||
log.Printf("Dealing with fragment")
|
||||
fragment = false
|
||||
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
|
||||
}
|
||||
//conn, rw, _ := Accept(w)
|
||||
//log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := handshakeResponse(major, minor, auth)
|
||||
log.Printf("Handshake response: %x", msg)
|
||||
c.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
readCreateTunnelRequest(pkt)
|
||||
msg := createTunnelResponse()
|
||||
log.Printf("Create tunnel response: %x", msg)
|
||||
c.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
readTunnelAuthRequest(pkt)
|
||||
msg := createTunnelAuthResponse()
|
||||
log.Printf("Create tunnel auth response: %x", msg)
|
||||
c.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %d, %s", server, port, err)
|
||||
return
|
||||
}
|
||||
msg := createChannelCreateResponse()
|
||||
log.Printf("Create channel create response: %x", msg)
|
||||
c.WriteMessage(mt, msg)
|
||||
go handleWebsocketData(s.Remote, mt, c)
|
||||
case PKT_TYPE_DATA:
|
||||
forwardDataPacket(s.Remote, pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
c.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
s.Remote.Close()
|
||||
return
|
||||
default:
|
||||
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
|
||||
}
|
||||
}
|
||||
conn, rw, _ := Accept(w)
|
||||
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
|
||||
|
||||
s.ConnOut = conn
|
||||
WriteAcceptSeed(rw.Writer, true)
|
||||
//s.ConnOut = conn
|
||||
//WriteAcceptSeed(rw.Writer, true)
|
||||
|
||||
//c.Set(connId, s, cache.DefaultExpiration)
|
||||
} /*else if r.Method == MethodRDGIN {
|
||||
|
@ -303,6 +235,82 @@ func RdgHandshake(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||
fragment := false
|
||||
buf := make([]byte, 4096)
|
||||
index := 0
|
||||
|
||||
var remote net.Conn
|
||||
|
||||
for {
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("Error read: %s", err)
|
||||
break
|
||||
}
|
||||
log.Printf("Message type: %d, message: %x", mt, msg)
|
||||
|
||||
// check for fragments
|
||||
var pt uint16
|
||||
var sz uint32
|
||||
var pkt []byte
|
||||
|
||||
if !fragment {
|
||||
pt, sz, pkt, err = readHeader(msg)
|
||||
if err != nil {
|
||||
// fragment received
|
||||
log.Printf("Received non websocket fragment")
|
||||
fragment = true
|
||||
index = copy(buf, msg)
|
||||
continue
|
||||
}
|
||||
index = 0
|
||||
} else {
|
||||
log.Printf("Dealing with fragment")
|
||||
fragment = false
|
||||
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
|
||||
}
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := handshakeResponse(major, minor, auth)
|
||||
log.Printf("Handshake response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
readCreateTunnelRequest(pkt)
|
||||
msg := createTunnelResponse()
|
||||
log.Printf("Create tunnel response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
readTunnelAuthRequest(pkt)
|
||||
msg := createTunnelAuthResponse()
|
||||
log.Printf("Create tunnel auth response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %d, %s", server, port, err)
|
||||
return
|
||||
}
|
||||
msg := createChannelCreateResponse()
|
||||
log.Printf("Create channel create response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
go handleWebsocketData(remote, mt, conn)
|
||||
case PKT_TYPE_DATA:
|
||||
forwardDataPacket(remote, pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
remote.Close()
|
||||
return
|
||||
default:
|
||||
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
|
||||
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
|
||||
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue