[libcxx-commits] [libcxx] 9391330 - [libc++][PSTL] Fix std::copy frontend dispatching
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Mon Jul 31 18:39:08 PDT 2023
Author: Nikolas Klauser
Date: 2023-07-31T18:39:02-07:00
New Revision: 9391330293163deb22e4dddf12c4186d202854ce
URL: https://github.com/llvm/llvm-project/commit/9391330293163deb22e4dddf12c4186d202854ce
DIFF: https://github.com/llvm/llvm-project/commit/9391330293163deb22e4dddf12c4186d202854ce.diff
LOG: [libc++][PSTL] Fix std::copy frontend dispatching
Reviewed By: #libc, Mordante
Spies: Mordante, libcxx-commits
Differential Revision: https://reviews.llvm.org/D155325
Added:
Modified:
libcxx/include/__algorithm/pstl_copy.h
libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
Removed:
################################################################################
diff --git a/libcxx/include/__algorithm/pstl_copy.h b/libcxx/include/__algorithm/pstl_copy.h
index 83c712c35407f2..9ac268b6f95f07 100644
--- a/libcxx/include/__algorithm/pstl_copy.h
+++ b/libcxx/include/__algorithm/pstl_copy.h
@@ -10,6 +10,8 @@
#define _LIBCPP___ALGORITHM_PSTL_COPY_H
#include <__algorithm/copy_n.h>
+#include <__algorithm/pstl_backend.h>
+#include <__algorithm/pstl_frontend_dispatch.h>
#include <__algorithm/pstl_transform.h>
#include <__config>
#include <__functional/identity.h>
@@ -30,26 +32,48 @@ _LIBCPP_BEGIN_NAMESPACE_STD
// TODO: Use the std::copy/move shenanigans to forward to std::memmove
+template <class>
+void __pstl_copy();
+
template <class _ExecutionPolicy,
class _ForwardIterator,
class _ForwardOutIterator,
- enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+ class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
+ enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator
copy(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _ForwardOutIterator __result) {
- return std::transform(__policy, __first, __last, __result, __identity());
+ return std::__pstl_frontend_dispatch(
+ _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy),
+ [&__policy](_ForwardIterator __g_first, _ForwardIterator __g_last, _ForwardOutIterator __g_result) {
+ return std::transform(__policy, __g_first, __g_last, __g_result, __identity());
+ },
+ std::move(__first),
+ std::move(__last),
+ std::move(__result));
}
+template <class>
+void __pstl_copy_n();
+
template <class _ExecutionPolicy,
class _ForwardIterator,
class _ForwardOutIterator,
class _Size,
- enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+ class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
+ enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator
copy_n(_ExecutionPolicy&& __policy, _ForwardIterator __first, _Size __n, _ForwardOutIterator __result) {
- if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value)
- return std::copy(__policy, __first, __first + __n, __result);
- else
- return std::copy_n(__first, __n, __result);
+ return std::__pstl_frontend_dispatch(
+ _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy_n),
+ [&__policy](_ForwardIterator __g_first, _Size __g_n, _ForwardOutIterator __g_result) {
+ if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value)
+ return std::copy(__policy, __g_first, __g_first + __g_n, __g_result);
+ else
+ return std::copy_n(__g_first, __g_n, __g_result);
+ },
+ std::move(__first),
+ __n,
+ std::move(__result));
}
_LIBCPP_END_NAMESPACE_STD
diff --git a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
index 919945f72adaf7..0b5d9c4e93cdf5 100644
--- a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
@@ -42,6 +42,24 @@ bool __pstl_all_of(TestBackend, ForwardIterator, ForwardIterator, Pred) {
return true;
}
+bool pstl_copy_called = false;
+
+template <class, class ForwardIterator, class ForwardOutIterator>
+ForwardIterator __pstl_copy(TestBackend, ForwardIterator, ForwardIterator, ForwardOutIterator) {
+ assert(!pstl_copy_called);
+ pstl_copy_called = true;
+ return 0;
+}
+
+bool pstl_copy_n_called = false;
+
+template <class, class ForwardIterator, class Size, class ForwardOutIterator>
+ForwardIterator __pstl_copy_n(TestBackend, ForwardIterator, Size, ForwardOutIterator) {
+ assert(!pstl_copy_n_called);
+ pstl_copy_n_called = true;
+ return 0;
+}
+
bool pstl_count_called = false;
template <class, class ForwardIterator, class T>
@@ -290,6 +308,10 @@ int main(int, char**) {
assert(std::pstl_all_of_called);
(void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
assert(std::pstl_none_of_called);
+ std::copy(TestPolicy{}, std::begin(a), std::end(a), std::begin(a));
+ assert(std::pstl_copy_called);
+ std::copy_n(TestPolicy{}, std::begin(a), 1, std::begin(a));
+ assert(std::pstl_copy_n_called);
(void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
assert(std::pstl_count_called);
(void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);
More information about the libcxx-commits
mailing list