Use list of kdcs and ensure length is removed / added when necessary

This commit is contained in:
Bolke de Bruin 2024-03-16 13:10:30 +01:00
parent a67962b02d
commit ecbe63f175

View file

@ -23,6 +23,13 @@ type KdcProxyMsg struct {
Flags int `asn1:"tag:2,optional"`
}
type Kdc struct {
Realm string
Host string
Proto string
Conn net.Conn
}
type KerberosProxy struct {
krb5Config *krbconfig.Config
}
@ -97,39 +104,71 @@ func (k *KerberosProxy) forward(realm string, data []byte) (resp []byte, err err
}
// load udp first as is the default for kerberos
c, kdcs, err := k.krb5Config.GetKDCs(realm, false)
if err != nil || c < 1 {
return nil, fmt.Errorf("cannot get kdc for realm %s due to %s", realm, err)
udpCnt, udpKdcs, err := k.krb5Config.GetKDCs(realm, false)
if err != nil {
return nil, fmt.Errorf("cannot get udp kdc for realm %s due to %s", realm, err)
}
// load tcp
tcpCnt, tcpKdcs, err := k.krb5Config.GetKDCs(realm, true)
if err != nil {
return nil, fmt.Errorf("cannot get tcp kdc for realm %s due to %s", realm, err)
}
if tcpCnt+udpCnt == 0 {
return nil, fmt.Errorf("cannot get any kdcs (tcp or udp) for realm %s", realm)
}
// merge the kdcs
kdcs := make([]Kdc, tcpCnt+udpCnt)
for i := range udpKdcs {
kdcs[i] = Kdc{Realm: realm, Host: udpKdcs[i], Proto: "udp"}
}
for i := range tcpKdcs {
kdcs[i+udpCnt] = Kdc{Realm: realm, Host: tcpKdcs[i], Proto: "tcp"}
}
replies := make(chan []byte, len(kdcs))
for i := range kdcs {
conn, err := net.Dial("tcp", kdcs[i])
conn, err := net.Dial(kdcs[i].Proto, kdcs[i].Host)
if err != nil {
log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err)
continue
}
conn.SetDeadline(time.Now().Add(timeout))
// if we proxy over UDP remove the length prefix
if kdcs[i].Proto == "tcp" {
_, err = conn.Write(data)
} else {
_, err = conn.Write(data[4:])
}
if err != nil {
log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err)
conn.Close()
continue
}
// todo check header
resp, err = io.ReadAll(conn)
if err != nil {
log.Printf("error reading from kdc %s due to %s, trying next if available", kdcs[i], err)
conn.Close()
continue
}
conn.Close()
return resp, nil
kdcs[i].Conn = conn
go awaitReply(conn, kdcs[i].Proto == "udp", replies)
}
return nil, fmt.Errorf("no kdcs found for realm %s", realm)
reply := <-replies
// close all the connections and return the first reply
for kdc := range kdcs {
if kdcs[kdc].Conn != nil {
kdcs[kdc].Conn.Close()
}
<-replies
}
if reply != nil {
return reply, nil
}
return nil, fmt.Errorf("no replies received from kdcs for realm %s", realm)
}
func decode(data []byte) (msg *KdcProxyMsg, err error) {
@ -155,3 +194,17 @@ func encode(krb5data []byte) (r []byte, err error) {
}
return enc, nil
}
func awaitReply(conn net.Conn, isUdp bool, reply chan<- []byte) {
resp, err := io.ReadAll(conn)
if err != nil {
log.Printf("error reading from kdc due to %s", err)
reply <- nil
return
}
if isUdp {
// udp will be missing the length prefix so add it
resp = append([]byte{byte(len(resp))}, resp...)
}
reply <- resp
}