Improved connection buffer implementation

This commit is contained in:
Tal Aloni 2017-01-03 16:01:27 +02:00
parent f3f11ba20a
commit 19cb25c463
4 changed files with 155 additions and 85 deletions

View file

@ -116,6 +116,7 @@
<Compile Include="Server\ResponseHelpers\ServerResponseHelper.cs" /> <Compile Include="Server\ResponseHelpers\ServerResponseHelper.cs" />
<Compile Include="Server\FileSystemShare.cs" /> <Compile Include="Server\FileSystemShare.cs" />
<Compile Include="Server\ShareCollection.cs" /> <Compile Include="Server\ShareCollection.cs" />
<Compile Include="Server\SMBConnectionReceiveBuffer.cs" />
<Compile Include="Server\SMBServer.cs" /> <Compile Include="Server\SMBServer.cs" />
<Compile Include="Server\StateObject.cs" /> <Compile Include="Server\StateObject.cs" />
<Compile Include="Server\User.cs" /> <Compile Include="Server\User.cs" />

View file

@ -0,0 +1,131 @@
/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
*
* You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*/
using System;
using System.Collections.Generic;
using System.Text;
using SMBLibrary.NetBios;
using Utilities;
namespace SMBLibrary.Server
{
public class SMBConnectionReceiveBuffer
{
private byte[] m_buffer;
private int m_readOffset = 0;
private int m_bytesInBuffer = 0;
private int? m_packetLength;
/// <param name="bufferLength">Must be large enough to hold the largest possible packet</param>
public SMBConnectionReceiveBuffer(int bufferLength)
{
m_buffer = new byte[bufferLength];
}
public void SetNumberOfBytesReceived(int numberOfBytesReceived)
{
m_bytesInBuffer += numberOfBytesReceived;
}
public bool HasCompletePacket()
{
if (m_bytesInBuffer >= 4)
{
if (!m_packetLength.HasValue)
{
// The packet is either Direct TCP transport packet (which is an NBT Session Message
// Packet) or an NBT packet.
byte flags = ByteReader.ReadByte(m_buffer, m_readOffset + 1);
int trailerLength = (flags & 0x01) << 16 | BigEndianConverter.ToUInt16(m_buffer, m_readOffset + 2);
m_packetLength = 4 + trailerLength;
}
return m_bytesInBuffer >= m_packetLength.Value;
}
return false;
}
/// <summary>
/// HasCompletePacket must be called and return true before calling DequeuePacket
/// </summary>
/// <exception cref="System.IO.InvalidDataException"></exception>
public SessionPacket DequeuePacket()
{
SessionPacket packet;
try
{
packet = SessionPacket.GetSessionPacket(m_buffer, m_readOffset);
}
catch (IndexOutOfRangeException ex)
{
throw new System.IO.InvalidDataException("Invalid Packet", ex);
}
RemovePacketBytes();
return packet;
}
/// <summary>
/// HasCompletePDU must be called and return true before calling DequeuePDUBytes
/// </summary>
public byte[] DequeuePacketBytes()
{
byte[] packetBytes = ByteReader.ReadBytes(m_buffer, m_readOffset, m_packetLength.Value);
RemovePacketBytes();
return packetBytes;
}
private void RemovePacketBytes()
{
m_bytesInBuffer -= m_packetLength.Value;
if (m_bytesInBuffer == 0)
{
m_readOffset = 0;
m_packetLength = null;
}
else
{
m_readOffset += m_packetLength.Value;
m_packetLength = null;
if (!HasCompletePacket())
{
Array.Copy(m_buffer, m_readOffset, m_buffer, 0, m_bytesInBuffer);
m_readOffset = 0;
}
}
}
public byte[] Buffer
{
get
{
return m_buffer;
}
}
public int WriteOffset
{
get
{
return m_readOffset + m_bytesInBuffer;
}
}
public int BytesInBuffer
{
get
{
return m_bytesInBuffer;
}
}
public int AvailableLength
{
get
{
return m_buffer.Length - (m_readOffset + m_bytesInBuffer);
}
}
}
}

View file

@ -95,13 +95,12 @@ namespace SMBLibrary.Server
} }
StateObject state = new StateObject(); StateObject state = new StateObject();
state.ReceiveBuffer = new byte[StateObject.ReceiveBufferSize];
// Disable the Nagle Algorithm for this tcp socket: // Disable the Nagle Algorithm for this tcp socket:
clientSocket.NoDelay = true; clientSocket.NoDelay = true;
state.ClientSocket = clientSocket; state.ClientSocket = clientSocket;
try try
{ {
clientSocket.BeginReceive(state.ReceiveBuffer, 0, StateObject.ReceiveBufferSize, 0, ReceiveCallback, state); clientSocket.BeginReceive(state.ReceiveBuffer.Buffer, state.ReceiveBuffer.WriteOffset, state.ReceiveBuffer.AvailableLength, 0, ReceiveCallback, state);
} }
catch (ObjectDisposedException) catch (ObjectDisposedException)
{ {
@ -123,13 +122,10 @@ namespace SMBLibrary.Server
return; return;
} }
byte[] receiveBuffer = state.ReceiveBuffer; int numberOfBytesReceived;
int bytesReceived;
try try
{ {
bytesReceived = clientSocket.EndReceive(result); numberOfBytesReceived = clientSocket.EndReceive(result);
} }
catch (ObjectDisposedException) catch (ObjectDisposedException)
{ {
@ -140,7 +136,7 @@ namespace SMBLibrary.Server
return; return;
} }
if (bytesReceived == 0) if (numberOfBytesReceived == 0)
{ {
// The other side has closed the connection // The other side has closed the connection
System.Diagnostics.Debug.Print("[{0}] The other side closed the connection", DateTime.Now.ToString("HH:mm:ss:ffff")); System.Diagnostics.Debug.Print("[{0}] The other side closed the connection", DateTime.Now.ToString("HH:mm:ss:ffff"));
@ -148,16 +144,15 @@ namespace SMBLibrary.Server
return; return;
} }
byte[] currentBuffer = new byte[bytesReceived]; SMBConnectionReceiveBuffer receiveBuffer = state.ReceiveBuffer;
Array.Copy(receiveBuffer, currentBuffer, bytesReceived); receiveBuffer.SetNumberOfBytesReceived(numberOfBytesReceived);
ProcessConnectionBuffer(state);
ProcessCurrentBuffer(currentBuffer, state);
if (clientSocket.Connected) if (clientSocket.Connected)
{ {
try try
{ {
clientSocket.BeginReceive(state.ReceiveBuffer, 0, StateObject.ReceiveBufferSize, 0, ReceiveCallback, state); clientSocket.BeginReceive(state.ReceiveBuffer.Buffer, state.ReceiveBuffer.WriteOffset, state.ReceiveBuffer.AvailableLength, 0, ReceiveCallback, state);
} }
catch (ObjectDisposedException) catch (ObjectDisposedException)
{ {
@ -168,88 +163,32 @@ namespace SMBLibrary.Server
} }
} }
public void ProcessCurrentBuffer(byte[] currentBuffer, StateObject state) public void ProcessConnectionBuffer(StateObject state)
{ {
Socket clientSocket = state.ClientSocket; Socket clientSocket = state.ClientSocket;
if (state.ConnectionBuffer.Length == 0) SMBConnectionReceiveBuffer receiveBuffer = state.ReceiveBuffer;
{ while (receiveBuffer.HasCompletePacket())
state.ConnectionBuffer = currentBuffer;
}
else
{
byte[] oldConnectionBuffer = state.ConnectionBuffer;
state.ConnectionBuffer = new byte[oldConnectionBuffer.Length + currentBuffer.Length];
Array.Copy(oldConnectionBuffer, state.ConnectionBuffer, oldConnectionBuffer.Length);
Array.Copy(currentBuffer, 0, state.ConnectionBuffer, oldConnectionBuffer.Length, currentBuffer.Length);
}
// we now have all SMB message bytes received so far in state.ConnectionBuffer
int bytesLeftInBuffer = state.ConnectionBuffer.Length;
while (bytesLeftInBuffer >= 4)
{
// The packet is either Direct TCP transport packet (which is an NBT Session Message
// Packet) or an NBT packet.
int bufferOffset = state.ConnectionBuffer.Length - bytesLeftInBuffer;
byte flags = ByteReader.ReadByte(state.ConnectionBuffer, bufferOffset + 1);
int trailerLength = (flags & 0x01) << 16 | BigEndianConverter.ToUInt16(state.ConnectionBuffer, bufferOffset + 2);
int packetLength = 4 + trailerLength;
if (flags > 0x01)
{
System.Diagnostics.Debug.Print("[{0}] Invalid NBT flags", DateTime.Now.ToString("HH:mm:ss:ffff"));
state.ClientSocket.Close();
return;
}
if (packetLength > bytesLeftInBuffer)
{
break;
}
else
{
byte[] packetBytes = new byte[packetLength];
Array.Copy(state.ConnectionBuffer, bufferOffset, packetBytes, 0, packetLength);
ProcessPacket(packetBytes, state);
bytesLeftInBuffer -= packetLength;
if (!clientSocket.Connected)
{
// Do not continue to process the buffer if the other side closed the connection
return;
}
}
}
if (bytesLeftInBuffer > 0)
{
byte[] newReceiveBuffer = new byte[bytesLeftInBuffer];
Array.Copy(state.ConnectionBuffer, state.ConnectionBuffer.Length - bytesLeftInBuffer, newReceiveBuffer, 0, bytesLeftInBuffer);
state.ConnectionBuffer = newReceiveBuffer;
}
else
{
state.ConnectionBuffer = new byte[0];
}
}
public void ProcessPacket(byte[] packetBytes, StateObject state)
{ {
SessionPacket packet = null; SessionPacket packet = null;
#if DEBUG
packet = SessionPacket.GetSessionPacket(packetBytes, 0);
#else
try try
{ {
packet = SessionPacket.GetSessionPacket(packetBytes, 0); packet = receiveBuffer.DequeuePacket();
} }
catch (Exception) catch (Exception)
{ {
state.ClientSocket.Close(); state.ClientSocket.Close();
return;
} }
#endif
if (packet != null)
{
ProcessPacket(packet, state);
}
}
}
public void ProcessPacket(SessionPacket packet, StateObject state)
{
if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP) if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP)
{ {
PositiveSessionResponsePacket response = new PositiveSessionResponsePacket(); PositiveSessionResponsePacket response = new PositiveSessionResponsePacket();

View file

@ -1,4 +1,4 @@
/* Copyright (C) 2014-2016 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved. /* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
* *
* You can redistribute this program and/or modify it under the terms of * You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation, * the GNU Lesser Public License as published by the Free Software Foundation,
@ -16,9 +16,8 @@ namespace SMBLibrary.Server
public class StateObject public class StateObject
{ {
public Socket ClientSocket = null; public Socket ClientSocket = null;
public const int ReceiveBufferSize = 65536; public const int ReceiveBufferSize = 131075; // Largest NBT Session Packet
public byte[] ReceiveBuffer = new byte[ReceiveBufferSize]; // immediate receive buffer public SMBConnectionReceiveBuffer ReceiveBuffer = new SMBConnectionReceiveBuffer(ReceiveBufferSize);
public byte[] ConnectionBuffer = new byte[0]; // we append the receive buffer here until we have a complete Message
public int MaxBufferSize; public int MaxBufferSize;
public bool LargeRead; public bool LargeRead;