Conversion is straightforward, mmap_sem is used within the the same function context most of the time. No change in semantics.
Signed-off-by: Davidlohr Bueso <dbu...@suse.de> --- fs/aio.c | 5 +++-- fs/coredump.c | 5 +++-- fs/exec.c | 19 +++++++++------- fs/io_uring.c | 5 +++-- fs/proc/base.c | 23 ++++++++++++-------- fs/proc/internal.h | 2 ++ fs/proc/task_mmu.c | 32 +++++++++++++++------------ fs/proc/task_nommu.c | 22 +++++++++++-------- fs/userfaultfd.c | 50 ++++++++++++++++++++++++++----------------- include/linux/userfaultfd_k.h | 5 +++-- 10 files changed, 100 insertions(+), 68 deletions(-) diff --git a/fs/aio.c b/fs/aio.c index 3490d1fa0e16..215d19dbbefa 100644 --- a/fs/aio.c +++ b/fs/aio.c @@ -461,6 +461,7 @@ static const struct address_space_operations aio_ctx_aops = { static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct aio_ring *ring; struct mm_struct *mm = current->mm; unsigned long size, unused; @@ -521,7 +522,7 @@ static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events) ctx->mmap_size = nr_pages * PAGE_SIZE; pr_debug("attempting mmap of %lu bytes\n", ctx->mmap_size); - if (down_write_killable(&mm->mmap_sem)) { + if (mm_write_lock_killable(mm, &mmrange)) { ctx->mmap_size = 0; aio_free_ring(ctx); return -EINTR; @@ -530,7 +531,7 @@ static int aio_setup_ring(struct kioctx *ctx, unsigned int nr_events) ctx->mmap_base = do_mmap_pgoff(ctx->aio_ring_file, 0, ctx->mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, 0, &unused, NULL); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); if (IS_ERR((void *)ctx->mmap_base)) { ctx->mmap_size = 0; aio_free_ring(ctx); diff --git a/fs/coredump.c b/fs/coredump.c index e42e17e55bfd..433713b63187 100644 --- a/fs/coredump.c +++ b/fs/coredump.c @@ -409,6 +409,7 @@ static int zap_threads(struct task_struct *tsk, struct mm_struct *mm, static int coredump_wait(int exit_code, struct core_state *core_state) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct task_struct *tsk = current; struct mm_struct *mm = tsk->mm; int core_waiters = -EBUSY; @@ -417,12 +418,12 @@ static int coredump_wait(int exit_code, struct core_state *core_state) core_state->dumper.task = tsk; core_state->dumper.next = NULL; - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; if (!mm->core_state) core_waiters = zap_threads(tsk, mm, core_state, exit_code); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); if (core_waiters > 0) { struct core_thread *ptr; diff --git a/fs/exec.c b/fs/exec.c index e96fd5328739..fbcb36bc4fd1 100644 --- a/fs/exec.c +++ b/fs/exec.c @@ -241,6 +241,7 @@ static void flush_arg_page(struct linux_binprm *bprm, unsigned long pos, static int __bprm_mm_init(struct linux_binprm *bprm) { + DEFINE_RANGE_LOCK_FULL(mmrange); int err; struct vm_area_struct *vma = NULL; struct mm_struct *mm = bprm->mm; @@ -250,7 +251,7 @@ static int __bprm_mm_init(struct linux_binprm *bprm) return -ENOMEM; vma_set_anonymous(vma); - if (down_write_killable(&mm->mmap_sem)) { + if (mm_write_lock_killable(mm, &mmrange)) { err = -EINTR; goto err_free; } @@ -273,11 +274,11 @@ static int __bprm_mm_init(struct linux_binprm *bprm) mm->stack_vm = mm->total_vm = 1; arch_bprm_mm_init(mm, vma); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); bprm->p = vma->vm_end - sizeof(void *); return 0; err: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); err_free: bprm->vma = NULL; vm_area_free(vma); @@ -691,6 +692,7 @@ int setup_arg_pages(struct linux_binprm *bprm, unsigned long stack_top, int executable_stack) { + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long ret; unsigned long stack_shift; struct mm_struct *mm = current->mm; @@ -738,7 +740,7 @@ int setup_arg_pages(struct linux_binprm *bprm, bprm->loader -= stack_shift; bprm->exec -= stack_shift; - if (down_write_killable(&mm->mmap_sem)) + if (mm_write_lock_killable(mm, &mmrange)) return -EINTR; vm_flags = VM_STACK_FLAGS; @@ -795,7 +797,7 @@ int setup_arg_pages(struct linux_binprm *bprm, ret = -EFAULT; out_unlock: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); return ret; } EXPORT_SYMBOL(setup_arg_pages); @@ -1010,6 +1012,7 @@ static int exec_mmap(struct mm_struct *mm) { struct task_struct *tsk; struct mm_struct *old_mm, *active_mm; + DEFINE_RANGE_LOCK_FULL(mmrange); /* Notify parent that we're no longer interested in the old VM */ tsk = current; @@ -1024,9 +1027,9 @@ static int exec_mmap(struct mm_struct *mm) * through with the exec. We must hold mmap_sem around * checking core_state and changing tsk->mm. */ - down_read(&old_mm->mmap_sem); + mm_read_lock(old_mm, &mmrange); if (unlikely(old_mm->core_state)) { - up_read(&old_mm->mmap_sem); + mm_read_unlock(old_mm, &mmrange); return -EINTR; } } @@ -1039,7 +1042,7 @@ static int exec_mmap(struct mm_struct *mm) vmacache_flush(tsk); task_unlock(tsk); if (old_mm) { - up_read(&old_mm->mmap_sem); + mm_read_unlock(old_mm, &mmrange); BUG_ON(active_mm != old_mm); setmax_mm_hiwater_rss(&tsk->signal->maxrss, old_mm); mm_update_next_owner(old_mm); diff --git a/fs/io_uring.c b/fs/io_uring.c index e11d77181398..16c06811193b 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -2597,6 +2597,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg, struct page **pages = NULL; int i, j, got_pages = 0; int ret = -EINVAL; + DEFINE_RANGE_LOCK_FULL(mmrange); if (ctx->user_bufs) return -EBUSY; @@ -2671,7 +2672,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg, } ret = 0; - down_read(¤t->mm->mmap_sem); + mm_read_lock(current->mm, &mmrange); pret = get_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM, pages, vmas); @@ -2689,7 +2690,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg, } else { ret = pret < 0 ? pret : -EFAULT; } - up_read(¤t->mm->mmap_sem); + mm_read_unlock(current->mm, &mmrange); if (ret) { /* * if we did partial map, or found file backed vmas, diff --git a/fs/proc/base.c b/fs/proc/base.c index 9c8ca6cd3ce4..63d0fea104af 100644 --- a/fs/proc/base.c +++ b/fs/proc/base.c @@ -1962,9 +1962,11 @@ static int map_files_d_revalidate(struct dentry *dentry, unsigned int flags) goto out; if (!dname_to_vma_addr(dentry, &vm_start, &vm_end)) { - down_read(&mm->mmap_sem); + DEFINE_RANGE_LOCK_FULL(mmrange); + + mm_read_lock(mm, &mmrange); exact_vma_exists = !!find_exact_vma(mm, vm_start, vm_end); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); } mmput(mm); @@ -1995,6 +1997,7 @@ static int map_files_get_link(struct dentry *dentry, struct path *path) struct task_struct *task; struct mm_struct *mm; int rc; + DEFINE_RANGE_LOCK_FULL(mmrange); rc = -ENOENT; task = get_proc_task(d_inode(dentry)); @@ -2011,14 +2014,14 @@ static int map_files_get_link(struct dentry *dentry, struct path *path) goto out_mmput; rc = -ENOENT; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_exact_vma(mm, vm_start, vm_end); if (vma && vma->vm_file) { *path = vma->vm_file->f_path; path_get(path); rc = 0; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); out_mmput: mmput(mm); @@ -2089,6 +2092,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir, struct task_struct *task; struct dentry *result; struct mm_struct *mm; + DEFINE_RANGE_LOCK_FULL(mmrange); result = ERR_PTR(-ENOENT); task = get_proc_task(dir); @@ -2107,7 +2111,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir, if (!mm) goto out_put_task; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); vma = find_exact_vma(mm, vm_start, vm_end); if (!vma) goto out_no_vma; @@ -2117,7 +2121,7 @@ static struct dentry *proc_map_files_lookup(struct inode *dir, (void *)(unsigned long)vma->vm_file->f_mode); out_no_vma: - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); mmput(mm); out_put_task: put_task_struct(task); @@ -2141,6 +2145,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) GENRADIX(struct map_files_info) fa; struct map_files_info *p; int ret; + DEFINE_RANGE_LOCK_FULL(mmrange); genradix_init(&fa); @@ -2160,7 +2165,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) mm = get_task_mm(task); if (!mm) goto out_put_task; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); nr_files = 0; @@ -2183,7 +2188,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) p = genradix_ptr_alloc(&fa, nr_files++, GFP_KERNEL); if (!p) { ret = -ENOMEM; - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); mmput(mm); goto out_put_task; } @@ -2192,7 +2197,7 @@ proc_map_files_readdir(struct file *file, struct dir_context *ctx) p->end = vma->vm_end; p->mode = vma->vm_file->f_mode; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); mmput(mm); for (i = 0; i < nr_files; i++) { diff --git a/fs/proc/internal.h b/fs/proc/internal.h index d1671e97f7fe..df6f0ec84a8f 100644 --- a/fs/proc/internal.h +++ b/fs/proc/internal.h @@ -15,6 +15,7 @@ #include <linux/spinlock.h> #include <linux/atomic.h> #include <linux/binfmts.h> +#include <linux/range_lock.h> #include <linux/sched/coredump.h> #include <linux/sched/task.h> @@ -287,6 +288,7 @@ struct proc_maps_private { #ifdef CONFIG_NUMA struct mempolicy *task_mempolicy; #endif + struct range_lock mmrange; } __randomize_layout; struct mm_struct *proc_mem_open(struct inode *inode, unsigned int mode); diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c index a1c2ad9f960a..7ab5c6f5b8aa 100644 --- a/fs/proc/task_mmu.c +++ b/fs/proc/task_mmu.c @@ -128,7 +128,7 @@ static void vma_stop(struct proc_maps_private *priv) struct mm_struct *mm = priv->mm; release_task_mempolicy(priv); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &priv->mmrange); mmput(mm); } @@ -166,7 +166,9 @@ static void *m_start(struct seq_file *m, loff_t *ppos) if (!mm || !mmget_not_zero(mm)) return NULL; - down_read(&mm->mmap_sem); + range_lock_init_full(&priv->mmrange); + + mm_read_lock(mm, &priv->mmrange); hold_task_mempolicy(priv); priv->tail_vma = get_gate_vma(mm); @@ -828,7 +830,7 @@ static int show_smaps_rollup(struct seq_file *m, void *v) memset(&mss, 0, sizeof(mss)); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &priv->mmrange); hold_task_mempolicy(priv); for (vma = priv->mm->mmap; vma; vma = vma->vm_next) { @@ -844,7 +846,7 @@ static int show_smaps_rollup(struct seq_file *m, void *v) __show_smap(m, &mss); release_task_mempolicy(priv); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &priv->mmrange); mmput(mm); out_put_task: @@ -1080,6 +1082,7 @@ static int clear_refs_test_walk(unsigned long start, unsigned long end, static ssize_t clear_refs_write(struct file *file, const char __user *buf, size_t count, loff_t *ppos) { + DEFINE_RANGE_LOCK_FULL(mmrange); struct task_struct *task; char buffer[PROC_NUMBUF]; struct mm_struct *mm; @@ -1118,7 +1121,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, }; if (type == CLEAR_REFS_MM_HIWATER_RSS) { - if (down_write_killable(&mm->mmap_sem)) { + if (mm_write_lock_killable(mm, &mmrange)) { count = -EINTR; goto out_mm; } @@ -1128,18 +1131,18 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, * resident set size to this mm's current rss value. */ reset_mm_hiwater_rss(mm); - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); goto out_mm; } - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); tlb_gather_mmu(&tlb, mm, 0, -1); if (type == CLEAR_REFS_SOFT_DIRTY) { for (vma = mm->mmap; vma; vma = vma->vm_next) { if (!(vma->vm_flags & VM_SOFTDIRTY)) continue; - up_read(&mm->mmap_sem); - if (down_write_killable(&mm->mmap_sem)) { + mm_read_unlock(mm, &mmrange); + if (mm_write_lock_killable(mm, &mmrange)) { count = -EINTR; goto out_mm; } @@ -1158,14 +1161,14 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, * failed like if * get_proc_task() fails? */ - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); goto out_mm; } for (vma = mm->mmap; vma; vma = vma->vm_next) { vma->vm_flags &= ~VM_SOFTDIRTY; vma_set_page_prot(vma); } - downgrade_write(&mm->mmap_sem); + mm_downgrade_write(mm, &mmrange); break; } @@ -1177,7 +1180,7 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf, if (type == CLEAR_REFS_SOFT_DIRTY) mmu_notifier_invalidate_range_end(&range); tlb_finish_mmu(&tlb, 0, -1); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); out_mm: mmput(mm); } @@ -1484,6 +1487,7 @@ static ssize_t pagemap_read(struct file *file, char __user *buf, unsigned long start_vaddr; unsigned long end_vaddr; int ret = 0, copied = 0; + DEFINE_RANGE_LOCK_FULL(mmrange); if (!mm || !mmget_not_zero(mm)) goto out; @@ -1539,9 +1543,9 @@ static ssize_t pagemap_read(struct file *file, char __user *buf, /* overflow ? */ if (end < start_vaddr || end > end_vaddr) end = end_vaddr; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); ret = walk_page_range(start_vaddr, end, &pagemap_walk); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); start_vaddr = end; len = min(count, PM_ENTRY_BYTES * pm.pos); diff --git a/fs/proc/task_nommu.c b/fs/proc/task_nommu.c index 36bf0f2e102e..32bf2860eff3 100644 --- a/fs/proc/task_nommu.c +++ b/fs/proc/task_nommu.c @@ -23,9 +23,10 @@ void task_mem(struct seq_file *m, struct mm_struct *mm) struct vm_area_struct *vma; struct vm_region *region; struct rb_node *p; + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long bytes = 0, sbytes = 0, slack = 0, size; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { vma = rb_entry(p, struct vm_area_struct, vm_rb); @@ -77,7 +78,7 @@ void task_mem(struct seq_file *m, struct mm_struct *mm) "Shared:\t%8lu bytes\n", bytes, slack, sbytes); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); } unsigned long task_vsize(struct mm_struct *mm) @@ -85,13 +86,14 @@ unsigned long task_vsize(struct mm_struct *mm) struct vm_area_struct *vma; struct rb_node *p; unsigned long vsize = 0; + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { vma = rb_entry(p, struct vm_area_struct, vm_rb); vsize += vma->vm_end - vma->vm_start; } - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); return vsize; } @@ -103,8 +105,9 @@ unsigned long task_statm(struct mm_struct *mm, struct vm_region *region; struct rb_node *p; unsigned long size = kobjsize(mm); + DEFINE_RANGE_LOCK_FULL(mmrange); - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) { vma = rb_entry(p, struct vm_area_struct, vm_rb); size += kobjsize(vma); @@ -119,7 +122,7 @@ unsigned long task_statm(struct mm_struct *mm, >> PAGE_SHIFT; *data = (PAGE_ALIGN(mm->start_stack) - (mm->start_data & PAGE_MASK)) >> PAGE_SHIFT; - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); size >>= PAGE_SHIFT; size += *text + *data; *resident = size; @@ -201,6 +204,7 @@ static void *m_start(struct seq_file *m, loff_t *pos) struct mm_struct *mm; struct rb_node *p; loff_t n = *pos; + DEFINE_RANGE_LOCK_FULL(mmrange); /* pin the task and mm whilst we play with them */ priv->task = get_proc_task(priv->inode); @@ -211,13 +215,13 @@ static void *m_start(struct seq_file *m, loff_t *pos) if (!mm || !mmget_not_zero(mm)) return NULL; - down_read(&mm->mmap_sem); + mm_read_lock(mm, &mmrange); /* start from the Nth VMA */ for (p = rb_first(&mm->mm_rb); p; p = rb_next(p)) if (n-- == 0) return p; - up_read(&mm->mmap_sem); + mm_read_unlock(mm, &mmrange); mmput(mm); return NULL; } @@ -227,7 +231,7 @@ static void m_stop(struct seq_file *m, void *_vml) struct proc_maps_private *priv = m->private; if (!IS_ERR_OR_NULL(_vml)) { - up_read(&priv->mm->mmap_sem); + mm_read_unlock(priv->mm, &mmrange); mmput(priv->mm); } if (priv->task) { diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 3b30301c90ec..3592f6d71778 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -220,13 +220,14 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx, struct vm_area_struct *vma, unsigned long address, unsigned long flags, - unsigned long reason) + unsigned long reason, + struct range_lock *mmrange) { struct mm_struct *mm = ctx->mm; pte_t *ptep, pte; bool ret = true; - VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem)); + VM_BUG_ON(!mm_is_locked(mm, mmrange)); ptep = huge_pte_offset(mm, address, vma_mmu_pagesize(vma)); @@ -252,7 +253,9 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx, struct vm_area_struct *vma, unsigned long address, unsigned long flags, - unsigned long reason) + unsigned long reason, + struct range_lock *mmrange) + { return false; /* should never get here */ } @@ -268,7 +271,8 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx, static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx, unsigned long address, unsigned long flags, - unsigned long reason) + unsigned long reason, + struct range_lock *mmrange) { struct mm_struct *mm = ctx->mm; pgd_t *pgd; @@ -278,7 +282,7 @@ static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx, pte_t *pte; bool ret = true; - VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem)); + VM_BUG_ON(!mm_is_locked(mm, mmrange)); pgd = pgd_offset(mm, address); if (!pgd_present(*pgd)) @@ -368,7 +372,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason) * Coredumping runs without mmap_sem so we can only check that * the mmap_sem is held, if PF_DUMPCORE was not set. */ - WARN_ON_ONCE(!rwsem_is_locked(&mm->mmap_sem)); + WARN_ON_ONCE(!mm_is_locked(mm, vmf->lockrange)); ctx = vmf->vma->vm_userfaultfd_ctx.ctx; if (!ctx) @@ -476,12 +480,13 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason) if (!is_vm_hugetlb_page(vmf->vma)) must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags, - reason); + reason, vmf->lockrange); else must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma, vmf->address, - vmf->flags, reason); - up_read(&mm->mmap_sem); + vmf->flags, reason, + vmf->lockrange); + mm_read_unlock(mm, vmf->lockrange); if (likely(must_wait && !READ_ONCE(ctx->released) && (return_to_userland ? !signal_pending(current) : @@ -535,7 +540,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason) * and there's no need to retake the mmap_sem * in such case. */ - down_read(&mm->mmap_sem); + mm_read_lock(mm, vmf->lockrange); ret = VM_FAULT_NOPAGE; } } @@ -628,9 +633,10 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, if (release_new_ctx) { struct vm_area_struct *vma; struct mm_struct *mm = release_new_ctx->mm; + DEFINE_RANGE_LOCK_FULL(mmrange); /* the various vma->vm_userfaultfd_ctx still points to it */ - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); /* no task can run (and in turn coredump) yet */ VM_WARN_ON(!mmget_still_valid(mm)); for (vma = mm->mmap; vma; vma = vma->vm_next) @@ -638,7 +644,7 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING); } - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); userfaultfd_ctx_put(release_new_ctx); } @@ -780,7 +786,8 @@ void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *vm_ctx, } bool userfaultfd_remove(struct vm_area_struct *vma, - unsigned long start, unsigned long end) + unsigned long start, unsigned long end, + struct range_lock *mmrange) { struct mm_struct *mm = vma->vm_mm; struct userfaultfd_ctx *ctx; @@ -792,7 +799,7 @@ bool userfaultfd_remove(struct vm_area_struct *vma, userfaultfd_ctx_get(ctx); WRITE_ONCE(ctx->mmap_changing, true); - up_read(&mm->mmap_sem); + mm_read_unlock(mm, mmrange); msg_init(&ewq.msg); @@ -872,6 +879,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) /* len == 0 means wake all */ struct userfaultfd_wake_range range = { .len = 0, }; unsigned long new_flags; + DEFINE_RANGE_LOCK_FULL(mmrange); WRITE_ONCE(ctx->released, true); @@ -886,7 +894,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) * it's critical that released is set to true (above), before * taking the mmap_sem for writing. */ - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); if (!mmget_still_valid(mm)) goto skip_mm; prev = NULL; @@ -912,7 +920,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } skip_mm: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); mmput(mm); wakeup: /* @@ -1299,6 +1307,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, unsigned long vm_flags, new_flags; bool found; bool basic_ioctls; + DEFINE_RANGE_LOCK_FULL(mmrange); unsigned long start, end, vma_end; user_uffdio_register = (struct uffdio_register __user *) arg; @@ -1339,7 +1348,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, if (!mmget_not_zero(mm)) goto out; - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); if (!mmget_still_valid(mm)) goto out_unlock; vma = find_vma_prev(mm, start, &prev); @@ -1483,7 +1492,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, vma = vma->vm_next; } while (vma && vma->vm_start < end); out_unlock: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); mmput(mm); if (!ret) { /* @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, bool found; unsigned long start, end, vma_end; const void __user *buf = (void __user *)arg; + DEFINE_RANGE_LOCK_FULL(mmrange); ret = -EFAULT; if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) @@ -1528,7 +1538,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, if (!mmget_not_zero(mm)) goto out; - down_write(&mm->mmap_sem); + mm_write_lock(mm, &mmrange); if (!mmget_still_valid(mm)) goto out_unlock; vma = find_vma_prev(mm, start, &prev); @@ -1645,7 +1655,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, vma = vma->vm_next; } while (vma && vma->vm_start < end); out_unlock: - up_write(&mm->mmap_sem); + mm_write_unlock(mm, &mmrange); mmput(mm); out: return ret; diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index ac9d71e24b81..c8d3c102ce5e 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -68,7 +68,7 @@ extern void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *, extern bool userfaultfd_remove(struct vm_area_struct *vma, unsigned long start, - unsigned long end); + unsigned long end, struct range_lock *mmrange); extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, unsigned long start, unsigned long end, @@ -125,7 +125,8 @@ static inline void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *ctx, static inline bool userfaultfd_remove(struct vm_area_struct *vma, unsigned long start, - unsigned long end) + unsigned long end, + struct range_lock *mmrange) { return true; } -- 2.16.4