From: Jason Gunthorpe <j...@mellanox.com>

Using two bools instead of flags return is not necessary and leads to
bugs. Returning a value is easier for the compiler to check and easier to
pass around the code flow.

Convert the two bools into flags and push the change to all callers.

Signed-off-by: Jason Gunthorpe <j...@mellanox.com>
---
 mm/hmm.c | 183 ++++++++++++++++++++++++-------------------------------
 1 file changed, 81 insertions(+), 102 deletions(-)

diff --git a/mm/hmm.c b/mm/hmm.c
index 3a2610e0713329..d208ddd351066f 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -32,6 +32,12 @@ struct hmm_vma_walk {
        unsigned int            flags;
 };
 
+enum {
+       HMM_NEED_FAULT = 1 << 0,
+       HMM_NEED_WRITE_FAULT = 1 << 1,
+       HMM_NEED_ALL_BITS = HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT,
+};
+
 static int hmm_pfns_fill(unsigned long addr, unsigned long end,
                struct hmm_range *range, enum hmm_pfn_value_e value)
 {
@@ -49,8 +55,7 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long 
end,
  * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
  * @addr: range virtual start address (inclusive)
  * @end: range virtual end address (exclusive)
- * @fault: should we fault or not ?
- * @write_fault: write fault ?
+ * @required_fault: HMM_NEED_* flags
  * @walk: mm_walk structure
  * Return: -EBUSY after page fault, or page fault error
  *
@@ -58,8 +63,7 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long 
end,
  * or whenever there is no page directory covering the virtual address range.
  */
 static int hmm_vma_fault(unsigned long addr, unsigned long end,
-                             bool fault, bool write_fault,
-                             struct mm_walk *walk)
+                        unsigned int required_fault, struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
@@ -68,13 +72,13 @@ static int hmm_vma_fault(unsigned long addr, unsigned long 
end,
        unsigned long i = (addr - range->start) >> PAGE_SHIFT;
        unsigned int fault_flags = FAULT_FLAG_REMOTE;
 
-       WARN_ON_ONCE(!fault && !write_fault);
+       WARN_ON_ONCE(!required_fault);
        hmm_vma_walk->last = addr;
 
        if (!vma)
                goto out_error;
 
-       if (write_fault) {
+       if (required_fault & HMM_NEED_WRITE_FAULT) {
                if (!(vma->vm_flags & VM_WRITE))
                        return -EPERM;
                fault_flags |= FAULT_FLAG_WRITE;
@@ -91,14 +95,13 @@ static int hmm_vma_fault(unsigned long addr, unsigned long 
end,
        return -EFAULT;
 }
 
-static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
-                                     uint64_t pfns, uint64_t cpu_flags,
-                                     bool *fault, bool *write_fault)
+static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                                      uint64_t pfns, uint64_t cpu_flags)
 {
        struct hmm_range *range = hmm_vma_walk->range;
 
        if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
-               return;
+               return 0;
 
        /*
         * So we not only consider the individual per page request we also
@@ -114,37 +117,37 @@ static inline void hmm_pte_need_fault(const struct 
hmm_vma_walk *hmm_vma_walk,
 
        /* We aren't ask to do anything ... */
        if (!(pfns & range->flags[HMM_PFN_VALID]))
-               return;
+               return 0;
 
-       /* If CPU page table is not valid then we need to fault */
-       *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
        /* Need to write fault ? */
        if ((pfns & range->flags[HMM_PFN_WRITE]) &&
-           !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
-               *write_fault = true;
-               *fault = true;
-       }
+           !(cpu_flags & range->flags[HMM_PFN_WRITE]))
+               return HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT;
+
+       /* If CPU page table is not valid then we need to fault */
+       if (!(cpu_flags & range->flags[HMM_PFN_VALID]))
+               return HMM_NEED_FAULT;
+       return 0;
 }
 
-static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
-                                const uint64_t *pfns, unsigned long npages,
-                                uint64_t cpu_flags, bool *fault,
-                                bool *write_fault)
+static unsigned int
+hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                    const uint64_t *pfns, unsigned long npages,
+                    uint64_t cpu_flags)
 {
+       unsigned int required_fault = 0;
        unsigned long i;
 
-       if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT) {
-               *fault = *write_fault = false;
-               return;
-       }
+       if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
+               return 0;
 
-       *fault = *write_fault = false;
        for (i = 0; i < npages; ++i) {
-               hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
-                                  fault, write_fault);
-               if ((*write_fault))
-                       return;
+               required_fault |=
+                       hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags);
+               if (required_fault == HMM_NEED_ALL_BITS)
+                       return required_fault;
        }
+       return required_fault;
 }
 
 static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
@@ -152,17 +155,16 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned 
long end,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       bool fault, write_fault;
+       unsigned int required_fault;
        unsigned long i, npages;
        uint64_t *pfns;
 
        i = (addr - range->start) >> PAGE_SHIFT;
        npages = (end - addr) >> PAGE_SHIFT;
        pfns = &range->pfns[i];
-       hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                            0, &fault, &write_fault);
-       if (fault || write_fault)
-               return hmm_vma_fault(addr, end, fault, write_fault, walk);
+       required_fault = hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0);
+       if (required_fault)
+               return hmm_vma_fault(addr, end, required_fault, walk);
        hmm_vma_walk->last = addr;
        return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
 }
@@ -183,16 +185,15 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, 
unsigned long addr,
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
        unsigned long pfn, npages, i;
-       bool fault, write_fault;
+       unsigned int required_fault;
        uint64_t cpu_flags;
 
        npages = (end - addr) >> PAGE_SHIFT;
        cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
-       hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
-                            &fault, &write_fault);
-
-       if (fault || write_fault)
-               return hmm_vma_fault(addr, end, fault, write_fault, walk);
+       required_fault =
+               hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags);
+       if (required_fault)
+               return hmm_vma_fault(addr, end, required_fault, walk);
 
        pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
@@ -229,18 +230,15 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       bool fault, write_fault;
+       unsigned int required_fault;
        uint64_t cpu_flags;
        pte_t pte = *ptep;
        uint64_t orig_pfn = *pfn;
 
        *pfn = range->values[HMM_PFN_NONE];
-       fault = write_fault = false;
-
        if (pte_none(pte)) {
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
-                                  &fault, &write_fault);
-               if (fault || write_fault)
+               required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
+               if (required_fault)
                        goto fault;
                return 0;
        }
@@ -261,9 +259,8 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
                        return 0;
                }
 
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
-                                  &write_fault);
-               if (!fault && !write_fault)
+               required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
+               if (!required_fault)
                        return 0;
 
                if (!non_swap_entry(entry))
@@ -283,9 +280,8 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
        }
 
        cpu_flags = pte_to_hmm_pfn_flags(range, pte);
-       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags, &fault,
-                          &write_fault);
-       if (fault || write_fault)
+       required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
+       if (required_fault)
                goto fault;
 
        /*
@@ -293,9 +289,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
         * fall through and treat it like a normal page.
         */
        if (pte_special(pte) && !is_zero_pfn(pte_pfn(pte))) {
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
-                                  &write_fault);
-               if (fault || write_fault) {
+               if (hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0)) {
                        pte_unmap(ptep);
                        return -EFAULT;
                }
@@ -309,7 +303,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, 
unsigned long addr,
 fault:
        pte_unmap(ptep);
        /* Fault any virtual address we were asked to fault */
-       return hmm_vma_fault(addr, end, fault, write_fault, walk);
+       return hmm_vma_fault(addr, end, required_fault, walk);
 }
 
 static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -322,7 +316,6 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
        uint64_t *pfns = &range->pfns[(start - range->start) >> PAGE_SHIFT];
        unsigned long npages = (end - start) >> PAGE_SHIFT;
        unsigned long addr = start;
-       bool fault, write_fault;
        pte_t *ptep;
        pmd_t pmd;
 
@@ -332,9 +325,7 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
                return hmm_vma_walk_hole(start, end, -1, walk);
 
        if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                                    0, &fault, &write_fault);
-               if (fault || write_fault) {
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0)) {
                        hmm_vma_walk->last = addr;
                        pmd_migration_entry_wait(walk->mm, pmdp);
                        return -EBUSY;
@@ -343,9 +334,7 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
        }
 
        if (!pmd_present(pmd)) {
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
-                                    &write_fault);
-               if (fault || write_fault)
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
                        return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
        }
@@ -375,9 +364,7 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
         * recover.
         */
        if (pmd_bad(pmd)) {
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
-                                    &write_fault);
-               if (fault || write_fault)
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
                        return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
        }
@@ -434,8 +421,8 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long 
start, unsigned long end,
 
        if (pud_huge(pud) && pud_devmap(pud)) {
                unsigned long i, npages, pfn;
+               unsigned int required_fault;
                uint64_t *pfns, cpu_flags;
-               bool fault, write_fault;
 
                if (!pud_present(pud)) {
                        spin_unlock(ptl);
@@ -447,12 +434,11 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long 
start, unsigned long end,
                pfns = &range->pfns[i];
 
                cpu_flags = pud_to_hmm_pfn_flags(range, pud);
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                                    cpu_flags, &fault, &write_fault);
-               if (fault || write_fault) {
+               required_fault = hmm_range_need_fault(hmm_vma_walk, pfns,
+                                                     npages, cpu_flags);
+               if (required_fault) {
                        spin_unlock(ptl);
-                       return hmm_vma_fault(addr, end, fault, write_fault,
-                                                 walk);
+                       return hmm_vma_fault(addr, end, required_fault, walk);
                }
 
                pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
@@ -484,7 +470,7 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned 
long hmask,
        struct hmm_range *range = hmm_vma_walk->range;
        struct vm_area_struct *vma = walk->vma;
        uint64_t orig_pfn, cpu_flags;
-       bool fault, write_fault;
+       unsigned int required_fault;
        spinlock_t *ptl;
        pte_t entry;
 
@@ -495,12 +481,10 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, 
unsigned long hmask,
        orig_pfn = range->pfns[i];
        range->pfns[i] = range->values[HMM_PFN_NONE];
        cpu_flags = pte_to_hmm_pfn_flags(range, entry);
-       fault = write_fault = false;
-       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                          &fault, &write_fault);
-       if (fault || write_fault) {
+       required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
+       if (required_fault) {
                spin_unlock(ptl);
-               return hmm_vma_fault(addr, end, fault, write_fault, walk);
+               return hmm_vma_fault(addr, end, required_fault, walk);
        }
 
        pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
@@ -522,37 +506,32 @@ static int hmm_vma_walk_test(unsigned long start, 
unsigned long end,
        struct hmm_range *range = hmm_vma_walk->range;
        struct vm_area_struct *vma = walk->vma;
 
+       if (!(vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) &&
+           vma->vm_flags & VM_READ)
+               return 0;
+
        /*
-        * Skip vma ranges that don't have struct page backing them or map I/O
-        * devices directly.
+        * vma ranges that don't have struct page backing them or map I/O
+        * devices directly cannot be handled by hmm_range_fault().
         *
         * If the vma does not allow read access, then assume that it does not
         * allow write access either. HMM does not support architectures that
         * allow write without read.
+        *
+        * If a fault is requested for an unsupported range then it is a hard
+        * failure.
         */
-       if ((vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) ||
-           !(vma->vm_flags & VM_READ)) {
-               bool fault, write_fault;
-
-               /*
-                * Check to see if a fault is requested for any page in the
-                * range.
-                */
-               hmm_range_need_fault(hmm_vma_walk, range->pfns +
-                                       ((start - range->start) >> PAGE_SHIFT),
-                                       (end - start) >> PAGE_SHIFT,
-                                       0, &fault, &write_fault);
-               if (fault || write_fault)
-                       return -EFAULT;
-
-               hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
-               hmm_vma_walk->last = end;
+       if (hmm_range_need_fault(hmm_vma_walk,
+                                range->pfns +
+                                        ((start - range->start) >> PAGE_SHIFT),
+                                (end - start) >> PAGE_SHIFT, 0))
+               return -EFAULT;
 
-               /* Skip this vma and continue processing the next vma. */
-               return 1;
-       }
+       hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       hmm_vma_walk->last = end;
 
-       return 0;
+       /* Skip this vma and continue processing the next vma. */
+       return 1;
 }
 
 static const struct mm_walk_ops hmm_walk_ops = {
-- 
2.25.2

_______________________________________________
amd-gfx mailing list
amd-gfx@lists.freedesktop.org
https://lists.freedesktop.org/mailman/listinfo/amd-gfx

Reply via email to