[libcxx-commits] [libcxx] [libc++] Vectorize std::mismatch with trivially equality comparable types (PR #87716)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Fri May 10 02:47:29 PDT 2024
https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/87716
>From 5a7d78b8afe577bcbe1a36f859e2f834217624e4 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 25 Mar 2024 08:25:23 +0100
Subject: [PATCH] [libc++] Vectorize trivially equality comparable types
---
libcxx/include/CMakeLists.txt | 1 +
libcxx/include/__algorithm/mismatch.h | 52 ++++++--
libcxx/include/__algorithm/simd_utils.h | 34 ++++-
libcxx/include/__iterator/aliasing_iterator.h | 125 ++++++++++++++++++
.../__type_traits/is_equality_comparable.h | 2 +
libcxx/include/module.modulemap | 2 +
.../iterators/aliasing_iterator.pass.cpp | 45 +++++++
.../mismatch/mismatch.pass.cpp | 47 +++++--
8 files changed, 283 insertions(+), 25 deletions(-)
create mode 100644 libcxx/include/__iterator/aliasing_iterator.h
create mode 100644 libcxx/test/libcxx/iterators/aliasing_iterator.pass.cpp
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index fd7eb125e007b..d5599cd7c58d8 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -446,6 +446,7 @@ set(files
__ios/fpos.h
__iterator/access.h
__iterator/advance.h
+ __iterator/aliasing_iterator.h
__iterator/back_insert_iterator.h
__iterator/bounded_iter.h
__iterator/common_iterator.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index c2b3f8938f711..632bec02406a4 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -16,6 +16,7 @@
#include <__algorithm/unwrap_iter.h>
#include <__config>
#include <__functional/identity.h>
+#include <__iterator/aliasing_iterator.h>
#include <__type_traits/desugars_to.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_constant_evaluated.h>
@@ -55,18 +56,13 @@ __mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Pro
#if _LIBCPP_VECTORIZE_ALGORITHMS
-template <class _Tp,
- class _Pred,
- class _Proj1,
- class _Proj2,
- __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
- __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
- int> = 0>
-_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
-__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+template <class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter, _Iter>
+__mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
+ using __value_type = __iter_value_type<_Iter>;
constexpr size_t __unroll_count = 4;
- constexpr size_t __vec_size = __native_vector_size<_Tp>;
- using __vec = __simd_vector<_Tp, __vec_size>;
+ constexpr size_t __vec_size = __native_vector_size<__value_type>;
+ using __vec = __simd_vector<__value_type, __vec_size>;
if (!__libcpp_is_constant_evaluated()) {
auto __orig_first1 = __first1;
@@ -116,9 +112,41 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
} // else loop over the elements individually
}
- return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+ __equal_to __pred;
+ __identity __proj;
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj, __proj);
+}
+
+template <class _Tp,
+ class _Pred,
+ class _Proj1,
+ class _Proj2,
+ __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+ __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
+ int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred&, _Proj1&, _Proj2&) {
+ return std::__mismatch_vectorized(__first1, __last1, __first2);
}
+template <class _Tp,
+ class _Pred,
+ class _Proj1,
+ class _Proj2,
+ __enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+ __is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
+ __can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
+ int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+ if (__libcpp_is_constant_evaluated()) {
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+ } else {
+ using _Iter = __aliasing_iterator<_Tp*, __get_as_integer_type_t<_Tp>>;
+ auto __ret = std::__mismatch_vectorized(_Iter(__first1), _Iter(__last1), _Iter(__first2));
+ return {__ret.first.__base(), __ret.second.__base()};
+ }
+}
#endif // _LIBCPP_VECTORIZE_ALGORITHMS
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
index 8d540ae2cce88..71d65e8f4afb5 100644
--- a/libcxx/include/__algorithm/simd_utils.h
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -43,6 +43,34 @@ _LIBCPP_PUSH_MACROS
_LIBCPP_BEGIN_NAMESPACE_STD
+template <class _Tp>
+inline constexpr bool __can_map_to_integer_v =
+ sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
+
+template <size_t _TypeSize>
+struct __get_as_integer_type_impl;
+
+template <>
+struct __get_as_integer_type_impl<1> {
+ using type = uint8_t;
+};
+
+template <>
+struct __get_as_integer_type_impl<2> {
+ using type = uint16_t;
+};
+template <>
+struct __get_as_integer_type_impl<4> {
+ using type = uint32_t;
+};
+template <>
+struct __get_as_integer_type_impl<8> {
+ using type = uint64_t;
+};
+
+template <class _Tp>
+using __get_as_integer_type_t = typename __get_as_integer_type_impl<sizeof(_Tp)>::type;
+
// This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
// in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
// as far as benchmarks are concerned.
@@ -80,10 +108,10 @@ template <class _VecT>
using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_type_impl(_VecT{}));
// This isn't inlined without always_inline when loading chars.
-template <class _VecT, class _Tp>
-_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
+template <class _VecT, class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(_Iter __iter) noexcept {
return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
- return _VecT{__ptr[_Indices]...};
+ return _VecT{__iter[_Indices]...};
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
}
diff --git a/libcxx/include/__iterator/aliasing_iterator.h b/libcxx/include/__iterator/aliasing_iterator.h
new file mode 100644
index 0000000000000..e0742491bb5f8
--- /dev/null
+++ b/libcxx/include/__iterator/aliasing_iterator.h
@@ -0,0 +1,125 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+#define _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+
+#include <__config>
+#include <__iterator/iterator_traits.h>
+#include <__memory/pointer_traits.h>
+#include <__type_traits/is_trivial.h>
+#include <cstddef>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+// This iterator wrapper is used to type-pun an iterator to return a different type. This is done without UB by not
+// actually punning the type, but instead inspecting the object representation of the base type and copying that into
+// an instance of the alias type. For that reason the alias type has to be trivial. The alias is returned as a prvalue
+// when derferencing the iterator, since it is temporary storage. This wrapper is used to vectorize some algorithms.
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _BaseIter, class _Alias>
+struct __aliasing_iterator_wrapper {
+ class __iterator {
+ _BaseIter __base_ = nullptr;
+
+ using __iter_traits = iterator_traits<_BaseIter>;
+ using __base_value_type = typename __iter_traits::value_type;
+
+ static_assert(__has_random_access_iterator_category<_BaseIter>::value,
+ "The base iterator has to be a random access iterator!");
+
+ public:
+ using iterator_category = random_access_iterator_tag;
+ using value_type = _Alias;
+ using difference_type = ptrdiff_t;
+
+ static_assert(is_trivial<value_type>::value);
+ static_assert(sizeof(__base_value_type) == sizeof(value_type));
+
+ _LIBCPP_HIDE_FROM_ABI __iterator() = default;
+ _LIBCPP_HIDE_FROM_ABI __iterator(_BaseIter __base) _NOEXCEPT : __base_(__base) {}
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator++() _NOEXCEPT {
+ ++__base_;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator operator++(int) _NOEXCEPT {
+ __iterator __tmp(*this);
+ ++__base_;
+ return __tmp;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator--() _NOEXCEPT {
+ --__base_;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator operator--(int) _NOEXCEPT {
+ __iterator __tmp(*this);
+ --__base_;
+ return __tmp;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(__iterator __iter, difference_type __n) _NOEXCEPT {
+ return __iterator(__iter.__base_ + __n);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(difference_type __n, __iterator __iter) _NOEXCEPT {
+ return __iterator(__n + __iter.__base_);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator+=(difference_type __n) _NOEXCEPT {
+ __base_ += __n;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator-(__iterator __iter, difference_type __n) _NOEXCEPT {
+ return __iterator(__iter.__base_ - __n);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend difference_type operator-(__iterator __lhs, __iterator __rhs) _NOEXCEPT {
+ return __lhs.__base_ - __rhs.__base_;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator-=(difference_type __n) _NOEXCEPT {
+ __base_ -= __n;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI _BaseIter __base() const _NOEXCEPT { return __base_; }
+
+ _LIBCPP_HIDE_FROM_ABI _Alias operator*() const _NOEXCEPT {
+ _Alias __val;
+ __builtin_memcpy(&__val, std::__to_address(__base_), sizeof(value_type));
+ return __val;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI value_type operator[](difference_type __n) const _NOEXCEPT { return *(*this + __n); }
+
+ _LIBCPP_HIDE_FROM_ABI friend bool operator==(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+ return __lhs.__base_ == __rhs.__base_;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend bool operator!=(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+ return __lhs.__base_ != __rhs.__base_;
+ }
+ };
+};
+
+// This is required to avoid ADL instantiations on _BaseT
+template <class _BaseT, class _Alias>
+using __aliasing_iterator = __aliasing_iterator_wrapper<_BaseT, _Alias>::__iterator;
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
diff --git a/libcxx/include/__type_traits/is_equality_comparable.h b/libcxx/include/__type_traits/is_equality_comparable.h
index d4142218b641a..4397f743e5ee9 100644
--- a/libcxx/include/__type_traits/is_equality_comparable.h
+++ b/libcxx/include/__type_traits/is_equality_comparable.h
@@ -44,6 +44,8 @@ struct __is_equality_comparable<_Tp, _Up, __void_t<decltype(std::declval<_Tp>()
// pointers that don't have the same type (ignoring cv-qualifiers): pointers to virtual bases are equality comparable,
// but don't have the same bit-pattern. An exception to this is comparing to a void-pointer. There the bit-pattern is
// always compared.
+// objects with padding bytes: since objects with padding bytes may compare equal, even though their object
+// representation may not be equivalent.
template <class _Tp, class _Up, class = void>
struct __libcpp_is_trivially_equality_comparable_impl : false_type {};
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index 2974d12500c4c..0c11a57a47743 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -700,6 +700,7 @@ module std_private_algorithm_minmax_element [system
module std_private_algorithm_mismatch [system] {
header "__algorithm/mismatch.h"
export std_private_algorithm_simd_utils
+ export std_private_iterator_aliasing_iterator
}
module std_private_algorithm_move [system] { header "__algorithm/move.h" }
module std_private_algorithm_move_backward [system] { header "__algorithm/move_backward.h" }
@@ -1390,6 +1391,7 @@ module std_private_iosfwd_streambuf_fwd [system] { header "__fwd/streambuf.h" }
module std_private_iterator_access [system] { header "__iterator/access.h" }
module std_private_iterator_advance [system] { header "__iterator/advance.h" }
+module std_private_iterator_aliasing_iterator [system] { header "__iterator/aliasing_iterator.h" }
module std_private_iterator_back_insert_iterator [system] { header "__iterator/back_insert_iterator.h" }
module std_private_iterator_bounded_iter [system] { header "__iterator/bounded_iter.h" }
module std_private_iterator_common_iterator [system] { header "__iterator/common_iterator.h" }
diff --git a/libcxx/test/libcxx/iterators/aliasing_iterator.pass.cpp b/libcxx/test/libcxx/iterators/aliasing_iterator.pass.cpp
new file mode 100644
index 0000000000000..60587d5bfe5d7
--- /dev/null
+++ b/libcxx/test/libcxx/iterators/aliasing_iterator.pass.cpp
@@ -0,0 +1,45 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// ADDITIONAL_COMPILE_FLAGS(clang): -Wprivate-header
+
+#include <__iterator/aliasing_iterator.h>
+#include <cassert>
+
+struct NonTrivial {
+ int i_;
+
+ NonTrivial(int i) : i_(i) {}
+ NonTrivial(const NonTrivial& other) : i_(other.i_) {}
+
+ NonTrivial& operator=(const NonTrivial& other) {
+ i_ = other.i_;
+ return *this;
+ }
+
+ ~NonTrivial() {}
+};
+
+int main(int, char**) {
+ {
+ NonTrivial arr[] = {1, 2, 3, 4};
+ std::__aliasing_iterator<NonTrivial*, int> iter(arr);
+
+ assert(*iter == 1);
+ assert(iter[0] == 1);
+ assert(iter[1] == 2);
+ ++iter;
+ assert(*iter == 2);
+ assert(iter[-1] == 1);
+ assert(iter.__base() == arr + 1);
+ assert(iter == iter);
+ assert(iter != (iter + 1));
+ }
+
+ return 0;
+}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index eb5f7cacdde34..dd37555ffcce5 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
#endif
}
-struct NonTrivial {
+// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
+struct NonTrivialMod4Comp {
int i_;
- TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
- TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+ TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
+ TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
- TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+ TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
+ return lhs.i_ % 4 == rhs.i_ % 4;
+ }
+};
+
+#if TEST_STD_VER >= 20
+struct TriviallyEqualityComparable {
+ int i_;
+
+ TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
+
+ TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
};
+#endif // TEST_STD_VER >= 20
struct ModTwoComp {
TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
{ // use a non-integer type to also test the general case - all elements match
- std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
- std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
- check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
+ std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
+ check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
}
{ // use a non-integer type to also test the general case - not all elements match
- std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
- std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
- check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
+ std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+ std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
+ }
+
+#if TEST_STD_VER >= 20
+ { // trivially equality comparable class type to test forwarding to the vectorized version - all elements match
+ std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
+ }
+
+ { // trivially equality comparable class type to test forwarding to the vectorized version - not all elements match
+ std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+ std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
}
+#endif // TEST_STD_VER >= 20
return true;
}
More information about the libcxx-commits
mailing list