Improved connected trees management logic

This commit is contained in:
Tal Aloni 2017-01-13 23:42:39 +02:00
parent 40a9b52cfc
commit f809d337c8
5 changed files with 62 additions and 92 deletions

View file

@ -22,7 +22,7 @@ namespace SMBLibrary.Server
private ushort m_nextUID = 1; private ushort m_nextUID = 1;
// Key is TID // Key is TID
private Dictionary<ushort, string> m_connectedTrees = new Dictionary<ushort, string>(); private Dictionary<ushort, ISMBShare> m_connectedTrees = new Dictionary<ushort, ISMBShare>();
private ushort m_nextTID = 1; private ushort m_nextTID = 1;
// Key is FID // Key is FID
@ -117,17 +117,17 @@ namespace SMBLibrary.Server
return null; return null;
} }
public ushort? AddConnectedTree(string relativePath) public ushort? AddConnectedTree(ISMBShare share)
{ {
ushort? treeID = AllocateTreeID(); ushort? treeID = AllocateTreeID();
if (treeID.HasValue) if (treeID.HasValue)
{ {
m_connectedTrees.Add(treeID.Value, relativePath); m_connectedTrees.Add(treeID.Value, share);
} }
return treeID; return treeID;
} }
public string GetConnectedTreePath(ushort treeID) public ISMBShare GetConnectedTree(ushort treeID)
{ {
if (m_connectedTrees.ContainsKey(treeID)) if (m_connectedTrees.ContainsKey(treeID))
{ {
@ -149,12 +149,6 @@ namespace SMBLibrary.Server
return m_connectedTrees.ContainsKey(treeID); return m_connectedTrees.ContainsKey(treeID);
} }
public bool IsIPC(ushort treeID)
{
string relativePath = GetConnectedTreePath(treeID);
return String.Equals(relativePath, "\\IPC$", StringComparison.InvariantCultureIgnoreCase);
}
public ProcessStateObject GetProcessState(uint processID) public ProcessStateObject GetProcessState(uint processID)
{ {
if (ProcessStateList.ContainsKey(processID)) if (ProcessStateList.ContainsKey(processID))

View file

@ -1,4 +1,4 @@
/* Copyright (C) 2014 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved. /* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
* *
* You can redistribute this program and/or modify it under the terms of * You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation, * the GNU Lesser Public License as published by the Free Software Foundation,
@ -36,20 +36,32 @@ namespace SMBLibrary.Server
/// <returns>e.g. \*</returns> /// <returns>e.g. \*</returns>
public static string GetRelativeSharePath(string path) public static string GetRelativeSharePath(string path)
{ {
if (path.StartsWith(@"\\")) string relativePath = GetRelativeServerPath(path);
int index = relativePath.IndexOf('\\', 1);
if (index > 0)
{ {
int firstIndex = path.IndexOf('\\', 2); return path.Substring(index);
int index = path.IndexOf('\\', firstIndex + 1);
if (index > 0)
{
return path.Substring(index);
}
else
{
return path;
}
} }
return path; else
{
return @"\";
}
}
public static string GetShareName(string path)
{
string relativePath = GetRelativeServerPath(path);
if (relativePath.StartsWith(@"\"))
{
relativePath = relativePath.Substring(1);
}
int indexOfSeparator = relativePath.IndexOf(@"\");
if (indexOfSeparator >= 0)
{
relativePath = relativePath.Substring(0, indexOfSeparator);
}
return relativePath;
} }
} }
} }

View file

@ -14,64 +14,49 @@ namespace SMBLibrary.Server.SMB1
{ {
public class TreeConnectHelper public class TreeConnectHelper
{ {
internal static SMB1Command GetTreeConnectResponse(SMB1Header header, TreeConnectAndXRequest request, SMB1ConnectionState state, ShareCollection shares) internal static SMB1Command GetTreeConnectResponse(SMB1Header header, TreeConnectAndXRequest request, SMB1ConnectionState state, NamedPipeShare services, ShareCollection shares)
{ {
bool isExtended = (request.Flags & TreeConnectFlags.ExtendedResponse) > 0; bool isExtended = (request.Flags & TreeConnectFlags.ExtendedResponse) > 0;
string relativePath = ServerPathUtils.GetRelativeServerPath(request.Path); string shareName = ServerPathUtils.GetShareName(request.Path);
if (String.Equals(relativePath, "\\IPC$", StringComparison.InvariantCultureIgnoreCase)) ISMBShare share;
ServiceName serviceName;
if (String.Equals(shareName, NamedPipeShare.NamedPipeShareName, StringComparison.InvariantCultureIgnoreCase))
{ {
ushort? treeID = state.AddConnectedTree(relativePath); share = services;
if (!treeID.HasValue) serviceName = ServiceName.NamedPipe;
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(ServiceName.NamedPipe);
}
else
{
return CreateTreeConnectResponse(ServiceName.NamedPipe);
}
} }
else else
{ {
FileSystemShare share = shares.GetShareFromRelativePath(relativePath); share = shares.GetShareFromName(shareName);
serviceName = ServiceName.DiskShare;
if (share == null) if (share == null)
{ {
header.Status = NTStatus.STATUS_OBJECT_PATH_NOT_FOUND; header.Status = NTStatus.STATUS_OBJECT_PATH_NOT_FOUND;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX); return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
} }
else
string userName = state.GetConnectedUserName(header.UID);
if (!((FileSystemShare)share).HasReadAccess(userName))
{ {
string userName = state.GetConnectedUserName(header.UID); header.Status = NTStatus.STATUS_ACCESS_DENIED;
if (!share.HasReadAccess(userName)) return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
{
header.Status = NTStatus.STATUS_ACCESS_DENIED;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
else
{
ushort? treeID = state.AddConnectedTree(relativePath);
if (!treeID.HasValue)
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(ServiceName.DiskShare);
}
else
{
return CreateTreeConnectResponse(ServiceName.DiskShare);
}
}
} }
} }
ushort? treeID = state.AddConnectedTree(share);
if (!treeID.HasValue)
{
header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
}
header.TID = treeID.Value;
if (isExtended)
{
return CreateTreeConnectResponseExtended(serviceName);
}
else
{
return CreateTreeConnectResponse(serviceName);
}
} }
private static TreeConnectAndXResponse CreateTreeConnectResponse(ServiceName serviceName) private static TreeConnectAndXResponse CreateTreeConnectResponse(ServiceName serviceName)

View file

@ -93,7 +93,7 @@ namespace SMBLibrary.Server
if (command is TreeConnectAndXRequest) if (command is TreeConnectAndXRequest)
{ {
TreeConnectAndXRequest request = (TreeConnectAndXRequest)command; TreeConnectAndXRequest request = (TreeConnectAndXRequest)command;
return TreeConnectHelper.GetTreeConnectResponse(header, request, state, m_shares); return TreeConnectHelper.GetTreeConnectResponse(header, request, state, m_services, m_shares);
} }
else if (command is LogoffAndXRequest) else if (command is LogoffAndXRequest)
{ {
@ -101,17 +101,7 @@ namespace SMBLibrary.Server
} }
else if (state.IsTreeConnected(header.TID)) else if (state.IsTreeConnected(header.TID))
{ {
string rootPath = state.GetConnectedTreePath(header.TID); ISMBShare share = state.GetConnectedTree(header.TID);
ISMBShare share;
if (state.IsIPC(header.TID))
{
share = m_services;
}
else
{
share = m_shares.GetShareFromRelativePath(rootPath);
}
if (command is CreateDirectoryRequest) if (command is CreateDirectoryRequest)
{ {
if (!(share is FileSystemShare)) if (!(share is FileSystemShare))

View file

@ -51,20 +51,9 @@ namespace SMBLibrary.Server
} }
/// <param name="relativePath">e.g. \Shared</param> /// <param name="relativePath">e.g. \Shared</param>
public FileSystemShare GetShareFromRelativePath(string relativePath) public FileSystemShare GetShareFromName(string shareName)
{ {
if (relativePath.StartsWith(@"\")) int index = IndexOf(shareName, StringComparison.InvariantCultureIgnoreCase);
{
relativePath = relativePath.Substring(1);
}
int indexOfSeparator = relativePath.IndexOf(@"\");
if (indexOfSeparator >= 0)
{
relativePath = relativePath.Substring(0, indexOfSeparator);
}
int index = IndexOf(relativePath, StringComparison.InvariantCultureIgnoreCase);
if (index >= 0) if (index >= 0)
{ {
return this[index]; return this[index];