SMBServer: Each connection now use a dedicaded thread for send operations

This commit is contained in:
Tal Aloni 2017-03-04 12:56:11 +02:00
parent 30627e72d4
commit 84affda0ff
5 changed files with 52 additions and 24 deletions

View file

@ -38,6 +38,7 @@ namespace SMBLibrary.Server
public void ReleaseConnection(ConnectionState connection) public void ReleaseConnection(ConnectionState connection)
{ {
connection.SendQueue.Stop();
SocketUtils.ReleaseSocket(connection.ClientSocket); SocketUtils.ReleaseSocket(connection.ClientSocket);
RemoveConnection(connection); RemoveConnection(connection);
} }

View file

@ -20,6 +20,7 @@ namespace SMBLibrary.Server
public Socket ClientSocket; public Socket ClientSocket;
public IPEndPoint ClientEndPoint; public IPEndPoint ClientEndPoint;
public NBTConnectionReceiveBuffer ReceiveBuffer; public NBTConnectionReceiveBuffer ReceiveBuffer;
public BlockingQueue<SessionPacket> SendQueue;
protected LogDelegate LogToServerHandler; protected LogDelegate LogToServerHandler;
public SMBDialect Dialect; public SMBDialect Dialect;
public object AuthenticationContext; public object AuthenticationContext;
@ -27,6 +28,7 @@ namespace SMBLibrary.Server
public ConnectionState(LogDelegate logToServerHandler) public ConnectionState(LogDelegate logToServerHandler)
{ {
ReceiveBuffer = new NBTConnectionReceiveBuffer(); ReceiveBuffer = new NBTConnectionReceiveBuffer();
SendQueue = new BlockingQueue<SessionPacket>();
LogToServerHandler = logToServerHandler; LogToServerHandler = logToServerHandler;
Dialect = SMBDialect.NotSet; Dialect = SMBDialect.NotSet;
} }
@ -36,6 +38,7 @@ namespace SMBLibrary.Server
ClientSocket = state.ClientSocket; ClientSocket = state.ClientSocket;
ClientEndPoint = state.ClientEndPoint; ClientEndPoint = state.ClientEndPoint;
ReceiveBuffer = state.ReceiveBuffer; ReceiveBuffer = state.ReceiveBuffer;
SendQueue = state.SendQueue;
LogToServerHandler = state.LogToServerHandler; LogToServerHandler = state.LogToServerHandler;
Dialect = state.Dialect; Dialect = state.Dialect;
} }

View file

@ -49,7 +49,7 @@ namespace SMBLibrary.Server
index--; index--;
} }
} }
TrySendMessage(state, reply); EnqueueMessage(state, reply);
} }
} }
@ -58,7 +58,7 @@ namespace SMBLibrary.Server
SMB1Message reply = new SMB1Message(); SMB1Message reply = new SMB1Message();
reply.Header = header; reply.Header = header;
reply.Commands.Add(response); reply.Commands.Add(response);
TrySendMessage(state, reply); EnqueueMessage(state, reply);
} }
} }
@ -302,12 +302,12 @@ namespace SMBLibrary.Server
return new ErrorResponse(command.CommandName); return new ErrorResponse(command.CommandName);
} }
private static void TrySendMessage(ConnectionState state, SMB1Message response) private static void EnqueueMessage(ConnectionState state, SMB1Message response)
{ {
SessionMessagePacket packet = new SessionMessagePacket(); SessionMessagePacket packet = new SessionMessagePacket();
packet.Trailer = response.GetBytes(); packet.Trailer = response.GetBytes();
TrySendPacket(state, packet); state.SendQueue.Enqueue(packet);
state.LogToServer(Severity.Verbose, "SMB1 message sent: {0} responses, First response: {1}, Packet length: {2}", response.Commands.Count, response.Commands[0].CommandName.ToString(), packet.Length); state.LogToServer(Severity.Verbose, "SMB1 message queued: {0} responses, First response: {1}, Packet length: {2}", response.Commands.Count, response.Commands[0].CommandName.ToString(), packet.Length);
} }
private static void PrepareResponseHeader(SMB1Header responseHeader, SMB1Header requestHeader) private static void PrepareResponseHeader(SMB1Header responseHeader, SMB1Header requestHeader)

View file

@ -60,7 +60,7 @@ namespace SMBLibrary.Server
} }
if (responseChain.Count > 0) if (responseChain.Count > 0)
{ {
TrySendResponseChain(state, responseChain); EnqueueResponseChain(state, responseChain);
} }
} }
@ -223,15 +223,15 @@ namespace SMBLibrary.Server
return new ErrorResponse(command.CommandName, NTStatus.STATUS_NOT_SUPPORTED); return new ErrorResponse(command.CommandName, NTStatus.STATUS_NOT_SUPPORTED);
} }
private static void TrySendResponse(ConnectionState state, SMB2Command response) private static void EnqueueResponse(ConnectionState state, SMB2Command response)
{ {
SessionMessagePacket packet = new SessionMessagePacket(); SessionMessagePacket packet = new SessionMessagePacket();
packet.Trailer = response.GetBytes(); packet.Trailer = response.GetBytes();
TrySendPacket(state, packet); state.SendQueue.Enqueue(packet);
state.LogToServer(Severity.Verbose, "SMB2 response sent: {0}, Packet length: {1}", response.CommandName.ToString(), packet.Length); state.LogToServer(Severity.Verbose, "SMB2 response queued: {0}, Packet length: {1}", response.CommandName.ToString(), packet.Length);
} }
private static void TrySendResponseChain(ConnectionState state, List<SMB2Command> responseChain) private static void EnqueueResponseChain(ConnectionState state, List<SMB2Command> responseChain)
{ {
byte[] sessionKey = null; byte[] sessionKey = null;
if (state is SMB2ConnectionState) if (state is SMB2ConnectionState)
@ -252,8 +252,8 @@ namespace SMBLibrary.Server
SessionMessagePacket packet = new SessionMessagePacket(); SessionMessagePacket packet = new SessionMessagePacket();
packet.Trailer = SMB2Command.GetCommandChainBytes(responseChain, sessionKey); packet.Trailer = SMB2Command.GetCommandChainBytes(responseChain, sessionKey);
TrySendPacket(state, packet); state.SendQueue.Enqueue(packet);
state.LogToServer(Severity.Verbose, "SMB2 response chain sent: Response count: {0}, First response: {1}, Packet length: {2}", responseChain.Count, responseChain[0].CommandName.ToString(), packet.Length); state.LogToServer(Severity.Verbose, "SMB2 response chain queued: Response count: {0}, First response: {1}, Packet length: {2}", responseChain.Count, responseChain[0].CommandName.ToString(), packet.Length);
} }
private static void UpdateSMB2Header(SMB2Command response, SMB2Command request) private static void UpdateSMB2Header(SMB2Command response, SMB2Command request)

View file

@ -8,6 +8,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading;
using SMBLibrary.Authentication.GSSAPI; using SMBLibrary.Authentication.GSSAPI;
using SMBLibrary.NetBios; using SMBLibrary.NetBios;
using SMBLibrary.Services; using SMBLibrary.Services;
@ -117,6 +118,13 @@ namespace SMBLibrary.Server
state.ClientSocket = clientSocket; state.ClientSocket = clientSocket;
state.ClientEndPoint = clientSocket.RemoteEndPoint as IPEndPoint; state.ClientEndPoint = clientSocket.RemoteEndPoint as IPEndPoint;
state.LogToServer(Severity.Verbose, "New connection request"); state.LogToServer(Severity.Verbose, "New connection request");
Thread senderThread = new Thread(delegate()
{
ProcessSendQueue(state);
});
senderThread.IsBackground = true;
senderThread.Start();
try try
{ {
// Direct TCP transport packet is actually an NBT Session Message Packet, // Direct TCP transport packet is actually an NBT Session Message Packet,
@ -219,7 +227,7 @@ namespace SMBLibrary.Server
if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP) if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP)
{ {
PositiveSessionResponsePacket response = new PositiveSessionResponsePacket(); PositiveSessionResponsePacket response = new PositiveSessionResponsePacket();
TrySendPacket(state, response); state.SendQueue.Enqueue(response);
} }
else if (packet is SessionKeepAlivePacket && m_transport == SMBTransportType.NetBiosOverTCP) else if (packet is SessionKeepAlivePacket && m_transport == SMBTransportType.NetBiosOverTCP)
{ {
@ -265,7 +273,7 @@ namespace SMBLibrary.Server
state = new SMB2ConnectionState(state, AllocatePersistentFileID); state = new SMB2ConnectionState(state, AllocatePersistentFileID);
m_connectionManager.AddConnection(state); m_connectionManager.AddConnection(state);
} }
TrySendResponse(state, response); EnqueueResponse(state, response);
return; return;
} }
} }
@ -319,18 +327,34 @@ namespace SMBLibrary.Server
} }
} }
private static void TrySendPacket(ConnectionState state, SessionPacket response) private void ProcessSendQueue(ConnectionState state)
{ {
Socket clientSocket = state.ClientSocket; while (true)
try
{
clientSocket.Send(response.GetBytes());
}
catch (SocketException)
{
}
catch (ObjectDisposedException)
{ {
Log(Severity.Trace, "Entering ProcessSendQueue");
SessionPacket response;
bool stopped = !state.SendQueue.TryDequeue(out response);
if (stopped)
{
return;
}
Socket clientSocket = state.ClientSocket;
try
{
clientSocket.Send(response.GetBytes());
}
catch (SocketException ex)
{
Log(Severity.Debug, "[{0}] Failed to send packet. SocketException: {1}", state.ConnectionIdentifier, ex.Message);
Log(Severity.Trace, "Leaving ProcessSendQueue");
return;
}
catch (ObjectDisposedException)
{
Log(Severity.Debug, "[{0}] Failed to send packet. ObjectDisposedException.", state.ConnectionIdentifier);
Log(Severity.Trace, "Leaving ProcessSendQueue");
return;
}
} }
} }