[libcxx-commits] [libcxx] 9391330 - [libc++][PSTL] Fix std::copy frontend dispatching

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Mon Jul 31 18:39:08 PDT 2023


Author: Nikolas Klauser
Date: 2023-07-31T18:39:02-07:00
New Revision: 9391330293163deb22e4dddf12c4186d202854ce

URL: https://github.com/llvm/llvm-project/commit/9391330293163deb22e4dddf12c4186d202854ce
DIFF: https://github.com/llvm/llvm-project/commit/9391330293163deb22e4dddf12c4186d202854ce.diff

LOG: [libc++][PSTL] Fix std::copy frontend dispatching

Reviewed By: #libc, Mordante

Spies: Mordante, libcxx-commits

Differential Revision: https://reviews.llvm.org/D155325

Added: 
    

Modified: 
    libcxx/include/__algorithm/pstl_copy.h
    libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp

Removed: 
    


################################################################################
diff  --git a/libcxx/include/__algorithm/pstl_copy.h b/libcxx/include/__algorithm/pstl_copy.h
index 83c712c35407f2..9ac268b6f95f07 100644
--- a/libcxx/include/__algorithm/pstl_copy.h
+++ b/libcxx/include/__algorithm/pstl_copy.h
@@ -10,6 +10,8 @@
 #define _LIBCPP___ALGORITHM_PSTL_COPY_H
 
 #include <__algorithm/copy_n.h>
+#include <__algorithm/pstl_backend.h>
+#include <__algorithm/pstl_frontend_dispatch.h>
 #include <__algorithm/pstl_transform.h>
 #include <__config>
 #include <__functional/identity.h>
@@ -30,26 +32,48 @@ _LIBCPP_BEGIN_NAMESPACE_STD
 
 // TODO: Use the std::copy/move shenanigans to forward to std::memmove
 
+template <class>
+void __pstl_copy();
+
 template <class _ExecutionPolicy,
           class _ForwardIterator,
           class _ForwardOutIterator,
-          enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+          class _RawPolicy                                    = __remove_cvref_t<_ExecutionPolicy>,
+          enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
 _LIBCPP_HIDE_FROM_ABI _ForwardOutIterator
 copy(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _ForwardOutIterator __result) {
-  return std::transform(__policy, __first, __last, __result, __identity());
+  return std::__pstl_frontend_dispatch(
+      _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy),
+      [&__policy](_ForwardIterator __g_first, _ForwardIterator __g_last, _ForwardOutIterator __g_result) {
+        return std::transform(__policy, __g_first, __g_last, __g_result, __identity());
+      },
+      std::move(__first),
+      std::move(__last),
+      std::move(__result));
 }
 
+template <class>
+void __pstl_copy_n();
+
 template <class _ExecutionPolicy,
           class _ForwardIterator,
           class _ForwardOutIterator,
           class _Size,
-          enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+          class _RawPolicy                                    = __remove_cvref_t<_ExecutionPolicy>,
+          enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
 _LIBCPP_HIDE_FROM_ABI _ForwardOutIterator
 copy_n(_ExecutionPolicy&& __policy, _ForwardIterator __first, _Size __n, _ForwardOutIterator __result) {
-  if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value)
-    return std::copy(__policy, __first, __first + __n, __result);
-  else
-    return std::copy_n(__first, __n, __result);
+  return std::__pstl_frontend_dispatch(
+      _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_copy_n),
+      [&__policy](_ForwardIterator __g_first, _Size __g_n, _ForwardOutIterator __g_result) {
+        if constexpr (__has_random_access_iterator_category<_ForwardIterator>::value)
+          return std::copy(__policy, __g_first, __g_first + __g_n, __g_result);
+        else
+          return std::copy_n(__g_first, __g_n, __g_result);
+      },
+      std::move(__first),
+      __n,
+      std::move(__result));
 }
 
 _LIBCPP_END_NAMESPACE_STD

diff  --git a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
index 919945f72adaf7..0b5d9c4e93cdf5 100644
--- a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
@@ -42,6 +42,24 @@ bool __pstl_all_of(TestBackend, ForwardIterator, ForwardIterator, Pred) {
   return true;
 }
 
+bool pstl_copy_called = false;
+
+template <class, class ForwardIterator, class ForwardOutIterator>
+ForwardIterator __pstl_copy(TestBackend, ForwardIterator, ForwardIterator, ForwardOutIterator) {
+  assert(!pstl_copy_called);
+  pstl_copy_called = true;
+  return 0;
+}
+
+bool pstl_copy_n_called = false;
+
+template <class, class ForwardIterator, class Size, class ForwardOutIterator>
+ForwardIterator __pstl_copy_n(TestBackend, ForwardIterator, Size, ForwardOutIterator) {
+  assert(!pstl_copy_n_called);
+  pstl_copy_n_called = true;
+  return 0;
+}
+
 bool pstl_count_called = false;
 
 template <class, class ForwardIterator, class T>
@@ -290,6 +308,10 @@ int main(int, char**) {
   assert(std::pstl_all_of_called);
   (void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
   assert(std::pstl_none_of_called);
+  std::copy(TestPolicy{}, std::begin(a), std::end(a), std::begin(a));
+  assert(std::pstl_copy_called);
+  std::copy_n(TestPolicy{}, std::begin(a), 1, std::begin(a));
+  assert(std::pstl_copy_n_called);
   (void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
   assert(std::pstl_count_called);
   (void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);


        


More information about the libcxx-commits mailing list