sme_pgtable_calc() is unnecessary complex. It can be re-written in a
more stream-lined way.

As a side effect, we would get the code ready to boot-time switching
between paging modes.

Signed-off-by: Kirill A. Shutemov <[email protected]>
---
 arch/x86/mm/mem_encrypt_identity.c | 42 +++++++++++---------------------------
 1 file changed, 12 insertions(+), 30 deletions(-)

diff --git a/arch/x86/mm/mem_encrypt_identity.c 
b/arch/x86/mm/mem_encrypt_identity.c
index 69635a02ce9e..613686cc56ae 100644
--- a/arch/x86/mm/mem_encrypt_identity.c
+++ b/arch/x86/mm/mem_encrypt_identity.c
@@ -230,8 +230,7 @@ static void __init sme_map_range_decrypted_wp(struct 
sme_populate_pgd_data *ppd)
 
 static unsigned long __init sme_pgtable_calc(unsigned long len)
 {
-       unsigned long p4d_size, pud_size, pmd_size, pte_size;
-       unsigned long total;
+       unsigned long entries = 0, tables = 0;
 
        /*
         * Perform a relatively simplistic calculation of the pagetable
@@ -245,42 +244,25 @@ static unsigned long __init sme_pgtable_calc(unsigned 
long len)
         * Incrementing the count for each covers the case where the addresses
         * cross entries.
         */
-       if (IS_ENABLED(CONFIG_X86_5LEVEL)) {
-               p4d_size = (ALIGN(len, PGDIR_SIZE) / PGDIR_SIZE) + 1;
-               p4d_size *= sizeof(p4d_t) * PTRS_PER_P4D;
-               pud_size = (ALIGN(len, P4D_SIZE) / P4D_SIZE) + 1;
-               pud_size *= sizeof(pud_t) * PTRS_PER_PUD;
-       } else {
-               p4d_size = 0;
-               pud_size = (ALIGN(len, PGDIR_SIZE) / PGDIR_SIZE) + 1;
-               pud_size *= sizeof(pud_t) * PTRS_PER_PUD;
-       }
-       pmd_size = (ALIGN(len, PUD_SIZE) / PUD_SIZE) + 1;
-       pmd_size *= sizeof(pmd_t) * PTRS_PER_PMD;
-       pte_size = 2 * sizeof(pte_t) * PTRS_PER_PTE;
 
-       total = p4d_size + pud_size + pmd_size + pte_size;
+       /* PGDIR_SIZE is equal to P4D_SIZE on 4-level machine. */
+       if (PTRS_PER_P4D > 1)
+               entries += (DIV_ROUND_UP(len, PGDIR_SIZE) + 1) * sizeof(p4d_t) 
* PTRS_PER_P4D;
+       entries += (DIV_ROUND_UP(len, P4D_SIZE) + 1) * sizeof(pud_t) * 
PTRS_PER_PUD;
+       entries += (DIV_ROUND_UP(len, PUD_SIZE) + 1) * sizeof(pmd_t) * 
PTRS_PER_PMD;
+       entries += 2 * sizeof(pte_t) * PTRS_PER_PTE;
 
        /*
         * Now calculate the added pagetable structures needed to populate
         * the new pagetables.
         */
-       if (IS_ENABLED(CONFIG_X86_5LEVEL)) {
-               p4d_size = ALIGN(total, PGDIR_SIZE) / PGDIR_SIZE;
-               p4d_size *= sizeof(p4d_t) * PTRS_PER_P4D;
-               pud_size = ALIGN(total, P4D_SIZE) / P4D_SIZE;
-               pud_size *= sizeof(pud_t) * PTRS_PER_PUD;
-       } else {
-               p4d_size = 0;
-               pud_size = ALIGN(total, PGDIR_SIZE) / PGDIR_SIZE;
-               pud_size *= sizeof(pud_t) * PTRS_PER_PUD;
-       }
-       pmd_size = ALIGN(total, PUD_SIZE) / PUD_SIZE;
-       pmd_size *= sizeof(pmd_t) * PTRS_PER_PMD;
 
-       total += p4d_size + pud_size + pmd_size;
+       if (PTRS_PER_P4D > 1)
+               tables += DIV_ROUND_UP(entries, PGDIR_SIZE) * sizeof(p4d_t) * 
PTRS_PER_P4D;
+       tables += DIV_ROUND_UP(entries, P4D_SIZE) * sizeof(pud_t) * 
PTRS_PER_PUD;
+       tables += DIV_ROUND_UP(entries, PUD_SIZE) * sizeof(pmd_t) * 
PTRS_PER_PMD;
 
-       return total;
+       return entries + tables;
 }
 
 void __init __nostackprotector sme_encrypt_kernel(struct boot_params *bp)
-- 
2.15.1

Reply via email to