Let E denote an multi-dimensional extent; n the rank of E; r = 0, ...,
n; E[i] the i-th extent; and D[k] be the (possibly empty) array of
dynamic extents.

The two partial products for r = 0, ..., n:

  \prod_{i = 0}^r E[i]     (fwd)
  \prod_{i = r+1}^n E[i]   (rev)

can be computed as the product of static and dynamic extents. The static
fwd and rev product can be computed at compile time for all values of r.

Three methods are directly affected by this optimization:

   layout_left::mapping::stride
   layout_right::mapping::stride
   mdspan::size

We'll check the generated code (-O2) for all three methods for a generic
(artificially) high-dimensional multi-dimensional extents.

Consider a generic case:

  using Extents = std::extents<int, 3, 5, dyn, dyn, dyn, 7, dyn>;

  int stride_left(const std::layout_left::mapping<Extents>& m, size_t r)
  { return m.stride(r); }

The code generated by master:

  80:  49 89 f0               mov    r8,rsi
  83:  49 89 f1               mov    r9,rsi
  86:  49 c1 e0 03            shl    r8,0x3
  8a:  74 6c                  je     f8 <stride_left+0x78>
  8c:  49 81 c0 00 00 00 00   add    r8,0x0
  93:  ba 00 00 00 00         mov    edx,0x0
  98:  b8 01 00 00 00         mov    eax,0x1
  9d:  0f 1f 00               nop    DWORD PTR [rax]
  a0:  48 8b 0a               mov    rcx,QWORD PTR [rdx]
  a3:  48 89 c6               mov    rsi,rax
  a6:  48 0f af f1            imul   rsi,rcx
  aa:  48 83 f9 ff            cmp    rcx,0xffffffffffffffff
  ae:  48 0f 45 c6            cmovne rax,rsi
  b2:  48 83 c2 08            add    rdx,0x8
  b6:  49 39 d0               cmp    r8,rdx
  b9:  75 e5                  jne    a0 <stride_left+0x20>
  bb:  31 d2                  xor    edx,edx
  bd:  48 85 c0               test   rax,rax
  c0:  74 30                  je     f2 <stride_left+0x72>
  c2:  4a 8b 0c cd 00 00 00   mov    rcx,QWORD PTR [r9*8+0x0]
  c9:  00
  ca:  48 c1 e1 02            shl    rcx,0x2
  ce:  74 20                  je     f0 <stride_left+0x70>
  d0:  48 01 f9               add    rcx,rdi
  d3:  66 66 2e 0f 1f 84 00   data16 cs nop WORD PTR [rax+rax*1+0x0]
  da:  00 00 00 00
  de:  66 90                  xchg   ax,ax
  e0:  48 63 17               movsxd rdx,DWORD PTR [rdi]
  e3:  48 83 c7 04            add    rdi,0x4
  e7:  48 0f af c2            imul   rax,rdx
  eb:  48 39 f9               cmp    rcx,rdi
  ee:  75 f0                  jne    e0 <stride_left+0x60>
  f0:  89 c2                  mov    edx,eax
  f2:  89 d0                  mov    eax,edx
  f4:  c3                     ret
  f5:  0f 1f 00               nop    DWORD PTR [rax]
  f8:  b8 01 00 00 00         mov    eax,0x1
  fd:  eb c3                  jmp    c2 <stride_left+0x42>

is reduced to:

  50:  48 8b 0c f5 00 00 00   mov    rcx,QWORD PTR [rsi*8+0x0]
  57:  00
  58:  48 8b 04 f5 00 00 00   mov    rax,QWORD PTR [rsi*8+0x0]
  5f:  00
  60:  48 c1 e1 02            shl    rcx,0x2
  64:  74 1a                  je     80 <stride_left+0x30>
  66:  48 01 f9               add    rcx,rdi
  69:  0f 1f 80 00 00 00 00   nop    DWORD PTR [rax+0x0]
  70:  48 63 17               movsxd rdx,DWORD PTR [rdi]
  73:  48 83 c7 04            add    rdi,0x4
  77:  48 0f af c2            imul   rax,rdx
  7b:  48 39 f9               cmp    rcx,rdi
  7e:  75 f0                  jne    70 <stride_left+0x20>
  80:  c3                     ret

Loosely speaking this does the following:

  1. Load the starting position k in the array of dynamic extents.
  2. Load the partial product of static extents.
  3. Computes the \prod_{i = k}^d D[i] where d is the number of
  dynamic extents in a loop.

It shows that the span used for passing in the dynamic extents is
completely eliminated; and the fact that the product always runs to the
end of the array of dynamic extents is used by the compiler to eliminate
one indirection to determine the end position in the array of dynamic
extents.

The analogous code is generated for layout_right.

Next, consider

  using E2 = std::extents<int, 3, 5, dyn, dyn, 7, dyn, 11>;
  int size2(const std::mdspan<double, E2>& md)
  { return md.size(); }

on master the generated code is

  10:  b8 00 00 00 00         mov    eax,0x0
  15:  ba 01 00 00 00         mov    edx,0x1
  1a:  66 0f 1f 44 00 00      nop    WORD PTR [rax+rax*1+0x0]
  20:  48 8b 08               mov    rcx,QWORD PTR [rax]
  23:  48 89 d6               mov    rsi,rdx
  26:  48 0f af f1            imul   rsi,rcx
  2a:  48 83 f9 ff            cmp    rcx,0xffffffffffffffff
  2e:  48 0f 45 d6            cmovne rdx,rsi
  32:  48 83 c0 08            add    rax,0x8
  36:  48 3d 00 00 00 00      cmp    rax,0x0
  3c:  75 e2                  jne    20 <size2+0x10>
  3e:  31 c0                  xor    eax,eax
  40:  48 85 d2               test   rdx,rdx
  43:  74 12                  je     57 <size2+0x47>
  45:  48 63 07               movsxd rax,DWORD PTR [rdi]
  48:  48 63 4f 04            movsxd rcx,DWORD PTR [rdi+0x4]
  4c:  48 0f af c1            imul   rax,rcx
  50:  0f af 47 08            imul   eax,DWORD PTR [rdi+0x8]
  54:  0f af c2               imul   eax,edx
  57:  c3                     ret

the optimized version is:

  10:  48 63 07               movsxd rax,DWORD PTR [rdi]
  13:  48 63 57 04            movsxd rdx,DWORD PTR [rdi+0x4]
  17:  48 0f af c2            imul   rax,rdx
  1b:  0f af 47 08            imul   eax,DWORD PTR [rdi+0x8]
  1f:  69 c0 83 04 00 00      imul   eax,eax,0x483
  25:  c3                     ret

Which simply computes the product:

  D[0] * D[1] * D[2] * const

where const is the product of all static extents. Meaning the loop to
compute the product of dynamic extents has been fully unrolled and
all constants are perfectly precomputed.

libstdc++-v3/ChangeLog:

        * include/std/mdspan (__mdspan::__fwd_prod): Compute as the
        product of pre-computed static static and the product of dynamic
        extents.
        (__mdspan::__rev_prod): Ditto.

Signed-off-by: Luc Grosheintz <luc.groshei...@gmail.com>
---
 libstdc++-v3/include/std/mdspan | 81 +++++++++++++++++++++++----------
 1 file changed, 56 insertions(+), 25 deletions(-)

diff --git a/libstdc++-v3/include/std/mdspan b/libstdc++-v3/include/std/mdspan
index 5e79d4bfb59..06ccf3e3827 100644
--- a/libstdc++-v3/include/std/mdspan
+++ b/libstdc++-v3/include/std/mdspan
@@ -184,6 +184,49 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       __valid_static_extent = _Extent == dynamic_extent
        || _Extent <= numeric_limits<_IndexType>::max();
 
+    template<typename _Extents>
+      consteval size_t
+      __static_prod(size_t __begin, size_t __end)
+      {
+       size_t __prod = 1;
+       for(size_t __i = __begin; __i < __end; ++__i)
+         {
+             auto __ext = _Extents::static_extent(__i);
+             __prod *= __ext == dynamic_extent ? size_t(1) : __ext;
+         }
+       return __prod;
+      }
+
+    // Pre-compute: \prod_{i = 0}^r _Extents[i]
+    template<typename _Extents>
+      struct _FwdProd
+      {
+       constexpr static std::array<size_t, _Extents::rank() + 1> _S_value =
+       [] consteval
+       {
+         constexpr size_t __rank = _Extents::rank();
+         std::array<size_t, __rank + 1> __ret;
+         for(size_t __r = 0; __r < __rank + 1; ++__r)
+           __ret[__r] = __static_prod<_Extents>(0, __r);
+         return __ret;
+       }();
+      };
+
+    // Pre-compute: \prod_{i = r+1}^n _Extents[i]
+    template<typename _Extents>
+      struct _RevProd
+      {
+       constexpr static std::array<size_t, _Extents::rank()> _S_value =
+       [] consteval
+       {
+         constexpr size_t __rank = _Extents::rank();
+         std::array<size_t, __rank> __ret;
+         for(size_t __r = 0; __r < __rank; ++__r)
+           __ret[__r] = __static_prod<_Extents>(__r + 1, __rank);
+         return __ret;
+       }();
+      };
+
     template<typename _Extents>
       constexpr span<const size_t>
       __static_extents(size_t __begin = 0, size_t __end = _Extents::rank())
@@ -352,46 +395,34 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
          return false;
       }
 
-    constexpr size_t
-    __static_extents_prod(const auto& __sta_exts) noexcept
-    {
-      size_t __ret = 1;
-      for (auto __factor : __sta_exts)
-       if (__factor != dynamic_extent)
-         __ret *= __factor;
-      return __ret;
-    }
-
     template<typename _Extents>
       constexpr typename _Extents::index_type
-      __exts_prod(const _Extents& __exts, size_t __begin, size_t __end) 
noexcept
+      __extents_prod(const _Extents& __exts, size_t __sta_prod, size_t __begin,
+                    size_t __end)
       {
-       using _IndexType = typename _Extents::index_type;
-
-       size_t __ret = 1;
-       if constexpr (_Extents::rank_dynamic() != _Extents::rank())
-         {
-           auto __sta_exts = __static_extents<_Extents>(__begin, __end);
-           __ret = __static_extents_prod(__sta_exts);
-           if (__ret == 0)
-             return 0;
-         }
-
+       size_t __ret = __sta_prod;
        if constexpr (_Extents::rank_dynamic() > 0)
          for (auto __factor : __dynamic_extents(__exts, __begin, __end))
            __ret *= size_t(__factor);
-       return _IndexType(__ret);
+       return static_cast<typename _Extents::index_type>(__ret);
       }
 
     template<typename _Extents>
       constexpr typename _Extents::index_type
       __fwd_prod(const _Extents& __exts, size_t __r) noexcept
-      { return __exts_prod(__exts, 0, __r); }
+      {
+       size_t __sta_prod = _FwdProd<_Extents>::_S_value[__r];
+       return __extents_prod(__exts, __sta_prod, 0, __r);
+      }
 
     template<typename _Extents>
       constexpr typename _Extents::index_type
       __rev_prod(const _Extents& __exts, size_t __r) noexcept
-      { return __exts_prod(__exts, __r + 1, __exts.rank()); }
+      {
+       constexpr size_t __rank = _Extents::rank();
+       size_t __sta_prod = _RevProd<_Extents>::_S_value[__r];
+       return __extents_prod(__exts, __sta_prod, __r + 1, __rank);
+      }
 
     template<typename _Extents>
       constexpr typename _Extents::index_type
-- 
2.50.0

Reply via email to