diff --git a/SMBLibrary/Win32/Authentication/SSPIHelper.cs b/SMBLibrary/Win32/Authentication/SSPIHelper.cs index 04321dd..0c0ccca 100644 --- a/SMBLibrary/Win32/Authentication/SSPIHelper.cs +++ b/SMBLibrary/Win32/Authentication/SSPIHelper.cs @@ -147,7 +147,7 @@ namespace SMBLibrary.Authentication.Win32 ); [DllImport("Secur32.dll")] - private extern static int DeleteSecurityContext( + public extern static int DeleteSecurityContext( ref SecHandle phContext ); @@ -205,13 +205,14 @@ namespace SMBLibrary.Authentication.Win32 public static byte[] GetType1Message(string domainName, string userName, string password, out SecHandle clientContext) { - SecHandle handle = AcquireNTLMCredentialsHandle(domainName, userName, password); + SecHandle credentialsHandle = AcquireNTLMCredentialsHandle(domainName, userName, password); clientContext = new SecHandle(); - SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE); + SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE); + SecBufferDesc output = new SecBufferDesc(outputBuffer); uint contextAttributes; SECURITY_INTEGER expiry; - int result = InitializeSecurityContext(ref handle, IntPtr.Zero, null, ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY, 0, SECURITY_NATIVE_DREP, IntPtr.Zero, 0, ref clientContext, ref output, out contextAttributes, out expiry); + int result = InitializeSecurityContext(ref credentialsHandle, IntPtr.Zero, null, ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY, 0, SECURITY_NATIVE_DREP, IntPtr.Zero, 0, ref clientContext, ref output, out contextAttributes, out expiry); if (result != SEC_E_OK && result != SEC_I_CONTINUE_NEEDED) { if ((uint)result == SEC_E_INVALID_HANDLE) @@ -227,14 +228,20 @@ namespace SMBLibrary.Authentication.Win32 throw new Exception("InitializeSecurityContext failed, Error code " + ((uint)result).ToString("X")); } } - return output.GetSecBufferBytes(); + FreeCredentialsHandle(ref credentialsHandle); + byte[] messageBytes = outputBuffer.GetBufferBytes(); + outputBuffer.Dispose(); + output.Dispose(); + return messageBytes; } public static byte[] GetType3Message(SecHandle clientContext, byte[] type2Message) { SecHandle newContext = new SecHandle(); - SecBufferDesc input = new SecBufferDesc(type2Message); - SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE); + SecBuffer inputBuffer = new SecBuffer(type2Message); + SecBufferDesc input = new SecBufferDesc(inputBuffer); + SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE); + SecBufferDesc output = new SecBufferDesc(outputBuffer); uint contextAttributes; SECURITY_INTEGER expiry; @@ -254,19 +261,26 @@ namespace SMBLibrary.Authentication.Win32 throw new Exception("InitializeSecurityContext failed, error code " + ((uint)result).ToString("X")); } } - return output.GetSecBufferBytes(); + byte[] messageBytes = outputBuffer.GetBufferBytes(); + inputBuffer.Dispose(); + input.Dispose(); + outputBuffer.Dispose(); + output.Dispose(); + return messageBytes; } public static byte[] GetType2Message(byte[] type1MessageBytes, out SecHandle serverContext) { - SecHandle handle = AcquireNTLMCredentialsHandle(); - SecBufferDesc type1Message = new SecBufferDesc(type1MessageBytes); + SecHandle credentialsHandle = AcquireNTLMCredentialsHandle(); + SecBuffer inputBuffer = new SecBuffer(type1MessageBytes); + SecBufferDesc input = new SecBufferDesc(inputBuffer); serverContext = new SecHandle(); - SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE); + SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE); + SecBufferDesc output = new SecBufferDesc(outputBuffer); uint contextAttributes; SECURITY_INTEGER timestamp; - int result = AcceptSecurityContext(ref handle, IntPtr.Zero, ref type1Message, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref serverContext, ref output, out contextAttributes, out timestamp); + int result = AcceptSecurityContext(ref credentialsHandle, IntPtr.Zero, ref input, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref serverContext, ref output, out contextAttributes, out timestamp); if (result != SEC_E_OK && result != SEC_I_CONTINUE_NEEDED) { if ((uint)result == SEC_E_INVALID_HANDLE) @@ -282,8 +296,13 @@ namespace SMBLibrary.Authentication.Win32 throw new Exception("AcceptSecurityContext failed, error code " + ((uint)result).ToString("X")); } } - FreeCredentialsHandle(ref handle); - return output.GetSecBufferBytes(); + FreeCredentialsHandle(ref credentialsHandle); + byte[] messageBytes = outputBuffer.GetBufferBytes(); + inputBuffer.Dispose(); + input.Dispose(); + outputBuffer.Dispose(); + output.Dispose(); + return messageBytes; } /// @@ -303,13 +322,20 @@ namespace SMBLibrary.Authentication.Win32 public static bool AuthenticateType3Message(SecHandle serverContext, byte[] type3MessageBytes) { SecHandle newContext = new SecHandle(); - SecBufferDesc type3Message = new SecBufferDesc(type3MessageBytes); - SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE); + SecBuffer inputBuffer = new SecBuffer(type3MessageBytes); + SecBufferDesc input = new SecBufferDesc(inputBuffer); + SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE); + SecBufferDesc output = new SecBufferDesc(outputBuffer); uint contextAttributes; SECURITY_INTEGER timestamp; - int result = AcceptSecurityContext(IntPtr.Zero, ref serverContext, ref type3Message, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref newContext, ref output, out contextAttributes, out timestamp); - + int result = AcceptSecurityContext(IntPtr.Zero, ref serverContext, ref input, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref newContext, ref output, out contextAttributes, out timestamp); + + inputBuffer.Dispose(); + input.Dispose(); + outputBuffer.Dispose(); + output.Dispose(); + if (result == SEC_E_OK) { return true; diff --git a/SMBLibrary/Win32/Authentication/SecBufferDesc.cs b/SMBLibrary/Win32/Authentication/SecBufferDesc.cs index e8d1850..e6bdfb4 100644 --- a/SMBLibrary/Win32/Authentication/SecBufferDesc.cs +++ b/SMBLibrary/Win32/Authentication/SecBufferDesc.cs @@ -1,4 +1,4 @@ -/* Copyright (C) 2014 Tal Aloni . All rights reserved. +/* Copyright (C) 2014-2017 Tal Aloni . All rights reserved. * * You can redistribute this program and/or modify it under the terms of * the GNU Lesser Public License as published by the Free Software Foundation, @@ -11,7 +11,7 @@ using System.Text; namespace SMBLibrary.Authentication.Win32 { - public enum SecBufferType + public enum SecBufferType : uint { SECBUFFER_VERSION = 0, SECBUFFER_EMPTY = 0, @@ -20,33 +20,33 @@ namespace SMBLibrary.Authentication.Win32 } [StructLayout(LayoutKind.Sequential)] - public struct SecBuffer + public struct SecBuffer : IDisposable { - public int cbBuffer; - public int BufferType; - public IntPtr pvBuffer; + public uint cbBuffer; // Specifies the size, in bytes, of the buffer pointed to by the pvBuffer member. + public uint BufferType; + public IntPtr pvBuffer; // A pointer to a buffer. public SecBuffer(int bufferSize) { - cbBuffer = bufferSize; - BufferType = (int)SecBufferType.SECBUFFER_TOKEN; + cbBuffer = (uint)bufferSize; + BufferType = (uint)SecBufferType.SECBUFFER_TOKEN; pvBuffer = Marshal.AllocHGlobal(bufferSize); } public SecBuffer(byte[] secBufferBytes) { - cbBuffer = secBufferBytes.Length; - BufferType = (int)SecBufferType.SECBUFFER_TOKEN; - pvBuffer = Marshal.AllocHGlobal(cbBuffer); - Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer); + cbBuffer = (uint)secBufferBytes.Length; + BufferType = (uint)SecBufferType.SECBUFFER_TOKEN; + pvBuffer = Marshal.AllocHGlobal(secBufferBytes.Length); + Marshal.Copy(secBufferBytes, 0, pvBuffer, secBufferBytes.Length); } public SecBuffer(byte[] secBufferBytes, SecBufferType bufferType) { - cbBuffer = secBufferBytes.Length; - BufferType = (int)bufferType; - pvBuffer = Marshal.AllocHGlobal(cbBuffer); - Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer); + cbBuffer = (uint)secBufferBytes.Length; + BufferType = (uint)bufferType; + pvBuffer = Marshal.AllocHGlobal(secBufferBytes.Length); + Marshal.Copy(secBufferBytes, 0, pvBuffer, secBufferBytes.Length); } public void Dispose() @@ -58,63 +58,50 @@ namespace SMBLibrary.Authentication.Win32 } } - public byte[] GetBytes() + public byte[] GetBufferBytes() { byte[] buffer = null; if (cbBuffer > 0) { buffer = new byte[cbBuffer]; - Marshal.Copy(pvBuffer, buffer, 0, cbBuffer); + Marshal.Copy(pvBuffer, buffer, 0, (int)cbBuffer); } return buffer; } } - /// - /// Simplified SecBufferDesc struct with only one SecBuffer - /// [StructLayout(LayoutKind.Sequential)] - public struct SecBufferDesc + public struct SecBufferDesc : IDisposable { - public int ulVersion; - public int cBuffers; - public IntPtr pBuffers; + public uint ulVersion; + public uint cBuffers; // Indicates the number of SecBuffer structures in the pBuffers array. + public IntPtr pBuffers; // Pointer to an array of SecBuffer structures. - public SecBufferDesc(int bufferSize) + public SecBufferDesc(SecBuffer buffer) : this(new SecBuffer[] { buffer }) { - ulVersion = (int)SecBufferType.SECBUFFER_VERSION; - cBuffers = 1; - SecBuffer secBuffer = new SecBuffer(bufferSize); - pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer)); - Marshal.StructureToPtr(secBuffer, pBuffers, false); } - public SecBufferDesc(byte[] secBufferBytes) + public SecBufferDesc(SecBuffer[] buffers) { - ulVersion = (int)SecBufferType.SECBUFFER_VERSION; - cBuffers = 1; - SecBuffer secBuffer = new SecBuffer(secBufferBytes); - pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer)); - Marshal.StructureToPtr(secBuffer, pBuffers, false); + int secBufferSize = Marshal.SizeOf(typeof(SecBuffer)); + ulVersion = (uint)SecBufferType.SECBUFFER_VERSION; + cBuffers = (uint)buffers.Length; + pBuffers = Marshal.AllocHGlobal(buffers.Length * secBufferSize); + IntPtr currentBuffer = pBuffers; + for (int index = 0; index < buffers.Length; index++) + { + Marshal.StructureToPtr(buffers[index], currentBuffer, false); + currentBuffer = new IntPtr(currentBuffer.ToInt64() + secBufferSize); + } } public void Dispose() { if (pBuffers != IntPtr.Zero) { - SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer)); - secBuffer.Dispose(); Marshal.FreeHGlobal(pBuffers); pBuffers = IntPtr.Zero; } } - - public byte[] GetSecBufferBytes() - { - if (pBuffers == IntPtr.Zero) - throw new ObjectDisposedException("SecBufferDesc"); - SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer)); - return secBuffer.GetBytes(); - } } }