[libcxx-commits] [libcxx] [libc++] Vectorize mismatch (PR #73255)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Sat Mar 23 05:48:24 PDT 2024


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/73255

>From f5fe730d87162ac051c280f54b5450757c13937c Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 19 Feb 2024 13:20:58 +0100
Subject: [PATCH] [libc++] Vectorize mismatch

---
 libcxx/benchmarks/CMakeLists.txt              |   1 +
 .../benchmarks/algorithms/mismatch.bench.cpp  |  31 +++
 libcxx/docs/ReleaseNotes/19.rst               |   2 +
 libcxx/include/CMakeLists.txt                 |   1 +
 libcxx/include/__algorithm/mismatch.h         |  82 ++++++-
 libcxx/include/__algorithm/simd_utils.h       | 125 ++++++++++
 libcxx/include/__bit/bit_cast.h               |   9 +
 libcxx/include/__bit/countr.h                 |  13 +-
 libcxx/include/libcxx.imp                     |   1 +
 libcxx/include/module.modulemap               |   6 +-
 .../mismatch/mismatch.pass.cpp                | 214 +++++++++++++-----
 .../mismatch/mismatch_pred.pass.cpp           | 119 ----------
 12 files changed, 415 insertions(+), 189 deletions(-)
 create mode 100644 libcxx/benchmarks/algorithms/mismatch.bench.cpp
 create mode 100644 libcxx/include/__algorithm/simd_utils.h
 delete mode 100644 libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp

diff --git a/libcxx/benchmarks/CMakeLists.txt b/libcxx/benchmarks/CMakeLists.txt
index 3dec6faea13a0c..387e013afeb6c4 100644
--- a/libcxx/benchmarks/CMakeLists.txt
+++ b/libcxx/benchmarks/CMakeLists.txt
@@ -183,6 +183,7 @@ set(BENCHMARK_TESTS
     algorithms/make_heap_then_sort_heap.bench.cpp
     algorithms/min.bench.cpp
     algorithms/min_max_element.bench.cpp
+    algorithms/mismatch.bench.cpp
     algorithms/pop_heap.bench.cpp
     algorithms/pstl.stable_sort.bench.cpp
     algorithms/push_heap.bench.cpp
diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
new file mode 100644
index 00000000000000..9274932a764c55
--- /dev/null
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <benchmark/benchmark.h>
+#include <random>
+
+// TODO: Look into benchmarking aligned and unaligned memory explicitly
+// (currently things happen to be aligned because they are malloced that way)
+template <class T>
+static void bm_mismatch(benchmark::State& state) {
+  std::vector<T> vec1(state.range(), '1');
+  std::vector<T> vec2(state.range(), '1');
+  std::mt19937_64 rng(std::random_device{}());
+
+  vec1.back() = '2';
+  for (auto _ : state) {
+    benchmark::DoNotOptimize(vec1);
+    benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
+  }
+}
+BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+
+BENCHMARK_MAIN();
diff --git a/libcxx/docs/ReleaseNotes/19.rst b/libcxx/docs/ReleaseNotes/19.rst
index c70ae477fafc1d..4e06a1dc3040d1 100644
--- a/libcxx/docs/ReleaseNotes/19.rst
+++ b/libcxx/docs/ReleaseNotes/19.rst
@@ -51,6 +51,8 @@ Improvements and New Features
 - The performance of growing ``std::vector`` has been improved for trivially relocatable types.
 - The performance of ``ranges::fill`` and ``ranges::fill_n`` has been improved for ``vector<bool>::iterator``\s,
   resulting in a performance increase of up to 1400x.
+- The ``std::mismatch`` algorithm has been optimized for integral types, which can lead up to 40x performance
+  improvements.
 
 Deprecations and Removals
 -------------------------
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 6ed8d21d98a15a..982b85e4e2d62f 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -217,6 +217,7 @@ set(files
   __algorithm/shift_right.h
   __algorithm/shuffle.h
   __algorithm/sift_down.h
+  __algorithm/simd_utils.h
   __algorithm/sort.h
   __algorithm/sort_heap.h
   __algorithm/stable_partition.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index d345b6048a7e9b..4eb693a1f2e9d8 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -11,23 +11,93 @@
 #define _LIBCPP___ALGORITHM_MISMATCH_H
 
 #include <__algorithm/comp.h>
+#include <__algorithm/simd_utils.h>
+#include <__algorithm/unwrap_iter.h>
 #include <__config>
-#include <__iterator/iterator_traits.h>
+#include <__functional/identity.h>
+#include <__type_traits/invoke.h>
+#include <__type_traits/is_constant_evaluated.h>
+#include <__type_traits/is_equality_comparable.h>
+#include <__type_traits/operation_traits.h>
+#include <__utility/move.h>
 #include <__utility/pair.h>
+#include <__utility/unreachable.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
 #endif
 
+_LIBCPP_PUSH_MACROS
+#include <__undef_macros>
+
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+template <class _Iter1, class _Sent1, class _Iter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter1, _Iter2>
+__mismatch_loop(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+  while (__first1 != __last1) {
+    if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
+      break;
+    ++__first1;
+    ++__first2;
+  }
+  return std::make_pair(std::move(__first1), std::move(__first2));
+}
+
+template <class _Iter1, class _Sent1, class _Iter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter1, _Iter2>
+__mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+
+template <class _Tp,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          __enable_if_t<is_integral<_Tp>::value && __desugars_to<__equal_tag, _Pred, _Tp, _Tp>::value &&
+                            __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
+                        int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+  constexpr size_t __unroll_count = 4;
+  constexpr size_t __vec_size     = __native_vector_size<_Tp>;
+  using __vec                     = __simd_vector<_Tp, __vec_size>;
+  if (!__libcpp_is_constant_evaluated()) {
+    while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
+      __vec __lhs[__unroll_count];
+      __vec __rhs[__unroll_count];
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        __lhs[__i] = std::__load_vector<__vec>(__first1 + __i * __vec_size);
+        __rhs[__i] = std::__load_vector<__vec>(__first2 + __i * __vec_size);
+      }
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        if (auto __cmp_res = __lhs[__i] == __rhs[__i]; !std::__all_of(__cmp_res)) {
+          auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res);
+          return {__first1 + __offset, __first2 + __offset};
+        }
+      }
+
+      __first1 += __unroll_count * __vec_size;
+      __first2 += __unroll_count * __vec_size;
+    }
+  }
+  // TODO: Consider vectorizing the tail
+  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+
+#endif // _LIBCPP_VECTORIZE_ALGORITHMS
+
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
 _LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
 mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
-  for (; __first1 != __last1; ++__first1, (void)++__first2)
-    if (!__pred(*__first1, *__first2))
-      break;
-  return pair<_InputIterator1, _InputIterator2>(__first1, __first2);
+  __identity __proj;
+  auto __res = std::__mismatch(
+      std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
+  return std::make_pair(std::__rewrap_iter(__first1, __res.first), std::__rewrap_iter(__first2, __res.second));
 }
 
 template <class _InputIterator1, class _InputIterator2>
@@ -59,4 +129,6 @@ mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __fi
 
 _LIBCPP_END_NAMESPACE_STD
 
+_LIBCPP_POP_MACROS
+
 #endif // _LIBCPP___ALGORITHM_MISMATCH_H
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
new file mode 100644
index 00000000000000..1eb21cbb9b32dc
--- /dev/null
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -0,0 +1,125 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_SIMD_UTILS_H
+#define _LIBCPP___ALGORITHM_SIMD_UTILS_H
+
+#include <__bit/bit_cast.h>
+#include <__bit/countr.h>
+#include <__config>
+#include <__type_traits/is_arithmetic.h>
+#include <__type_traits/is_same.h>
+#include <__utility/integer_sequence.h>
+#include <cstddef>
+#include <cstdint>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+// TODO: Find out how altivec changes things and allow vectorizations there too.
+#if _LIBCPP_STD_VER >= 14 && defined(_LIBCPP_CLANG_VER) && _LIBCPP_CLANG_VER >= 1700 && !defined(__ALTIVEC__)
+#  define _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS 1
+#else
+#  define _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS 0
+#endif
+
+#if _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS && !defined(__OPTIMIZE_SIZE__)
+#  define _LIBCPP_VECTORIZE_ALGORITHMS 1
+#else
+#  define _LIBCPP_VECTORIZE_ALGORITHMS 0
+#endif
+
+#if _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+// This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
+// in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
+// as far as benchmarks are concerned.
+#  if defined(__AVX__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 32 / sizeof(_Tp);
+#  elif defined(__SSE__) || defined(__ARM_NEON__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 16 / sizeof(_Tp);
+#  elif defined(__MMX__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 8 / sizeof(_Tp);
+#  else
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 1;
+#  endif
+
+template <class _ArithmeticT, size_t _Np>
+using __simd_vector __attribute__((__ext_vector_type__(_Np))) = _ArithmeticT;
+
+template <class _VecT>
+inline constexpr size_t __simd_vector_size_v = []<bool _False = false>() -> size_t {
+  static_assert(_False, "Not a vector!");
+}();
+
+template <class _Tp, size_t _Np>
+inline constexpr size_t __simd_vector_size_v<__simd_vector<_Tp, _Np>> = _Np;
+
+template <class _Tp, size_t _Np>
+_LIBCPP_HIDE_FROM_ABI _Tp __simd_vector_underlying_type_impl(__simd_vector<_Tp, _Np>) {
+  return _Tp{};
+}
+
+template <class _VecT>
+using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_type_impl(_VecT{}));
+
+// This isn't inlined without always_inline when loading chars.
+template <class _VecT, class _Tp>
+_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
+  return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
+    return _VecT{__ptr[_Indices]...};
+  }(make_index_sequence<__simd_vector_size_v<_VecT>>{});
+}
+
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<_Tp, _Np> __vec) noexcept {
+  return __builtin_reduce_and(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
+}
+
+// This has MSan disabled du to https://github.com/llvm/llvm-project/issues/85876
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_NO_SANITIZE("memory") size_t
+__find_first_set(__simd_vector<_Tp, _Np> __vec) noexcept {
+  using __mask_vec = __simd_vector<bool, _Np>;
+
+  auto __impl = [&]<class _MaskT>(_MaskT) noexcept {
+    return std::__countr_zero(std::__bit_cast<_MaskT>(__builtin_convertvector(__vec, __mask_vec)));
+  };
+
+  if constexpr (sizeof(__mask_vec) == sizeof(uint8_t)) {
+    return __impl(uint8_t{});
+  } else if constexpr (sizeof(__mask_vec) == sizeof(uint16_t)) {
+    return __impl(uint16_t{});
+  } else if constexpr (sizeof(__mask_vec) == sizeof(uint32_t)) {
+    return __impl(uint32_t{});
+  } else if constexpr (sizeof(__mask_vec) == sizeof(uint64_t)) {
+    return __impl(uint64_t{});
+  } else {
+    static_assert(sizeof(__mask_vec) == 0, "unexpected required size for mask integer type");
+    return 0;
+  }
+}
+
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_not_set(__simd_vector<_Tp, _Np> __vec) noexcept {
+  return std::__find_first_set(~__vec);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP_STD_VER >= 14 && __has_attribute(__ext_vector_type__) && __has_builtin(__builtin_reduce_and) &&
+       // __has_builtin(__builtin_convertvector)
+
+#endif // _LIBCPP___ALGORITHM_SIMD_UTILS_H
diff --git a/libcxx/include/__bit/bit_cast.h b/libcxx/include/__bit/bit_cast.h
index f20b39ae748b10..6298810f373303 100644
--- a/libcxx/include/__bit/bit_cast.h
+++ b/libcxx/include/__bit/bit_cast.h
@@ -19,6 +19,15 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+#ifndef _LIBCPP_CXX03_LANG
+
+template <class _ToType, class _FromType>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI constexpr _ToType __bit_cast(const _FromType& __from) noexcept {
+  return __builtin_bit_cast(_ToType, __from);
+}
+
+#endif // _LIBCPP_CXX03_LANG
+
 #if _LIBCPP_STD_VER >= 20
 
 template <class _ToType, class _FromType>
diff --git a/libcxx/include/__bit/countr.h b/libcxx/include/__bit/countr.h
index 0cc679f87a99d9..b6b3ac52ca4e47 100644
--- a/libcxx/include/__bit/countr.h
+++ b/libcxx/include/__bit/countr.h
@@ -35,10 +35,8 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR int __libcpp_ct
   return __builtin_ctzll(__x);
 }
 
-#if _LIBCPP_STD_VER >= 20
-
-template <__libcpp_unsigned_integer _Tp>
-_LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) noexcept {
+template <class _Tp>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 int __countr_zero(_Tp __t) _NOEXCEPT {
   if (__t == 0)
     return numeric_limits<_Tp>::digits;
 
@@ -59,6 +57,13 @@ _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) n
   }
 }
 
+#if _LIBCPP_STD_VER >= 20
+
+template <__libcpp_unsigned_integer _Tp>
+_LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) noexcept {
+  return std::__countr_zero(__t);
+}
+
 template <__libcpp_unsigned_integer _Tp>
 _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_one(_Tp __t) noexcept {
   return __t != numeric_limits<_Tp>::max() ? std::countr_zero(static_cast<_Tp>(~__t)) : numeric_limits<_Tp>::digits;
diff --git a/libcxx/include/libcxx.imp b/libcxx/include/libcxx.imp
index 77b7befd44f56c..56ea58262828a0 100644
--- a/libcxx/include/libcxx.imp
+++ b/libcxx/include/libcxx.imp
@@ -217,6 +217,7 @@
   { include: [ "<__algorithm/shift_right.h>", "private", "<algorithm>", "public" ] },
   { include: [ "<__algorithm/shuffle.h>", "private", "<algorithm>", "public" ] },
   { include: [ "<__algorithm/sift_down.h>", "private", "<algorithm>", "public" ] },
+  { include: [ "<__algorithm/simd_utils.h>", "private", "<algorithm>", "public" ] },
   { include: [ "<__algorithm/sort.h>", "private", "<algorithm>", "public" ] },
   { include: [ "<__algorithm/sort_heap.h>", "private", "<algorithm>", "public" ] },
   { include: [ "<__algorithm/stable_partition.h>", "private", "<algorithm>", "public" ] },
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index f36a47cef00977..03d18775631ed6 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -697,7 +697,10 @@ module std_private_algorithm_minmax                                      [system
   export *
 }
 module std_private_algorithm_minmax_element                              [system] { header "__algorithm/minmax_element.h" }
-module std_private_algorithm_mismatch                                    [system] { header "__algorithm/mismatch.h" }
+module std_private_algorithm_mismatch                                    [system] {
+  header "__algorithm/mismatch.h"
+  export std_private_algorithm_simd_utils
+}
 module std_private_algorithm_move                                        [system] { header "__algorithm/move.h" }
 module std_private_algorithm_move_backward                               [system] { header "__algorithm/move_backward.h" }
 module std_private_algorithm_next_permutation                            [system] { header "__algorithm/next_permutation.h" }
@@ -1048,6 +1051,7 @@ module std_private_algorithm_sort                                        [system
   header "__algorithm/sort.h"
   export std_private_debug_utils_strict_weak_ordering_check
 }
+module std_private_algorithm_simd_utils                                  [system] { header "__algorithm/simd_utils.h" }
 module std_private_algorithm_sort_heap                                   [system] { header "__algorithm/sort_heap.h" }
 module std_private_algorithm_stable_partition                            [system] { header "__algorithm/stable_partition.h" }
 module std_private_algorithm_stable_sort                                 [system] { header "__algorithm/stable_sort.h" }
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index cc588c095ccfb2..e7f3994d977dcd 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -16,79 +16,173 @@
 // template<InputIterator Iter1, InputIterator Iter2Pred>
 //   constexpr pair<Iter1, Iter2>   // constexpr after c++17
 //   mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2); // C++14
+//
+// template<InputIterator Iter1, InputIterator Iter2,
+//          Predicate<auto, Iter1::value_type, Iter2::value_type> Pred>
+//   requires CopyConstructible<Pred>
+//   constexpr pair<Iter1, Iter2>   // constexpr after c++17
+//   mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Pred pred);
+//
+// template<InputIterator Iter1, InputIterator Iter2, Predicate Pred>
+//   constexpr pair<Iter1, Iter2>   // constexpr after c++17
+//   mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, Pred pred); // C++14
+
+// ADDITIONAL_COMPILE_FLAGS(has-fconstexpr-steps): -fconstexpr-steps=50000000
+// ADDITIONAL_COMPILE_FLAGS(has-fconstexpr-ops-limit): -fconstexpr-ops-limit=100000000
 
 #include <algorithm>
+#include <array>
 #include <cassert>
+#include <vector>
 
 #include "test_macros.h"
 #include "test_iterators.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {1, 3, 6, 7};
-    int ib[] = {1, 3};
-    int ic[] = {1, 3, 5, 7};
-    typedef cpp17_input_iterator<int*>         II;
-    typedef bidirectional_iterator<int*> BI;
-
-    auto p1 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic));
-    if (p1.first != ia+2 || p1.second != ic+2)
-        return false;
-
-    auto p2 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), std::end(ic));
-    if (p2.first != ia+2 || p2.second != ic+2)
-        return false;
-
-    auto p3 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic));
-    if (p3.first != ib+2 || p3.second != ic+2)
-        return false;
-
-    auto p4 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), std::end(ic));
-    if (p4.first != ib+2 || p4.second != ic+2)
-        return false;
-
-    auto p5 = std::mismatch(II(std::begin(ib)), II(std::end(ib)), II(std::begin(ic)));
-    if (p5.first != II(ib+2) || p5.second != II(ic+2))
-        return false;
-    auto p6 = std::mismatch(BI(std::begin(ib)), BI(std::end(ib)), BI(std::begin(ic)), BI(std::end(ic)));
-    if (p6.first != BI(ib+2) || p6.second != BI(ic+2))
-        return false;
-
-    return true;
-    }
+#include "type_algorithms.h"
+
+template <class Iter, class Container1, class Container2>
+TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
+  if (lhs.size() == rhs.size()) {
+    assert(std::mismatch(Iter(lhs.data()), Iter(lhs.data() + lhs.size()), Iter(rhs.data())) ==
+           std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+
+    assert(std::mismatch(Iter(lhs.data()),
+                         Iter(lhs.data() + lhs.size()),
+                         Iter(rhs.data()),
+                         std::equal_to<typename Container1::value_type>()) ==
+           std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+  }
+
+#if TEST_STD_VER >= 14
+  assert(
+      std::mismatch(Iter(lhs.data()), Iter(lhs.data() + lhs.size()), Iter(rhs.data()), Iter(rhs.data() + rhs.size())) ==
+      std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+
+  assert(std::mismatch(Iter(lhs.data()),
+                       Iter(lhs.data() + lhs.size()),
+                       Iter(rhs.data()),
+                       Iter(rhs.data() + rhs.size()),
+                       std::equal_to<typename Container1::value_type>()) ==
+         std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
 #endif
+}
 
-int main(int, char**)
-{
-    int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
-    const unsigned sa = sizeof(ia)/sizeof(ia[0]);
-    int ib[] = {0, 1, 2, 3, 0, 1, 2, 3};
-    const unsigned sb = sizeof(ib)/sizeof(ib[0]); ((void)sb); // unused in C++11
-
-    typedef cpp17_input_iterator<const int*> II;
-    typedef random_access_iterator<const int*>  RAI;
-
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib))
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib))
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-
-#if TEST_STD_VER > 11 // We have the four iteration version
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+sb))
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
+struct NonTrivial {
+  int i_;
+
+  TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
+  TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+
+  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+};
+
+struct ModTwoComp {
+  TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
+};
+
+template <class Iter>
+TEST_CONSTEXPR_CXX20 bool test() {
+  { // empty ranges
+    std::array<int, 0> lhs = {};
+    std::array<int, 0> rhs = {};
+    check<Iter>(lhs, rhs, 0);
+  }
+
+  { // same range without mismatch
+    std::array<int, 8> lhs = {0, 1, 2, 3, 0, 1, 2, 3};
+    std::array<int, 8> rhs = {0, 1, 2, 3, 0, 1, 2, 3};
+    check<Iter>(lhs, rhs, 8);
+  }
+
+  { // same range with mismatch
+    std::array<int, 8> lhs = {0, 1, 2, 2, 0, 1, 2, 3};
+    std::array<int, 8> rhs = {0, 1, 2, 3, 0, 1, 2, 3};
+    check<Iter>(lhs, rhs, 3);
+  }
+
+  { // second range is smaller
+    std::array<int, 8> lhs = {0, 1, 2, 2, 0, 1, 2, 3};
+    std::array<int, 2> rhs = {0, 1};
+    check<Iter>(lhs, rhs, 2);
+  }
+
+  { // first range is smaller
+    std::array<int, 2> lhs = {0, 1};
+    std::array<int, 8> rhs = {0, 1, 2, 2, 0, 1, 2, 3};
+    check<Iter>(lhs, rhs, 2);
+  }
+
+  { // use a custom comparator
+    std::array<int, 4> lhs = {0, 2, 3, 4};
+    std::array<int, 4> rhs = {0, 0, 4, 4};
+    assert(std::mismatch(lhs.data(), lhs.data() + lhs.size(), rhs.data(), ModTwoComp()) ==
+           std::make_pair(lhs.data() + 2, rhs.data() + 2));
+#if TEST_STD_VER >= 14
+    assert(std::mismatch(lhs.data(), lhs.data() + lhs.size(), rhs.data(), rhs.data() + rhs.size(), ModTwoComp()) ==
+           std::make_pair(lhs.data() + 2, rhs.data() + 2));
+#endif
+  }
 
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib+sb))
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
+  return true;
+}
 
+struct Test {
+  template <class Iter>
+  TEST_CONSTEXPR_CXX20 void operator()() {
+    test<Iter>();
+  }
+};
+
+TEST_CONSTEXPR_CXX20 bool test() {
+  types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
+
+  { // use a non-integer type to also test the general case - all elements match
+    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
+  }
+
+  { // use a non-integer type to also test the general case - not all elements match
+    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
+  }
+
+  return true;
+}
 
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+2))
-            == (std::pair<II, II>(II(ia+2), II(ib+2))));
+int main(int, char**) {
+  test();
+#if TEST_STD_VER >= 20
+  static_assert(test());
 #endif
 
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
-#endif
+  { // check with a lot of elements to test the vectorization optimization
+    {
+      std::vector<char> lhs(256);
+      std::vector<char> rhs(256);
+      for (size_t i = 0; i != lhs.size(); ++i) {
+        lhs[i] = 1;
+        check<char*>(lhs, rhs, i);
+        lhs[i] = 0;
+        rhs[i] = 1;
+        check<char*>(lhs, rhs, i);
+        rhs[i] = 0;
+      }
+    }
+
+    {
+      std::vector<int> lhs(256);
+      std::vector<int> rhs(256);
+      for (size_t i = 0; i != lhs.size(); ++i) {
+        lhs[i] = 1;
+        check<int*>(lhs, rhs, i);
+        lhs[i] = 0;
+        rhs[i] = 1;
+        check<int*>(lhs, rhs, i);
+        rhs[i] = 0;
+      }
+    }
+  }
 
   return 0;
 }
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp
deleted file mode 100644
index bda4ec7ba5ed60..00000000000000
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-//===----------------------------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-// <algorithm>
-
-// template<InputIterator Iter1, InputIterator Iter2,
-//          Predicate<auto, Iter1::value_type, Iter2::value_type> Pred>
-//   requires CopyConstructible<Pred>
-//   constexpr pair<Iter1, Iter2>   // constexpr after c++17
-//   mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Pred pred);
-//
-// template<InputIterator Iter1, InputIterator Iter2, Predicate Pred>
-//   constexpr pair<Iter1, Iter2>   // constexpr after c++17
-//   mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, Pred pred); // C++14
-
-#include <algorithm>
-#include <functional>
-#include <cassert>
-
-#include "test_macros.h"
-#include "test_iterators.h"
-#include "counting_predicates.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool eq(int a, int b) { return a == b; }
-
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {1, 3, 6, 7};
-    int ib[] = {1, 3};
-    int ic[] = {1, 3, 5, 7};
-    typedef cpp17_input_iterator<int*>         II;
-    typedef bidirectional_iterator<int*> BI;
-
-    auto p1 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), eq);
-    if (p1.first != ia+2 || p1.second != ic+2)
-        return false;
-
-    auto p2 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), std::end(ic), eq);
-    if (p2.first != ia+2 || p2.second != ic+2)
-        return false;
-
-    auto p3 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), eq);
-    if (p3.first != ib+2 || p3.second != ic+2)
-        return false;
-
-    auto p4 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), std::end(ic), eq);
-    if (p4.first != ib+2 || p4.second != ic+2)
-        return false;
-
-    auto p5 = std::mismatch(II(std::begin(ib)), II(std::end(ib)), II(std::begin(ic)), eq);
-    if (p5.first != II(ib+2) || p5.second != II(ic+2))
-        return false;
-    auto p6 = std::mismatch(BI(std::begin(ib)), BI(std::end(ib)), BI(std::begin(ic)), BI(std::end(ic)), eq);
-    if (p6.first != BI(ib+2) || p6.second != BI(ic+2))
-        return false;
-
-    return true;
-    }
-#endif
-
-
-#if TEST_STD_VER > 11
-#define HAS_FOUR_ITERATOR_VERSION
-#endif
-
-int main(int, char**)
-{
-    int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
-    const unsigned sa = sizeof(ia)/sizeof(ia[0]);
-    int ib[] = {0, 1, 2, 3, 0, 1, 2, 3};
-    const unsigned sb = sizeof(ib)/sizeof(ib[0]); ((void)sb); // unused in C++11
-
-    typedef cpp17_input_iterator<const int*> II;
-    typedef random_access_iterator<const int*>  RAI;
-    typedef std::equal_to<int> EQ;
-
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), EQ())
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), EQ())
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-
-    binary_counting_predicate<EQ, int> bcp((EQ()));
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), std::ref(bcp))
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-    assert(bcp.count() > 0 && bcp.count() < sa);
-    bcp.reset();
-
-#if TEST_STD_VER >= 14
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + sb), EQ())
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-    assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib + sb), EQ())
-            == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-
-    assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + sb), std::ref(bcp))
-            == (std::pair<II, II>(II(ia+3), II(ib+3))));
-    assert(bcp.count() > 0 && bcp.count() < std::min(sa, sb));
-#endif
-
-    assert(std::mismatch(ia, ia + sa, ib, EQ()) ==
-           (std::pair<int*,int*>(ia+3,ib+3)));
-
-#if TEST_STD_VER >= 14
-    assert(std::mismatch(ia, ia + sa, ib, ib + sb, EQ()) ==
-           (std::pair<int*,int*>(ia+3,ib+3)));
-    assert(std::mismatch(ia, ia + sa, ib, ib + 2, EQ()) ==
-           (std::pair<int*,int*>(ia+2,ib+2)));
-#endif
-
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
-#endif
-
-  return 0;
-}



More information about the libcxx-commits mailing list