IFileSystem: Added FileOptions parameter to OpenFile

This commit is contained in:
Tal Aloni 2017-05-19 13:37:11 +03:00
parent 340f55c26e
commit 03cff5e469
6 changed files with 74 additions and 238 deletions

View file

@ -182,7 +182,7 @@ namespace SMBLibrary
// Truncate the file
try
{
Stream temp = m_fileSystem.OpenFile(path, FileMode.Truncate, FileAccess.ReadWrite, FileShare.ReadWrite);
Stream temp = m_fileSystem.OpenFile(path, FileMode.Truncate, FileAccess.ReadWrite, FileShare.ReadWrite, FileOptions.None);
temp.Close();
}
catch (Exception ex)
@ -266,17 +266,13 @@ namespace SMBLibrary
private NTStatus OpenFileStream(out Stream stream, string path, FileAccess fileAccess, ShareAccess shareAccess, CreateOptions openOptions)
{
stream = null;
// When FILE_OPEN_REPARSE_POINT is specified, the operation should continue normally if the file is not a reparse point.
// FILE_OPEN_REPARSE_POINT is a hint that the caller does not intend to actually read the file, with the exception
// of a file copy operation (where the caller will attempt to simply copy the reparse point).
bool openReparsePoint = (openOptions & CreateOptions.FILE_OPEN_REPARSE_POINT) > 0;
bool disableBuffering = (openOptions & CreateOptions.FILE_NO_INTERMEDIATE_BUFFERING) > 0;
bool buffered = (openOptions & CreateOptions.FILE_RANDOM_ACCESS) == 0 && !disableBuffering && !openReparsePoint;
FileShare fileShare = NTFileStoreHelper.ToFileShare(shareAccess);
FileOptions fileOptions = ToFileOptions(openOptions);
string fileShareString = fileShare.ToString().Replace(", ", "|");
string fileOptionsString = ToFileOptionsString(fileOptions);
try
{
stream = m_fileSystem.OpenFile(path, FileMode.Open, fileAccess, fileShare);
stream = m_fileSystem.OpenFile(path, FileMode.Open, fileAccess, fileShare, fileOptions);
}
catch (Exception ex)
{
@ -285,12 +281,7 @@ namespace SMBLibrary
return status;
}
Log(Severity.Information, "OpenFileStream: Opened '{0}', Access={1}, Share={2}, Buffered={3}", path, fileAccess, fileShareString, buffered);
if (buffered)
{
stream = new PrefetchedStream(stream);
}
Log(Severity.Information, "OpenFileStream: Opened '{0}', Access={1}, Share={2}, FileOptions={3}", path, fileAccess, fileShareString, fileOptionsString);
return NTStatus.STATUS_SUCCESS;
}
@ -303,18 +294,6 @@ namespace SMBLibrary
fileHandle.Stream.Close();
}
if (fileHandle.DeleteOnClose)
{
try
{
m_fileSystem.Delete(fileHandle.Path);
Log(Severity.Verbose, "CloseFile: Deleted '{0}'.", fileHandle.Path);
}
catch
{
Log(Severity.Verbose, "CloseFile: Error deleting '{0}'.", fileHandle.Path);
}
}
return NTStatus.STATUS_SUCCESS;
}
@ -462,6 +441,63 @@ namespace SMBLibrary
}
}
private static FileOptions ToFileOptions(CreateOptions createOptions)
{
const FileOptions FILE_FLAG_OPEN_REPARSE_POINT = (FileOptions)0x00200000;
const FileOptions FILE_FLAG_NO_BUFFERING = (FileOptions)0x20000000;
FileOptions result = FileOptions.None;
if ((createOptions & CreateOptions.FILE_OPEN_REPARSE_POINT) > 0)
{
result |= FILE_FLAG_OPEN_REPARSE_POINT;
}
if ((createOptions & CreateOptions.FILE_NO_INTERMEDIATE_BUFFERING) > 0)
{
result |= FILE_FLAG_NO_BUFFERING;
}
if ((createOptions & CreateOptions.FILE_RANDOM_ACCESS) > 0)
{
result |= FileOptions.RandomAccess;
}
if ((createOptions & CreateOptions.FILE_SEQUENTIAL_ONLY) > 0)
{
result |= FileOptions.SequentialScan;
}
if ((createOptions & CreateOptions.FILE_WRITE_THROUGH) > 0)
{
result |= FileOptions.WriteThrough;
}
if ((createOptions & CreateOptions.FILE_DELETE_ON_CLOSE) > 0)
{
result |= FileOptions.DeleteOnClose;
}
return result;
}
private static string ToFileOptionsString(FileOptions options)
{
string result = String.Empty;
const FileOptions FILE_FLAG_OPEN_REPARSE_POINT = (FileOptions)0x00200000;
const FileOptions FILE_FLAG_NO_BUFFERING = (FileOptions)0x20000000;
if ((options & FILE_FLAG_OPEN_REPARSE_POINT) > 0)
{
result += "ReparsePoint|";
options &= ~FILE_FLAG_OPEN_REPARSE_POINT;
}
if ((options & FILE_FLAG_NO_BUFFERING) > 0)
{
result += "NoBuffering|";
options &= ~FILE_FLAG_NO_BUFFERING;
}
if (result == String.Empty || options != FileOptions.None)
{
result += options.ToString().Replace(", ", "|");
}
result = result.TrimEnd(new char[] { '|' });
return result;
}
/// <summary>
/// Will return a virtual allocation size, assuming 4096 bytes per cluster
/// </summary>

View file

@ -540,7 +540,6 @@
<Compile Include="Tests\RPCTests.cs" />
<Compile Include="Tests\SMB2SigningTests.cs" />
<Compile Include="Utilities\LogEntry.cs" />
<Compile Include="Utilities\PrefetchedStream.cs" />
<Compile Include="Utilities\SocketUtils.cs" />
<Compile Include="Win32\IntegratedNTLMAuthenticationProvider.cs" />
<Compile Include="Win32\Security\LoginAPI.cs" />

View file

@ -1,205 +0,0 @@
/* Copyright (C) 2016-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
*
* You can redistribute this program and/or modify it under the terms of
* the GNU Lesser Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*/
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
namespace Utilities
{
public class PrefetchedStream : Stream
{
public const int CacheSize = 524288; // 512 KB
public const int ReadAheadThershold = 65536; // 64 KB
private long m_cacheOffset;
private byte[] m_cache = new byte[0];
private Stream m_stream;
public PrefetchedStream(Stream stream)
{
m_stream = stream;
if (m_stream.CanRead)
{
ScheduleReadAhead();
}
}
private void ScheduleReadAhead()
{
new Thread(delegate()
{
ReadAhead();
}).Start();
}
private void ReadAhead()
{
lock (m_stream)
{
long position = this.Position;
bool isInCache = (position >= m_cacheOffset) && (position < m_cacheOffset + m_cache.Length);
int bytesAlreadyRead;
if (isInCache)
{
int offsetInCache = (int)(position - m_cacheOffset);
bytesAlreadyRead = m_cache.Length - offsetInCache;
byte[] oldCache = m_cache;
m_cache = new byte[CacheSize];
Array.Copy(oldCache, offsetInCache, m_cache, 0, bytesAlreadyRead);
this.Position = position + bytesAlreadyRead;
}
else
{
bytesAlreadyRead = 0;
m_cache = new byte[CacheSize];
}
m_cacheOffset = position;
int bytesRead = m_stream.Read(m_cache, bytesAlreadyRead, CacheSize - bytesAlreadyRead);
System.Diagnostics.Debug.Print("[{0}] {1} bytes have been read ahead from offset {2}.", DateTime.Now.ToString("HH:mm:ss:ffff"), bytesRead, position);
if (bytesAlreadyRead + bytesRead < CacheSize)
{
// EOF, we must trim the response data array
m_cache = ByteReader.ReadBytes(m_cache, 0, bytesAlreadyRead + bytesRead);
}
this.Position = position;
}
}
public override int Read(byte[] buffer, int offset, int count)
{
int bytesCopied;
lock (m_stream)
{
long position = this.Position;
bool isInCache = (position >= m_cacheOffset) && (position < m_cacheOffset + m_cache.Length);
if (isInCache)
{
int offsetInCache = (int)(position - m_cacheOffset);
int bytesAvailableInCache = m_cache.Length - offsetInCache;
bytesCopied = Math.Min(count, bytesAvailableInCache);
Array.Copy(m_cache, offsetInCache, buffer, offset, bytesCopied);
this.Position = position + bytesCopied;
if (bytesCopied < count)
{
int bytesMissing = count - bytesCopied;
int bytesRead = m_stream.Read(buffer, offset + bytesCopied, bytesMissing);
}
if (offsetInCache + ReadAheadThershold >= m_cache.Length)
{
ScheduleReadAhead();
}
}
else
{
bytesCopied = m_stream.Read(buffer, 0, count);
ScheduleReadAhead();
}
}
return bytesCopied;
}
public override void Write(byte[] buffer, int offset, int count)
{
lock (m_stream)
{
m_cache = new byte[0];
m_stream.Write(buffer, offset, count);
}
}
public override void Close()
{
lock (m_stream)
{
m_stream.Close();
}
base.Close();
}
public override bool CanRead
{
get
{
return m_stream.CanRead;
}
}
public override bool CanSeek
{
get
{
return m_stream.CanSeek;
}
}
public override bool CanWrite
{
get
{
return m_stream.CanWrite;
}
}
public override long Length
{
get
{
lock (m_stream)
{
return m_stream.Length;
}
}
}
public override long Position
{
get
{
lock (m_stream)
{
return m_stream.Position;
}
}
set
{
lock (m_stream)
{
m_stream.Position = value;
}
}
}
public override void Flush()
{
lock (m_stream)
{
m_stream.Flush();
}
}
public override long Seek(long offset, SeekOrigin origin)
{
lock (m_stream)
{
return m_stream.Seek(offset, origin);
}
}
public override void SetLength(long value)
{
lock (m_stream)
{
m_stream.SetLength(value);
}
}
}
}

View file

@ -137,11 +137,12 @@ namespace SMBServer
return result;
}
public override Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share)
public override Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share, FileOptions options)
{
ValidatePath(path);
string fullPath = m_directory.FullName + path;
FileStream fileStream = File.Open(fullPath, mode, access, share);
const int DefaultBufferSize = 4096;
FileStream fileStream = new FileStream(fullPath, mode, access, share, DefaultBufferSize, options);
if (!m_openHandles.ContainsKey(fullPath.ToLower()))
{
m_openHandles.Add(fullPath.ToLower(), fileStream.SafeFileHandle);

View file

@ -13,7 +13,7 @@ namespace Utilities
public abstract void Move(string source, string destination);
public abstract void Delete(string path);
public abstract List<FileSystemEntry> ListEntriesInDirectory(string path);
public abstract Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share);
public abstract Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share, FileOptions options);
public abstract void SetAttributes(string path, bool? isHidden, bool? isReadonly, bool? isArchived);
public abstract void SetDates(string path, DateTime? creationDT, DateTime? lastWriteDT, DateTime? lastAccessDT);
@ -22,6 +22,11 @@ namespace Utilities
return ListEntriesInDirectory(@"\");
}
public Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share)
{
return OpenFile(path, mode, access, share, FileOptions.None);
}
public void CopyFile(string sourcePath, string destinationPath)
{
const int bufferLength = 1024 * 1024;
@ -41,8 +46,8 @@ namespace Utilities
{
destinationFile = CreateFile(destinationPath);
}
Stream sourceStream = OpenFile(sourcePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
Stream destinationStream = OpenFile(destinationPath, FileMode.Open, FileAccess.ReadWrite, FileShare.ReadWrite);
Stream sourceStream = OpenFile(sourcePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite, FileOptions.SequentialScan);
Stream destinationStream = OpenFile(destinationPath, FileMode.Open, FileAccess.ReadWrite, FileShare.ReadWrite, FileOptions.None);
while (sourceStream.Position < sourceStream.Length)
{
int readSize = (int)Math.Max(bufferLength, sourceStream.Length - sourceStream.Position);

View file

@ -48,7 +48,7 @@ namespace Utilities
/// <exception cref="System.IO.FileNotFoundException"></exception>
/// <exception cref="System.IO.IOException"></exception>
/// <exception cref="System.UnauthorizedAccessException"></exception>
Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share);
Stream OpenFile(string path, FileMode mode, FileAccess access, FileShare share, FileOptions options);
/// <exception cref="System.ArgumentException"></exception>
/// <exception cref="System.IO.FileNotFoundException"></exception>