more cleanups

This commit is contained in:
Bolke de Bruin 2020-07-20 12:52:46 +02:00
parent 9209f9152d
commit 46be8de038
2 changed files with 22 additions and 36 deletions

56
rdg.go
View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -15,7 +14,6 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"unicode/utf16" "unicode/utf16"
"unicode/utf8" "unicode/utf8"
@ -117,8 +115,8 @@ type RdgSession struct {
ConnId string ConnId string
CorrelationId string CorrelationId string
UserId string UserId string
TransportIn transport.HttpLayer TransportIn transport.Transport
TransportOut transport.HttpLayer TransportOut transport.Transport
StateIn int StateIn int
StateOut int StateOut int
Remote net.Conn Remote net.Conn
@ -133,21 +131,6 @@ var ErrNotHijacker = RejectConnectionError(
var DefaultSession RdgSession var DefaultSession RdgSession
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
log.Print("Accept connection")
hj, ok := w.(http.Hijacker)
if ok {
return hj.Hijack()
} else {
err = ErrNotHijacker
}
if err != nil {
httpError(w, err.Error(), http.StatusInternalServerError)
return nil, nil, err
}
return
}
var upgrader = websocket.Upgrader{} var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
@ -172,7 +155,7 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
} }
} }
func handleWebsocketProtocol(conn *websocket.Conn) { func handleWebsocketProtocol(c *websocket.Conn) {
fragment := false fragment := false
buf := make([]byte, 4096) buf := make([]byte, 4096)
index := 0 index := 0
@ -182,9 +165,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
websocketConnections.Inc() websocketConnections.Inc()
defer websocketConnections.Dec() defer websocketConnections.Dec()
inout, _ := transport.NewWS(c)
var host string var host string
for { for {
mt, msg, err := conn.ReadMessage() _, msg, err := inout.ReadPacket()
if err != nil { if err != nil {
log.Printf("Error read: %s", err) log.Printf("Error read: %s", err)
break break
@ -216,28 +201,29 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
major, minor, _, auth := readHandshake(pkt) major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth) msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg) log.Printf("Handshake response: %x", msg)
conn.WriteMessage(mt, msg) inout.WritePacket(msg)
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
_, cookie := readCreateTunnelRequest(pkt) readCreateTunnelRequest(pkt)
data, found := tokens.Get(cookie) /*data, found := tokens.Get(cookie)
if found == false { if found == false {
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
return return
} }*/
host = conf.Server.HostTemplate host = conf.Server.HostTemplate
/*
for k, v := range data.(map[string]interface{}) { for k, v := range data.(map[string]interface{}) {
if val, ok := v.(string); ok == true { if val, ok := v.(string); ok == true {
host = strings.Replace(host, "{{ " + k + " }}", val, 1) host = strings.Replace(host, "{{ " + k + " }}", val, 1)
} }
} }*/
msg := createTunnelResponse() msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg) log.Printf("Create tunnel response: %x", msg)
conn.WriteMessage(mt, msg) inout.WritePacket(msg)
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt) readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse() msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg) log.Printf("Create tunnel auth response: %x", msg)
conn.WriteMessage(mt, msg) inout.WritePacket(msg)
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt) server, port := readChannelCreateRequest(pkt)
if conf.Server.EnableOverride == true { if conf.Server.EnableOverride == true {
@ -256,13 +242,13 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
log.Printf("Connection established") log.Printf("Connection established")
msg := createChannelCreateResponse() msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg) log.Printf("Create channel create response: %x", msg)
conn.WriteMessage(mt, msg) inout.WritePacket(msg)
go handleWebsocketData(remote, mt, conn) go sendDataPacket(remote, inout)
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt) forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// do not write to make sure we do not create concurrency issues // do not write to make sure we do not create concurrency issues
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
break break
default: default:
@ -612,7 +598,7 @@ func forwardDataPacket(conn net.Conn, data []byte) {
conn.Write(pkt) conn.Write(pkt)
} }
func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) { func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
defer rdp.Close() defer rdp.Close()
b1 := new(bytes.Buffer) b1 := new(bytes.Buffer)
buf := make([]byte, 4086) buf := make([]byte, 4086)
@ -625,12 +611,12 @@ func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
break break
} }
b1.Write(buf[:n]) b1.Write(buf[:n])
conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes())) conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset() b1.Reset()
} }
} }
func sendDataPacket(connIn net.Conn, connOut transport.HttpLayer) { func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
defer connIn.Close() defer connIn.Close()
b1 := new(bytes.Buffer) b1 := new(bytes.Buffer)
buf := make([]byte, 4086) buf := make([]byte, 4086)

View file

@ -1,6 +1,6 @@
package transport package transport
type HttpLayer interface { type Transport interface {
ReadPacket() (n int, p []byte, err error) ReadPacket() (n int, p []byte, err error)
WritePacket(b []byte) (n int, err error) WritePacket(b []byte) (n int, err error)
Close() error Close() error