[libcxx-commits] [libcxx] cbd9e54 - [libc++][PSTL] Implement std::transform
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Mon May 15 06:48:50 PDT 2023
Author: Nikolas Klauser
Date: 2023-05-15T06:48:43-07:00
New Revision: cbd9e5454741ebe6b39521fe1a8ed4eed5c2c801
URL: https://github.com/llvm/llvm-project/commit/cbd9e5454741ebe6b39521fe1a8ed4eed5c2c801
DIFF: https://github.com/llvm/llvm-project/commit/cbd9e5454741ebe6b39521fe1a8ed4eed5c2c801.diff
LOG: [libc++][PSTL] Implement std::transform
Reviewed By: ldionne, #libc
Spies: libcxx-commits
Differential Revision: https://reviews.llvm.org/D149615
Added:
libcxx/include/__algorithm/pstl_transform.h
libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp
libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp
Modified:
libcxx/include/CMakeLists.txt
libcxx/include/__pstl/internal/glue_algorithm_defs.h
libcxx/include/__pstl/internal/glue_algorithm_impl.h
libcxx/include/algorithm
libcxx/test/support/type_algorithms.h
Removed:
################################################################################
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 4dd363de4d173..f304b5dafef82 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -82,6 +82,7 @@ set(files
__algorithm/pstl_find.h
__algorithm/pstl_for_each.h
__algorithm/pstl_frontend_dispatch.h
+ __algorithm/pstl_transform.h
__algorithm/push_heap.h
__algorithm/ranges_adjacent_find.h
__algorithm/ranges_all_of.h
diff --git a/libcxx/include/__algorithm/pstl_transform.h b/libcxx/include/__algorithm/pstl_transform.h
new file mode 100644
index 0000000000000..74a869583f515
--- /dev/null
+++ b/libcxx/include/__algorithm/pstl_transform.h
@@ -0,0 +1,129 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H
+#define _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H
+
+#include <__algorithm/transform.h>
+#include <__config>
+#include <__iterator/iterator_traits.h>
+#include <__pstl/internal/parallel_backend.h>
+#include <__pstl/internal/unseq_backend_simd.h>
+#include <__type_traits/is_execution_policy.h>
+#include <__utility/terminate_on_exception.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _ExecutionPolicy,
+ class _ForwardIterator,
+ class _ForwardOutIterator,
+ class _UnaryOperation,
+ enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator transform(
+ _ExecutionPolicy&& __policy,
+ _ForwardIterator __first,
+ _ForwardIterator __last,
+ _ForwardOutIterator __result,
+ _UnaryOperation __op) {
+ if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
+ __is_cpp17_random_access_iterator<_ForwardIterator>::value &&
+ __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) {
+ std::__terminate_on_exception([&] {
+ __pstl::__par_backend::__parallel_for(
+ __pstl::__internal::__par_backend_tag{},
+ __policy,
+ __first,
+ __last,
+ [&__policy, __op, __first, __result](_ForwardIterator __brick_first, _ForwardIterator __brick_last) {
+ return std::transform(
+ std::__remove_parallel_policy(__policy),
+ __brick_first,
+ __brick_last,
+ __result + (__brick_first - __first),
+ __op);
+ });
+ });
+ return __result + (__last - __first);
+ } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
+ __is_cpp17_random_access_iterator<_ForwardIterator>::value &&
+ __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) {
+ return __pstl::__unseq_backend::__simd_walk_2(
+ __first,
+ __last - __first,
+ __result,
+ [&](__iter_reference<_ForwardIterator> __in_value, __iter_reference<_ForwardOutIterator> __out_value) {
+ __out_value = __op(__in_value);
+ });
+ } else {
+ return std::transform(__first, __last, __result, __op);
+ }
+}
+
+template <class _ExecutionPolicy,
+ class _ForwardIterator1,
+ class _ForwardIterator2,
+ class _ForwardOutIterator,
+ class _BinaryOperation,
+ enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI _ForwardOutIterator transform(
+ _ExecutionPolicy&& __policy,
+ _ForwardIterator1 __first1,
+ _ForwardIterator1 __last1,
+ _ForwardIterator2 __first2,
+ _ForwardOutIterator __result,
+ _BinaryOperation __op) {
+ if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
+ __is_cpp17_random_access_iterator<_ForwardIterator1>::value &&
+ __is_cpp17_random_access_iterator<_ForwardIterator2>::value &&
+ __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) {
+ std::__terminate_on_exception([&] {
+ __pstl::__par_backend::__parallel_for(
+ __pstl::__internal::__par_backend_tag{},
+ __policy,
+ __first1,
+ __last1,
+ [&__policy, __op, __first1, __first2, __result](
+ _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last) {
+ return std::transform(
+ std::__remove_parallel_policy(__policy),
+ __brick_first,
+ __brick_last,
+ __first2 + (__brick_first - __first1),
+ __result + (__brick_first - __first1),
+ __op);
+ });
+ });
+ return __result + (__last1 - __first1);
+ } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
+ __is_cpp17_random_access_iterator<_ForwardIterator1>::value &&
+ __is_cpp17_random_access_iterator<_ForwardIterator2>::value &&
+ __is_cpp17_random_access_iterator<_ForwardOutIterator>::value) {
+ return __pstl::__unseq_backend::__simd_walk_3(
+ __first1,
+ __last1 - __first1,
+ __first2,
+ __result,
+ [&](__iter_reference<_ForwardIterator1> __in1,
+ __iter_reference<_ForwardIterator2> __in2,
+ __iter_reference<_ForwardOutIterator> __out) { __out = __op(__in1, __in2); });
+ } else {
+ return std::transform(__first1, __last1, __first2, __result, __op);
+ }
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+#endif // _LIBCPP___ALGORITHM_PSTL_TRANSFORM_H
diff --git a/libcxx/include/__pstl/internal/glue_algorithm_defs.h b/libcxx/include/__pstl/internal/glue_algorithm_defs.h
index 82bb3f508d5a4..de4501e56b2cf 100644
--- a/libcxx/include/__pstl/internal/glue_algorithm_defs.h
+++ b/libcxx/include/__pstl/internal/glue_algorithm_defs.h
@@ -134,29 +134,6 @@ template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterato
__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2> swap_ranges(
_ExecutionPolicy&& __exec, _ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2);
-// [alg.transform]
-
-template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _UnaryOperation>
-__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2>
-transform(_ExecutionPolicy&& __exec,
- _ForwardIterator1 __first,
- _ForwardIterator1 __last,
- _ForwardIterator2 __result,
- _UnaryOperation __op);
-
-template <class _ExecutionPolicy,
- class _ForwardIterator1,
- class _ForwardIterator2,
- class _ForwardIterator,
- class _BinaryOperation>
-__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator>
-transform(_ExecutionPolicy&& __exec,
- _ForwardIterator1 __first1,
- _ForwardIterator1 __last1,
- _ForwardIterator2 __first2,
- _ForwardIterator __result,
- _BinaryOperation __op);
-
// [alg.replace]
template <class _ExecutionPolicy, class _ForwardIterator, class _UnaryPredicate, class _Tp>
diff --git a/libcxx/include/__pstl/internal/glue_algorithm_impl.h b/libcxx/include/__pstl/internal/glue_algorithm_impl.h
index db62705233b9e..bae5efa7d0575 100644
--- a/libcxx/include/__pstl/internal/glue_algorithm_impl.h
+++ b/libcxx/include/__pstl/internal/glue_algorithm_impl.h
@@ -251,27 +251,6 @@ __pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardItera
// [alg.transform]
-template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _UnaryOperation>
-__pstl::__internal::__enable_if_execution_policy<_ExecutionPolicy, _ForwardIterator2>
-transform(_ExecutionPolicy&& __exec,
- _ForwardIterator1 __first,
- _ForwardIterator1 __last,
- _ForwardIterator2 __result,
- _UnaryOperation __op) {
- typedef typename iterator_traits<_ForwardIterator1>::reference _InputType;
- typedef typename iterator_traits<_ForwardIterator2>::reference _OutputType;
-
- auto __dispatch_tag = __pstl::__internal::__select_backend(__exec, __first, __result);
-
- return __pstl::__internal::__pattern_walk2(
- __dispatch_tag,
- std::forward<_ExecutionPolicy>(__exec),
- __first,
- __last,
- __result,
- [__op](_InputType __x, _OutputType __y) mutable { __y = __op(__x); });
-}
-
template <class _ExecutionPolicy,
class _ForwardIterator1,
class _ForwardIterator2,
diff --git a/libcxx/include/algorithm b/libcxx/include/algorithm
index 469bf17066281..18a89eb1a4dc9 100644
--- a/libcxx/include/algorithm
+++ b/libcxx/include/algorithm
@@ -1792,6 +1792,7 @@ template <class BidirectionalIterator, class Compare>
#include <__algorithm/pstl_fill.h>
#include <__algorithm/pstl_find.h>
#include <__algorithm/pstl_for_each.h>
+#include <__algorithm/pstl_transform.h>
#include <__algorithm/push_heap.h>
#include <__algorithm/ranges_adjacent_find.h>
#include <__algorithm/ranges_all_of.h>
diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp
new file mode 100644
index 0000000000000..1076a1548ee3d
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.binary.pass.cpp
@@ -0,0 +1,85 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator1, class ForwardIterator2,
+// class ForwardIterator, class BinaryOperation>
+// ForwardIterator
+// transform(ExecutionPolicy&& exec,
+// ForwardIterator1 first1, ForwardIterator1 last1,
+// ForwardIterator2 first2, ForwardIterator result,
+// BinaryOperation binary_op);
+
+#include <algorithm>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+EXECUTION_POLICY_SFINAE_TEST(transform);
+
+static_assert(sfinae_test_transform<int, int*, int*, int*, int*, bool (*)(int)>);
+static_assert(!sfinae_test_transform<std::execution::parallel_policy, int*, int*, int*, int*, int (*)(int, int)>);
+
+template <class Iter1, class Iter2, class Iter3>
+struct Test {
+ template <class Policy>
+ void operator()(Policy&& policy) {
+ // simple test
+ for (const int size : {0, 1, 2, 100, 350}) {
+ std::vector<int> a(size);
+ std::vector<int> b(size);
+ for (int i = 0; i != size; ++i) {
+ a[i] = i + 1;
+ b[i] = i - 3;
+ }
+
+ std::vector<int> out(std::size(a));
+ decltype(auto) ret = std::transform(
+ policy,
+ Iter1(std::data(a)),
+ Iter1(std::data(a) + std::size(a)),
+ Iter2(std::data(b)),
+ Iter3(std::data(out)),
+ [](int i, int j) { return i + j + 3; });
+ static_assert(std::is_same_v<decltype(ret), Iter3>);
+ assert(base(ret) == std::data(out) + std::size(out));
+ for (int i = 0; i != size; ++i) {
+ assert(out[i] == i * 2 + 1);
+ }
+ }
+ }
+};
+
+template <class Iter3>
+struct TestIterators2 {
+ template <class Iter2>
+ void operator()() {
+ types::for_each(types::forward_iterator_list<int*>{},
+ TestIteratorWithPolicies<types::partial_instantiation<Test, Iter2, Iter3>::template apply>{});
+ }
+};
+
+struct TestIterators1 {
+ template <class Iter3>
+ void operator()() {
+ types::for_each(types::forward_iterator_list<int*>{}, TestIterators2<Iter3>{});
+ }
+};
+
+int main(int, char**) {
+ types::for_each(types::forward_iterator_list<int*>{}, TestIterators1{});
+
+ return 0;
+}
diff --git a/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp
new file mode 100644
index 0000000000000..31069de4e5230
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.modifying.operations/alg.transform/pstl.transform.unary.pass.cpp
@@ -0,0 +1,67 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator1, class ForwardIterator2,
+// class UnaryOperation>
+// ForwardIterator2
+// transform(ExecutionPolicy&& exec,
+// ForwardIterator1 first1, ForwardIterator1 last1,
+// ForwardIterator2 result, UnaryOperation op);
+
+#include <algorithm>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+// We can't test the constraint on the execution policy, because that would conflict with the binary
+// transform algorithm that doesn't take an execution policy, which is not constrained at all.
+
+template <class Iter1, class Iter2>
+struct Test {
+ template <class Policy>
+ void operator()(Policy&& policy) {
+ // simple test
+ for (const int size : {0, 1, 2, 100, 350}) {
+ std::vector<int> a(size);
+ for (int i = 0; i != size; ++i)
+ a[i] = i + 1;
+
+ std::vector<int> out(std::size(a));
+ decltype(auto) ret = std::transform(
+ policy, Iter1(std::data(a)), Iter1(std::data(a) + std::size(a)), Iter2(std::data(out)), [](int i) {
+ return i + 3;
+ });
+ static_assert(std::is_same_v<decltype(ret), Iter2>);
+ assert(base(ret) == std::data(out) + std::size(out));
+ for (int i = 0; i != size; ++i)
+ assert(out[i] == i + 4);
+ }
+ }
+};
+
+struct TestIterators {
+ template <class Iter2>
+ void operator()() {
+ types::for_each(types::forward_iterator_list<int*>{},
+ TestIteratorWithPolicies<types::partial_instantiation<Test, Iter2>::template apply>{});
+ }
+};
+
+int main(int, char**) {
+ types::for_each(types::forward_iterator_list<int*>{}, TestIterators{});
+
+ return 0;
+}
diff --git a/libcxx/test/support/type_algorithms.h b/libcxx/test/support/type_algorithms.h
index 95a282b7b0bcf..ac3ee60b2ccfb 100644
--- a/libcxx/test/support/type_algorithms.h
+++ b/libcxx/test/support/type_algorithms.h
@@ -52,6 +52,13 @@ TEST_CONSTEXPR_CXX14 void for_each(type_list<Types...>, Functor f) {
swallow((f.template operator()<Types>(), 0)...);
}
+
+template <template <class...> class T, class... Args>
+struct partial_instantiation {
+ template <class Other>
+ using apply = T<Args..., Other>;
+};
+
// type categories defined in [basic.fundamental] plus extensions (without CV-qualifiers)
using character_types =
More information about the libcxx-commits
mailing list