From 8870fe9acdb84d6cc1af0c03ea29db6da5c51414 Mon Sep 17 00:00:00 2001 From: raufismayilov Date: Wed, 29 May 2024 12:29:33 +0400 Subject: [PATCH] * Added proper socket disconnecting sequence * Fixed socket.EndReceive call --- SMBLibrary/Client/SMB2Client.cs | 57 ++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/SMBLibrary/Client/SMB2Client.cs b/SMBLibrary/Client/SMB2Client.cs index 4486e9a..b03ed00 100644 --- a/SMBLibrary/Client/SMB2Client.cs +++ b/SMBLibrary/Client/SMB2Client.cs @@ -35,6 +35,7 @@ namespace SMBLibrary.Client private Socket m_clientSocket; private ConnectionState m_connectionState; private int m_responseTimeoutInMilliseconds; + private EventWaitHandle m_disconnectedEventHandle = new EventWaitHandle(false, EventResetMode.ManualReset); private object m_incomingQueueLock = new object(); private List m_incomingQueue = new List(); @@ -120,7 +121,7 @@ namespace SMBLibrary.Client SessionPacket sessionResponsePacket = WaitForSessionResponsePacket(); if (!(sessionResponsePacket is PositiveSessionResponsePacket)) { - m_clientSocket.Disconnect(false); + DisconnectSocket(); if (!ConnectSocket(serverAddress, port)) { return false; @@ -147,6 +148,7 @@ namespace SMBLibrary.Client bool supportsDialect = NegotiateDialect(); if (!supportsDialect) { + DisconnectSocket(); m_clientSocket.Close(); } else @@ -159,6 +161,7 @@ namespace SMBLibrary.Client private bool ConnectSocket(IPAddress serverAddress, int port) { + m_disconnectedEventHandle.Reset(); m_clientSocket = new Socket(serverAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try @@ -176,11 +179,18 @@ namespace SMBLibrary.Client return true; } + private void DisconnectSocket() + { + m_clientSocket.Shutdown(SocketShutdown.Send); + m_disconnectedEventHandle.WaitOne(); + m_clientSocket.Shutdown(SocketShutdown.Receive); + } + public void Disconnect() { if (m_isConnected) { - m_clientSocket.Disconnect(false); + DisconnectSocket(); m_clientSocket.Close(); m_connectionState.ReceiveBuffer.Dispose(); m_isConnected = false; @@ -357,12 +367,6 @@ namespace SMBLibrary.Client ConnectionState state = (ConnectionState)ar.AsyncState; Socket clientSocket = state.ClientSocket; - if (!clientSocket.Connected) - { - state.ReceiveBuffer.Dispose(); - return; - } - int numberOfBytesReceived = 0; try { @@ -371,18 +375,21 @@ namespace SMBLibrary.Client catch (ArgumentException) // The IAsyncResult object was not returned from the corresponding synchronous method on this class. { state.ReceiveBuffer.Dispose(); + m_disconnectedEventHandle.Set(); return; } catch (ObjectDisposedException) { Log("[ReceiveCallback] EndReceive ObjectDisposedException"); state.ReceiveBuffer.Dispose(); + m_disconnectedEventHandle.Set(); return; } catch (SocketException ex) { Log("[ReceiveCallback] EndReceive SocketException: " + ex.Message); state.ReceiveBuffer.Dispose(); + m_disconnectedEventHandle.Set(); return; } @@ -390,6 +397,7 @@ namespace SMBLibrary.Client { m_isConnected = false; state.ReceiveBuffer.Dispose(); + m_disconnectedEventHandle.Set(); } else { @@ -397,24 +405,23 @@ namespace SMBLibrary.Client buffer.SetNumberOfBytesReceived(numberOfBytesReceived); ProcessConnectionBuffer(state); - if (clientSocket.Connected) + try { - try - { - clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state); - } - catch (ObjectDisposedException) - { - m_isConnected = false; - Log("[ReceiveCallback] BeginReceive ObjectDisposedException"); - buffer.Dispose(); - } - catch (SocketException ex) - { - m_isConnected = false; - Log("[ReceiveCallback] BeginReceive SocketException: " + ex.Message); - buffer.Dispose(); - } + clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state); + } + catch (ObjectDisposedException) + { + m_isConnected = false; + Log("[ReceiveCallback] BeginReceive ObjectDisposedException"); + buffer.Dispose(); + m_disconnectedEventHandle.Set(); + } + catch (SocketException ex) + { + m_isConnected = false; + Log("[ReceiveCallback] BeginReceive SocketException: " + ex.Message); + buffer.Dispose(); + m_disconnectedEventHandle.Set(); } } }