Support setting send/receive buffers for the IP sockets

This commit is contained in:
Bolke de Bruin 2020-09-11 22:01:40 +02:00
parent 8876b04466
commit 505eafdc1e
6 changed files with 89 additions and 0 deletions

View file

@ -2,13 +2,17 @@ package protocol
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"log"
"net"
"net/http"
"reflect"
"syscall"
"time"
)
@ -81,12 +85,76 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
if err != nil {
log.Printf("Cannot set send/receive buffers: %s", err)
}
g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
}
}
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
return nil
}
// conn == tls.Conn
ptr := reflect.ValueOf(conn)
val := reflect.Indirect(ptr)
if val.Kind() != reflect.Struct {
return errors.New("didn't get a struct from conn")
}
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
ptrConn := val.FieldByName("conn")
valConn := reflect.Indirect(ptrConn)
if !valConn.IsValid() {
return errors.New("cannot find conn field")
}
valConn = valConn.Elem().Elem()
// net.FD
ptrNetFd := valConn.FieldByName("fd")
valNetFd := reflect.Indirect(ptrNetFd)
if !valNetFd.IsValid() {
return errors.New("cannot find fd field")
}
// pfd member
ptrPfd := valNetFd.FieldByName("pfd")
valPfd := reflect.Indirect(ptrPfd)
if !valPfd.IsValid() {
return errors.New("cannot find pfd field")
}
// finally the exported Sysfd
ptrSysFd := valPfd.FieldByName("Sysfd")
if !ptrSysFd.IsValid() {
return errors.New("cannot find Sysfd field")
}
fd := int(ptrSysFd.Int())
if g.ServerConf.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
if g.ServerConf.SendBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
return nil
}
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc()
defer websocketConnections.Dec()