diff --git a/SMBLibrary/Server/ConnectionState/SMB1Session.cs b/SMBLibrary/Server/ConnectionState/SMB1Session.cs index fe67675..a7587e5 100644 --- a/SMBLibrary/Server/ConnectionState/SMB1Session.cs +++ b/SMBLibrary/Server/ConnectionState/SMB1Session.cs @@ -42,12 +42,15 @@ namespace SMBLibrary.Server public ushort? AddConnectedTree(ISMBShare share) { - ushort? treeID = m_connection.AllocateTreeID(); - if (treeID.HasValue) + lock (m_connection) { - m_connectedTrees.Add(treeID.Value, share); + ushort? treeID = m_connection.AllocateTreeID(); + if (treeID.HasValue) + { + m_connectedTrees.Add(treeID.Value, share); + } + return treeID; } - return treeID; } public ISMBShare GetConnectedTree(ushort treeID) @@ -63,7 +66,7 @@ namespace SMBLibrary.Server m_connectedTrees.TryGetValue(treeID, out share); if (share != null) { - lock (m_openFiles) + lock (m_connection) { List fileIDList = new List(m_openFiles.Keys); foreach (ushort fileID in fileIDList) @@ -75,8 +78,8 @@ namespace SMBLibrary.Server m_openFiles.Remove(fileID); } } + m_connectedTrees.Remove(treeID); } - m_connectedTrees.Remove(treeID); } } @@ -94,15 +97,15 @@ namespace SMBLibrary.Server public ushort? AddOpenFile(ushort treeID, string shareName, string relativePath, object handle) { - ushort? fileID = m_connection.AllocateFileID(); - if (fileID.HasValue) + lock (m_connection) { - lock (m_openFiles) + ushort? fileID = m_connection.AllocateFileID(); + if (fileID.HasValue) { m_openFiles.Add(fileID.Value, new OpenFileObject(treeID, shareName, relativePath, handle)); } + return fileID; } - return fileID; } public OpenFileObject GetOpenFileObject(ushort fileID) @@ -114,7 +117,7 @@ namespace SMBLibrary.Server public void RemoveOpenFile(ushort fileID) { - lock (m_openFiles) + lock (m_connection) { m_openFiles.Remove(fileID); } @@ -123,7 +126,7 @@ namespace SMBLibrary.Server public List ListOpenFiles() { List result = new List(); - lock (m_openFiles) + lock (m_connection) { foreach (OpenFileObject openFile in m_openFiles.Values) { diff --git a/SMBLibrary/Server/ConnectionState/SMB2Session.cs b/SMBLibrary/Server/ConnectionState/SMB2Session.cs index cea9bc8..f90c8b8 100644 --- a/SMBLibrary/Server/ConnectionState/SMB2Session.cs +++ b/SMBLibrary/Server/ConnectionState/SMB2Session.cs @@ -62,24 +62,22 @@ namespace SMBLibrary.Server public uint? AddConnectedTree(ISMBShare share) { - uint? treeID = AllocateTreeID(); - if (treeID.HasValue) + lock (m_connectedTrees) { - m_connectedTrees.Add(treeID.Value, share); + uint? treeID = AllocateTreeID(); + if (treeID.HasValue) + { + m_connectedTrees.Add(treeID.Value, share); + } + return treeID; } - return treeID; } public ISMBShare GetConnectedTree(uint treeID) { - if (m_connectedTrees.ContainsKey(treeID)) - { - return m_connectedTrees[treeID]; - } - else - { - return null; - } + ISMBShare result; + m_connectedTrees.TryGetValue(treeID, out result); + return result; } public void DisconnectTree(uint treeID) @@ -101,7 +99,10 @@ namespace SMBLibrary.Server } } } - m_connectedTrees.Remove(treeID); + lock (m_connectedTrees) + { + m_connectedTrees.Remove(treeID); + } } } @@ -131,19 +132,19 @@ namespace SMBLibrary.Server public FileID? AddOpenFile(uint treeID, string shareName, string relativePath, object handle) { - ulong? volatileFileID = AllocateVolatileFileID(); - if (volatileFileID.HasValue) + lock (m_openFiles) { - FileID fileID = new FileID(); - fileID.Volatile = volatileFileID.Value; - // [MS-SMB2] FileId.Persistent MUST be set to Open.DurableFileId. - // Note: We don't support durable handles so we use volatileFileID. - fileID.Persistent = volatileFileID.Value; - lock (m_openFiles) + ulong? volatileFileID = AllocateVolatileFileID(); + if (volatileFileID.HasValue) { + FileID fileID = new FileID(); + fileID.Volatile = volatileFileID.Value; + // [MS-SMB2] FileId.Persistent MUST be set to Open.DurableFileId. + // Note: We don't support durable handles so we use volatileFileID. + fileID.Persistent = volatileFileID.Value; m_openFiles.Add(volatileFileID.Value, new OpenFileObject(treeID, shareName, relativePath, handle)); + return fileID; } - return fileID; } return null; }