Firstly, introduce two new flags MM_CP_UFFD_WP[_RESOLVE] for
change_protection() when used with uffd-wp and make sure the two new
flags are exclusively used.  Then,

  - For MM_CP_UFFD_WP: apply the _PAGE_UFFD_WP bit and remove _PAGE_RW
    when a range of memory is write protected by uffd

  - For MM_CP_UFFD_WP_RESOLVE: remove the _PAGE_UFFD_WP bit and recover
    _PAGE_RW when write protection is resolved from userspace

And use this new interface in mwriteprotect_range() to replace the old
MM_CP_DIRTY_ACCT.

Do this change for both PTEs and huge PMDs.  Then we can start to
identify which PTE/PMD is write protected by general (e.g., COW or soft
dirty tracking), and which is for userfaultfd-wp.

Since we should keep the _PAGE_UFFD_WP when doing pte_modify(), add it
into _PAGE_CHG_MASK as well.  Meanwhile, since we have this new bit, we
can be even more strict when detecting uffd-wp page faults in either
do_wp_page() or wp_huge_pmd().

Reviewed-by: Jerome Glisse <jgli...@redhat.com>
Reviewed-by: Mike Rapoport <r...@linux.vnet.ibm.com>
Signed-off-by: Peter Xu <pet...@redhat.com>
---
 include/linux/mm.h |  5 +++++
 mm/huge_memory.c   | 14 +++++++++++++-
 mm/memory.c        |  4 ++--
 mm/mprotect.c      | 12 ++++++++++++
 mm/userfaultfd.c   |  8 ++++++--
 5 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 086e69d4439d..a5ac81188523 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1652,6 +1652,11 @@ extern unsigned long move_page_tables(struct 
vm_area_struct *vma,
 #define  MM_CP_DIRTY_ACCT                  (1UL << 0)
 /* Whether this protection change is for NUMA hints */
 #define  MM_CP_PROT_NUMA                   (1UL << 1)
+/* Whether this change is for write protecting */
+#define  MM_CP_UFFD_WP                     (1UL << 2) /* do wp */
+#define  MM_CP_UFFD_WP_RESOLVE             (1UL << 3) /* Resolve wp */
+#define  MM_CP_UFFD_WP_ALL                 (MM_CP_UFFD_WP | \
+                                           MM_CP_UFFD_WP_RESOLVE)
 
 extern unsigned long change_protection(struct vm_area_struct *vma, unsigned 
long start,
                              unsigned long end, pgprot_t newprot,
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 64d26b1989d2..3885747d4901 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1907,6 +1907,8 @@ int change_huge_pmd(struct vm_area_struct *vma, pmd_t 
*pmd,
        bool preserve_write;
        int ret;
        bool prot_numa = cp_flags & MM_CP_PROT_NUMA;
+       bool uffd_wp = cp_flags & MM_CP_UFFD_WP;
+       bool uffd_wp_resolve = cp_flags & MM_CP_UFFD_WP_RESOLVE;
 
        ptl = __pmd_trans_huge_lock(pmd, vma);
        if (!ptl)
@@ -1973,6 +1975,13 @@ int change_huge_pmd(struct vm_area_struct *vma, pmd_t 
*pmd,
        entry = pmd_modify(entry, newprot);
        if (preserve_write)
                entry = pmd_mk_savedwrite(entry);
+       if (uffd_wp) {
+               entry = pmd_wrprotect(entry);
+               entry = pmd_mkuffd_wp(entry);
+       } else if (uffd_wp_resolve) {
+               entry = pmd_mkwrite(entry);
+               entry = pmd_clear_uffd_wp(entry);
+       }
        ret = HPAGE_PMD_NR;
        set_pmd_at(mm, addr, pmd, entry);
        BUG_ON(vma_is_anonymous(vma) && !preserve_write && pmd_write(entry));
@@ -2120,7 +2129,7 @@ static void __split_huge_pmd_locked(struct vm_area_struct 
*vma, pmd_t *pmd,
        struct page *page;
        pgtable_t pgtable;
        pmd_t old_pmd, _pmd;
-       bool young, write, soft_dirty, pmd_migration = false;
+       bool young, write, soft_dirty, pmd_migration = false, uffd_wp = false;
        unsigned long addr;
        int i;
 
@@ -2202,6 +2211,7 @@ static void __split_huge_pmd_locked(struct vm_area_struct 
*vma, pmd_t *pmd,
                write = pmd_write(old_pmd);
                young = pmd_young(old_pmd);
                soft_dirty = pmd_soft_dirty(old_pmd);
+               uffd_wp = pmd_uffd_wp(old_pmd);
        }
        VM_BUG_ON_PAGE(!page_count(page), page);
        page_ref_add(page, HPAGE_PMD_NR - 1);
@@ -2235,6 +2245,8 @@ static void __split_huge_pmd_locked(struct vm_area_struct 
*vma, pmd_t *pmd,
                                entry = pte_mkold(entry);
                        if (soft_dirty)
                                entry = pte_mksoft_dirty(entry);
+                       if (uffd_wp)
+                               entry = pte_mkuffd_wp(entry);
                }
                pte = pte_offset_map(&_pmd, addr);
                BUG_ON(!pte_none(*pte));
diff --git a/mm/memory.c b/mm/memory.c
index 8ccd4927b58d..64bd8075f054 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2492,7 +2492,7 @@ static vm_fault_t do_wp_page(struct vm_fault *vmf)
 {
        struct vm_area_struct *vma = vmf->vma;
 
-       if (userfaultfd_wp(vma)) {
+       if (userfaultfd_pte_wp(vma, *vmf->pte)) {
                pte_unmap_unlock(vmf->pte, vmf->ptl);
                return handle_userfault(vmf, VM_UFFD_WP);
        }
@@ -3713,7 +3713,7 @@ static inline vm_fault_t create_huge_pmd(struct vm_fault 
*vmf)
 static inline vm_fault_t wp_huge_pmd(struct vm_fault *vmf, pmd_t orig_pmd)
 {
        if (vma_is_anonymous(vmf->vma)) {
-               if (userfaultfd_wp(vmf->vma))
+               if (userfaultfd_huge_pmd_wp(vmf->vma, orig_pmd))
                        return handle_userfault(vmf, VM_UFFD_WP);
                return do_huge_pmd_wp_page(vmf, orig_pmd);
        }
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 98091408bd11..732d9b6d1d21 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -46,6 +46,8 @@ static unsigned long change_pte_range(struct vm_area_struct 
*vma, pmd_t *pmd,
        int target_node = NUMA_NO_NODE;
        bool dirty_accountable = cp_flags & MM_CP_DIRTY_ACCT;
        bool prot_numa = cp_flags & MM_CP_PROT_NUMA;
+       bool uffd_wp = cp_flags & MM_CP_UFFD_WP;
+       bool uffd_wp_resolve = cp_flags & MM_CP_UFFD_WP_RESOLVE;
 
        /*
         * Can be called with only the mmap_sem for reading by
@@ -117,6 +119,14 @@ static unsigned long change_pte_range(struct 
vm_area_struct *vma, pmd_t *pmd,
                        if (preserve_write)
                                ptent = pte_mk_savedwrite(ptent);
 
+                       if (uffd_wp) {
+                               ptent = pte_wrprotect(ptent);
+                               ptent = pte_mkuffd_wp(ptent);
+                       } else if (uffd_wp_resolve) {
+                               ptent = pte_mkwrite(ptent);
+                               ptent = pte_clear_uffd_wp(ptent);
+                       }
+
                        /* Avoid taking write faults for known dirty pages */
                        if (dirty_accountable && pte_dirty(ptent) &&
                                        (pte_soft_dirty(ptent) ||
@@ -301,6 +311,8 @@ unsigned long change_protection(struct vm_area_struct *vma, 
unsigned long start,
 {
        unsigned long pages;
 
+       BUG_ON((cp_flags & MM_CP_UFFD_WP_ALL) == MM_CP_UFFD_WP_ALL);
+
        if (is_vm_hugetlb_page(vma))
                pages = hugetlb_change_protection(vma, start, end, newprot);
        else
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index eaecc21806da..240de2a8492d 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -73,8 +73,12 @@ static int mcopy_atomic_pte(struct mm_struct *dst_mm,
                goto out_release;
 
        _dst_pte = pte_mkdirty(mk_pte(page, dst_vma->vm_page_prot));
-       if ((dst_vma->vm_flags & VM_WRITE) && !wp_copy)
-               _dst_pte = pte_mkwrite(_dst_pte);
+       if (dst_vma->vm_flags & VM_WRITE) {
+               if (wp_copy)
+                       _dst_pte = pte_mkuffd_wp(_dst_pte);
+               else
+                       _dst_pte = pte_mkwrite(_dst_pte);
+       }
 
        dst_pte = pte_offset_map_lock(dst_mm, dst_pmd, dst_addr, &ptl);
        if (dst_vma->vm_file) {
-- 
2.17.1

Reply via email to