SMB2: FileID related improvements

This commit is contained in:
Tal Aloni 2017-03-08 14:45:14 +02:00
parent 8e07373185
commit f71ef6b232
11 changed files with 66 additions and 72 deletions

View file

@ -12,18 +12,14 @@ using Utilities;
namespace SMBLibrary.Server namespace SMBLibrary.Server
{ {
public delegate ulong? AllocatePersistentFileID();
internal class SMB2ConnectionState : ConnectionState internal class SMB2ConnectionState : ConnectionState
{ {
// Key is SessionID // Key is SessionID
private Dictionary<ulong, SMB2Session> m_sessions = new Dictionary<ulong, SMB2Session>(); private Dictionary<ulong, SMB2Session> m_sessions = new Dictionary<ulong, SMB2Session>();
private ulong m_nextSessionID = 1; private ulong m_nextSessionID = 1;
public AllocatePersistentFileID AllocatePersistentFileID;
public SMB2ConnectionState(ConnectionState state, AllocatePersistentFileID allocatePersistentFileID) : base(state) public SMB2ConnectionState(ConnectionState state) : base(state)
{ {
AllocatePersistentFileID = allocatePersistentFileID;
} }
public ulong? AllocateSessionID() public ulong? AllocateSessionID()

View file

@ -24,10 +24,11 @@ namespace SMBLibrary.Server
private Dictionary<uint, ISMBShare> m_connectedTrees = new Dictionary<uint, ISMBShare>(); private Dictionary<uint, ISMBShare> m_connectedTrees = new Dictionary<uint, ISMBShare>();
private uint m_nextTreeID = 1; // TreeID uniquely identifies a tree connect within the scope of the session private uint m_nextTreeID = 1; // TreeID uniquely identifies a tree connect within the scope of the session
// Key is the persistent portion of the FileID // Key is the volatile portion of the FileID
private Dictionary<ulong, OpenFileObject> m_openFiles = new Dictionary<ulong, OpenFileObject>(); private Dictionary<ulong, OpenFileObject> m_openFiles = new Dictionary<ulong, OpenFileObject>();
private ulong m_nextVolatileFileID = 1;
// Key is the persistent portion of the FileID // Key is the volatile portion of the FileID
private Dictionary<ulong, OpenSearch> m_openSearches = new Dictionary<ulong, OpenSearch>(); private Dictionary<ulong, OpenSearch> m_openSearches = new Dictionary<ulong, OpenSearch>();
public SMB2Session(SMB2ConnectionState connection, ulong sessionID, string userName, string machineName, byte[] sessionKey, object accessToken) public SMB2Session(SMB2ConnectionState connection, ulong sessionID, string userName, string machineName, byte[] sessionKey, object accessToken)
@ -107,39 +108,58 @@ namespace SMBLibrary.Server
return m_connectedTrees.ContainsKey(treeID); return m_connectedTrees.ContainsKey(treeID);
} }
/// <returns>The persistent portion of the FileID</returns> // VolatileFileID MUST be unique for all volatile handles within the scope of a session
public ulong? AddOpenFile(uint treeID, string relativePath, object handle) private ulong? AllocateVolatileFileID()
{ {
ulong? persistentID = m_connection.AllocatePersistentFileID(); for (ulong offset = 0; offset < UInt64.MaxValue; offset++)
if (persistentID.HasValue)
{ {
lock (m_openFiles) ulong volatileFileID = (ulong)(m_nextVolatileFileID + offset);
if (volatileFileID == 0 || volatileFileID == 0xFFFFFFFFFFFFFFFF)
{ {
m_openFiles.Add(persistentID.Value, new OpenFileObject(treeID, relativePath, handle)); continue;
}
if (!m_openFiles.ContainsKey(volatileFileID))
{
m_nextVolatileFileID = (ulong)(volatileFileID + 1);
return volatileFileID;
} }
} }
return persistentID;
}
public OpenFileObject GetOpenFileObject(ulong fileID)
{
if (m_openFiles.ContainsKey(fileID))
{
return m_openFiles[fileID];
}
else
{
return null; return null;
} }
public FileID? AddOpenFile(uint treeID, string relativePath, object handle)
{
ulong? volatileFileID = AllocateVolatileFileID();
if (volatileFileID.HasValue)
{
FileID fileID = new FileID();
fileID.Volatile = volatileFileID.Value;
// [MS-SMB2] FileId.Persistent MUST be set to Open.DurableFileId.
// Note: We don't support durable handles so we use volatileFileID.
fileID.Persistent = volatileFileID.Value;
lock (m_openFiles)
{
m_openFiles.Add(volatileFileID.Value, new OpenFileObject(treeID, relativePath, handle));
}
return fileID;
}
return null;
} }
public void RemoveOpenFile(ulong fileID) public OpenFileObject GetOpenFileObject(FileID fileID)
{
OpenFileObject result;
m_openFiles.TryGetValue(fileID.Volatile, out result);
return result;
}
public void RemoveOpenFile(FileID fileID)
{ {
lock (m_openFiles) lock (m_openFiles)
{ {
m_openFiles.Remove(fileID); m_openFiles.Remove(fileID.Volatile);
} }
m_openSearches.Remove(fileID); m_openSearches.Remove(fileID.Volatile);
} }
public List<string> ListOpenFiles() public List<string> ListOpenFiles()
@ -155,23 +175,23 @@ namespace SMBLibrary.Server
return result; return result;
} }
public OpenSearch AddOpenSearch(ulong fileID, List<QueryDirectoryFileInformation> entries, int enumerationLocation) public OpenSearch AddOpenSearch(FileID fileID, List<QueryDirectoryFileInformation> entries, int enumerationLocation)
{ {
OpenSearch openSearch = new OpenSearch(entries, enumerationLocation); OpenSearch openSearch = new OpenSearch(entries, enumerationLocation);
m_openSearches.Add(fileID, openSearch); m_openSearches.Add(fileID.Volatile, openSearch);
return openSearch; return openSearch;
} }
public OpenSearch GetOpenSearch(ulong fileID) public OpenSearch GetOpenSearch(FileID fileID)
{ {
OpenSearch openSearch; OpenSearch openSearch;
m_openSearches.TryGetValue(fileID, out openSearch); m_openSearches.TryGetValue(fileID.Volatile, out openSearch);
return openSearch; return openSearch;
} }
public void RemoveOpenSearch(ulong fileID) public void RemoveOpenSearch(FileID fileID)
{ {
m_openSearches.Remove(fileID); m_openSearches.Remove(fileID.Volatile);
} }
/// <summary> /// <summary>

View file

@ -17,13 +17,13 @@ namespace SMBLibrary.Server.SMB2
internal static SMB2Command GetCloseResponse(CloseRequest request, ISMBShare share, SMB2ConnectionState state) internal static SMB2Command GetCloseResponse(CloseRequest request, ISMBShare share, SMB2ConnectionState state)
{ {
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);
} }
share.FileStore.CloseFile(openFile.Handle); share.FileStore.CloseFile(openFile.Handle);
session.RemoveOpenFile(request.FileId.Persistent); session.RemoveOpenFile(request.FileId);
CloseResponse response = new CloseResponse(); CloseResponse response = new CloseResponse();
if (request.PostQueryAttributes) if (request.PostQueryAttributes)
{ {

View file

@ -40,8 +40,8 @@ namespace SMBLibrary.Server.SMB2
return new ErrorResponse(request.CommandName, createStatus); return new ErrorResponse(request.CommandName, createStatus);
} }
ulong? persistentFileID = session.AddOpenFile(request.Header.TreeID, path, handle); FileID? fileID = session.AddOpenFile(request.Header.TreeID, path, handle);
if (!persistentFileID.HasValue) if (fileID == null)
{ {
share.FileStore.CloseFile(handle); share.FileStore.CloseFile(handle);
return new ErrorResponse(request.CommandName, NTStatus.STATUS_TOO_MANY_OPENED_FILES); return new ErrorResponse(request.CommandName, NTStatus.STATUS_TOO_MANY_OPENED_FILES);
@ -49,12 +49,12 @@ namespace SMBLibrary.Server.SMB2
if (share is NamedPipeShare) if (share is NamedPipeShare)
{ {
return CreateResponseForNamedPipe(persistentFileID.Value, FileStatus.FILE_OPENED); return CreateResponseForNamedPipe(fileID.Value, FileStatus.FILE_OPENED);
} }
else else
{ {
FileNetworkOpenInformation fileInfo = NTFileStoreHelper.GetNetworkOpenInformation(share.FileStore, handle); FileNetworkOpenInformation fileInfo = NTFileStoreHelper.GetNetworkOpenInformation(share.FileStore, handle);
CreateResponse response = CreateResponseFromFileSystemEntry(fileInfo, persistentFileID.Value, fileStatus); CreateResponse response = CreateResponseFromFileSystemEntry(fileInfo, fileID.Value, fileStatus);
if (request.RequestedOplockLevel == OplockLevel.Batch) if (request.RequestedOplockLevel == OplockLevel.Batch)
{ {
response.OplockLevel = OplockLevel.Batch; response.OplockLevel = OplockLevel.Batch;
@ -63,16 +63,16 @@ namespace SMBLibrary.Server.SMB2
} }
} }
private static CreateResponse CreateResponseForNamedPipe(ulong persistentFileID, FileStatus fileStatus) private static CreateResponse CreateResponseForNamedPipe(FileID fileID, FileStatus fileStatus)
{ {
CreateResponse response = new CreateResponse(); CreateResponse response = new CreateResponse();
response.CreateAction = (CreateAction)fileStatus; response.CreateAction = (CreateAction)fileStatus;
response.FileAttributes = FileAttributes.Normal; response.FileAttributes = FileAttributes.Normal;
response.FileId.Persistent = persistentFileID; response.FileId = fileID;
return response; return response;
} }
private static CreateResponse CreateResponseFromFileSystemEntry(FileNetworkOpenInformation fileInfo, ulong persistentFileID, FileStatus fileStatus) private static CreateResponse CreateResponseFromFileSystemEntry(FileNetworkOpenInformation fileInfo, FileID fileID, FileStatus fileStatus)
{ {
CreateResponse response = new CreateResponse(); CreateResponse response = new CreateResponse();
response.CreateAction = (CreateAction)fileStatus; response.CreateAction = (CreateAction)fileStatus;
@ -83,7 +83,7 @@ namespace SMBLibrary.Server.SMB2
response.AllocationSize = fileInfo.AllocationSize; response.AllocationSize = fileInfo.AllocationSize;
response.EndofFile = fileInfo.EndOfFile; response.EndofFile = fileInfo.EndOfFile;
response.FileAttributes = fileInfo.FileAttributes; response.FileAttributes = fileInfo.FileAttributes;
response.FileId.Persistent = persistentFileID; response.FileId = fileID;
return response; return response;
} }
} }

View file

@ -24,7 +24,7 @@ namespace SMBLibrary.Server.SMB2
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FS_DRIVER_REQUIRED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FS_DRIVER_REQUIRED);
} }
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
object handle; object handle;
if (openFile == null) if (openFile == null)
{ {

View file

@ -17,7 +17,7 @@ namespace SMBLibrary.Server.SMB2
internal static SMB2Command GetQueryDirectoryResponse(QueryDirectoryRequest request, ISMBShare share, SMB2ConnectionState state) internal static SMB2Command GetQueryDirectoryResponse(QueryDirectoryRequest request, ISMBShare share, SMB2ConnectionState state)
{ {
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);
@ -30,7 +30,7 @@ namespace SMBLibrary.Server.SMB2
FileSystemShare fileSystemShare = (FileSystemShare)share; FileSystemShare fileSystemShare = (FileSystemShare)share;
ulong fileID = request.FileId.Persistent; FileID fileID = request.FileId;
OpenSearch openSearch = session.GetOpenSearch(fileID); OpenSearch openSearch = session.GetOpenSearch(fileID);
if (openSearch == null || request.Reopen) if (openSearch == null || request.Reopen)
{ {

View file

@ -19,7 +19,7 @@ namespace SMBLibrary.Server.SMB2
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
if (request.InfoType == InfoType.File) if (request.InfoType == InfoType.File)
{ {
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);

View file

@ -17,7 +17,7 @@ namespace SMBLibrary.Server.SMB2
internal static SMB2Command GetReadResponse(ReadRequest request, ISMBShare share, SMB2ConnectionState state) internal static SMB2Command GetReadResponse(ReadRequest request, ISMBShare share, SMB2ConnectionState state)
{ {
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);
@ -37,7 +37,7 @@ namespace SMBLibrary.Server.SMB2
internal static SMB2Command GetWriteResponse(WriteRequest request, ISMBShare share, SMB2ConnectionState state) internal static SMB2Command GetWriteResponse(WriteRequest request, ISMBShare share, SMB2ConnectionState state)
{ {
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);

View file

@ -19,7 +19,7 @@ namespace SMBLibrary.Server.SMB2
SMB2Session session = state.GetSession(request.Header.SessionID); SMB2Session session = state.GetSession(request.Header.SessionID);
if (request.InfoType == InfoType.File) if (request.InfoType == InfoType.File)
{ {
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);

View file

@ -15,28 +15,6 @@ namespace SMBLibrary.Server
{ {
public partial class SMBServer public partial class SMBServer
{ {
// Key is the persistent portion of the FileID
private Dictionary<ulong, OpenFileObject> m_globalOpenFiles = new Dictionary<ulong, OpenFileObject>();
private static ulong m_nextPersistentFileID = 1; // A numeric value that uniquely identifies the open handle to a file or a pipe within the scope of all opens granted by the server
private ulong? AllocatePersistentFileID()
{
for (ulong offset = 0; offset < UInt64.MaxValue; offset++)
{
ulong persistentID = (ulong)(m_nextPersistentFileID + offset);
if (persistentID == 0 || persistentID == 0xFFFFFFFFFFFFFFFF)
{
continue;
}
if (!m_globalOpenFiles.ContainsKey(persistentID))
{
m_nextPersistentFileID = (ulong)(persistentID + 1);
return persistentID;
}
}
return null;
}
private void ProcessSMB2RequestChain(List<SMB2Command> requestChain, ref ConnectionState state) private void ProcessSMB2RequestChain(List<SMB2Command> requestChain, ref ConnectionState state)
{ {
List<SMB2Command> responseChain = new List<SMB2Command>(); List<SMB2Command> responseChain = new List<SMB2Command>();
@ -77,7 +55,7 @@ namespace SMBLibrary.Server
SMB2Command response = NegotiateHelper.GetNegotiateResponse(request, m_securityProvider, state, m_serverGuid, m_serverStartTime); SMB2Command response = NegotiateHelper.GetNegotiateResponse(request, m_securityProvider, state, m_serverGuid, m_serverStartTime);
if (state.Dialect != SMBDialect.NotSet) if (state.Dialect != SMBDialect.NotSet)
{ {
state = new SMB2ConnectionState(state, AllocatePersistentFileID); state = new SMB2ConnectionState(state);
m_connectionManager.AddConnection(state); m_connectionManager.AddConnection(state);
} }
return response; return response;
@ -188,7 +166,7 @@ namespace SMBLibrary.Server
else if (command is FlushRequest) else if (command is FlushRequest)
{ {
FlushRequest request = (FlushRequest)command; FlushRequest request = (FlushRequest)command;
OpenFileObject openFile = session.GetOpenFileObject(request.FileId.Persistent); OpenFileObject openFile = session.GetOpenFileObject(request.FileId);
if (openFile == null) if (openFile == null)
{ {
return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED); return new ErrorResponse(request.CommandName, NTStatus.STATUS_FILE_CLOSED);

View file

@ -272,7 +272,7 @@ namespace SMBLibrary.Server
SMB2Command response = SMB2.NegotiateHelper.GetNegotiateResponse(smb2Dialects, m_securityProvider, state, m_serverGuid, m_serverStartTime); SMB2Command response = SMB2.NegotiateHelper.GetNegotiateResponse(smb2Dialects, m_securityProvider, state, m_serverGuid, m_serverStartTime);
if (state.Dialect != SMBDialect.NotSet) if (state.Dialect != SMBDialect.NotSet)
{ {
state = new SMB2ConnectionState(state, AllocatePersistentFileID); state = new SMB2ConnectionState(state);
m_connectionManager.AddConnection(state); m_connectionManager.AddConnection(state);
} }
EnqueueResponse(state, response); EnqueueResponse(state, response);