mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-07-25 20:08:20 +02:00
more cleanups
This commit is contained in:
parent
9209f9152d
commit
46be8de038
2 changed files with 22 additions and 36 deletions
56
rdg.go
56
rdg.go
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -15,7 +14,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -117,8 +115,8 @@ type RdgSession struct {
|
||||||
ConnId string
|
ConnId string
|
||||||
CorrelationId string
|
CorrelationId string
|
||||||
UserId string
|
UserId string
|
||||||
TransportIn transport.HttpLayer
|
TransportIn transport.Transport
|
||||||
TransportOut transport.HttpLayer
|
TransportOut transport.Transport
|
||||||
StateIn int
|
StateIn int
|
||||||
StateOut int
|
StateOut int
|
||||||
Remote net.Conn
|
Remote net.Conn
|
||||||
|
@ -133,21 +131,6 @@ var ErrNotHijacker = RejectConnectionError(
|
||||||
|
|
||||||
var DefaultSession RdgSession
|
var DefaultSession RdgSession
|
||||||
|
|
||||||
func Accept(w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) {
|
|
||||||
log.Print("Accept connection")
|
|
||||||
hj, ok := w.(http.Hijacker)
|
|
||||||
if ok {
|
|
||||||
return hj.Hijack()
|
|
||||||
} else {
|
|
||||||
err = ErrNotHijacker
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
httpError(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{}
|
var upgrader = websocket.Upgrader{}
|
||||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
|
||||||
|
@ -172,7 +155,7 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleWebsocketProtocol(conn *websocket.Conn) {
|
func handleWebsocketProtocol(c *websocket.Conn) {
|
||||||
fragment := false
|
fragment := false
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
index := 0
|
index := 0
|
||||||
|
@ -182,9 +165,11 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
websocketConnections.Inc()
|
websocketConnections.Inc()
|
||||||
defer websocketConnections.Dec()
|
defer websocketConnections.Dec()
|
||||||
|
|
||||||
|
inout, _ := transport.NewWS(c)
|
||||||
|
|
||||||
var host string
|
var host string
|
||||||
for {
|
for {
|
||||||
mt, msg, err := conn.ReadMessage()
|
_, msg, err := inout.ReadPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error read: %s", err)
|
log.Printf("Error read: %s", err)
|
||||||
break
|
break
|
||||||
|
@ -216,28 +201,29 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
major, minor, _, auth := readHandshake(pkt)
|
major, minor, _, auth := readHandshake(pkt)
|
||||||
msg := handshakeResponse(major, minor, auth)
|
msg := handshakeResponse(major, minor, auth)
|
||||||
log.Printf("Handshake response: %x", msg)
|
log.Printf("Handshake response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
inout.WritePacket(msg)
|
||||||
case PKT_TYPE_TUNNEL_CREATE:
|
case PKT_TYPE_TUNNEL_CREATE:
|
||||||
_, cookie := readCreateTunnelRequest(pkt)
|
readCreateTunnelRequest(pkt)
|
||||||
data, found := tokens.Get(cookie)
|
/*data, found := tokens.Get(cookie)
|
||||||
if found == false {
|
if found == false {
|
||||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr())
|
log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
|
||||||
return
|
return
|
||||||
}
|
}*/
|
||||||
host = conf.Server.HostTemplate
|
host = conf.Server.HostTemplate
|
||||||
|
/*
|
||||||
for k, v := range data.(map[string]interface{}) {
|
for k, v := range data.(map[string]interface{}) {
|
||||||
if val, ok := v.(string); ok == true {
|
if val, ok := v.(string); ok == true {
|
||||||
host = strings.Replace(host, "{{ " + k + " }}", val, 1)
|
host = strings.Replace(host, "{{ " + k + " }}", val, 1)
|
||||||
}
|
}
|
||||||
}
|
}*/
|
||||||
msg := createTunnelResponse()
|
msg := createTunnelResponse()
|
||||||
log.Printf("Create tunnel response: %x", msg)
|
log.Printf("Create tunnel response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
inout.WritePacket(msg)
|
||||||
case PKT_TYPE_TUNNEL_AUTH:
|
case PKT_TYPE_TUNNEL_AUTH:
|
||||||
readTunnelAuthRequest(pkt)
|
readTunnelAuthRequest(pkt)
|
||||||
msg := createTunnelAuthResponse()
|
msg := createTunnelAuthResponse()
|
||||||
log.Printf("Create tunnel auth response: %x", msg)
|
log.Printf("Create tunnel auth response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
inout.WritePacket(msg)
|
||||||
case PKT_TYPE_CHANNEL_CREATE:
|
case PKT_TYPE_CHANNEL_CREATE:
|
||||||
server, port := readChannelCreateRequest(pkt)
|
server, port := readChannelCreateRequest(pkt)
|
||||||
if conf.Server.EnableOverride == true {
|
if conf.Server.EnableOverride == true {
|
||||||
|
@ -256,13 +242,13 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
|
||||||
log.Printf("Connection established")
|
log.Printf("Connection established")
|
||||||
msg := createChannelCreateResponse()
|
msg := createChannelCreateResponse()
|
||||||
log.Printf("Create channel create response: %x", msg)
|
log.Printf("Create channel create response: %x", msg)
|
||||||
conn.WriteMessage(mt, msg)
|
inout.WritePacket(msg)
|
||||||
go handleWebsocketData(remote, mt, conn)
|
go sendDataPacket(remote, inout)
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
forwardDataPacket(remote, pkt)
|
forwardDataPacket(remote, pkt)
|
||||||
case PKT_TYPE_KEEPALIVE:
|
case PKT_TYPE_KEEPALIVE:
|
||||||
// do not write to make sure we do not create concurrency issues
|
// do not write to make sure we do not create concurrency issues
|
||||||
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
// inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||||
case PKT_TYPE_CLOSE_CHANNEL:
|
case PKT_TYPE_CLOSE_CHANNEL:
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
|
@ -612,7 +598,7 @@ func forwardDataPacket(conn net.Conn, data []byte) {
|
||||||
conn.Write(pkt)
|
conn.Write(pkt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
|
func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
|
||||||
defer rdp.Close()
|
defer rdp.Close()
|
||||||
b1 := new(bytes.Buffer)
|
b1 := new(bytes.Buffer)
|
||||||
buf := make([]byte, 4086)
|
buf := make([]byte, 4086)
|
||||||
|
@ -625,12 +611,12 @@ func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
b1.Write(buf[:n])
|
b1.Write(buf[:n])
|
||||||
conn.WriteMessage(mt, createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||||
b1.Reset()
|
b1.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendDataPacket(connIn net.Conn, connOut transport.HttpLayer) {
|
func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
|
||||||
defer connIn.Close()
|
defer connIn.Close()
|
||||||
b1 := new(bytes.Buffer)
|
b1 := new(bytes.Buffer)
|
||||||
buf := make([]byte, 4086)
|
buf := make([]byte, 4086)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
type HttpLayer interface {
|
type Transport interface {
|
||||||
ReadPacket() (n int, p []byte, err error)
|
ReadPacket() (n int, p []byte, err error)
|
||||||
WritePacket(b []byte) (n int, err error)
|
WritePacket(b []byte) (n int, err error)
|
||||||
Close() error
|
Close() error
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue