Support setting send/receive buffers for the IP sockets

This commit is contained in:
Bolke de Bruin 2020-09-11 22:01:40 +02:00
parent 8876b04466
commit 505eafdc1e
6 changed files with 89 additions and 0 deletions

View file

@ -59,6 +59,11 @@ server:
# make sure to share this across the different pods
sessionKey: thisisasessionkeyreplacethisjetzt
sessionEncryptionKey: thisisasessionkeyreplacethisnunu!
# tries to set the receive / send buffer of the connections to the client
# in case of high latency high bandwidth the defaults set by the OS might
# be to low for a good experience
# receiveBuf: 12582912
# sendBuf: 12582912
# Open ID Connect specific settings
openId:
providerUrl: http://keycloak/auth/realms/test

View file

@ -22,6 +22,8 @@ type ServerConfig struct {
RoundRobin bool
SessionKey string
SessionEncryptionKey string
SendBuf int
ReceiveBuf int
}
type OpenIDConfig struct {

View file

@ -128,6 +128,8 @@ func main() {
},
VerifyTunnelCreate: security.VerifyPAAToken,
VerifyServerFunc: security.VerifyServerFunc,
SendBuf: conf.Server.SendBuf,
ReceiveBuf: conf.Server.ReceiveBuf,
}
gw := protocol.Gateway{
ServerConf: &handlerConfig,

View file

@ -8,6 +8,8 @@ import (
"io"
"log"
"net"
"os"
"syscall"
)
type RedirectFlags struct {
@ -136,3 +138,11 @@ func receive(data []byte, out net.Conn) {
out.Write(pkt)
}
// wrapSyscallError takes an error and a syscall name. If the error is
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
func wrapSyscallError(name string, err error) error {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError(name, err)
}
return err
}

View file

@ -2,13 +2,17 @@ package protocol
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"log"
"net"
"net/http"
"reflect"
"syscall"
"time"
)
@ -81,12 +85,76 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
}
defer conn.Close()
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
if err != nil {
log.Printf("Cannot set send/receive buffers: %s", err)
}
g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
}
}
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
return nil
}
// conn == tls.Conn
ptr := reflect.ValueOf(conn)
val := reflect.Indirect(ptr)
if val.Kind() != reflect.Struct {
return errors.New("didn't get a struct from conn")
}
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
ptrConn := val.FieldByName("conn")
valConn := reflect.Indirect(ptrConn)
if !valConn.IsValid() {
return errors.New("cannot find conn field")
}
valConn = valConn.Elem().Elem()
// net.FD
ptrNetFd := valConn.FieldByName("fd")
valNetFd := reflect.Indirect(ptrNetFd)
if !valNetFd.IsValid() {
return errors.New("cannot find fd field")
}
// pfd member
ptrPfd := valNetFd.FieldByName("pfd")
valPfd := reflect.Indirect(ptrPfd)
if !valPfd.IsValid() {
return errors.New("cannot find pfd field")
}
// finally the exported Sysfd
ptrSysFd := valPfd.FieldByName("Sysfd")
if !ptrSysFd.IsValid() {
return errors.New("cannot find Sysfd field")
}
fd := int(ptrSysFd.Int())
if g.ServerConf.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
if g.ServerConf.SendBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
return nil
}
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc()
defer websocketConnections.Dec()

View file

@ -39,6 +39,8 @@ type ServerConf struct {
IdleTimeout int
SmartCardAuth bool
TokenAuth bool
ReceiveBuf int
SendBuf int
}
func NewServer(s *SessionInfo, conf *ServerConf) *Server {