Refactored ReadWriteResponseHelper

This commit is contained in:
Tal Aloni 2017-01-14 19:52:55 +02:00
parent 2049be2220
commit d42240ffa1

View file

@ -19,10 +19,17 @@ namespace SMBLibrary.Server.SMB1
{ {
internal static SMB1Command GetReadResponse(SMB1Header header, ReadRequest request, ISMBShare share, SMB1ConnectionState state) internal static SMB1Command GetReadResponse(SMB1Header header, ReadRequest request, ISMBShare share, SMB1ConnectionState state)
{ {
byte[] data = PerformRead(header, share, request.FID, request.ReadOffsetInBytes, request.CountOfBytesToRead, state); OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return null;
}
byte[] data;
header.Status = ReadFile(out data, openedFile, request.ReadOffsetInBytes, request.CountOfBytesToRead, state);
if (header.Status != NTStatus.STATUS_SUCCESS) if (header.Status != NTStatus.STATUS_SUCCESS)
{ {
return new ErrorResponse(CommandName.SMB_COM_READ); return new ErrorResponse(request.CommandName);
} }
ReadResponse response = new ReadResponse(); ReadResponse response = new ReadResponse();
@ -33,15 +40,22 @@ namespace SMBLibrary.Server.SMB1
internal static SMB1Command GetReadResponse(SMB1Header header, ReadAndXRequest request, ISMBShare share, SMB1ConnectionState state) internal static SMB1Command GetReadResponse(SMB1Header header, ReadAndXRequest request, ISMBShare share, SMB1ConnectionState state)
{ {
OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return null;
}
uint maxCount = request.MaxCount; uint maxCount = request.MaxCount;
if ((share is FileSystemShare) && state.LargeRead) if ((share is FileSystemShare) && state.LargeRead)
{ {
maxCount = request.MaxCountLarge; maxCount = request.MaxCountLarge;
} }
byte[] data = PerformRead(header, share, request.FID, request.Offset, maxCount, state); byte[] data;
header.Status = ReadFile(out data, openedFile, (long)request.Offset, (int)maxCount, state);
if (header.Status != NTStatus.STATUS_SUCCESS) if (header.Status != NTStatus.STATUS_SUCCESS)
{ {
return new ErrorResponse(CommandName.SMB_COM_READ_ANDX); return new ErrorResponse(request.CommandName);
} }
ReadAndXResponse response = new ReadAndXResponse(); ReadAndXResponse response = new ReadAndXResponse();
@ -54,47 +68,31 @@ namespace SMBLibrary.Server.SMB1
return response; return response;
} }
public static byte[] PerformRead(SMB1Header header, ISMBShare share, ushort FID, ulong offset, uint maxCount, SMB1ConnectionState state) public static NTStatus ReadFile(out byte[] data, OpenedFileObject openedFile, long offset, int maxCount, ConnectionState state)
{ {
if (offset > Int64.MaxValue || maxCount > Int32.MaxValue) data = null;
{
throw new NotImplementedException("Underlying filesystem does not support unsigned offset / read count");
}
return PerformRead(header, share, FID, (long)offset, (int)maxCount, state);
}
public static byte[] PerformRead(SMB1Header header, ISMBShare share, ushort FID, long offset, int maxCount, SMB1ConnectionState state)
{
OpenedFileObject openedFile = state.GetOpenedFileObject(FID);
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return null;
}
string openedFilePath = openedFile.Path; string openedFilePath = openedFile.Path;
Stream stream = openedFile.Stream; Stream stream = openedFile.Stream;
if (share is NamedPipeShare) if (stream is RPCPipeStream)
{ {
byte[] data = new byte[maxCount]; data = new byte[maxCount];
int bytesRead = stream.Read(data, 0, maxCount); int bytesRead = stream.Read(data, 0, maxCount);
if (bytesRead < maxCount) if (bytesRead < maxCount)
{ {
// EOF, we must trim the response data array // EOF, we must trim the response data array
data = ByteReader.ReadBytes(data, 0, bytesRead); data = ByteReader.ReadBytes(data, 0, bytesRead);
} }
return data; return NTStatus.STATUS_SUCCESS;
} }
else // FileSystemShare else // File
{ {
if (stream == null) if (stream == null)
{ {
header.Status = NTStatus.STATUS_ACCESS_DENIED; state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}', Invalid Operation.", openedFilePath);
return null; return NTStatus.STATUS_ACCESS_DENIED;
} }
int bytesRead; int bytesRead;
byte[] data;
try try
{ {
stream.Seek(offset, SeekOrigin.Begin); stream.Seek(offset, SeekOrigin.Begin);
@ -107,28 +105,24 @@ namespace SMBLibrary.Server.SMB1
if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION) if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION)
{ {
// Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid // Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid
state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Sharing Violation.", openedFilePath); state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Sharing Violation.", openedFilePath);
header.Status = NTStatus.STATUS_SHARING_VIOLATION; return NTStatus.STATUS_SHARING_VIOLATION;
return null;
} }
else else
{ {
state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Data Error.", openedFilePath); state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Data Error.", openedFilePath);
header.Status = NTStatus.STATUS_DATA_ERROR; return NTStatus.STATUS_DATA_ERROR;
return null;
} }
} }
catch (ArgumentOutOfRangeException) catch (ArgumentOutOfRangeException)
{ {
state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Offset Out Of Range.", openedFilePath); state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Offset Out Of Range.", openedFilePath);
header.Status = NTStatus.STATUS_DATA_ERROR; return NTStatus.STATUS_DATA_ERROR;
return null;
} }
catch (UnauthorizedAccessException) catch (UnauthorizedAccessException)
{ {
state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}', Access Denied.", openedFilePath); state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}', Access Denied.", openedFilePath);
header.Status = NTStatus.STATUS_ACCESS_DENIED; return NTStatus.STATUS_ACCESS_DENIED;
return null;
} }
if (bytesRead < maxCount) if (bytesRead < maxCount)
@ -136,33 +130,45 @@ namespace SMBLibrary.Server.SMB1
// EOF, we must trim the response data array // EOF, we must trim the response data array
data = ByteReader.ReadBytes(data, 0, bytesRead); data = ByteReader.ReadBytes(data, 0, bytesRead);
} }
return data; return NTStatus.STATUS_SUCCESS;
} }
} }
internal static SMB1Command GetWriteResponse(SMB1Header header, WriteRequest request, ISMBShare share, SMB1ConnectionState state) internal static SMB1Command GetWriteResponse(SMB1Header header, WriteRequest request, ISMBShare share, SMB1ConnectionState state)
{ {
ushort bytesWritten = (ushort)PerformWrite(header, share, request.FID, request.WriteOffsetInBytes, request.Data, state); OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return new ErrorResponse(request.CommandName);
}
int numberOfBytesWritten;
header.Status = WriteFile(out numberOfBytesWritten, openedFile, request.WriteOffsetInBytes, request.Data, state);
if (header.Status != NTStatus.STATUS_SUCCESS) if (header.Status != NTStatus.STATUS_SUCCESS)
{ {
return new ErrorResponse(CommandName.SMB_COM_WRITE_ANDX); return new ErrorResponse(request.CommandName);
} }
WriteResponse response = new WriteResponse(); WriteResponse response = new WriteResponse();
response.CountOfBytesWritten = bytesWritten; response.CountOfBytesWritten = (ushort)numberOfBytesWritten;
return response; return response;
} }
internal static SMB1Command GetWriteResponse(SMB1Header header, WriteAndXRequest request, ISMBShare share, SMB1ConnectionState state) internal static SMB1Command GetWriteResponse(SMB1Header header, WriteAndXRequest request, ISMBShare share, SMB1ConnectionState state)
{ {
uint bytesWritten = PerformWrite(header, share, request.FID, request.Offset, request.Data, state); OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return new ErrorResponse(request.CommandName);
}
int numberOfBytesWritten;
header.Status = WriteFile(out numberOfBytesWritten, openedFile, (long)request.Offset, request.Data, state);
if (header.Status != NTStatus.STATUS_SUCCESS) if (header.Status != NTStatus.STATUS_SUCCESS)
{ {
return new ErrorResponse(CommandName.SMB_COM_WRITE_ANDX); return new ErrorResponse(request.CommandName);
} }
WriteAndXResponse response = new WriteAndXResponse(); WriteAndXResponse response = new WriteAndXResponse();
response.Count = bytesWritten; response.Count = (uint)numberOfBytesWritten;
if (share is FileSystemShare) if (share is FileSystemShare)
{ {
// If the client wrote to a disk file, this field MUST be set to 0xFFFF. // If the client wrote to a disk file, this field MUST be set to 0xFFFF.
@ -171,65 +177,62 @@ namespace SMBLibrary.Server.SMB1
return response; return response;
} }
public static uint PerformWrite(SMB1Header header, ISMBShare share, ushort FID, ulong offset, byte[] data, SMB1ConnectionState state) public static NTStatus WriteFile(out int numberOfBytesWritten, OpenedFileObject openedFile, long offset, byte[] data, ConnectionState state)
{ {
OpenedFileObject openedFile = state.GetOpenedFileObject(FID); numberOfBytesWritten = 0;
if (openedFile == null)
{
header.Status = NTStatus.STATUS_INVALID_HANDLE;
return 0;
}
string openedFilePath = openedFile.Path; string openedFilePath = openedFile.Path;
Stream stream = openedFile.Stream; Stream stream = openedFile.Stream;
if (share is NamedPipeShare) if (stream is RPCPipeStream)
{ {
stream.Write(data, 0, data.Length); stream.Write(data, 0, data.Length);
return (uint)data.Length; numberOfBytesWritten = data.Length;
return NTStatus.STATUS_SUCCESS;
} }
else // FileSystemShare else // File
{ {
if (stream == null) if (stream == null)
{ {
header.Status = NTStatus.STATUS_ACCESS_DENIED; state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Invalid Operation.", openedFilePath);
return 0; return NTStatus.STATUS_ACCESS_DENIED;
} }
try try
{ {
stream.Seek((long)offset, SeekOrigin.Begin); stream.Seek(offset, SeekOrigin.Begin);
stream.Write(data, 0, data.Length); stream.Write(data, 0, data.Length);
return (uint)data.Length; numberOfBytesWritten = data.Length;
return NTStatus.STATUS_SUCCESS;
} }
catch (IOException ex) catch (IOException ex)
{ {
ushort errorCode = IOExceptionHelper.GetWin32ErrorCode(ex); ushort errorCode = IOExceptionHelper.GetWin32ErrorCode(ex);
if (errorCode == (ushort)Win32Error.ERROR_DISK_FULL) if (errorCode == (ushort)Win32Error.ERROR_DISK_FULL)
{ {
header.Status = NTStatus.STATUS_DISK_FULL; state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Disk Full.", openedFilePath);
return 0; return NTStatus.STATUS_DISK_FULL;
} }
else if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION) else if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION)
{ {
state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Sharing Violation.", openedFilePath);
// Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid // Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid
header.Status = NTStatus.STATUS_SHARING_VIOLATION; return NTStatus.STATUS_SHARING_VIOLATION;
return 0;
} }
else else
{ {
header.Status = NTStatus.STATUS_DATA_ERROR; state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Data Error.", openedFilePath);
return 0; return NTStatus.STATUS_DATA_ERROR;
} }
} }
catch (ArgumentOutOfRangeException) catch (ArgumentOutOfRangeException)
{ {
header.Status = NTStatus.STATUS_DATA_ERROR; state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Offset Out Of Range.", openedFilePath);
return 0; return NTStatus.STATUS_DATA_ERROR;
} }
catch (UnauthorizedAccessException) catch (UnauthorizedAccessException)
{ {
state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Access Denied.", openedFilePath);
// The user may have tried to write to a readonly file // The user may have tried to write to a readonly file
header.Status = NTStatus.STATUS_ACCESS_DENIED; return NTStatus.STATUS_ACCESS_DENIED;
return 0;
} }
} }
} }