mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2025-08-16 05:33:47 +02:00
Use list of kdcs and ensure length is removed / added when necessary
This commit is contained in:
parent
a67962b02d
commit
ecbe63f175
1 changed files with 69 additions and 16 deletions
|
@ -23,6 +23,13 @@ type KdcProxyMsg struct {
|
||||||
Flags int `asn1:"tag:2,optional"`
|
Flags int `asn1:"tag:2,optional"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Kdc struct {
|
||||||
|
Realm string
|
||||||
|
Host string
|
||||||
|
Proto string
|
||||||
|
Conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
type KerberosProxy struct {
|
type KerberosProxy struct {
|
||||||
krb5Config *krbconfig.Config
|
krb5Config *krbconfig.Config
|
||||||
}
|
}
|
||||||
|
@ -97,39 +104,71 @@ func (k *KerberosProxy) forward(realm string, data []byte) (resp []byte, err err
|
||||||
}
|
}
|
||||||
|
|
||||||
// load udp first as is the default for kerberos
|
// load udp first as is the default for kerberos
|
||||||
c, kdcs, err := k.krb5Config.GetKDCs(realm, false)
|
udpCnt, udpKdcs, err := k.krb5Config.GetKDCs(realm, false)
|
||||||
if err != nil || c < 1 {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot get kdc for realm %s due to %s", realm, err)
|
return nil, fmt.Errorf("cannot get udp kdc for realm %s due to %s", realm, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load tcp
|
||||||
|
tcpCnt, tcpKdcs, err := k.krb5Config.GetKDCs(realm, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot get tcp kdc for realm %s due to %s", realm, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tcpCnt+udpCnt == 0 {
|
||||||
|
return nil, fmt.Errorf("cannot get any kdcs (tcp or udp) for realm %s", realm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge the kdcs
|
||||||
|
kdcs := make([]Kdc, tcpCnt+udpCnt)
|
||||||
|
for i := range udpKdcs {
|
||||||
|
kdcs[i] = Kdc{Realm: realm, Host: udpKdcs[i], Proto: "udp"}
|
||||||
|
}
|
||||||
|
for i := range tcpKdcs {
|
||||||
|
kdcs[i+udpCnt] = Kdc{Realm: realm, Host: tcpKdcs[i], Proto: "tcp"}
|
||||||
|
}
|
||||||
|
|
||||||
|
replies := make(chan []byte, len(kdcs))
|
||||||
for i := range kdcs {
|
for i := range kdcs {
|
||||||
conn, err := net.Dial("tcp", kdcs[i])
|
conn, err := net.Dial(kdcs[i].Proto, kdcs[i].Host)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err)
|
log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conn.SetDeadline(time.Now().Add(timeout))
|
conn.SetDeadline(time.Now().Add(timeout))
|
||||||
|
|
||||||
_, err = conn.Write(data)
|
// if we proxy over UDP remove the length prefix
|
||||||
|
if kdcs[i].Proto == "tcp" {
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
} else {
|
||||||
|
_, err = conn.Write(data[4:])
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err)
|
log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo check header
|
kdcs[i].Conn = conn
|
||||||
resp, err = io.ReadAll(conn)
|
go awaitReply(conn, kdcs[i].Proto == "udp", replies)
|
||||||
if err != nil {
|
|
||||||
log.Printf("error reading from kdc %s due to %s, trying next if available", kdcs[i], err)
|
|
||||||
conn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("no kdcs found for realm %s", realm)
|
reply := <-replies
|
||||||
|
|
||||||
|
// close all the connections and return the first reply
|
||||||
|
for kdc := range kdcs {
|
||||||
|
if kdcs[kdc].Conn != nil {
|
||||||
|
kdcs[kdc].Conn.Close()
|
||||||
|
}
|
||||||
|
<-replies
|
||||||
|
}
|
||||||
|
|
||||||
|
if reply != nil {
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("no replies received from kdcs for realm %s", realm)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decode(data []byte) (msg *KdcProxyMsg, err error) {
|
func decode(data []byte) (msg *KdcProxyMsg, err error) {
|
||||||
|
@ -155,3 +194,17 @@ func encode(krb5data []byte) (r []byte, err error) {
|
||||||
}
|
}
|
||||||
return enc, nil
|
return enc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func awaitReply(conn net.Conn, isUdp bool, reply chan<- []byte) {
|
||||||
|
resp, err := io.ReadAll(conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error reading from kdc due to %s", err)
|
||||||
|
reply <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if isUdp {
|
||||||
|
// udp will be missing the length prefix so add it
|
||||||
|
resp = append([]byte{byte(len(resp))}, resp...)
|
||||||
|
}
|
||||||
|
reply <- resp
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue