[libcxx-commits] [libcxx] [libc++] Vectorize std::adjacent_find (PR #89757)
Louis Dionne via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Aug 22 09:34:43 PDT 2024
================
@@ -39,10 +44,83 @@ __adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
return __i;
}
+template <class _Iter, class _Sent, class _BinaryPredicate>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Iter
+__adjacent_find(_Iter __first, _Sent __last, _BinaryPredicate&& __pred) {
+ return std::__adjacent_find_loop(__first, __last, __pred);
+}
+
+#if _LIBCPP_VECTORIZE_ALGORITHMS
+
+template <class _Tp, class _Pred>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _Tp*
+__adjacent_find_vectorized(_Tp* __first, _Tp* __last, _Pred& __pred) {
+ constexpr size_t __unroll_count = 4;
+ constexpr size_t __vec_size = __native_vector_size<_Tp>;
+ using __vec = __simd_vector<_Tp, __vec_size>;
+
+ if (!__libcpp_is_constant_evaluated()) {
+ auto __orig_first = __first;
+ while (static_cast<size_t>(__last - __first) > __unroll_count * __vec_size) [[__unlikely__]] {
+ __vec __cmp_res[__unroll_count];
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ __cmp_res[__i] = std::__load_vector<__vec>(__first + __i * __vec_size) !=
+ std::__load_vector<__vec>(__first + __i * __vec_size + 1);
+ }
+
+ for (size_t __i = 0; __i != __unroll_count; ++__i) {
+ if (!std::__all_of(__cmp_res[__i])) {
+ auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res[__i]);
+ return __first + __offset;
+ }
+ }
+
+ __first += __unroll_count * __vec_size;
+ }
+
+ // check the last 0-3 vectors
+ while (static_cast<size_t>(__last - __first) > __vec_size) [[__unlikely__]] {
+ if (auto __cmp_res = std::__load_vector<__vec>(__first) != std::__load_vector<__vec>(__first + 1);
+ !std::__all_of(__cmp_res)) {
+ auto __offset = std::__find_first_not_set(__cmp_res);
+ return __first + __offset;
+ }
+ __first += __vec_size;
+ }
+
+ if (__first == __last)
+ return __first;
+
+ // Check if we can load elements in front of the current pointer. If that's the case load a vector at
+ // (last - vector_size - 1) to check the remaining elements
+ if (static_cast<size_t>(__last - __orig_first) > __vec_size) {
+ __first = __last - __vec_size - 1;
+ auto __offset =
+ std::__find_first_not_set(std::__load_vector<__vec>(__first) != std::__load_vector<__vec>(__first + 1));
+ if (__offset == __vec_size)
+ return __last;
+ return __first + __offset;
+ }
+ } // else loop over the elements individually
+ return std::__adjacent_find_loop(__first, __last, __pred);
+}
+
+template <class _Tp,
+ class _Pred,
+ __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp>, int> = 0>
----------------
ldionne wrote:
If we call the predicate above, then this can accept predicates other than `==`. It should work with anything that can be applied to `__vec`s.
https://github.com/llvm/llvm-project/pull/89757
More information about the libcxx-commits
mailing list