On Sun, Apr 16, 2017 at 07:42:27PM -0600, Alex Williamson wrote:

[...]

> -static void vfio_lock_acct(struct task_struct *task, long npage)
> +static int vfio_lock_acct(struct task_struct *task, long npage, bool 
> lock_cap)
>  {
> -     struct vwork *vwork;
>       struct mm_struct *mm;
>       bool is_current;
> +     int ret;
>  
>       if (!npage)
> -             return;
> +             return 0;
>  
>       is_current = (task->mm == current->mm);
>  
>       mm = is_current ? task->mm : get_task_mm(task);
>       if (!mm)
> -             return; /* process exited */
> +             return -ESRCH; /* process exited */
>  
> -     if (down_write_trylock(&mm->mmap_sem)) {
> -             mm->locked_vm += npage;
> -             up_write(&mm->mmap_sem);
> -             if (!is_current)
> -                     mmput(mm);
> -             return;
> -     }
> +     ret = down_write_killable(&mm->mmap_sem);
> +     if (!ret) {
> +             if (npage < 0 || lock_cap) {

Nit: maybe we can avoid passing in lock_cap in all the callers of
vfio_lock_acct() and fetch it via has_capability() only if npage < 0?
IMHO that'll keep the vfio_lock_acct() interface cleaner, and we won't
need to pass in "false" any time when doing unpins.

[...]

> @@ -405,7 +379,7 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned 
> long vaddr,
>  static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
>                                 long npage, unsigned long *pfn_base)
>  {
> -     unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
> +     unsigned long pfn = 0, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
>       bool lock_cap = capable(CAP_IPC_LOCK);
>       long ret, pinned = 0, lock_acct = 0;
>       bool rsvd;
> @@ -442,8 +416,6 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, 
> unsigned long vaddr,
>       /* Lock all the consecutive pages from pfn_base */
>       for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
>            pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
> -             unsigned long pfn = 0;
> -
>               ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
>               if (ret)
>                       break;
> @@ -460,14 +432,25 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, 
> unsigned long vaddr,
>                               put_pfn(pfn, dma->prot);
>                               pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
>                                       __func__, limit << PAGE_SHIFT);
> -                             break;
> +                             ret = -ENOMEM;
> +                             goto unpin_out;
>                       }
>                       lock_acct++;
>               }
>       }
>  
>  out:
> -     vfio_lock_acct(current, lock_acct);
> +     ret = vfio_lock_acct(current, lock_acct, lock_cap);

I just didn't notice this in previous review, but... do we need to
check against !rsvd as well here before doing the accounting?

Thanks!

> +
> +unpin_out:
> +     if (ret) {
> +             if (!rsvd) {
> +                     for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
> +                             put_pfn(pfn, dma->prot);
> +             }
> +
> +             return ret;
> +     }
>  
>       return pinned;
>  }

-- 
Peter Xu

Reply via email to