Improved connected trees management logic

This commit is contained in:
Tal Aloni 2017-01-13 23:42:39 +02:00
parent 40a9b52cfc
commit f809d337c8
5 changed files with 62 additions and 92 deletions

View file

@ -22,7 +22,7 @@ namespace SMBLibrary.Server
private ushort m_nextUID = 1;
// Key is TID
private Dictionary<ushort, string> m_connectedTrees = new Dictionary<ushort, string>();
private Dictionary<ushort, ISMBShare> m_connectedTrees = new Dictionary<ushort, ISMBShare>();
private ushort m_nextTID = 1;
// Key is FID
@ -117,17 +117,17 @@ namespace SMBLibrary.Server
return null;
}
public ushort? AddConnectedTree(string relativePath)
public ushort? AddConnectedTree(ISMBShare share)
{
ushort? treeID = AllocateTreeID();
if (treeID.HasValue)
{
m_connectedTrees.Add(treeID.Value, relativePath);
m_connectedTrees.Add(treeID.Value, share);
}
return treeID;
}
public string GetConnectedTreePath(ushort treeID)
public ISMBShare GetConnectedTree(ushort treeID)
{
if (m_connectedTrees.ContainsKey(treeID))
{
@ -149,12 +149,6 @@ namespace SMBLibrary.Server
return m_connectedTrees.ContainsKey(treeID);
}
public bool IsIPC(ushort treeID)
{
string relativePath = GetConnectedTreePath(treeID);
return String.Equals(relativePath, "\\IPC$", StringComparison.InvariantCultureIgnoreCase);
}
public ProcessStateObject GetProcessState(uint processID)
{
if (ProcessStateList.ContainsKey(processID))

View file

@ -1,4 +1,4 @@
/* Copyright (C) 2014 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
*
* You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation,
@ -36,20 +36,32 @@ namespace SMBLibrary.Server
/// <returns>e.g. \*</returns>
public static string GetRelativeSharePath(string path)
{
if (path.StartsWith(@"\\"))
string relativePath = GetRelativeServerPath(path);
int index = relativePath.IndexOf('\\', 1);
if (index > 0)
{
int firstIndex = path.IndexOf('\\', 2);
int index = path.IndexOf('\\', firstIndex + 1);
if (index > 0)
{
return path.Substring(index);
}
else
{
return path;
}
return path.Substring(index);
}
return path;
else
{
return @"\";
}
}
public static string GetShareName(string path)
{
string relativePath = GetRelativeServerPath(path);
if (relativePath.StartsWith(@"\"))
{
relativePath = relativePath.Substring(1);
}
int indexOfSeparator = relativePath.IndexOf(@"\");
if (indexOfSeparator >= 0)
{
relativePath = relativePath.Substring(0, indexOfSeparator);
}
return relativePath;
}
}
}

View file

@ -14,64 +14,49 @@ namespace SMBLibrary.Server.SMB1
{
public class TreeConnectHelper
{
internal static SMB1Command GetTreeConnectResponse(SMB1Header header, TreeConnectAndXRequest request, SMB1ConnectionState state, ShareCollection shares)
internal static SMB1Command GetTreeConnectResponse(SMB1Header header, TreeConnectAndXRequest request, SMB1ConnectionState state, NamedPipeShare services, ShareCollection shares)
{
bool isExtended = (request.Flags & TreeConnectFlags.ExtendedResponse) > 0;
string relativePath = ServerPathUtils.GetRelativeServerPath(request.Path);
if (String.Equals(relativePath, "\\IPC$", StringComparison.InvariantCultureIgnoreCase))
string shareName = ServerPathUtils.GetShareName(request.Path);
ISMBShare share;
ServiceName serviceName;
if (String.Equals(shareName, NamedPipeShare.NamedPipeShareName, StringComparison.InvariantCultureIgnoreCase))
{
ushort? treeID = state.AddConnectedTree(relativePath);
if (!treeID.HasValue)
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(ServiceName.NamedPipe);
}
else
{
return CreateTreeConnectResponse(ServiceName.NamedPipe);
}
share = services;
serviceName = ServiceName.NamedPipe;
}
else
{
FileSystemShare share = shares.GetShareFromRelativePath(relativePath);
share = shares.GetShareFromName(shareName);
serviceName = ServiceName.DiskShare;
if (share == null)
{
header.Status = NTStatus.STATUS_OBJECT_PATH_NOT_FOUND;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
else
string userName = state.GetConnectedUserName(header.UID);
if (!((FileSystemShare)share).HasReadAccess(userName))
{
string userName = state.GetConnectedUserName(header.UID);
if (!share.HasReadAccess(userName))
{
header.Status = NTStatus.STATUS_ACCESS_DENIED;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
else
{
ushort? treeID = state.AddConnectedTree(relativePath);
if (!treeID.HasValue)
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(ServiceName.DiskShare);
}
else
{
return CreateTreeConnectResponse(ServiceName.DiskShare);
}
}
header.Status = NTStatus.STATUS_ACCESS_DENIED;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
}
ushort? treeID = state.AddConnectedTree(share);
if (!treeID.HasValue)
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(serviceName);
}
else
{
return CreateTreeConnectResponse(serviceName);
}
}
private static TreeConnectAndXResponse CreateTreeConnectResponse(ServiceName serviceName)

View file

@ -93,7 +93,7 @@ namespace SMBLibrary.Server
if (command is TreeConnectAndXRequest)
{
TreeConnectAndXRequest request = (TreeConnectAndXRequest)command;
return TreeConnectHelper.GetTreeConnectResponse(header, request, state, m_shares);
return TreeConnectHelper.GetTreeConnectResponse(header, request, state, m_services, m_shares);
}
else if (command is LogoffAndXRequest)
{
@ -101,17 +101,7 @@ namespace SMBLibrary.Server
}
else if (state.IsTreeConnected(header.TID))
{
string rootPath = state.GetConnectedTreePath(header.TID);
ISMBShare share;
if (state.IsIPC(header.TID))
{
share = m_services;
}
else
{
share = m_shares.GetShareFromRelativePath(rootPath);
}
ISMBShare share = state.GetConnectedTree(header.TID);
if (command is CreateDirectoryRequest)
{
if (!(share is FileSystemShare))

View file

@ -51,20 +51,9 @@ namespace SMBLibrary.Server
}
/// <param name="relativePath">e.g. \Shared</param>
public FileSystemShare GetShareFromRelativePath(string relativePath)
public FileSystemShare GetShareFromName(string shareName)
{
if (relativePath.StartsWith(@"\"))
{
relativePath = relativePath.Substring(1);
}
int indexOfSeparator = relativePath.IndexOf(@"\");
if (indexOfSeparator >= 0)
{
relativePath = relativePath.Substring(0, indexOfSeparator);
}
int index = IndexOf(relativePath, StringComparison.InvariantCultureIgnoreCase);
int index = IndexOf(shareName, StringComparison.InvariantCultureIgnoreCase);
if (index >= 0)
{
return this[index];