On Mon, Mar 25, 2019 at 10:40:02AM -0400, Jerome Glisse wrote:
> From: Jérôme Glisse <jgli...@redhat.com>
> 
> Every time i read the code to check that the HMM structure does not
> vanish before it should thanks to the many lock protecting its removal
> i get a headache. Switch to reference counting instead it is much
> easier to follow and harder to break. This also remove some code that
> is no longer needed with refcounting.
> 
> Changes since v1:
>     - removed bunch of useless check (if API is use with bogus argument
>       better to fail loudly so user fix their code)
>     - s/hmm_get/mm_get_hmm/
> 
> Signed-off-by: Jérôme Glisse <jgli...@redhat.com>
> Reviewed-by: Ralph Campbell <rcampb...@nvidia.com>
> Cc: John Hubbard <jhubb...@nvidia.com>
> Cc: Andrew Morton <a...@linux-foundation.org>
> Cc: Dan Williams <dan.j.willi...@intel.com>
> ---
>  include/linux/hmm.h |   2 +
>  mm/hmm.c            | 170 ++++++++++++++++++++++++++++----------------
>  2 files changed, 112 insertions(+), 60 deletions(-)
> 
> diff --git a/include/linux/hmm.h b/include/linux/hmm.h
> index ad50b7b4f141..716fc61fa6d4 100644
> --- a/include/linux/hmm.h
> +++ b/include/linux/hmm.h
> @@ -131,6 +131,7 @@ enum hmm_pfn_value_e {
>  /*
>   * struct hmm_range - track invalidation lock on virtual address range
>   *
> + * @hmm: the core HMM structure this range is active against
>   * @vma: the vm area struct for the range
>   * @list: all range lock are on a list
>   * @start: range virtual start address (inclusive)
> @@ -142,6 +143,7 @@ enum hmm_pfn_value_e {
>   * @valid: pfns array did not change since it has been fill by an HMM 
> function
>   */
>  struct hmm_range {
> +     struct hmm              *hmm;
>       struct vm_area_struct   *vma;
>       struct list_head        list;
>       unsigned long           start;
> diff --git a/mm/hmm.c b/mm/hmm.c
> index fe1cd87e49ac..306e57f7cded 100644
> --- a/mm/hmm.c
> +++ b/mm/hmm.c
> @@ -50,6 +50,7 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
>   */
>  struct hmm {
>       struct mm_struct        *mm;
> +     struct kref             kref;
>       spinlock_t              lock;
>       struct list_head        ranges;
>       struct list_head        mirrors;
> @@ -57,6 +58,16 @@ struct hmm {
>       struct rw_semaphore     mirrors_sem;
>  };
>  
> +static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
> +{
> +     struct hmm *hmm = READ_ONCE(mm->hmm);
> +
> +     if (hmm && kref_get_unless_zero(&hmm->kref))
> +             return hmm;
> +
> +     return NULL;
> +}
> +
>  /*
>   * hmm_register - register HMM against an mm (HMM internal)
>   *
> @@ -67,14 +78,9 @@ struct hmm {
>   */
>  static struct hmm *hmm_register(struct mm_struct *mm)
>  {
> -     struct hmm *hmm = READ_ONCE(mm->hmm);
> +     struct hmm *hmm = mm_get_hmm(mm);

FWIW: having hmm_register == "hmm get" is a bit confusing...

Ira

>       bool cleanup = false;
>  
> -     /*
> -      * The hmm struct can only be freed once the mm_struct goes away,
> -      * hence we should always have pre-allocated an new hmm struct
> -      * above.
> -      */
>       if (hmm)
>               return hmm;
>  
> @@ -86,6 +92,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
>       hmm->mmu_notifier.ops = NULL;
>       INIT_LIST_HEAD(&hmm->ranges);
>       spin_lock_init(&hmm->lock);
> +     kref_init(&hmm->kref);
>       hmm->mm = mm;
>  
>       spin_lock(&mm->page_table_lock);
> @@ -106,7 +113,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
>       if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
>               goto error_mm;
>  
> -     return mm->hmm;
> +     return hmm;
>  
>  error_mm:
>       spin_lock(&mm->page_table_lock);
> @@ -118,9 +125,41 @@ static struct hmm *hmm_register(struct mm_struct *mm)
>       return NULL;
>  }
>  
> +static void hmm_free(struct kref *kref)
> +{
> +     struct hmm *hmm = container_of(kref, struct hmm, kref);
> +     struct mm_struct *mm = hmm->mm;
> +
> +     mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
> +
> +     spin_lock(&mm->page_table_lock);
> +     if (mm->hmm == hmm)
> +             mm->hmm = NULL;
> +     spin_unlock(&mm->page_table_lock);
> +
> +     kfree(hmm);
> +}
> +
> +static inline void hmm_put(struct hmm *hmm)
> +{
> +     kref_put(&hmm->kref, hmm_free);
> +}
> +
>  void hmm_mm_destroy(struct mm_struct *mm)
>  {
> -     kfree(mm->hmm);
> +     struct hmm *hmm;
> +
> +     spin_lock(&mm->page_table_lock);
> +     hmm = mm_get_hmm(mm);
> +     mm->hmm = NULL;
> +     if (hmm) {
> +             hmm->mm = NULL;
> +             spin_unlock(&mm->page_table_lock);
> +             hmm_put(hmm);
> +             return;
> +     }
> +
> +     spin_unlock(&mm->page_table_lock);
>  }
>  
>  static int hmm_invalidate_range(struct hmm *hmm, bool device,
> @@ -165,7 +204,7 @@ static int hmm_invalidate_range(struct hmm *hmm, bool 
> device,
>  static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
>  {
>       struct hmm_mirror *mirror;
> -     struct hmm *hmm = mm->hmm;
> +     struct hmm *hmm = mm_get_hmm(mm);
>  
>       down_write(&hmm->mirrors_sem);
>       mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
> @@ -186,13 +225,16 @@ static void hmm_release(struct mmu_notifier *mn, struct 
> mm_struct *mm)
>                                                 struct hmm_mirror, list);
>       }
>       up_write(&hmm->mirrors_sem);
> +
> +     hmm_put(hmm);
>  }
>  
>  static int hmm_invalidate_range_start(struct mmu_notifier *mn,
>                       const struct mmu_notifier_range *range)
>  {
> +     struct hmm *hmm = mm_get_hmm(range->mm);
>       struct hmm_update update;
> -     struct hmm *hmm = range->mm->hmm;
> +     int ret;
>  
>       VM_BUG_ON(!hmm);
>  
> @@ -200,14 +242,16 @@ static int hmm_invalidate_range_start(struct 
> mmu_notifier *mn,
>       update.end = range->end;
>       update.event = HMM_UPDATE_INVALIDATE;
>       update.blockable = range->blockable;
> -     return hmm_invalidate_range(hmm, true, &update);
> +     ret = hmm_invalidate_range(hmm, true, &update);
> +     hmm_put(hmm);
> +     return ret;
>  }
>  
>  static void hmm_invalidate_range_end(struct mmu_notifier *mn,
>                       const struct mmu_notifier_range *range)
>  {
> +     struct hmm *hmm = mm_get_hmm(range->mm);
>       struct hmm_update update;
> -     struct hmm *hmm = range->mm->hmm;
>  
>       VM_BUG_ON(!hmm);
>  
> @@ -216,6 +260,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier 
> *mn,
>       update.event = HMM_UPDATE_INVALIDATE;
>       update.blockable = true;
>       hmm_invalidate_range(hmm, false, &update);
> +     hmm_put(hmm);
>  }
>  
>  static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
> @@ -241,24 +286,13 @@ int hmm_mirror_register(struct hmm_mirror *mirror, 
> struct mm_struct *mm)
>       if (!mm || !mirror || !mirror->ops)
>               return -EINVAL;
>  
> -again:
>       mirror->hmm = hmm_register(mm);
>       if (!mirror->hmm)
>               return -ENOMEM;
>  
>       down_write(&mirror->hmm->mirrors_sem);
> -     if (mirror->hmm->mm == NULL) {
> -             /*
> -              * A racing hmm_mirror_unregister() is about to destroy the hmm
> -              * struct. Try again to allocate a new one.
> -              */
> -             up_write(&mirror->hmm->mirrors_sem);
> -             mirror->hmm = NULL;
> -             goto again;
> -     } else {
> -             list_add(&mirror->list, &mirror->hmm->mirrors);
> -             up_write(&mirror->hmm->mirrors_sem);
> -     }
> +     list_add(&mirror->list, &mirror->hmm->mirrors);
> +     up_write(&mirror->hmm->mirrors_sem);
>  
>       return 0;
>  }
> @@ -273,33 +307,18 @@ EXPORT_SYMBOL(hmm_mirror_register);
>   */
>  void hmm_mirror_unregister(struct hmm_mirror *mirror)
>  {
> -     bool should_unregister = false;
> -     struct mm_struct *mm;
> -     struct hmm *hmm;
> +     struct hmm *hmm = READ_ONCE(mirror->hmm);
>  
> -     if (mirror->hmm == NULL)
> +     if (hmm == NULL)
>               return;
>  
> -     hmm = mirror->hmm;
>       down_write(&hmm->mirrors_sem);
>       list_del_init(&mirror->list);
> -     should_unregister = list_empty(&hmm->mirrors);
> +     /* To protect us against double unregister ... */
>       mirror->hmm = NULL;
> -     mm = hmm->mm;
> -     hmm->mm = NULL;
>       up_write(&hmm->mirrors_sem);
>  
> -     if (!should_unregister || mm == NULL)
> -             return;
> -
> -     mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
> -
> -     spin_lock(&mm->page_table_lock);
> -     if (mm->hmm == hmm)
> -             mm->hmm = NULL;
> -     spin_unlock(&mm->page_table_lock);
> -
> -     kfree(hmm);
> +     hmm_put(hmm);
>  }
>  EXPORT_SYMBOL(hmm_mirror_unregister);
>  
> @@ -708,6 +727,8 @@ int hmm_vma_get_pfns(struct hmm_range *range)
>       struct mm_walk mm_walk;
>       struct hmm *hmm;
>  
> +     range->hmm = NULL;
> +
>       /* Sanity check, this really should not happen ! */
>       if (range->start < vma->vm_start || range->start >= vma->vm_end)
>               return -EINVAL;
> @@ -717,14 +738,18 @@ int hmm_vma_get_pfns(struct hmm_range *range)
>       hmm = hmm_register(vma->vm_mm);
>       if (!hmm)
>               return -ENOMEM;
> -     /* Caller must have registered a mirror, via hmm_mirror_register() ! */
> -     if (!hmm->mmu_notifier.ops)
> +
> +     /* Check if hmm_mm_destroy() was call. */
> +     if (hmm->mm == NULL) {
> +             hmm_put(hmm);
>               return -EINVAL;
> +     }
>  
>       /* FIXME support hugetlb fs */
>       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
>                       vma_is_dax(vma)) {
>               hmm_pfns_special(range);
> +             hmm_put(hmm);
>               return -EINVAL;
>       }
>  
> @@ -736,6 +761,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
>                * operations such has atomic access would not work.
>                */
>               hmm_pfns_clear(range, range->pfns, range->start, range->end);
> +             hmm_put(hmm);
>               return -EPERM;
>       }
>  
> @@ -758,6 +784,12 @@ int hmm_vma_get_pfns(struct hmm_range *range)
>       mm_walk.pte_hole = hmm_vma_walk_hole;
>  
>       walk_page_range(range->start, range->end, &mm_walk);
> +     /*
> +      * Transfer hmm reference to the range struct it will be drop inside
> +      * the hmm_vma_range_done() function (which _must_ be call if this
> +      * function return 0).
> +      */
> +     range->hmm = hmm;
>       return 0;
>  }
>  EXPORT_SYMBOL(hmm_vma_get_pfns);
> @@ -802,25 +834,27 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
>   */
>  bool hmm_vma_range_done(struct hmm_range *range)
>  {
> -     unsigned long npages = (range->end - range->start) >> PAGE_SHIFT;
> -     struct hmm *hmm;
> +     bool ret = false;
>  
> -     if (range->end <= range->start) {
> +     /* Sanity check this really should not happen. */
> +     if (range->hmm == NULL || range->end <= range->start) {
>               BUG();
>               return false;
>       }
>  
> -     hmm = hmm_register(range->vma->vm_mm);
> -     if (!hmm) {
> -             memset(range->pfns, 0, sizeof(*range->pfns) * npages);
> -             return false;
> -     }
> -
> -     spin_lock(&hmm->lock);
> +     spin_lock(&range->hmm->lock);
>       list_del_rcu(&range->list);
> -     spin_unlock(&hmm->lock);
> +     ret = range->valid;
> +     spin_unlock(&range->hmm->lock);
>  
> -     return range->valid;
> +     /* Is the mm still alive ? */
> +     if (range->hmm->mm == NULL)
> +             ret = false;
> +
> +     /* Drop reference taken by hmm_vma_fault() or hmm_vma_get_pfns() */
> +     hmm_put(range->hmm);
> +     range->hmm = NULL;
> +     return ret;
>  }
>  EXPORT_SYMBOL(hmm_vma_range_done);
>  
> @@ -880,6 +914,8 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
>       struct hmm *hmm;
>       int ret;
>  
> +     range->hmm = NULL;
> +
>       /* Sanity check, this really should not happen ! */
>       if (range->start < vma->vm_start || range->start >= vma->vm_end)
>               return -EINVAL;
> @@ -891,14 +927,18 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
>               hmm_pfns_clear(range, range->pfns, range->start, range->end);
>               return -ENOMEM;
>       }
> -     /* Caller must have registered a mirror using hmm_mirror_register() */
> -     if (!hmm->mmu_notifier.ops)
> +
> +     /* Check if hmm_mm_destroy() was call. */
> +     if (hmm->mm == NULL) {
> +             hmm_put(hmm);
>               return -EINVAL;
> +     }
>  
>       /* FIXME support hugetlb fs */
>       if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
>                       vma_is_dax(vma)) {
>               hmm_pfns_special(range);
> +             hmm_put(hmm);
>               return -EINVAL;
>       }
>  
> @@ -910,6 +950,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
>                * operations such has atomic access would not work.
>                */
>               hmm_pfns_clear(range, range->pfns, range->start, range->end);
> +             hmm_put(hmm);
>               return -EPERM;
>       }
>  
> @@ -945,7 +986,16 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
>               hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
>                              range->end);
>               hmm_vma_range_done(range);
> +             hmm_put(hmm);
> +     } else {
> +             /*
> +              * Transfer hmm reference to the range struct it will be drop
> +              * inside the hmm_vma_range_done() function (which _must_ be
> +              * call if this function return 0).
> +              */
> +             range->hmm = hmm;
>       }
> +
>       return ret;
>  }
>  EXPORT_SYMBOL(hmm_vma_fault);
> -- 
> 2.17.2
> 

Reply via email to