Add prometheus and disable http2

This commit is contained in:
Bolke de Bruin 2020-07-14 08:43:25 +02:00
parent 1f191a5e41
commit 33a5e0e03c
3 changed files with 43 additions and 5 deletions

1
go.mod
View file

@ -5,4 +5,5 @@ go 1.14
require ( require (
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.7.1
) )

11
main.go
View file

@ -3,6 +3,8 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/client_golang/prometheus"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -42,9 +44,16 @@ func main() {
server := http.Server{ server := http.Server{
Addr: ":" + strconv.Itoa(*port), Addr: ":" + strconv.Itoa(*port),
TLSConfig: cfg, TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
} }
http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol) http.HandleFunc("/remoteDesktopGateway/", handleGatewayProtocol)
http.Handle("/metrics", promhttp.Handler())
prometheus.MustRegister(connectionCache)
prometheus.MustRegister(legacyConnections)
prometheus.MustRegister(websocketConnections)
err = server.ListenAndServeTLS("", "") err = server.ListenAndServeTLS("", "")
if err != nil { if err != nil {
log.Fatal("ListenAndServe: ", err) log.Fatal("ListenAndServe: ", err)

36
rdg.go
View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"io" "io"
"log" "log"
"math/rand" "math/rand"
@ -15,7 +16,6 @@ import (
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strconv" "strconv"
"time" "time"
"unicode/utf16" "unicode/utf16"
"unicode/utf8" "unicode/utf8"
@ -84,6 +84,29 @@ const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
) )
var (
connectionCache = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
// HandshakeHeader is the interface that writes both upgrade request or // HandshakeHeader is the interface that writes both upgrade request or
// response headers into a given io.Writer. // response headers into a given io.Writer.
type HandshakeHeader interface { type HandshakeHeader interface {
@ -129,6 +152,7 @@ var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute) var c = cache.New(5*time.Minute, 10*time.Minute)
func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) { func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount()))
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
handleLegacyProtocol(w, r) handleLegacyProtocol(w, r)
@ -155,6 +179,9 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
var remote net.Conn var remote net.Conn
websocketConnections.Inc()
defer websocketConnections.Dec()
for { for {
mt, msg, err := conn.ReadMessage() mt, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
@ -217,8 +244,7 @@ func handleWebsocketProtocol(conn *websocket.Conn) {
// do not write to make sure we do not create concurrency issues // do not write to make sure we do not create concurrency issues
// conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // conn.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
remote.Close() break
return
default: default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt) log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
} }
@ -252,6 +278,9 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
c.Set(connId, s, cache.DefaultExpiration) c.Set(connId, s, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
legacyConnections.Inc()
defer legacyConnections.Dec()
var remote net.Conn var remote net.Conn
conn, rw, _ := Accept(w) conn, rw, _ := Accept(w)
@ -314,7 +343,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
s.ConnIn.Close() s.ConnIn.Close()
s.ConnOut.Close() s.ConnOut.Close()
remote.Close()
break break
default: default:
log.Printf("Unknown packet (size %d): %x", n, packet) log.Printf("Unknown packet (size %d): %x", n, packet)