[libcxx-commits] [libcxx] [libc++] Optimize the std::mismatch tail (PR #83440)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Feb 29 08:02:44 PST 2024
https://github.com/philnik777 created https://github.com/llvm/llvm-project/pull/83440
```
-----------------------------------------------------------------------
Benchmark old full vectors parital vector
-----------------------------------------------------------------------
bm_mismatch<char>/1 1.40 ns 1.62 ns 2.09 ns
bm_mismatch<char>/2 1.88 ns 2.10 ns 2.33 ns
bm_mismatch<char>/3 2.67 ns 2.56 ns 2.72 ns
bm_mismatch<char>/4 3.01 ns 3.20 ns 3.70 ns
bm_mismatch<char>/5 3.51 ns 3.73 ns 3.64 ns
bm_mismatch<char>/6 4.71 ns 4.85 ns 4.37 ns
bm_mismatch<char>/7 5.12 ns 5.33 ns 4.37 ns
bm_mismatch<char>/8 5.79 ns 6.02 ns 4.75 ns
bm_mismatch<char>/15 9.20 ns 10.5 ns 7.23 ns
bm_mismatch<char>/16 10.2 ns 10.1 ns 7.46 ns
bm_mismatch<char>/17 10.2 ns 10.8 ns 7.57 ns
bm_mismatch<char>/31 17.6 ns 17.1 ns 10.8 ns
bm_mismatch<char>/32 17.4 ns 1.64 ns 1.64 ns
bm_mismatch<char>/33 23.3 ns 2.10 ns 2.33 ns
bm_mismatch<char>/63 31.8 ns 16.9 ns 2.33 ns
bm_mismatch<char>/64 32.6 ns 2.10 ns 2.10 ns
bm_mismatch<char>/65 33.6 ns 2.57 ns 2.80 ns
bm_mismatch<char>/127 67.3 ns 18.1 ns 3.27 ns
bm_mismatch<char>/128 2.17 ns 2.14 ns 2.57 ns
bm_mismatch<char>/129 2.36 ns 2.80 ns 3.27 ns
bm_mismatch<char>/255 67.5 ns 19.6 ns 4.68 ns
bm_mismatch<char>/256 3.76 ns 3.71 ns 3.97 ns
bm_mismatch<char>/257 3.77 ns 4.04 ns 4.43 ns
bm_mismatch<char>/511 70.8 ns 22.1 ns 7.47 ns
bm_mismatch<char>/512 7.27 ns 7.30 ns 6.95 ns
bm_mismatch<char>/513 7.11 ns 7.05 ns 6.96 ns
bm_mismatch<char>/1023 75.9 ns 27.4 ns 13.3 ns
bm_mismatch<char>/1024 13.9 ns 13.8 ns 12.4 ns
bm_mismatch<char>/1025 13.6 ns 13.6 ns 12.8 ns
bm_mismatch<char>/2047 87.3 ns 37.5 ns 25.4 ns
bm_mismatch<char>/2048 26.8 ns 27.4 ns 24.0 ns
bm_mismatch<char>/2049 26.7 ns 27.3 ns 25.5 ns
bm_mismatch<char>/4095 112 ns 64.7 ns 48.7 ns
bm_mismatch<char>/4096 53.0 ns 54.2 ns 46.8 ns
bm_mismatch<char>/4097 52.7 ns 54.2 ns 48.4 ns
bm_mismatch<char>/8191 160 ns 118 ns 98.4 ns
bm_mismatch<char>/8192 107 ns 108 ns 96.0 ns
bm_mismatch<char>/8193 106 ns 108 ns 97.2 ns
bm_mismatch<char>/16383 283 ns 234 ns 215 ns
bm_mismatch<char>/16384 227 ns 223 ns 217 ns
bm_mismatch<char>/16385 221 ns 221 ns 215 ns
bm_mismatch<char>/32767 547 ns 499 ns 488 ns
bm_mismatch<char>/32768 495 ns 492 ns 492 ns
bm_mismatch<char>/32769 491 ns 489 ns 488 ns
bm_mismatch<char>/65535 1028 ns 979 ns 971 ns
bm_mismatch<char>/65536 976 ns 970 ns 974 ns
bm_mismatch<char>/65537 970 ns 965 ns 971 ns
bm_mismatch<char>/131071 2031 ns 1948 ns 2005 ns
bm_mismatch<char>/131072 1973 ns 1955 ns 1974 ns
bm_mismatch<char>/131073 1989 ns 1932 ns 2001 ns
bm_mismatch<char>/262143 4469 ns 4244 ns 4223 ns
bm_mismatch<char>/262144 4443 ns 4183 ns 4243 ns
bm_mismatch<char>/262145 4400 ns 4232 ns 4246 ns
bm_mismatch<char>/524287 10169 ns 9733 ns 9592 ns
bm_mismatch<char>/524288 10154 ns 9664 ns 9843 ns
bm_mismatch<char>/524289 10113 ns 9641 ns 10003 ns
bm_mismatch<short>/1 1.86 ns 2.53 ns 2.32 ns
bm_mismatch<short>/2 2.57 ns 2.77 ns 2.55 ns
bm_mismatch<short>/3 3.26 ns 3.00 ns 2.79 ns
bm_mismatch<short>/4 3.95 ns 3.39 ns 3.15 ns
bm_mismatch<short>/5 4.83 ns 3.97 ns 3.72 ns
bm_mismatch<short>/6 5.43 ns 4.34 ns 4.03 ns
bm_mismatch<short>/7 6.11 ns 4.73 ns 4.44 ns
bm_mismatch<short>/8 6.84 ns 5.02 ns 4.79 ns
bm_mismatch<short>/15 11.5 ns 7.12 ns 6.50 ns
bm_mismatch<short>/16 13.9 ns 1.87 ns 2.11 ns
bm_mismatch<short>/17 14.0 ns 3.00 ns 2.47 ns
bm_mismatch<short>/31 23.1 ns 7.87 ns 2.47 ns
bm_mismatch<short>/32 23.8 ns 2.57 ns 2.81 ns
bm_mismatch<short>/33 24.5 ns 3.70 ns 2.94 ns
bm_mismatch<short>/63 44.8 ns 9.37 ns 3.46 ns
bm_mismatch<short>/64 2.32 ns 2.57 ns 2.64 ns
bm_mismatch<short>/65 2.52 ns 3.02 ns 3.51 ns
bm_mismatch<short>/127 45.6 ns 9.97 ns 5.18 ns
bm_mismatch<short>/128 3.85 ns 3.93 ns 3.94 ns
bm_mismatch<short>/129 3.82 ns 4.20 ns 4.70 ns
bm_mismatch<short>/255 50.4 ns 12.6 ns 8.07 ns
bm_mismatch<short>/256 7.23 ns 6.91 ns 6.98 ns
bm_mismatch<short>/257 7.24 ns 7.19 ns 7.55 ns
bm_mismatch<short>/511 52.3 ns 17.8 ns 14.0 ns
bm_mismatch<short>/512 13.6 ns 13.7 ns 13.6 ns
bm_mismatch<short>/513 13.9 ns 13.8 ns 18.5 ns
bm_mismatch<short>/1023 60.9 ns 30.9 ns 26.3 ns
bm_mismatch<short>/1024 26.7 ns 27.7 ns 25.7 ns
bm_mismatch<short>/1025 27.7 ns 27.6 ns 25.3 ns
bm_mismatch<short>/2047 88.4 ns 58.0 ns 51.6 ns
bm_mismatch<short>/2048 52.8 ns 55.3 ns 50.6 ns
bm_mismatch<short>/2049 55.2 ns 54.8 ns 48.7 ns
bm_mismatch<short>/4095 153 ns 113 ns 102 ns
bm_mismatch<short>/4096 105 ns 110 ns 101 ns
bm_mismatch<short>/4097 110 ns 110 ns 99.1 ns
bm_mismatch<short>/8191 277 ns 219 ns 206 ns
bm_mismatch<short>/8192 226 ns 214 ns 250 ns
bm_mismatch<short>/8193 226 ns 207 ns 208 ns
bm_mismatch<short>/16383 519 ns 492 ns 488 ns
bm_mismatch<short>/16384 494 ns 492 ns 492 ns
bm_mismatch<short>/16385 492 ns 488 ns 489 ns
bm_mismatch<short>/32767 1007 ns 968 ns 964 ns
bm_mismatch<short>/32768 977 ns 972 ns 970 ns
bm_mismatch<short>/32769 972 ns 962 ns 967 ns
bm_mismatch<short>/65535 1978 ns 1918 ns 1956 ns
bm_mismatch<short>/65536 1940 ns 1927 ns 1970 ns
bm_mismatch<short>/65537 1937 ns 1922 ns 1959 ns
bm_mismatch<short>/131071 4524 ns 4193 ns 4304 ns
bm_mismatch<short>/131072 4445 ns 4196 ns 4306 ns
bm_mismatch<short>/131073 4452 ns 4278 ns 4311 ns
bm_mismatch<short>/262143 9801 ns 10188 ns 9634 ns
bm_mismatch<short>/262144 9738 ns 10151 ns 9651 ns
bm_mismatch<short>/262145 9716 ns 10171 ns 9715 ns
bm_mismatch<short>/524287 19944 ns 20718 ns 20044 ns
bm_mismatch<short>/524288 21139 ns 20647 ns 20008 ns
bm_mismatch<short>/524289 21162 ns 19512 ns 20068 ns
bm_mismatch<int>/1 1.40 ns 1.84 ns 1.87 ns
bm_mismatch<int>/2 1.87 ns 2.08 ns 2.09 ns
bm_mismatch<int>/3 2.36 ns 2.31 ns 2.87 ns
bm_mismatch<int>/4 3.06 ns 2.72 ns 2.95 ns
bm_mismatch<int>/5 3.66 ns 3.37 ns 3.42 ns
bm_mismatch<int>/6 4.55 ns 3.65 ns 3.73 ns
bm_mismatch<int>/7 5.03 ns 3.93 ns 3.94 ns
bm_mismatch<int>/8 5.67 ns 1.86 ns 1.87 ns
bm_mismatch<int>/15 9.89 ns 4.41 ns 2.34 ns
bm_mismatch<int>/16 10.1 ns 2.33 ns 2.34 ns
bm_mismatch<int>/17 10.2 ns 3.34 ns 2.86 ns
bm_mismatch<int>/31 17.2 ns 5.54 ns 3.28 ns
bm_mismatch<int>/32 2.16 ns 2.15 ns 2.58 ns
bm_mismatch<int>/33 2.36 ns 3.01 ns 3.28 ns
bm_mismatch<int>/63 17.7 ns 6.50 ns 4.93 ns
bm_mismatch<int>/64 3.81 ns 3.58 ns 3.90 ns
bm_mismatch<int>/65 3.74 ns 4.36 ns 4.45 ns
bm_mismatch<int>/127 19.5 ns 9.56 ns 7.74 ns
bm_mismatch<int>/128 7.30 ns 6.41 ns 6.85 ns
bm_mismatch<int>/129 7.09 ns 7.04 ns 7.06 ns
bm_mismatch<int>/255 24.7 ns 14.8 ns 13.3 ns
bm_mismatch<int>/256 14.0 ns 12.1 ns 12.3 ns
bm_mismatch<int>/257 13.8 ns 12.7 ns 12.8 ns
bm_mismatch<int>/511 34.3 ns 26.3 ns 24.8 ns
bm_mismatch<int>/512 27.6 ns 23.6 ns 23.9 ns
bm_mismatch<int>/513 27.3 ns 24.4 ns 25.1 ns
bm_mismatch<int>/1023 62.5 ns 50.9 ns 48.3 ns
bm_mismatch<int>/1024 54.4 ns 46.1 ns 46.6 ns
bm_mismatch<int>/1025 54.2 ns 48.4 ns 47.5 ns
bm_mismatch<int>/2047 116 ns 97.8 ns 94.1 ns
bm_mismatch<int>/2048 108 ns 92.6 ns 92.4 ns
bm_mismatch<int>/2049 108 ns 104 ns 94.0 ns
bm_mismatch<int>/4095 233 ns 222 ns 205 ns
bm_mismatch<int>/4096 226 ns 223 ns 225 ns
bm_mismatch<int>/4097 221 ns 219 ns 210 ns
bm_mismatch<int>/8191 499 ns 485 ns 488 ns
bm_mismatch<int>/8192 496 ns 490 ns 495 ns
bm_mismatch<int>/8193 491 ns 485 ns 488 ns
bm_mismatch<int>/16383 982 ns 962 ns 964 ns
bm_mismatch<int>/16384 974 ns 971 ns 971 ns
bm_mismatch<int>/16385 971 ns 961 ns 968 ns
bm_mismatch<int>/32767 2003 ns 1959 ns 1920 ns
bm_mismatch<int>/32768 1996 ns 1947 ns 1928 ns
bm_mismatch<int>/32769 1990 ns 1945 ns 1926 ns
bm_mismatch<int>/65535 4434 ns 4275 ns 4312 ns
bm_mismatch<int>/65536 4437 ns 4267 ns 4321 ns
bm_mismatch<int>/65537 4442 ns 4261 ns 4321 ns
bm_mismatch<int>/131071 9673 ns 9648 ns 9465 ns
bm_mismatch<int>/131072 9667 ns 9671 ns 9465 ns
bm_mismatch<int>/131073 9661 ns 9653 ns 9464 ns
bm_mismatch<int>/262143 20595 ns 19605 ns 19064 ns
bm_mismatch<int>/262144 19894 ns 19572 ns 19009 ns
bm_mismatch<int>/262145 19851 ns 19656 ns 18999 ns
bm_mismatch<int>/524287 39556 ns 39364 ns 38131 ns
bm_mismatch<int>/524288 39678 ns 39573 ns 38183 ns
bm_mismatch<int>/524289 40168 ns 39301 ns 38121 ns
```
>From 6c6f71abe66c9b15d2b2fac75f350b79c4fc5f13 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 19 Feb 2024 13:20:58 +0100
Subject: [PATCH 1/2] [libc++] Vectorize mismatch
---
libcxx/benchmarks/CMakeLists.txt | 1 +
.../benchmarks/algorithms/mismatch.bench.cpp | 31 ++++
libcxx/docs/ReleaseNotes/19.rst | 2 +
libcxx/include/CMakeLists.txt | 1 +
libcxx/include/__algorithm/mismatch.h | 81 ++++++++-
libcxx/include/__algorithm/simd_utils.h | 114 ++++++++++++
libcxx/include/__bit/bit_cast.h | 9 +
libcxx/include/__bit/countr.h | 13 +-
libcxx/include/libcxx.imp | 1 +
libcxx/include/module.modulemap.in | 6 +-
.../mismatch/mismatch.pass.cpp | 172 ++++++++++++------
.../mismatch/mismatch_pred.pass.cpp | 119 ------------
12 files changed, 362 insertions(+), 188 deletions(-)
create mode 100644 libcxx/benchmarks/algorithms/mismatch.bench.cpp
create mode 100644 libcxx/include/__algorithm/simd_utils.h
delete mode 100644 libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp
diff --git a/libcxx/benchmarks/CMakeLists.txt b/libcxx/benchmarks/CMakeLists.txt
index b436e96f178b70..16ecf9b775508a 100644
--- a/libcxx/benchmarks/CMakeLists.txt
+++ b/libcxx/benchmarks/CMakeLists.txt
@@ -182,6 +182,7 @@ set(BENCHMARK_TESTS
algorithms/make_heap_then_sort_heap.bench.cpp
algorithms/min.bench.cpp
algorithms/min_max_element.bench.cpp
+ algorithms/mismatch.bench.cpp
algorithms/pop_heap.bench.cpp
algorithms/pstl.stable_sort.bench.cpp
algorithms/push_heap.bench.cpp
diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
new file mode 100644
index 00000000000000..9274932a764c55
--- /dev/null
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <benchmark/benchmark.h>
+#include <random>
+
+// TODO: Look into benchmarking aligned and unaligned memory explicitly
+// (currently things happen to be aligned because they are malloced that way)
+template <class T>
+static void bm_mismatch(benchmark::State& state) {
+ std::vector<T> vec1(state.range(), '1');
+ std::vector<T> vec2(state.range(), '1');
+ std::mt19937_64 rng(std::random_device{}());
+
+ vec1.back() = '2';
+ for (auto _ : state) {
+ benchmark::DoNotOptimize(vec1);
+ benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
+ }
+}
+BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+
+BENCHMARK_MAIN();
diff --git a/libcxx/docs/ReleaseNotes/19.rst b/libcxx/docs/ReleaseNotes/19.rst
index 78c6bb87a5a402..8f6091503ce025 100644
--- a/libcxx/docs/ReleaseNotes/19.rst
+++ b/libcxx/docs/ReleaseNotes/19.rst
@@ -47,6 +47,8 @@ Improvements and New Features
-----------------------------
- The performance of growing ``std::vector`` has been improved for trivially relocatable types.
+- The ``std::mismatch`` algorithm has been optimized for integral types, which can lead up to 40x performance
+ improvements.
Deprecations and Removals
-------------------------
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index cafd8c6e00d968..e0be4b2f340eed 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -217,6 +217,7 @@ set(files
__algorithm/shift_right.h
__algorithm/shuffle.h
__algorithm/sift_down.h
+ __algorithm/simd_utils.h
__algorithm/sort.h
__algorithm/sort_heap.h
__algorithm/stable_partition.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index d345b6048a7e9b..d0234e793cbca3 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -11,23 +11,92 @@
#define _LIBCPP___ALGORITHM_MISMATCH_H
#include <__algorithm/comp.h>
+#include <__algorithm/simd_utils.h>
+#include <__algorithm/unwrap_iter.h>
#include <__config>
-#include <__iterator/iterator_traits.h>
+#include <__functional/identity.h>
+#include <__type_traits/invoke.h>
+#include <__type_traits/is_constant_evaluated.h>
+#include <__type_traits/is_equality_comparable.h>
+#include <__type_traits/operation_traits.h>
+#include <__utility/move.h>
#include <__utility/pair.h>
+#include <__utility/unreachable.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
#endif
+_LIBCPP_PUSH_MACROS
+#include <__undef_macros>
+
_LIBCPP_BEGIN_NAMESPACE_STD
+template <class _Iter1, class _Sent1, class _Iter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter1, _Iter2>
+__mismatch_loop(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+ while (__first1 != __last1) {
+ if (!std::__invoke(__pred, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
+ break;
+ ++__first1;
+ ++__first2;
+ }
+ return std::make_pair(std::move(__first1), std::move(__first2));
+}
+
+template <class _Iter1, class _Sent1, class _Iter2, class _Pred, class _Proj1, class _Proj2>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter1, _Iter2>
+__mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORIHTMS
+
+template <class _Tp,
+ class _Pred,
+ class _Proj1,
+ class _Proj2,
+ __enable_if_t<is_integral<_Tp>::value && __desugars_to<__equal_tag, _Pred, _Tp, _Tp>::value &&
+ __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
+ int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+ constexpr size_t __unroll_count = 4;
+ constexpr size_t __vec_size = __native_vector_size<_Tp>;
+ using __vec = __simd_vector<_Tp, __vec_size>;
+ while (!__libcpp_is_constant_evaluated() && static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size)
+ [[__unlikely__]] {
+ __vec __lhs[__unroll_count];
+ __vec __rhs[__unroll_count];
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ __lhs[__i] = std::__load_vector<__vec>(__first1 + __i * __vec_size);
+ __rhs[__i] = std::__load_vector<__vec>(__first2 + __i * __vec_size);
+ }
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ if (auto __cmp_res = __lhs[__i] == __rhs[__i]; !std::__all_of(__cmp_res)) {
+ auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res);
+ return {__first1 + __offset, __first2 + __offset};
+ }
+ }
+
+ __first1 += __unroll_count * __vec_size;
+ __first2 += __unroll_count * __vec_size;
+ }
+ // TODO: Consider vectorizing the tail
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+}
+
+#endif // _LIBCPP_VECTORIZE_ALGORIHTMS
+
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
_LIBCPP_NODISCARD_EXT inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_InputIterator1, _InputIterator2>
mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate __pred) {
- for (; __first1 != __last1; ++__first1, (void)++__first2)
- if (!__pred(*__first1, *__first2))
- break;
- return pair<_InputIterator1, _InputIterator2>(__first1, __first2);
+ __identity __proj;
+ auto __res = std::__mismatch(
+ std::__unwrap_iter(__first1), std::__unwrap_iter(__last1), std::__unwrap_iter(__first2), __pred, __proj, __proj);
+ return std::make_pair(std::__rewrap_iter(__first1, __res.first), std::__rewrap_iter(__first2, __res.second));
}
template <class _InputIterator1, class _InputIterator2>
@@ -59,4 +128,6 @@ mismatch(_InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __fi
_LIBCPP_END_NAMESPACE_STD
+_LIBCPP_POP_MACROS
+
#endif // _LIBCPP___ALGORITHM_MISMATCH_H
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
new file mode 100644
index 00000000000000..b500ea09186ae6
--- /dev/null
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -0,0 +1,114 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_SIMD_UTILS_H
+#define _LIBCPP___ALGORITHM_SIMD_UTILS_H
+
+#include <__bit/bit_cast.h>
+#include <__bit/countr.h>
+#include <__config>
+#include <__type_traits/is_arithmetic.h>
+#include <__type_traits/is_same.h>
+#include <__utility/integer_sequence.h>
+#include <cstddef>
+#include <cstdint>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+#if _LIBCPP_STD_VER >= 14 && __has_attribute(__ext_vector_type__) && __has_builtin(__builtin_reduce_and) && \
+ __has_builtin(__builtin_convertvector)
+# define _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS 1
+#else
+# define _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS 0
+#endif
+
+#if _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS && !defined(__OPTIMIZE_SIZE__)
+# define _LIBCPP_VECTORIZE_ALGORIHTMS 1
+#else
+# define _LIBCPP_VECTORIZE_ALGORIHTMS 0
+#endif
+
+#if _LIBCPP_HAS_ALGORITHM_VECTOR_UTILS
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+# if defined(__AVX__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 32 / sizeof(_Tp);
+# elif defined(__SSE__) || defined(__ARM_NEON__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 16 / sizeof(_Tp);
+# elif defined(__MMX__)
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 8 / sizeof(_Tp);
+# else
+template <class _Tp>
+inline constexpr size_t __native_vector_size = 1;
+# endif
+
+template <class _Tp, size_t _Np>
+using __simd_vector __attribute__((__ext_vector_type__(_Np))) = _Tp;
+
+template <class _VecT>
+inline constexpr size_t __simd_vector_size_v = []() -> size_t { static_assert(false, "Not a vector!"); }();
+
+template <class _Tp, size_t _Np>
+inline constexpr size_t __simd_vector_size_v<__simd_vector<_Tp, _Np>> = _Np;
+
+template <class _VecT>
+using __simd_vector_underlying_type_t =
+ decltype([]<class _Tp, size_t _Np>(__simd_vector<_Tp, _Np>) { return _Tp{}; }(_VecT{}));
+
+// This isn't inlined without always_inline when loading chars.
+template <class _VecT>
+[[__gnu__::__always_inline__]] _LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const auto* __ptr) noexcept {
+ return []<size_t... _Indices> [[__gnu__::__always_inline__]] (
+ const auto* __lptr, index_sequence<_Indices...>) static noexcept {
+ return _VecT{__lptr[_Indices]...};
+ }(__ptr, make_index_sequence<__simd_vector_size_v<_VecT>>{});
+}
+
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<_Tp, _Np> __vec) noexcept {
+ return __builtin_reduce_and(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
+}
+
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_set(__simd_vector<_Tp, _Np> __vec) noexcept {
+ using __mask_vec = __simd_vector<bool, _Np>;
+
+ auto __impl = [&]<class _MaskT>(_MaskT) noexcept {
+ return std::__countr_zero(std::__bit_cast<_MaskT>(__builtin_convertvector(__vec, __mask_vec)));
+ };
+
+ if constexpr (sizeof(__mask_vec) == sizeof(uint8_t)) {
+ return __impl(uint8_t{});
+ } else if constexpr (sizeof(__mask_vec) == sizeof(uint16_t)) {
+ return __impl(uint16_t{});
+ } else if constexpr (sizeof(__mask_vec) == sizeof(uint32_t)) {
+ return __impl(uint32_t{});
+ } else if constexpr (sizeof(__mask_vec) == sizeof(uint64_t)) {
+ return __impl(uint64_t{});
+ } else {
+ static_assert(sizeof(__mask_vec) == 0, "unexpected required size for mask integer type");
+ }
+}
+
+template <class _Tp, size_t _Np>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_not_set(__simd_vector<_Tp, _Np> __vec) noexcept {
+ return std::__find_first_set(~__vec);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP_STD_VER >= 14 && __has_attribute(__ext_vector_type__) && __has_builtin(__builtin_reduce_and) &&
+ // __has_builtin(__builtin_convertvector)
+
+#endif // _LIBCPP___ALGORITHM_SIMD_UTILS_H
diff --git a/libcxx/include/__bit/bit_cast.h b/libcxx/include/__bit/bit_cast.h
index f20b39ae748b10..6298810f373303 100644
--- a/libcxx/include/__bit/bit_cast.h
+++ b/libcxx/include/__bit/bit_cast.h
@@ -19,6 +19,15 @@
_LIBCPP_BEGIN_NAMESPACE_STD
+#ifndef _LIBCPP_CXX03_LANG
+
+template <class _ToType, class _FromType>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI constexpr _ToType __bit_cast(const _FromType& __from) noexcept {
+ return __builtin_bit_cast(_ToType, __from);
+}
+
+#endif // _LIBCPP_CXX03_LANG
+
#if _LIBCPP_STD_VER >= 20
template <class _ToType, class _FromType>
diff --git a/libcxx/include/__bit/countr.h b/libcxx/include/__bit/countr.h
index 0cc679f87a99d9..b6b3ac52ca4e47 100644
--- a/libcxx/include/__bit/countr.h
+++ b/libcxx/include/__bit/countr.h
@@ -35,10 +35,8 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR int __libcpp_ct
return __builtin_ctzll(__x);
}
-#if _LIBCPP_STD_VER >= 20
-
-template <__libcpp_unsigned_integer _Tp>
-_LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) noexcept {
+template <class _Tp>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 int __countr_zero(_Tp __t) _NOEXCEPT {
if (__t == 0)
return numeric_limits<_Tp>::digits;
@@ -59,6 +57,13 @@ _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) n
}
}
+#if _LIBCPP_STD_VER >= 20
+
+template <__libcpp_unsigned_integer _Tp>
+_LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_zero(_Tp __t) noexcept {
+ return std::__countr_zero(__t);
+}
+
template <__libcpp_unsigned_integer _Tp>
_LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr int countr_one(_Tp __t) noexcept {
return __t != numeric_limits<_Tp>::max() ? std::countr_zero(static_cast<_Tp>(~__t)) : numeric_limits<_Tp>::digits;
diff --git a/libcxx/include/libcxx.imp b/libcxx/include/libcxx.imp
index 22fbea99b848bb..f606865380bc90 100644
--- a/libcxx/include/libcxx.imp
+++ b/libcxx/include/libcxx.imp
@@ -217,6 +217,7 @@
{ include: [ "<__algorithm/shift_right.h>", "private", "<algorithm>", "public" ] },
{ include: [ "<__algorithm/shuffle.h>", "private", "<algorithm>", "public" ] },
{ include: [ "<__algorithm/sift_down.h>", "private", "<algorithm>", "public" ] },
+ { include: [ "<__algorithm/simd_utils.h>", "private", "<algorithm>", "public" ] },
{ include: [ "<__algorithm/sort.h>", "private", "<algorithm>", "public" ] },
{ include: [ "<__algorithm/sort_heap.h>", "private", "<algorithm>", "public" ] },
{ include: [ "<__algorithm/stable_partition.h>", "private", "<algorithm>", "public" ] },
diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in
index 219906aa9a5668..3fccae7d8ea49f 100644
--- a/libcxx/include/module.modulemap.in
+++ b/libcxx/include/module.modulemap.in
@@ -697,7 +697,10 @@ module std_private_algorithm_minmax [system
export *
}
module std_private_algorithm_minmax_element [system] { header "__algorithm/minmax_element.h" }
-module std_private_algorithm_mismatch [system] { header "__algorithm/mismatch.h" }
+module std_private_algorithm_mismatch [system] {
+ header "__algorithm/mismatch.h"
+ export std_private_algorithm_simd_utils
+}
module std_private_algorithm_move [system] { header "__algorithm/move.h" }
module std_private_algorithm_move_backward [system] { header "__algorithm/move_backward.h" }
module std_private_algorithm_next_permutation [system] { header "__algorithm/next_permutation.h" }
@@ -1048,6 +1051,7 @@ module std_private_algorithm_sort [system
header "__algorithm/sort.h"
export std_private_debug_utils_strict_weak_ordering_check
}
+module std_private_algorithm_simd_utils [system] { header "__algorithm/simd_utils.h" }
module std_private_algorithm_sort_heap [system] { header "__algorithm/sort_heap.h" }
module std_private_algorithm_stable_partition [system] { header "__algorithm/stable_partition.h" }
module std_private_algorithm_stable_sort [system] { header "__algorithm/stable_sort.h" }
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index cc588c095ccfb2..f0716a1af1e1ef 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -16,79 +16,133 @@
// template<InputIterator Iter1, InputIterator Iter2Pred>
// constexpr pair<Iter1, Iter2> // constexpr after c++17
// mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2); // C++14
+//
+// template<InputIterator Iter1, InputIterator Iter2,
+// Predicate<auto, Iter1::value_type, Iter2::value_type> Pred>
+// requires CopyConstructible<Pred>
+// constexpr pair<Iter1, Iter2> // constexpr after c++17
+// mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Pred pred);
+//
+// template<InputIterator Iter1, InputIterator Iter2, Predicate Pred>
+// constexpr pair<Iter1, Iter2> // constexpr after c++17
+// mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, Pred pred); // C++14
+
+// ADDITIONAL_COMPILE_FLAGS(has-fconstexpr-steps): -fconstexpr-steps=50000000
+// ADDITIONAL_COMPILE_FLAGS(has-fconstexpr-ops-limit): -fconstexpr-ops-limit=100000000
#include <algorithm>
+#include <array>
#include <cassert>
+#include <vector>
#include "test_macros.h"
#include "test_iterators.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
- int ia[] = {1, 3, 6, 7};
- int ib[] = {1, 3};
- int ic[] = {1, 3, 5, 7};
- typedef cpp17_input_iterator<int*> II;
- typedef bidirectional_iterator<int*> BI;
-
- auto p1 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic));
- if (p1.first != ia+2 || p1.second != ic+2)
- return false;
-
- auto p2 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), std::end(ic));
- if (p2.first != ia+2 || p2.second != ic+2)
- return false;
-
- auto p3 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic));
- if (p3.first != ib+2 || p3.second != ic+2)
- return false;
-
- auto p4 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), std::end(ic));
- if (p4.first != ib+2 || p4.second != ic+2)
- return false;
-
- auto p5 = std::mismatch(II(std::begin(ib)), II(std::end(ib)), II(std::begin(ic)));
- if (p5.first != II(ib+2) || p5.second != II(ic+2))
- return false;
- auto p6 = std::mismatch(BI(std::begin(ib)), BI(std::end(ib)), BI(std::begin(ic)), BI(std::end(ic)));
- if (p6.first != BI(ib+2) || p6.second != BI(ic+2))
- return false;
-
- return true;
- }
+#include "type_algorithms.h"
+
+template <class Iter, class Container1, class Container2>
+TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
+ if (lhs.size() == rhs.size()) {
+ assert(std::mismatch(Iter(lhs.data()), Iter(lhs.data() + lhs.size()), Iter(rhs.data())) ==
+ std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+
+ assert(std::mismatch(Iter(lhs.data()), Iter(lhs.data() + lhs.size()), Iter(rhs.data()), std::equal_to<int>()) ==
+ std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+ }
+
+#if TEST_STD_VER >= 14
+ assert(
+ std::mismatch(Iter(lhs.data()), Iter(lhs.data() + lhs.size()), Iter(rhs.data()), Iter(rhs.data() + rhs.size())) ==
+ std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
+
+ assert(std::mismatch(Iter(lhs.data()),
+ Iter(lhs.data() + lhs.size()),
+ Iter(rhs.data()),
+ Iter(rhs.data() + rhs.size()),
+ std::equal_to<int>()) == std::make_pair(Iter(lhs.data() + offset), Iter(rhs.data() + offset)));
#endif
+}
-int main(int, char**)
-{
- int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
- const unsigned sa = sizeof(ia)/sizeof(ia[0]);
- int ib[] = {0, 1, 2, 3, 0, 1, 2, 3};
- const unsigned sb = sizeof(ib)/sizeof(ib[0]); ((void)sb); // unused in C++11
-
- typedef cpp17_input_iterator<const int*> II;
- typedef random_access_iterator<const int*> RAI;
-
- assert(std::mismatch(II(ia), II(ia + sa), II(ib))
- == (std::pair<II, II>(II(ia+3), II(ib+3))));
-
- assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib))
- == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
+template <class Iter>
+TEST_CONSTEXPR_CXX20 bool test() {
+ { // empty ranges
+ std::array<int, 0> lhs = {};
+ std::array<int, 0> rhs = {};
+ check<Iter>(lhs, rhs, 0);
+ }
+
+ { // same range without mismatch
+ std::array<int, 8> lhs = {0, 1, 2, 3, 0, 1, 2, 3};
+ std::array<int, 8> rhs = {0, 1, 2, 3, 0, 1, 2, 3};
+ check<Iter>(lhs, rhs, 8);
+ }
+
+ { // same range with mismatch
+ std::array<int, 8> lhs = {0, 1, 2, 2, 0, 1, 2, 3};
+ std::array<int, 8> rhs = {0, 1, 2, 3, 0, 1, 2, 3};
+ check<Iter>(lhs, rhs, 3);
+ }
+
+ { // second range is smaller
+ std::array<int, 8> lhs = {0, 1, 2, 2, 0, 1, 2, 3};
+ std::array<int, 2> rhs = {0, 1};
+ check<Iter>(lhs, rhs, 2);
+ }
+
+ { // first range is smaller
+ std::array<int, 2> lhs = {0, 1};
+ std::array<int, 8> rhs = {0, 1, 2, 2, 0, 1, 2, 3};
+ check<Iter>(lhs, rhs, 2);
+ }
+
+ return true;
+}
-#if TEST_STD_VER > 11 // We have the four iteration version
- assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+sb))
- == (std::pair<II, II>(II(ia+3), II(ib+3))));
+struct Test {
+ template <class Iter>
+ TEST_CONSTEXPR_CXX20 void operator()() {
+ test<Iter>();
+ }
+};
- assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib+sb))
- == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
+TEST_CONSTEXPR_CXX20 bool test() {
+ types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
+ return true;
+}
- assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib+2))
- == (std::pair<II, II>(II(ia+2), II(ib+2))));
+int main(int, char**) {
+ test();
+#if TEST_STD_VER >= 20
+ static_assert(test());
#endif
-#if TEST_STD_VER > 17
- static_assert(test_constexpr());
-#endif
+ { // check with a lot of elements to test the vectorization optimization
+ {
+ std::vector<char> lhs(256);
+ std::vector<char> rhs(256);
+ for (size_t i = 0; i != lhs.size(); ++i) {
+ lhs[i] = 1;
+ check<char*>(lhs, rhs, i);
+ lhs[i] = 0;
+ rhs[i] = 1;
+ check<char*>(lhs, rhs, i);
+ rhs[i] = 0;
+ }
+ }
+
+ {
+ std::vector<int> lhs(256);
+ std::vector<int> rhs(256);
+ for (size_t i = 0; i != lhs.size(); ++i) {
+ lhs[i] = 1;
+ check<int*>(lhs, rhs, i);
+ lhs[i] = 0;
+ rhs[i] = 1;
+ check<int*>(lhs, rhs, i);
+ rhs[i] = 0;
+ }
+ }
+ }
return 0;
}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp
deleted file mode 100644
index bda4ec7ba5ed60..00000000000000
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch_pred.pass.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-//===----------------------------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-// <algorithm>
-
-// template<InputIterator Iter1, InputIterator Iter2,
-// Predicate<auto, Iter1::value_type, Iter2::value_type> Pred>
-// requires CopyConstructible<Pred>
-// constexpr pair<Iter1, Iter2> // constexpr after c++17
-// mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Pred pred);
-//
-// template<InputIterator Iter1, InputIterator Iter2, Predicate Pred>
-// constexpr pair<Iter1, Iter2> // constexpr after c++17
-// mismatch(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, Pred pred); // C++14
-
-#include <algorithm>
-#include <functional>
-#include <cassert>
-
-#include "test_macros.h"
-#include "test_iterators.h"
-#include "counting_predicates.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool eq(int a, int b) { return a == b; }
-
-TEST_CONSTEXPR bool test_constexpr() {
- int ia[] = {1, 3, 6, 7};
- int ib[] = {1, 3};
- int ic[] = {1, 3, 5, 7};
- typedef cpp17_input_iterator<int*> II;
- typedef bidirectional_iterator<int*> BI;
-
- auto p1 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), eq);
- if (p1.first != ia+2 || p1.second != ic+2)
- return false;
-
- auto p2 = std::mismatch(std::begin(ia), std::end(ia), std::begin(ic), std::end(ic), eq);
- if (p2.first != ia+2 || p2.second != ic+2)
- return false;
-
- auto p3 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), eq);
- if (p3.first != ib+2 || p3.second != ic+2)
- return false;
-
- auto p4 = std::mismatch(std::begin(ib), std::end(ib), std::begin(ic), std::end(ic), eq);
- if (p4.first != ib+2 || p4.second != ic+2)
- return false;
-
- auto p5 = std::mismatch(II(std::begin(ib)), II(std::end(ib)), II(std::begin(ic)), eq);
- if (p5.first != II(ib+2) || p5.second != II(ic+2))
- return false;
- auto p6 = std::mismatch(BI(std::begin(ib)), BI(std::end(ib)), BI(std::begin(ic)), BI(std::end(ic)), eq);
- if (p6.first != BI(ib+2) || p6.second != BI(ic+2))
- return false;
-
- return true;
- }
-#endif
-
-
-#if TEST_STD_VER > 11
-#define HAS_FOUR_ITERATOR_VERSION
-#endif
-
-int main(int, char**)
-{
- int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
- const unsigned sa = sizeof(ia)/sizeof(ia[0]);
- int ib[] = {0, 1, 2, 3, 0, 1, 2, 3};
- const unsigned sb = sizeof(ib)/sizeof(ib[0]); ((void)sb); // unused in C++11
-
- typedef cpp17_input_iterator<const int*> II;
- typedef random_access_iterator<const int*> RAI;
- typedef std::equal_to<int> EQ;
-
- assert(std::mismatch(II(ia), II(ia + sa), II(ib), EQ())
- == (std::pair<II, II>(II(ia+3), II(ib+3))));
- assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), EQ())
- == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-
- binary_counting_predicate<EQ, int> bcp((EQ()));
- assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), std::ref(bcp))
- == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
- assert(bcp.count() > 0 && bcp.count() < sa);
- bcp.reset();
-
-#if TEST_STD_VER >= 14
- assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + sb), EQ())
- == (std::pair<II, II>(II(ia+3), II(ib+3))));
- assert(std::mismatch(RAI(ia), RAI(ia + sa), RAI(ib), RAI(ib + sb), EQ())
- == (std::pair<RAI, RAI>(RAI(ia+3), RAI(ib+3))));
-
- assert(std::mismatch(II(ia), II(ia + sa), II(ib), II(ib + sb), std::ref(bcp))
- == (std::pair<II, II>(II(ia+3), II(ib+3))));
- assert(bcp.count() > 0 && bcp.count() < std::min(sa, sb));
-#endif
-
- assert(std::mismatch(ia, ia + sa, ib, EQ()) ==
- (std::pair<int*,int*>(ia+3,ib+3)));
-
-#if TEST_STD_VER >= 14
- assert(std::mismatch(ia, ia + sa, ib, ib + sb, EQ()) ==
- (std::pair<int*,int*>(ia+3,ib+3)));
- assert(std::mismatch(ia, ia + sa, ib, ib + 2, EQ()) ==
- (std::pair<int*,int*>(ia+2,ib+2)));
-#endif
-
-#if TEST_STD_VER > 17
- static_assert(test_constexpr());
-#endif
-
- return 0;
-}
>From e1c9ff33e89ef62f322e73469aab5e9e7fa52b8e Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 27 Feb 2024 11:45:41 +0100
Subject: [PATCH 2/2] [libc++] Optimize mismatch tail
---
.../benchmarks/algorithms/mismatch.bench.cpp | 15 ++++-
libcxx/include/__algorithm/mismatch.h | 60 ++++++++++++++-----
.../mismatch/mismatch.pass.cpp | 28 +++++++++
3 files changed, 86 insertions(+), 17 deletions(-)
diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
index 9274932a764c55..06289068bb0492 100644
--- a/libcxx/benchmarks/algorithms/mismatch.bench.cpp
+++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp
@@ -10,6 +10,15 @@
#include <benchmark/benchmark.h>
#include <random>
+void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
+ Benchmark->DenseRange(1, 8);
+ for (size_t i = 16; i != 1 << 20; i *= 2) {
+ Benchmark->Arg(i - 1);
+ Benchmark->Arg(i);
+ Benchmark->Arg(i + 1);
+ }
+}
+
// TODO: Look into benchmarking aligned and unaligned memory explicitly
// (currently things happen to be aligned because they are malloced that way)
template <class T>
@@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) {
benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
}
}
-BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
-BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_mismatch<char>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<short>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_mismatch<int>)->Apply(BenchmarkSizes);
BENCHMARK_MAIN();
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index d0234e793cbca3..776b08e279fdbd 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -64,27 +64,59 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
constexpr size_t __unroll_count = 4;
constexpr size_t __vec_size = __native_vector_size<_Tp>;
using __vec = __simd_vector<_Tp, __vec_size>;
- while (!__libcpp_is_constant_evaluated() && static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size)
- [[__unlikely__]] {
- __vec __lhs[__unroll_count];
- __vec __rhs[__unroll_count];
-
- for (size_t __i = 0; __i != __unroll_count; ++__i) {
- __lhs[__i] = std::__load_vector<__vec>(__first1 + __i * __vec_size);
- __rhs[__i] = std::__load_vector<__vec>(__first2 + __i * __vec_size);
+
+ if (!__libcpp_is_constant_evaluated()) {
+ auto __orig_first1 = __first1;
+ auto __last2 = __first2 + (__last1 - __first1);
+ while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
+ __vec __lhs[__unroll_count];
+ __vec __rhs[__unroll_count];
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ __lhs[__i] = std::__load_vector<__vec>(__first1 + __i * __vec_size);
+ __rhs[__i] = std::__load_vector<__vec>(__first2 + __i * __vec_size);
+ }
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ if (auto __cmp_res = __lhs[__i] == __rhs[__i]; !std::__all_of(__cmp_res)) {
+ auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res);
+ return {__first1 + __offset, __first2 + __offset};
+ }
+ }
+
+ __first1 += __unroll_count * __vec_size;
+ __first2 += __unroll_count * __vec_size;
}
- for (size_t __i = 0; __i != __unroll_count; ++__i) {
- if (auto __cmp_res = __lhs[__i] == __rhs[__i]; !std::__all_of(__cmp_res)) {
- auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res);
+ // check the remaining 0-3 vectors
+ while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
+ if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
+ !std::__all_of(__cmp_res)) {
+ auto __offset = std::__find_first_not_set(__cmp_res);
return {__first1 + __offset, __first2 + __offset};
}
+ __first1 += __vec_size;
+ __first2 += __vec_size;
}
- __first1 += __unroll_count * __vec_size;
- __first2 += __unroll_count * __vec_size;
+ if (__last1 - __first1 == 0)
+ return {__first1, __first2};
+
+ // Check if we can load elements in fron of the current pointer. If that's the case load a vector at
+ // (last - vector_size) to check the remaining elements
+ if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
+ __first1 = __last1 - __vec_size;
+ __first2 = __last2 - __vec_size;
+ auto __offset =
+ std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
+ return {__first1 + __offset, __first2 + __offset};
+ } // else loop over the elements individually
+
+ // TODO: Consider vectorizing the loop tail further with
+ // - smaller vectors
+ // - loading bytes out of range if it's known to be safe
}
- // TODO: Consider vectorizing the tail
+
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index f0716a1af1e1ef..7e508834e869d2 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -144,5 +144,33 @@ int main(int, char**) {
}
}
+ { // check the tail of the vectorized loop
+ for (size_t vec_size = 1; vec_size != 256; ++vec_size) {
+ {
+ std::vector<char> lhs(256);
+ std::vector<char> rhs(256);
+
+ check<char*>(lhs, rhs, lhs.size());
+ lhs.back() = 1;
+ check<char*>(lhs, rhs, lhs.size() - 1);
+ lhs.back() = 0;
+ rhs.back() = 1;
+ check<char*>(lhs, rhs, lhs.size() - 1);
+ rhs.back() = 0;
+ }
+ {
+ std::vector<int> lhs(256);
+ std::vector<int> rhs(256);
+
+ check<int*>(lhs, rhs, lhs.size());
+ lhs.back() = 1;
+ check<int*>(lhs, rhs, lhs.size() - 1);
+ lhs.back() = 0;
+ rhs.back() = 1;
+ check<int*>(lhs, rhs, lhs.size() - 1);
+ rhs.back() = 0;
+ }
+ }
+ }
return 0;
}
More information about the libcxx-commits
mailing list