diff --git a/SMBLibrary/Server/SMBServer.SMB2.cs b/SMBLibrary/Server/SMBServer.SMB2.cs index c57abc0..9a10b24 100644 --- a/SMBLibrary/Server/SMBServer.SMB2.cs +++ b/SMBLibrary/Server/SMBServer.SMB2.cs @@ -37,6 +37,33 @@ namespace SMBLibrary.Server return null; } + public void ProcessSMB2RequestChain(List requestChain, ref ConnectionState state) + { + List responseChain = new List(); + FileID? fileID = null; + foreach (SMB2Command request in requestChain) + { + if (request.Header.IsRelatedOperations && fileID.HasValue) + { + SetRequestFileID(request, fileID.Value); + } + SMB2Command response = ProcessSMB2Command(request, ref state); + if (response != null) + { + UpdateSMB2Header(response, request); + responseChain.Add(response); + if (!request.Header.IsRelatedOperations) + { + fileID = GetResponseFileID(response); + } + } + } + if (responseChain.Count > 0) + { + TrySendResponseChain(state, responseChain); + } + } + /// /// May return null /// @@ -204,5 +231,62 @@ namespace SMBLibrary.Server response.Header.TreeID = request.Header.TreeID; } } + + private static void SetRequestFileID(SMB2Command command, FileID fileID) + { + if (command is ChangeNotifyRequest) + { + ((ChangeNotifyRequest)command).FileId = fileID; + } + else if (command is CloseRequest) + { + ((CloseRequest)command).FileId = fileID; + } + else if (command is FlushRequest) + { + ((FlushRequest)command).FileId = fileID; + } + else if (command is IOCtlRequest) + { + ((IOCtlRequest)command).FileId = fileID; + } + else if (command is LockRequest) + { + ((LockRequest)command).FileId = fileID; + } + else if (command is QueryDirectoryRequest) + { + ((QueryDirectoryRequest)command).FileId = fileID; + } + else if (command is QueryInfoRequest) + { + ((QueryInfoRequest)command).FileId = fileID; + } + else if (command is ReadRequest) + { + ((ReadRequest)command).FileId = fileID; + } + else if (command is SetInfoRequest) + { + ((SetInfoRequest)command).FileId = fileID; + } + else if (command is WriteRequest) + { + ((WriteRequest)command).FileId = fileID; + } + } + + private static FileID? GetResponseFileID(SMB2Command command) + { + if (command is CreateResponse) + { + return ((CreateResponse)command).FileId; + } + else if (command is IOCtlResponse) + { + return ((IOCtlResponse)command).FileId; + } + return null; + } } } diff --git a/SMBLibrary/Server/SMBServer.cs b/SMBLibrary/Server/SMBServer.cs index a7ff737..4d008d4 100644 --- a/SMBLibrary/Server/SMBServer.cs +++ b/SMBLibrary/Server/SMBServer.cs @@ -287,20 +287,7 @@ namespace SMBLibrary.Server return; } state.LogToServer(Severity.Verbose, "SMB2 request chain received: {0} requests, First request: {1}, Packet length: {2}", requestChain.Count, requestChain[0].CommandName.ToString(), packet.Length); - List responseChain = new List(); - foreach (SMB2Command request in requestChain) - { - SMB2Command response = ProcessSMB2Command(request, ref state); - if (response != null) - { - UpdateSMB2Header(response, request); - responseChain.Add(response); - } - } - if (responseChain.Count > 0) - { - TrySendResponseChain(state, responseChain); - } + ProcessSMB2RequestChain(requestChain, ref state); } else {