https://git.reactos.org/?p=reactos.git;a=commitdiff;h=c8fb3f75145b94cdab5cfe908d46f8ed7326494d

commit c8fb3f75145b94cdab5cfe908d46f8ed7326494d
Author:     Jérôme Gardou <jerome.gar...@reactos.org>
AuthorDate: Fri May 28 16:58:59 2021 +0200
Commit:     Jérôme Gardou <zefk...@users.noreply.github.com>
CommitDate: Wed Jun 9 11:27:18 2021 +0200

    [NTOS:MM] Implement proper refcounting of page tables on amd64
    
    CORE-17552
---
 ntoskrnl/mm/ARM3/mdlsup.c   |  27 ++----
 ntoskrnl/mm/ARM3/miarm.h    | 215 +++++++++++++++++++++++++++++++++++++-------
 ntoskrnl/mm/ARM3/pagfault.c |   5 ++
 ntoskrnl/mm/ARM3/pfnlist.c  |   2 +
 ntoskrnl/mm/ARM3/session.c  |   2 +-
 ntoskrnl/mm/ARM3/virtual.c  |  15 ++--
 ntoskrnl/mm/i386/page.c     |   8 +-
 sdk/include/ndk/mmtypes.h   |  11 +++
 8 files changed, 217 insertions(+), 68 deletions(-)

diff --git a/ntoskrnl/mm/ARM3/mdlsup.c b/ntoskrnl/mm/ARM3/mdlsup.c
index 10a2aa02f68..9332ec3153e 100644
--- a/ntoskrnl/mm/ARM3/mdlsup.c
+++ b/ntoskrnl/mm/ARM3/mdlsup.c
@@ -248,6 +248,7 @@ MiMapLockedPagesInUserSpace(
 
         /* Acquire a share count */
         Pfn1 = MI_PFN_ELEMENT(PointerPde->u.Hard.PageFrameNumber);
+        DPRINT("Incrementing %p from %p\n", Pfn1, _ReturnAddress());
         OldIrql = MiAcquirePfnLock();
         Pfn1->u2.ShareCount++;
         MiReleasePfnLock(OldIrql);
@@ -330,9 +331,6 @@ MiUnmapLockedPagesInUserSpace(
         ASSERT(MiAddressToPte(PointerPte)->u.Hard.Valid == 1);
         ASSERT(PointerPte->u.Hard.Valid == 1);
 
-        /* Dereference the page */
-        MiDecrementPageTableReferences(BaseAddress);
-
         /* Invalidate it */
         MI_ERASE_PTE(PointerPte);
 
@@ -341,28 +339,17 @@ MiUnmapLockedPagesInUserSpace(
         PageTablePage = PointerPde->u.Hard.PageFrameNumber;
         MiDecrementShareCount(MiGetPfnEntry(PageTablePage), PageTablePage);
 
+        if (MiDecrementPageTableReferences(BaseAddress) == 0)
+        {
+            ASSERT(MiIsPteOnPdeBoundary(PointerPte + 1) || (NumberOfPages == 
1));
+            MiDeletePde(PointerPde, Process);
+        }
+
         /* Next page */
         PointerPte++;
         NumberOfPages--;
         BaseAddress = (PVOID)((ULONG_PTR)BaseAddress + PAGE_SIZE);
         MdlPages++;
-
-        /* Moving to a new PDE? */
-        if (PointerPde != MiAddressToPde(BaseAddress))
-        {
-            /* See if we should delete it */
-            KeFlushProcessTb();
-            PointerPde = MiPteToPde(PointerPte - 1);
-            ASSERT(PointerPde->u.Hard.Valid == 1);
-            if (MiQueryPageTableReferences(BaseAddress) == 0)
-            {
-                ASSERT(PointerPde->u.Long != 0);
-                MiDeletePte(PointerPde,
-                            MiPteToAddress(PointerPde),
-                            Process,
-                            NULL);
-            }
-        }
     }
 
     KeFlushProcessTb();
diff --git a/ntoskrnl/mm/ARM3/miarm.h b/ntoskrnl/mm/ARM3/miarm.h
index a784d08e15f..8c165759277 100644
--- a/ntoskrnl/mm/ARM3/miarm.h
+++ b/ntoskrnl/mm/ARM3/miarm.h
@@ -1823,40 +1823,7 @@ MiReferenceUnusedPageAndBumpLockCount(IN PMMPFN Pfn1)
     }
 }
 
-FORCEINLINE
-VOID
-MiIncrementPageTableReferences(IN PVOID Address)
-{
-    PUSHORT RefCount;
-
-    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
-    *RefCount += 1;
-    ASSERT(*RefCount <= PTE_PER_PAGE);
-}
 
-FORCEINLINE
-VOID
-MiDecrementPageTableReferences(IN PVOID Address)
-{
-    PUSHORT RefCount;
-
-    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
-    *RefCount -= 1;
-    ASSERT(*RefCount < PTE_PER_PAGE);
-}
-
-FORCEINLINE
-USHORT
-MiQueryPageTableReferences(IN PVOID Address)
-{
-    PUSHORT RefCount;
-
-    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
-
-    return *RefCount;
-}
 
 CODE_SEG("INIT")
 BOOLEAN
@@ -2484,8 +2451,190 @@ MiSynchronizeSystemPde(PMMPDE PointerPde)
 }
 #endif
 
+#if _MI_PAGING_LEVELS == 2
+FORCEINLINE
+USHORT
+MiIncrementPageTableReferences(IN PVOID Address)
+{
+    PUSHORT RefCount;
+
+    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+    *RefCount += 1;
+    ASSERT(*RefCount <= PTE_PER_PAGE);
+    return *RefCount;
+}
+
+FORCEINLINE
+USHORT
+MiDecrementPageTableReferences(IN PVOID Address)
+{
+    PUSHORT RefCount;
+
+    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+    *RefCount -= 1;
+    ASSERT(*RefCount < PTE_PER_PAGE);
+    return *RefCount;
+}
+
+FORCEINLINE
+USHORT
+MiQueryPageTableReferences(IN PVOID Address)
+{
+    PUSHORT RefCount;
+
+    RefCount = 
&MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(Address)];
+
+    return *RefCount;
+}
+#else
+FORCEINLINE
+USHORT
+MiIncrementPageTableReferences(IN PVOID Address)
+{
+    PMMPDE PointerPde = MiAddressToPde(Address);
+    PMMPFN Pfn;
+
+    /* We should not tinker with this one. */
+    ASSERT(PointerPde != (PMMPDE)PXE_SELFMAP);
+    DPRINT("Incrementing %p from %p\n", Address, _ReturnAddress());
+
+    /* Make sure we're locked */
+    ASSERT(PsGetCurrentThread()->OwnsProcessWorkingSetExclusive);
+
+    /* If we're bumping refcount, then it must be valid! */
+    ASSERT(PointerPde->u.Hard.Valid == 1);
+
+    /* This lies on the PFN */
+    Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+    Pfn->OriginalPte.u.Soft.UsedPageTableEntries++;
+
+    ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries <= PTE_PER_PAGE);
+
+    return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+
+FORCEINLINE
+USHORT
+MiDecrementPageTableReferences(IN PVOID Address)
+{
+    PMMPDE PointerPde = MiAddressToPde(Address);
+    PMMPFN Pfn;
+
+    /* We should not tinker with this one. */
+    ASSERT(PointerPde != (PMMPDE)PXE_SELFMAP);
+
+    DPRINT("Decrementing %p from %p\n", PointerPde, _ReturnAddress());
+
+    /* Make sure we're locked */
+    ASSERT(PsGetCurrentThread()->OwnsProcessWorkingSetExclusive);
+
+    /* If we're decreasing refcount, then it must be valid! */
+    ASSERT(PointerPde->u.Hard.Valid == 1);
+
+    /* This lies on the PFN */
+    Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+
+    ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries != 0);
+    Pfn->OriginalPte.u.Soft.UsedPageTableEntries--;
+
+    ASSERT(Pfn->OriginalPte.u.Soft.UsedPageTableEntries < PTE_PER_PAGE);
+
+    return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+
+FORCEINLINE
+USHORT
+MiQueryPageTableReferences(IN PVOID Address)
+{
+    PMMPDE PointerPde;
+    PMMPPE PointerPpe;
+#if _MI_PAGING_LEVELS == 4
+    PMMPXE PointerPxe;
+#endif
+    PMMPFN Pfn;
+
+    /* Make sure we're locked */
+    ASSERT((PsGetCurrentThread()->OwnsProcessWorkingSetExclusive) || 
(PsGetCurrentThread()->OwnsProcessWorkingSetShared));
+
+    /* Check if PXE or PPE have references first. */
+#if _MI_PAGING_LEVELS == 4
+    PointerPxe = MiAddressToPxe(Address);
+    if ((PointerPxe->u.Hard.Valid == 1) || (PointerPxe->u.Soft.Transition == 
1))
+    {
+        Pfn = MiGetPfnEntry(PFN_FROM_PXE(PointerPxe));
+        if (Pfn->OriginalPte.u.Soft.UsedPageTableEntries == 0)
+            return 0;
+    }
+    else if (PointerPxe->u.Soft.UsedPageTableEntries == 0)
+    {
+        return 0;
+    }
+
+    if (PointerPxe->u.Hard.Valid == 0)
+    {
+        MiMakeSystemAddressValid(MiPteToAddress(PointerPxe), 
PsGetCurrentProcess());
+    }
+#endif
+
+    PointerPpe = MiAddressToPpe(Address);
+    if ((PointerPpe->u.Hard.Valid == 1) || (PointerPpe->u.Soft.Transition == 
1))
+    {
+        Pfn = MiGetPfnEntry(PFN_FROM_PPE(PointerPpe));
+        if (Pfn->OriginalPte.u.Soft.UsedPageTableEntries == 0)
+            return 0;
+    }
+    else if (PointerPpe->u.Soft.UsedPageTableEntries == 0)
+    {
+        return 0;
+    }
+
+    if (PointerPpe->u.Hard.Valid == 0)
+    {
+        MiMakeSystemAddressValid(MiPteToAddress(PointerPpe), 
PsGetCurrentProcess());
+    }
+
+    PointerPde = MiAddressToPde(Address);
+    if ((PointerPde->u.Hard.Valid == 0) && (PointerPde->u.Soft.Transition == 
0))
+    {
+        return PointerPde->u.Soft.UsedPageTableEntries;
+    }
+
+    /* This lies on the PFN */
+    Pfn = MiGetPfnEntry(PFN_FROM_PDE(PointerPde));
+    return Pfn->OriginalPte.u.Soft.UsedPageTableEntries;
+}
+#endif
+
 #ifdef __cplusplus
 } // extern "C"
 #endif
 
+FORCEINLINE
+VOID
+MiDeletePde(
+    _In_ PMMPDE PointerPde,
+    _In_ PEPROCESS CurrentProcess)
+{
+    /* Only for user-mode ones */
+    ASSERT(MiIsUserPde(PointerPde));
+
+    /* Kill this one as a PTE */
+    MiDeletePte((PMMPTE)PointerPde, MiPdeToPte(PointerPde), CurrentProcess, 
NULL);
+#if _MI_PAGING_LEVELS >= 3
+    /* Cascade down */
+    if (MiDecrementPageTableReferences(MiPdeToPte(PointerPde)) == 0)
+    {
+        MiDeletePte(MiPdeToPpe(PointerPde), PointerPde, CurrentProcess, NULL);
+#if _MI_PAGING_LEVELS == 4
+        if (MiDecrementPageTableReferences(PointerPde) == 0)
+        {
+            MiDeletePte(MiPdeToPxe(PointerPde), MiPdeToPpe(PointerPde), 
CurrentProcess, NULL);
+        }
+#endif
+    }
+#endif
+}
+
 /* EOF */
diff --git a/ntoskrnl/mm/ARM3/pagfault.c b/ntoskrnl/mm/ARM3/pagfault.c
index 87c789c1742..b6e2f9e8287 100644
--- a/ntoskrnl/mm/ARM3/pagfault.c
+++ b/ntoskrnl/mm/ARM3/pagfault.c
@@ -2145,6 +2145,7 @@ UserFault:
 
         /* We should come back with a valid PPE */
         ASSERT(PointerPpe->u.Hard.Valid == 1);
+        MiIncrementPageTableReferences(PointerPde);
     }
 #endif
 
@@ -2184,6 +2185,10 @@ UserFault:
                                  MM_EXECUTE_READWRITE,
                                  CurrentProcess,
                                  MM_NOIRQL);
+#if _MI_PAGING_LEVELS >= 3
+        MiIncrementPageTableReferences(PointerPte);
+#endif
+
 #if MI_TRACE_PFNS
         UserPdeFault = FALSE;
 #endif
diff --git a/ntoskrnl/mm/ARM3/pfnlist.c b/ntoskrnl/mm/ARM3/pfnlist.c
index b9838f29602..726143870f4 100644
--- a/ntoskrnl/mm/ARM3/pfnlist.c
+++ b/ntoskrnl/mm/ARM3/pfnlist.c
@@ -1027,6 +1027,8 @@ MiInitializePfn(IN PFN_NUMBER PageFrameIndex,
     ASSERT(PageFrameIndex != 0);
     Pfn1->u4.PteFrame = PageFrameIndex;
 
+    DPRINT("Incrementing share count of %lp from %p\n", PageFrameIndex, 
_ReturnAddress());
+
     /* Increase its share count so we don't get rid of it */
     Pfn1 = MI_PFN_ELEMENT(PageFrameIndex);
     Pfn1->u2.ShareCount++;
diff --git a/ntoskrnl/mm/ARM3/session.c b/ntoskrnl/mm/ARM3/session.c
index cb4de5f9b5f..70ae1ec9889 100644
--- a/ntoskrnl/mm/ARM3/session.c
+++ b/ntoskrnl/mm/ARM3/session.c
@@ -477,7 +477,7 @@ MiSessionInitializeWorkingSetList(VOID)
 
     /* Fill out the two pointers */
     MmSessionSpace->Vm.VmWorkingSetList = WorkingSetList;
-    MmSessionSpace->Wsle = (PMMWSLE)WorkingSetList->UsedPageTableEntries;
+    MmSessionSpace->Wsle = (PMMWSLE)((&WorkingSetList->VadBitMapHint) + 1);
 
     /* Get the PDE for the working set, and check if it's already allocated */
     PointerPde = MiAddressToPde(WorkingSetList);
diff --git a/ntoskrnl/mm/ARM3/virtual.c b/ntoskrnl/mm/ARM3/virtual.c
index 1f897c4e10d..f43351b7933 100644
--- a/ntoskrnl/mm/ARM3/virtual.c
+++ b/ntoskrnl/mm/ARM3/virtual.c
@@ -727,18 +727,15 @@ MiDeleteVirtualAddresses(IN ULONG_PTR Va,
         /* Check remaining PTE count (go back 1 page due to above loop) */
         if (MiQueryPageTableReferences((PVOID)(Va - PAGE_SIZE)) == 0)
         {
-            if (PointerPde->u.Long != 0)
-            {
-                /* Delete the PTE proper */
-                MiDeletePte(PointerPde,
-                            MiPteToAddress(PointerPde),
-                            CurrentProcess,
-                            NULL);
-            }
+            ASSERT(PointerPde->u.Long != 0);
+
+            /* Delete the PDE proper */
+            MiDeletePde(PointerPde, CurrentProcess);
         }
 
-        /* Release the lock and get out if we're done */
+        /* Release the lock */
         MiReleasePfnLock(OldIrql);
+
         if (Va > EndingAddress) return;
 
         /* Otherwise, we exited because we hit a new PDE boundary, so start 
over */
diff --git a/ntoskrnl/mm/i386/page.c b/ntoskrnl/mm/i386/page.c
index d15a9f74964..7e7db5bd431 100644
--- a/ntoskrnl/mm/i386/page.c
+++ b/ntoskrnl/mm/i386/page.c
@@ -238,11 +238,10 @@ MmDeleteVirtualMapping(PEPROCESS Process, PVOID Address,
     if (Address < MmSystemRangeStart)
     {
         /* Remove PDE reference */
-        MiDecrementPageTableReferences(Address);
-        if (MiQueryPageTableReferences(Address) == 0)
+        if (MiDecrementPageTableReferences(Address) == 0)
         {
             KIRQL OldIrql = MiAcquirePfnLock();
-            MiDeletePte(MiAddressToPte(PointerPte), PointerPte, Process, NULL);
+            MiDeletePde(MiAddressToPde(Address), Process);
             MiReleasePfnLock(OldIrql);
         }
 
@@ -293,8 +292,7 @@ MmDeletePageFileMapping(
     }
 
     /* This used to be a non-zero PTE, now we can let the PDE go. */
-    MiDecrementPageTableReferences(Address);
-    if (MiQueryPageTableReferences(Address) == 0)
+    if (MiDecrementPageTableReferences(Address) == 0)
     {
         /* We can let it go */
         KIRQL OldIrql = MiAcquirePfnLock();
diff --git a/sdk/include/ndk/mmtypes.h b/sdk/include/ndk/mmtypes.h
index 28d6cb00006..f33d5cc8b14 100644
--- a/sdk/include/ndk/mmtypes.h
+++ b/sdk/include/ndk/mmtypes.h
@@ -879,8 +879,19 @@ typedef struct _MMWSL
     PVOID HighestPermittedHashAddress;
     ULONG NumberOfImageWaiters;
     ULONG VadBitMapHint;
+#ifndef _M_AMD64
     USHORT UsedPageTableEntries[768];
     ULONG CommittedPageTables[24];
+#else
+    VOID* HighestUserAddress;
+    ULONG MaximumUserPageTablePages;
+    ULONG MaximumUserPageDirectoryPages;
+    ULONG* CommittedPageTables;
+    ULONG NumberOfCommittedPageDirectories;
+    ULONG* CommittedPageDirectories;
+    ULONG NumberOfCommittedPageDirectoryParents;
+    ULONGLONG CommittedPageDirectoryParents[1];
+#endif
 } MMWSL, *PMMWSL;
 
 //

Reply via email to