rdpgw/cmd/auth/ntlm/ntlm.go
m7913d 372dc43ef2
Support for NTLM authentication added (#109)
* Support for NTLM authentication added

To support NTLM authentication, a database is added as an authentication source.
Currently, only the configuration file is supported as a database.
Database authentication supports Basic and NTLM authentication protcols.

ServerConfig.BasicAuthEnabled renamed to LocalEnabled as Basic auth can be used with NTLM or Local.
2024-04-24 14:12:41 +02:00

160 lines
4.7 KiB
Go

package ntlm
import (
"encoding/base64"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/auth/database"
"github.com/bolkedebruin/rdpgw/shared/auth"
"github.com/patrickmn/go-cache"
"github.com/m7913d/go-ntlm/ntlm"
"fmt"
"log"
"time"
)
const (
cacheExpiration = time.Minute
cleanupInterval = time.Minute * 5
)
type NTLMAuth struct {
contextCache *cache.Cache
// Information about the server, returned to the client during authentication
ServerName string // e.g. EXAMPLE1
DomainName string // e.g. EXAMPLE
DnsServerName string // e.g. example1.example.com
DnsDomainName string // e.g. example.com
DnsTreeName string // e.g. example.com
Database database.Database
}
func NewNTLMAuth (database database.Database) (*NTLMAuth) {
return &NTLMAuth{
contextCache: cache.New(cacheExpiration, cleanupInterval),
Database: database,
}
}
func (h *NTLMAuth) Authenticate(message *auth.NtlmRequest) (*auth.NtlmResponse, error) {
r := &auth.NtlmResponse{}
r.Authenticated = false
if message.Session == "" {
return r, errors.New("Invalid (empty) session specified")
}
if message.NtlmMessage == "" {
return r, errors.New("Empty NTLM message specified")
}
c := h.getContext(message.Session)
err := c.Authenticate(message.NtlmMessage, r)
if err != nil || r.Authenticated {
h.removeContext(message.Session)
}
return r, err
}
func (h *NTLMAuth) getContext (session string) (*ntlmContext) {
if c_, found := h.contextCache.Get(session); found {
if c, ok := c_.(*ntlmContext); ok {
return c
}
}
c := new(ntlmContext)
c.h = h
h.contextCache.Set(session, c, cache.DefaultExpiration)
return c
}
func (h *NTLMAuth) removeContext (session string) {
h.contextCache.Delete(session)
}
type ntlmContext struct {
session ntlm.ServerSession
h *NTLMAuth
}
func (c *ntlmContext) Authenticate(authorisationEncoded string, r *auth.NtlmResponse) (error) {
authorisation, err := base64.StdEncoding.DecodeString(authorisationEncoded)
if err != nil {
return errors.New(fmt.Sprintf("Failed to decode NTLM Authorisation header: %s", err))
}
nm, err := ntlm.ParseNegotiateMessage(authorisation)
if err == nil {
return c.negotiate(nm, r)
}
if (nm != nil && nm.MessageType == 1) {
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
} else if c.session == nil {
return errors.New(fmt.Sprintf("New NTLM auth sequence should start with negotioate request"))
}
am, err := ntlm.ParseAuthenticateMessage(authorisation, 2)
if err == nil {
return c.authenticate(am, r)
}
return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err))
}
func (c *ntlmContext) negotiate(nm *ntlm.NegotiateMessage, r *auth.NtlmResponse) (error) {
session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
if err != nil {
c.session = nil;
return errors.New(fmt.Sprintf("Failed to create NTLM server session: %s", err))
}
c.session = session
c.session.SetRequireNtHash(true)
c.session.SetDomainName(c.h.DomainName)
c.session.SetComputerName(c.h.ServerName)
c.session.SetDnsDomainName(c.h.DnsDomainName)
c.session.SetDnsComputerName(c.h.DnsServerName)
c.session.SetDnsTreeName(c.h.DnsTreeName)
err = c.session.ProcessNegotiateMessage(nm)
if err != nil {
return errors.New(fmt.Sprintf("Failed to process NTLM negotiate message: %s", err))
}
cm, err := c.session.GenerateChallengeMessage()
if err != nil {
return errors.New(fmt.Sprintf("Failed to generate NTLM challenge message: %s", err))
}
r.NtlmMessage = base64.StdEncoding.EncodeToString(cm.Bytes())
return nil
}
func (c *ntlmContext) authenticate(am *ntlm.AuthenticateMessage, r *auth.NtlmResponse) (error) {
if c.session == nil {
return errors.New(fmt.Sprintf("NTLM Authenticate requires active session: first call negotioate"))
}
username := am.UserName.String()
password := c.h.Database.GetPassword (username)
if password == "" {
log.Printf("NTLM: unknown username specified: %s", username)
return nil
}
c.session.SetUserInfo(username,password,"")
err := c.session.ProcessAuthenticateMessage(am)
if err != nil {
log.Printf("Failed to process NTLM authenticate message: %s", err)
return nil
}
r.Authenticated = true
r.Username = username
return nil
}