Firstly, pass the wp_copy variable into hugetlb_mcopy_atomic_pte() thoughout
the stack.  Then, apply the UFFD_WP bit if UFFDIO_COPY_MODE_WP is with
UFFDIO_COPY.  Introduce huge_pte_mkuffd_wp() for it.

Note that similar to how we've handled shmem, we'd better keep setting the
dirty bit even if UFFDIO_COPY_MODE_WP is provided, so that the core mm will
know this page contains valid data and never drop it.

Signed-off-by: Peter Xu <[email protected]>
---
 include/asm-generic/hugetlb.h |  5 +++++
 include/linux/hugetlb.h       |  6 ++++--
 mm/hugetlb.c                  |  9 +++++++--
 mm/userfaultfd.c              | 12 ++++++++----
 4 files changed, 24 insertions(+), 8 deletions(-)

diff --git a/include/asm-generic/hugetlb.h b/include/asm-generic/hugetlb.h
index 8e1e6244a89d..548212eccbd6 100644
--- a/include/asm-generic/hugetlb.h
+++ b/include/asm-generic/hugetlb.h
@@ -27,6 +27,11 @@ static inline pte_t huge_pte_mkdirty(pte_t pte)
        return pte_mkdirty(pte);
 }
 
+static inline pte_t huge_pte_mkuffd_wp(pte_t pte)
+{
+       return pte_mkuffd_wp(pte);
+}
+
 static inline pte_t huge_pte_modify(pte_t pte, pgprot_t newprot)
 {
        return pte_modify(pte, newprot);
diff --git a/include/linux/hugetlb.h b/include/linux/hugetlb.h
index ebca2ef02212..bd061f7eedcb 100644
--- a/include/linux/hugetlb.h
+++ b/include/linux/hugetlb.h
@@ -138,7 +138,8 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, 
pte_t *dst_pte,
                                struct vm_area_struct *dst_vma,
                                unsigned long dst_addr,
                                unsigned long src_addr,
-                               struct page **pagep);
+                               struct page **pagep,
+                               bool wp_copy);
 int hugetlb_reserve_pages(struct inode *inode, long from, long to,
                                                struct vm_area_struct *vma,
                                                vm_flags_t vm_flags);
@@ -313,7 +314,8 @@ static inline int hugetlb_mcopy_atomic_pte(struct mm_struct 
*dst_mm,
                                                struct vm_area_struct *dst_vma,
                                                unsigned long dst_addr,
                                                unsigned long src_addr,
-                                               struct page **pagep)
+                                               struct page **pagep,
+                                               bool wp_copy)
 {
        BUG();
        return 0;
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index dcbbba53bd10..563b8f70537f 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -4624,7 +4624,8 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
                            struct vm_area_struct *dst_vma,
                            unsigned long dst_addr,
                            unsigned long src_addr,
-                           struct page **pagep)
+                           struct page **pagep,
+                           bool wp_copy)
 {
        struct address_space *mapping;
        pgoff_t idx;
@@ -4717,8 +4718,12 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
        }
 
        _dst_pte = make_huge_pte(dst_vma, page, dst_vma->vm_flags & VM_WRITE);
-       if (dst_vma->vm_flags & VM_WRITE)
+       if (dst_vma->vm_flags & VM_WRITE) {
                _dst_pte = huge_pte_mkdirty(_dst_pte);
+               if (wp_copy)
+                       _dst_pte = huge_pte_mkuffd_wp(
+                           huge_pte_wrprotect(_dst_pte));
+       }
        _dst_pte = pte_mkyoung(_dst_pte);
 
        set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 6d4b3b7c7f9f..b00e5e6b8b8b 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -207,7 +207,8 @@ static __always_inline ssize_t 
__mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
                                              unsigned long dst_start,
                                              unsigned long src_start,
                                              unsigned long len,
-                                             bool zeropage)
+                                             bool zeropage,
+                                             bool wp_copy)
 {
        int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED;
        int vm_shared = dst_vma->vm_flags & VM_SHARED;
@@ -306,7 +307,8 @@ static __always_inline ssize_t 
__mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
                }
 
                err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
-                                               dst_addr, src_addr, &page);
+                                              dst_addr, src_addr, &page,
+                                              wp_copy);
 
                mutex_unlock(&hugetlb_fault_mutex_table[hash]);
                i_mmap_unlock_read(mapping);
@@ -408,7 +410,8 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct 
*dst_mm,
                                      unsigned long dst_start,
                                      unsigned long src_start,
                                      unsigned long len,
-                                     bool zeropage);
+                                     bool zeropage,
+                                     bool wp_copy);
 #endif /* CONFIG_HUGETLB_PAGE */
 
 static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
@@ -527,7 +530,8 @@ static __always_inline ssize_t __mcopy_atomic(struct 
mm_struct *dst_mm,
         */
        if (is_vm_hugetlb_page(dst_vma))
                return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
-                                               src_start, len, zeropage);
+                                              src_start, len, zeropage,
+                                              wp_copy);
 
        if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
                goto out_unlock;
-- 
2.26.2

Reply via email to