[libcxx-commits] [libcxx] 7a3b528 - [libc++][PSTL] Implement std::count{, _if}
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Tue Jun 6 08:43:04 PDT 2023
Author: Nikolas Klauser
Date: 2023-06-06T08:42:59-07:00
New Revision: 7a3b528e1b540b4c98f4b557f917447481872749
URL: https://github.com/llvm/llvm-project/commit/7a3b528e1b540b4c98f4b557f917447481872749
DIFF: https://github.com/llvm/llvm-project/commit/7a3b528e1b540b4c98f4b557f917447481872749.diff
LOG: [libc++][PSTL] Implement std::count{,_if}
Reviewed By: ldionne, #libc
Spies: libcxx-commits
Differential Revision: https://reviews.llvm.org/D150128
Added:
libcxx/include/__algorithm/pstl_count.h
libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp
libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp
Modified:
libcxx/include/CMakeLists.txt
libcxx/include/__algorithm/pstl_backend.h
libcxx/include/algorithm
libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
Removed:
################################################################################
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 910e6b1727e69..62d6ea5962527 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -85,6 +85,7 @@ set(files
__algorithm/pstl_backends/cpu_backends/transform.h
__algorithm/pstl_backends/cpu_backends/transform_reduce.h
__algorithm/pstl_copy.h
+ __algorithm/pstl_count.h
__algorithm/pstl_fill.h
__algorithm/pstl_find.h
__algorithm/pstl_for_each.h
diff --git a/libcxx/include/__algorithm/pstl_backend.h b/libcxx/include/__algorithm/pstl_backend.h
index c25a8b1d0a930..73e2c48deb081 100644
--- a/libcxx/include/__algorithm/pstl_backend.h
+++ b/libcxx/include/__algorithm/pstl_backend.h
@@ -113,6 +113,12 @@ implemented, all the algorithms will eventually forward to the basis algorithms
temlate <class _ExecutionPolicy, class _Iterator>
__iter_value_type<_Iterator> __pstl_reduce(_Backend, _Iterator __first, _Iterator __last);
+ template <class _ExecuitonPolicy, class _Iterator, class _Tp>
+ __iter_
diff _t<_Iterator> __pstl_count(_Backend, _Iterator __first, _Iterator __last, const _Tp& __value);
+
+ template <class _ExecutionPolicy, class _Iterator, class _Predicate>
+ __iter_
diff _t<_Iterator> __pstl_count_if(_Backend, _Iterator __first, _Iterator __last, _Predicate __pred);
+
// TODO: Complete this list
*/
diff --git a/libcxx/include/__algorithm/pstl_count.h b/libcxx/include/__algorithm/pstl_count.h
new file mode 100644
index 0000000000000..7f591c99915ca
--- /dev/null
+++ b/libcxx/include/__algorithm/pstl_count.h
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_PSTL_COUNT_H
+#define _LIBCPP___ALGORITHM_PSTL_COUNT_H
+
+#include <__algorithm/count.h>
+#include <__algorithm/for_each.h>
+#include <__algorithm/pstl_backend.h>
+#include <__algorithm/pstl_for_each.h>
+#include <__algorithm/pstl_frontend_dispatch.h>
+#include <__atomic/atomic.h>
+#include <__config>
+#include <__iterator/iterator_traits.h>
+#include <__numeric/pstl_transform_reduce.h>
+#include <__type_traits/is_execution_policy.h>
+#include <__type_traits/remove_cvref.h>
+#include <__utility/terminate_on_exception.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class>
+void __pstl_count_if(); // declaration needed for the frontend dispatch below
+
+template <class _ExecutionPolicy,
+ class _ForwardIterator,
+ class _Predicate,
+ class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
+ enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI __iter_
diff _t<_ForwardIterator>
+count_if(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) {
+ using __
diff _t = __iter_
diff _t<_ForwardIterator>;
+ return std::__pstl_frontend_dispatch(
+ _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count_if),
+ [&](_ForwardIterator __g_first, _ForwardIterator __g_last, _Predicate __g_pred) {
+ return std::transform_reduce(
+ __policy,
+ std::move(__g_first),
+ std::move(__g_last),
+ __
diff _t(),
+ std::plus{},
+ [&](__iter_reference<_ForwardIterator> __element) -> bool { return __g_pred(__element); });
+ },
+ std::move(__first),
+ std::move(__last),
+ std::move(__pred));
+}
+
+template <class>
+void __pstl_count(); // declaration needed for the frontend dispatch below
+
+template <class _ExecutionPolicy,
+ class _ForwardIterator,
+ class _Tp,
+ class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
+ enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI __iter_
diff _t<_ForwardIterator>
+count(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) {
+ return std::__pstl_frontend_dispatch(
+ _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count),
+ [&](_ForwardIterator __g_first, _ForwardIterator __g_last, const _Tp& __g_value) {
+ return std::count_if(__policy, __g_first, __g_last, [&](__iter_reference<_ForwardIterator> __v) {
+ return __v == __g_value;
+ });
+ },
+ std::move(__first),
+ std::move(__last),
+ __value);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+#endif // _LIBCPP___ALGORITHM_PSTL_COUNT_H
diff --git a/libcxx/include/algorithm b/libcxx/include/algorithm
index 24d29fd777f54..04a21c040ef82 100644
--- a/libcxx/include/algorithm
+++ b/libcxx/include/algorithm
@@ -1802,6 +1802,7 @@ template <class BidirectionalIterator, class Compare>
#include <__algorithm/prev_permutation.h>
#include <__algorithm/pstl_any_all_none_of.h>
#include <__algorithm/pstl_copy.h>
+#include <__algorithm/pstl_count.h>
#include <__algorithm/pstl_fill.h>
#include <__algorithm/pstl_find.h>
#include <__algorithm/pstl_for_each.h>
diff --git a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
index 76188e70a441b..2a634c38d425f 100644
--- a/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
+++ b/libcxx/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
@@ -42,6 +42,26 @@ bool __pstl_all_of(TestBackend, ForwardIterator, ForwardIterator, Pred) {
return true;
}
+bool pstl_count_called = false;
+
+template <class, class ForwardIterator, class T>
+typename std::iterator_traits<ForwardIterator>::
diff erence_type
+__pstl_count(TestBackend, ForwardIterator, ForwardIterator, const T&) {
+ assert(!pstl_count_called);
+ pstl_count_called = true;
+ return 0;
+}
+
+bool pstl_count_if_called = false;
+
+template <class, class ForwardIterator, class Pred>
+typename std::iterator_traits<ForwardIterator>::
diff erence_type
+__pstl_count_if(TestBackend, ForwardIterator, ForwardIterator, Pred) {
+ assert(!pstl_count_if_called);
+ pstl_count_if_called = true;
+ return 0;
+}
+
bool pstl_none_of_called = false;
template <class, class ForwardIterator, class Pred>
@@ -197,6 +217,10 @@ int main(int, char**) {
assert(std::pstl_all_of_called);
(void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
assert(std::pstl_none_of_called);
+ (void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
+ assert(std::pstl_count_called);
+ (void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);
+ assert(std::pstl_count_if_called);
(void)std::fill(TestPolicy{}, std::begin(a), std::end(a), 0);
assert(std::pstl_fill_called);
(void)std::fill_n(TestPolicy{}, std::begin(a), std::size(a), 0);
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp
new file mode 100644
index 0000000000000..f00861f66bfe9
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator, class T>
+// typename iterator_traits<ForwardIterator>::
diff erence_type
+// count(ExecutionPolicy&& exec,
+// ForwardIterator first, ForwardIterator last, const T& value);
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+EXECUTION_POLICY_SFINAE_TEST(count);
+
+static_assert(sfinae_test_count<int, int*, int*, bool (*)(int)>);
+static_assert(!sfinae_test_count<std::execution::parallel_policy, int*, int*, int>);
+
+template <class Iter>
+struct Test {
+ template <class Policy>
+ void operator()(Policy&& policy) {
+ { // simple test
+ int a[] = {1, 2, 3, 4, 5};
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 1);
+ }
+
+ { // test that an empty range works
+ std::array<int, 0> a;
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 0);
+ }
+
+ { // test that a single-element range works
+ int a[] = {1};
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 1);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 1);
+ }
+
+ { // test that a two-element range works
+ int a[] = {1, 3};
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 1);
+ }
+
+ { // test that a three-element range works
+ int a[] = {3, 1, 3};
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 2);
+ }
+
+ { // test that a large range works
+ std::vector<int> a(100, 2);
+ decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 2);
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 100);
+ }
+ }
+};
+
+int main(int, char**) {
+ types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
+
+ return 0;
+}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp
new file mode 100644
index 0000000000000..489c7a7332a6e
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator, class Predicate>
+// typename iterator_traits<ForwardIterator>::
diff erence_type
+// count_if(ExecutionPolicy&& exec,
+// ForwardIterator first, ForwardIterator last, Predicate pred);
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+EXECUTION_POLICY_SFINAE_TEST(count_if);
+
+static_assert(sfinae_test_count_if<int, int*, int*, bool (*)(int)>);
+static_assert(!sfinae_test_count_if<std::execution::parallel_policy, int*, int*, int>);
+
+template <class Iter>
+struct Test {
+ template <class Policy>
+ void operator()(Policy&& policy) {
+ { // simple test
+ int a[] = {1, 2, 3, 4, 5};
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 2);
+ }
+
+ { // test that an empty range works
+ std::array<int, 0> a;
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 0);
+ }
+
+ { // test that a single-element range works
+ int a[] = {1};
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 1);
+ }
+
+ { // test that a two-element range works
+ int a[] = {1, 3};
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 1);
+ }
+
+ { // test that a three-element range works
+ int a[] = {2, 3, 2};
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 2);
+ }
+
+ { // test that a large range works
+ std::vector<int> a(100, 2);
+ decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+ static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::
diff erence_type>);
+ assert(ret == 100);
+ }
+ }
+};
+
+int main(int, char**) {
+ types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
+
+ return 0;
+}
More information about the libcxx-commits
mailing list