From cd03f7c946485f76fd8f266d59eb293e37ada768 Mon Sep 17 00:00:00 2001 From: Tal Aloni Date: Sat, 4 Feb 2017 21:21:46 +0200 Subject: [PATCH] Corrected SPNEGO implementation --- .../Authentication/GSSAPI/GSSAPIHelper.cs | 36 ++-- .../SimpleProtectedNegotiationTokenInit.cs | 139 +++++++-------- ...SimpleProtectedNegotiationTokenResponse.cs | 162 ++++++++---------- 3 files changed, 146 insertions(+), 191 deletions(-) diff --git a/SMBLibrary/Authentication/GSSAPI/GSSAPIHelper.cs b/SMBLibrary/Authentication/GSSAPI/GSSAPIHelper.cs index 06a6ade..86c9c2e 100644 --- a/SMBLibrary/Authentication/GSSAPI/GSSAPIHelper.cs +++ b/SMBLibrary/Authentication/GSSAPI/GSSAPIHelper.cs @@ -24,25 +24,19 @@ namespace SMBLibrary.Authentication { if (token is SimpleProtectedNegotiationTokenInit) { - List tokens = ((SimpleProtectedNegotiationTokenInit)token).Tokens; - foreach (TokenInitEntry entry in tokens) + SimpleProtectedNegotiationTokenInit tokenInit = (SimpleProtectedNegotiationTokenInit)token; + foreach (byte[] identifier in tokenInit.MechanismTypeList) { - foreach (byte[] identifier in entry.MechanismTypeList) + if (ByteUtils.AreByteArraysEqual(identifier, NTLMSSPIdentifier)) { - if (ByteUtils.AreByteArraysEqual(identifier, NTLMSSPIdentifier)) - { - return entry.MechanismToken; - } + return tokenInit.MechanismToken; } } } else { - List tokens = ((SimpleProtectedNegotiationTokenResponse)token).Tokens; - if (tokens.Count > 0) - { - return tokens[0].ResponseToken; - } + SimpleProtectedNegotiationTokenResponse tokenResponse = (SimpleProtectedNegotiationTokenResponse)token; + return tokenResponse.ResponseToken; } } return null; @@ -51,30 +45,24 @@ namespace SMBLibrary.Authentication public static byte[] GetGSSTokenInitNTLMSSPBytes() { SimpleProtectedNegotiationTokenInit token = new SimpleProtectedNegotiationTokenInit(); - TokenInitEntry entry = new TokenInitEntry(); - entry.MechanismTypeList = new List(); - entry.MechanismTypeList.Add(NTLMSSPIdentifier); - token.Tokens.Add(entry); + token.MechanismTypeList = new List(); + token.MechanismTypeList.Add(NTLMSSPIdentifier); return SimpleProtectedNegotiationToken.GetTokenBytes(token); } public static byte[] GetGSSTokenResponseBytesFromNTLMSSPMessage(byte[] messageBytes) { SimpleProtectedNegotiationTokenResponse token = new SimpleProtectedNegotiationTokenResponse(); - TokenResponseEntry entry = new TokenResponseEntry(); - entry.NegState = NegState.AcceptIncomplete; - entry.SupportedMechanism = NTLMSSPIdentifier; - entry.ResponseToken = messageBytes; - token.Tokens.Add(entry); + token.NegState = NegState.AcceptIncomplete; + token.SupportedMechanism = NTLMSSPIdentifier; + token.ResponseToken = messageBytes; return token.GetBytes(); } public static byte[] GetGSSTokenAcceptCompletedResponse() { SimpleProtectedNegotiationTokenResponse token = new SimpleProtectedNegotiationTokenResponse(); - TokenResponseEntry entry = new TokenResponseEntry(); - entry.NegState = NegState.AcceptCompleted; - token.Tokens.Add(entry); + token.NegState = NegState.AcceptCompleted; return token.GetBytes(); } } diff --git a/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenInit.cs b/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenInit.cs index fd2acea..4082fa2 100644 --- a/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenInit.cs +++ b/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenInit.cs @@ -11,14 +11,6 @@ using Utilities; namespace SMBLibrary.Authentication { - public class TokenInitEntry - { - public List MechanismTypeList; // Optional - // reqFlags - Optional, RECOMMENDED to be left out - public byte[] MechanismToken; // Optional - public byte[] MechanismListMIC; // Optional - } - /// /// RFC 4178 - negTokenInit /// @@ -30,7 +22,10 @@ namespace SMBLibrary.Authentication public const byte MechanismTokenTag = 0xA2; public const byte MechanismListMICTag = 0xA3; - public List Tokens = new List(); + public List MechanismTypeList; // Optional + // reqFlags - Optional, RECOMMENDED to be left out + public byte[] MechanismToken; // Optional + public byte[] MechanismListMIC; // Optional public SimpleProtectedNegotiationTokenInit() { @@ -40,95 +35,85 @@ namespace SMBLibrary.Authentication public SimpleProtectedNegotiationTokenInit(byte[] buffer, int offset) { int constructionLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int sequenceEndOffset = offset + constructionLength; byte tag = ByteReader.ReadByte(buffer, ref offset); if (tag != (byte)DerEncodingTag.Sequence) { throw new InvalidDataException(); } + int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset); + int sequenceEndOffset = offset + sequenceLength; while (offset < sequenceEndOffset) { - int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int entryEndOffset = offset + entryLength; - TokenInitEntry entry = new TokenInitEntry(); - while (offset < entryEndOffset) + tag = ByteReader.ReadByte(buffer, ref offset); + if (tag == MechanismTypeListTag) { - tag = ByteReader.ReadByte(buffer, ref offset); - if (tag == MechanismTypeListTag) - { - entry.MechanismTypeList = ReadMechanismTypeList(buffer, ref offset); - } - else if (tag == RequiredFlagsTag) - { - throw new NotImplementedException("negTokenInit.ReqFlags is not implemented"); - } - else if (tag == MechanismTokenTag) - { - entry.MechanismToken = ReadMechanismToken(buffer, ref offset); - } - else if (tag == MechanismListMICTag) - { - entry.MechanismListMIC = ReadMechanismListMIC(buffer, ref offset); - } - else - { - throw new InvalidDataException("Invalid negTokenInit structure"); - } + MechanismTypeList = ReadMechanismTypeList(buffer, ref offset); + } + else if (tag == RequiredFlagsTag) + { + throw new NotImplementedException("negTokenInit.ReqFlags is not implemented"); + } + else if (tag == MechanismTokenTag) + { + MechanismToken = ReadMechanismToken(buffer, ref offset); + } + else if (tag == MechanismListMICTag) + { + MechanismListMIC = ReadMechanismListMIC(buffer, ref offset); + } + else + { + throw new InvalidDataException("Invalid negTokenInit structure"); } - Tokens.Add(entry); } } public override byte[] GetBytes() { - int sequenceLength = 0; - foreach (TokenInitEntry token in Tokens) - { - int entryLength = GetEntryLength(token); - sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength; - } - int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + sequenceLength); - int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLength; + int sequenceLength = GetTokenFieldsLength(); + int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength); + int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength; + int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(constructionLength); + int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLengthFieldSize + sequenceLength; byte[] buffer = new byte[bufferSize]; int offset = 0; ByteWriter.WriteByte(buffer, ref offset, NegTokenInitTag); - DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength); + DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength); ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence); - foreach (TokenInitEntry token in Tokens) + DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength); + if (MechanismTypeList != null) { - int entryLength = GetEntryLength(token); - DerEncodingHelper.WriteLength(buffer, ref offset, entryLength); - if (token.MechanismTypeList != null) - { - WriteMechanismTypeList(buffer, ref offset, token.MechanismTypeList); - } - if (token.MechanismToken != null) - { - WriteMechanismToken(buffer, ref offset, token.MechanismToken); - } - if (token.MechanismListMIC != null) - { - WriteMechanismListMIC(buffer, ref offset, token.MechanismListMIC); - } + WriteMechanismTypeList(buffer, ref offset, MechanismTypeList); + } + if (MechanismToken != null) + { + WriteMechanismToken(buffer, ref offset, MechanismToken); + } + if (MechanismListMIC != null) + { + WriteMechanismListMIC(buffer, ref offset, MechanismListMIC); } return buffer; } - public int GetEntryLength(TokenInitEntry token) + private int GetTokenFieldsLength() { int result = 0; - if (token.MechanismTypeList != null) + if (MechanismTypeList != null) { - int typeListSequenceLength = GetSequenceLength(token.MechanismTypeList); - int constructionLenthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + typeListSequenceLength); - int typeListLength = 1 + constructionLenthFieldSize + 1 + typeListSequenceLength; + int typeListSequenceLength = GetSequenceLength(MechanismTypeList); + int typeListSequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(typeListSequenceLength); + int typeListConstructionLength = 1 + typeListSequenceLengthFieldSize + typeListSequenceLength; + int typeListConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(typeListConstructionLength); + int typeListLength = 1 + typeListConstructionLengthFieldSize + 1 + typeListSequenceLengthFieldSize + typeListSequenceLength; result += typeListLength; } - if (token.MechanismToken != null) + if (MechanismToken != null) { - int byteArrayFieldSize = DerEncodingHelper.GetLengthFieldSize(token.MechanismToken.Length); - int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + byteArrayFieldSize + token.MechanismToken.Length); - int tokenLength = 1 + constructionLengthFieldSize + 1 + byteArrayFieldSize + token.MechanismToken.Length; + int mechanismTokenBytesFieldSize = DerEncodingHelper.GetLengthFieldSize(MechanismToken.Length); + int mechanismTokenConstructionLength = 1 + mechanismTokenBytesFieldSize + MechanismToken.Length; + int mechanismTokenConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismTokenConstructionLength); + int tokenLength = 1 + mechanismTokenConstructionLengthFieldSize + 1 + mechanismTokenBytesFieldSize + MechanismToken.Length; result += tokenLength; } return result; @@ -138,16 +123,15 @@ namespace SMBLibrary.Authentication { List result = new List(); int constructionLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int sequenceEndOffset = offset + constructionLength; byte tag = ByteReader.ReadByte(buffer, ref offset); if (tag != (byte)DerEncodingTag.Sequence) { throw new InvalidDataException(); } + int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset); + int sequenceEndOffset = offset + sequenceLength; while (offset < sequenceEndOffset) { - int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int entryEndOffset = offset + entryLength; tag = ByteReader.ReadByte(buffer, ref offset); if (tag != (byte)DerEncodingTag.ObjectIdentifier) { @@ -192,7 +176,7 @@ namespace SMBLibrary.Authentication { int lengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismType.Length); int entryLength = 1 + lengthFieldSize + mechanismType.Length; - sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength; + sequenceLength += entryLength; } return sequenceLength; } @@ -200,15 +184,14 @@ namespace SMBLibrary.Authentication private static void WriteMechanismTypeList(byte[] buffer, ref int offset, List mechanismTypeList) { int sequenceLength = GetSequenceLength(mechanismTypeList); + int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength); + int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength; ByteWriter.WriteByte(buffer, ref offset, MechanismTypeListTag); - DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength); + DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength); ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence); + DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength); foreach (byte[] mechanismType in mechanismTypeList) { - int lengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismType.Length); - int entryLength = 1 + lengthFieldSize + mechanismType.Length; - - DerEncodingHelper.WriteLength(buffer, ref offset, entryLength); ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.ObjectIdentifier); DerEncodingHelper.WriteLength(buffer, ref offset, mechanismType.Length); ByteWriter.WriteBytes(buffer, ref offset, mechanismType); diff --git a/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenResponse.cs b/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenResponse.cs index 6a3e5d2..ca94657 100644 --- a/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenResponse.cs +++ b/SMBLibrary/Authentication/GSSAPI/SimpleProtectedNegotiationTokenResponse.cs @@ -19,14 +19,6 @@ namespace SMBLibrary.Authentication RequestMic = 0x03, } - public class TokenResponseEntry - { - public NegState? NegState; // Optional - public byte[] SupportedMechanism; // Optional - public byte[] ResponseToken; // Optional - public byte[] MechanismListMIC; // Optional - } - /// /// RFC 4178 - negTokenResp /// @@ -38,7 +30,10 @@ namespace SMBLibrary.Authentication public const byte ResponseTokenTag = 0xA2; public const byte MechanismListMICTag = 0xA3; - public List Tokens = new List(); + public NegState? NegState; // Optional + public byte[] SupportedMechanism; // Optional + public byte[] ResponseToken; // Optional + public byte[] MechanismListMIC; // Optional public SimpleProtectedNegotiationTokenResponse() { @@ -48,84 +43,98 @@ namespace SMBLibrary.Authentication public SimpleProtectedNegotiationTokenResponse(byte[] buffer, int offset) { int constuctionLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int sequenceEndOffset = offset + constuctionLength; byte tag = ByteReader.ReadByte(buffer, ref offset); if (tag != (byte)DerEncodingTag.Sequence) { throw new InvalidDataException(); } + int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset); + int sequenceEndOffset = offset + sequenceLength; while (offset < sequenceEndOffset) { - int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset); - int entryEndOffset = offset + entryLength; - TokenResponseEntry entry = new TokenResponseEntry(); - while (offset < entryEndOffset) + tag = ByteReader.ReadByte(buffer, ref offset); + if (tag == NegStateTag) { - tag = ByteReader.ReadByte(buffer, ref offset); - if (tag == NegStateTag) - { - entry.NegState = ReadNegState(buffer, ref offset); - } - else if (tag == SupportedMechanismTag) - { - entry.SupportedMechanism = ReadSupportedMechanism(buffer, ref offset); - } - else if (tag == ResponseTokenTag) - { - entry.ResponseToken = ReadResponseToken(buffer, ref offset); - } - else if (tag == MechanismListMICTag) - { - entry.MechanismListMIC = ReadMechanismListMIC(buffer, ref offset); - } - else - { - throw new InvalidDataException("Invalid negTokenResp structure"); - } + NegState = ReadNegState(buffer, ref offset); + } + else if (tag == SupportedMechanismTag) + { + SupportedMechanism = ReadSupportedMechanism(buffer, ref offset); + } + else if (tag == ResponseTokenTag) + { + ResponseToken = ReadResponseToken(buffer, ref offset); + } + else if (tag == MechanismListMICTag) + { + MechanismListMIC = ReadMechanismListMIC(buffer, ref offset); + } + else + { + throw new InvalidDataException("Invalid negTokenResp structure"); } - Tokens.Add(entry); } } public override byte[] GetBytes() { - int sequenceLength = 0; - foreach (TokenResponseEntry token in Tokens) - { - int entryLength = GetEntryLength(token); - sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength; - } - int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + sequenceLength); - int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLength; + int sequenceLength = GetTokenFieldsLength(); + int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength); + int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength; + int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(constructionLength); + int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLengthFieldSize + sequenceLength; byte[] buffer = new byte[bufferSize]; int offset = 0; ByteWriter.WriteByte(buffer, ref offset, NegTokenRespTag); - DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength); + DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength); ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence); - foreach (TokenResponseEntry token in Tokens) + DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength); + if (NegState.HasValue) { - int entryLength = GetEntryLength(token); - DerEncodingHelper.WriteLength(buffer, ref offset, entryLength); - if (token.NegState.HasValue) - { - WriteNegState(buffer, ref offset, token.NegState.Value); - } - if (token.SupportedMechanism != null) - { - WriteSupportedMechanism(buffer, ref offset, token.SupportedMechanism); - } - if (token.ResponseToken != null) - { - WriteResponseToken(buffer, ref offset, token.ResponseToken); - } - if (token.MechanismListMIC != null) - { - WriteMechanismListMIC(buffer, ref offset, token.MechanismListMIC); - } + WriteNegState(buffer, ref offset, NegState.Value); + } + if (SupportedMechanism != null) + { + WriteSupportedMechanism(buffer, ref offset, SupportedMechanism); + } + if (ResponseToken != null) + { + WriteResponseToken(buffer, ref offset, ResponseToken); + } + if (MechanismListMIC != null) + { + WriteMechanismListMIC(buffer, ref offset, MechanismListMIC); } return buffer; } + private int GetTokenFieldsLength() + { + int result = 0; + if (NegState.HasValue) + { + int negStateLength = 5; + result += negStateLength; + } + if (SupportedMechanism != null) + { + int supportedMechanismBytesLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(SupportedMechanism.Length); + int supportedMechanismConstructionLength = 1 + supportedMechanismBytesLengthFieldSize + SupportedMechanism.Length; + int supportedMechanismConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(supportedMechanismConstructionLength); + int supportedMechanismLength = 1 + supportedMechanismConstructionLengthFieldSize + 1 + supportedMechanismBytesLengthFieldSize + SupportedMechanism.Length; + result += supportedMechanismLength; + } + if (ResponseToken != null) + { + int responseTokenBytesLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(ResponseToken.Length); + int responseTokenConstructionLength = 1 + responseTokenBytesLengthFieldSize + ResponseToken.Length; + int responseTokenConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(responseTokenConstructionLength); + int responseTokenLength = 1 + responseTokenConstructionLengthFieldSize + 1 + responseTokenBytesLengthFieldSize + ResponseToken.Length; + result += responseTokenLength; + } + return result; + } + private static NegState ReadNegState(byte[] buffer, ref int offset) { int length = DerEncodingHelper.ReadLength(buffer, ref offset); @@ -174,31 +183,6 @@ namespace SMBLibrary.Authentication return ByteReader.ReadBytes(buffer, ref offset, length); } - private static int GetEntryLength(TokenResponseEntry token) - { - int result = 0; - if (token.NegState.HasValue) - { - int negStateLength = 5; - result += negStateLength; - } - if (token.SupportedMechanism != null) - { - int supportedMechanismLength2FieldSize = DerEncodingHelper.GetLengthFieldSize(token.SupportedMechanism.Length); - int supportedMechanismLength1FieldSize = DerEncodingHelper.GetLengthFieldSize(1 + supportedMechanismLength2FieldSize + token.SupportedMechanism.Length); - int supportedMechanismLength = 1 + supportedMechanismLength1FieldSize + 1 + supportedMechanismLength2FieldSize + token.SupportedMechanism.Length; - result += supportedMechanismLength; - } - if (token.ResponseToken != null) - { - int responseToken2FieldSize = DerEncodingHelper.GetLengthFieldSize(token.ResponseToken.Length); - int responseToken1FieldSize = DerEncodingHelper.GetLengthFieldSize(1 + responseToken2FieldSize + token.ResponseToken.Length); - int responseTokenLength = 1 + responseToken1FieldSize + 1 + responseToken2FieldSize + token.ResponseToken.Length; - result += responseTokenLength; - } - return result; - } - private static void WriteNegState(byte[] buffer, ref int offset, NegState negState) { ByteWriter.WriteByte(buffer, ref offset, NegStateTag);