Add server implementation of basic auth

This commit is contained in:
Bolke de Bruin 2022-08-24 13:47:26 +02:00
parent 390f6acbcd
commit fb58cb299e
8 changed files with 157 additions and 53 deletions

View file

@ -1,38 +1,66 @@
package main package main
import ( import (
"context"
"errors" "errors"
"github.com/golang/protobuf/proto" "github.com/bolkedebruin/rdpgw/shared/auth"
ipc "github.com/james-barrow/golang-ipc"
"github.com/msteinert/pam" "github.com/msteinert/pam"
"github.com/thought-machine/go-flags" "github.com/thought-machine/go-flags"
"google.golang.org/grpc"
"log" "log"
"net"
"os"
"syscall"
)
const (
protocol = "unix"
) )
var opts struct { var opts struct {
serviceName string `short:"s" long:"service" default:"rdpgw" description:"the PAM service name to use"` ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"`
SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"`
} }
func auth(service, user, passwd string) error { type AuthServiceImpl struct {
t, err := pam.StartFunc(service, user, func(s pam.Style, msg string) (string, error) { serviceName string
}
var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil)
func NewAuthService(serviceName string) auth.AuthenticateServer {
s := &AuthServiceImpl{serviceName: serviceName}
return s
}
func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) {
t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) {
switch s { switch s {
case pam.PromptEchoOff: case pam.PromptEchoOff:
return passwd, nil return message.Password, nil
case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo: case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo:
return "", nil return "", nil
} }
return "", errors.New("unrecognized PAM message style") return "", errors.New("unrecognized PAM message style")
}) })
r := &auth.AuthResponse{}
r.Authenticated = false
if err != nil { if err != nil {
return err log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
r.Error = err.Error()
return r, err
} }
if err = t.Authenticate(0); err != nil { if err = t.Authenticate(0); err != nil {
return err log.Printf("Authentication for user: %s failed due to: %s", message.Username, err)
r.Error = err.Error()
return r, nil
} }
return nil log.Printf("User: %s authenticated", message.Username)
r.Authenticated = true
return r, nil
} }
func main() { func main() {
@ -41,32 +69,24 @@ func main() {
panic(err) panic(err)
} }
config := &ipc.ServerConfig{UnmaskPermissions: true} log.Printf("Starting auth server on %s", opts.SocketAddr)
sc, err := ipc.StartServer("rdpgw-auth", config) cleanup := func() {
for { if _, err := os.Stat(opts.SocketAddr); err == nil {
msg, err := sc.Read() if err := os.RemoveAll(opts.SocketAddr); err != nil {
if err != nil { log.Fatal(err)
log.Printf("server error, %s", err)
continue
}
if msg.MsgType > 0 {
req := &UserPass{}
if err = proto.Unmarshal(msg.Data, req); err != nil {
log.Printf("cannot unmarshal request %s", string(msg.Data))
continue
}
err := auth(opts.serviceName, req.Username, req.Password)
if err != nil {
res := &Response{Status: "cannot authenticate"}
out, err := proto.Marshal(res)
if err != nil {
log.Fatalf("cannot marshal response due to %s", err)
}
sc.Write(1, out)
} }
} }
} }
cleanup()
oldUmask := syscall.Umask(0)
listener, err := net.Listen(protocol, opts.SocketAddr)
syscall.Umask(oldUmask)
if err != nil { if err != nil {
log.Printf("cannot authenticate due to %s", err) log.Fatal(err)
} }
server := grpc.NewServer()
service := NewAuthService(opts.ServiceName)
auth.RegisterAuthenticateServer(server, service)
server.Serve(listener)
} }

View file

@ -1,14 +0,0 @@
syntax = "proto3";
package main;
option go_package = "./auth;main";
message UserPass {
string username = 1;
string password = 2;
}
message Response {
string status = 1;
}

62
cmd/rdpgw/api/basic.go Normal file
View file

@ -0,0 +1,62 @@
package api
import (
"context"
"github.com/bolkedebruin/rdpgw/shared/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"log"
"net"
"net/http"
"time"
)
const (
protocol = "unix"
)
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if ok {
ctx := r.Context()
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial(protocol, addr)
}))
if err != nil {
log.Printf("Cannot reach authentication provider: %s", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
defer conn.Close()
c := auth.NewAuthenticateClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
req := &auth.UserPass{Username: username, Password: password}
res, err := c.Authenticate(ctx, req)
if err != nil {
log.Printf("Error talking to authentication provider: %s", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
if !res.Authenticated {
log.Printf("User %s is not authenticated for this service", username)
} else {
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// If the Authentication header is not present, is invalid, or the
// username or password is wrong, then set a WWW-Authenticate
// header to inform the client that we expect them to use basic
// authentication and send a 401 Unauthorized response.
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}

View file

@ -51,6 +51,8 @@ type Config struct {
ConnectionType int ConnectionType int
SplitUserDomain bool SplitUserDomain bool
DefaultDomain string DefaultDomain string
SocketAddress string
Authentication string
} }
func (c *Config) NewApi() { func (c *Config) NewApi() {

View file

@ -30,8 +30,10 @@ type ServerConfig struct {
SessionEncryptionKey string `koanf:"sessionencryptionkey"` SessionEncryptionKey string `koanf:"sessionencryptionkey"`
SessionStore string `koanf:"sessionstore"` SessionStore string `koanf:"sessionstore"`
SendBuf int `koanf:"sendbuf"` SendBuf int `koanf:"sendbuf"`
ReceiveBuf int `koanf:"recievebuf"` ReceiveBuf int `koanf:"receivebuf"`
DisableTLS bool `koanf:"disabletls"` DisableTLS bool `koanf:"disabletls"`
Authentication string `koanf:"authentication"`
AuthSocket string `koanf:"authsocket"`
} }
type OpenIDConfig struct { type OpenIDConfig struct {
@ -121,6 +123,8 @@ func Load(configFile string) Configuration {
"Server.Port": 443, "Server.Port": 443,
"Server.SessionStore": "cookie", "Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin", "Server.HostSelection": "roundrobin",
"Server.Authentication": "openid",
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
"Client.NetworkAutoDetect": 1, "Client.NetworkAutoDetect": 1,
"Client.BandwidthAutoDetect": 1, "Client.BandwidthAutoDetect": 1,
"Security.VerifyClientIp": true, "Security.VerifyClientIp": true,
@ -182,6 +186,9 @@ func Load(configFile string) Configuration {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
} }
if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
log.Fatalf("basicauth=local and disabletls are mutually exclusive")
}
return Conf return Conf
} }

View file

@ -89,6 +89,8 @@ func main() {
ConnectionType: conf.Client.ConnectionType, ConnectionType: conf.Client.ConnectionType,
SplitUserDomain: conf.Client.SplitUserDomain, SplitUserDomain: conf.Client.SplitUserDomain,
DefaultDomain: conf.Client.DefaultDomain, DefaultDomain: conf.Client.DefaultDomain,
SocketAddress: conf.Server.AuthSocket,
Authentication: conf.Server.Authentication,
} }
api.NewApi() api.NewApi()
@ -148,11 +150,16 @@ func main() {
ServerConf: &handlerConfig, ServerConf: &handlerConfig,
} }
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol))) if conf.Server.Authentication == "local" {
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload)))) http.Handle("/connect", common.EnrichContext(api.BasicAuth(api.HandleDownload)))
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
} else {
// openid
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.HandleFunc("/callback", api.HandleCallback)
}
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo) http.HandleFunc("/tokeninfo", api.TokenInfo)
http.HandleFunc("/callback", api.HandleCallback)
if conf.Server.DisableTLS { if conf.Server.DisableTLS {
err = server.ListenAndServe() err = server.ListenAndServe()

7
go.mod
View file

@ -7,17 +7,17 @@ require (
github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-jose/go-jose/v3 v3.0.0
github.com/gorilla/sessions v1.2.1 github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/james-barrow/golang-ipc v1.0.0
github.com/knadh/koanf v1.4.2 github.com/knadh/koanf v1.4.2
github.com/msteinert/pam v1.0.0 github.com/msteinert/pam v1.0.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.12.1 github.com/prometheus/client_golang v1.12.1
github.com/thought-machine/go-flags v1.6.1 github.com/thought-machine/go-flags v1.6.1
golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
google.golang.org/grpc v1.49.0
google.golang.org/protobuf v1.28.1
) )
require ( require (
github.com/Microsoft/go-winio v0.4.16 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect
@ -35,8 +35,9 @@ require (
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.28.0 // indirect google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v3 v3.0.0 // indirect gopkg.in/yaml.v3 v3.0.0 // indirect
) )

19
proto/auth.proto Normal file
View file

@ -0,0 +1,19 @@
syntax = "proto3";
package auth;
option go_package = "./auth";
message UserPass {
string username = 1;
string password = 2;
}
message AuthResponse {
bool authenticated = 1;
string error = 2;
}
service Authenticate {
rpc Authenticate (UserPass) returns (AuthResponse) {}
}