https://git.reactos.org/?p=reactos.git;a=commitdiff;h=dfb776380d4ca289a30b999791040d80bd38e0ae
commit dfb776380d4ca289a30b999791040d80bd38e0ae Author: Ged Murphy <gedmur...@reactos.org> AuthorDate: Tue Nov 21 16:36:29 2017 +0000 [FLTMGR] Latest from my branch (#135) [FLTMGR][KMTEST] Squash and push my local branch across to master as the patch is getting a bit large. This is still WIP and none of this code is run in ros yet, so don't fret if you see ugly/unfinished code or int3's dotted around. [FLTMGR] Improve loading/reg of filters and start to implement client connections - Implement handling of connections from clients - Implement closing of client ports - Add a basic message waiter queue using CSQ's (untested) - Hand off messages for the comms object to be handled by the comms file - Initialize the connection list - Add a registry file which will contain lib functions for accessing filter service entries - [KMTEST] Initial usermode support for testing FS mini-filters - Add base routines to wrap the win32 'Filter' APis - Add support routines to be used when testing FS filter drivers - Move KmtCreateService to a private routine so it can be shared with KmtFltCreateService - Completely untested at the mo, so likely contains bugs at this point - Add support for adding altitude and flags registry entries for minifilters - Allow minifilters to setup without requiring instance attach/detach callbacks - Add tests for FltRegisterFilter and FltUnregisterFilter and start to add associated tests --- drivers/filters/fltmgr/CMakeLists.txt | 1 + drivers/filters/fltmgr/Filter.c | 223 +++++++++- drivers/filters/fltmgr/Interface.c | 22 +- drivers/filters/fltmgr/Messaging.c | 462 ++++++++++++++++++++- drivers/filters/fltmgr/Misc.c | 2 +- drivers/filters/fltmgr/Registry.c | 142 +++++++ drivers/filters/fltmgr/Registry.h | 20 + drivers/filters/fltmgr/fltmgr.h | 7 +- drivers/filters/fltmgr/fltmgrint.h | 180 +++++--- modules/rostests/kmtests/CMakeLists.txt | 4 + modules/rostests/kmtests/fltmgr/CMakeLists.txt | 1 + .../kmtests/fltmgr/fltmgr_load/CMakeLists.txt | 12 +- .../kmtests/fltmgr/fltmgr_load/fltmgr_load.c | 2 +- .../kmtests/fltmgr/fltmgr_load/fltmgr_user.c | 28 ++ .../kmtests/fltmgr/fltmgr_register/CMakeLists.txt | 15 + .../fltmgr/fltmgr_register/fltmgr_reg_user.c | 23 + .../fltmgr/fltmgr_register/fltmgr_register.c | 248 +++++++++++ modules/rostests/kmtests/include/kmt_platform.h | 1 + modules/rostests/kmtests/include/kmt_test.h | 17 +- modules/rostests/kmtests/kmtest/filter.c | 88 +--- modules/rostests/kmtests/kmtest/fltsupport.c | 335 +++++++++++++-- modules/rostests/kmtests/kmtest/kmtest.h | 11 +- modules/rostests/kmtests/kmtest/testlist.c | 4 + .../kmtests/kmtest_drv/kmtest_fsminifilter.c | 121 ++++-- 24 files changed, 1739 insertions(+), 230 deletions(-) diff --git a/drivers/filters/fltmgr/CMakeLists.txt b/drivers/filters/fltmgr/CMakeLists.txt index 5a02cab116..cb6ebbdd18 100644 --- a/drivers/filters/fltmgr/CMakeLists.txt +++ b/drivers/filters/fltmgr/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND SOURCE Messaging.c Misc.c Object.c + Registry.c Volume.c ${CMAKE_CURRENT_BINARY_DIR}/fltmgr.def fltmgr.h) diff --git a/drivers/filters/fltmgr/Filter.c b/drivers/filters/fltmgr/Filter.c index 71e45aea04..92a824c0d3 100644 --- a/drivers/filters/fltmgr/Filter.c +++ b/drivers/filters/fltmgr/Filter.c @@ -10,6 +10,7 @@ #include "fltmgr.h" #include "fltmgrint.h" +#include "Registry.h" #define NDEBUG #include <debug.h> @@ -25,6 +26,17 @@ FltpStartingToDrainObject( _Inout_ PFLT_OBJECT Object ); +VOID +FltpMiniFilterDriverUnload( +); + +static +NTSTATUS +GetFilterAltitude( + _In_ PFLT_FILTER Filter, + _Inout_ PUNICODE_STRING AltitudeString +); + /* EXPORTED FUNCTIONS ******************************************************/ @@ -56,8 +68,26 @@ NTSTATUS NTAPI FltUnloadFilter(_In_ PCUNICODE_STRING FilterName) { - UNREFERENCED_PARAMETER(FilterName); - return STATUS_NOT_IMPLEMENTED; + // + //FIXME: This is a temp hack, it needs properly implementing + // + + UNICODE_STRING DriverServiceName; + UNICODE_STRING ServicesKey; + CHAR Buffer[MAX_KEY_LENGTH]; + + /* Setup the base services key */ + RtlInitUnicodeString(&ServicesKey, SERVICES_KEY); + + /* Initialize the string data */ + DriverServiceName.Length = 0; + DriverServiceName.Buffer = (PWCH)Buffer; + DriverServiceName.MaximumLength = MAX_KEY_LENGTH; + + /* Create the full service key for this filter */ + RtlCopyUnicodeString(&DriverServiceName, &ServicesKey); + RtlAppendUnicodeStringToString(&DriverServiceName, FilterName); + return ZwUnloadDriver(&DriverServiceName); } NTSTATUS @@ -74,6 +104,8 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject, PCHAR Ptr; NTSTATUS Status; + *RetFilter = NULL; + /* Make sure we're targeting the correct major revision */ if ((Registration->Version & 0xFF00) != FLT_MAJOR_VERSION) { @@ -145,6 +177,10 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject, InitializeListHead(&Filter->ActiveOpens.mList); Filter->ActiveOpens.mCount = 0; + ExInitializeFastMutex(&Filter->ConnectionList.mLock); + InitializeListHead(&Filter->ConnectionList.mList); + Filter->ConnectionList.mCount = 0; + /* Initialize the usermode port list */ ExInitializeFastMutex(&Filter->PortList.mLock); InitializeListHead(&Filter->PortList.mList); @@ -199,15 +235,42 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject, Filter->Name.Buffer = (PWCH)Ptr; RtlCopyUnicodeString(&Filter->Name, &DriverObject->DriverExtension->ServiceKeyName); + Status = GetFilterAltitude(Filter, &Filter->DefaultAltitude); + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + // - // - Get the altitude string // - Slot the filter into the correct altitude location // - More stuff?? // + /* Store any existing driver unload routine before we make any changes */ + Filter->OldDriverUnload = (PFLT_FILTER_UNLOAD_CALLBACK)DriverObject->DriverUnload; + + /* Check we opted not to have an unload routine, or if we want to stop the driver from being unloaded */ + if (!FlagOn(Filter->Flags, FLTFL_REGISTRATION_DO_NOT_SUPPORT_SERVICE_STOP)) + { + DriverObject->DriverUnload = (PDRIVER_UNLOAD)FltpMiniFilterDriverUnload; + } + else + { + DriverObject->DriverUnload = (PDRIVER_UNLOAD)NULL; + } + + Quit: - if (!NT_SUCCESS(Status)) + + if (NT_SUCCESS(Status)) + { + DPRINT1("Loaded FS mini-filter %wZ\n", &DriverObject->DriverExtension->ServiceKeyName); + *RetFilter = Filter; + } + else { + DPRINT1("Failed to load FS mini-filter %wZ : 0x%X\n", &DriverObject->DriverExtension->ServiceKeyName, Status); + // Add cleanup for context resources ExDeleteResourceLite(&Filter->InstanceList.rLock); @@ -319,3 +382,155 @@ FltpStartingToDrainObject(_Inout_ PFLT_OBJECT Object) return STATUS_SUCCESS; } + +VOID +FltpMiniFilterDriverUnload() +{ + __debugbreak(); +} + +/* PRIVATE FUNCTIONS ******************************************************/ + +static +NTSTATUS +GetFilterAltitude( + _In_ PFLT_FILTER Filter, + _Inout_ PUNICODE_STRING AltitudeString) +{ + UNICODE_STRING InstancesKey = RTL_CONSTANT_STRING(L"Instances"); + UNICODE_STRING DefaultInstance = RTL_CONSTANT_STRING(L"DefaultInstance"); + UNICODE_STRING Altitude = RTL_CONSTANT_STRING(L"Altitude"); + OBJECT_ATTRIBUTES ObjectAttributes; + UNICODE_STRING FilterInstancePath; + ULONG BytesRequired; + HANDLE InstHandle = NULL; + HANDLE RootHandle; + PWCH InstBuffer = NULL; + PWCH AltBuffer = NULL; + NTSTATUS Status; + + /* Get a handle to the instances key in the filter's services key */ + Status = FltpOpenFilterServicesKey(Filter, + KEY_QUERY_VALUE, + &InstancesKey, + &RootHandle); + if (!NT_SUCCESS(Status)) + { + return Status; + } + + /* Read the size 'default instances' string value */ + Status = FltpReadRegistryValue(RootHandle, + &DefaultInstance, + REG_SZ, + NULL, + 0, + &BytesRequired); + + /* We should get a buffer too small error */ + if (Status == STATUS_BUFFER_TOO_SMALL) + { + /* Allocate the buffer we need to hold the string */ + InstBuffer = ExAllocatePoolWithTag(PagedPool, BytesRequired, FM_TAG_UNICODE_STRING); + if (InstBuffer == NULL) + { + Status = STATUS_INSUFFICIENT_RESOURCES; + goto Quit; + } + + /* Now read the string value */ + Status = FltpReadRegistryValue(RootHandle, + &DefaultInstance, + REG_SZ, + InstBuffer, + BytesRequired, + &BytesRequired); + } + + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + + /* Convert the string to a unicode_string */ + RtlInitUnicodeString(&FilterInstancePath, InstBuffer); + + /* Setup the attributes using the root key handle */ + InitializeObjectAttributes(&ObjectAttributes, + &FilterInstancePath, + OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE, + RootHandle, + NULL); + + /* Now open the key name which was stored in the default instance */ + Status = ZwOpenKey(&InstHandle, KEY_QUERY_VALUE, &ObjectAttributes); + if (NT_SUCCESS(Status)) + { + /* Get the size of the buffer that holds the altitude */ + Status = FltpReadRegistryValue(InstHandle, + &Altitude, + REG_SZ, + NULL, + 0, + &BytesRequired); + if (Status == STATUS_BUFFER_TOO_SMALL) + { + /* Allocate the required buffer */ + AltBuffer = ExAllocatePoolWithTag(PagedPool, BytesRequired, FM_TAG_UNICODE_STRING); + if (AltBuffer == NULL) + { + Status = STATUS_INSUFFICIENT_RESOURCES; + goto Quit; + } + + /* And now finally read in the actual altitude string */ + Status = FltpReadRegistryValue(InstHandle, + &Altitude, + REG_SZ, + AltBuffer, + BytesRequired, + &BytesRequired); + if (NT_SUCCESS(Status)) + { + /* We made it, setup the return buffer */ + AltitudeString->Length = BytesRequired; + AltitudeString->MaximumLength = BytesRequired; + AltitudeString->Buffer = AltBuffer; + } + } + } + +Quit: + if (!NT_SUCCESS(Status)) + { + if (AltBuffer) + { + ExFreePoolWithTag(AltBuffer, FM_TAG_UNICODE_STRING); + } + } + + if (InstBuffer) + { + ExFreePoolWithTag(InstBuffer, FM_TAG_UNICODE_STRING); + } + + if (InstHandle) + { + ZwClose(InstHandle); + } + ZwClose(RootHandle); + + return Status; +} + + + +NTSTATUS +FltpReadRegistryValue( + _In_ HANDLE KeyHandle, + _In_ PUNICODE_STRING ValueName, + _In_opt_ ULONG Type, + _Out_writes_bytes_(BufferSize) PVOID Buffer, + _In_ ULONG BufferSize, + _Out_opt_ PULONG BytesRequired +); \ No newline at end of file diff --git a/drivers/filters/fltmgr/Interface.c b/drivers/filters/fltmgr/Interface.c index be1288a7c2..158303be92 100644 --- a/drivers/filters/fltmgr/Interface.c +++ b/drivers/filters/fltmgr/Interface.c @@ -28,6 +28,8 @@ ((_devObj)->DriverObject == Dispatcher::DriverObject) && \ ((_devObj)->DeviceExtension != NULL)) +extern PDEVICE_OBJECT CommsDeviceObject; + DRIVER_INITIALIZE DriverEntry; NTSTATUS @@ -454,6 +456,13 @@ FltpDispatch(_In_ PDEVICE_OBJECT DeviceObject, return Status; } + /* Check if this is a request for a the messaging device */ + if (DeviceObject == CommsDeviceObject) + { + /* Hand off to our internal routine */ + return FltpMsgDispatch(DeviceObject, Irp); + } + FLT_ASSERT(DeviceExtension && DeviceExtension->AttachedToDeviceObject); @@ -494,6 +503,13 @@ FltpCreate(_In_ PDEVICE_OBJECT DeviceObject, return STATUS_SUCCESS; } + /* Check if this is a request for a the new comms connection */ + if (DeviceObject == CommsDeviceObject) + { + /* Hand off to our internal routine */ + return FltpMsgCreate(DeviceObject, Irp); + } + FLT_ASSERT(DeviceExtension && DeviceExtension->AttachedToDeviceObject); @@ -2104,9 +2120,9 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, Status = SetupDispatchAndCallbacksTables(DriverObject); if (!NT_SUCCESS(Status)) goto Cleanup; - // - // TODO: Create fltmgr message device - // + /* Initialize the comms objects */ + Status = FltpSetupCommunicationObjects(DriverObject); + if (!NT_SUCCESS(Status)) goto Cleanup; /* Register for notifications when a new file system is loaded. This also enumerates any existing file systems */ Status = IoRegisterFsRegistrationChange(DriverObject, FltpFsNotification); diff --git a/drivers/filters/fltmgr/Messaging.c b/drivers/filters/fltmgr/Messaging.c index 97e6a94c9c..6a40e0d8c1 100644 --- a/drivers/filters/fltmgr/Messaging.c +++ b/drivers/filters/fltmgr/Messaging.c @@ -10,6 +10,7 @@ #include "fltmgr.h" #include "fltmgrint.h" +#include <fltmgr_shared.h> #define NDEBUG #include <debug.h> @@ -29,6 +30,32 @@ FltpDisconnectPort( _In_ PFLT_PORT_OBJECT PortObject ); +static +NTSTATUS +CreateClientPort( + _In_ PFILE_OBJECT FileObject, + _Inout_ PIRP Irp +); + +static +NTSTATUS +CloseClientPort( + _In_ PFILE_OBJECT FileObject, + _Inout_ PIRP Irp +); + +static +NTSTATUS +InitializeMessageWaiterQueue( + _Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue +); + +static +PPORT_CCB +CreatePortCCB( + _In_ PFLT_PORT_OBJECT PortObject +); + /* EXPORTED FUNCTIONS ******************************************************/ @@ -72,8 +99,8 @@ FltCreateCommunicationPort(_In_ PFLT_FILTER Filter, return Status; } - /* Create our new server port object */ - Status = ObCreateObject(0, + /* Create the server port object for this filter */ + Status = ObCreateObject(KernelMode, ServerPortObjectType, ObjectAttributes, KernelMode, @@ -191,6 +218,72 @@ FltSendMessage(_In_ PFLT_FILTER Filter, /* INTERNAL FUNCTIONS ******************************************************/ + +NTSTATUS +FltpMsgCreate(_In_ PDEVICE_OBJECT DeviceObject, + _Inout_ PIRP Irp) +{ + PIO_STACK_LOCATION StackPtr; + NTSTATUS Status; + + /* Get the stack location */ + StackPtr = IoGetCurrentIrpStackLocation(Irp); + + FLT_ASSERT(StackPtr->MajorFunction == IRP_MJ_CREATE); + + /* Check if this is a caller wanting to connect */ + if (StackPtr->MajorFunction == IRP_MJ_CREATE) + { + /* Create the client port for this connection and exit */ + Status = CreateClientPort(StackPtr->FileObject, Irp); + } + else + { + Status = STATUS_INVALID_PARAMETER; + } + + if (Status != STATUS_PENDING) + { + Irp->IoStatus.Status = Status; + Irp->IoStatus.Information = 0; + IoCompleteRequest(Irp, 0); + } + + return Status; +} + +NTSTATUS +FltpMsgDispatch(_In_ PDEVICE_OBJECT DeviceObject, + _Inout_ PIRP Irp) +{ + PIO_STACK_LOCATION StackPtr; + NTSTATUS Status; + + /* Get the stack location */ + StackPtr = IoGetCurrentIrpStackLocation(Irp); + + /* Check if this is a caller wanting to connect */ + if (StackPtr->MajorFunction == IRP_MJ_CLOSE) + { + /* Create the client port for this connection and exit */ + Status = CloseClientPort(StackPtr->FileObject, Irp); + } + else + { + // We don't support anything else yet + Status = STATUS_NOT_IMPLEMENTED; + } + + if (Status != STATUS_PENDING) + { + Irp->IoStatus.Status = Status; + Irp->IoStatus.Information = 0; + IoCompleteRequest(Irp, 0); + } + + return Status; +} + VOID NTAPI FltpServerPortClose(_In_opt_ PEPROCESS Process, @@ -369,5 +462,370 @@ Quit: return Status; } +/* CSQ IRP CALLBACKS *******************************************************/ + + +NTSTATUS +NTAPI +FltpAddMessageWaiter(_In_ PIO_CSQ Csq, + _In_ PIRP Irp, + _In_ PVOID InsertContext) +{ + PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue; + + /* Get the start of the waiter queue struct */ + MessageWaiterQueue = CONTAINING_RECORD(Csq, + FLT_MESSAGE_WAITER_QUEUE, + Csq); + + /* Insert the IRP at the end of the queue */ + InsertTailList(&MessageWaiterQueue->WaiterQ.mList, + &Irp->Tail.Overlay.ListEntry); + + /* return success */ + return STATUS_SUCCESS; +} + +VOID +NTAPI +FltpRemoveMessageWaiter(_In_ PIO_CSQ Csq, + _In_ PIRP Irp) +{ + /* Remove the IRP from the queue */ + RemoveEntryList(&Irp->Tail.Overlay.ListEntry); +} + +PIRP +NTAPI +FltpGetNextMessageWaiter(_In_ PIO_CSQ Csq, + _In_ PIRP Irp, + _In_ PVOID PeekContext) +{ + PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue; + PIRP NextIrp = NULL; + PLIST_ENTRY NextEntry; + PIO_STACK_LOCATION IrpStack; + + /* Get the start of the waiter queue struct */ + MessageWaiterQueue = CONTAINING_RECORD(Csq, + FLT_MESSAGE_WAITER_QUEUE, + Csq); + + /* Is the IRP valid? */ + if (Irp == NULL) + { + /* Start peeking from the listhead */ + NextEntry = MessageWaiterQueue->WaiterQ.mList.Flink; + } + else + { + /* Start peeking from that IRP onwards */ + NextEntry = Irp->Tail.Overlay.ListEntry.Flink; + } + + /* Loop through the queue */ + while (NextEntry != &MessageWaiterQueue->WaiterQ.mList) + { + /* Store the next IRP in the list */ + NextIrp = CONTAINING_RECORD(NextEntry, IRP, Tail.Overlay.ListEntry); + + /* Did we supply a PeekContext on insert? */ + if (!PeekContext) + { + /* We already have the next IRP */ + break; + } + else + { + /* Get the stack of the next IRP */ + IrpStack = IoGetCurrentIrpStackLocation(NextIrp); + + /* Does the PeekContext match the object? */ + if (IrpStack->FileObject == (PFILE_OBJECT)PeekContext) + { + /* We have a match */ + break; + } + + /* Move to the next IRP */ + NextIrp = NULL; + NextEntry = NextEntry->Flink; + } + } + + return NextIrp; +} + +_Acquires_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, FLT_MESSAGE_WAITER_QUEUE, Csq))->WaiterQ.mLock) +_IRQL_saves_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock) +_IRQL_raises_(DISPATCH_LEVEL) +VOID +NTAPI +FltpAcquireMessageWaiterLock(_In_ PIO_CSQ Csq, + _Out_ PKIRQL Irql) +{ + PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue; + + UNREFERENCED_PARAMETER(Irql); + + /* Get the start of the waiter queue struct */ + MessageWaiterQueue = CONTAINING_RECORD(Csq, + FLT_MESSAGE_WAITER_QUEUE, + Csq); + + /* Acquire the IRP queue lock */ + ExAcquireFastMutex(&MessageWaiterQueue->WaiterQ.mLock); +} + +_Releases_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock) +_IRQL_restores_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock) +_IRQL_requires_(DISPATCH_LEVEL) +VOID +NTAPI +FltpReleaseMessageWaiterLock(_In_ PIO_CSQ Csq, + _In_ KIRQL Irql) +{ + PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue; + + UNREFERENCED_PARAMETER(Irql); + + /* Get the start of the waiter queue struct */ + MessageWaiterQueue = CONTAINING_RECORD(Csq, + FLT_MESSAGE_WAITER_QUEUE, + Csq); + + /* Release the IRP queue lock */ + ExReleaseFastMutex(&MessageWaiterQueue->WaiterQ.mLock); +} + +VOID +NTAPI +FltpCancelMessageWaiter(_In_ PIO_CSQ Csq, + _In_ PIRP Irp) +{ + /* Cancel the IRP */ + Irp->IoStatus.Status = STATUS_CANCELLED; + Irp->IoStatus.Information = 0; + IoCompleteRequest(Irp, IO_NO_INCREMENT); +} + + /* PRIVATE FUNCTIONS ******************************************************/ +static +NTSTATUS +CreateClientPort(_In_ PFILE_OBJECT FileObject, + _Inout_ PIRP Irp) +{ + PFLT_SERVER_PORT_OBJECT ServerPortObject = NULL; + OBJECT_ATTRIBUTES ObjectAttributes; + PFILTER_PORT_DATA FilterPortData; + PFLT_PORT_OBJECT ClientPortObject = NULL; + PFLT_PORT PortHandle = NULL; + PPORT_CCB PortCCB = NULL; + //ULONG BufferLength; + LONG NumConns; + NTSTATUS Status; + + /* We received the buffer via FilterConnectCommunicationPort, cast it back to its original form */ + FilterPortData = Irp->AssociatedIrp.SystemBuffer; + + /* Get a reference to the server port the filter created */ + Status = ObReferenceObjectByName(&FilterPortData->PortName, + 0, + 0, + FLT_PORT_ALL_ACCESS, + ServerPortObjectType, + ExGetPreviousMode(), + 0, + (PVOID *)&ServerPortObject); + if (!NT_SUCCESS(Status)) + { + return Status; + } + + /* Increment the number of connections on the server port */ + NumConns = InterlockedIncrement(&ServerPortObject->NumberOfConnections); + if (NumConns > ServerPortObject->MaxConnections) + { + Status = STATUS_CONNECTION_COUNT_LIMIT; + goto Quit; + } + + /* Initialize a basic kernel handle request */ + InitializeObjectAttributes(&ObjectAttributes, + NULL, + OBJ_KERNEL_HANDLE, + NULL, + NULL); + + /* Now create the new client port object */ + Status = ObCreateObject(KernelMode, + ClientPortObjectType, + &ObjectAttributes, + KernelMode, + NULL, + sizeof(FLT_PORT_OBJECT), + 0, + 0, + (PVOID *)&ClientPortObject); + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + + /* Clear out the buffer */ + RtlZeroMemory(ClientPortObject, sizeof(FLT_PORT_OBJECT)); + + /* Initialize the locks */ + ExInitializeRundownProtection(&ClientPortObject->MsgNotifRundownRef); + ExInitializeFastMutex(&ClientPortObject->Lock); + + /* Set the server port object this belongs to */ + ClientPortObject->ServerPort = ServerPortObject; + + /* Setup the message queue */ + Status = InitializeMessageWaiterQueue(&ClientPortObject->MsgQ); + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + + /* Create the CCB which we'll attach to the file object */ + PortCCB = CreatePortCCB(ClientPortObject); + if (PortCCB == NULL) + { + Status = STATUS_INSUFFICIENT_RESOURCES; + goto Quit; + } + + /* Now insert the new client port into the object manager*/ + Status = ObInsertObject(ClientPortObject, 0, FLT_PORT_ALL_ACCESS, 1, 0, (PHANDLE)&PortHandle); + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + + /* Add a reference to the filter to keep it alive while we do some work with it */ + Status = FltObjectReference(ServerPortObject->Filter); + if (NT_SUCCESS(Status)) + { + /* Invoke the callback to let the filter know we have a connection */ + Status = ServerPortObject->ConnectNotify(PortHandle, + ServerPortObject->Cookie, + NULL, //ConnectionContext + 0, //SizeOfContext + &ClientPortObject->Cookie); + if (NT_SUCCESS(Status)) + { + /* Add the client port CCB to the file object */ + FileObject->FsContext2 = PortCCB; + + /* Lock the port list on the filter and add this new port object to the list */ + ExAcquireFastMutex(&ServerPortObject->Filter->PortList.mLock); + InsertTailList(&ServerPortObject->Filter->PortList.mList, &ClientPortObject->FilterLink); + ExReleaseFastMutex(&ServerPortObject->Filter->PortList.mLock); + } + + /* We're done with the filter object, decremement the count */ + FltObjectDereference(ServerPortObject->Filter); + } + + +Quit: + if (!NT_SUCCESS(Status)) + { + if (ClientPortObject) + { + ObfDereferenceObject(ClientPortObject); + } + + if (PortHandle) + { + ZwClose(PortHandle); + } + else if (ServerPortObject) + { + InterlockedDecrement(&ServerPortObject->NumberOfConnections); + ObfDereferenceObject(ServerPortObject); + } + + if (PortCCB) + { + ExFreePoolWithTag(PortCCB, FM_TAG_CCB); + } + } + + return Status; +} + +static +NTSTATUS +CloseClientPort(_In_ PFILE_OBJECT FileObject, + _Inout_ PIRP Irp) +{ + PFLT_CCB Ccb; + + Ccb = (PFLT_CCB)FileObject->FsContext2; + + /* Remove the reference on the filter we added when we opened the port */ + ObDereferenceObject(Ccb->Data.Port.Port); + + // FIXME: Free the CCB + + return STATUS_SUCCESS; +} + +static +NTSTATUS +InitializeMessageWaiterQueue(_Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue) +{ + NTSTATUS Status; + + /* Setup the IRP queue */ + Status = IoCsqInitializeEx(&MsgWaiterQueue->Csq, + FltpAddMessageWaiter, + FltpRemoveMessageWaiter, + FltpGetNextMessageWaiter, + FltpAcquireMessageWaiterLock, + FltpReleaseMessageWaiterLock, + FltpCancelMessageWaiter); + if (!NT_SUCCESS(Status)) + { + return Status; + } + + /* Initialize the waiter queue */ + ExInitializeFastMutex(&MsgWaiterQueue->WaiterQ.mLock); + InitializeListHead(&MsgWaiterQueue->WaiterQ.mList); + MsgWaiterQueue->WaiterQ.mCount = 0; + + /* We don't have a minimum waiter length */ + MsgWaiterQueue->MinimumWaiterLength = (ULONG)-1; + + /* Init the semaphore and event used for counting and signaling available IRPs */ + KeInitializeSemaphore(&MsgWaiterQueue->Semaphore, 0, MAXLONG); + KeInitializeEvent(&MsgWaiterQueue->Event, NotificationEvent, FALSE); + + return STATUS_SUCCESS; +} + +static +PPORT_CCB +CreatePortCCB(_In_ PFLT_PORT_OBJECT PortObject) +{ + PPORT_CCB PortCCB; + + /* Allocate a CCB struct to hold the client port object info */ + PortCCB = ExAllocatePoolWithTag(NonPagedPool, sizeof(PPORT_CCB), FM_TAG_CCB); + if (PortCCB) + { + /* Initialize the structure */ + PortCCB->Port = PortObject; + PortCCB->ReplyWaiterList.mCount = 0; + ExInitializeFastMutex(&PortCCB->ReplyWaiterList.mLock); + KeInitializeEvent(&PortCCB->ReplyWaiterList.mLock.Event, SynchronizationEvent, 0); + } + + return PortCCB; +} \ No newline at end of file diff --git a/drivers/filters/fltmgr/Misc.c b/drivers/filters/fltmgr/Misc.c index 3ef1424d5f..8643c48882 100644 --- a/drivers/filters/fltmgr/Misc.c +++ b/drivers/filters/fltmgr/Misc.c @@ -29,8 +29,8 @@ FltBuildDefaultSecurityDescriptor( _In_ ACCESS_MASK DesiredAccess ) { - UNREFERENCED_PARAMETER(SecurityDescriptor); UNREFERENCED_PARAMETER(DesiredAccess); + *SecurityDescriptor = NULL; return 0; } diff --git a/drivers/filters/fltmgr/Registry.c b/drivers/filters/fltmgr/Registry.c new file mode 100644 index 0000000000..2470d70442 --- /dev/null +++ b/drivers/filters/fltmgr/Registry.c @@ -0,0 +1,142 @@ +/* + * PROJECT: Filesystem Filter Manager + * LICENSE: GPL - See COPYING in the top level directory + * FILE: drivers/filters/fltmgr/Misc.c + * PURPOSE: Uncataloged functions + * PROGRAMMERS: Ged Murphy (gedmur...@reactos.org) + */ + +/* INCLUDES ******************************************************************/ + +#include "fltmgr.h" +#include "fltmgrint.h" + +#define NDEBUG +#include <debug.h> + + +/* DATA *********************************************************************/ + +#define REG_SERVICES_KEY L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\" +#define REG_PATH_LENGTH 512 + + +/* INTERNAL FUNCTIONS ******************************************************/ + + +NTSTATUS +FltpOpenFilterServicesKey( + _In_ PFLT_FILTER Filter, + _In_ ACCESS_MASK DesiredAccess, + _In_opt_ PUNICODE_STRING SubKey, + _Out_ PHANDLE Handle) +{ + OBJECT_ATTRIBUTES ObjectAttributes; + UNICODE_STRING ServicesKey; + UNICODE_STRING Path; + WCHAR Buffer[REG_PATH_LENGTH]; + + /* Setup a local buffer to hold the services key path */ + Path.Length = 0; + Path.MaximumLength = REG_PATH_LENGTH; + Path.Buffer = Buffer; + + /* Build up the serices key name */ + RtlInitUnicodeString(&ServicesKey, REG_SERVICES_KEY); + RtlCopyUnicodeString(&Path, &ServicesKey); + RtlAppendUnicodeStringToString(&Path, &Filter->Name); + + if (SubKey) + { + /* Tag on any child key */ + RtlAppendUnicodeToString(&Path, L"\\"); + RtlAppendUnicodeStringToString(&Path, SubKey); + } + + InitializeObjectAttributes(&ObjectAttributes, + &Path, + OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE, + NULL, + NULL); + + /* Open and return the key handle param*/ + return ZwOpenKey(Handle, DesiredAccess, &ObjectAttributes); +} + +NTSTATUS +FltpReadRegistryValue(_In_ HANDLE KeyHandle, + _In_ PUNICODE_STRING ValueName, + _In_opt_ ULONG Type, + _Out_writes_bytes_(BufferSize) PVOID Buffer, + _In_ ULONG BufferSize, + _Out_opt_ PULONG BytesRequired) +{ + PKEY_VALUE_PARTIAL_INFORMATION Value = NULL; + ULONG ValueLength = 0; + NTSTATUS Status; + + PAGED_CODE(); + + /* Get the size of the buffer required to hold the string */ + Status = ZwQueryValueKey(KeyHandle, + ValueName, + KeyValuePartialInformation, + NULL, + 0, + &ValueLength); + if (Status != STATUS_BUFFER_TOO_SMALL && Status != STATUS_BUFFER_OVERFLOW) + { + return Status; + } + + /* Allocate the buffer */ + Value = (PKEY_VALUE_PARTIAL_INFORMATION)ExAllocatePoolWithTag(PagedPool, + ValueLength, + FM_TAG_TEMP_REGISTRY); + if (Value == NULL) + { + Status = STATUS_INSUFFICIENT_RESOURCES; + goto Quit; + } + + /* Now read in the value */ + Status = ZwQueryValueKey(KeyHandle, + ValueName, + KeyValuePartialInformation, + Value, + ValueLength, + &ValueLength); + if (!NT_SUCCESS(Status)) + { + goto Quit; + } + + /* Make sure we got the type expected */ + if (Value->Type != Type) + { + Status = STATUS_INVALID_PARAMETER; + goto Quit; + } + + if (BytesRequired) + { + *BytesRequired = Value->DataLength; + } + + /* Make sure the caller buffer is big enough to hold the data */ + if (!BufferSize || BufferSize < Value->DataLength) + { + Status = STATUS_BUFFER_TOO_SMALL; + goto Quit; + } + + /* Copy the data into the caller buffer */ + RtlCopyMemory(Buffer, Value->Data, Value->DataLength); + +Quit: + + if (Value) + ExFreePoolWithTag(Value, FM_TAG_TEMP_REGISTRY); + + return Status; +} diff --git a/drivers/filters/fltmgr/Registry.h b/drivers/filters/fltmgr/Registry.h new file mode 100644 index 0000000000..b295daefac --- /dev/null +++ b/drivers/filters/fltmgr/Registry.h @@ -0,0 +1,20 @@ +#pragma once + + +NTSTATUS +FltpOpenFilterServicesKey( + _In_ PFLT_FILTER Filter, + _In_ ACCESS_MASK DesiredAccess, + _In_opt_ PUNICODE_STRING SubKey, + _Out_ PHANDLE Handle +); + +NTSTATUS +FltpReadRegistryValue( + _In_ HANDLE KeyHandle, + _In_ PUNICODE_STRING ValueName, + _In_opt_ ULONG Type, + _Out_writes_bytes_(BufferSize) PVOID Buffer, + _In_ ULONG BufferSize, + _Out_opt_ PULONG BytesRequired +); \ No newline at end of file diff --git a/drivers/filters/fltmgr/fltmgr.h b/drivers/filters/fltmgr/fltmgr.h index 71e070b5f3..7008c1d094 100644 --- a/drivers/filters/fltmgr/fltmgr.h +++ b/drivers/filters/fltmgr/fltmgr.h @@ -18,6 +18,8 @@ #define FM_TAG_UNICODE_STRING 'suMF' #define FM_TAG_FILTER 'lfMF' #define FM_TAG_CONTEXT_REGISTA 'rcMF' +#define FM_TAG_CCB 'bcMF' +#define FM_TAG_TEMP_REGISTRY 'rtMF' #define MAX_DEVNAME_LENGTH 64 @@ -100,6 +102,9 @@ FltpReallocateUnicodeString(_In_ PUNICODE_STRING String, VOID FltpFreeUnicodeString(_In_ PUNICODE_STRING String); + + + //////////////////////////////////////////////// @@ -252,7 +257,7 @@ FltGetUpperInstance FltGetVolumeContext FltGetVolumeFromDeviceObject FltGetVolumeFromFileObject -FltGetVolumeFromInstance +FltLoadFilter FltGetVolumeFromName FltGetVolumeGuidName FltGetVolumeInstanceFromName diff --git a/drivers/filters/fltmgr/fltmgrint.h b/drivers/filters/fltmgr/fltmgrint.h index a060b70eec..64065be763 100644 --- a/drivers/filters/fltmgr/fltmgrint.h +++ b/drivers/filters/fltmgr/fltmgrint.h @@ -58,10 +58,45 @@ typedef struct _FLT_MUTEX_LIST_HEAD } FLT_MUTEX_LIST_HEAD, *PFLT_MUTEX_LIST_HEAD; +typedef struct _FLT_TYPE +{ + USHORT Signature; + USHORT Size; + +} FLT_TYPE, *PFLT_TYPE; + +// http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-1.html +typedef struct _FLTP_FRAME +{ + FLT_TYPE Type; + LIST_ENTRY Links; + unsigned int FrameID; + ERESOURCE AltitudeLock; + UNICODE_STRING AltitudeIntervalLow; + UNICODE_STRING AltitudeIntervalHigh; + char LargeIrpCtrlStackSize; + char SmallIrpCtrlStackSize; + FLT_RESOURCE_LIST_HEAD RegisteredFilters; + FLT_RESOURCE_LIST_HEAD AttachedVolumes; + LIST_ENTRY MountingVolumes; + FLT_MUTEX_LIST_HEAD AttachedFileSystems; + FLT_MUTEX_LIST_HEAD ZombiedFltObjectContexts; + ERESOURCE FilterUnloadLock; + FAST_MUTEX DeviceObjectAttachLock; + //FLT_PRCB *Prcb; + void *PrcbPoolToFree; + void *LookasidePoolToFree; + //FLTP_IRPCTRL_STACK_PROFILER IrpCtrlStackProfiler; + NPAGED_LOOKASIDE_LIST SmallIrpCtrlLookasideList; + NPAGED_LOOKASIDE_LIST LargeIrpCtrlLookasideList; + //STATIC_IRP_CONTROL GlobalSIC; + +} FLTP_FRAME, *PFLTP_FRAME; + typedef struct _FLT_FILTER // size = 0x120 { FLT_OBJECT Base; - PVOID Frame; //FLTP_FRAME + PFLTP_FRAME Frame; UNICODE_STRING Name; UNICODE_STRING DefaultAltitude; FLT_FILTER_FLAGS Flags; @@ -97,12 +132,7 @@ typedef enum _FLT_yINSTANCE_FLAGS } FLT_INSTANCE_FLAGS, *PFLT_INSTANCE_FLAGS; -typedef struct _FLT_TYPE -{ - USHORT Signature; - USHORT Size; -} FLT_TYPE, *PFLT_TYPE; typedef struct _FLT_INSTANCE // size = 0x144 (324) { @@ -121,34 +151,18 @@ typedef struct _FLT_INSTANCE // size = 0x144 (324) } FLT_INSTANCE, *PFLT_INSTANCE; -// http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-1.html -typedef struct _FLTP_FRAME + +typedef struct _TREE_ROOT { - FLT_TYPE Type; - LIST_ENTRY Links; - unsigned int FrameID; - ERESOURCE AltitudeLock; - UNICODE_STRING AltitudeIntervalLow; - UNICODE_STRING AltitudeIntervalHigh; - char LargeIrpCtrlStackSize; - char SmallIrpCtrlStackSize; - FLT_RESOURCE_LIST_HEAD RegisteredFilters; - FLT_RESOURCE_LIST_HEAD AttachedVolumes; - LIST_ENTRY MountingVolumes; - FLT_MUTEX_LIST_HEAD AttachedFileSystems; - FLT_MUTEX_LIST_HEAD ZombiedFltObjectContexts; - ERESOURCE FilterUnloadLock; - FAST_MUTEX DeviceObjectAttachLock; - //FLT_PRCB *Prcb; - void *PrcbPoolToFree; - void *LookasidePoolToFree; - //FLTP_IRPCTRL_STACK_PROFILER IrpCtrlStackProfiler; - NPAGED_LOOKASIDE_LIST SmallIrpCtrlLookasideList; - NPAGED_LOOKASIDE_LIST LargeIrpCtrlLookasideList; - //STATIC_IRP_CONTROL GlobalSIC; + RTL_SPLAY_LINKS *Tree; -} FLTP_FRAME, *PFLTP_FRAME; +} TREE_ROOT, *PTREE_ROOT; + +typedef struct _CONTEXT_LIST_CTRL +{ + TREE_ROOT List; +} CONTEXT_LIST_CTRL, *PCONTEXT_LIST_CTRL; // http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-6.html typedef struct _STREAM_LIST_CTRL // size = 0xC8 (200) @@ -156,16 +170,16 @@ typedef struct _STREAM_LIST_CTRL // size = 0xC8 (200) FLT_TYPE Type; FSRTL_PER_STREAM_CONTEXT ContextCtrl; LIST_ENTRY VolumeLink; - //STREAM_LIST_CTRL_FLAGS Flags; + ULONG Flags; //STREAM_LIST_CTRL_FLAGS Flags; int UseCount; ERESOURCE ContextLock; - //CONTEXT_LIST_CTRL StreamContexts; - //CONTEXT_LIST_CTRL StreamHandleContexts; + CONTEXT_LIST_CTRL StreamContexts; + CONTEXT_LIST_CTRL StreamHandleContexts; ERESOURCE NameCacheLock; LARGE_INTEGER LastRenameCompleted; - //NAME_CACHE_LIST_CTRL NormalizedNameCache; - // NAME_CACHE_LIST_CTRL ShortNameCache; - // NAME_CACHE_LIST_CTRL OpenedNameCache; + ULONG NormalizedNameCache; //NAME_CACHE_LIST_CTRL NormalizedNameCache; + ULONG ShortNameCache; // NAME_CACHE_LIST_CTRL ShortNameCache; + ULONG OpenedNameCache; // NAME_CACHE_LIST_CTRL OpenedNameCache; int AllNameContextsTemporary; } STREAM_LIST_CTRL, *PSTREAM_LIST_CTRL; @@ -186,6 +200,17 @@ typedef struct _FLT_SERVER_PORT_OBJECT } FLT_SERVER_PORT_OBJECT, *PFLT_SERVER_PORT_OBJECT; +typedef struct _FLT_MESSAGE_WAITER_QUEUE +{ + IO_CSQ Csq; + FLT_MUTEX_LIST_HEAD WaiterQ; + ULONG MinimumWaiterLength; + KSEMAPHORE Semaphore; + KEVENT Event; + +} FLT_MESSAGE_WAITER_QUEUE, *PFLT_MESSAGE_WAITER_QUEUE; + + typedef struct _FLT_PORT_OBJECT { LIST_ENTRY FilterLink; @@ -193,7 +218,7 @@ typedef struct _FLT_PORT_OBJECT PVOID Cookie; EX_RUNDOWN_REF MsgNotifRundownRef; FAST_MUTEX Lock; - PVOID MsgQ; // FLT_MESSAGE_WAITER_QUEUE MsgQ; + FLT_MESSAGE_WAITER_QUEUE MsgQ; ULONGLONG MessageId; KEVENT DisconnectEvent; BOOLEAN Disconnected; @@ -232,18 +257,6 @@ typedef struct _CALLBACK_CTRL } CALLBACK_CTRL, *PCALLBACK_CTRL; -typedef struct _TREE_ROOT -{ - RTL_SPLAY_LINKS *Tree; - -} TREE_ROOT, *PTREE_ROOT; - - -typedef struct _CONTEXT_LIST_CTRL -{ - TREE_ROOT List; - -} CONTEXT_LIST_CTRL, *PCONTEXT_LIST_CTRL; typedef struct _NAME_CACHE_LIST_CTRL_STATS { @@ -311,6 +324,58 @@ typedef struct _FLT_VOLUME } FLT_VOLUME, *PFLT_VOLUME; +typedef struct _MANAGER_CCB +{ + PFLTP_FRAME Frame; + unsigned int Iterator; + +} MANAGER_CCB, *PMANAGER_CCB; + +typedef struct _FILTER_CCB +{ + PFLT_FILTER Filter; + unsigned int Iterator; + +} FILTER_CCB, *PFILTER_CCB; + +typedef struct _INSTANCE_CCB +{ + PFLT_INSTANCE Instance; + +} INSTANCE_CCB, *PINSTANCE_CCB; + +typedef struct _VOLUME_CCB +{ + UNICODE_STRING Volume; + unsigned int Iterator; + +} VOLUME_CCB, *PVOLUME_CCB; + +typedef struct _PORT_CCB +{ + PFLT_PORT_OBJECT Port; + FLT_MUTEX_LIST_HEAD ReplyWaiterList; + +} PORT_CCB, *PPORT_CCB; + + +typedef union _CCB_TYPE +{ + MANAGER_CCB Manager; + FILTER_CCB Filter; + INSTANCE_CCB Instance; + VOLUME_CCB Volume; + PORT_CCB Port; + +} CCB_TYPE, *PCCB_TYPE; + + +typedef struct _FLT_CCB +{ + FLT_TYPE Type; + CCB_TYPE Data; + +} FLT_CCB, *PFLT_CCB; VOID FltpExInitializeRundownProtection( @@ -387,6 +452,21 @@ FltpDispatchHandler( _Inout_ PIRP Irp ); +NTSTATUS +FltpMsgCreate( + _In_ PDEVICE_OBJECT DeviceObject, + _Inout_ PIRP Irp +); +NTSTATUS +FltpMsgDispatch( + _In_ PDEVICE_OBJECT DeviceObject, + _Inout_ PIRP Irp +); + +NTSTATUS +FltpSetupCommunicationObjects( + _In_ PDRIVER_OBJECT DriverObject +); #endif /* _FLTMGR_INTERNAL_H */ diff --git a/modules/rostests/kmtests/CMakeLists.txt b/modules/rostests/kmtests/CMakeLists.txt index 64b088d221..c66ed9a38e 100644 --- a/modules/rostests/kmtests/CMakeLists.txt +++ b/modules/rostests/kmtests/CMakeLists.txt @@ -127,6 +127,10 @@ list(APPEND KMTEST_SOURCE kmtest/testlist.c example/Example_user.c + + fltmgr/fltmgr_load/fltmgr_user.c + fltmgr/fltmgr_register/fltmgr_reg_user.c + hidparse/HidP_user.c kernel32/FileAttributes_user.c kernel32/FindFile_user.c diff --git a/modules/rostests/kmtests/fltmgr/CMakeLists.txt b/modules/rostests/kmtests/fltmgr/CMakeLists.txt index fd7a55e6bc..f0fdac80bc 100644 --- a/modules/rostests/kmtests/fltmgr/CMakeLists.txt +++ b/modules/rostests/kmtests/fltmgr/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(fltmgr_load) add_subdirectory(fltmgr_create) +add_subdirectory(fltmgr_register) diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_load/CMakeLists.txt b/modules/rostests/kmtests/fltmgr/fltmgr_load/CMakeLists.txt index 7fa022d154..f93e516be8 100644 --- a/modules/rostests/kmtests/fltmgr/fltmgr_load/CMakeLists.txt +++ b/modules/rostests/kmtests/fltmgr/fltmgr_load/CMakeLists.txt @@ -5,10 +5,10 @@ list(APPEND FLTMGR_TEST_DRV_SOURCE ../../kmtest_drv/kmtest_fsminifilter.c fltmgr_load.c) -add_library(fltmgr_load SHARED ${FLTMGR_TEST_DRV_SOURCE}) -set_module_type(fltmgr_load kernelmodedriver) -target_link_libraries(fltmgr_load kmtest_printf ${PSEH_LIB}) -add_importlibs(fltmgr_load fltmgr ntoskrnl hal) -add_target_compile_definitions(fltmgr_load KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1) +add_library(FltMgrLoad_drv SHARED ${FLTMGR_TEST_DRV_SOURCE}) +set_module_type(FltMgrLoad_drv kernelmodedriver) +target_link_libraries(FltMgrLoad_drv kmtest_printf ${PSEH_LIB}) +add_importlibs(FltMgrLoad_drv fltmgr ntoskrnl hal) +add_target_compile_definitions(FltMgrLoad_drv KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1) #add_pch(example_drv ../include/kmt_test.h) -add_rostests_file(TARGET fltmgr_load) +add_rostests_file(TARGET FltMgrLoad_drv) diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_load.c b/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_load.c index 7b9f5235cb..e56d7eb62a 100644 --- a/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_load.c +++ b/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_load.c @@ -89,7 +89,7 @@ TestEntry( ok_irql(PASSIVE_LEVEL); TestDriverObject = DriverObject; - *DeviceName = L"fltmgr_load"; + *DeviceName = L"FltMgrLoad"; trace("Hi, this is the filter manager load test driver\n"); diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c b/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c new file mode 100644 index 0000000000..727e9c0703 --- /dev/null +++ b/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c @@ -0,0 +1,28 @@ +/* + * PROJECT: ReactOS kernel-mode tests - Filter Manager + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Tests for checking filters load and connect correctly + * PROGRAMMER: Ged Murphy <gedmur...@reactos.org> + */ + +#include <kmt_test.h> + + +START_TEST(FltMgrLoad) +{ + static WCHAR FilterName[] = L"FltMgrLoad"; + SC_HANDLE hService; + HANDLE hPort; + + trace("Message from user-mode\n"); + + ok(KmtFltCreateService(FilterName, L"FltMgrLoad test driver", &hService) == ERROR_SUCCESS, "\n"); + ok(KmtFltLoadDriver(FALSE, FALSE, FALSE, &hPort) == ERROR_PRIVILEGE_NOT_HELD, "\n"); + ok(KmtFltLoadDriver(TRUE, FALSE, FALSE, &hPort) == ERROR_SUCCESS, "\n"); + + ok(KmtFltConnectComms(&hPort) == ERROR_SUCCESS, "\n"); + + ok(KmtFltDisconnectComms(hPort) == ERROR_SUCCESS, "\n"); + ok(KmtFltUnloadDriver(hPort, FALSE) == ERROR_SUCCESS, "\n"); + KmtFltDeleteService(NULL, &hService); +} diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt b/modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt new file mode 100644 index 0000000000..2857a32bd6 --- /dev/null +++ b/modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt @@ -0,0 +1,15 @@ + +include_directories(../../include) +include_directories(${REACTOS_SOURCE_DIR}/drivers/filters/fltmgr) + +list(APPEND FLTMGR_TEST_DRV_SOURCE + ../../kmtest_drv/kmtest_fsminifilter.c + fltmgr_register.c) + +add_library(fltmgrreg_drv SHARED ${FLTMGR_TEST_DRV_SOURCE}) +set_module_type(fltmgrreg_drv kernelmodedriver) +target_link_libraries(fltmgrreg_drv kmtest_printf ${PSEH_LIB}) +add_importlibs(fltmgrreg_drv fltmgr ntoskrnl hal) +add_target_compile_definitions(fltmgrreg_drv KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1) +#add_pch(example_drv ../include/kmt_test.h) +add_rostests_file(TARGET fltmgrreg_drv) diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c new file mode 100644 index 0000000000..da7d072988 --- /dev/null +++ b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c @@ -0,0 +1,23 @@ +/* + * PROJECT: ReactOS kernel-mode tests - Filter Manager + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Tests for checking filter registration + * PROGRAMMER: Ged Murphy <gedmur...@reactos.org> + */ + +#include <kmt_test.h> + + +START_TEST(FltMgrReg) +{ + static WCHAR FilterName[] = L"FltMgrReg"; + SC_HANDLE hService; + HANDLE hPort; + + ok(KmtFltCreateService(FilterName, L"FltMgrLoad test driver", &hService) == ERROR_SUCCESS, "Failed to create the reg entry\n"); + ok(KmtFltAddAltitude(L"123456") == ERROR_SUCCESS, "\n"); + ok(KmtFltLoadDriver(TRUE, FALSE, FALSE, &hPort) == ERROR_SUCCESS, "Failed to load the driver\n"); + //__debugbreak(); + ok(KmtFltUnloadDriver(hPort, FALSE) == ERROR_SUCCESS, "Failed to unload the driver\n"); + ok(KmtFltDeleteService(NULL, &hService) == ERROR_SUCCESS, "Failed to delete the driver\n"); +} diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c new file mode 100644 index 0000000000..9703dfa8fd --- /dev/null +++ b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c @@ -0,0 +1,248 @@ +/* + * PROJECT: ReactOS kernel-mode tests - Filter Manager + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Tests for checking filter registration + * PROGRAMMER: Ged Murphy <gedmur...@reactos.org> + */ + +// This tests needs to be run via a standalone driver because FltRegisterFilter +// uses the DriverObject in its internal structures, and we don't want it to be +// linked to a device object from the test suite itself. + +#include <kmt_test.h> +#include <fltkernel.h> +#include <fltmgrint.h> + +//#define NDEBUG +#include <debug.h> + +#define RESET_REGISTRATION(basic) \ + do { \ + RtlZeroMemory(&FilterRegistration, sizeof(FLT_REGISTRATION)); \ + if (basic) { \ + FilterRegistration.Size = sizeof(FLT_REGISTRATION); \ + FilterRegistration.Version = FLT_REGISTRATION_VERSION; \ + } \ + } while (0) + +#define RESET_UNLOAD(DO) DO->DriverUnload = NULL; + + +NTSTATUS +FLTAPI +TestRegFilterUnload( + _In_ FLT_FILTER_UNLOAD_FLAGS Flags +); + +/* Globals */ +static PDRIVER_OBJECT TestDriverObject; +static FLT_REGISTRATION FilterRegistration; +static PFLT_FILTER TestFilter = NULL; + + + + +BOOLEAN +TestFltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject) +{ + UNICODE_STRING Altitude; + UNICODE_STRING Name; + PFLT_FILTER Filter = NULL; + PFLT_FILTER Temp = NULL; + NTSTATUS Status; + + RESET_REGISTRATION(FALSE); +#if 0 + KmtStartSeh() + Status = FltRegisterFilter(NULL, &FilterRegistration, &Filter); + KmtEndSeh(STATUS_INVALID_PARAMETER); + + KmtStartSeh() + Status = FltRegisterFilter(DriverObject, NULL, &Filter); + KmtEndSeh(STATUS_INVALID_PARAMETER); + + KmtStartSeh() + Status = FltRegisterFilter(DriverObject, &FilterRegistration, NULL); + KmtEndSeh(STATUS_INVALID_PARAMETER) +#endif + + RESET_REGISTRATION(TRUE); + FilterRegistration.Version = 0x0100; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_INVALID_PARAMETER); + + RESET_REGISTRATION(TRUE); + FilterRegistration.Version = 0x0300; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_INVALID_PARAMETER); + + RESET_REGISTRATION(TRUE); + FilterRegistration.Version = 0x0200; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_SUCCESS); + FltUnregisterFilter(Filter); + + + /* Test invalid sizes. MSDN says this is required, but it doesn't appear to be */ + RESET_REGISTRATION(TRUE); + FilterRegistration.Size = 0; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_SUCCESS); + FltUnregisterFilter(Filter); + + RESET_REGISTRATION(TRUE); + FilterRegistration.Size = 0xFFFF; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_SUCCESS); + FltUnregisterFilter(Filter); + + + /* Now make a valid registration */ + RESET_REGISTRATION(TRUE); + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter); + ok_eq_hex(Status, STATUS_SUCCESS); + + /* Try to register again */ + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Temp); + ok_eq_hex(Status, STATUS_FLT_INSTANCE_ALTITUDE_COLLISION); + + + ok_eq_hex(Filter->Base.Flags, FLT_OBFL_TYPE_FILTER); + + /* Check we have the right filter name */ + RtlInitUnicodeString(&Name, L"Kmtest-FltMgrReg"); + ok_eq_long(RtlCompareUnicodeString(&Filter->Name, &Name, FALSE), 0); + + /* And the altitude is corect */ + RtlInitUnicodeString(&Altitude, L"123456"); + ok_eq_long(RtlCompareUnicodeString(&Filter->DefaultAltitude, &Altitude, FALSE), 0); + + // + // FIXME: More checks + // + + /* Cleanup the valid registration */ + FltUnregisterFilter(Filter); + + /* + * The last thing we'll do before we exit is to properly register with the filter manager + * and set an unload routine. This'll let us test the FltUnregisterFilter routine + */ + RESET_REGISTRATION(TRUE); + + /* Set a fake unload routine we'll use to test */ + DriverObject->DriverUnload = (PDRIVER_UNLOAD)0x1234FFFF; + + FilterRegistration.FilterUnloadCallback = TestRegFilterUnload; + Status = FltRegisterFilter(DriverObject, &FilterRegistration, &TestFilter); + ok_eq_hex(Status, STATUS_SUCCESS); + + /* Test all the unlod routines */ + ok_eq_pointer(TestFilter->FilterUnload, TestRegFilterUnload); + ok_eq_pointer(TestFilter->OldDriverUnload, (PFLT_FILTER_UNLOAD_CALLBACK)0x1234FFFF); + + // This should equal the fltmgr's private unload routine, but there's no easy way of testing it... + //ok_eq_pointer(DriverObject->DriverUnload, FltpMiniFilterDriverUnload); + + /* Make sure our test address is never actually called */ + TestFilter->OldDriverUnload = (PFLT_FILTER_UNLOAD_CALLBACK)NULL; + + return TRUE; +} + + +NTSTATUS +FLTAPI +TestRegFilterUnload( + _In_ FLT_FILTER_UNLOAD_FLAGS Flags) +{ + //__debugbreak(); + + ok_irql(PASSIVE_LEVEL); + ok(TestFilter != NULL, "Buffer is NULL\n"); + + // + // FIXME: Add tests + // + + FltUnregisterFilter(TestFilter); + + // + // FIXME: Add tests + // + + return STATUS_SUCCESS; +} + + + + + + + + + + +/* + * KMT Callback routines + */ + +NTSTATUS +TestEntry( + IN PDRIVER_OBJECT DriverObject, + IN PCUNICODE_STRING RegistryPath, + OUT PCWSTR *DeviceName, + IN OUT INT *Flags) +{ + NTSTATUS Status = STATUS_SUCCESS; + + PAGED_CODE(); + + UNREFERENCED_PARAMETER(RegistryPath); + + DPRINT("FltMgrReg Entry!\n"); + trace("Entered FltMgrReg tests\n"); + + /* We'll do the work ourselves in this test */ + *Flags = TESTENTRY_NO_ALL; + + ok_irql(PASSIVE_LEVEL); + TestDriverObject = DriverObject; + + + /* Run the tests */ + (VOID)TestFltRegisterFilter(DriverObject); + + return Status; +} + +VOID +TestFilterUnload( + IN ULONG Flags) +{ + PAGED_CODE(); + ok_irql(PASSIVE_LEVEL); +} + +NTSTATUS +TestInstanceSetup( + _In_ PCFLT_RELATED_OBJECTS FltObjects, + _In_ FLT_INSTANCE_SETUP_FLAGS Flags, + _In_ DEVICE_TYPE VolumeDeviceType, + _In_ FLT_FILESYSTEM_TYPE VolumeFilesystemType, + _In_ PUNICODE_STRING VolumeName, + _In_ ULONG SectorSize, + _In_ ULONG ReportedSectorSize +) +{ + return STATUS_FLT_DO_NOT_ATTACH; +} + +VOID +TestQueryTeardown( + _In_ PCFLT_RELATED_OBJECTS FltObjects, + _In_ FLT_INSTANCE_QUERY_TEARDOWN_FLAGS Flags) +{ + UNREFERENCED_PARAMETER(FltObjects); + UNREFERENCED_PARAMETER(Flags); +} diff --git a/modules/rostests/kmtests/include/kmt_platform.h b/modules/rostests/kmtests/include/kmt_platform.h index 636f099cf6..3e29d2c9af 100644 --- a/modules/rostests/kmtests/include/kmt_platform.h +++ b/modules/rostests/kmtests/include/kmt_platform.h @@ -44,6 +44,7 @@ #include <strsafe.h> #include <fltuser.h> + #ifdef KMT_EMULATE_KERNEL #define ok_irql(i) #define KIRQL int diff --git a/modules/rostests/kmtests/include/kmt_test.h b/modules/rostests/kmtests/include/kmt_test.h index 8fcc93a419..1c5ddeefa2 100644 --- a/modules/rostests/kmtests/include/kmt_test.h +++ b/modules/rostests/kmtests/include/kmt_test.h @@ -126,9 +126,12 @@ NTSTATUS KmtFilterRegisterCallbacks(_In_ CONST FLT_OPERATION_REGISTRATION *Opera typedef enum { - TESTENTRY_NO_REGISTER_FILTER = 1, - TESTENTRY_NO_CREATE_COMMS_PORT = 2, - TESTENTRY_NO_START_FILTERING = 4, + TESTENTRY_NO_REGISTER_FILTER = 0x01, + TESTENTRY_NO_CREATE_COMMS_PORT = 0x02, + TESTENTRY_NO_START_FILTERING = 0x04, + TESTENTRY_NO_INSTANCE_SETUP = 0x08, + TESTENTRY_NO_QUERY_TEARDOWN = 0x10, + TESTENTRY_NO_ALL = 0xFF } KMT_MINIFILTER_FLAGS; VOID TestFilterUnload(_In_ ULONG Flags); @@ -174,8 +177,14 @@ DWORD KmtSendWStringToDriver(IN DWORD ControlCode, IN PCWSTR String); DWORD KmtSendUlongToDriver(IN DWORD ControlCode, IN DWORD Value); DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer OPTIONAL, IN DWORD InLength, IN OUT PDWORD OutLength); -DWORD KmtFltLoadDriver(_In_z_ PCWSTR ServiceName, _In_ BOOLEAN RestartIfRunning, _In_ BOOLEAN ConnectComms, _Out_ HANDLE *hPort); + +DWORD KmtFltCreateService(_In_z_ PCWSTR ServiceName, _In_z_ PCWSTR DisplayName, _Out_ SC_HANDLE *ServiceHandle); +DWORD KmtFltDeleteService(_In_opt_z_ PCWSTR ServiceName, _Inout_ SC_HANDLE *ServiceHandle); +DWORD KmtFltAddAltitude(_In_z_ LPWSTR Altitude); +DWORD KmtFltLoadDriver(_In_ BOOLEAN EnableDriverLoadPrivlege, _In_ BOOLEAN RestartIfRunning, _In_ BOOLEAN ConnectComms, _Out_ HANDLE *hPort); DWORD KmtFltUnloadDriver(_In_ HANDLE *hPort, _In_ BOOLEAN DisonnectComms); +DWORD KmtFltConnectComms(_Out_ HANDLE *hPort); +DWORD KmtFltDisconnectComms(_In_ HANDLE hPort); DWORD KmtFltRunKernelTest(_In_ HANDLE hPort, _In_z_ PCSTR TestName); DWORD KmtFltSendToDriver(_In_ HANDLE hPort, _In_ DWORD Message); DWORD KmtFltSendStringToDriver(_In_ HANDLE hPort, _In_ DWORD Message, _In_ PCSTR String); diff --git a/modules/rostests/kmtests/kmtest/filter.c b/modules/rostests/kmtests/kmtest/filter.c index cb5957477c..042d1743b2 100644 --- a/modules/rostests/kmtests/kmtest/filter.c +++ b/modules/rostests/kmtests/kmtest/filter.c @@ -14,50 +14,8 @@ #define SERVICE_ACCESS (SERVICE_START | SERVICE_STOP | DELETE) -/* - * We need to call the internal function in the service.c file - */ -DWORD -KmtpCreateService( - IN PCWSTR ServiceName, - IN PCWSTR ServicePath, - IN PCWSTR DisplayName OPTIONAL, - IN DWORD ServiceType, - OUT SC_HANDLE *ServiceHandle); - -static SC_HANDLE ScmHandle; -/** - * @name KmtFltCreateService - * - * Create the specified driver service and return a handle to it - * - * @param ServiceName - * Name of the service to create - * @param ServicePath - * File name of the driver, relative to the current directory - * @param DisplayName - * Service display name - * @param ServiceHandle - * Pointer to a variable to receive the handle to the service - * - * @return Win32 error code - */ -DWORD -KmtFltCreateService( - _In_z_ PCWSTR ServiceName, - _In_z_ PCWSTR ServicePath, - _In_z_ PCWSTR DisplayName OPTIONAL, - _Out_ SC_HANDLE *ServiceHandle) -{ - return KmtpCreateService(ServiceName, - ServicePath, - DisplayName, - SERVICE_FILE_SYSTEM_DRIVER, - ServiceHandle); -} - /** * @name KmtFltLoad * @@ -82,7 +40,7 @@ KmtFltLoad( return Error; } - +#if 0 /** * @name KmtFltCreateAndStartService * @@ -143,7 +101,7 @@ cleanup: assert(Error); return Error; } - +#endif /** * @name KmtFltConnect @@ -163,7 +121,7 @@ KmtFltConnect( _Out_ HANDLE *hPort) { HRESULT hResult; - DWORD Error = ERROR_SUCCESS; + DWORD Error; assert(ServiceName); assert(hPort); @@ -191,7 +149,7 @@ KmtFltConnect( */ DWORD KmtFltDisconnect( - _Out_ HANDLE *hPort) + _In_ HANDLE hPort) { DWORD Error = ERROR_SUCCESS; @@ -388,41 +346,3 @@ KmtFltUnload( return Error; } - -/** - * @name KmtFltDeleteService - * - * Delete the specified filter driver - * - * @param ServiceName - * If *ServiceHandle is NULL, name of the service to delete - * @param ServiceHandle - * Pointer to a variable containing the service handle. - * Will be set to NULL on success - * - * @return Win32 error code - */ -DWORD -KmtFltDeleteService( - _In_z_ PCWSTR ServiceName OPTIONAL, - _Inout_ SC_HANDLE *ServiceHandle) -{ - return KmtDeleteService(ServiceName, ServiceHandle); -} - -/** - * @name KmtFltCloseService - * - * Close the specified driver service handle - * - * @param ServiceHandle - * Pointer to a variable containing the service handle. - * Will be set to NULL on success - * - * @return Win32 error code - */ -DWORD KmtFltCloseService( - _Inout_ SC_HANDLE *ServiceHandle) -{ - return KmtCloseService(ServiceHandle); -} diff --git a/modules/rostests/kmtests/kmtest/fltsupport.c b/modules/rostests/kmtests/kmtest/fltsupport.c index 772e98a211..4127ccad24 100644 --- a/modules/rostests/kmtests/kmtest/fltsupport.c +++ b/modules/rostests/kmtests/kmtest/fltsupport.c @@ -7,12 +7,28 @@ #include <kmt_test.h> +#define KMT_FLT_USER_MODE #include "kmtest.h" #include <kmt_public.h> #include <assert.h> #include <debug.h> +/* + * We need to call the internal function in the service.c file + */ +DWORD +KmtpCreateService( + IN PCWSTR ServiceName, + IN PCWSTR ServicePath, + IN PCWSTR DisplayName OPTIONAL, + IN DWORD ServiceType, + OUT SC_HANDLE *ServiceHandle); + +DWORD EnablePrivilegeInCurrentProcess( + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable); + // move to a shared location typedef struct _KMTFLT_MESSAGE_HEADER { @@ -26,33 +42,30 @@ extern HANDLE KmtestHandle; static WCHAR TestServiceName[MAX_PATH]; + /** - * @name KmtFltLoadDriver + * @name KmtFltCreateService * - * Load the specified filter driver - * This routine will create the service entry if it doesn't already exist + * Create the specified driver service and return a handle to it * * @param ServiceName - * Name of the driver service (Kmtest- prefix will be added automatically) - * @param RestartIfRunning - * TRUE to stop and restart the service if it is already running - * @param ConnectComms - * TRUE to create a comms connection to the specified filter - * @param hPort - * Handle to the filter's comms port + * Name of the service to create + * @param ServicePath + * File name of the driver, relative to the current directory + * @param DisplayName + * Service display name + * @param ServiceHandle + * Pointer to a variable to receive the handle to the service * * @return Win32 error code */ DWORD -KmtFltLoadDriver( +KmtFltCreateService( _In_z_ PCWSTR ServiceName, - _In_ BOOLEAN RestartIfRunning, - _In_ BOOLEAN ConnectComms, - _Out_ HANDLE *hPort) + _In_z_ PCWSTR DisplayName, + _Out_ SC_HANDLE *ServiceHandle) { - DWORD Error = ERROR_SUCCESS; WCHAR ServicePath[MAX_PATH]; - SC_HANDLE TestServiceHandle; StringCbCopy(ServicePath, sizeof ServicePath, ServiceName); StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys"); @@ -60,11 +73,81 @@ KmtFltLoadDriver( StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-"); StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName); - Error = KmtFltCreateAndStartService(TestServiceName, ServicePath, NULL, &TestServiceHandle, TRUE); + return KmtpCreateService(TestServiceName, + ServicePath, + DisplayName, + SERVICE_FILE_SYSTEM_DRIVER, + ServiceHandle); +} + +/** + * @name KmtFltDeleteService + * + * Delete the specified filter driver + * + * @param ServiceName + * If *ServiceHandle is NULL, name of the service to delete + * @param ServiceHandle + * Pointer to a variable containing the service handle. + * Will be set to NULL on success + * + * @return Win32 error code + */ +DWORD +KmtFltDeleteService( + _In_opt_z_ PCWSTR ServiceName, + _Inout_ SC_HANDLE *ServiceHandle) +{ + return KmtDeleteService(ServiceName, ServiceHandle); +} - if (Error == ERROR_SUCCESS && ConnectComms) +/** + * @name KmtFltLoadDriver + * + * Delete the specified filter driver + * + * @return Win32 error code + */ +DWORD +KmtFltLoadDriver( + _In_ BOOLEAN EnableDriverLoadPrivlege, + _In_ BOOLEAN RestartIfRunning, + _In_ BOOLEAN ConnectComms, + _Out_ HANDLE *hPort +) +{ + DWORD Error; + + if (EnableDriverLoadPrivlege) + { + Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE); + if (Error) + { + return Error; + } + } + + Error = KmtFltLoad(TestServiceName); + if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning) + { + Error = KmtFltUnload(TestServiceName); + if (Error) + { + // TODO + __debugbreak(); + } + + Error = KmtFltLoad(TestServiceName); + } + + if (Error) { - Error = KmtFltConnect(ServiceName, hPort); + return Error; + } + + if (ConnectComms) + { + Error = KmtFltConnectComms(hPort); } return Error; @@ -77,7 +160,7 @@ KmtFltLoadDriver( * * @param hPort * Handle to the filter's comms port - * @param ConnectComms + * @param DisonnectComms * TRUE to disconnect the comms connection before unloading * * @return Win32 error code @@ -110,6 +193,57 @@ KmtFltUnloadDriver( return Error; } +/** + * @name KmtFltConnectComms + * + * Create a comms connection to the specified filter + * + * @param hPort + * Handle to the filter's comms port + * + * @return Win32 error code + */ +DWORD +KmtFltConnectComms( + _Out_ HANDLE *hPort) +{ + return KmtFltConnect(TestServiceName, hPort); +} + +/** + * @name KmtFltDisconnectComms + * + * Disconenct from the comms port + * + * @param hPort + * Handle to the filter's comms port + * + * @return Win32 error code + */ +DWORD +KmtFltDisconnectComms( + _In_ HANDLE hPort) +{ + return KmtFltDisconnect(hPort); +} + + +/** +* @name KmtFltCloseService +* +* Close the specified driver service handle +* +* @param ServiceHandle +* Pointer to a variable containing the service handle. +* Will be set to NULL on success +* +* @return Win32 error code +*/ +DWORD KmtFltCloseService( + _Inout_ SC_HANDLE *ServiceHandle) +{ + return KmtCloseService(ServiceHandle); +} /** * @name KmtFltRunKernelTest @@ -141,7 +275,7 @@ KmtFltRunKernelTest( * @param Message * The message to send to the filter * -* @return Win32 error code as returned by DeviceIoControl +* @return Win32 error code */ DWORD KmtFltSendToDriver( @@ -165,7 +299,7 @@ KmtFltSendToDriver( * @param String * An ANSI string to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendStringToDriver( @@ -190,7 +324,7 @@ KmtFltSendStringToDriver( * @param String * An wide string to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendWStringToDriver( @@ -213,7 +347,7 @@ KmtFltSendWStringToDriver( * @param Value * An 32bit valueng to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendUlongToDriver( @@ -244,7 +378,7 @@ KmtFltSendUlongToDriver( * @param BytesReturned * Number of bytes written in the reply buffer * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendBufferToDriver( @@ -299,3 +433,154 @@ KmtFltSendBufferToDriver( return Error; } + +/** +* @name KmtFltAddAltitude +* +* Sets up the mini-filter altitude data in the registry +* +* @param hPort +* The altitude string to set +* +* @return Win32 error code +*/ +DWORD +KmtFltAddAltitude( + _In_z_ LPWSTR Altitude) +{ + WCHAR DefaultInstance[128]; + WCHAR KeyPath[256]; + HKEY hKey = NULL; + HKEY hSubKey = NULL; + DWORD Zero = 0; + LONG Error; + + StringCbCopy(KeyPath, sizeof KeyPath, L"SYSTEM\\CurrentControlSet\\Services\\"); + StringCbCat(KeyPath, sizeof KeyPath, TestServiceName); + StringCbCat(KeyPath, sizeof KeyPath, L"\\Instances\\"); + + Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE, + KeyPath, + 0, + NULL, + REG_OPTION_NON_VOLATILE, + KEY_CREATE_SUB_KEY | KEY_SET_VALUE, + NULL, + &hKey, + NULL); + if (Error != ERROR_SUCCESS) + { + return Error; + } + + StringCbCopy(DefaultInstance, sizeof DefaultInstance, TestServiceName); + StringCbCat(DefaultInstance, sizeof DefaultInstance, L" Instance"); + + Error = RegSetValueExW(hKey, + L"DefaultInstance", + 0, + REG_SZ, + (LPBYTE)DefaultInstance, + (wcslen(DefaultInstance) + 1) * sizeof(WCHAR)); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegSetValueExW(hSubKey, + L"Altitude", + 0, + REG_SZ, + (LPBYTE)Altitude, + (wcslen(Altitude) + 1) * sizeof(WCHAR)); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegSetValueExW(hSubKey, + L"Flags", + 0, + REG_DWORD, + (LPBYTE)&Zero, + sizeof(DWORD)); + +Quit: + if (hSubKey) + { + RegCloseKey(hSubKey); + } + if (hKey) + { + RegCloseKey(hKey); + } + + return Error; + +} + +/* +* Private functions, not meant for use in kmtests +*/ + +DWORD EnablePrivilege( + _In_ HANDLE hToken, + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable) +{ + TOKEN_PRIVILEGES TokenPrivileges; + LUID luid; + BOOL bSuccess; + DWORD dwError = ERROR_SUCCESS; + + /* Get the luid for this privilege */ + if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid)) + return GetLastError(); + + /* Setup the struct with the priv info */ + TokenPrivileges.PrivilegeCount = 1; + TokenPrivileges.Privileges[0].Luid = luid; + TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0; + + /* Enable the privilege info in the token */ + bSuccess = AdjustTokenPrivileges(hToken, + FALSE, + &TokenPrivileges, + sizeof(TOKEN_PRIVILEGES), + NULL, + NULL); + if (bSuccess == FALSE) dwError = GetLastError(); + + /* return status */ + return dwError; +} + +DWORD EnablePrivilegeInCurrentProcess( + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable) +{ + HANDLE hToken; + BOOL bSuccess; + DWORD dwError = ERROR_SUCCESS; + + /* Get a handle to our token */ + bSuccess = OpenProcessToken(GetCurrentProcess(), + TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, + &hToken); + if (bSuccess == FALSE) return GetLastError(); + + /* Enable the privilege in the agent token */ + dwError = EnablePrivilege(hToken, lpPrivName, bEnable); + + /* We're done with this now */ + CloseHandle(hToken); + + /* return status */ + return dwError; +} diff --git a/modules/rostests/kmtests/kmtest/kmtest.h b/modules/rostests/kmtests/kmtest/kmtest.h index 7329d18e48..f39c418b5d 100644 --- a/modules/rostests/kmtests/kmtest/kmtest.h +++ b/modules/rostests/kmtests/kmtest/kmtest.h @@ -60,16 +60,9 @@ KmtDeleteService( DWORD KmtCloseService( IN OUT SC_HANDLE *ServiceHandle); - -/* FS Filter management functions */ -DWORD -KmtFltCreateService( - _In_z_ PCWSTR ServiceName, - _In_z_ PCWSTR ServicePath, - _In_z_ PCWSTR DisplayName OPTIONAL, - _Out_ SC_HANDLE *ServiceHandle); +#ifdef KMT_FLT_USER_MODE DWORD KmtFltLoad( @@ -132,4 +125,6 @@ KmtFltDeleteService( DWORD KmtFltCloseService( _Inout_ SC_HANDLE *ServiceHandle); +#endif /* KMT_FILTER_DRIVER */ + #endif /* !defined _KMTESTS_H_ */ diff --git a/modules/rostests/kmtests/kmtest/testlist.c b/modules/rostests/kmtests/kmtest/testlist.c index 562691d6f0..01cb31cf53 100644 --- a/modules/rostests/kmtests/kmtest/testlist.c +++ b/modules/rostests/kmtests/kmtest/testlist.c @@ -11,6 +11,8 @@ KMT_TESTFUNC Test_CcCopyRead; KMT_TESTFUNC Test_Example; KMT_TESTFUNC Test_FileAttributes; KMT_TESTFUNC Test_FindFile; +KMT_TESTFUNC Test_FltMgrLoad; +KMT_TESTFUNC Test_FltMgrReg; KMT_TESTFUNC Test_HidPDescription; KMT_TESTFUNC Test_IoCreateFile; KMT_TESTFUNC Test_IoDeviceObject; @@ -37,6 +39,8 @@ const KMT_TEST TestList[] = { "-Example", Test_Example }, { "FileAttributes", Test_FileAttributes }, { "FindFile", Test_FindFile }, + { "FltMgrLoad", Test_FltMgrLoad }, + { "FltMgrReg", Test_FltMgrReg }, { "HidPDescription", Test_HidPDescription }, { "IoCreateFile", Test_IoCreateFile }, { "IoDeviceObject", Test_IoDeviceObject }, diff --git a/modules/rostests/kmtests/kmtest_drv/kmtest_fsminifilter.c b/modules/rostests/kmtests/kmtest_drv/kmtest_fsminifilter.c index 2de7f3b159..3bb686de9e 100644 --- a/modules/rostests/kmtests/kmtest_drv/kmtest_fsminifilter.c +++ b/modules/rostests/kmtests/kmtest_drv/kmtest_fsminifilter.c @@ -2,7 +2,8 @@ * PROJECT: ReactOS kernel-mode tests - Filter Manager * LICENSE: GPLv2+ - See COPYING in the top level directory * PURPOSE: FS Mini-filter wrapper to host the filter manager tests - * PROGRAMMER: Ged Murphy <ged.mur...@reactos.org> + * PROGRAMMER: Thomas Faber <thomas.fa...@reactos.org> + * Ged Murphy <ged.mur...@reactos.org> */ #include <ntifs.h> @@ -35,6 +36,7 @@ DRIVER_INITIALIZE DriverEntry; /* Globals */ static PDRIVER_OBJECT TestDriverObject; +static PDEVICE_OBJECT KmtestDeviceObject; static FILTER_DATA FilterData; static PFLT_OPERATION_REGISTRATION Callbacks = NULL; static PFLT_CONTEXT_REGISTRATION Contexts = NULL; @@ -133,6 +135,9 @@ DriverEntry( PSECURITY_DESCRIPTOR SecurityDescriptor; UNICODE_STRING DeviceName; WCHAR DeviceNameBuffer[128] = L"\\Device\\Kmtest-"; + UNICODE_STRING KmtestDeviceName; + PFILE_OBJECT KmtestFileObject; + PKMT_DEVICE_EXTENSION KmtestDeviceExtension; PCWSTR DeviceNameSuffix; INT Flags = 0; PKPRCB Prcb; @@ -148,6 +153,32 @@ DriverEntry( KmtIsMultiProcessorBuild = (Prcb->BuildType & PRCB_BUILD_UNIPROCESSOR) == 0; TestDriverObject = DriverObject; + /* get the Kmtest device, so that we get a ResultBuffer pointer */ + RtlInitUnicodeString(&KmtestDeviceName, KMTEST_DEVICE_DRIVER_PATH); + Status = IoGetDeviceObjectPointer(&KmtestDeviceName, FILE_ALL_ACCESS, &KmtestFileObject, &KmtestDeviceObject); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to get Kmtest device object pointer\n"); + goto cleanup; + } + + Status = ObReferenceObjectByPointer(KmtestDeviceObject, FILE_ALL_ACCESS, NULL, KernelMode); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to reference Kmtest device object\n"); + goto cleanup; + } + + ObDereferenceObject(KmtestFileObject); + KmtestFileObject = NULL; + KmtestDeviceExtension = KmtestDeviceObject->DeviceExtension; + ResultBuffer = KmtestDeviceExtension->ResultBuffer; + DPRINT("KmtestDeviceObject: %p\n", (PVOID)KmtestDeviceObject); + DPRINT("KmtestDeviceExtension: %p\n", (PVOID)KmtestDeviceExtension); + DPRINT("Setting ResultBuffer: %p\n", (PVOID)ResultBuffer); + /* call TestEntry */ RtlInitUnicodeString(&DeviceName, DeviceNameBuffer); @@ -243,6 +274,7 @@ FilterUnload( { PAGED_CODE(); UNREFERENCED_PARAMETER(Flags); + //__debugbreak(); DPRINT("DriverUnload\n"); @@ -312,50 +344,54 @@ FilterInstanceSetup( UNREFERENCED_PARAMETER(FltObjects); UNREFERENCED_PARAMETER(Flags); - RtlInitUnicodeString(&VolumeName, NULL); + if (!(Flags & TESTENTRY_NO_INSTANCE_SETUP)) + { + RtlInitUnicodeString(&VolumeName, NULL); + #if 0 // FltGetVolumeProperties is not yet implemented /* Get the properties of this volume */ - Status = FltGetVolumeProperties(Volume, - VolumeProperties, - sizeof(VolPropBuffer), - &LengthReturned); - if (NT_SUCCESS(Status)) - { - FLT_ASSERT((VolumeProperties->SectorSize == 0) || (VolumeProperties->SectorSize >= MIN_SECTOR_SIZE)); - SectorSize = max(VolumeProperties->SectorSize, MIN_SECTOR_SIZE); - ReportedSectorSize = VolumeProperties->SectorSize; - } - else - { - DPRINT1("Failed to get the volume properties : 0x%X", Status); - return Status; - } + Status = FltGetVolumeProperties(Volume, + VolumeProperties, + sizeof(VolPropBuffer), + &LengthReturned); + if (NT_SUCCESS(Status)) + { + FLT_ASSERT((VolumeProperties->SectorSize == 0) || (VolumeProperties->SectorSize >= MIN_SECTOR_SIZE)); + SectorSize = max(VolumeProperties->SectorSize, MIN_SECTOR_SIZE); + ReportedSectorSize = VolumeProperties->SectorSize; + } + else + { + DPRINT1("Failed to get the volume properties : 0x%X", Status); + return Status; + } #endif - /* Get the storage device object we want a name for */ - Status = FltGetDiskDeviceObject(FltObjects->Volume, &DeviceObject); - if (NT_SUCCESS(Status)) - { - /* Get the dos device name */ - Status = IoVolumeDeviceToDosName(DeviceObject, &VolumeName); + /* Get the storage device object we want a name for */ + Status = FltGetDiskDeviceObject(FltObjects->Volume, &DeviceObject); if (NT_SUCCESS(Status)) { - DPRINT("VolumeDeviceType %lu, VolumeFilesystemType %lu, Real SectSize=0x%04x, Reported SectSize=0x%04x, Name=\"%wZ\"", - VolumeDeviceType, - VolumeFilesystemType, - SectorSize, - ReportedSectorSize, - &VolumeName); - - Status = TestInstanceSetup(FltObjects, - Flags, - VolumeDeviceType, - VolumeFilesystemType, - &VolumeName, - SectorSize, - ReportedSectorSize); - - /* The buffer was allocated by the IoMgr */ - ExFreePool(VolumeName.Buffer); + /* Get the dos device name */ + Status = IoVolumeDeviceToDosName(DeviceObject, &VolumeName); + if (NT_SUCCESS(Status)) + { + DPRINT("VolumeDeviceType %lu, VolumeFilesystemType %lu, Real SectSize=0x%04x, Reported SectSize=0x%04x, Name=\"%wZ\"", + VolumeDeviceType, + VolumeFilesystemType, + SectorSize, + ReportedSectorSize, + &VolumeName); + + Status = TestInstanceSetup(FltObjects, + Flags, + VolumeDeviceType, + VolumeFilesystemType, + &VolumeName, + SectorSize, + ReportedSectorSize); + + /* The buffer was allocated by the IoMgr */ + ExFreePool(VolumeName.Buffer); + } } } @@ -384,7 +420,10 @@ FilterQueryTeardown( { PAGED_CODE(); - TestQueryTeardown(FltObjects, Flags); + if (!(Flags & TESTENTRY_NO_QUERY_TEARDOWN)) + { + TestQueryTeardown(FltObjects, Flags); + } /* We always allow a volume to detach */ return STATUS_SUCCESS;