Use list of kdcs and ensure length is removed / added when necessary

This commit is contained in:
Bolke de Bruin 2024-03-16 13:10:30 +01:00
parent a67962b02d
commit ecbe63f175

View file

@ -23,6 +23,13 @@ type KdcProxyMsg struct {
Flags int `asn1:"tag:2,optional"` Flags int `asn1:"tag:2,optional"`
} }
type Kdc struct {
Realm string
Host string
Proto string
Conn net.Conn
}
type KerberosProxy struct { type KerberosProxy struct {
krb5Config *krbconfig.Config krb5Config *krbconfig.Config
} }
@ -97,39 +104,71 @@ func (k *KerberosProxy) forward(realm string, data []byte) (resp []byte, err err
} }
// load udp first as is the default for kerberos // load udp first as is the default for kerberos
c, kdcs, err := k.krb5Config.GetKDCs(realm, false) udpCnt, udpKdcs, err := k.krb5Config.GetKDCs(realm, false)
if err != nil || c < 1 { if err != nil {
return nil, fmt.Errorf("cannot get kdc for realm %s due to %s", realm, err) return nil, fmt.Errorf("cannot get udp kdc for realm %s due to %s", realm, err)
} }
// load tcp
tcpCnt, tcpKdcs, err := k.krb5Config.GetKDCs(realm, true)
if err != nil {
return nil, fmt.Errorf("cannot get tcp kdc for realm %s due to %s", realm, err)
}
if tcpCnt+udpCnt == 0 {
return nil, fmt.Errorf("cannot get any kdcs (tcp or udp) for realm %s", realm)
}
// merge the kdcs
kdcs := make([]Kdc, tcpCnt+udpCnt)
for i := range udpKdcs {
kdcs[i] = Kdc{Realm: realm, Host: udpKdcs[i], Proto: "udp"}
}
for i := range tcpKdcs {
kdcs[i+udpCnt] = Kdc{Realm: realm, Host: tcpKdcs[i], Proto: "tcp"}
}
replies := make(chan []byte, len(kdcs))
for i := range kdcs { for i := range kdcs {
conn, err := net.Dial("tcp", kdcs[i]) conn, err := net.Dial(kdcs[i].Proto, kdcs[i].Host)
if err != nil { if err != nil {
log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err) log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err)
continue continue
} }
conn.SetDeadline(time.Now().Add(timeout)) conn.SetDeadline(time.Now().Add(timeout))
_, err = conn.Write(data) // if we proxy over UDP remove the length prefix
if kdcs[i].Proto == "tcp" {
_, err = conn.Write(data)
} else {
_, err = conn.Write(data[4:])
}
if err != nil { if err != nil {
log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err) log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err)
conn.Close() conn.Close()
continue continue
} }
// todo check header kdcs[i].Conn = conn
resp, err = io.ReadAll(conn) go awaitReply(conn, kdcs[i].Proto == "udp", replies)
if err != nil {
log.Printf("error reading from kdc %s due to %s, trying next if available", kdcs[i], err)
conn.Close()
continue
}
conn.Close()
return resp, nil
} }
return nil, fmt.Errorf("no kdcs found for realm %s", realm) reply := <-replies
// close all the connections and return the first reply
for kdc := range kdcs {
if kdcs[kdc].Conn != nil {
kdcs[kdc].Conn.Close()
}
<-replies
}
if reply != nil {
return reply, nil
}
return nil, fmt.Errorf("no replies received from kdcs for realm %s", realm)
} }
func decode(data []byte) (msg *KdcProxyMsg, err error) { func decode(data []byte) (msg *KdcProxyMsg, err error) {
@ -155,3 +194,17 @@ func encode(krb5data []byte) (r []byte, err error) {
} }
return enc, nil return enc, nil
} }
func awaitReply(conn net.Conn, isUdp bool, reply chan<- []byte) {
resp, err := io.ReadAll(conn)
if err != nil {
log.Printf("error reading from kdc due to %s", err)
reply <- nil
return
}
if isUdp {
// udp will be missing the length prefix so add it
resp = append([]byte{byte(len(resp))}, resp...)
}
reply <- resp
}