mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-16 21:53:45 +02:00
Support setting send/receive buffers for the IP sockets
This commit is contained in:
parent
8876b04466
commit
505eafdc1e
6 changed files with 89 additions and 0 deletions
|
@ -59,6 +59,11 @@ server:
|
|||
# make sure to share this across the different pods
|
||||
sessionKey: thisisasessionkeyreplacethisjetzt
|
||||
sessionEncryptionKey: thisisasessionkeyreplacethisnunu!
|
||||
# tries to set the receive / send buffer of the connections to the client
|
||||
# in case of high latency high bandwidth the defaults set by the OS might
|
||||
# be to low for a good experience
|
||||
# receiveBuf: 12582912
|
||||
# sendBuf: 12582912
|
||||
# Open ID Connect specific settings
|
||||
openId:
|
||||
providerUrl: http://keycloak/auth/realms/test
|
||||
|
|
|
@ -22,6 +22,8 @@ type ServerConfig struct {
|
|||
RoundRobin bool
|
||||
SessionKey string
|
||||
SessionEncryptionKey string
|
||||
SendBuf int
|
||||
ReceiveBuf int
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
|
|
2
main.go
2
main.go
|
@ -128,6 +128,8 @@ func main() {
|
|||
},
|
||||
VerifyTunnelCreate: security.VerifyPAAToken,
|
||||
VerifyServerFunc: security.VerifyServerFunc,
|
||||
SendBuf: conf.Server.SendBuf,
|
||||
ReceiveBuf: conf.Server.ReceiveBuf,
|
||||
}
|
||||
gw := protocol.Gateway{
|
||||
ServerConf: &handlerConfig,
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type RedirectFlags struct {
|
||||
|
@ -136,3 +138,11 @@ func receive(data []byte, out net.Conn) {
|
|||
out.Write(pkt)
|
||||
}
|
||||
|
||||
// wrapSyscallError takes an error and a syscall name. If the error is
|
||||
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
|
||||
func wrapSyscallError(name string, err error) error {
|
||||
if _, ok := err.(syscall.Errno); ok {
|
||||
err = os.NewSyscallError(name, err)
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -2,13 +2,17 @@ package protocol
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -81,12 +85,76 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
||||
if err != nil {
|
||||
log.Printf("Cannot set send/receive buffers: %s", err)
|
||||
}
|
||||
|
||||
g.handleWebsocketProtocol(ctx, conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// conn == tls.Conn
|
||||
ptr := reflect.ValueOf(conn)
|
||||
val := reflect.Indirect(ptr)
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return errors.New("didn't get a struct from conn")
|
||||
}
|
||||
|
||||
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
|
||||
ptrConn := val.FieldByName("conn")
|
||||
valConn := reflect.Indirect(ptrConn)
|
||||
if !valConn.IsValid() {
|
||||
return errors.New("cannot find conn field")
|
||||
}
|
||||
valConn = valConn.Elem().Elem()
|
||||
|
||||
// net.FD
|
||||
ptrNetFd := valConn.FieldByName("fd")
|
||||
valNetFd := reflect.Indirect(ptrNetFd)
|
||||
if !valNetFd.IsValid() {
|
||||
return errors.New("cannot find fd field")
|
||||
}
|
||||
|
||||
// pfd member
|
||||
ptrPfd := valNetFd.FieldByName("pfd")
|
||||
valPfd := reflect.Indirect(ptrPfd)
|
||||
if !valPfd.IsValid() {
|
||||
return errors.New("cannot find pfd field")
|
||||
}
|
||||
|
||||
// finally the exported Sysfd
|
||||
ptrSysFd := valPfd.FieldByName("Sysfd")
|
||||
if !ptrSysFd.IsValid() {
|
||||
return errors.New("cannot find Sysfd field")
|
||||
}
|
||||
fd := int(ptrSysFd.Int())
|
||||
|
||||
if g.ServerConf.ReceiveBuf > 0 {
|
||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
|
||||
if err != nil {
|
||||
return wrapSyscallError("setsockopt", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.ServerConf.SendBuf > 0 {
|
||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
|
||||
if err != nil {
|
||||
return wrapSyscallError("setsockopt", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
|
|
@ -39,6 +39,8 @@ type ServerConf struct {
|
|||
IdleTimeout int
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
ReceiveBuf int
|
||||
SendBuf int
|
||||
}
|
||||
|
||||
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue