[libcxx-commits] [libcxx] [libc++] Vectorize trivially equality comparable types (PR #87716)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Fri Apr 5 02:46:11 PDT 2024


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/87716

>From dc09211e98c33ea14fd7378d8d07441587e11e42 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 25 Mar 2024 08:25:23 +0100
Subject: [PATCH] [libc++] Vectorize trivially equality comparable types

---
 libcxx/include/__algorithm/mismatch.h         | 26 ++++++++++
 libcxx/include/__algorithm/simd_utils.h       | 24 +++++++++-
 .../mismatch/mismatch.pass.cpp                | 47 +++++++++++++++----
 3 files changed, 86 insertions(+), 11 deletions(-)

diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 4ada29eabc470c..c35b3ae26cc6b9 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -16,6 +16,7 @@
 #include <__algorithm/unwrap_iter.h>
 #include <__config>
 #include <__functional/identity.h>
+#include <__type_traits/copy_cv.h>
 #include <__type_traits/desugars_to.h>
 #include <__type_traits/invoke.h>
 #include <__type_traits/is_constant_evaluated.h>
@@ -119,6 +120,31 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
   return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
 }
 
+template <class _Tp,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          __enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+                            __is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
+                            __can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
+                        int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+  if (__libcpp_is_constant_evaluated()) {
+    return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+  } else {
+    using __integer_t = __copy_cv_t<_Tp, __get_as_integer_type<_Tp>>;
+    // This is valid because we disable TBAA when loading vectors. Alignment requirements still have to be fulfilled.
+    auto __ret = std::__mismatch(
+        reinterpret_cast<__integer_t*>(__first1),
+        reinterpret_cast<__integer_t*>(__last1),
+        reinterpret_cast<__integer_t*>(__first2),
+        __pred,
+        __proj1,
+        __proj2);
+    return {reinterpret_cast<_Tp*>(__ret.first), reinterpret_cast<_Tp*>(__ret.second)};
+  }
+}
 #endif // _LIBCPP_VECTORIZE_ALGORITHMS
 
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
index 989a1957987e1e..c3b59e0f98ef0d 100644
--- a/libcxx/include/__algorithm/simd_utils.h
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -43,6 +43,27 @@ _LIBCPP_PUSH_MACROS
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+template <class _Tp>
+inline constexpr bool __can_map_to_integer_v =
+    sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
+
+template <class _Tp>
+_LIBCPP_HIDE_FROM_ABI auto __get_as_integer_type_impl() {
+  if constexpr (sizeof(_Tp) == 1)
+    return uint8_t{};
+  else if constexpr (sizeof(_Tp) == 2)
+    return uint16_t{};
+  else if constexpr (sizeof(_Tp) == 4)
+    return uint32_t{};
+  else if constexpr (sizeof(_Tp) == 8)
+    return uint64_t{};
+  else
+    static_assert(false, "Unexpected size type");
+}
+
+template <class _Tp>
+using __get_as_integer_type = decltype(std::__get_as_integer_type_impl<_Tp>());
+
 // This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
 // in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
 // as far as benchmarks are concerned.
@@ -83,7 +104,8 @@ using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_t
 template <class _VecT, class _Tp>
 _LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
   return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
-    return _VecT{__ptr[_Indices]...};
+    [[__gnu__::__may_alias__]] const _Tp* __aliasing_ptr = __ptr;
+    return _VecT{__aliasing_ptr[_Indices]...};
   }(make_index_sequence<__simd_vector_size_v<_VecT>>{});
 }
 
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index eb5f7cacdde34b..72df17628dad78 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
 #endif
 }
 
-struct NonTrivial {
+// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
+struct NonTrivialMod4Comp {
   int i_;
 
-  TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
-  TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+  TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
+  TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
 
-  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
+    return lhs.i_ % 4 == rhs.i_ % 4;
+  }
+};
+
+#if TEST_STD_VER >= 20
+struct TriviallyEqualityComparable {
+  int i_;
+
+  TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
+
+  TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
 };
+#endif // TEST_STD_VER >= 20
 
 struct ModTwoComp {
   TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
   types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
 
   { // use a non-integer type to also test the general case - all elements match
-    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
+    std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
+    check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
   }
 
   { // use a non-integer type to also test the general case - not all elements match
-    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
-    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
+    std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+    std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
+  }
+
+#if TEST_STD_VER >= 20
+  { // trivially equaltiy comparable class type to test forwarding to the vectorized version - all elements match
+    std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
+  }
+
+  { // trivially equaltiy comparable class type to test forwarding to the vectorized version - not all elements match
+    std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+    std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
   }
+#endif // TEST_STD_VER >= 20
 
   return true;
 }



More information about the libcxx-commits mailing list