Convert the mshv driver's HMM fault path to use
hmm_range_fault_unlockable() instead of hmm_range_fault(). This enables
userfaultfd-backed guest memory regions by allowing the mmap lock to be
dropped during page fault handling.

Extract the per-VMA walk into a dedicated mshv_region_hmm_fault_walk()
helper. The outer mshv_region_hmm_fault_and_lock() handles the do/while
restart loop: if the lock is dropped during a fault (userfaultfd resolution
or similar) or an invalidation occurs (-EBUSY), the function restarts the
entire walk from the beginning with a fresh notifier_seq, since the VMA
layout may have changed.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 drivers/hv/mshv_regions.c |  127 +++++++++++++++++++++++++++++++--------------
 1 file changed, 87 insertions(+), 40 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index d09940e88298e..05665446ca6d9 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -565,6 +565,75 @@ int mshv_region_get(struct mshv_region *region)
        return kref_get_unless_zero(&region->mreg_refcount);
 }
 
+/**
+ * mshv_region_hmm_fault_walk - Walk VMAs and fault in pages for a range
+ * @region : Pointer to the memory region structure
+ * @range  : HMM range structure (caller sets notifier and notifier_seq)
+ * @start  : Starting virtual address of the range to fault (inclusive)
+ * @end    : Ending virtual address of the range to fault (exclusive)
+ * @pfns   : Output array for page frame numbers with HMM flags
+ * @locked : Pointer to lock state; set to 0 if mmap lock was dropped
+ * @do_fault: If true, fault in missing pages; if false, snapshot only
+ *
+ * Iterates through VMAs covering [start, end), collecting page frame
+ * numbers via hmm_range_fault_unlockable() for each VMA segment.
+ * When @do_fault is true, missing pages are faulted in and write faults
+ * are requested only when both the VMA and the hypervisor mapping permit
+ * writes, to avoid breaking copy-on-write semantics on read-only mappings.
+ *
+ * Return: 0 on success, negative error code on failure.
+ */
+static int mshv_region_hmm_fault_walk(struct mshv_region *region,
+                                     struct hmm_range *range,
+                                     unsigned long start,
+                                     unsigned long end,
+                                     unsigned long *pfns,
+                                     int *locked,
+                                     bool do_fault)
+{
+       unsigned long cur_start = start;
+       unsigned long *cur_pfns = pfns;
+
+       while (cur_start < end) {
+               struct vm_area_struct *vma;
+
+               vma = vma_lookup(range->notifier->mm, cur_start);
+               if (!vma)
+                       return -EFAULT;
+
+               range->hmm_pfns = cur_pfns;
+               range->start = cur_start;
+               range->end = min(vma->vm_end, end);
+               range->default_flags = 0;
+               if (do_fault) {
+                       range->default_flags = HMM_PFN_REQ_FAULT;
+                       /*
+                        * Only request writable pages from HMM when
+                        * both the VMA and the hypervisor mapping allow
+                        * writes. Without this, hmm_range_fault() would
+                        * trigger COW on read-only mappings (e.g. shared
+                        * zero pages, file-backed pages), breaking
+                        * copy-on-write semantics and potentially
+                        * granting the guest write access to shared host
+                        * pages.
+                        */
+                       if ((vma->vm_flags & VM_WRITE) &&
+                           (region->hv_map_flags & HV_MAP_GPA_WRITABLE))
+                               range->default_flags |= HMM_PFN_REQ_WRITE;
+               }
+
+               int ret = hmm_range_fault_unlockable(range, locked);
+
+               if (ret || !*locked)
+                       return ret;
+
+               cur_start = range->end;
+               cur_pfns += (range->end - range->start) >> PAGE_SHIFT;
+       }
+
+       return 0;
+}
+
 /**
  * mshv_region_hmm_fault_and_lock - Fault in pages across VMAs and lock
  *                                  the memory region
@@ -575,11 +644,9 @@ int mshv_region_get(struct mshv_region *region)
  * @do_fault: If true, fault in missing pages; if false, snapshot only
  *            pages already present in page tables
  *
- * Iterates through VMAs covering [start, end), collecting page frame
- * numbers via hmm_range_fault() for each VMA segment.  When @do_fault
- * is true, missing pages are faulted in and write faults are requested
- * only when both the VMA and the hypervisor mapping permit writes, to
- * avoid breaking copy-on-write semantics on read-only mappings.
+ * Faults in pages covering [start, end) and acquires region->mreg_mutex.
+ * If the mmap lock is dropped during the fault (e.g. by userfaultfd) or
+ * the mmu notifier sequence is invalidated, the entire walk is restarted.
  *
  * On success, returns with region->mreg_mutex held; the caller is
  * responsible for releasing it.  Returns -EBUSY if the mmu notifier
@@ -597,47 +664,27 @@ static int mshv_region_hmm_fault_and_lock(struct 
mshv_region *region,
                .notifier = &region->mreg_mni,
        };
        struct mm_struct *mm = region->mreg_mni.mm;
+       int locked;
        int ret;
 
-       range.notifier_seq = mmu_interval_read_begin(range.notifier);
-       mmap_read_lock(mm);
-       while (start < end) {
-               struct vm_area_struct *vma;
+       do {
+               range.notifier_seq = mmu_interval_read_begin(range.notifier);
+               locked = 1;
+               mmap_read_lock(mm);
 
-               vma = vma_lookup(mm, start);
-               if (!vma) {
-                       ret = -EFAULT;
-                       break;
-               }
+               ret = mshv_region_hmm_fault_walk(region, &range, start, end,
+                                                pfns, &locked, do_fault);
 
-               range.hmm_pfns = pfns;
-               range.start = start;
-               range.end = min(vma->vm_end, end);
-               range.default_flags = 0;
-               if (do_fault) {
-                       range.default_flags = HMM_PFN_REQ_FAULT;
-                       /*
-                        * Only request writable pages from HMM when both
-                        * the VMA and the hypervisor mapping allow writes.
-                        * Without this, hmm_range_fault() would trigger
-                        * COW on read-only mappings (e.g. shared zero
-                        * pages, file-backed pages), breaking
-                        * copy-on-write semantics and potentially granting
-                        * the guest write access to shared host pages.
-                        */
-                       if ((vma->vm_flags & VM_WRITE) &&
-                           (region->hv_map_flags & HV_MAP_GPA_WRITABLE))
-                               range.default_flags |= HMM_PFN_REQ_WRITE;
-               }
+               if (locked)
+                       mmap_read_unlock(mm);
 
-               ret = hmm_range_fault(&range);
-               if (ret)
-                       break;
+               /*
+                * If the lock was dropped (by userfaultfd or similar), restart
+                * the entire walk with a fresh notifier_seq since the VMA 
layout
+                * may have changed. Also restart on -EBUSY (invalidation).
+                */
+       } while (!locked || ret == -EBUSY);
 
-               start = range.end;
-               pfns += (range.end - range.start) >> PAGE_SHIFT;
-       }
-       mmap_read_unlock(mm);
        if (ret)
                return ret;
 



Reply via email to