Corrected SPNEGO implementation

This commit is contained in:
Tal Aloni 2017-02-04 21:21:46 +02:00
parent 2f78339fc8
commit cd03f7c946
3 changed files with 146 additions and 191 deletions

View file

@ -24,25 +24,19 @@ namespace SMBLibrary.Authentication
{
if (token is SimpleProtectedNegotiationTokenInit)
{
List<TokenInitEntry> tokens = ((SimpleProtectedNegotiationTokenInit)token).Tokens;
foreach (TokenInitEntry entry in tokens)
{
foreach (byte[] identifier in entry.MechanismTypeList)
SimpleProtectedNegotiationTokenInit tokenInit = (SimpleProtectedNegotiationTokenInit)token;
foreach (byte[] identifier in tokenInit.MechanismTypeList)
{
if (ByteUtils.AreByteArraysEqual(identifier, NTLMSSPIdentifier))
{
return entry.MechanismToken;
}
return tokenInit.MechanismToken;
}
}
}
else
{
List<TokenResponseEntry> tokens = ((SimpleProtectedNegotiationTokenResponse)token).Tokens;
if (tokens.Count > 0)
{
return tokens[0].ResponseToken;
}
SimpleProtectedNegotiationTokenResponse tokenResponse = (SimpleProtectedNegotiationTokenResponse)token;
return tokenResponse.ResponseToken;
}
}
return null;
@ -51,30 +45,24 @@ namespace SMBLibrary.Authentication
public static byte[] GetGSSTokenInitNTLMSSPBytes()
{
SimpleProtectedNegotiationTokenInit token = new SimpleProtectedNegotiationTokenInit();
TokenInitEntry entry = new TokenInitEntry();
entry.MechanismTypeList = new List<byte[]>();
entry.MechanismTypeList.Add(NTLMSSPIdentifier);
token.Tokens.Add(entry);
token.MechanismTypeList = new List<byte[]>();
token.MechanismTypeList.Add(NTLMSSPIdentifier);
return SimpleProtectedNegotiationToken.GetTokenBytes(token);
}
public static byte[] GetGSSTokenResponseBytesFromNTLMSSPMessage(byte[] messageBytes)
{
SimpleProtectedNegotiationTokenResponse token = new SimpleProtectedNegotiationTokenResponse();
TokenResponseEntry entry = new TokenResponseEntry();
entry.NegState = NegState.AcceptIncomplete;
entry.SupportedMechanism = NTLMSSPIdentifier;
entry.ResponseToken = messageBytes;
token.Tokens.Add(entry);
token.NegState = NegState.AcceptIncomplete;
token.SupportedMechanism = NTLMSSPIdentifier;
token.ResponseToken = messageBytes;
return token.GetBytes();
}
public static byte[] GetGSSTokenAcceptCompletedResponse()
{
SimpleProtectedNegotiationTokenResponse token = new SimpleProtectedNegotiationTokenResponse();
TokenResponseEntry entry = new TokenResponseEntry();
entry.NegState = NegState.AcceptCompleted;
token.Tokens.Add(entry);
token.NegState = NegState.AcceptCompleted;
return token.GetBytes();
}
}

View file

@ -11,14 +11,6 @@ using Utilities;
namespace SMBLibrary.Authentication
{
public class TokenInitEntry
{
public List<byte[]> MechanismTypeList; // Optional
// reqFlags - Optional, RECOMMENDED to be left out
public byte[] MechanismToken; // Optional
public byte[] MechanismListMIC; // Optional
}
/// <summary>
/// RFC 4178 - negTokenInit
/// </summary>
@ -30,7 +22,10 @@ namespace SMBLibrary.Authentication
public const byte MechanismTokenTag = 0xA2;
public const byte MechanismListMICTag = 0xA3;
public List<TokenInitEntry> Tokens = new List<TokenInitEntry>();
public List<byte[]> MechanismTypeList; // Optional
// reqFlags - Optional, RECOMMENDED to be left out
public byte[] MechanismToken; // Optional
public byte[] MechanismListMIC; // Optional
public SimpleProtectedNegotiationTokenInit()
{
@ -40,23 +35,19 @@ namespace SMBLibrary.Authentication
public SimpleProtectedNegotiationTokenInit(byte[] buffer, int offset)
{
int constructionLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + constructionLength;
byte tag = ByteReader.ReadByte(buffer, ref offset);
if (tag != (byte)DerEncodingTag.Sequence)
{
throw new InvalidDataException();
}
int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + sequenceLength;
while (offset < sequenceEndOffset)
{
int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int entryEndOffset = offset + entryLength;
TokenInitEntry entry = new TokenInitEntry();
while (offset < entryEndOffset)
{
tag = ByteReader.ReadByte(buffer, ref offset);
if (tag == MechanismTypeListTag)
{
entry.MechanismTypeList = ReadMechanismTypeList(buffer, ref offset);
MechanismTypeList = ReadMechanismTypeList(buffer, ref offset);
}
else if (tag == RequiredFlagsTag)
{
@ -64,71 +55,65 @@ namespace SMBLibrary.Authentication
}
else if (tag == MechanismTokenTag)
{
entry.MechanismToken = ReadMechanismToken(buffer, ref offset);
MechanismToken = ReadMechanismToken(buffer, ref offset);
}
else if (tag == MechanismListMICTag)
{
entry.MechanismListMIC = ReadMechanismListMIC(buffer, ref offset);
MechanismListMIC = ReadMechanismListMIC(buffer, ref offset);
}
else
{
throw new InvalidDataException("Invalid negTokenInit structure");
}
}
Tokens.Add(entry);
}
}
public override byte[] GetBytes()
{
int sequenceLength = 0;
foreach (TokenInitEntry token in Tokens)
{
int entryLength = GetEntryLength(token);
sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength;
}
int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + sequenceLength);
int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLength;
int sequenceLength = GetTokenFieldsLength();
int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength);
int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength;
int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(constructionLength);
int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLengthFieldSize + sequenceLength;
byte[] buffer = new byte[bufferSize];
int offset = 0;
ByteWriter.WriteByte(buffer, ref offset, NegTokenInitTag);
DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength);
DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength);
ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence);
foreach (TokenInitEntry token in Tokens)
DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength);
if (MechanismTypeList != null)
{
int entryLength = GetEntryLength(token);
DerEncodingHelper.WriteLength(buffer, ref offset, entryLength);
if (token.MechanismTypeList != null)
{
WriteMechanismTypeList(buffer, ref offset, token.MechanismTypeList);
WriteMechanismTypeList(buffer, ref offset, MechanismTypeList);
}
if (token.MechanismToken != null)
if (MechanismToken != null)
{
WriteMechanismToken(buffer, ref offset, token.MechanismToken);
WriteMechanismToken(buffer, ref offset, MechanismToken);
}
if (token.MechanismListMIC != null)
if (MechanismListMIC != null)
{
WriteMechanismListMIC(buffer, ref offset, token.MechanismListMIC);
}
WriteMechanismListMIC(buffer, ref offset, MechanismListMIC);
}
return buffer;
}
public int GetEntryLength(TokenInitEntry token)
private int GetTokenFieldsLength()
{
int result = 0;
if (token.MechanismTypeList != null)
if (MechanismTypeList != null)
{
int typeListSequenceLength = GetSequenceLength(token.MechanismTypeList);
int constructionLenthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + typeListSequenceLength);
int typeListLength = 1 + constructionLenthFieldSize + 1 + typeListSequenceLength;
int typeListSequenceLength = GetSequenceLength(MechanismTypeList);
int typeListSequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(typeListSequenceLength);
int typeListConstructionLength = 1 + typeListSequenceLengthFieldSize + typeListSequenceLength;
int typeListConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(typeListConstructionLength);
int typeListLength = 1 + typeListConstructionLengthFieldSize + 1 + typeListSequenceLengthFieldSize + typeListSequenceLength;
result += typeListLength;
}
if (token.MechanismToken != null)
if (MechanismToken != null)
{
int byteArrayFieldSize = DerEncodingHelper.GetLengthFieldSize(token.MechanismToken.Length);
int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + byteArrayFieldSize + token.MechanismToken.Length);
int tokenLength = 1 + constructionLengthFieldSize + 1 + byteArrayFieldSize + token.MechanismToken.Length;
int mechanismTokenBytesFieldSize = DerEncodingHelper.GetLengthFieldSize(MechanismToken.Length);
int mechanismTokenConstructionLength = 1 + mechanismTokenBytesFieldSize + MechanismToken.Length;
int mechanismTokenConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismTokenConstructionLength);
int tokenLength = 1 + mechanismTokenConstructionLengthFieldSize + 1 + mechanismTokenBytesFieldSize + MechanismToken.Length;
result += tokenLength;
}
return result;
@ -138,16 +123,15 @@ namespace SMBLibrary.Authentication
{
List<byte[]> result = new List<byte[]>();
int constructionLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + constructionLength;
byte tag = ByteReader.ReadByte(buffer, ref offset);
if (tag != (byte)DerEncodingTag.Sequence)
{
throw new InvalidDataException();
}
int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + sequenceLength;
while (offset < sequenceEndOffset)
{
int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int entryEndOffset = offset + entryLength;
tag = ByteReader.ReadByte(buffer, ref offset);
if (tag != (byte)DerEncodingTag.ObjectIdentifier)
{
@ -192,7 +176,7 @@ namespace SMBLibrary.Authentication
{
int lengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismType.Length);
int entryLength = 1 + lengthFieldSize + mechanismType.Length;
sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength;
sequenceLength += entryLength;
}
return sequenceLength;
}
@ -200,15 +184,14 @@ namespace SMBLibrary.Authentication
private static void WriteMechanismTypeList(byte[] buffer, ref int offset, List<byte[]> mechanismTypeList)
{
int sequenceLength = GetSequenceLength(mechanismTypeList);
int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength);
int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength;
ByteWriter.WriteByte(buffer, ref offset, MechanismTypeListTag);
DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength);
DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength);
ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence);
DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength);
foreach (byte[] mechanismType in mechanismTypeList)
{
int lengthFieldSize = DerEncodingHelper.GetLengthFieldSize(mechanismType.Length);
int entryLength = 1 + lengthFieldSize + mechanismType.Length;
DerEncodingHelper.WriteLength(buffer, ref offset, entryLength);
ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.ObjectIdentifier);
DerEncodingHelper.WriteLength(buffer, ref offset, mechanismType.Length);
ByteWriter.WriteBytes(buffer, ref offset, mechanismType);

View file

@ -19,14 +19,6 @@ namespace SMBLibrary.Authentication
RequestMic = 0x03,
}
public class TokenResponseEntry
{
public NegState? NegState; // Optional
public byte[] SupportedMechanism; // Optional
public byte[] ResponseToken; // Optional
public byte[] MechanismListMIC; // Optional
}
/// <summary>
/// RFC 4178 - negTokenResp
/// </summary>
@ -38,7 +30,10 @@ namespace SMBLibrary.Authentication
public const byte ResponseTokenTag = 0xA2;
public const byte MechanismListMICTag = 0xA3;
public List<TokenResponseEntry> Tokens = new List<TokenResponseEntry>();
public NegState? NegState; // Optional
public byte[] SupportedMechanism; // Optional
public byte[] ResponseToken; // Optional
public byte[] MechanismListMIC; // Optional
public SimpleProtectedNegotiationTokenResponse()
{
@ -48,84 +43,98 @@ namespace SMBLibrary.Authentication
public SimpleProtectedNegotiationTokenResponse(byte[] buffer, int offset)
{
int constuctionLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + constuctionLength;
byte tag = ByteReader.ReadByte(buffer, ref offset);
if (tag != (byte)DerEncodingTag.Sequence)
{
throw new InvalidDataException();
}
int sequenceLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int sequenceEndOffset = offset + sequenceLength;
while (offset < sequenceEndOffset)
{
int entryLength = DerEncodingHelper.ReadLength(buffer, ref offset);
int entryEndOffset = offset + entryLength;
TokenResponseEntry entry = new TokenResponseEntry();
while (offset < entryEndOffset)
{
tag = ByteReader.ReadByte(buffer, ref offset);
if (tag == NegStateTag)
{
entry.NegState = ReadNegState(buffer, ref offset);
NegState = ReadNegState(buffer, ref offset);
}
else if (tag == SupportedMechanismTag)
{
entry.SupportedMechanism = ReadSupportedMechanism(buffer, ref offset);
SupportedMechanism = ReadSupportedMechanism(buffer, ref offset);
}
else if (tag == ResponseTokenTag)
{
entry.ResponseToken = ReadResponseToken(buffer, ref offset);
ResponseToken = ReadResponseToken(buffer, ref offset);
}
else if (tag == MechanismListMICTag)
{
entry.MechanismListMIC = ReadMechanismListMIC(buffer, ref offset);
MechanismListMIC = ReadMechanismListMIC(buffer, ref offset);
}
else
{
throw new InvalidDataException("Invalid negTokenResp structure");
}
}
Tokens.Add(entry);
}
}
public override byte[] GetBytes()
{
int sequenceLength = 0;
foreach (TokenResponseEntry token in Tokens)
{
int entryLength = GetEntryLength(token);
sequenceLength += DerEncodingHelper.GetLengthFieldSize(entryLength) + entryLength;
}
int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(1 + sequenceLength);
int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLength;
int sequenceLength = GetTokenFieldsLength();
int sequenceLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(sequenceLength);
int constructionLength = 1 + sequenceLengthFieldSize + sequenceLength;
int constructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(constructionLength);
int bufferSize = 1 + constructionLengthFieldSize + 1 + sequenceLengthFieldSize + sequenceLength;
byte[] buffer = new byte[bufferSize];
int offset = 0;
ByteWriter.WriteByte(buffer, ref offset, NegTokenRespTag);
DerEncodingHelper.WriteLength(buffer, ref offset, 1 + sequenceLength);
DerEncodingHelper.WriteLength(buffer, ref offset, constructionLength);
ByteWriter.WriteByte(buffer, ref offset, (byte)DerEncodingTag.Sequence);
foreach (TokenResponseEntry token in Tokens)
DerEncodingHelper.WriteLength(buffer, ref offset, sequenceLength);
if (NegState.HasValue)
{
int entryLength = GetEntryLength(token);
DerEncodingHelper.WriteLength(buffer, ref offset, entryLength);
if (token.NegState.HasValue)
{
WriteNegState(buffer, ref offset, token.NegState.Value);
WriteNegState(buffer, ref offset, NegState.Value);
}
if (token.SupportedMechanism != null)
if (SupportedMechanism != null)
{
WriteSupportedMechanism(buffer, ref offset, token.SupportedMechanism);
WriteSupportedMechanism(buffer, ref offset, SupportedMechanism);
}
if (token.ResponseToken != null)
if (ResponseToken != null)
{
WriteResponseToken(buffer, ref offset, token.ResponseToken);
WriteResponseToken(buffer, ref offset, ResponseToken);
}
if (token.MechanismListMIC != null)
if (MechanismListMIC != null)
{
WriteMechanismListMIC(buffer, ref offset, token.MechanismListMIC);
}
WriteMechanismListMIC(buffer, ref offset, MechanismListMIC);
}
return buffer;
}
private int GetTokenFieldsLength()
{
int result = 0;
if (NegState.HasValue)
{
int negStateLength = 5;
result += negStateLength;
}
if (SupportedMechanism != null)
{
int supportedMechanismBytesLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(SupportedMechanism.Length);
int supportedMechanismConstructionLength = 1 + supportedMechanismBytesLengthFieldSize + SupportedMechanism.Length;
int supportedMechanismConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(supportedMechanismConstructionLength);
int supportedMechanismLength = 1 + supportedMechanismConstructionLengthFieldSize + 1 + supportedMechanismBytesLengthFieldSize + SupportedMechanism.Length;
result += supportedMechanismLength;
}
if (ResponseToken != null)
{
int responseTokenBytesLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(ResponseToken.Length);
int responseTokenConstructionLength = 1 + responseTokenBytesLengthFieldSize + ResponseToken.Length;
int responseTokenConstructionLengthFieldSize = DerEncodingHelper.GetLengthFieldSize(responseTokenConstructionLength);
int responseTokenLength = 1 + responseTokenConstructionLengthFieldSize + 1 + responseTokenBytesLengthFieldSize + ResponseToken.Length;
result += responseTokenLength;
}
return result;
}
private static NegState ReadNegState(byte[] buffer, ref int offset)
{
int length = DerEncodingHelper.ReadLength(buffer, ref offset);
@ -174,31 +183,6 @@ namespace SMBLibrary.Authentication
return ByteReader.ReadBytes(buffer, ref offset, length);
}
private static int GetEntryLength(TokenResponseEntry token)
{
int result = 0;
if (token.NegState.HasValue)
{
int negStateLength = 5;
result += negStateLength;
}
if (token.SupportedMechanism != null)
{
int supportedMechanismLength2FieldSize = DerEncodingHelper.GetLengthFieldSize(token.SupportedMechanism.Length);
int supportedMechanismLength1FieldSize = DerEncodingHelper.GetLengthFieldSize(1 + supportedMechanismLength2FieldSize + token.SupportedMechanism.Length);
int supportedMechanismLength = 1 + supportedMechanismLength1FieldSize + 1 + supportedMechanismLength2FieldSize + token.SupportedMechanism.Length;
result += supportedMechanismLength;
}
if (token.ResponseToken != null)
{
int responseToken2FieldSize = DerEncodingHelper.GetLengthFieldSize(token.ResponseToken.Length);
int responseToken1FieldSize = DerEncodingHelper.GetLengthFieldSize(1 + responseToken2FieldSize + token.ResponseToken.Length);
int responseTokenLength = 1 + responseToken1FieldSize + 1 + responseToken2FieldSize + token.ResponseToken.Length;
result += responseTokenLength;
}
return result;
}
private static void WriteNegState(byte[] buffer, ref int offset, NegState negState)
{
ByteWriter.WriteByte(buffer, ref offset, NegStateTag);