Add tracking of pages that were pinned via FOLL_PIN.

As mentioned in the FOLL_PIN documentation, callers who effectively set
FOLL_PIN are required to ultimately free such pages via put_user_page().
The effect is similar to FOLL_GET, and may be thought of as "FOLL_GET
for DIO and/or RDMA use".

Pages that have been pinned via FOLL_PIN are identifiable via a
new function call:

   bool page_dma_pinned(struct page *page);

What to do in response to encountering such a page, is left to later
patchsets. There is discussion about this in [1], [2], and [3].

This also changes a BUG_ON(), to a WARN_ON(), in follow_page_mask().

[1] Some slow progress on get_user_pages() (Apr 2, 2019):
    https://lwn.net/Articles/784574/
[2] DMA and get_user_pages() (LPC: Dec 12, 2018):
    https://lwn.net/Articles/774411/
[3] The trouble with get_user_pages() (Apr 30, 2018):
    https://lwn.net/Articles/753027/

Suggested-by: Jan Kara <j...@suse.cz>
Suggested-by: Jérôme Glisse <jgli...@redhat.com>
Signed-off-by: John Hubbard <jhubb...@nvidia.com>
---
 Documentation/core-api/pin_user_pages.rst |   2 +-
 include/linux/mm.h                        |  86 +++++--
 include/linux/mmzone.h                    |   2 +
 include/linux/page_ref.h                  |  10 +
 mm/gup.c                                  | 290 ++++++++++++++++------
 mm/huge_memory.c                          |  54 +++-
 mm/hugetlb.c                              |  39 ++-
 mm/vmstat.c                               |   2 +
 8 files changed, 390 insertions(+), 95 deletions(-)

diff --git a/Documentation/core-api/pin_user_pages.rst 
b/Documentation/core-api/pin_user_pages.rst
index 4f26637a5005..baa288a44a77 100644
--- a/Documentation/core-api/pin_user_pages.rst
+++ b/Documentation/core-api/pin_user_pages.rst
@@ -53,7 +53,7 @@ Which flags are set by each wrapper
 For these pin_user_pages*() functions, FOLL_PIN is OR'd in with whatever gup
 flags the caller provides. The caller is required to pass in a non-null struct
 pages* array, and the function then pin pages by incrementing each by a special
-value. For now, that value is +1, just like get_user_pages*().::
+value: GUP_PIN_COUNTING_BIAS.::
 
  Function
  --------
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 568cbb895f03..ab311b356ab1 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1054,6 +1054,21 @@ static inline __must_check bool try_get_page(struct page 
*page)
        return true;
 }
 
+__must_check bool try_pin_compound_head(struct page *page, int refs);
+
+/**
+ * try_pin_page() - mark a page as being used by pin_user_pages*().
+ *
+ * This is the FOLL_PIN counterpart to try_get_page().
+ *
+ * @page:      pointer to page to be marked
+ * @Return:    true for success, false for failure
+ */
+static inline __must_check bool try_pin_page(struct page *page)
+{
+       return try_pin_compound_head(page, 1);
+}
+
 static inline void put_page(struct page *page)
 {
        page = compound_head(page);
@@ -1071,29 +1086,70 @@ static inline void put_page(struct page *page)
                __put_page(page);
 }
 
-/**
- * put_user_page() - release a gup-pinned page
- * @page:            pointer to page to be released
+/*
+ * GUP_PIN_COUNTING_BIAS, and the associated functions that use it, overload
+ * the page's refcount so that two separate items are tracked: the original 
page
+ * reference count, and also a new count of how many pin_user_pages() calls 
were
+ * made against the page. ("gup-pinned" is another term for the latter).
+ *
+ * With this scheme, pin_user_pages() becomes special: such pages are marked as
+ * distinct from normal pages. As such, the put_user_page() call (and its
+ * variants) must be used in order to release gup-pinned pages.
+ *
+ * Choice of value:
+ *
+ * By making GUP_PIN_COUNTING_BIAS a power of two, debugging of page reference
+ * counts with respect to pin_user_pages() and put_user_page() becomes simpler,
+ * due to the fact that adding an even power of two to the page refcount has 
the
+ * effect of using only the upper N bits, for the code that counts up using the
+ * bias value. This means that the lower bits are left for the exclusive use of
+ * the original code that increments and decrements by one (or at least, by 
much
+ * smaller values than the bias value).
  *
- * Pages that were pinned via pin_user_pages*() must be released via either
- * put_user_page(), or one of the put_user_pages*() routines. This is so that
- * eventually such pages can be separately tracked and uniquely handled. In
- * particular, interactions with RDMA and filesystems need special handling.
+ * Of course, once the lower bits overflow into the upper bits (and this is
+ * OK, because subtraction recovers the original values), then visual 
inspection
+ * no longer suffices to directly view the separate counts. However, for normal
+ * applications that don't have huge page reference counts, this won't be an
+ * issue.
  *
- * put_user_page() and put_page() are not interchangeable, despite this early
- * implementation that makes them look the same. put_user_page() calls must
- * be perfectly matched up with pin*() calls.
+ * Locking: the lockless algorithm described in page_cache_get_speculative()
+ * and page_cache_gup_pin_speculative() provides safe operation for
+ * get_user_pages and page_mkclean and other calls that race to set up page
+ * table entries.
  */
-static inline void put_user_page(struct page *page)
-{
-       put_page(page);
-}
+#define GUP_PIN_COUNTING_BIAS (1UL << 10)
 
+void put_user_page(struct page *page);
 void put_user_pages_dirty_lock(struct page **pages, unsigned long npages,
                               bool make_dirty);
-
 void put_user_pages(struct page **pages, unsigned long npages);
 
+/**
+ * page_dma_pinned() - report if a page is pinned for DMA.
+ *
+ * This function checks if a page has been pinned via a call to
+ * pin_user_pages*().
+ *
+ * The return value is partially fuzzy: false is not fuzzy, because it means
+ * "definitely not pinned for DMA", but true means "probably pinned for DMA, 
but
+ * possibly a false positive due to having at least GUP_PIN_COUNTING_BIAS worth
+ * of normal page references".
+ *
+ * False positives are OK, because: a) it's unlikely for a page to get that 
many
+ * refcounts, and b) all the callers of this routine are expected to be able to
+ * deal gracefully with a false positive.
+ *
+ * For more information, please see Documentation/vm/pin_user_pages.rst.
+ *
+ * @page:      pointer to page to be queried.
+ * @Return:    True, if it is likely that the page has been "dma-pinned".
+ *             False, if the page is definitely not dma-pinned.
+ */
+static inline bool page_dma_pinned(struct page *page)
+{
+       return (page_ref_count(compound_head(page))) >= GUP_PIN_COUNTING_BIAS;
+}
+
 #if defined(CONFIG_SPARSEMEM) && !defined(CONFIG_SPARSEMEM_VMEMMAP)
 #define SECTION_IN_PAGE_FLAGS
 #endif
diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h
index bda20282746b..0485cba38d23 100644
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -244,6 +244,8 @@ enum node_stat_item {
        NR_DIRTIED,             /* page dirtyings since bootup */
        NR_WRITTEN,             /* page writings since bootup */
        NR_KERNEL_MISC_RECLAIMABLE,     /* reclaimable non-slab kernel pages */
+       NR_FOLL_PIN_REQUESTED,  /* via: pin_user_page(), gup flag: FOLL_PIN */
+       NR_FOLL_PIN_RETURNED,   /* pages returned via put_user_page() */
        NR_VM_NODE_STAT_ITEMS
 };
 
diff --git a/include/linux/page_ref.h b/include/linux/page_ref.h
index 14d14beb1f7f..b9cbe553d1e7 100644
--- a/include/linux/page_ref.h
+++ b/include/linux/page_ref.h
@@ -102,6 +102,16 @@ static inline void page_ref_sub(struct page *page, int nr)
                __page_ref_mod(page, -nr);
 }
 
+static inline int page_ref_sub_return(struct page *page, int nr)
+{
+       int ret = atomic_sub_return(nr, &page->_refcount);
+
+       if (page_ref_tracepoint_active(__tracepoint_page_ref_mod))
+               __page_ref_mod(page, -nr);
+
+       return ret;
+}
+
 static inline void page_ref_inc(struct page *page)
 {
        atomic_inc(&page->_refcount);
diff --git a/mm/gup.c b/mm/gup.c
index f72d7a1635b4..002816526670 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -51,6 +51,96 @@ static inline struct page *try_get_compound_head(struct page 
*page, int refs)
        return head;
 }
 
+#ifdef CONFIG_DEBUG_VM
+static inline void __update_proc_vmstat(struct page *page,
+                                       enum node_stat_item item, int count)
+{
+       mod_node_page_state(page_pgdat(page), item, count);
+}
+#else
+static inline void __update_proc_vmstat(struct page *page,
+                                       enum node_stat_item item, int count)
+{
+}
+#endif
+
+/**
+ * try_pin_compound_head() - mark a compound page as being used by
+ * pin_user_pages*().
+ *
+ * This is the FOLL_PIN counterpart to try_get_compound_head().
+ *
+ * @page:      pointer to page to be marked
+ * @Return:    true for success, false for failure
+ */
+__must_check bool try_pin_compound_head(struct page *page, int refs)
+{
+       page = try_get_compound_head(page, GUP_PIN_COUNTING_BIAS * refs);
+       if (!page)
+               return false;
+
+       __update_proc_vmstat(page, NR_FOLL_PIN_REQUESTED, refs);
+       return true;
+}
+
+#ifdef CONFIG_DEV_PAGEMAP_OPS
+static bool __put_devmap_managed_user_page(struct page *page)
+{
+       bool is_devmap = page_is_devmap_managed(page);
+
+       if (is_devmap) {
+               int count = page_ref_sub_return(page, GUP_PIN_COUNTING_BIAS);
+
+               __update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
+               /*
+                * devmap page refcounts are 1-based, rather than 0-based: if
+                * refcount is 1, then the page is free and the refcount is
+                * stable because nobody holds a reference on the page.
+                */
+               if (count == 1)
+                       free_devmap_managed_page(page);
+               else if (!count)
+                       __put_page(page);
+       }
+
+       return is_devmap;
+}
+#else
+static bool __put_devmap_managed_user_page(struct page *page)
+{
+       return false;
+}
+#endif /* CONFIG_DEV_PAGEMAP_OPS */
+
+/**
+ * put_user_page() - release a dma-pinned page
+ * @page:            pointer to page to be released
+ *
+ * Pages that were pinned via pin_user_pages*() must be released via either
+ * put_user_page(), or one of the put_user_pages*() routines. This is so that
+ * such pages can be separately tracked and uniquely handled. In particular,
+ * interactions with RDMA and filesystems need special handling.
+ */
+void put_user_page(struct page *page)
+{
+       page = compound_head(page);
+
+       /*
+        * For devmap managed pages we need to catch refcount transition from
+        * GUP_PIN_COUNTING_BIAS to 1, when refcount reach one it means the
+        * page is free and we need to inform the device driver through
+        * callback. See include/linux/memremap.h and HMM for details.
+        */
+       if (__put_devmap_managed_user_page(page))
+               return;
+
+       if (page_ref_sub_and_test(page, GUP_PIN_COUNTING_BIAS))
+               __put_page(page);
+
+       __update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
+}
+EXPORT_SYMBOL(put_user_page);
+
 /**
  * put_user_pages_dirty_lock() - release and optionally dirty gup-pinned pages
  * @pages:  array of pages to be maybe marked dirty, and definitely released.
@@ -237,10 +327,11 @@ static struct page *follow_page_pte(struct vm_area_struct 
*vma,
        }
 
        page = vm_normal_page(vma, address, pte);
-       if (!page && pte_devmap(pte) && (flags & FOLL_GET)) {
+       if (!page && pte_devmap(pte) && (flags & (FOLL_GET | FOLL_PIN))) {
                /*
-                * Only return device mapping pages in the FOLL_GET case since
-                * they are only valid while holding the pgmap reference.
+                * Only return device mapping pages in the FOLL_GET or FOLL_PIN
+                * case since they are only valid while holding the pgmap
+                * reference.
                 */
                *pgmap = get_dev_pagemap(pte_pfn(pte), *pgmap);
                if (*pgmap)
@@ -283,6 +374,11 @@ static struct page *follow_page_pte(struct vm_area_struct 
*vma,
                        page = ERR_PTR(-ENOMEM);
                        goto out;
                }
+       } else if (flags & FOLL_PIN) {
+               if (unlikely(!try_pin_page(page))) {
+                       page = ERR_PTR(-ENOMEM);
+                       goto out;
+               }
        }
        if (flags & FOLL_TOUCH) {
                if ((flags & FOLL_WRITE) &&
@@ -544,8 +640,8 @@ static struct page *follow_page_mask(struct vm_area_struct 
*vma,
        /* make this handle hugepd */
        page = follow_huge_addr(mm, address, flags & FOLL_WRITE);
        if (!IS_ERR(page)) {
-               BUG_ON(flags & FOLL_GET);
-               return page;
+               WARN_ON_ONCE(flags & (FOLL_GET | FOLL_PIN));
+               return NULL;
        }
 
        pgd = pgd_offset(mm, address);
@@ -1125,6 +1221,36 @@ static __always_inline long 
__get_user_pages_locked(struct task_struct *tsk,
        return pages_done;
 }
 
+static long __get_user_pages_remote(struct task_struct *tsk,
+                                   struct mm_struct *mm,
+                                   unsigned long start, unsigned long nr_pages,
+                                   unsigned int gup_flags, struct page **pages,
+                                   struct vm_area_struct **vmas, int *locked)
+{
+       /*
+        * Parts of FOLL_LONGTERM behavior are incompatible with
+        * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
+        * vmas. However, this only comes up if locked is set, and there are
+        * callers that do request FOLL_LONGTERM, but do not set locked. So,
+        * allow what we can.
+        */
+       if (gup_flags & FOLL_LONGTERM) {
+               if (WARN_ON_ONCE(locked))
+                       return -EINVAL;
+               /*
+                * This will check the vmas (even if our vmas arg is NULL)
+                * and return -ENOTSUPP if DAX isn't allowed in this case:
+                */
+               return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
+                                            vmas, gup_flags | FOLL_TOUCH |
+                                            FOLL_REMOTE);
+       }
+
+       return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
+                                      locked,
+                                      gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+}
+
 /*
  * get_user_pages_remote() - pin user pages in memory
  * @tsk:       the task_struct to use for page fault accounting, or
@@ -1193,28 +1319,8 @@ long get_user_pages_remote(struct task_struct *tsk, 
struct mm_struct *mm,
        if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
                return -EINVAL;
 
-       /*
-        * Parts of FOLL_LONGTERM behavior are incompatible with
-        * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
-        * vmas. However, this only comes up if locked is set, and there are
-        * callers that do request FOLL_LONGTERM, but do not set locked. So,
-        * allow what we can.
-        */
-       if (gup_flags & FOLL_LONGTERM) {
-               if (WARN_ON_ONCE(locked))
-                       return -EINVAL;
-               /*
-                * This will check the vmas (even if our vmas arg is NULL)
-                * and return -ENOTSUPP if DAX isn't allowed in this case:
-                */
-               return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
-                                            vmas, gup_flags | FOLL_TOUCH |
-                                            FOLL_REMOTE);
-       }
-
-       return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
-                                      locked,
-                                      gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+       return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+                                      pages, vmas, locked);
 }
 EXPORT_SYMBOL(get_user_pages_remote);
 
@@ -1842,13 +1948,17 @@ static inline pte_t gup_get_pte(pte_t *ptep)
 #endif /* CONFIG_GUP_GET_PTE_LOW_HIGH */
 
 static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
+                                           unsigned int flags,
                                            struct page **pages)
 {
        while ((*nr) - nr_start) {
                struct page *page = pages[--(*nr)];
 
                ClearPageReferenced(page);
-               put_page(page);
+               if (flags & FOLL_PIN)
+                       put_user_page(page);
+               else
+                       put_page(page);
        }
 }
 
@@ -1881,7 +1991,7 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, 
unsigned long end,
 
                        pgmap = get_dev_pagemap(pte_pfn(pte), pgmap);
                        if (unlikely(!pgmap)) {
-                               undo_dev_pagemap(nr, nr_start, pages);
+                               undo_dev_pagemap(nr, nr_start, flags, pages);
                                goto pte_unmap;
                        }
                } else if (pte_special(pte))
@@ -1890,9 +2000,15 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, 
unsigned long end,
                VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
                page = pte_page(pte);
 
-               head = try_get_compound_head(page, 1);
-               if (!head)
-                       goto pte_unmap;
+               if (flags & FOLL_PIN) {
+                       head = page;
+                       if (unlikely(!try_pin_page(head)))
+                               goto pte_unmap;
+               } else {
+                       head = try_get_compound_head(page, 1);
+                       if (!head)
+                               goto pte_unmap;
+               }
 
                if (unlikely(pte_val(pte) != pte_val(*ptep))) {
                        put_page(head);
@@ -1946,12 +2062,20 @@ static int __gup_device_huge(unsigned long pfn, 
unsigned long addr,
 
                pgmap = get_dev_pagemap(pfn, pgmap);
                if (unlikely(!pgmap)) {
-                       undo_dev_pagemap(nr, nr_start, pages);
+                       undo_dev_pagemap(nr, nr_start, flags, pages);
                        return 0;
                }
                SetPageReferenced(page);
                pages[*nr] = page;
-               get_page(page);
+
+               if (flags & FOLL_PIN) {
+                       if (unlikely(!try_pin_page(page))) {
+                               undo_dev_pagemap(nr, nr_start, flags, pages);
+                               return 0;
+                       }
+               } else
+                       get_page(page);
+
                (*nr)++;
                pfn++;
        } while (addr += PAGE_SIZE, addr != end);
@@ -1973,7 +2097,7 @@ static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, 
unsigned long addr,
                return 0;
 
        if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
-               undo_dev_pagemap(nr, nr_start, pages);
+               undo_dev_pagemap(nr, nr_start, flags, pages);
                return 0;
        }
        return 1;
@@ -1991,7 +2115,7 @@ static int __gup_device_huge_pud(pud_t orig, pud_t *pudp, 
unsigned long addr,
                return 0;
 
        if (unlikely(pud_val(orig) != pud_val(*pudp))) {
-               undo_dev_pagemap(nr, nr_start, pages);
+               undo_dev_pagemap(nr, nr_start, flags, pages);
                return 0;
        }
        return 1;
@@ -2025,6 +2149,20 @@ static int __record_subpages(struct page *page, unsigned 
long addr,
        return nr;
 }
 
+static bool __pin_compound_head(struct page *head, int refs, unsigned int 
flags)
+{
+       if (flags & FOLL_PIN) {
+               if (unlikely(!try_pin_compound_head(head, refs)))
+                       return false;
+       } else {
+               head = try_get_compound_head(head, refs);
+               if (!head)
+                       return false;
+       }
+
+       return true;
+}
+
 static void put_compound_head(struct page *page, int refs)
 {
        /* Do a get_page() first, in case refs == page->_refcount */
@@ -2066,8 +2204,7 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, 
unsigned long addr,
        page = head + ((addr & (sz-1)) >> PAGE_SHIFT);
        refs = __record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(head, refs);
-       if (!head)
+       if (!__pin_compound_head(head, refs, flags))
                return 0;
 
        if (unlikely(pte_val(pte) != pte_val(*ptep))) {
@@ -2126,8 +2263,8 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned 
long addr,
        page = pmd_page(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        refs = __record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pmd_page(orig), refs);
-       if (!head)
+       head = pmd_page(orig);
+       if (!__pin_compound_head(head, refs, flags))
                return 0;
 
        if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
@@ -2160,8 +2297,8 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned 
long addr,
        page = pud_page(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
        refs = __record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pud_page(orig), refs);
-       if (!head)
+       head = pud_page(orig);
+       if (!__pin_compound_head(head, refs, flags))
                return 0;
 
        if (unlikely(pud_val(orig) != pud_val(*pudp))) {
@@ -2189,8 +2326,8 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned 
long addr,
        page = pgd_page(orig) + ((addr & ~PGDIR_MASK) >> PAGE_SHIFT);
        refs = __record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pgd_page(orig), refs);
-       if (!head)
+       head = pgd_page(orig);
+       if (!__pin_compound_head(head, refs, flags))
                return 0;
 
        if (unlikely(pgd_val(orig) != pgd_val(*pgdp))) {
@@ -2494,9 +2631,12 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast);
 /**
  * pin_user_pages_fast() - pin user pages in memory without taking locks
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_fast().
+ * Nearly the same as get_user_pages_fast(), except that FOLL_PIN is set. See
+ * get_user_pages_fast() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for further details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2504,21 +2644,24 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast);
 int pin_user_pages_fast(unsigned long start, int nr_pages,
                        unsigned int gup_flags, struct page **pages)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages_fast(start, nr_pages, gup_flags, pages);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
 EXPORT_SYMBOL_GPL(pin_user_pages_fast);
 
 /**
  * pin_user_pages_remote() - pin pages of a remote process (task != current)
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_remote().
+ * Nearly the same as get_user_pages_remote(), except that FOLL_PIN is set. See
+ * get_user_pages_remote() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2528,22 +2671,24 @@ long pin_user_pages_remote(struct task_struct *tsk, 
struct mm_struct *mm,
                           unsigned int gup_flags, struct page **pages,
                           struct vm_area_struct **vmas, int *locked)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags, pages,
-                                    vmas, locked);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+                                      pages, vmas, locked);
 }
 EXPORT_SYMBOL(pin_user_pages_remote);
 
 /**
  * pin_user_pages() - pin user pages in memory for use by other devices
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages().
+ * Nearly the same as get_user_pages(), except that FOLL_TOUCH is not set, and
+ * FOLL_PIN is set.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2552,11 +2697,12 @@ long pin_user_pages(unsigned long start, unsigned long 
nr_pages,
                    unsigned int gup_flags, struct page **pages,
                    struct vm_area_struct **vmas)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages(start, nr_pages, gup_flags, pages, vmas);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return __gup_longterm_locked(current, current->mm, start, nr_pages,
+                                    pages, vmas, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages);
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 13cc93785006..e94297799041 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -945,6 +945,11 @@ struct page *follow_devmap_pmd(struct vm_area_struct *vma, 
unsigned long addr,
         */
        WARN_ONCE(flags & FOLL_COW, "mm: In follow_devmap_pmd with FOLL_COW 
set");
 
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+                        (FOLL_PIN | FOLL_GET)))
+               return NULL;
+
        if (flags & FOLL_WRITE && !pmd_write(*pmd))
                return NULL;
 
@@ -960,7 +965,7 @@ struct page *follow_devmap_pmd(struct vm_area_struct *vma, 
unsigned long addr,
         * device mapped pages can only be returned if the
         * caller will manage the page reference count.
         */
-       if (!(flags & FOLL_GET))
+       if (!(flags & (FOLL_GET | FOLL_PIN)))
                return ERR_PTR(-EEXIST);
 
        pfn += (addr & ~PMD_MASK) >> PAGE_SHIFT;
@@ -968,7 +973,18 @@ struct page *follow_devmap_pmd(struct vm_area_struct *vma, 
unsigned long addr,
        if (!*pgmap)
                return ERR_PTR(-EFAULT);
        page = pfn_to_page(pfn);
-       get_page(page);
+
+       if (flags & FOLL_GET)
+               get_page(page);
+       else if (flags & FOLL_PIN) {
+               /*
+                * try_pin_page() is not actually expected to fail here because
+                * we hold the pmd lock so no one can unmap the pmd and free the
+                * page that it points to.
+                */
+               if (unlikely(!try_pin_page(page)))
+                       page = ERR_PTR(-EFAULT);
+       }
 
        return page;
 }
@@ -1088,6 +1104,11 @@ struct page *follow_devmap_pud(struct vm_area_struct 
*vma, unsigned long addr,
        if (flags & FOLL_WRITE && !pud_write(*pud))
                return NULL;
 
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+                        (FOLL_PIN | FOLL_GET)))
+               return NULL;
+
        if (pud_present(*pud) && pud_devmap(*pud))
                /* pass */;
        else
@@ -1099,8 +1120,10 @@ struct page *follow_devmap_pud(struct vm_area_struct 
*vma, unsigned long addr,
        /*
         * device mapped pages can only be returned if the
         * caller will manage the page reference count.
+        *
+        * At least one of FOLL_GET | FOLL_PIN must be set, so assert that here:
         */
-       if (!(flags & FOLL_GET))
+       if (!(flags & (FOLL_GET | FOLL_PIN)))
                return ERR_PTR(-EEXIST);
 
        pfn += (addr & ~PUD_MASK) >> PAGE_SHIFT;
@@ -1108,7 +1131,18 @@ struct page *follow_devmap_pud(struct vm_area_struct 
*vma, unsigned long addr,
        if (!*pgmap)
                return ERR_PTR(-EFAULT);
        page = pfn_to_page(pfn);
-       get_page(page);
+
+       if (flags & FOLL_GET)
+               get_page(page);
+       else if (flags & FOLL_PIN) {
+               /*
+                * try_pin_page() is not actually expected to fail here because
+                * we hold the pud lock so no one can unmap the pud and free the
+                * page that it points to.
+                */
+               if (unlikely(!try_pin_page(page)))
+                       page = ERR_PTR(-EFAULT);
+       }
 
        return page;
 }
@@ -1522,8 +1556,20 @@ struct page *follow_trans_huge_pmd(struct vm_area_struct 
*vma,
 skip_mlock:
        page += (addr & ~HPAGE_PMD_MASK) >> PAGE_SHIFT;
        VM_BUG_ON_PAGE(!PageCompound(page) && !is_zone_device_page(page), page);
+
        if (flags & FOLL_GET)
                get_page(page);
+       else if (flags & FOLL_PIN) {
+               /*
+                * try_pin_page() is not actually expected to fail here because
+                * we hold the pmd lock so no one can unmap the pmd and free the
+                * page that it points to.
+                */
+               if (unlikely(!try_pin_page(page))) {
+                       WARN_ON_ONCE(1);
+                       page = NULL;
+               }
+       }
 
 out:
        return page;
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index b45a95363a84..0abde7288127 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -4462,7 +4462,22 @@ long follow_hugetlb_page(struct mm_struct *mm, struct 
vm_area_struct *vma,
 same_page:
                if (pages) {
                        pages[i] = mem_map_offset(page, pfn_offset);
-                       get_page(pages[i]);
+
+                       if (flags & FOLL_GET)
+                               get_page(pages[i]);
+                       else if (flags & FOLL_PIN) {
+                               /*
+                                * try_pin_page() is not actually expected to
+                                * fail here because we hold the ptl.
+                                */
+                               if (unlikely(!try_pin_page(pages[i]))) {
+                                       spin_unlock(ptl);
+                                       remainder = 0;
+                                       err = -ENOMEM;
+                                       WARN_ON_ONCE(1);
+                                       break;
+                               }
+                       }
                }
 
                if (vmas)
@@ -5022,6 +5037,12 @@ follow_huge_pmd(struct mm_struct *mm, unsigned long 
address,
        struct page *page = NULL;
        spinlock_t *ptl;
        pte_t pte;
+
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+                        (FOLL_PIN | FOLL_GET)))
+               return NULL;
+
 retry:
        ptl = pmd_lockptr(mm, pmd);
        spin_lock(ptl);
@@ -5034,8 +5055,20 @@ follow_huge_pmd(struct mm_struct *mm, unsigned long 
address,
        pte = huge_ptep_get((pte_t *)pmd);
        if (pte_present(pte)) {
                page = pmd_page(*pmd) + ((address & ~PMD_MASK) >> PAGE_SHIFT);
+
                if (flags & FOLL_GET)
                        get_page(page);
+               else if (flags & FOLL_PIN) {
+                       /*
+                        * try_pin_page() is not actually expected to fail
+                        * here because we hold the ptl.
+                        */
+                       if (unlikely(!try_pin_page(page))) {
+                               WARN_ON_ONCE(1);
+                               page = NULL;
+                               goto out;
+                       }
+               }
        } else {
                if (is_hugetlb_entry_migration(pte)) {
                        spin_unlock(ptl);
@@ -5056,7 +5089,7 @@ struct page * __weak
 follow_huge_pud(struct mm_struct *mm, unsigned long address,
                pud_t *pud, int flags)
 {
-       if (flags & FOLL_GET)
+       if (flags & (FOLL_GET | FOLL_PIN))
                return NULL;
 
        return pte_page(*(pte_t *)pud) + ((address & ~PUD_MASK) >> PAGE_SHIFT);
@@ -5065,7 +5098,7 @@ follow_huge_pud(struct mm_struct *mm, unsigned long 
address,
 struct page * __weak
 follow_huge_pgd(struct mm_struct *mm, unsigned long address, pgd_t *pgd, int 
flags)
 {
-       if (flags & FOLL_GET)
+       if (flags & (FOLL_GET | FOLL_PIN))
                return NULL;
 
        return pte_page(*(pte_t *)pgd) + ((address & ~PGDIR_MASK) >> 
PAGE_SHIFT);
diff --git a/mm/vmstat.c b/mm/vmstat.c
index a8222041bd44..fdad40ccde7b 100644
--- a/mm/vmstat.c
+++ b/mm/vmstat.c
@@ -1167,6 +1167,8 @@ const char * const vmstat_text[] = {
        "nr_dirtied",
        "nr_written",
        "nr_kernel_misc_reclaimable",
+       "nr_foll_pin_requested",
+       "nr_foll_pin_returned",
 
        /* enum writeback_stat_item counters */
        "nr_dirty_threshold",
-- 
2.24.0

_______________________________________________
dri-devel mailing list
dri-devel@lists.freedesktop.org
https://lists.freedesktop.org/mailman/listinfo/dri-devel

Reply via email to