The methods layout_{left,right}::mapping::stride are defined
as

  \prod_{i = 0}^r E[i]
  \prod_{i = r+1}^n E[i]

This is computed as the product of a pre-comupted static product and the
product of the required dynamic extents.

Disassembly shows that even for low-rank extents, i.e. rank == 1 and
rank == 2, with at least one  dynamic extent, the generated code loads
two values; and then runs the loop over at most one element, e.g.

 220:  48 8b 0c f5 00 00 00   mov    rcx,QWORD PTR [rsi*8+0x0]
 227:  00
 228:  48 8b 04 f5 00 00 00   mov    rax,QWORD PTR [rsi*8+0x0]
 22f:  00
 230:  48 c1 e1 02            shl    rcx,0x2
 234:  74 1a                  je     250 <stride_left_d5+0x30>
 236:  48 01 f9               add    rcx,rdi
 239:  0f 1f 80 00 00 00 00   nop    DWORD PTR [rax+0x0]
 240:  48 63 17               movsxd rdx,DWORD PTR [rdi]
 243:  48 83 c7 04            add    rdi,0x4
 247:  48 0f af c2            imul   rax,rdx
 24b:  48 39 f9               cmp    rcx,rdi
 24e:  75 f0                  jne    240 <stride_left_d5+0x20>
 250:  c3                     ret

If there's no dynamic extents, it simply loads the precomputed product
of static extents.

For rank == 1 the answer is constant `1`; for rank == 2 it's either 1 or
extents.extent(k), with k == 0 for layout_left and k == 1 for
layout_right.

Consider,

  using Ed = std::extents<int, dyn>;
  int stride_left_d(const std::layout_left::mapping<Ed>& m, size_t r)
  { return m.stride(r); }

  using E3d = std::extents<int, 3, dyn>;
  int stride_left_3d(const std::layout_left::mapping<E3d>& m, size_t r)
  { return m.stride(r); }

  using Ed5 = std::extents<int, dyn, 5>;
  int stride_left_d5(const std::layout_left::mapping<Ed5>& m, size_t r)
  { return m.stride(r); }

The optimized code for these three cases is:

  0000000000000060 <stride_left_d>:
  60:  b8 01 00 00 00         mov    eax,0x1
  65:  c3                     ret

  0000000000000090 <stride_left_3d>:
  90:  48 83 fe 01            cmp    rsi,0x1
  94:  19 c0                  sbb    eax,eax
  96:  83 e0 fe               and    eax,0xfffffffe
  99:  83 c0 03               add    eax,0x3
  9c:  c3                     ret

  00000000000000a0 <stride_left_d5>:
  a0:  b8 01 00 00 00         mov    eax,0x1
  a5:  48 85 f6               test   rsi,rsi
  a8:  74 02                  je     ac <stride_left_d5+0xc>
  aa:  8b 07                  mov    eax,DWORD PTR [rdi]
  ac:  c3                     ret

For rank == 1 it simply returns 1 (as expected). For rank == 2, it
either implements a branchless formula, or conditionally loads one
value. In all cases involving a dynamic extent this seems like it's
always doing clearly less work, both in terms of computation and loads.

For rank == 2, it trades loading one value for a branchless sequence of
four instructions that don't require loading any values.

libstdc++-v3/ChangeLog:

        * include/std/mdspan (layout_left::mapping::stride): Optimize
        for rank <= 2.
        (layout_right::mapping::stride): Ditto.

Signed-off-by: Luc Grosheintz <luc.groshei...@gmail.com>
---
 libstdc++-v3/include/std/mdspan | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/libstdc++-v3/include/std/mdspan b/libstdc++-v3/include/std/mdspan
index 06ccf3e3827..f288af96cdb 100644
--- a/libstdc++-v3/include/std/mdspan
+++ b/libstdc++-v3/include/std/mdspan
@@ -652,7 +652,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       requires (extents_type::rank() > 0)
       {
        __glibcxx_assert(__i < extents_type::rank());
-       return __mdspan::__fwd_prod(_M_extents, __i);
+       if constexpr (extents_type::rank() == 1)
+         return 1;
+       else if constexpr (extents_type::rank() == 2)
+         return __i == 0 ? 1 : _M_extents.extent(0);
+       else
+         return __mdspan::__fwd_prod(_M_extents, __i);
       }
 
       template<typename _OExtents>
@@ -797,7 +802,12 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       requires (extents_type::rank() > 0)
       {
        __glibcxx_assert(__i < extents_type::rank());
-       return __mdspan::__rev_prod(_M_extents, __i);
+       if constexpr (extents_type::rank() == 1)
+         return 1;
+       else if constexpr (extents_type::rank() == 2)
+         return __i == 0 ? _M_extents.extent(1) : 1;
+       else
+         return __mdspan::__rev_prod(_M_extents, __i);
       }
 
       template<typename _OExtents>
-- 
2.50.0

Reply via email to