Working websockets

This commit is contained in:
Bolke de Bruin 2020-07-09 10:15:27 +02:00
parent 858159632a
commit 80d11598ec
3 changed files with 345 additions and 344 deletions

458
rdg.go
View file

@ -4,22 +4,25 @@ import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/patrickmn/go-cache"
"io"
"log"
"math/rand"
"net"
"net/http"
//"net/http/httputil"
"strconv"
"time"
"unicode/utf16"
"unicode/utf8"
"github.com/gorilla/websocket"
)
const (
crlf = "\r\n"
crlf = "\r\n"
rdgConnectionIdKey = "Rdg-Connection-Id"
HANDSHAKE = 1
)
const (
@ -40,6 +43,13 @@ const (
PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
)
const (
HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01
HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02
HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04
HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
)
const (
HTTP_EXTENDED_AUTH_NONE = 0x0
HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */
@ -47,6 +57,28 @@ const (
HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
)
const (
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
)
const (
HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000
HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000
HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01
HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02
HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x03
HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10
)
const (
HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01
HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04
)
const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
)
@ -63,10 +95,9 @@ type RdgSession struct {
UserId string
ConnIn net.Conn
ConnOut net.Conn
BufOut *bufio.Writer
BufIn *bufio.Reader
State int
Remote net.Conn
StateIn int
StateOut int
Remote net.Conn
}
// ErrNotHijacker is an error returned when http.ResponseWriter does not
@ -79,89 +110,196 @@ var ErrNotHijacker = RejectConnectionError(
var DefaultSession RdgSession
func Upgrade(next http.Handler) http.Handler {
return DefaultSession.RdgHandshake(next)
return RdgHandshake(next)
}
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
log.Print("Accept connection")
hj, ok := w.(http.Hijacker)
if ok {
return hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return nil, nil, err
}
return
log.Print("Accept connection")
hj, ok := w.(http.Hijacker)
if ok {
return hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return nil, nil, err
}
return
}
func (s RdgSession) RdgHandshake(next http.Handler) http.Handler {
var upgrader = websocket.Upgrader{}
func RdgHandshake(next http.Handler) http.Handler {
c := cache.New(5*time.Minute, 10*time.Minute)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
/*_, _, ok := r.BasicAuth()
var s RdgSession
if !ok && s.ConnIn == nil {
w.Header().Set("WWW-Authenticate", `Basic realm="rdpgw"`)
w.WriteHeader(401)
w.Write([]byte("Unauthorized.\n"))
fmt.Println("Unauthorized")
return
}*/
connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId)
if !found {
log.Printf("No cached session found")
s = RdgSession{ConnId: connId, StateIn: 0, StateOut: 0}
} else {
log.Printf("Found cached session")
s = x.(RdgSession)
}
log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil)
conn, rw, _ := Accept(w)
if r.Method == MethodRDGOUT {
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
s.ConnId = r.Header.Get(rdgConnectionIdKey)
s.ConnOut = conn
s.BufOut = rw.Writer
WriteAcceptSeed(rw.Writer)
rw.Writer.Flush()
} else if r.Method == MethodRDGIN {
if s.ConnIn == nil {
defer conn.Close()
s.ConnIn = conn
s.BufIn = rw.Reader
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
WriteAcceptSeed(rw.Writer)
rw.Writer.Flush()
p := make([]byte, 4096)
rw.Reader.Read(p)
//log.Printf("Read %q", p)
r.Method = "GET" // force
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return
}
defer c.Close()
log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
scanner := bufio.NewScanner(rw.Reader)
scanner.Split(ReadPacket)
for scanner.Scan() {
packet := scanner.Bytes()
packetType, size, _, packet := readHeader(packet)
log.Printf("Scanned packet got packet type %x size %d", packetType, size)
switch packetType {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(packet)
sendHandshakeResponse(s.BufOut, major, minor, auth)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(packet)
sendCreateTunnelResponse(s.BufOut)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(packet)
sendTunnelAuthResponse(s.BufOut)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(packet)
var err error
s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
sendChannelCreateResponse(s.BufOut)
go sendDataPacket(s.Remote, s.BufOut)
case PKT_TYPE_DATA:
receiveDataPacket(s.Remote, packet)
fragment := false
buf := make([]byte, 4096)
index := 0
for {
mt, msg, err := c.ReadMessage()
if err != nil {
log.Printf("Error read: %s", err)
break
}
log.Printf("Message type: %d, message: %x", mt, msg)
// check for fragments
var pt uint16
var sz uint32
var pkt []byte
if !fragment {
pt, sz, pkt, err = readHeader(msg)
if err != nil {
// fragment received
log.Printf("Received non websocket fragment")
fragment = true
index = copy(buf, msg)
continue
}
index = 0
} else {
log.Printf("Dealing with fragment")
fragment = false
pt, sz, pkt, _ = readHeader(append(buf[:index], msg...))
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg)
c.WriteMessage(mt, msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg)
c.WriteMessage(mt, msg)
go handleWebsocketData(s.Remote, mt, c)
case PKT_TYPE_DATA:
forwardDataPacket(s.Remote, pkt)
case PKT_TYPE_KEEPALIVE:
c.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
s.Remote.Close()
return
default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz)
}
}
conn, rw, _ := Accept(w)
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String())
s.ConnOut = conn
WriteAcceptSeed(rw.Writer, true)
//c.Set(connId, s, cache.DefaultExpiration)
} /*else if r.Method == MethodRDGIN {
if !checkNTLMAuth(w, &s, "IN") {
c.Set(connId, s, cache.DefaultExpiration)
return
}
conn, rw, _ := Accept(w)
if s.ConnIn == nil {
defer conn.Close()
s.ConnIn = conn
c.Set(connId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
WriteAcceptSeed(rw.Writer, false)
p := make([]byte, 32767)
rw.Reader.Read(p)
//log.Printf("Read %q", p)
log.Printf("Reading packet from client %s", conn.RemoteAddr().String())
chunkScanner := httputil.NewChunkedReader(rw.Reader)
packet := make([]byte, 4096) // bufio.defaultBufSize
for {
n, err := chunkScanner.Read(packet)
if err == io.EOF || n == 0 {
break
}
old_packet := packet
packetType, size, _, packet := readHeader(packet)
log.Printf("Scanned packet got packet type %x size %d", packetType, size)
switch packetType {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(packet)
sendHandshakeResponse(s.ConnOut, major, minor, auth)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(packet)
sendCreateTunnelResponse(s.ConnOut)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(packet)
sendTunnelAuthResponse(s.ConnOut)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(packet)
var err error
s.Remote, err = net.Dial("tcp", net.JoinHostPort(server, strconv.Itoa(int(port))))
if err != nil {
log.Printf("Error connecting to %s, %d, %s", server, port, err)
return
}
sendChannelCreateResponse(s.ConnOut)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go sendDataPacket(s.Remote, s.ConnOut)
case PKT_TYPE_DATA:
receiveDataPacket(s.Remote, packet)
case PKT_TYPE_KEEPALIVE:
s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
s.ConnIn.Close()
s.ConnOut.Close()
break
default:
log.Printf("UNKNOWN PACKET (%d): %x", n, old_packet[:n])
//receiveDataPacket(s.Remote, old_packet)
receiveUnknownPacket(s.Remote, old_packet, n)
}
}
}*/
})
}
@ -170,58 +308,50 @@ func (s RdgSession) RdgHandshake(next http.Handler) http.Handler {
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func WriteAcceptSeed(bw *bufio.Writer) {
func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) {
bw.WriteString(HttpOK)
bw.WriteString("Server: Microsoft-HTTPAPI/2.0\r\n")
bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n")
bw.WriteString("Content-Type: application/octet-stream\r\n")
bw.WriteString("Content-Length: 0\r\n")
if !doSeed {
bw.WriteString("Content-Length: 0\r\n")
}
bw.WriteString(crlf)
seed := make([]byte, 10)
rand.Read(seed)
bw.Write(seed)
if doSeed {
seed := make([]byte, 10)
rand.Read(seed)
// docs say it's a seed but 2019 responds with ab cd * 5
bw.Write(seed)
}
bw.Flush()
}
func ReadPacket(data []byte, atEOF bool) (advance int, packet []byte, err error) {
log.Printf("Reading data len = %d", len(data))
if atEOF && len(data) == 0 {
return 0, nil, nil
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
if i := bytes.Index(data, []byte{'\r', '\n'}); i >= 0 {
//log.Printf("Got rn at %d ", i)
chunkSize, err := strconv.ParseInt(string(data[0:i]), 16, 0)
log.Printf("chunkSize %d", chunkSize)
if err != nil {
return i + 2, data[0:i], err
}
//log.Printf("Return %d", i+2+int(chunkSize)+2)
return i + 2 + int(chunkSize) + 2, data[i+2 : i+2+int(chunkSize)+2], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
}
func readHeader(data []byte) (packetType uint16, size uint32, advance int, remain []byte) {
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
return packetType, size, 8, data[8:]
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:], nil
}
func sendHandshakeResponse(w *bufio.Writer, major byte, minor byte, auth uint16) {
// Creates a packet the is a response to a handshake request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
func handshakeResponse(major byte, minor byte, auth uint16) []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
buf.Write([]byte{major, minor})
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(2)) // PAA
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(HTTP_EXTENDED_AUTH_PAA|HTTP_EXTENDED_AUTH_SC)) // extended auth
w.Write(createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes()))
w.Flush()
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
}
func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
@ -235,7 +365,7 @@ func readHandshake(data []byte) (major byte, minor byte, version uint16, extAuth
return
}
func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){
func readCreateTunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16
r := bytes.NewReader(data)
@ -255,16 +385,21 @@ func readCreateTunnelRequest(data []byte) (caps uint32, cookie string){
return
}
func sendCreateTunnelResponse(w *bufio.Writer) {
func createTunnelResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
w.Write(createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes()))
w.Flush()
// tunnel id ?
binary.Write(buf, binary.LittleEndian, uint32(15))
// caps ?
binary.Write(buf, binary.LittleEndian, uint32(2))
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func readTunnelAuthRequest(data []byte) {
@ -278,18 +413,21 @@ func readTunnelAuthRequest(data []byte) {
log.Printf("Client: %s", clientName)
}
func sendTunnelAuthResponse(w *bufio.Writer) {
func createTunnelAuthResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
w.Write(createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes()))
w.Flush()
// flags
binary.Write(buf, binary.LittleEndian, uint32(HTTP_TUNNEL_REDIR_ENABLE_ALL)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(0)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
}
func readChannelCreateRequest(data []byte) (server string, port uint16){
func readChannelCreateRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data)
var resourcesSize byte
@ -313,55 +451,55 @@ func readChannelCreateRequest(data []byte) (server string, port uint16){
return
}
func sendChannelCreateResponse(w *bufio.Writer) {
func createChannelCreateResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields present
//binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // fields
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
w.Write(createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes()))
w.Flush()
// optional fields
// channel id uint32 (4)
// udp port uint16 (2)
// udp auth cookie 1 byte for side channel
// length uint16
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
}
func createPacket(pktType uint16, data []byte) (packet []byte){
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)
log.Printf("Data sent Size: %d", size)
// http chunk size in hex string
// fmt.Fprintf(buf,"%x\r\n", size)
binary.Write(buf, binary.LittleEndian, uint16(pktType))
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint32(size))
buf.Write(data)
// http close crlf
// buf.Write([]byte(crlf))
// log.Printf("data sent: %q", buf.Bytes())
return buf.Bytes()
}
func receiveDataPacket(conn net.Conn, data []byte) {
func forwardDataPacket(conn net.Conn, data []byte) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
log.Printf("Received PKT_DATA %d", cblen)
//log.Printf("Received PKT_DATA %d", cblen)
pkt := make([]byte, cblen)
//binary.Read(buf, binary.LittleEndian, &pkt)
buf.Read(pkt)
binary.Read(buf, binary.LittleEndian, &pkt)
//n, _ := buf.Read(pkt)
//log.Printf("CBLEN: %d, N: %d", cblen, n)
//log.Printf("DATA FROM CLIENT %q", pkt)
conn.Write(pkt)
}
func sendDataPacket(conn net.Conn, w *bufio.Writer) {
defer conn.Close()
func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
defer rdp.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 32767)
buf := make([]byte, 4086)
for {
n, err := conn.Read(buf)
n, err := rdp.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
log.Printf("RDP SIZE: %d", n)
if err != nil {
@ -369,16 +507,32 @@ func sendDataPacket(conn net.Conn, w *bufio.Writer) {
break
}
b1.Write(buf[:n])
w.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
w.Flush()
conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func sendDataPacket(connIn net.Conn, connOut net.Conn) {
defer connIn.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := connIn.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
log.Printf("RDP SIZE: %d", n)
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
connOut.Write(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func DecodeUTF16(b []byte) (string, error) {
if len(b)%2 != 0 {
log.Printf("Error decoding utf16")
return "", fmt.Errorf("Must have even length byte slice")
return "", fmt.Errorf("must have even length byte slice")
}
u16s := make([]uint16, 1)
@ -394,4 +548,4 @@ func DecodeUTF16(b []byte) (string, error) {
}
return ret.String(), nil
}
}