https://gcc.gnu.org/g:de19b516edbf919d31e9d22fdbf6066342d904a2

commit r15-1857-gde19b516edbf919d31e9d22fdbf6066342d904a2
Author: Jonathan Wakely <jwak...@redhat.com>
Date:   Wed Jun 5 16:01:26 2024 +0100

    libstdc++: Use memchr to optimize std::find [PR88545]
    
    This optimizes std::find to use memchr when searching for an integer in
    a range of bytes.
    
    libstdc++-v3/ChangeLog:
    
            PR libstdc++/88545
            PR libstdc++/115040
            * include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
            New variable template.
            * include/bits/ranges_util.h (__find_fn): Use memchr when
            possible.
            * include/bits/stl_algo.h (find): Likewise.
            * testsuite/25_algorithms/find/bytes.cc: New test.

Diff:
---
 libstdc++-v3/include/bits/cpp_type_traits.h        |  13 ++
 libstdc++-v3/include/bits/ranges_util.h            |  21 ++++
 libstdc++-v3/include/bits/stl_algo.h               |  35 +++++-
 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc | 135 +++++++++++++++++++++
 4 files changed, 202 insertions(+), 2 deletions(-)

diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h 
b/libstdc++-v3/include/bits/cpp_type_traits.h
index abe0c7603e3..4bfb4521e06 100644
--- a/libstdc++-v3/include/bits/cpp_type_traits.h
+++ b/libstdc++-v3/include/bits/cpp_type_traits.h
@@ -35,6 +35,10 @@
 #pragma GCC system_header
 
 #include <bits/c++config.h>
+#include <bits/version.h>
+#if __glibcxx_type_trait_variable_templates
+# include <type_traits> // is_same_v, is_integral_v
+#endif
 
 //
 // This file provides some compile-time information about various types.
@@ -547,6 +551,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3)
     { static constexpr bool __value = false; };
 #endif
 
+#if __glibcxx_type_trait_variable_templates
+  template<typename _ValT, typename _Tp>
+    constexpr bool __can_use_memchr_for_find
+    // Can only use memchr to search for narrow characters and std::byte.
+      = __is_byte<_ValT>::__value
+       // And only if the value to find is an integer (or is also std::byte).
+         && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
+#endif
+
   //
   // Move iterator type
   //
diff --git a/libstdc++-v3/include/bits/ranges_util.h 
b/libstdc++-v3/include/bits/ranges_util.h
index 9b79c3a229d..186acae4f70 100644
--- a/libstdc++-v3/include/bits/ranges_util.h
+++ b/libstdc++-v3/include/bits/ranges_util.h
@@ -34,6 +34,7 @@
 # include <bits/ranges_base.h>
 # include <bits/utility.h>
 # include <bits/invoke.h>
+# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
 
 #ifdef __glibcxx_ranges
 namespace std _GLIBCXX_VISIBILITY(default)
@@ -494,6 +495,26 @@ namespace ranges
       operator()(_Iter __first, _Sent __last,
                 const _Tp& __value, _Proj __proj = {}) const
       {
+       if constexpr (is_same_v<_Proj, identity>)
+         if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
+           if constexpr (sized_sentinel_for<_Sent, _Iter>)
+             if constexpr (contiguous_iterator<_Iter>)
+               if (!is_constant_evaluated())
+                 {
+                   if (static_cast<iter_value_t<_Iter>>(__value) != __value)
+                     return __last;
+
+                   auto __n = __last - __first;
+                   if (__n > 0)
+                     {
+                       const int __ival = static_cast<int>(__value);
+                       const void* __p0 = std::to_address(__first);
+                       if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+                         __n = (const char*)__p1 - (const char*)__p0;
+                     }
+                   return __first + __n;
+                 }
+
        while (__first != __last
            && !(std::__invoke(__proj, *__first) == __value))
          ++__first;
diff --git a/libstdc++-v3/include/bits/stl_algo.h 
b/libstdc++-v3/include/bits/stl_algo.h
index 1a996aa61da..45c3b591326 100644
--- a/libstdc++-v3/include/bits/stl_algo.h
+++ b/libstdc++-v3/include/bits/stl_algo.h
@@ -3838,14 +3838,45 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
   template<typename _InputIterator, typename _Tp>
     _GLIBCXX20_CONSTEXPR
     inline _InputIterator
-    find(_InputIterator __first, _InputIterator __last,
-        const _Tp& __val)
+    find(_InputIterator __first, _InputIterator __last, const _Tp& __val)
     {
       // concept requirements
       __glibcxx_function_requires(_InputIteratorConcept<_InputIterator>)
       __glibcxx_function_requires(_EqualOpConcept<
                typename iterator_traits<_InputIterator>::value_type, _Tp>)
       __glibcxx_requires_valid_range(__first, __last);
+
+#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
+      using _ValT = typename iterator_traits<_InputIterator>::value_type;
+      if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
+       {
+         // If converting the value to the 1-byte value_type alters its value,
+         // then it would not be found by std::find using equality comparison.
+         // We need to check this here, because otherwise something like
+         // memchr("a", 'a'+256, 1) would give a false positive match.
+         if (!(static_cast<_ValT>(__val) == __val))
+           return __last;
+         else if (!__is_constant_evaluated())
+           {
+             const void* __p0 = nullptr;
+             if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
+               __p0 = std::__niter_base(__first);
+#if __cpp_lib_concepts
+             else if constexpr (contiguous_iterator<_InputIterator>)
+               __p0 = std::to_address(__first);
+#endif
+             if (__p0)
+               {
+                 const int __ival = static_cast<int>(__val);
+                 if (auto __n = std::distance(__first, __last); __n > 0)
+                   if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+                     return __first + ((const char*)__p1 - (const char*)__p0);
+                 return __last;
+               }
+           }
+       }
+#endif
+
       return std::__find_if(__first, __last,
                            __gnu_cxx::__ops::__iter_equals_val(__val));
     }
diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc 
b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
new file mode 100644
index 00000000000..f4ac5d4018d
--- /dev/null
+++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
@@ -0,0 +1,135 @@
+// { dg-do run }
+
+#include <algorithm>
+#include <cstddef> // std::byte
+#include <testsuite_hooks.h>
+
+// PR libstdc++/88545 made std::find use memchr as an optimization.
+// This test verifies that it didn't change any semantics.
+
+template<typename C>
+void
+test_char()
+{
+  const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
+  const C* end = a + sizeof(a);
+  const C* res = std::find(a, end, a[0]);
+  VERIFY( res == a );
+  res = std::find(a, end, a[2]);
+  VERIFY( res == a+2 );
+  res = std::find(a, end, a[0] + 256);
+  VERIFY( res == end );
+  res = std::find(a, end, a[0] - 256);
+  VERIFY( res == end );
+  res = std::find(a, end, 256);
+  VERIFY( res == end );
+
+#ifdef __cpp_lib_ranges
+  res = std::ranges::find(a, a[0]);
+  VERIFY( res == a );
+  res = std::ranges::find(a, a[2]);
+  VERIFY( res == a+2 );
+  res = std::ranges::find(a, a[0] + 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, a[0] - 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, 256);
+  VERIFY( res == end );
+#endif
+}
+
+// Trivial type of size 1, with custom equality.
+struct S {
+  bool operator==(const S&) const { return true; };
+  char c;
+};
+
+// Trivial type of size 1, with custom equality.
+enum E
+#if __cplusplus >= 201103L
+: unsigned char
+#endif
+{ e1 = 1, e255 = 255 };
+
+bool operator==(E l, E r) { return (l % 3) == (r % 3); }
+
+struct X { char c; };
+bool operator==(X, char) { return false; }
+bool operator==(char, X) { return false; }
+
+bool operator==(E, char) { return false; }
+bool operator==(char, E) { return false; }
+
+void
+test_non_characters()
+{
+  S s[3] = { {'a'}, {'b'}, {'c'} };
+  S sx = {'x'};
+  S* sres = std::find(s, s+3, sx);
+  VERIFY( sres == s ); // memchr optimization would not find a match
+
+  E e[3] = { E(1), E(2), E(3) };
+  E* eres = std::find(e, e+3, E(4));
+  VERIFY( eres == e ); // memchr optimization would not find a match
+
+  char x[1] = { 'x' };
+  X xx = { 'x' };
+  char* xres = std::find(x, x+1, xx);
+  VERIFY( xres == x+1 ); // memchr optimization would find a match
+  xres = std::find(x, x+1, E('x'));
+  VERIFY( xres == x+1 ); // memchr optimization would find a match
+
+#ifdef __cpp_lib_byte
+  std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
+  std::byte* bres = std::find(b, b+4, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::find(b, b+3, std::byte{3});
+  VERIFY( bres == b+3 );
+#endif
+
+#ifdef __cpp_lib_ranges
+  sres = std::ranges::find(s, sx);
+  VERIFY( sres == s );
+
+  eres = std::ranges::find(e, e+3, E(4));
+  VERIFY( eres == e );
+
+  // std::equality_comparable_with<X, char> is not satisfied, so can't do
+  // std::ranges::find(x, xx)
+
+  bres = std::ranges::find(b, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::ranges::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::ranges::find(b, std::byte{3});
+  VERIFY( bres == b+3 );
+
+  xres = std::find(x, x+1, xx);
+  VERIFY( xres == std::ranges::end(x) );
+  xres = std::find(x, x+1, E('x'));
+  VERIFY( xres == std::ranges::end(x) );
+#endif
+}
+
+int main()
+{
+  test_char<char>();
+  test_char<signed char>();
+  test_char<unsigned char>();
+  test_non_characters();
+
+#if __cpp_lib_constexpr_algorithms
+  static_assert( [] {
+    char c[] = "abcd";
+    return std::find(c, c+4, 'b') == c+1;
+  }() );
+#ifdef __cpp_lib_ranges
+  static_assert( [] {
+    char c[] = "abcd";
+    return std::ranges::find(c, 'b') == c+1;
+  }() );
+#endif
+#endif
+}

Reply via email to