[libcxx-commits] [libcxx] [libc++] Vectorize std::adjacent_find (PR #89757)

via libcxx-commits libcxx-commits at lists.llvm.org
Thu May 30 09:09:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libcxx

Author: Nikolas Klauser (philnik777)

<details>
<summary>Changes</summary>

```
--------------------------------------------------------------
Benchmark                                  old             new
--------------------------------------------------------------
bm_adjacent_find<char>/1              0.317 ns         1.27 ns
bm_adjacent_find<char>/2              0.533 ns         1.49 ns
bm_adjacent_find<char>/3              0.757 ns         1.91 ns
bm_adjacent_find<char>/4              0.986 ns         2.33 ns
bm_adjacent_find<char>/5               1.20 ns         2.75 ns
bm_adjacent_find<char>/6               1.45 ns         3.17 ns
bm_adjacent_find<char>/7               1.63 ns         3.60 ns
bm_adjacent_find<char>/8               1.83 ns         4.03 ns
bm_adjacent_find<char>/15              3.61 ns         6.96 ns
bm_adjacent_find<char>/16              3.81 ns         7.37 ns
bm_adjacent_find<char>/17              5.07 ns         7.79 ns
bm_adjacent_find<char>/31              9.32 ns         13.7 ns
bm_adjacent_find<char>/32              9.41 ns         14.1 ns
bm_adjacent_find<char>/33              10.2 ns         1.92 ns
bm_adjacent_find<char>/63              18.5 ns         14.6 ns
bm_adjacent_find<char>/64              18.2 ns         15.0 ns
bm_adjacent_find<char>/65              18.2 ns         2.58 ns
bm_adjacent_find<char>/127             31.2 ns         3.02 ns
bm_adjacent_find<char>/128             31.3 ns         3.01 ns
bm_adjacent_find<char>/129             33.6 ns         2.78 ns
bm_adjacent_find<char>/255             58.0 ns         4.71 ns
bm_adjacent_find<char>/256             58.6 ns         4.72 ns
bm_adjacent_find<char>/257             58.4 ns         4.28 ns
bm_adjacent_find<char>/511              112 ns         7.91 ns
bm_adjacent_find<char>/512              112 ns         7.86 ns
bm_adjacent_find<char>/513              112 ns         7.75 ns
bm_adjacent_find<char>/1023             220 ns         15.3 ns
bm_adjacent_find<char>/1024             219 ns         15.3 ns
bm_adjacent_find<char>/1025             219 ns         13.9 ns
bm_adjacent_find<char>/2047             436 ns         27.5 ns
bm_adjacent_find<char>/2048             436 ns         27.4 ns
bm_adjacent_find<char>/2049             434 ns         26.4 ns
bm_adjacent_find<char>/4095             863 ns         53.0 ns
bm_adjacent_find<char>/4096             863 ns         52.7 ns
bm_adjacent_find<char>/4097             862 ns         52.0 ns
bm_adjacent_find<char>/8191            1720 ns          104 ns
bm_adjacent_find<char>/8192            1726 ns          104 ns
bm_adjacent_find<char>/8193            1721 ns          102 ns
bm_adjacent_find<char>/16383           3436 ns          210 ns
bm_adjacent_find<char>/16384           3445 ns          211 ns
bm_adjacent_find<char>/16385           3442 ns          210 ns
bm_adjacent_find<char>/32767           7015 ns          425 ns
bm_adjacent_find<char>/32768           7033 ns          424 ns
bm_adjacent_find<char>/32769           7047 ns          426 ns
bm_adjacent_find<char>/65535          14046 ns          884 ns
bm_adjacent_find<char>/65536          14046 ns          884 ns
bm_adjacent_find<char>/65537          14046 ns          887 ns
bm_adjacent_find<char>/131071         28426 ns         1754 ns
bm_adjacent_find<char>/131072         28233 ns         1754 ns
bm_adjacent_find<char>/131073         28335 ns         1754 ns
bm_adjacent_find<char>/262143         56740 ns         3517 ns
bm_adjacent_find<char>/262144         56228 ns         3508 ns
bm_adjacent_find<char>/262145         56085 ns         3530 ns
bm_adjacent_find<char>/524287        112343 ns         7158 ns
bm_adjacent_find<char>/524288        112179 ns         7154 ns
bm_adjacent_find<char>/524289        112175 ns         7162 ns
bm_adjacent_find<short>/1             0.318 ns         1.27 ns
bm_adjacent_find<short>/2             0.580 ns         1.49 ns
bm_adjacent_find<short>/3             0.783 ns         1.91 ns
bm_adjacent_find<short>/4              1.03 ns         2.14 ns
bm_adjacent_find<short>/5              1.23 ns         2.33 ns
bm_adjacent_find<short>/6              1.51 ns         2.55 ns
bm_adjacent_find<short>/7              1.69 ns         2.75 ns
bm_adjacent_find<short>/8              1.86 ns         2.97 ns
bm_adjacent_find<short>/15             6.09 ns         5.19 ns
bm_adjacent_find<short>/16             6.78 ns         5.29 ns
bm_adjacent_find<short>/17             7.24 ns         2.34 ns
bm_adjacent_find<short>/31             13.1 ns         5.93 ns
bm_adjacent_find<short>/32             13.5 ns         6.14 ns
bm_adjacent_find<short>/33             14.0 ns         2.58 ns
bm_adjacent_find<short>/63             32.1 ns         3.22 ns
bm_adjacent_find<short>/64             32.9 ns         3.22 ns
bm_adjacent_find<short>/65             36.6 ns         2.79 ns
bm_adjacent_find<short>/127            49.5 ns         4.72 ns
bm_adjacent_find<short>/128            49.3 ns         4.73 ns
bm_adjacent_find<short>/129            48.6 ns         4.28 ns
bm_adjacent_find<short>/255            73.9 ns         7.97 ns
bm_adjacent_find<short>/256            74.4 ns         7.86 ns
bm_adjacent_find<short>/257            75.7 ns         7.38 ns
bm_adjacent_find<short>/511             128 ns         14.9 ns
bm_adjacent_find<short>/512             134 ns         14.6 ns
bm_adjacent_find<short>/513             131 ns         14.6 ns
bm_adjacent_find<short>/1023            246 ns         27.4 ns
bm_adjacent_find<short>/1024            245 ns         27.2 ns
bm_adjacent_find<short>/1025            238 ns         26.4 ns
bm_adjacent_find<short>/2047            459 ns         53.8 ns
bm_adjacent_find<short>/2048            452 ns         53.8 ns
bm_adjacent_find<short>/2049            463 ns         53.6 ns
bm_adjacent_find<short>/4095            918 ns          105 ns
bm_adjacent_find<short>/4096            920 ns          105 ns
bm_adjacent_find<short>/4097            910 ns          105 ns
bm_adjacent_find<short>/8191           1759 ns          214 ns
bm_adjacent_find<short>/8192           1746 ns          214 ns
bm_adjacent_find<short>/8193           1739 ns          215 ns
bm_adjacent_find<short>/16383          3563 ns          433 ns
bm_adjacent_find<short>/16384          3563 ns          434 ns
bm_adjacent_find<short>/16385          3563 ns          432 ns
bm_adjacent_find<short>/32767          7147 ns          894 ns
bm_adjacent_find<short>/32768          7125 ns          893 ns
bm_adjacent_find<short>/32769          7123 ns          894 ns
bm_adjacent_find<short>/65535         14229 ns         1772 ns
bm_adjacent_find<short>/65536         14222 ns         1770 ns
bm_adjacent_find<short>/65537         14215 ns         1772 ns
bm_adjacent_find<short>/131071        28435 ns         3528 ns
bm_adjacent_find<short>/131072        28388 ns         3537 ns
bm_adjacent_find<short>/131073        28377 ns         3528 ns
bm_adjacent_find<short>/262143        56742 ns         7184 ns
bm_adjacent_find<short>/262144        56713 ns         7161 ns
bm_adjacent_find<short>/262145        56683 ns         7163 ns
bm_adjacent_find<short>/524287       114215 ns        14991 ns
bm_adjacent_find<short>/524288       114503 ns        15067 ns
bm_adjacent_find<short>/524289       114622 ns        14981 ns
bm_adjacent_find<int>/1               0.321 ns         1.70 ns
bm_adjacent_find<int>/2               0.564 ns         1.90 ns
bm_adjacent_find<int>/3               0.756 ns         2.33 ns
bm_adjacent_find<int>/4               1.000 ns         2.55 ns
bm_adjacent_find<int>/5                1.23 ns         2.92 ns
bm_adjacent_find<int>/6                1.47 ns         3.31 ns
bm_adjacent_find<int>/7                1.65 ns         3.64 ns
bm_adjacent_find<int>/8                1.94 ns         3.75 ns
bm_adjacent_find<int>/15               4.50 ns         3.96 ns
bm_adjacent_find<int>/16               4.79 ns         4.35 ns
bm_adjacent_find<int>/17               5.47 ns         2.58 ns
bm_adjacent_find<int>/31               9.40 ns         3.22 ns
bm_adjacent_find<int>/32               10.1 ns         3.23 ns
bm_adjacent_find<int>/33               10.2 ns         3.22 ns
bm_adjacent_find<int>/63               17.7 ns         5.16 ns
bm_adjacent_find<int>/64               17.9 ns         5.14 ns
bm_adjacent_find<int>/65               18.2 ns         5.13 ns
bm_adjacent_find<int>/127              31.2 ns         9.06 ns
bm_adjacent_find<int>/128              31.4 ns         9.00 ns
bm_adjacent_find<int>/129              31.6 ns         9.23 ns
bm_adjacent_find<int>/255              57.9 ns         17.6 ns
bm_adjacent_find<int>/256              58.1 ns         17.6 ns
bm_adjacent_find<int>/257              58.3 ns         18.0 ns
bm_adjacent_find<int>/511               111 ns         34.0 ns
bm_adjacent_find<int>/512               112 ns         34.1 ns
bm_adjacent_find<int>/513               112 ns         34.2 ns
bm_adjacent_find<int>/1023              219 ns         67.8 ns
bm_adjacent_find<int>/1024              219 ns         67.7 ns
bm_adjacent_find<int>/1025              219 ns         68.0 ns
bm_adjacent_find<int>/2047              433 ns          135 ns
bm_adjacent_find<int>/2048              434 ns          135 ns
bm_adjacent_find<int>/2049              434 ns          135 ns
bm_adjacent_find<int>/4095              863 ns          269 ns
bm_adjacent_find<int>/4096              863 ns          270 ns
bm_adjacent_find<int>/4097              863 ns          271 ns
bm_adjacent_find<int>/8191             1799 ns          538 ns
bm_adjacent_find<int>/8192             1799 ns          538 ns
bm_adjacent_find<int>/8193             1799 ns          539 ns
bm_adjacent_find<int>/16383            3652 ns         1064 ns
bm_adjacent_find<int>/16384            3656 ns         1064 ns
bm_adjacent_find<int>/16385            3653 ns         1068 ns
bm_adjacent_find<int>/32767            7299 ns         2110 ns
bm_adjacent_find<int>/32768            7325 ns         2110 ns
bm_adjacent_find<int>/32769            7287 ns         2114 ns
bm_adjacent_find<int>/65535           14513 ns         4202 ns
bm_adjacent_find<int>/65536           14524 ns         4188 ns
bm_adjacent_find<int>/65537           14534 ns         4202 ns
bm_adjacent_find<int>/131071          29374 ns         8534 ns
bm_adjacent_find<int>/131072          29376 ns         8534 ns
bm_adjacent_find<int>/131073          29360 ns         8535 ns
bm_adjacent_find<int>/262143          59728 ns        17492 ns
bm_adjacent_find<int>/262144          59691 ns        17505 ns
bm_adjacent_find<int>/262145          59651 ns        17495 ns
bm_adjacent_find<int>/524287         119947 ns        34309 ns
bm_adjacent_find<int>/524288         120793 ns        34219 ns
bm_adjacent_find<int>/524289         121228 ns        34155 ns
```

---
Full diff: https://github.com/llvm/llvm-project/pull/89757.diff


4 Files Affected:

- (modified) libcxx/benchmarks/CMakeLists.txt (+1) 
- (added) libcxx/benchmarks/algorithms/adjacent_find.bench.cpp (+42) 
- (modified) libcxx/include/__algorithm/adjacent_find.h (+80-2) 
- (modified) libcxx/test/std/algorithms/alg.nonmodifying/alg.adjacent.find/adjacent_find.pass.cpp (+99-25) 


``````````diff
diff --git a/libcxx/benchmarks/CMakeLists.txt b/libcxx/benchmarks/CMakeLists.txt
index 5dc3be0c367e5..14d8799c509c8 100644
--- a/libcxx/benchmarks/CMakeLists.txt
+++ b/libcxx/benchmarks/CMakeLists.txt
@@ -173,6 +173,7 @@ endfunction()
 #==============================================================================
 set(BENCHMARK_TESTS
     algorithms.partition_point.bench.cpp
+    algorithms/adjacent_find.bench.cpp
     algorithms/count.bench.cpp
     algorithms/equal.bench.cpp
     algorithms/find.bench.cpp
diff --git a/libcxx/benchmarks/algorithms/adjacent_find.bench.cpp b/libcxx/benchmarks/algorithms/adjacent_find.bench.cpp
new file mode 100644
index 0000000000000..4467acbcfbefd
--- /dev/null
+++ b/libcxx/benchmarks/algorithms/adjacent_find.bench.cpp
@@ -0,0 +1,42 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <benchmark/benchmark.h>
+#include <random>
+
+void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
+  Benchmark->DenseRange(1, 8);
+  for (size_t i = 16; i != 1 << 20; i *= 2) {
+    Benchmark->Arg(i - 1);
+    Benchmark->Arg(i);
+    Benchmark->Arg(i + 1);
+  }
+}
+
+// TODO: Look into benchmarking aligned and unaligned memory explicitly
+// (currently things happen to be aligned because they are malloced that way)
+template <class T>
+static void bm_adjacent_find(benchmark::State& state) {
+  std::vector<T> vec1(state.range());
+
+  size_t val = 1;
+  for (auto& e : vec1) {
+    e = val++;
+  }
+
+  for (auto _ : state) {
+    benchmark::DoNotOptimize(vec1);
+    benchmark::DoNotOptimize(std::adjacent_find(vec1.begin(), vec1.end()));
+  }
+}
+BENCHMARK(bm_adjacent_find<char>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_adjacent_find<short>)->Apply(BenchmarkSizes);
+BENCHMARK(bm_adjacent_find<int>)->Apply(BenchmarkSizes);
+
+BENCHMARK_MAIN();
diff --git a/libcxx/include/__algorithm/adjacent_find.h b/libcxx/include/__algorithm/adjacent_find.h
index 6f15456e3a4d0..a456724b80b21 100644
--- a/libcxx/include/__algorithm/adjacent_find.h
+++ b/libcxx/include/__algorithm/adjacent_find.h
@@ -12,9 +12,14 @@
 
 #include <__algorithm/comp.h>
 #include <__algorithm/iterator_operations.h>
+#include <__algorithm/simd_utils.h>
+#include <__algorithm/unwrap_iter.h>
 #include <__config>
 #include <__iterator/iterator_traits.h>
+#include <__type_traits/desugars_to.h>
+#include <__type_traits/is_constant_evaluated.h>
 #include <__utility/move.h>
+#include <cstddef>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
@@ -27,7 +32,7 @@ _LIBCPP_BEGIN_NAMESPACE_STD
 
 template <class _Iter, class _Sent, class _BinaryPredicate>
 _LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Iter
-__adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
+__adjacent_find_loop(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
   if (__first == __last)
     return __first;
   _Iter __i = __first;
@@ -39,10 +44,83 @@ __adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
   return __i;
 }
 
+template <class _Iter, class _Sent, class _BinaryPredicate>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Iter
+__adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
+  return std::__adjacent_find_loop(__first, __last, __pred);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+
+template <class _Tp, class _Pred>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Tp*
+__adjacent_find_vectorized(_Tp* __first, _Tp* __last, _Pred& __pred) {
+  constexpr size_t __unroll_count = 4;
+  constexpr size_t __vec_size     = __native_vector_size<_Tp>;
+  using __vec                     = __simd_vector<_Tp, __vec_size>;
+
+  if (!__libcpp_is_constant_evaluated()) {
+    auto __orig_first = __first;
+    while (static_cast<size_t>(__last - __first) > __unroll_count * __vec_size) [[__unlikely__]] {
+      __vec __cmp_res[__unroll_count];
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        __cmp_res[__i] = std::__load_vector<__vec>(__first + __i * __vec_size) !=
+                         std::__load_vector<__vec>(__first + __i * __vec_size + 1);
+      }
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        if (!std::__all_of(__cmp_res[__i])) {
+          auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res[__i]);
+          return __first + __offset;
+        }
+      }
+
+      __first += __unroll_count * __vec_size;
+    }
+
+    // check the last 0-3 vectors
+    while (static_cast<size_t>(__last - __first) > __vec_size) [[__unlikely__]] {
+      if (auto __cmp_res = std::__load_vector<__vec>(__first) != std::__load_vector<__vec>(__first + 1);
+          !std::__all_of(__cmp_res)) {
+        auto __offset = std::__find_first_not_set(__cmp_res);
+        return __first + __offset;
+      }
+      __first += __vec_size;
+    }
+
+    if (__first == __last)
+      return __first;
+
+    // Check if we can load elements in front of the current pointer. If that's the case load a vector at
+    // (last - vector_size - 1) to check the remaining elements
+    if (static_cast<size_t>(__first - __orig_first) > __vec_size) {
+      __first = __last - __vec_size - 1;
+      auto __offset =
+          std::__find_first_not_set(std::__load_vector<__vec>(__first) != std::__load_vector<__vec>(__first + 1));
+      if (__offset == __vec_size)
+        return __last;
+      return __first + __offset;
+    }
+  } // else loop over the elements individually
+  return std::__adjacent_find_loop(__first, __last, __pred);
+}
+
+template <class _Tp,
+          class _Pred,
+          __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp>, int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Tp*
+__adjacent_find(_Tp* __first, _Tp* __last, _Pred& __pred) {
+  return std::__adjacent_find_vectorized(__first, __last, __pred);
+}
+
+#endif // _LIBCPP_VECTORIZE_ALGORITHMS
+
 template <class _ForwardIterator, class _BinaryPredicate>
 _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
 adjacent_find(_ForwardIterator __first, _ForwardIterator __last, _BinaryPredicate __pred) {
-  return std::__adjacent_find(std::move(__first), std::move(__last), __pred);
+  return std::__rewrap_iter(
+      __first, std::__adjacent_find(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __pred));
 }
 
 template <class _ForwardIterator>
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.adjacent.find/adjacent_find.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.adjacent.find/adjacent_find.pass.cpp
index 6d57c5869ab70..94d2947cf629f 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.adjacent.find/adjacent_find.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.adjacent.find/adjacent_find.pass.cpp
@@ -14,39 +14,113 @@
 //   adjacent_find(Iter first, Iter last);
 
 #include <algorithm>
+#include <array>
 #include <cassert>
+#include <vector>
 
 #include "test_macros.h"
 #include "test_iterators.h"
 
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
-    int ib[] = {0, 1, 2, 7, 0, 1, 2, 3};
+struct Test {
+  template <class Iter>
+  TEST_CONSTEXPR_CXX20 void operator()() {
+    int ia[]          = {0, 1, 2, 2, 0, 1, 2, 3};
+    const unsigned sa = sizeof(ia) / sizeof(ia[0]);
+    assert(std::adjacent_find(Iter(ia), Iter(ia + sa)) == Iter(ia + 2));
+    assert(std::adjacent_find(Iter(ia), Iter(ia)) == Iter(ia));
+    assert(std::adjacent_find(Iter(ia + 3), Iter(ia + sa)) == Iter(ia + sa));
+  }
+};
 
-    return  (std::adjacent_find(std::begin(ia), std::end(ia)) == ia+2)
-         && (std::adjacent_find(std::begin(ib), std::end(ib)) == std::end(ib))
-         ;
-    }
-#endif
+struct NonTrivial {
+  int i_;
+
+  TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
+  TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+
+  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+};
+
+struct ModTwoComp {
+  TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
+};
+
+TEST_CONSTEXPR_CXX20 bool test() {
+  types::for_each(types::forward_iterator_list<int*>(), Test());
+
+  { // use a non-integer type to also test the general case - no match
+    std::array<NonTrivial, 8> arr = {1, 2, 3, 4, 5, 6, 7, 8};
+    assert(std::adjacent_find(arr.begin(), arr.end()) == arr.end());
+  }
+
+  { // use a non-integer type to also test the general case - match
+    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 4, 6, 7, 8};
+    assert(std::adjacent_find(lhs.begin(), lhs.end()) == lhs.begin() + 3);
+  }
+
+  { // use a custom comparator
+    std::array<int, 8> lhs = {0, 1, 2, 3, 5, 6, 7, 8};
+    assert(std::adjacent_find(lhs.begin(), lhs.end(), ModTwoComp()) == lhs.begin() + 3);
+  }
+
+  return true;
+}
 
-int main(int, char**)
-{
-    int ia[] = {0, 1, 2, 2, 0, 1, 2, 3};
-    const unsigned sa = sizeof(ia)/sizeof(ia[0]);
-    assert(std::adjacent_find(forward_iterator<const int*>(ia),
-                              forward_iterator<const int*>(ia + sa)) ==
-                              forward_iterator<const int*>(ia+2));
-    assert(std::adjacent_find(forward_iterator<const int*>(ia),
-                              forward_iterator<const int*>(ia)) ==
-                              forward_iterator<const int*>(ia));
-    assert(std::adjacent_find(forward_iterator<const int*>(ia+3),
-                              forward_iterator<const int*>(ia + sa)) ==
-                              forward_iterator<const int*>(ia+sa));
-
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
+template <class T>
+void fill_vec(std::vector<T>& vec) {
+  for (size_t i = 0; i != vec.size(); ++i) {
+    vec[i] = static_cast<T>(i);
+  }
+}
+
+int main(int, char**) {
+  test();
+#if TEST_STD_VER >= 20
+  static_assert(test());
 #endif
 
+  { // check with a lot of elements to test the vectorization optimization
+    {
+      std::vector<char> vec(256);
+      fill_vec(vec);
+      for (size_t i = 0; i != vec.size() - 1; ++i) {
+        vec[i] = static_cast<char>(i + 1);
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.begin() + i);
+        vec[i] = static_cast<char>(i);
+      }
+    }
+
+    {
+      std::vector<int> vec(256);
+      fill_vec(vec);
+      for (size_t i = 0; i != vec.size() - 1; ++i) {
+        vec[i] = static_cast<int>(i + 1);
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.begin() + i);
+        vec[i] = static_cast<int>(i);
+      }
+    }
+  }
+
+  { // check the tail of the vectorized loop
+    for (size_t vec_size = 2; vec_size != 256; ++vec_size) {
+      {
+        std::vector<char> vec(vec_size);
+        fill_vec(vec);
+
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.end());
+        vec.back() = static_cast<char>(vec.size() - 2);
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.end() - 2);
+      }
+      {
+        std::vector<int> vec(vec_size);
+        fill_vec(vec);
+
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.end());
+        vec.back() = static_cast<int>(vec.size() - 2);
+        assert(std::adjacent_find(vec.begin(), vec.end()) == vec.end() - 2);
+      }
+    }
+  }
+
   return 0;
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/89757


More information about the libcxx-commits mailing list