SMB1 command processing code refactoring

This commit is contained in:
Tal Aloni 2017-02-19 20:59:05 +02:00
parent 23f6127808
commit 4d51c7eed4
5 changed files with 104 additions and 77 deletions

View file

@ -1,4 +1,4 @@
/* Copyright (C) 2014 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
*
* You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation,
@ -264,5 +264,12 @@ namespace SMBLibrary.SMB1
throw new NotImplementedException("SMB Command 0x" + ((byte)commandName).ToString("X"));
}
}
public static implicit operator List<SMB1Command>(SMB1Command command)
{
List<SMB1Command> result = new List<SMB1Command>();
result.Add(command);
return result;
}
}
}

View file

@ -17,7 +17,7 @@ namespace SMBLibrary.Server.SMB1
/// <summary>
/// The client MUST send as many secondary requests as are needed to complete the transfer of the transaction request.
/// </summary>
internal static SMB1Command GetNTTransactResponse(SMB1Header header, NTTransactRequest request, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetNTTransactResponse(SMB1Header header, NTTransactRequest request, ISMBShare share, SMB1ConnectionState state)
{
if (request.TransParameters.Length < request.TotalParameterCount ||
request.TransData.Length < request.TotalDataCount)
@ -37,7 +37,7 @@ namespace SMBLibrary.Server.SMB1
else
{
// We have a complete command
return GetCompleteNTTransactResponse(header, request.Function, request.Setup, request.TransParameters, request.TransData, share, state, sendQueue);
return GetCompleteNTTransactResponse(header, request.Function, request.Setup, request.TransParameters, request.TransData, share, state);
}
}
@ -45,7 +45,7 @@ namespace SMBLibrary.Server.SMB1
/// There are no secondary response messages.
/// The client MUST send as many secondary requests as are needed to complete the transfer of the transaction request.
/// </summary>
internal static SMB1Command GetNTTransactResponse(SMB1Header header, NTTransactSecondaryRequest request, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetNTTransactResponse(SMB1Header header, NTTransactSecondaryRequest request, ISMBShare share, SMB1ConnectionState state)
{
ProcessStateObject processState = state.GetProcessState(header.PID);
if (processState == null)
@ -60,16 +60,16 @@ namespace SMBLibrary.Server.SMB1
if (processState.TransactionParametersReceived < processState.TransactionParameters.Length ||
processState.TransactionDataReceived < processState.TransactionData.Length)
{
return null;
return new List<SMB1Command>();
}
else
{
// We have a complete command
return GetCompleteNTTransactResponse(header, (NTTransactSubcommandName)processState.SubcommandID, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state, sendQueue);
return GetCompleteNTTransactResponse(header, (NTTransactSubcommandName)processState.SubcommandID, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state);
}
}
internal static SMB1Command GetCompleteNTTransactResponse(SMB1Header header, NTTransactSubcommandName subcommandName, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetCompleteNTTransactResponse(SMB1Header header, NTTransactSubcommandName subcommandName, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state)
{
NTTransactSubcommand subcommand = NTTransactSubcommand.GetSubcommandRequest(subcommandName, requestSetup, requestParameters, requestData, header.UnicodeFlag);
NTTransactSubcommand subcommandResponse = null;
@ -107,9 +107,7 @@ namespace SMBLibrary.Server.SMB1
byte[] responseSetup = subcommandResponse.GetSetup();
byte[] responseParameters = subcommandResponse.GetParameters(header.UnicodeFlag);
byte[] responseData = subcommandResponse.GetData();
NTTransactResponse response = new NTTransactResponse();
PrepareResponse(response, responseSetup, responseParameters, responseData, state.MaxBufferSize, sendQueue);
return response;
return GetNTTransactResponse(responseSetup, responseParameters, responseData, state.MaxBufferSize);
}
private static NTTransactIOCTLResponse GetSubcommandResponse(SMB1Header header, NTTransactIOCTLRequest subcommand, ISMBShare share, SMB1ConnectionState state)
@ -143,15 +141,17 @@ namespace SMBLibrary.Server.SMB1
}
}
private static void PrepareResponse(NTTransactResponse response, byte[] responseSetup, byte[] responseParameters, byte[] responseData, int maxBufferSize, List<SMB1Command> sendQueue)
private static List<SMB1Command> GetNTTransactResponse(byte[] responseSetup, byte[] responseParameters, byte[] responseData, int maxBufferSize)
{
if (NTTransactResponse.CalculateMessageSize(responseSetup.Length, responseParameters.Length, responseData.Length) <= maxBufferSize)
{
NTTransactResponse response = new NTTransactResponse();
response.Setup = responseSetup;
response.TotalParameterCount = (ushort)responseParameters.Length;
response.TotalDataCount = (ushort)responseData.Length;
response.TransParameters = responseParameters;
response.TransData = responseData;
return response;
}
else
{

View file

@ -43,17 +43,15 @@ namespace SMBLibrary.Server.SMB1
return new FindClose2Response();
}
internal static EchoResponse GetEchoResponse(EchoRequest request, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetEchoResponse(EchoRequest request)
{
EchoResponse response = new EchoResponse();
response.SequenceNumber = 0;
response.SMBData = request.SMBData;
for (int index = 1; index < request.EchoCount; index++)
List<SMB1Command> response = new List<SMB1Command>();
for (int index = 0; index < request.EchoCount; index++)
{
EchoResponse echo = new EchoResponse();
echo.SequenceNumber = (ushort)index;
echo.SMBData = request.SMBData;
sendQueue.Add(echo);
response.Add(echo);
}
return response;
}

View file

@ -21,7 +21,7 @@ namespace SMBLibrary.Server.SMB1
/// The client MUST send as many secondary requests as are needed to complete the transfer of the transaction request.
/// The server MUST respond to the transaction request as a whole.
/// </summary>
internal static SMB1Command GetTransactionResponse(SMB1Header header, TransactionRequest request, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetTransactionResponse(SMB1Header header, TransactionRequest request, ISMBShare share, SMB1ConnectionState state)
{
ProcessStateObject processState = state.ObtainProcessState(header.PID);
processState.MaxDataCount = request.MaxDataCount;
@ -38,18 +38,18 @@ namespace SMBLibrary.Server.SMB1
ByteWriter.WriteBytes(processState.TransactionData, 0, request.TransData);
processState.TransactionParametersReceived += request.TransParameters.Length;
processState.TransactionDataReceived += request.TransData.Length;
return null;
return new List<SMB1Command>();
}
else
{
// We have a complete command
if (request is Transaction2Request)
{
return GetCompleteTransaction2Response(header, request.Setup, request.TransParameters, request.TransData, share, state, sendQueue);
return GetCompleteTransaction2Response(header, request.Setup, request.TransParameters, request.TransData, share, state);
}
else
{
return GetCompleteTransactionResponse(header, request.Name, request.Setup, request.TransParameters, request.TransData, share, state, sendQueue);
return GetCompleteTransactionResponse(header, request.Name, request.Setup, request.TransParameters, request.TransData, share, state);
}
}
}
@ -59,7 +59,7 @@ namespace SMBLibrary.Server.SMB1
/// The client MUST send as many secondary requests as are needed to complete the transfer of the transaction request.
/// The server MUST respond to the transaction request as a whole.
/// </summary>
internal static SMB1Command GetTransactionResponse(SMB1Header header, TransactionSecondaryRequest request, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetTransactionResponse(SMB1Header header, TransactionSecondaryRequest request, ISMBShare share, SMB1ConnectionState state)
{
ProcessStateObject processState = state.GetProcessState(header.PID);
if (processState == null)
@ -74,23 +74,23 @@ namespace SMBLibrary.Server.SMB1
if (processState.TransactionParametersReceived < processState.TransactionParameters.Length ||
processState.TransactionDataReceived < processState.TransactionData.Length)
{
return null;
return new List<SMB1Command>();
}
else
{
// We have a complete command
if (request is Transaction2SecondaryRequest)
{
return GetCompleteTransaction2Response(header, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state, sendQueue);
return GetCompleteTransaction2Response(header, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state);
}
else
{
return GetCompleteTransactionResponse(header, processState.Name, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state, sendQueue);
return GetCompleteTransactionResponse(header, processState.Name, processState.TransactionSetup, processState.TransactionParameters, processState.TransactionData, share, state);
}
}
}
internal static SMB1Command GetCompleteTransactionResponse(SMB1Header header, string name, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetCompleteTransactionResponse(SMB1Header header, string name, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state)
{
if (String.Equals(name, @"\pipe\lanman", StringComparison.InvariantCultureIgnoreCase))
{
@ -169,12 +169,10 @@ namespace SMBLibrary.Server.SMB1
byte[] responseSetup = subcommandResponse.GetSetup();
byte[] responseParameters = subcommandResponse.GetParameters(header.UnicodeFlag);
byte[] responseData = subcommandResponse.GetData();
TransactionResponse response = new TransactionResponse();
PrepareResponse(response, responseSetup, responseParameters, responseData, state.MaxBufferSize, sendQueue);
return response;
return GetTransactionResponse(false, responseSetup, responseParameters, responseData, state.MaxBufferSize);
}
internal static SMB1Command GetCompleteTransaction2Response(SMB1Header header, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetCompleteTransaction2Response(SMB1Header header, byte[] requestSetup, byte[] requestParameters, byte[] requestData, ISMBShare share, SMB1ConnectionState state)
{
Transaction2Subcommand subcommand;
try
@ -237,13 +235,22 @@ namespace SMBLibrary.Server.SMB1
byte[] responseSetup = subcommandResponse.GetSetup();
byte[] responseParameters = subcommandResponse.GetParameters(header.UnicodeFlag);
byte[] responseData = subcommandResponse.GetData(header.UnicodeFlag);
Transaction2Response response = new Transaction2Response();
PrepareResponse(response, responseSetup, responseParameters, responseData, state.MaxBufferSize, sendQueue);
return response;
return GetTransactionResponse(true, responseSetup, responseParameters, responseData, state.MaxBufferSize);
}
internal static void PrepareResponse(TransactionResponse response, byte[] responseSetup, byte[] responseParameters, byte[] responseData, int maxBufferSize, List<SMB1Command> sendQueue)
internal static List<SMB1Command> GetTransactionResponse(bool transaction2Response, byte[] responseSetup, byte[] responseParameters, byte[] responseData, int maxBufferSize)
{
List<SMB1Command> result = new List<SMB1Command>();
TransactionResponse response;
if (transaction2Response)
{
response = new Transaction2Response();
}
else
{
response = new TransactionResponse();
}
result.Add(response);
int responseSize = TransactionResponse.CalculateMessageSize(responseSetup.Length, responseParameters.Length, responseData.Length);
if (responseSize <= maxBufferSize)
{
@ -268,7 +275,7 @@ namespace SMBLibrary.Server.SMB1
while (dataBytesLeftToSend > 0)
{
TransactionResponse additionalResponse;
if (response is Transaction2Response)
if (transaction2Response)
{
additionalResponse = new Transaction2Response();
}
@ -290,11 +297,12 @@ namespace SMBLibrary.Server.SMB1
additionalResponse.TransData = buffer;
additionalResponse.ParameterDisplacement = (ushort)response.TransParameters.Length;
additionalResponse.DataDisplacement = (ushort)dataBytesSent;
sendQueue.Add(additionalResponse);
result.Add(additionalResponse);
dataBytesLeftToSend -= currentDataLength;
}
}
return result;
}
}
}

View file

@ -17,41 +17,55 @@ namespace SMBLibrary.Server
{
public void ProcessSMB1Message(SMB1Message message, ref ConnectionState state)
{
SMB1Message reply = new SMB1Message();
PrepareResponseHeader(reply, message);
SMB1Header header = new SMB1Header();
PrepareResponseHeader(header, message.Header);
List<SMB1Command> sendQueue = new List<SMB1Command>();
bool isBatchedRequest = (message.Commands.Count > 1);
foreach (SMB1Command command in message.Commands)
{
SMB1Command response = ProcessSMB1Command(reply.Header, command, ref state, sendQueue);
if (response != null)
{
reply.Commands.Add(response);
}
if (reply.Header.Status != NTStatus.STATUS_SUCCESS)
List<SMB1Command> responses = ProcessSMB1Command(header, command, ref state);
sendQueue.AddRange(responses);
if (header.Status != NTStatus.STATUS_SUCCESS)
{
break;
}
}
if (reply.Commands.Count > 0)
if (isBatchedRequest)
{
TrySendMessage(state, reply);
foreach (SMB1Command command in sendQueue)
if (sendQueue.Count > 0)
{
SMB1Message secondaryReply = new SMB1Message();
secondaryReply.Header = reply.Header;
secondaryReply.Commands.Add(command);
TrySendMessage(state, secondaryReply);
// The server MUST batch the response into an AndX Response chain.
SMB1Message reply = new SMB1Message();
reply.Header = header;
for (int index = 0; index < sendQueue.Count; index++)
{
if (sendQueue[index] is SMBAndXCommand || index == sendQueue.Count - 1)
{
reply.Commands.Add(sendQueue[index]);
sendQueue.RemoveAt(index);
index--;
}
}
TrySendMessage(state, reply);
}
}
foreach (SMB1Command response in sendQueue)
{
SMB1Message reply = new SMB1Message();
reply.Header = header;
reply.Commands.Add(response);
TrySendMessage(state, reply);
}
}
/// <summary>
/// May return null
/// May return an empty list
/// </summary>
public SMB1Command ProcessSMB1Command(SMB1Header header, SMB1Command command, ref ConnectionState state, List<SMB1Command> sendQueue)
public List<SMB1Command> ProcessSMB1Command(SMB1Header header, SMB1Command command, ref ConnectionState state)
{
if (state.ServerDialect == SMBDialect.NotSet)
{
@ -92,11 +106,11 @@ namespace SMBLibrary.Server
}
else
{
return ProcessSMB1Command(header, command, (SMB1ConnectionState)state, sendQueue);
return ProcessSMB1Command(header, command, (SMB1ConnectionState)state);
}
}
private SMB1Command ProcessSMB1Command(SMB1Header header, SMB1Command command, SMB1ConnectionState state, List<SMB1Command> sendQueue)
private List<SMB1Command> ProcessSMB1Command(SMB1Header header, SMB1Command command, SMB1ConnectionState state)
{
if (command is SessionSetupAndXRequest)
{
@ -112,7 +126,7 @@ namespace SMBLibrary.Server
}
else if (command is EchoRequest)
{
return ServerResponseHelper.GetEchoResponse((EchoRequest)command, sendQueue);
return ServerResponseHelper.GetEchoResponse((EchoRequest)command);
}
else
{
@ -244,7 +258,7 @@ namespace SMBLibrary.Server
TransactionRequest request = (TransactionRequest)command;
try
{
return TransactionHelper.GetTransactionResponse(header, request, share, state, sendQueue);
return TransactionHelper.GetTransactionResponse(header, request, share, state);
}
catch (UnsupportedInformationLevelException)
{
@ -257,7 +271,7 @@ namespace SMBLibrary.Server
TransactionSecondaryRequest request = (TransactionSecondaryRequest)command;
try
{
return TransactionHelper.GetTransactionResponse(header, request, share, state, sendQueue);
return TransactionHelper.GetTransactionResponse(header, request, share, state);
}
catch (UnsupportedInformationLevelException)
{
@ -268,12 +282,12 @@ namespace SMBLibrary.Server
else if (command is NTTransactRequest)
{
NTTransactRequest request = (NTTransactRequest)command;
return NTTransactHelper.GetNTTransactResponse(header, request, share, state, sendQueue);
return NTTransactHelper.GetNTTransactResponse(header, request, share, state);
}
else if (command is NTTransactSecondaryRequest)
{
NTTransactSecondaryRequest request = (NTTransactSecondaryRequest)command;
return NTTransactHelper.GetNTTransactResponse(header, request, share, state, sendQueue);
return NTTransactHelper.GetNTTransactResponse(header, request, share, state);
}
else if (command is NTCreateAndXRequest)
{
@ -295,31 +309,31 @@ namespace SMBLibrary.Server
state.LogToServer(Severity.Verbose, "SMB1 message sent: {0} responses, First response: {1}, Packet length: {2}", response.Commands.Count, response.Commands[0].CommandName.ToString(), packet.Length);
}
private static void PrepareResponseHeader(SMB1Message response, SMB1Message request)
private static void PrepareResponseHeader(SMB1Header responseHeader, SMB1Header requestHeader)
{
response.Header.Status = NTStatus.STATUS_SUCCESS;
response.Header.Flags = HeaderFlags.CaseInsensitive | HeaderFlags.CanonicalizedPaths | HeaderFlags.Reply;
response.Header.Flags2 = HeaderFlags2.NTStatusCode;
if ((request.Header.Flags2 & HeaderFlags2.LongNamesAllowed) > 0)
responseHeader.Status = NTStatus.STATUS_SUCCESS;
responseHeader.Flags = HeaderFlags.CaseInsensitive | HeaderFlags.CanonicalizedPaths | HeaderFlags.Reply;
responseHeader.Flags2 = HeaderFlags2.NTStatusCode;
if ((requestHeader.Flags2 & HeaderFlags2.LongNamesAllowed) > 0)
{
response.Header.Flags2 |= HeaderFlags2.LongNamesAllowed | HeaderFlags2.LongNameUsed;
responseHeader.Flags2 |= HeaderFlags2.LongNamesAllowed | HeaderFlags2.LongNameUsed;
}
if ((request.Header.Flags2 & HeaderFlags2.ExtendedAttributes) > 0)
if ((requestHeader.Flags2 & HeaderFlags2.ExtendedAttributes) > 0)
{
response.Header.Flags2 |= HeaderFlags2.ExtendedAttributes;
responseHeader.Flags2 |= HeaderFlags2.ExtendedAttributes;
}
if ((request.Header.Flags2 & HeaderFlags2.ExtendedSecurity) > 0)
if ((requestHeader.Flags2 & HeaderFlags2.ExtendedSecurity) > 0)
{
response.Header.Flags2 |= HeaderFlags2.ExtendedSecurity;
responseHeader.Flags2 |= HeaderFlags2.ExtendedSecurity;
}
if ((request.Header.Flags2 & HeaderFlags2.Unicode) > 0)
if ((requestHeader.Flags2 & HeaderFlags2.Unicode) > 0)
{
response.Header.Flags2 |= HeaderFlags2.Unicode;
responseHeader.Flags2 |= HeaderFlags2.Unicode;
}
response.Header.MID = request.Header.MID;
response.Header.PID = request.Header.PID;
response.Header.UID = request.Header.UID;
response.Header.TID = request.Header.TID;
responseHeader.MID = requestHeader.MID;
responseHeader.PID = requestHeader.PID;
responseHeader.UID = requestHeader.UID;
responseHeader.TID = requestHeader.TID;
}
}
}