From df58913372883bc65a1c1326740bb534cd287da1 Mon Sep 17 00:00:00 2001 From: Tal Aloni Date: Sun, 3 Sep 2017 20:06:48 +0300 Subject: [PATCH] SMBServer: SMB2: Improved handling of compunded related requests --- SMBLibrary/Server/SMBServer.SMB2.cs | 104 ++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 4 deletions(-) diff --git a/SMBLibrary/Server/SMBServer.SMB2.cs b/SMBLibrary/Server/SMBServer.SMB2.cs index e8aba4f..a966472 100644 --- a/SMBLibrary/Server/SMBServer.SMB2.cs +++ b/SMBLibrary/Server/SMBServer.SMB2.cs @@ -19,20 +19,51 @@ namespace SMBLibrary.Server { List responseChain = new List(); FileID? fileID = null; + NTStatus? fileIDStatus = null; foreach (SMB2Command request in requestChain) { - if (request.Header.IsRelatedOperations && fileID.HasValue) + SMB2Command response; + if (request.Header.IsRelatedOperations && RequestContainsFileID(request)) { - SetRequestFileID(request, fileID.Value); + if (fileIDStatus != null && fileIDStatus != NTStatus.STATUS_SUCCESS && fileIDStatus != NTStatus.STATUS_BUFFER_OVERFLOW) + { + // [MS-SMB2] When the current request requires a FileId and the previous request either contains + // or generates a FileId, if the previous request fails with an error, the server SHOULD fail the + // current request with the same error code returned by the previous request. + state.LogToServer(Severity.Verbose, "Compunded related request {0} failed because FileId generation failed.", request.CommandName); + response = new ErrorResponse(request.CommandName, fileIDStatus.Value); + } + else if (fileID.HasValue) + { + SetRequestFileID(request, fileID.Value); + response = ProcessSMB2Command(request, ref state); + } + else + { + // [MS-SMB2] When the current request requires a FileId, and if the previous request neither contains + // nor generates a FileId, the server MUST fail the compounded request with STATUS_INVALID_PARAMETER. + state.LogToServer(Severity.Verbose, "Compunded related request {0} failed, the previous request neither contains nor generates a FileId.", request.CommandName); + response = new ErrorResponse(request.CommandName, NTStatus.STATUS_INVALID_PARAMETER); + } } - SMB2Command response = ProcessSMB2Command(request, ref state); + else + { + fileID = GetRequestFileID(request); + response = ProcessSMB2Command(request, ref state); + } + if (response != null) { UpdateSMB2Header(response, request, state); responseChain.Add(response); - if (!request.Header.IsRelatedOperations) + if (GeneratesFileID(response)) { fileID = GetResponseFileID(response); + fileIDStatus = response.Header.Status; + } + else if (RequestContainsFileID(request)) + { + fileIDStatus = response.Header.Status; } } } @@ -250,6 +281,65 @@ namespace SMBLibrary.Server response.Header.IsSigned = (request.Header.IsSigned || signingRequired) && !isInterimResponse; } + private static bool RequestContainsFileID(SMB2Command command) + { + return (command is ChangeNotifyRequest || + command is CloseRequest || + command is FlushRequest || + command is IOCtlRequest || + command is LockRequest || + command is QueryDirectoryRequest || + command is QueryInfoRequest || + command is ReadRequest || + command is SetInfoRequest || + command is WriteRequest); + } + + private static FileID? GetRequestFileID(SMB2Command command) + { + if (command is ChangeNotifyRequest) + { + return ((ChangeNotifyRequest)command).FileId; + } + else if (command is CloseRequest) + { + return ((CloseRequest)command).FileId; + } + else if (command is FlushRequest) + { + return ((FlushRequest)command).FileId; + } + else if (command is IOCtlRequest) + { + return ((IOCtlRequest)command).FileId; + } + else if (command is LockRequest) + { + return ((LockRequest)command).FileId; + } + else if (command is QueryDirectoryRequest) + { + return ((QueryDirectoryRequest)command).FileId; + } + else if (command is QueryInfoRequest) + { + return ((QueryInfoRequest)command).FileId; + } + else if (command is ReadRequest) + { + return ((ReadRequest)command).FileId; + } + else if (command is SetInfoRequest) + { + return ((SetInfoRequest)command).FileId; + } + else if (command is WriteRequest) + { + return ((WriteRequest)command).FileId; + } + return null; + } + private static void SetRequestFileID(SMB2Command command, FileID fileID) { if (command is ChangeNotifyRequest) @@ -294,6 +384,12 @@ namespace SMBLibrary.Server } } + private static bool GeneratesFileID(SMB2Command command) + { + return (command.CommandName == SMB2CommandName.Create || + command.CommandName == SMB2CommandName.IOCtl); + } + private static FileID? GetResponseFileID(SMB2Command command) { if (command is CreateResponse)