Use FileId of previous request in request chain if operations are related

This commit is contained in:
Tal Aloni 2017-02-01 23:19:41 +02:00
parent e597bd082a
commit c4b6d9a08b
2 changed files with 85 additions and 14 deletions

View file

@ -37,6 +37,33 @@ namespace SMBLibrary.Server
return null;
}
public void ProcessSMB2RequestChain(List<SMB2Command> requestChain, ref ConnectionState state)
{
List<SMB2Command> responseChain = new List<SMB2Command>();
FileID? fileID = null;
foreach (SMB2Command request in requestChain)
{
if (request.Header.IsRelatedOperations && fileID.HasValue)
{
SetRequestFileID(request, fileID.Value);
}
SMB2Command response = ProcessSMB2Command(request, ref state);
if (response != null)
{
UpdateSMB2Header(response, request);
responseChain.Add(response);
if (!request.Header.IsRelatedOperations)
{
fileID = GetResponseFileID(response);
}
}
}
if (responseChain.Count > 0)
{
TrySendResponseChain(state, responseChain);
}
}
/// <summary>
/// May return null
/// </summary>
@ -204,5 +231,62 @@ namespace SMBLibrary.Server
response.Header.TreeID = request.Header.TreeID;
}
}
private static void SetRequestFileID(SMB2Command command, FileID fileID)
{
if (command is ChangeNotifyRequest)
{
((ChangeNotifyRequest)command).FileId = fileID;
}
else if (command is CloseRequest)
{
((CloseRequest)command).FileId = fileID;
}
else if (command is FlushRequest)
{
((FlushRequest)command).FileId = fileID;
}
else if (command is IOCtlRequest)
{
((IOCtlRequest)command).FileId = fileID;
}
else if (command is LockRequest)
{
((LockRequest)command).FileId = fileID;
}
else if (command is QueryDirectoryRequest)
{
((QueryDirectoryRequest)command).FileId = fileID;
}
else if (command is QueryInfoRequest)
{
((QueryInfoRequest)command).FileId = fileID;
}
else if (command is ReadRequest)
{
((ReadRequest)command).FileId = fileID;
}
else if (command is SetInfoRequest)
{
((SetInfoRequest)command).FileId = fileID;
}
else if (command is WriteRequest)
{
((WriteRequest)command).FileId = fileID;
}
}
private static FileID? GetResponseFileID(SMB2Command command)
{
if (command is CreateResponse)
{
return ((CreateResponse)command).FileId;
}
else if (command is IOCtlResponse)
{
return ((IOCtlResponse)command).FileId;
}
return null;
}
}
}