[libcxx-commits] [libcxx] [libc++] Vectorize std::adjacent_find (PR #89757)

Louis Dionne via libcxx-commits libcxx-commits at lists.llvm.org
Thu May 30 09:32:03 PDT 2024


================
@@ -39,10 +44,83 @@ __adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
   return __i;
 }
 
+template <class _Iter, class _Sent, class _BinaryPredicate>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Iter
+__adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
+  return std::__adjacent_find_loop(__first, __last, __pred);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+
+template <class _Tp, class _Pred>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Tp*
+__adjacent_find_vectorized(_Tp* __first, _Tp* __last, _Pred& __pred) {
+  constexpr size_t __unroll_count = 4;
+  constexpr size_t __vec_size     = __native_vector_size<_Tp>;
+  using __vec                     = __simd_vector<_Tp, __vec_size>;
+
+  if (!__libcpp_is_constant_evaluated()) {
+    auto __orig_first = __first;
+    while (static_cast<size_t>(__last - __first) > __unroll_count * __vec_size) [[__unlikely__]] {
+      __vec __cmp_res[__unroll_count];
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        __cmp_res[__i] = std::__load_vector<__vec>(__first + __i * __vec_size) !=
+                         std::__load_vector<__vec>(__first + __i * __vec_size + 1);
+      }
+
+      for (size_t __i = 0; __i != __unroll_count; ++__i) {
+        if (!std::__all_of(__cmp_res[__i])) {
+          auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res[__i]);
+          return __first + __offset;
+        }
+      }
+
+      __first += __unroll_count * __vec_size;
+    }
+
+    // check the last 0-3 vectors
+    while (static_cast<size_t>(__last - __first) > __vec_size) [[__unlikely__]] {
----------------
ldionne wrote:

Based on this, you seem to think something's wrong with your determination of when vectors can be loaded:

```
bm_adjacent_find<char>/32              9.41 ns         14.1 ns
bm_adjacent_find<char>/33              10.2 ns         1.92 ns
bm_adjacent_find<char>/63              18.5 ns         14.6 ns
bm_adjacent_find<char>/64              18.2 ns         15.0 ns
bm_adjacent_find<char>/65              18.2 ns         2.58 ns
bm_adjacent_find<char>/127             31.2 ns         3.02 ns
bm_adjacent_find<char>/128             31.3 ns         3.01 ns
```

Just taking a note so you can have a look.

https://github.com/llvm/llvm-project/pull/89757


More information about the libcxx-commits mailing list