Convert the mshv driver's HMM fault path to use hmm_range_fault_unlockable() instead of hmm_range_fault(). This enables userfaultfd-backed guest memory regions by allowing the mmap lock to be dropped during page fault handling.
Extract the per-VMA walk into a dedicated mshv_region_hmm_fault_walk() helper. The outer mshv_region_hmm_fault_and_lock() handles the do/while restart loop: if the lock is dropped during a fault (userfaultfd resolution or similar) or an invalidation occurs (-EBUSY), the function restarts the entire walk from the beginning with a fresh notifier_seq, since the VMA layout may have changed. Signed-off-by: Stanislav Kinsburskii <[email protected]> --- drivers/hv/mshv_regions.c | 127 +++++++++++++++++++++++++++++++-------------- 1 file changed, 87 insertions(+), 40 deletions(-) diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c index d09940e88298e..05665446ca6d9 100644 --- a/drivers/hv/mshv_regions.c +++ b/drivers/hv/mshv_regions.c @@ -565,6 +565,75 @@ int mshv_region_get(struct mshv_region *region) return kref_get_unless_zero(®ion->mreg_refcount); } +/** + * mshv_region_hmm_fault_walk - Walk VMAs and fault in pages for a range + * @region : Pointer to the memory region structure + * @range : HMM range structure (caller sets notifier and notifier_seq) + * @start : Starting virtual address of the range to fault (inclusive) + * @end : Ending virtual address of the range to fault (exclusive) + * @pfns : Output array for page frame numbers with HMM flags + * @locked : Pointer to lock state; set to 0 if mmap lock was dropped + * @do_fault: If true, fault in missing pages; if false, snapshot only + * + * Iterates through VMAs covering [start, end), collecting page frame + * numbers via hmm_range_fault_unlockable() for each VMA segment. + * When @do_fault is true, missing pages are faulted in and write faults + * are requested only when both the VMA and the hypervisor mapping permit + * writes, to avoid breaking copy-on-write semantics on read-only mappings. + * + * Return: 0 on success, negative error code on failure. + */ +static int mshv_region_hmm_fault_walk(struct mshv_region *region, + struct hmm_range *range, + unsigned long start, + unsigned long end, + unsigned long *pfns, + int *locked, + bool do_fault) +{ + unsigned long cur_start = start; + unsigned long *cur_pfns = pfns; + + while (cur_start < end) { + struct vm_area_struct *vma; + + vma = vma_lookup(range->notifier->mm, cur_start); + if (!vma) + return -EFAULT; + + range->hmm_pfns = cur_pfns; + range->start = cur_start; + range->end = min(vma->vm_end, end); + range->default_flags = 0; + if (do_fault) { + range->default_flags = HMM_PFN_REQ_FAULT; + /* + * Only request writable pages from HMM when + * both the VMA and the hypervisor mapping allow + * writes. Without this, hmm_range_fault() would + * trigger COW on read-only mappings (e.g. shared + * zero pages, file-backed pages), breaking + * copy-on-write semantics and potentially + * granting the guest write access to shared host + * pages. + */ + if ((vma->vm_flags & VM_WRITE) && + (region->hv_map_flags & HV_MAP_GPA_WRITABLE)) + range->default_flags |= HMM_PFN_REQ_WRITE; + } + + int ret = hmm_range_fault_unlockable(range, locked); + + if (ret || !*locked) + return ret; + + cur_start = range->end; + cur_pfns += (range->end - range->start) >> PAGE_SHIFT; + } + + return 0; +} + /** * mshv_region_hmm_fault_and_lock - Fault in pages across VMAs and lock * the memory region @@ -575,11 +644,9 @@ int mshv_region_get(struct mshv_region *region) * @do_fault: If true, fault in missing pages; if false, snapshot only * pages already present in page tables * - * Iterates through VMAs covering [start, end), collecting page frame - * numbers via hmm_range_fault() for each VMA segment. When @do_fault - * is true, missing pages are faulted in and write faults are requested - * only when both the VMA and the hypervisor mapping permit writes, to - * avoid breaking copy-on-write semantics on read-only mappings. + * Faults in pages covering [start, end) and acquires region->mreg_mutex. + * If the mmap lock is dropped during the fault (e.g. by userfaultfd) or + * the mmu notifier sequence is invalidated, the entire walk is restarted. * * On success, returns with region->mreg_mutex held; the caller is * responsible for releasing it. Returns -EBUSY if the mmu notifier @@ -597,47 +664,27 @@ static int mshv_region_hmm_fault_and_lock(struct mshv_region *region, .notifier = ®ion->mreg_mni, }; struct mm_struct *mm = region->mreg_mni.mm; + int locked; int ret; - range.notifier_seq = mmu_interval_read_begin(range.notifier); - mmap_read_lock(mm); - while (start < end) { - struct vm_area_struct *vma; + do { + range.notifier_seq = mmu_interval_read_begin(range.notifier); + locked = 1; + mmap_read_lock(mm); - vma = vma_lookup(mm, start); - if (!vma) { - ret = -EFAULT; - break; - } + ret = mshv_region_hmm_fault_walk(region, &range, start, end, + pfns, &locked, do_fault); - range.hmm_pfns = pfns; - range.start = start; - range.end = min(vma->vm_end, end); - range.default_flags = 0; - if (do_fault) { - range.default_flags = HMM_PFN_REQ_FAULT; - /* - * Only request writable pages from HMM when both - * the VMA and the hypervisor mapping allow writes. - * Without this, hmm_range_fault() would trigger - * COW on read-only mappings (e.g. shared zero - * pages, file-backed pages), breaking - * copy-on-write semantics and potentially granting - * the guest write access to shared host pages. - */ - if ((vma->vm_flags & VM_WRITE) && - (region->hv_map_flags & HV_MAP_GPA_WRITABLE)) - range.default_flags |= HMM_PFN_REQ_WRITE; - } + if (locked) + mmap_read_unlock(mm); - ret = hmm_range_fault(&range); - if (ret) - break; + /* + * If the lock was dropped (by userfaultfd or similar), restart + * the entire walk with a fresh notifier_seq since the VMA layout + * may have changed. Also restart on -EBUSY (invalidation). + */ + } while (!locked || ret == -EBUSY); - start = range.end; - pfns += (range.end - range.start) >> PAGE_SHIFT; - } - mmap_read_unlock(mm); if (ret) return ret;

