[libcxx-commits] [libcxx] [libc++] Vectorize mismatch (PR #73255)

via libcxx-commits libcxx-commits at lists.llvm.org
Thu Nov 23 09:14:40 PST 2023


https://github.com/philnik777 created https://github.com/llvm/llvm-project/pull/73255

None

>From f5a4e3b1408ae234546163f9221016233278335c Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Sun, 8 Oct 2023 12:45:40 +0200
Subject: [PATCH] [libc++] Vectorize mismatch

---
 .../benchmarks/algorithms/mismatch.bench.cpp  |  31 ++++++
 libcxx/include/CMakeLists.txt                 |   1 +
 libcxx/include/__algorithm/mismatch.h         | 101 +++++++++++++++---
 libcxx/include/__algorithm/vectorization.h    |  69 ++++++++++++
 libcxx/include/experimental/__simd/abi_tag.h  |   1 +
 .../include/experimental/__simd/aligned_tag.h |   3 +-
 .../include/experimental/__simd/declaration.h |   3 +
 .../__simd/internal_declaration.h             |   2 +
 .../include/experimental/__simd/reference.h   |   1 +
 libcxx/include/experimental/__simd/simd.h     |  11 ++
 .../include/experimental/__simd/simd_mask.h   |  19 ++++
 libcxx/include/experimental/__simd/utility.h  |   1 +
 libcxx/include/experimental/__simd/vec_ext.h  |  33 ++++++
 .../mismatch/mismatch.pass.cpp                |  77 ++++---------
 14 files changed, 279 insertions(+), 74 deletions(-)
 create mode 100644 libcxx/benchmarks/algorithms/mismatch.bench.cpp
 create mode 100644 libcxx/include/__algorithm/vectorization.h

diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
new file mode 100644
index 000000000000000..3cce0c108ee1ecd
--- /dev/null
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <benchmark/benchmark.h>
+#include <random>
+
+template <class T>
+static void bm_find(benchmark::State& state) {
+  std::vector<T> vec1(state.range(), '1');
+  std::vector<T> vec2(state.range(), '1');
+  std::mt19937_64 rng(std::random_device{}());
+
+  for (auto _ : state) {
+    auto idx  = rng() % vec1.size();
+    vec1[idx] = '2';
+    benchmark::DoNotOptimize(vec1);
+    benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
+    vec1[idx] = '1';
+  }
+}
+BENCHMARK(bm_find<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_find<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_find<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+
+BENCHMARK_MAIN();
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 889d7fedbf2965f..f16bd2784001018 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -225,6 +225,7 @@ set(files
   __algorithm/unwrap_iter.h
   __algorithm/unwrap_range.h
   __algorithm/upper_bound.h
+  __algorithm/vectorization.h
   __assert
   __atomic/aliases.h
   __atomic/atomic.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index e5b014f45738a1d..6ae118f1bbcede0 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -11,9 +11,19 @@
 #define _LIBCPP___ALGORITHM_MISMATCH_H
 
 #include <__algorithm/comp.h>
+#include <__algorithm/unwrap_iter.h>
+#include <__algorithm/vectorization.h>
 #include <__config>
+#include <__functional/identity.h>
 #include <__iterator/iterator_traits.h>
+#include <__type_traits/invoke.h>
+#include <__type_traits/is_equality_comparable.h>
+#include <__type_traits/predicate_traits.h>
+#include <__utility/move.h>
 #include <__utility/pair.h>
+#include <experimental/__simd/abi_tag.h>
+#include <experimental/__simd/simd.h>
+#include <experimental/__simd/simd_mask.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
@@ -21,29 +31,87 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
-_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY
-    _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
-    mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
-  for (; __first1 != __last1; ++__first1, (void)++__first2)
-    if (!__pred(*__first1, *__first2))
+template <class _InIter1, class _Sent1, class _InIter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InIter1, _InIter2>
+__mismatch_loop(_InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Pred __pred, _Proj1 __proj1, _Proj2 __proj2) {
+  while (__first1 != __last1) {
+    if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
       break;
-  return pair<_InputIterator1, _InputIterator2>(__first1, __first2);
+    ++__first1;
+    ++__first2;
+  }
+  return {std::move(__first1), std::move(__first2)};
+}
+
+template <class _InIter1, class _Sent1, class _InIter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InIter1, _InIter2>
+__mismatch(_InIter1 __first1, _Sent1 __last1, _InIter2 __first2, _Pred __pred, _Proj1 __proj1, _Proj2 __proj2) {
+  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+template <class _Tp,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          enable_if_t<__is_trivial_equality_predicate<_Pred, _Tp, _Tp>::value && __is_identity<_Proj1>::value &&
+                          __is_identity<_Proj2>::value &&
+                          ((__libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value && __fits_in_vector<_Tp> &&
+                            alignof(_Tp) >= alignof(__get_arithmetic_type<_Tp>)) ||
+                           (_LIBCPP_VECTORIZE_FLOATING_POINT_ALGORITHMS && is_floating_point_v<_Tp>)),
+                      int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI constexpr pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred __pred, _Proj1 __proj1, _Proj2 __proj2) {
+  using __vec                     = __arithmetic_vec<_Tp>;
+  constexpr size_t __unroll_count = 4;
+
+  while (!__libcpp_is_constant_evaluated() && __last1 - __first1 >= __unroll_count * __vec::size()) {
+    __vec __lhs[__unroll_count];
+    __vec __rhs[__unroll_count];
+
+    for (size_t __i = 0; __i != __unroll_count; ++__i) {
+      __lhs[__i] = std::__load_as_arithmetic(__first1 + __i * __vec::size());
+      __rhs[__i] = std::__load_as_arithmetic(__first2 + __i * __vec::size());
+    }
+
+    for (size_t __i = 0; __i != __unroll_count; ++__i) {
+      if (auto __res = __lhs[__i] == __rhs[__i]; !experimental::all_of(__res)) {
+        auto __offset = __i * __vec::size() + experimental::find_first_set(__res);
+        return std::make_pair(__first1 + __offset, __first2 + __offset);
+      }
+    }
+
+    __first1 += __unroll_count * __vec::size();
+    __first2 += __unroll_count * __vec::size();
+  }
+
+  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+#endif // _LIBCPP_VECTORIZE_ALGORITHMS
+
+template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
+_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
+mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
+  __identity __proj;
+  auto __res = std::__mismatch(
+      std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
+  return std::make_pair(std::__rewrap_iter(__first1, __res.first), std::__rewrap_iter(__first2, __res.second));
 }
 
 template <class _InputIterator1, class _InputIterator2>
-_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY
-    _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
-    mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2) {
+_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
+mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2) {
   return std::mismatch(__first1, __last1, __first2, __equal_to());
 }
 
 #if _LIBCPP_STD_VER >= 14
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
-_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY
-    _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
-    mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2,
-             _BinaryPredicate __pred) {
+_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
+mismatch(_InputIterator1 __first1,
+         _InputIterator1 __last1,
+         _InputIterator2 __first2,
+         _InputIterator2 __last2,
+         _BinaryPredicate __pred) {
   for (; __first1 != __last1 && __first2 != __last2; ++__first1, (void)++__first2)
     if (!__pred(*__first1, *__first2))
       break;
@@ -51,9 +119,8 @@ _LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY
 }
 
 template <class _InputIterator1, class _InputIterator2>
-_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY
-    _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
-    mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2) {
+_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
+mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _InputIterator2 __last2) {
   return std::mismatch(__first1, __last1, __first2, __last2, __equal_to());
 }
 #endif
diff --git a/libcxx/include/__algorithm/vectorization.h b/libcxx/include/__algorithm/vectorization.h
new file mode 100644
index 000000000000000..acdeb065d27b4bd
--- /dev/null
+++ b/libcxx/include/__algorithm/vectorization.h
@@ -0,0 +1,69 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_VECTORIZATION_H
+#define _LIBCPP___ALGORITHM_VECTORIZATION_H
+
+#include <__config>
+#include <__type_traits/is_floating_point.h>
+#include <__utility/integer_sequence.h>
+#include <experimental/__simd/simd.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+#if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL) && !defined(__OPTIMIZE_SIZE__)
+#  define _LIBCPP_VECTORIZE_ALGORITHMS 1
+#else
+#  define _LIBCPP_VECTORIZE_ALGORITHMS 0
+#endif
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS && defined(__FAST_MATH__)
+#  define _LIBCPP_VECTORIZE_FLOATING_POINT_ALGORITHMS 1
+#else
+#  define _LIBCPP_VECTORIZE_FLOATING_POINT_ALGORITHMS 0
+#endif
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _Tp>
+inline static const bool __fits_in_vector =
+    sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8;
+
+template <class _Up>
+using __get_arithmetic_type = decltype([]<class _Tp> {
+  if constexpr (is_floating_point_v<_Tp>)
+    return _Tp{};
+  else if constexpr (constexpr auto __sz = sizeof(_Tp); __sz == 1)
+    return uint8_t{};
+  else if constexpr (__sz == 2)
+    return uint16_t{};
+  else if constexpr (__sz == 4)
+    return uint32_t{};
+  else if constexpr (__sz == 8)
+    return uint64_t{};
+  else
+    static_assert(false, "unexpected sizeof type");
+}.template operator()<_Up>());
+
+template <class _Tp>
+using __arithmetic_vec = experimental::native_simd<__get_arithmetic_type<_Tp>>;
+
+template <class _Tp>
+_LIBCPP_HIDE_FROM_ABI __arithmetic_vec<_Tp> __load_as_arithmetic(_Tp* __values) {
+  return (reinterpret_cast<__get_arithmetic_type<_Tp>*>(__values), 0);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP_VECTORIZE_ALGORITHMS
+
+#endif // _LIBCPP___ALGORITHM_VECTORIZATION_H
diff --git a/libcxx/include/experimental/__simd/abi_tag.h b/libcxx/include/experimental/__simd/abi_tag.h
index a9d51c0683b1dc8..2fa2492bb2fecc3 100644
--- a/libcxx/include/experimental/__simd/abi_tag.h
+++ b/libcxx/include/experimental/__simd/abi_tag.h
@@ -10,6 +10,7 @@
 #ifndef _LIBCPP_EXPERIMENTAL___SIMD_ABI_TAG_H
 #define _LIBCPP_EXPERIMENTAL___SIMD_ABI_TAG_H
 
+#include <experimental/__config>
 #include <experimental/__simd/scalar.h>
 #include <experimental/__simd/vec_ext.h>
 
diff --git a/libcxx/include/experimental/__simd/aligned_tag.h b/libcxx/include/experimental/__simd/aligned_tag.h
index d3816ae1b0717d7..d216a21c073f3a3 100644
--- a/libcxx/include/experimental/__simd/aligned_tag.h
+++ b/libcxx/include/experimental/__simd/aligned_tag.h
@@ -12,7 +12,8 @@
 
 #include <__bit/bit_ceil.h>
 #include <__memory/assume_aligned.h>
-#include <cstdint>
+#include <cstddef>
+#include <experimental/__config>
 
 #if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL)
 
diff --git a/libcxx/include/experimental/__simd/declaration.h b/libcxx/include/experimental/__simd/declaration.h
index 747f87be63e535b..065faeaec3841f9 100644
--- a/libcxx/include/experimental/__simd/declaration.h
+++ b/libcxx/include/experimental/__simd/declaration.h
@@ -10,6 +10,9 @@
 #ifndef _LIBCPP_EXPERIMENTAL___SIMD_DECLARATION_H
 #define _LIBCPP_EXPERIMENTAL___SIMD_DECLARATION_H
 
+#include <experimental/__config>
+#include <experimental/__simd/abi_tag.h>
+
 #if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL)
 
 _LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL
diff --git a/libcxx/include/experimental/__simd/internal_declaration.h b/libcxx/include/experimental/__simd/internal_declaration.h
index 294b54d63bb5d9f..a027f265c7d8e5b 100644
--- a/libcxx/include/experimental/__simd/internal_declaration.h
+++ b/libcxx/include/experimental/__simd/internal_declaration.h
@@ -10,6 +10,8 @@
 #ifndef _LIBCPP_EXPERIMENTAL___SIMD_INTERNAL_DECLARATION_H
 #define _LIBCPP_EXPERIMENTAL___SIMD_INTERNAL_DECLARATION_H
 
+#include <experimental/__config>
+
 #if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL)
 
 _LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL
diff --git a/libcxx/include/experimental/__simd/reference.h b/libcxx/include/experimental/__simd/reference.h
index 8c58d24f2f2dc42..88c24634657fcfd 100644
--- a/libcxx/include/experimental/__simd/reference.h
+++ b/libcxx/include/experimental/__simd/reference.h
@@ -11,6 +11,7 @@
 #define _LIBCPP_EXPERIMENTAL___SIMD_REFERENCE_H
 
 #include <__type_traits/is_assignable.h>
+#include <__utility/forward.h>
 #include <experimental/__simd/utility.h>
 
 #if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL)
diff --git a/libcxx/include/experimental/__simd/simd.h b/libcxx/include/experimental/__simd/simd.h
index 29a566608603d65..fbbbaeb4e0e0eaf 100644
--- a/libcxx/include/experimental/__simd/simd.h
+++ b/libcxx/include/experimental/__simd/simd.h
@@ -11,6 +11,7 @@
 #define _LIBCPP_EXPERIMENTAL___SIMD_SIMD_H
 
 #include <__type_traits/remove_cvref.h>
+#include <experimental/__config>
 #include <experimental/__simd/abi_tag.h>
 #include <experimental/__simd/declaration.h>
 #include <experimental/__simd/reference.h>
@@ -42,6 +43,9 @@ class simd {
 
   _LIBCPP_HIDE_FROM_ABI simd() noexcept = default;
 
+  template <class _Up, class _Flags>
+  _LIBCPP_HIDE_FROM_ABI simd(const _Up* __data, _Flags) noexcept : __s_(_Impl::__load(__data)) {}
+
   // broadcast constructor
   template <class _Up, enable_if_t<__can_broadcast_v<value_type, __remove_cvref_t<_Up>>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI simd(_Up&& __v) noexcept : __s_(_Impl::__broadcast(static_cast<value_type>(__v))) {}
@@ -54,6 +58,13 @@ class simd {
   // scalar access [simd.subscr]
   _LIBCPP_HIDE_FROM_ABI reference operator[](size_t __i) noexcept { return reference(__s_, __i); }
   _LIBCPP_HIDE_FROM_ABI value_type operator[](size_t __i) const noexcept { return __s_.__get(__i); }
+
+  friend _LIBCPP_HIDE_FROM_ABI mask_type operator==(const simd& __lhs, const simd& __rhs) noexcept {
+    mask_type __result;
+    for (int __i = 0; __i != size(); ++__i)
+      __result[__i] = __lhs[__i] == __rhs[__i];
+    return __result;
+  }
 };
 
 template <class _Tp>
diff --git a/libcxx/include/experimental/__simd/simd_mask.h b/libcxx/include/experimental/__simd/simd_mask.h
index 2e47b678913f851..19e488b44d2b6ae 100644
--- a/libcxx/include/experimental/__simd/simd_mask.h
+++ b/libcxx/include/experimental/__simd/simd_mask.h
@@ -10,6 +10,7 @@
 #ifndef _LIBCPP_EXPERIMENTAL___SIMD_SIMD_MASK_H
 #define _LIBCPP_EXPERIMENTAL___SIMD_SIMD_MASK_H
 
+#include <__utility/unreachable.h>
 #include <experimental/__simd/abi_tag.h>
 #include <experimental/__simd/declaration.h>
 #include <experimental/__simd/reference.h>
@@ -46,6 +47,10 @@ class simd_mask {
   // scalar access [simd.mask.subscr]
   _LIBCPP_HIDE_FROM_ABI reference operator[](size_t __i) noexcept { return reference(__s_, __i); }
   _LIBCPP_HIDE_FROM_ABI value_type operator[](size_t __i) const noexcept { return __s_.__get(__i); }
+
+private:
+  template <class _Tp2, class _Abi2>
+  friend bool all_of(const simd_mask<_Tp2, _Abi2>&) noexcept;
 };
 
 template <class _Tp>
@@ -54,6 +59,20 @@ using native_simd_mask = simd_mask<_Tp, simd_abi::native<_Tp>>;
 template <class _Tp, int _Np>
 using fixed_size_simd_mask = simd_mask<_Tp, simd_abi::fixed_size<_Np>>;
 
+template <class _Tp, class _Abi>
+_LIBCPP_HIDE_FROM_ABI bool all_of(const simd_mask<_Tp, _Abi>& __mask) noexcept {
+  return __mask_operations<_Tp, _Abi>::all_of(__mask.__s_);
+}
+
+template <class _Tp, class _Abi>
+_LIBCPP_HIDE_FROM_ABI int find_first_set(const simd_mask<_Tp, _Abi>& __mask) noexcept {
+  for (int __i = 0; __i != __mask.size(); ++__i) {
+    if (__mask[__i])
+      return __i;
+  }
+  std::unreachable();
+}
+
 } // namespace parallelism_v2
 _LIBCPP_END_NAMESPACE_EXPERIMENTAL
 
diff --git a/libcxx/include/experimental/__simd/utility.h b/libcxx/include/experimental/__simd/utility.h
index 847d006629c8d3b..c9d21aaf2ef2787 100644
--- a/libcxx/include/experimental/__simd/utility.h
+++ b/libcxx/include/experimental/__simd/utility.h
@@ -21,6 +21,7 @@
 #include <__utility/declval.h>
 #include <__utility/integer_sequence.h>
 #include <cstdint>
+#include <experimental/__config>
 #include <limits>
 
 _LIBCPP_PUSH_MACROS
diff --git a/libcxx/include/experimental/__simd/vec_ext.h b/libcxx/include/experimental/__simd/vec_ext.h
index 4b23bfe384477ed..ab414c136df8b7c 100644
--- a/libcxx/include/experimental/__simd/vec_ext.h
+++ b/libcxx/include/experimental/__simd/vec_ext.h
@@ -16,6 +16,10 @@
 #include <experimental/__simd/internal_declaration.h>
 #include <experimental/__simd/utility.h>
 
+#if __has_include(<immintrin.h>)
+#  include <immintrin.h>
+#endif
+
 #if _LIBCPP_STD_VER >= 17 && defined(_LIBCPP_ENABLE_EXPERIMENTAL)
 
 _LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL
@@ -67,6 +71,14 @@ struct __simd_operations<_Tp, simd_abi::__vec_ext<_Np>> {
   static _LIBCPP_HIDE_FROM_ABI _SimdStorage __generate(_Generator&& __g) noexcept {
     return __generate_init(std::forward<_Generator>(__g), std::make_index_sequence<_Np>());
   }
+
+  template <class _Up>
+  static _LIBCPP_HIDE_FROM_ABI _SimdStorage __load(const _Up* __data) noexcept {
+    _SimdStorage __result;
+    for (size_t __i = 0; __i != _Np; ++__i)
+      __result.__set(__i, __data[__i]);
+    return __result;
+  }
 };
 
 template <class _Tp, int _Np>
@@ -81,6 +93,27 @@ struct __mask_operations<_Tp, simd_abi::__vec_ext<_Np>> {
     }
     return __result;
   }
+
+  inline static _LIBCPP_HIDE_FROM_ABI bool all_of(_MaskStorage __mask) noexcept {
+    [[maybe_unused]] constexpr auto __vec_size = sizeof(_Tp) * _Np;
+#  ifdef __AVX2__
+    if constexpr (__vec_size == 32) {
+      return _mm256_movemask_epi8((__m256i)__mask.__data) == 0xffffffffU;
+    } else
+#  endif
+#  ifdef __SSE2__
+    if constexpr (__vec_size == 16) {
+      return _mm_movemask_epi8((__m128i)__mask.__data) == 0xffffU;
+    } else
+#  endif
+    {
+      for (int __i = 0; __i != _Np; ++__i) {
+        if (!__mask.__get(__i))
+          return false;
+      }
+      return true;
+    }
+  }
 };
 
 } // namespace parallelism_v2
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index cc588c095ccfb2f..e5f481bc6cfae09 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -23,71 +23,36 @@
 #include "test_macros.h"
 #include "test_iterators.h"
 
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {1, 3, 6, 7};
-    int ib[] = {1, 3};
-    int ic[] = {1, 3, 5, 7};
-    typedef cpp17_input_iterator<int*>         II;
-    typedef bidirectional_iterator<int*> BI;
+TEST_CONSTEXPR_CXX20 bool test() {
+  int ia[]          = {0, 1, 2, 2, 0, 1, 2, 3};
+  const unsigned sa = sizeof(ia) / sizeof(ia[0]);
+  int ib[]          = {0, 1, 2, 3, 0, 1, 2, 3};
+  const unsigned sb = sizeof(ib) / sizeof(ib[0]);
+  ((void)sb); // unused in C++11
 
-    auto p1 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic));
-    if (p1.first != ia+2 || p1.second != ic+2)
-        return false;
+  typedef cpp17_input_iterator<const int*> II;
+  typedef random_access_iterator<const int*> RAI;
 
-    auto p2 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), std::end(ic));
-    if (p2.first != ia+2 || p2.second != ic+2)
-        return false;
+  assert(std::mismatch(II(ia), II(ia + sa), II(ib)) == (std::pair<II, II>(II(ia + 3), II(ib + 3))));
 
-    auto p3 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic));
-    if (p3.first != ib+2 || p3.second != ic+2)
-        return false;
-
-    auto p4 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), std::end(ic));
-    if (p4.first != ib+2 || p4.second != ic+2)
-        return false;
-
-    auto p5 = std::mismatch(II(std::begin(ib)), II(std::end(ib)), II(std::begin(ic)));
-    if (p5.first != II(ib+2) || p5.second != II(ic+2))
-        return false;
-    auto p6 = std::mismatch(BI(std::begin(ib)), BI(std::end(ib)), BI(std::begin(ic)), BI(std::end(ic)));
-    if (p6.first != BI(ib+2) || p6.second != BI(ic+2))
-        return false;
-
-    return true;
-    }
-#endif
-
-int main(int, char**)
-{
-    int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
-    const unsigned sa = sizeof(ia)/sizeof(ia[0]);
-    int ib[] = {0, 1, 2, 3, 0, 1, 2, 3};
-    const unsigned sb = sizeof(ib)/sizeof(ib[0]); ((void)sb); // unused in C++11
-
-    typedef cpp17_input_iterator<const int*> II;
-    typedef random_access_iterator<const int*>  RAI;
-
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib))
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib))
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
+  assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib)) == (std::pair<RAI, RAI>(RAI(ia + 3), RAI(ib + 3))));
 
 #if TEST_STD_VER > 11 // We have the four iteration version
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+sb))
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib+sb))
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
+  assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + sb)) == (std::pair<II, II>(II(ia + 3), II(ib + 3))));
 
+  assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib + sb)) ==
+         (std::pair<RAI, RAI>(RAI(ia + 3), RAI(ib + 3))));
 
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+2))
-            == (std::pair<II, II>(II(ia+2), II(ib+2))));
+  assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + 2)) == (std::pair<II, II>(II(ia + 2), II(ib + 2))));
 #endif
 
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
+  return true;
+}
+
+int main(int, char**) {
+  test();
+#if TEST_STD_VER >= 20
+  static_assert(test());
 #endif
 
   return 0;



More information about the libcxx-commits mailing list