mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-12 11:59:18 +02:00
Add server implementation of basic auth
This commit is contained in:
parent
390f6acbcd
commit
fb58cb299e
8 changed files with 157 additions and 53 deletions
|
@ -1,38 +1,66 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
ipc "github.com/james-barrow/golang-ipc"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"github.com/msteinert/pam"
|
||||
"github.com/thought-machine/go-flags"
|
||||
"google.golang.org/grpc"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
protocol = "unix"
|
||||
)
|
||||
|
||||
var opts struct {
|
||||
serviceName string `short:"s" long:"service" default:"rdpgw" description:"the PAM service name to use"`
|
||||
ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"`
|
||||
SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"`
|
||||
}
|
||||
|
||||
func auth(service, user, passwd string) error {
|
||||
t, err := pam.StartFunc(service, user, func(s pam.Style, msg string) (string, error) {
|
||||
type AuthServiceImpl struct {
|
||||
serviceName string
|
||||
}
|
||||
|
||||
var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil)
|
||||
|
||||
func NewAuthService(serviceName string) auth.AuthenticateServer {
|
||||
s := &AuthServiceImpl{serviceName: serviceName}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) {
|
||||
t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) {
|
||||
switch s {
|
||||
case pam.PromptEchoOff:
|
||||
return passwd, nil
|
||||
return message.Password, nil
|
||||
case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo:
|
||||
return "", nil
|
||||
}
|
||||
return "", errors.New("unrecognized PAM message style")
|
||||
})
|
||||
|
||||
r := &auth.AuthResponse{}
|
||||
r.Authenticated = false
|
||||
if err != nil {
|
||||
return err
|
||||
log.Printf("Error authenticating user: %s due to: %s", message.Username, err)
|
||||
r.Error = err.Error()
|
||||
return r, err
|
||||
}
|
||||
|
||||
if err = t.Authenticate(0); err != nil {
|
||||
return err
|
||||
log.Printf("Authentication for user: %s failed due to: %s", message.Username, err)
|
||||
r.Error = err.Error()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
log.Printf("User: %s authenticated", message.Username)
|
||||
r.Authenticated = true
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -41,32 +69,24 @@ func main() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
config := &ipc.ServerConfig{UnmaskPermissions: true}
|
||||
sc, err := ipc.StartServer("rdpgw-auth", config)
|
||||
for {
|
||||
msg, err := sc.Read()
|
||||
log.Printf("Starting auth server on %s", opts.SocketAddr)
|
||||
cleanup := func() {
|
||||
if _, err := os.Stat(opts.SocketAddr); err == nil {
|
||||
if err := os.RemoveAll(opts.SocketAddr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cleanup()
|
||||
|
||||
oldUmask := syscall.Umask(0)
|
||||
listener, err := net.Listen(protocol, opts.SocketAddr)
|
||||
syscall.Umask(oldUmask)
|
||||
if err != nil {
|
||||
log.Printf("server error, %s", err)
|
||||
continue
|
||||
}
|
||||
if msg.MsgType > 0 {
|
||||
req := &UserPass{}
|
||||
if err = proto.Unmarshal(msg.Data, req); err != nil {
|
||||
log.Printf("cannot unmarshal request %s", string(msg.Data))
|
||||
continue
|
||||
}
|
||||
err := auth(opts.serviceName, req.Username, req.Password)
|
||||
if err != nil {
|
||||
res := &Response{Status: "cannot authenticate"}
|
||||
out, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot marshal response due to %s", err)
|
||||
}
|
||||
sc.Write(1, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("cannot authenticate due to %s", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
server := grpc.NewServer()
|
||||
service := NewAuthService(opts.ServiceName)
|
||||
auth.RegisterAuthenticateServer(server, service)
|
||||
server.Serve(listener)
|
||||
}
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package main;
|
||||
|
||||
option go_package = "./auth;main";
|
||||
|
||||
message UserPass {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message Response {
|
||||
string status = 1;
|
||||
}
|
62
cmd/rdpgw/api/basic.go
Normal file
62
cmd/rdpgw/api/basic.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/shared/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
protocol = "unix"
|
||||
)
|
||||
|
||||
func (c *Config) BasicAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if ok {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := grpc.Dial(c.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return net.Dial(protocol, addr)
|
||||
}))
|
||||
if err != nil {
|
||||
log.Printf("Cannot reach authentication provider: %s", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
c := auth.NewAuthenticateClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := &auth.UserPass{Username: username, Password: password}
|
||||
res, err := c.Authenticate(ctx, req)
|
||||
if err != nil {
|
||||
log.Printf("Error talking to authentication provider: %s", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !res.Authenticated {
|
||||
log.Printf("User %s is not authenticated for this service", username)
|
||||
} else {
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
// If the Authentication header is not present, is invalid, or the
|
||||
// username or password is wrong, then set a WWW-Authenticate
|
||||
// header to inform the client that we expect them to use basic
|
||||
// authentication and send a 401 Unauthorized response.
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
|
@ -51,6 +51,8 @@ type Config struct {
|
|||
ConnectionType int
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
SocketAddress string
|
||||
Authentication string
|
||||
}
|
||||
|
||||
func (c *Config) NewApi() {
|
||||
|
|
|
@ -30,8 +30,10 @@ type ServerConfig struct {
|
|||
SessionEncryptionKey string `koanf:"sessionencryptionkey"`
|
||||
SessionStore string `koanf:"sessionstore"`
|
||||
SendBuf int `koanf:"sendbuf"`
|
||||
ReceiveBuf int `koanf:"recievebuf"`
|
||||
ReceiveBuf int `koanf:"receivebuf"`
|
||||
DisableTLS bool `koanf:"disabletls"`
|
||||
Authentication string `koanf:"authentication"`
|
||||
AuthSocket string `koanf:"authsocket"`
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
|
@ -121,6 +123,8 @@ func Load(configFile string) Configuration {
|
|||
"Server.Port": 443,
|
||||
"Server.SessionStore": "cookie",
|
||||
"Server.HostSelection": "roundrobin",
|
||||
"Server.Authentication": "openid",
|
||||
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
|
||||
"Client.NetworkAutoDetect": 1,
|
||||
"Client.BandwidthAutoDetect": 1,
|
||||
"Security.VerifyClientIp": true,
|
||||
|
@ -182,6 +186,9 @@ func Load(configFile string) Configuration {
|
|||
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
|
||||
}
|
||||
|
||||
if Conf.Server.Authentication == "local" && Conf.Server.DisableTLS {
|
||||
log.Fatalf("basicauth=local and disabletls are mutually exclusive")
|
||||
}
|
||||
return Conf
|
||||
|
||||
}
|
||||
|
|
|
@ -89,6 +89,8 @@ func main() {
|
|||
ConnectionType: conf.Client.ConnectionType,
|
||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||
DefaultDomain: conf.Client.DefaultDomain,
|
||||
SocketAddress: conf.Server.AuthSocket,
|
||||
Authentication: conf.Server.Authentication,
|
||||
}
|
||||
api.NewApi()
|
||||
|
||||
|
@ -148,11 +150,16 @@ func main() {
|
|||
ServerConf: &handlerConfig,
|
||||
}
|
||||
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
if conf.Server.Authentication == "local" {
|
||||
http.Handle("/connect", common.EnrichContext(api.BasicAuth(api.HandleDownload)))
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(api.BasicAuth(gw.HandleGatewayProtocol)))
|
||||
} else {
|
||||
// openid
|
||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
}
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
|
||||
if conf.Server.DisableTLS {
|
||||
err = server.ListenAndServe()
|
||||
|
|
7
go.mod
7
go.mod
|
@ -7,17 +7,17 @@ require (
|
|||
github.com/go-jose/go-jose/v3 v3.0.0
|
||||
github.com/gorilla/sessions v1.2.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/james-barrow/golang-ipc v1.0.0
|
||||
github.com/knadh/koanf v1.4.2
|
||||
github.com/msteinert/pam v1.0.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/prometheus/client_golang v1.12.1
|
||||
github.com/thought-machine/go-flags v1.6.1
|
||||
golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
|
||||
google.golang.org/grpc v1.49.0
|
||||
google.golang.org/protobuf v1.28.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.4.16 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/fsnotify/fsnotify v1.5.4 // indirect
|
||||
|
@ -35,8 +35,9 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
|
||||
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/protobuf v1.28.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0 // indirect
|
||||
)
|
||||
|
|
19
proto/auth.proto
Normal file
19
proto/auth.proto
Normal file
|
@ -0,0 +1,19 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package auth;
|
||||
|
||||
option go_package = "./auth";
|
||||
|
||||
message UserPass {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message AuthResponse {
|
||||
bool authenticated = 1;
|
||||
string error = 2;
|
||||
}
|
||||
|
||||
service Authenticate {
|
||||
rpc Authenticate (UserPass) returns (AuthResponse) {}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue