timshen updated this revision to Diff 139043.
timshen added a comment.
Herald added a subscriber: christof.

Rebase.


https://reviews.llvm.org/D41843

Files:
  libcxx/include/experimental/simd
  
libcxx/test/std/experimental/simd/simd.whereexpr/const_where_expression.pass.cpp
  libcxx/test/std/experimental/simd/simd.whereexpr/where.pass.cpp
  libcxx/test/std/experimental/simd/simd.whereexpr/where_expression.pass.cpp

Index: libcxx/test/std/experimental/simd/simd.whereexpr/where_expression.pass.cpp
===================================================================
--- /dev/null
+++ libcxx/test/std/experimental/simd/simd.whereexpr/where_expression.pass.cpp
@@ -0,0 +1,366 @@
+//===----------------------------------------------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is dual licensed under the MIT and the University of Illinois Open
+// Source Licenses. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++98, c++03
+
+// <experimental/simd>
+//
+// // [simd.whereexpr]
+// template <class M, class T>
+// class where_expression : public const_where_expression<M, T> {
+// public:
+//   where_expression(const where_expression&) = delete;
+//   where_expression& operator=(const where_expression&) = delete;
+//   template <class U> void operator=(U&& x);
+//   template <class U> void operator+=(U&& x);
+//   template <class U> void operator-=(U&& x);
+//   template <class U> void operator*=(U&& x);
+//   template <class U> void operator/=(U&& x);
+//   template <class U> void operator%=(U&& x);
+//   template <class U> void operator&=(U&& x);
+//   template <class U> void operator|=(U&& x);
+//   template <class U> void operator^=(U&& x);
+//   template <class U> void operator<<=(U&& x);
+//   template <class U> void operator>>=(U&& x);
+//   void operator++();
+//   void operator++(int);
+//   void operator--();
+//   void operator--(int);
+//   template <class U, class Flags> void copy_from(const U* mem, Flags);
+// };
+
+#include <experimental/simd>
+#include <cassert>
+#include <cstdint>
+#include <algorithm>
+
+using namespace std::experimental::parallelism_v2;
+
+void test_operators_simd() {
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) = -1;
+    assert(a[0] == -1);
+    assert(a[1] == -1);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) = fixed_size_simd<int, 4>(-1);
+    assert(a[0] == -1);
+    assert(a[1] == -1);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) += -1;
+    assert(a[0] == -1);
+    assert(a[1] == 0);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) += fixed_size_simd<int, 4>(-1);
+    assert(a[0] == -1);
+    assert(a[1] == 0);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) -= -1;
+    assert(a[0] == 1);
+    assert(a[1] == 2);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) -= fixed_size_simd<int, 4>(-1);
+    assert(a[0] == 1);
+    assert(a[1] == 2);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) *= -1;
+    assert(a[0] == 0);
+    assert(a[1] == -1);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) *= fixed_size_simd<int, 4>(-1);
+    assert(a[0] == 0);
+    assert(a[1] == -1);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 3 * i; });
+    where(a >= 6, a) /= 2;
+    assert(a[0] == 0);
+    assert(a[1] == 3);
+    assert(a[2] == 3);
+    assert(a[3] == 4);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 3 * i; });
+    where(a >= 6, a) /= fixed_size_simd<int, 4>(2);
+    assert(a[0] == 0);
+    assert(a[1] == 3);
+    assert(a[2] == 3);
+    assert(a[3] == 4);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a % 2 == 1, a) /=
+        fixed_size_simd<int, 4>([](int i) { return i % 2 * 2; });
+    assert(a[0] == 0);
+    assert(a[1] == 0);
+    assert(a[2] == 2);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 3 * i; });
+    where(a >= 6, a) %= 2;
+    assert(a[0] == 0);
+    assert(a[1] == 3);
+    assert(a[2] == 0);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 3 * i; });
+    where(a >= 6, a) %= fixed_size_simd<int, 4>(2);
+    assert(a[0] == 0);
+    assert(a[1] == 3);
+    assert(a[2] == 0);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a % 2 == 1, a) %=
+        fixed_size_simd<int, 4>([](int i) { return i % 2 * 2; });
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == 2);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a > -2, a) &= 1;
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == 0);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a >= 2, a) &= fixed_size_simd<int, 4>(1);
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == 0);
+    assert(a[3] == 1);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) |= 2;
+    assert(a[0] == 2);
+    assert(a[1] == 3);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) |= fixed_size_simd<int, 4>(2);
+    assert(a[0] == 2);
+    assert(a[1] == 3);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) ^= 1;
+    assert(a[0] == 1);
+    assert(a[1] == 0);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) ^= fixed_size_simd<int, 4>(1);
+    assert(a[0] == 1);
+    assert(a[1] == 0);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) <<= 1;
+    assert(a[0] == 0);
+    assert(a[1] == 2);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a) <<= fixed_size_simd<int, 4>(1);
+    assert(a[0] == 0);
+    assert(a[1] == 2);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 2 * i; });
+    where(a < 4, a) >>= 1;
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == 4);
+    assert(a[3] == 6);
+  }
+  {
+    fixed_size_simd<int, 4> a([](int i) { return 2 * i; });
+    where(a < 4, a) >>= fixed_size_simd<int, 4>(1);
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == 4);
+    assert(a[3] == 6);
+  }
+}
+
+void test_operators_mask() {
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    where(a, a) = fixed_size_simd_mask<int, 4>(false);
+    assert(!a[0]);
+    assert(!a[1]);
+    assert(!a[2]);
+    assert(!a[3]);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    where(a, a) &= fixed_size_simd_mask<int, 4>(false);
+    assert(!a[0]);
+    assert(!a[1]);
+    assert(!a[2]);
+    assert(!a[3]);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    where(!a, a) |= fixed_size_simd_mask<int, 4>(true);
+    assert(a[0]);
+    assert(a[1]);
+    assert(a[2]);
+    assert(a[3]);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    where(a, a) ^= fixed_size_simd_mask<int, 4>(true);
+    assert(!a[0]);
+    assert(!a[1]);
+    assert(!a[2]);
+    assert(!a[3]);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    a[0] = false;
+    a[1] = true;
+    a[2] = true;
+    a[3] = false;
+    where(!a, a) ^= fixed_size_simd_mask<int, 4>(true);
+    assert(a[0]);
+    assert(a[1]);
+    assert(a[2]);
+    assert(a[3]);
+  }
+}
+
+void test_copy_from() {
+  {
+    const int buffer[] = {-1, -2, -3, -4};
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a < 2, a).copy_from(buffer, element_aligned_tag());
+    assert(a[0] == -1);
+    assert(a[1] == -2);
+    assert(a[2] == 2);
+    assert(a[3] == 3);
+  }
+  {
+    const int buffer[] = {-1, -2, -3, -4};
+    fixed_size_simd<int, 4> a([](int i) { return i; });
+    where(a >= 2, a).copy_from(buffer, element_aligned_tag());
+    assert(a[0] == 0);
+    assert(a[1] == 1);
+    assert(a[2] == -3);
+    assert(a[3] == -4);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    const bool input[] = {false, true, true, false};
+    a.copy_from(input, element_aligned_tag());
+
+    const bool buffer[] = {true, true, false, false};
+    where(a, a).copy_from(buffer, element_aligned_tag());
+    assert(!a[0]);
+    assert(a[1]);
+    assert(!a[2]);
+    assert(!a[3]);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    const bool input[] = {false, true, true, false};
+    a.copy_from(input, element_aligned_tag());
+
+    const bool buffer[] = {true, true, false, false};
+    where(!a, a).copy_from(buffer, element_aligned_tag());
+    assert(a[0]);
+    assert(a[1]);
+    assert(a[2]);
+    assert(!a[3]);
+  }
+  {
+    const int b = 1;
+    int a = 3;
+    where(true, a).copy_from(&b, element_aligned_tag());
+    assert(a == 1);
+  }
+  {
+    const int b = 1;
+    int a = 3;
+    where(false, a).copy_from(&b, element_aligned_tag());
+    assert(b == 1);
+  }
+}
+
+int main() {
+  test_operators_simd();
+  test_operators_mask();
+  test_copy_from();
+}
Index: libcxx/test/std/experimental/simd/simd.whereexpr/where.pass.cpp
===================================================================
--- /dev/null
+++ libcxx/test/std/experimental/simd/simd.whereexpr/where.pass.cpp
@@ -0,0 +1,93 @@
+//===----------------------------------------------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is dual licensed under the MIT and the University of Illinois Open
+// Source Licenses. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++98, c++03
+
+// <experimental/simd>
+//
+// // masked assignment [simd.mask.where]
+// template <class T, class Abi>
+// where_expression<simd_mask<T, Abi>, simd<T, Abi>>
+// where(const typename simd<T, Abi>::mask_type&, simd<T, Abi>&) noexcept;
+//
+// template <class T, class Abi>
+// const_where_expression<simd_mask<T, Abi>, const simd<T, Abi>>
+// where(const typename simd<T, Abi>::mask_type&, const simd<T, Abi>&) noexcept;
+//
+// template <class T, class Abi>
+// where_expression<simd_mask<T, Abi>, simd_mask<T, Abi>>
+// where(const nodeduce_t<simd_mask<T, Abi>>&, simd_mask<T, Abi>&) noexcept;
+//
+// template <class T, class Abi>
+// const_where_expression<simd_mask<T, Abi>, const simd_mask<T, Abi>>
+// where(const nodeduce_t<simd_mask<T, Abi>>&, const simd_mask<T, Abi>&) noexcept;
+//
+// template <class T> where_expression<bool, T> where(see below k, T& d) noexcept;
+//
+// template <class T>
+// const_where_expression<bool, const T> where(see below k, const T& d) noexcept;
+
+#include <experimental/simd>
+#include <cassert>
+#include <cstdint>
+#include <algorithm>
+
+using namespace std::experimental::parallelism_v2;
+
+void compile_const_where() {
+  {
+    const native_simd<int> a{};
+    static_assert(
+        std::is_same<decltype(where(a < 2, a)),
+                     const_where_expression<native_simd_mask<int>,
+                                            const native_simd<int>>>::value,
+        "");
+  }
+  {
+    const native_simd_mask<int> a{};
+    static_assert(
+        std::is_same<
+            decltype(where(a, a)),
+            const_where_expression<native_simd_mask<int>,
+                                   const native_simd_mask<int>>>::value,
+        "");
+  }
+  {
+    const bool b = true;
+    static_assert(std::is_same<decltype(where(b, 3)),
+                               const_where_expression<bool, const int>>::value,
+                  "");
+  }
+}
+
+void compile_where() {
+  {
+    native_simd<int> a;
+    static_assert(
+        std::is_same<
+            decltype(where(a < 2, a)),
+            where_expression<native_simd_mask<int>, native_simd<int>>>::value,
+        "");
+  }
+  {
+    native_simd_mask<int> a;
+    static_assert(std::is_same<decltype(where(a, a)),
+                               where_expression<native_simd_mask<int>,
+                                                native_simd_mask<int>>>::value,
+                  "");
+  }
+  {
+    int v = 3;
+    static_assert(
+        std::is_same<decltype(where(true, v)), where_expression<bool, int>>::value,
+        "");
+  }
+}
+
+int main() {}
Index: libcxx/test/std/experimental/simd/simd.whereexpr/const_where_expression.pass.cpp
===================================================================
--- /dev/null
+++ libcxx/test/std/experimental/simd/simd.whereexpr/const_where_expression.pass.cpp
@@ -0,0 +1,103 @@
+//===----------------------------------------------------------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is dual licensed under the MIT and the University of Illinois Open
+// Source Licenses. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++98, c++03
+
+// <experimental/simd>
+//
+// // [simd.whereexpr]
+// template <class M, class T>
+// class const_where_expression {
+//   const M& mask; // exposition only
+//   T& data; // exposition only
+// public:
+//   const_where_expression(const const_where_expression&) = delete;
+//   const_where_expression& operator=(const const_where_expression&) = delete;
+//   remove_const_t<T> operator-() const &&;
+//   template <class U, class Flags> void copy_to(U* mem, Flags f) const &&;
+// };
+
+#include <experimental/simd>
+#include <cassert>
+#include <cstdint>
+#include <algorithm>
+
+using namespace std::experimental::parallelism_v2;
+
+void test_operator_minus() {
+  {
+    const fixed_size_simd<int, 4> a([](int i) { return i; });
+    auto b = -where(a < 2, a);
+    assert(b[0] == 0);
+    assert(b[1] == -1);
+    assert(b[2] == 2);
+    assert(b[3] == 3);
+  }
+  assert((-where(true, 3)) == -3);
+  assert((-where(false, 3)) == 3);
+}
+
+void test_copy_to() {
+  {
+    const fixed_size_simd<int, 4> a([](int i) { return i - 2; });
+    int buffer[] = {1, 2, 3, 4};
+    where(a < 0, a).copy_to(buffer, element_aligned_tag());
+    assert(buffer[0] == -2);
+    assert(buffer[1] == -1);
+    assert(buffer[2] == 3);
+    assert(buffer[3] == 4);
+  }
+  {
+    const fixed_size_simd<int, 4> a([](int i) { return i - 2; });
+    int buffer[] = {1, 2, 3, 4};
+    where(a >= 0, a).copy_to(buffer, element_aligned_tag());
+    assert(buffer[0] == 1);
+    assert(buffer[1] == 2);
+    assert(buffer[2] == 0);
+    assert(buffer[3] == 1);
+  }
+  {
+    fixed_size_simd_mask<int, 4> a;
+    {
+      bool input[] = {false, true, true, false};
+      a.copy_from(input, element_aligned_tag());
+    }
+    {
+      bool buffer[] = {true, true, false, false};
+      where(a, a).copy_to(buffer, element_aligned_tag());
+      assert(buffer[0]);
+      assert(buffer[1]);
+      assert(buffer[2]);
+      assert(!buffer[3]);
+    }
+    {
+      bool buffer[] = {true, true, false, false};
+      where(!a, a).copy_to(buffer, element_aligned_tag());
+      assert(!buffer[0]);
+      assert(buffer[1]);
+      assert(!buffer[2]);
+      assert(!buffer[3]);
+    }
+  }
+  {
+    int b = 1;
+    where(true, 3).copy_to(&b, element_aligned_tag());
+    assert(b == 3);
+  }
+  {
+    int b = 1;
+    where(false, 3).copy_to(&b, element_aligned_tag());
+    assert(b == 1);
+  }
+}
+
+int main() {
+  test_operator_minus();
+  test_copy_to();
+}
Index: libcxx/include/experimental/simd
===================================================================
--- libcxx/include/experimental/simd
+++ libcxx/include/experimental/simd
@@ -592,6 +592,7 @@
 #include <algorithm>
 #include <array>
 #include <cstddef>
+#include <cstring>
 #include <functional>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -1341,6 +1342,20 @@
   return __concat_array(__arr, std::make_index_sequence<_Np>());
 }
 
+template <class _Up, class _Tp, class _Abi>
+simd<_Up, _Abi> __bit_cast(const simd<_Tp, _Abi>& __v) {
+  static_assert(std::is_arithmetic<_Up>::value, "");
+  static_assert(sizeof(_Up) == sizeof(_Tp), "");
+  simd<_Up, _Abi> __ret;
+  for (size_t __i = 0; __i < __v.size(); __i++) {
+    _Up __tmp;
+    _Tp __elem = __v[__i];
+    memcpy(&__tmp, &__elem, sizeof(__elem));
+    __ret[__i] = __tmp;
+  }
+  return __ret;
+}
+
 struct __simd_mask_friend {
   template <class _Tp, class _Abi>
   static fixed_size_simd_mask<_Tp, simd_size<_Tp, _Abi>::value>
@@ -1428,6 +1443,31 @@
     }
     return concat(__arr);
   }
+
+  template <class _Tp, class _Abi>
+  static simd<_Tp, _Abi> __simd_select(const simd<_Tp, _Abi>& __false_values,
+                                       const simd<_Tp, _Abi>& __true_values,
+                                       const simd_mask<_Tp, _Abi>& __m) {
+    using __unsigned = typename __unsigned_traits<sizeof(_Tp)>::type;
+    return __bit_cast<_Tp>(
+        (__bit_cast<__unsigned>(__false_values) & ~__m.__s_) +
+        (__bit_cast<__unsigned>(__true_values) & __m.__s_));
+  }
+
+  template <class _Tp, class _Abi>
+  static simd_mask<_Tp, _Abi>
+  __simd_select(const simd_mask<_Tp, _Abi>& __false_values,
+                const simd_mask<_Tp, _Abi>& __true_values,
+                const simd_mask<_Tp, _Abi>& __m) {
+    using __unsigned = typename __unsigned_traits<sizeof(_Tp)>::type;
+    return __simd_select(__false_values.__s_, __true_values.__s_,
+                         simd_mask<__unsigned, _Abi>(__m.__s_));
+  }
+
+  template <class _Tp>
+  static _Tp __simd_select(_Tp __false_value, _Tp __true_value, bool __m) {
+    return __m ? __true_value : __false_value;
+  }
 };
 
 template <class _Tp, class _Abi>
@@ -1593,38 +1633,6 @@
   return 0;
 }
 
-// masked assignment [simd.whereexpr]
-template <class _MaskType, class _Tp>
-class const_where_expression;
-template <class _MaskType, class _Tp>
-class where_expression;
-
-// masked assignment [simd.mask.where]
-template <class _Tp, class _Abi>
-where_expression<simd_mask<_Tp, _Abi>, simd<_Tp, _Abi>>
-where(const typename simd<_Tp, _Abi>::mask_type&, simd<_Tp, _Abi>&) noexcept;
-
-template <class _Tp, class _Abi>
-const_where_expression<simd_mask<_Tp, _Abi>, const simd<_Tp, _Abi>>
-where(const typename simd<_Tp, _Abi>::mask_type&,
-      const simd<_Tp, _Abi>&) noexcept;
-
-template <class _Tp, class _Abi>
-where_expression<simd_mask<_Tp, _Abi>, simd_mask<_Tp, _Abi>>
-where(const typename __nodeduce<simd_mask<_Tp, _Abi>>::type&,
-      simd_mask<_Tp, _Abi>&) noexcept;
-
-template <class _Tp, class _Abi>
-const_where_expression<simd_mask<_Tp, _Abi>, const simd_mask<_Tp, _Abi>>
-where(const typename __nodeduce<simd_mask<_Tp, _Abi>>::type&,
-      const simd_mask<_Tp, _Abi>&) noexcept;
-
-template <class _Tp>
-where_expression<bool, _Tp> where(bool, _Tp&) noexcept;
-
-template <class _Tp>
-const_where_expression<bool, const _Tp> where(bool, const _Tp&) noexcept;
-
 // reductions [simd.reductions]
 template <class _Tp, class _Abi, class _BinaryOp = std::plus<_Tp>>
 _Tp reduce(const simd<_Tp, _Abi>& __v, _BinaryOp __op = _BinaryOp()) {
@@ -1635,36 +1643,6 @@
   return __acc;
 }
 
-template <class _MaskType, class _SimdType, class _BinaryOp>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       typename _SimdType::value_type neutral_element, _BinaryOp binary_op);
-
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       plus<typename _SimdType::value_type> binary_op = {});
-
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       multiplies<typename _SimdType::value_type> binary_op);
-
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_and<typename _SimdType::value_type> binary_op);
-
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_or<typename _SimdType::value_type> binary_op);
-
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-reduce(const const_where_expression<_MaskType, _SimdType>&,
-       bit_xor<typename _SimdType::value_type> binary_op);
-
 template <class _Tp, class _Abi>
 _Tp hmin(const simd<_Tp, _Abi>& __v) {
   _Tp __acc = __v[0];
@@ -1674,10 +1652,6 @@
   return __acc;
 }
 
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-hmin(const const_where_expression<_MaskType, _SimdType>&);
-
 template <class _Tp, class _Abi>
 _Tp hmax(const simd<_Tp, _Abi>& __v) {
   _Tp __acc = __v[0];
@@ -1687,10 +1661,6 @@
   return __acc;
 }
 
-template <class _MaskType, class _SimdType>
-typename _SimdType::value_type
-hmax(const const_where_expression<_MaskType, _SimdType>&);
-
 // algorithms [simd.alg]
 template <class _Tp, class _Abi>
 simd<_Tp, _Abi> min(const simd<_Tp, _Abi>& __a,
@@ -1724,53 +1694,6 @@
   return min(max(__v, __lo), __hi);
 }
 
-// [simd.whereexpr]
-// TODO implement where expressions.
-template <class _MaskType, class _Tp>
-class const_where_expression {
-public:
-  const_where_expression(const const_where_expression&) = delete;
-  const_where_expression& operator=(const const_where_expression&) = delete;
-  typename remove_const<_Tp>::type operator-() const&&;
-  template <class _Up, class _Flags>
-  void copy_to(_Up*, _Flags) const&&;
-};
-
-template <class _MaskType, class _Tp>
-class where_expression : public const_where_expression<_MaskType, _Tp> {
-public:
-  where_expression(const where_expression&) = delete;
-  where_expression& operator=(const where_expression&) = delete;
-  template <class _Up>
-  void operator=(_Up&&);
-  template <class _Up>
-  void operator+=(_Up&&);
-  template <class _Up>
-  void operator-=(_Up&&);
-  template <class _Up>
-  void operator*=(_Up&&);
-  template <class _Up>
-  void operator/=(_Up&&);
-  template <class _Up>
-  void operator%=(_Up&&);
-  template <class _Up>
-  void operator&=(_Up&&);
-  template <class _Up>
-  void operator|=(_Up&&);
-  template <class _Up>
-  void operator^=(_Up&&);
-  template <class _Up>
-  void operator<<=(_Up&&);
-  template <class _Up>
-  void operator>>=(_Up&&);
-  void operator++();
-  void operator++(int);
-  void operator--();
-  void operator--(int);
-  template <class _Up, class _Flags>
-  void copy_from(const _Up*, _Flags);
-};
-
 // [simd.class]
 template <class _Tp, class _Abi>
 class simd {
@@ -2258,6 +2181,340 @@
   }
 };
 
+template <class _Tp, class _Abi, class _Up, class _Flags>
+void __mask_copy_to(const simd<_Tp, _Abi>& __v, const simd_mask<_Tp, _Abi>& __m,
+                    _Up* __buffer, _Flags) {
+  for (size_t __i = 0; __i < __v.size(); __i++) {
+    if (__m[__i]) {
+      __buffer[__i] = static_cast<_Up>(__v[__i]);
+    }
+  }
+}
+
+template <class _Tp, class _Abi, class _Up, class _Flags>
+void __mask_copy_to(const simd_mask<_Tp, _Abi>& __v,
+                    const simd_mask<_Tp, _Abi>& __m, _Up* __buffer, _Flags) {
+  for (size_t __i = 0; __i < __v.size(); __i++) {
+    if (__m[__i]) {
+      __buffer[__i] = static_cast<_Up>(__v[__i]);
+    }
+  }
+}
+
+template <class _Tp, class _Up, class _Flags>
+void __mask_copy_to(_Tp __v, bool __m, _Up* __buffer, _Flags) {
+  if (__m) {
+    *__buffer = static_cast<_Up>(__v);
+  }
+}
+
+template <class _Tp, class _Abi, class _Up, class _Flags>
+void __mask_copy_from(simd<_Tp, _Abi>& __v, const simd_mask<_Tp, _Abi>& __m,
+                      const _Up* __buffer, _Flags) {
+  // TODO: optimize for overaligned flags
+  for (size_t __i = 0; __i < __v.size(); __i++) {
+    if (__m[__i]) {
+      __v[__i] = static_cast<_Tp>(__buffer[__i]);
+    }
+  }
+}
+
+template <class _Tp, class _Abi, class _Up, class _Flags>
+void __mask_copy_from(simd_mask<_Tp, _Abi>& __v,
+                      const simd_mask<_Tp, _Abi>& __m, const _Up* __buffer,
+                      _Flags) {
+  // TODO: optimize based on bool's bit pattern.
+  for (size_t __i = 0; __i < __v.size(); __i++) {
+    if (__m[__i]) {
+      __v[__i] = static_cast<bool>(__buffer[__i]);
+    }
+  }
+}
+
+template <class _Tp, class _Up, class _Flags>
+void __mask_copy_from(_Tp& __v, bool __m, const _Up* __buffer, _Flags) {
+  if (__m) {
+    __v = static_cast<_Tp>(*__buffer);
+  }
+}
+
+template <class _ValueType>
+struct __simd_value_type_traits {
+  static_assert(std::is_arithmetic<_ValueType>::value, "");
+  using type = _ValueType;
+};
+
+template <class _Tp, class _Abi>
+struct __simd_value_type_traits<simd<_Tp, _Abi>> {
+  static_assert(std::is_arithmetic<_Tp>::value, "");
+  using type = _Tp;
+};
+
+template <class _Tp, class _Abi>
+struct __simd_value_type_traits<simd_mask<_Tp, _Abi>> {
+  static_assert(std::is_arithmetic<_Tp>::value, "");
+  using type = _Tp;
+};
+
+// [simd.whereexpr]
+template <class _MaskType, class _ValueType>
+class const_where_expression {
+  static_assert(
+      std::is_arithmetic<typename remove_const<_ValueType>::type>::value ||
+          is_simd<typename remove_const<_ValueType>::type>::value ||
+          is_simd_mask<typename remove_const<_ValueType>::type>::value,
+      "");
+
+  using _Tp = typename __simd_value_type_traits<
+      typename remove_const<_ValueType>::type>::type;
+
+  typename std::conditional<std::is_same<bool, _MaskType>::value, bool,
+                            const _MaskType>::type __m_;
+  _ValueType& __v_;
+
+  const_where_expression(const _MaskType& __m, _ValueType& __v)
+      : __m_(__m), __v_(__v) {}
+
+  const_where_expression(const const_where_expression&) = default;
+
+  template <class, class>
+  friend class where_expression;
+
+  template <class _Up, class _Ap>
+  friend const_where_expression<simd_mask<_Up, _Ap>, const simd<_Up, _Ap>>
+  where(const typename simd<_Up, _Ap>::mask_type& __m,
+        const simd<_Up, _Ap>& __v) noexcept;
+
+  template <class _Up, class _Ap>
+  friend const_where_expression<simd_mask<_Up, _Ap>, const simd_mask<_Up, _Ap>>
+  where(const typename __nodeduce<simd_mask<_Up, _Ap>>::type& __m,
+        const simd_mask<_Up, _Ap>& __v) noexcept;
+
+  template <class _Up, class _Mp>
+  friend typename std::enable_if<std::is_same<_Mp, bool>::value,
+                                 const_where_expression<bool, const _Up>>::type
+  where(_Mp __m, const _Up& __v) noexcept;
+
+public:
+  const_where_expression& operator=(const const_where_expression&) = delete;
+
+  typename std::remove_const<_ValueType>::type operator-() const&& {
+    static_assert(!is_simd_mask<typename remove_const<_ValueType>::type>::value,
+                  "Library extension: operator-() doesn't really make sense "
+                  "when operating on simd_mask<>.");
+    return __simd_mask_friend::__simd_select(__v_, _ValueType(0), __m_) -
+           __simd_mask_friend::__simd_select(_ValueType(0), __v_, __m_);
+  }
+
+  template <class _Up, class _Flags>
+  typename std::enable_if<std::is_same<_Tp, _Up>::value ||
+                          !std::is_same<_Tp, bool>::value>::type
+  copy_to(_Up* __buffer, _Flags) const&& {
+    __mask_copy_to(__v_, __m_, __buffer, _Flags());
+  }
+};
+
+template <class _MaskType, class _ValueType>
+class where_expression : public const_where_expression<_MaskType, _ValueType> {
+  using _Tp = typename __simd_value_type_traits<
+      typename remove_const<_ValueType>::type>::type;
+
+  where_expression(const _MaskType& __m, _ValueType& __v)
+      : const_where_expression<_MaskType, _ValueType>(__m, __v) {}
+
+  where_expression(const where_expression&) = default;
+
+  template <class _Up, class _Ap>
+  friend where_expression<simd_mask<_Up, _Ap>, simd<_Up, _Ap>>
+  where(const typename simd<_Up, _Ap>::mask_type& __m,
+        simd<_Up, _Ap>& __v) noexcept;
+
+  template <class _Up, class _Ap>
+  friend where_expression<simd_mask<_Up, _Ap>, simd_mask<_Up, _Ap>>
+  where(const typename __nodeduce<simd_mask<_Up, _Ap>>::type& __m,
+        simd_mask<_Up, _Ap>& __v) noexcept;
+
+  template <class _Up, class _Mp>
+  friend typename std::enable_if<std::is_same<_Mp, bool>::value,
+                                 where_expression<bool, _Up>>::type
+  where(_Mp __m, _Up& __v) noexcept;
+
+public:
+  where_expression& operator=(const where_expression&) = delete;
+
+  template <class _Up>
+  auto operator=(_Up&& __u)
+      -> decltype(this->__v_ = std::forward<_Up>(__u), void()) {
+    this->__v_ = __simd_mask_friend::__simd_select(
+        this->__v_, _ValueType(std::forward<_Up>(__u)), this->__m_);
+  }
+
+  template <class _Up>
+  auto operator+=(_Up&& __u)
+      -> decltype(this->__v_ + std::forward<_Up>(__u), void()) {
+    *this = this->__v_ + std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator-=(_Up&& __u)
+      -> decltype(this->__v_ - std::forward<_Up>(__u), void()) {
+    *this = this->__v_ - std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator*=(_Up&& __u)
+      -> decltype(this->__v_ * std::forward<_Up>(__u), void()) {
+    *this = this->__v_ * std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator/=(_Up&& __u)
+      -> decltype(this->__v_ / std::forward<_Up>(__u), void()) {
+    this->__v_ =
+        this->__v_ /
+        __simd_mask_friend::__simd_select(
+            _ValueType(1), _ValueType(std::forward<_Up>(__u)), this->__m_);
+  }
+
+  template <class _Up>
+  auto operator%=(_Up&& __u)
+      -> decltype(this->__v_ % std::forward<_Up>(__u), void()) {
+    this->__v_ = __simd_mask_friend::__simd_select(
+        this->__v_,
+        this->__v_ %
+            __simd_mask_friend::__simd_select(
+                _ValueType(1), _ValueType(std::forward<_Up>(__u)), this->__m_),
+        this->__m_);
+  }
+
+  template <class _Up>
+  auto operator&=(_Up&& __u)
+      -> decltype(this->__v_ & std::forward<_Up>(__u), void()) {
+    *this = this->__v_ & std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator|=(_Up&& __u)
+      -> decltype(this->__v_ | std::forward<_Up>(__u), void()) {
+    *this = this->__v_ | std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator^=(_Up&& __u)
+      -> decltype(this->__v_ ^ std::forward<_Up>(__u), void()) {
+    *this = this->__v_ ^ std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator<<=(_Up&& __u)
+      -> decltype(this->__v_ << std::forward<_Up>(__u), void()) {
+    *this = this->__v_ << std::forward<_Up>(__u);
+  }
+
+  template <class _Up>
+  auto operator>>=(_Up&& __u)
+      -> decltype(this->__v_ >> std::forward<_Up>(__u), void()) {
+    *this = this->__v_ >> std::forward<_Up>(__u);
+  }
+
+  void operator++() { *this += _ValueType(1); }
+
+  void operator++(int) { ++*this; }
+
+  void operator--() { *this -= _ValueType(1); }
+
+  void operator--(int) { --*this; }
+
+  template <class _Up, class _Flags>
+  typename std::enable_if<std::is_same<_Tp, _Up>::value ||
+                          !std::is_same<_Tp, bool>::value>::type
+  copy_from(const _Up* __buffer, _Flags) {
+    __mask_copy_from(this->__v_, this->__m_, __buffer, _Flags());
+  }
+};
+
+template <class _Tp, class _Abi>
+where_expression<simd_mask<_Tp, _Abi>, simd<_Tp, _Abi>>
+where(const typename simd<_Tp, _Abi>::mask_type& __m,
+      simd<_Tp, _Abi>& __v) noexcept {
+  return where_expression<simd_mask<_Tp, _Abi>, simd<_Tp, _Abi>>(__m, __v);
+}
+
+template <class _Tp, class _Abi>
+const_where_expression<simd_mask<_Tp, _Abi>, const simd<_Tp, _Abi>>
+where(const typename simd<_Tp, _Abi>::mask_type& __m,
+      const simd<_Tp, _Abi>& __v) noexcept {
+  return const_where_expression<simd_mask<_Tp, _Abi>, const simd<_Tp, _Abi>>(
+      __m, __v);
+}
+
+template <class _Tp, class _Abi>
+where_expression<simd_mask<_Tp, _Abi>, simd_mask<_Tp, _Abi>>
+where(const typename __nodeduce<simd_mask<_Tp, _Abi>>::type& __m,
+      simd_mask<_Tp, _Abi>& __v) noexcept {
+  return where_expression<simd_mask<_Tp, _Abi>, simd_mask<_Tp, _Abi>>(__m, __v);
+}
+
+template <class _Tp, class _Abi>
+const_where_expression<simd_mask<_Tp, _Abi>, const simd_mask<_Tp, _Abi>>
+where(const typename __nodeduce<simd_mask<_Tp, _Abi>>::type& __m,
+      const simd_mask<_Tp, _Abi>& __v) noexcept {
+  return const_where_expression<simd_mask<_Tp, _Abi>,
+                                const simd_mask<_Tp, _Abi>>(__m, __v);
+}
+
+template <class _Tp, class _MaskType>
+typename std::enable_if<std::is_same<_MaskType, bool>::value,
+                        where_expression<bool, _Tp>>::type
+where(_MaskType __m, _Tp& __v) noexcept {
+  return where_expression<bool, _Tp>(__m, __v);
+}
+
+template <class _Tp, class _MaskType>
+typename std::enable_if<std::is_same<_MaskType, bool>::value,
+                        const_where_expression<bool, const _Tp>>::type
+where(_MaskType __m, const _Tp& __v) noexcept {
+  return const_where_expression<bool, const _Tp>(__m, __v);
+}
+
+template <class _MaskType, class _SimdType, class _BinaryOp>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       typename _SimdType::value_type neutral_element, _BinaryOp binary_op);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       plus<typename _SimdType::value_type> binary_op = {});
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       multiplies<typename _SimdType::value_type> binary_op);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       bit_and<typename _SimdType::value_type> binary_op);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       bit_or<typename _SimdType::value_type> binary_op);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+reduce(const const_where_expression<_MaskType, _SimdType>&,
+       bit_xor<typename _SimdType::value_type> binary_op);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+hmin(const const_where_expression<_MaskType, _SimdType>&);
+
+template <class _MaskType, class _SimdType>
+typename _SimdType::value_type
+hmax(const const_where_expression<_MaskType, _SimdType>&);
+
 _LIBCPP_END_NAMESPACE_EXPERIMENTAL_SIMD
 
 #endif /* _LIBCPP_EXPERIMENTAL_SIMD */
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to