mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-24 11:28:24 +02:00
more cleanups
This commit is contained in:
parent
9209f9152d
commit
46be8de038
2 changed files with 22 additions and 36 deletions
56
rdg.go
56
rdg.go
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
@ -15,7 +14,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
|
@ -117,8 +115,8 @@ type RdgSession struct {
|
|||
ConnId string
|
||||
CorrelationId string
|
||||
UserId string
|
||||
TransportIn transport.HttpLayer
|
||||
TransportOut transport.HttpLayer
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
StateIn int
|
||||
StateOut int
|
||||
Remote net.Conn
|
||||
|
@ -133,21 +131,6 @@ var ErrNotHijacker = RejectConnectionError(
|
|||
|
||||
var DefaultSession RdgSession
|
||||
|
||||
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
|
||||
log.Print("Accept connection")
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if ok {
|
||||
return hj.Hijack()
|
||||
} else {
|
||||
err = ErrNotHijacker
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil, nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
|
@ -172,7 +155,7 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||
func handleWebsocketProtocol(c *websocket.Conn) {
|
||||
fragment := false
|
||||
buf := make([]byte, 4096)
|
||||
index := 0
|
||||
|
@ -182,9 +165,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
|||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
inout, _ := transport.NewWS(c)
|
||||
|
||||
var host string
|
||||
for {
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
_, msg, err := inout.ReadPacket()
|
||||
if err != nil {
|
||||
log.Printf("Error read: %s", err)
|
||||
break
|
||||
|
@ -216,28 +201,29 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
|||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := handshakeResponse(major, minor, auth)
|
||||
log.Printf("Handshake response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
_, cookie := readCreateTunnelRequest(pkt)
|
||||
data, found := tokens.Get(cookie)
|
||||
readCreateTunnelRequest(pkt)
|
||||
/*data, found := tokens.Get(cookie)
|
||||
if found == false {
|
||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
}*/
|
||||
host = conf.Server.HostTemplate
|
||||
/*
|
||||
for k, v := range data.(map[string]interface{}) {
|
||||
if val, ok := v.(string); ok == true {
|
||||
host = strings.Replace(host, "{{ " + k + " }}", val, 1)
|
||||
}
|
||||
}
|
||||
}*/
|
||||
msg := createTunnelResponse()
|
||||
log.Printf("Create tunnel response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
readTunnelAuthRequest(pkt)
|
||||
msg := createTunnelAuthResponse()
|
||||
log.Printf("Create tunnel auth response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
if conf.Server.EnableOverride == true {
|
||||
|
@ -256,13 +242,13 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
|||
log.Printf("Connection established")
|
||||
msg := createChannelCreateResponse()
|
||||
log.Printf("Create channel create response: %x", msg)
|
||||
conn.WriteMessage(mt, msg)
|
||||
go handleWebsocketData(remote, mt, conn)
|
||||
inout.WritePacket(msg)
|
||||
go sendDataPacket(remote, inout)
|
||||
case PKT_TYPE_DATA:
|
||||
forwardDataPacket(remote, pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// do not write to make sure we do not create concurrency issues
|
||||
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
// inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
break
|
||||
default:
|
||||
|
@ -612,7 +598,7 @@ func forwardDataPacket(conn net.Conn, data []byte) {
|
|||
conn.Write(pkt)
|
||||
}
|
||||
|
||||
func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
|
||||
func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
|
||||
defer rdp.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
|
@ -625,12 +611,12 @@ func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
|
|||
break
|
||||
}
|
||||
b1.Write(buf[:n])
|
||||
conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func sendDataPacket(connIn net.Conn, connOut transport.HttpLayer) {
|
||||
func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
|
||||
defer connIn.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package transport
|
||||
|
||||
type HttpLayer interface {
|
||||
type Transport interface {
|
||||
ReadPacket() (n int, p []byte, err error)
|
||||
WritePacket(b []byte) (n int, err error)
|
||||
Close() error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue