[libcxx-commits] [libcxx] [libc++] Optimize {std, ranges}::for_each for iterating over __trees (PR #164405)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Wed Dec 10 01:26:47 PST 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/164405

>From bebc3270d233804e9dbaf945ec3e8d150c27fae9 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 21 Oct 2025 14:07:25 +0200
Subject: [PATCH] [libc++] Optimize ranges::for_each for iterating over __trees

[libc++] Optimize std::for_each for __tree iterators
---
 libcxx/include/__algorithm/for_each.h         |   8 +-
 libcxx/include/__algorithm/ranges_for_each.h  |  10 +-
 .../__algorithm/specialized_algorithms.h      |   7 +
 libcxx/include/__tree                         | 104 +++++++++++
 libcxx/include/map                            |  77 ++++++++
 libcxx/include/set                            |  41 +++++
 .../nonmodifying/for_each.bench.cpp           |  63 +++++++
 .../alg.foreach/for_each.associative.pass.cpp |  78 ++++++++
 .../alg.foreach/for_each.pass.cpp             |   4 +-
 .../ranges.for_each.associative.pass copy.cpp | 168 ++++++++++++++++++
 10 files changed, 555 insertions(+), 5 deletions(-)
 create mode 100644 libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.associative.pass.cpp
 create mode 100644 libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/ranges.for_each.associative.pass copy.cpp

diff --git a/libcxx/include/__algorithm/for_each.h b/libcxx/include/__algorithm/for_each.h
index cb26aa4d2656a..85fedce3d936d 100644
--- a/libcxx/include/__algorithm/for_each.h
+++ b/libcxx/include/__algorithm/for_each.h
@@ -11,6 +11,7 @@
 #define _LIBCPP___ALGORITHM_FOR_EACH_H
 
 #include <__algorithm/for_each_segment.h>
+#include <__algorithm/specialized_algorithms.h>
 #include <__config>
 #include <__functional/identity.h>
 #include <__iterator/segmented_iterator.h>
@@ -27,7 +28,12 @@ template <class _InputIterator, class _Sent, class _Func, class _Proj>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator
 __for_each(_InputIterator __first, _Sent __last, _Func& __func, _Proj& __proj) {
 #ifndef _LIBCPP_CXX03_LANG
-  if constexpr (is_same<_InputIterator, _Sent>::value && __is_segmented_iterator_v<_InputIterator>) {
+  if constexpr (using _SpecialAlg =
+                    __specialized_algorithm<_Algorithm::__for_each, __iterator_pair<_InputIterator, _Sent>>;
+                _SpecialAlg::__has_algorithm) {
+    _SpecialAlg()(__first, __last, __func, __proj);
+    return __last;
+  } else if constexpr (is_same<_InputIterator, _Sent>::value && __is_segmented_iterator_v<_InputIterator>) {
     using __local_iterator_t = typename __segmented_iterator_traits<_InputIterator>::__local_iterator;
     std::__for_each_segment(__first, __last, [&](__local_iterator_t __lfirst, __local_iterator_t __llast) {
       std::__for_each(__lfirst, __llast, __func, __proj);
diff --git a/libcxx/include/__algorithm/ranges_for_each.h b/libcxx/include/__algorithm/ranges_for_each.h
index e9c84e8583f87..7a547fb269b4b 100644
--- a/libcxx/include/__algorithm/ranges_for_each.h
+++ b/libcxx/include/__algorithm/ranges_for_each.h
@@ -12,6 +12,7 @@
 #include <__algorithm/for_each.h>
 #include <__algorithm/for_each_n.h>
 #include <__algorithm/in_fun_result.h>
+#include <__algorithm/specialized_algorithms.h>
 #include <__concepts/assignable.h>
 #include <__config>
 #include <__functional/identity.h>
@@ -20,6 +21,7 @@
 #include <__ranges/access.h>
 #include <__ranges/concepts.h>
 #include <__ranges/dangling.h>
+#include <__type_traits/remove_cvref.h>
 #include <__utility/move.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -71,7 +73,13 @@ struct __for_each {
             indirectly_unary_invocable<projected<iterator_t<_Range>, _Proj>> _Func>
   _LIBCPP_HIDE_FROM_ABI constexpr for_each_result<borrowed_iterator_t<_Range>, _Func>
   operator()(_Range&& __range, _Func __func, _Proj __proj = {}) const {
-    return __for_each_impl(ranges::begin(__range), ranges::end(__range), __func, __proj);
+    using _SpecialAlg = __specialized_algorithm<_Algorithm::__for_each, __single_range<remove_cvref_t<_Range>>>;
+    if constexpr (_SpecialAlg::__has_algorithm) {
+      auto [__iter, __func2] = _SpecialAlg()(__range, std::move(__func), std::move(__proj));
+      return {std::move(__iter), std::move(__func)};
+    } else {
+      return __for_each_impl(ranges::begin(__range), ranges::end(__range), __func, __proj);
+    }
   }
 };
 
diff --git a/libcxx/include/__algorithm/specialized_algorithms.h b/libcxx/include/__algorithm/specialized_algorithms.h
index a2ffd36f0c87d..1d3bc8723c3f3 100644
--- a/libcxx/include/__algorithm/specialized_algorithms.h
+++ b/libcxx/include/__algorithm/specialized_algorithms.h
@@ -19,11 +19,18 @@ _LIBCPP_BEGIN_NAMESPACE_STD
 
 namespace _Algorithm {
 struct __fill_n {};
+struct __for_each {};
 } // namespace _Algorithm
 
 template <class>
 struct __single_iterator;
 
+template <class, class>
+struct __iterator_pair;
+
+template <class>
+struct __single_range;
+
 // This struct allows specializing algorithms for specific arguments. This is useful when we know a more efficient
 // algorithm implementation for e.g. library-defined iterators. _Alg is one of tags defined inside the _Algorithm
 // namespace above. _Ranges is an essentially arbitrary subset of the arguments to the algorithm that are used for
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index ceae22bb48702..75fb950c4ed1c 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -11,6 +11,7 @@
 #define _LIBCPP___TREE
 
 #include <__algorithm/min.h>
+#include <__algorithm/specialized_algorithms.h>
 #include <__assert>
 #include <__config>
 #include <__fwd/pair.h>
@@ -36,6 +37,7 @@
 #include <__type_traits/is_swappable.h>
 #include <__type_traits/make_transparent.h>
 #include <__type_traits/remove_const.h>
+#include <__type_traits/remove_cvref.h>
 #include <__utility/forward.h>
 #include <__utility/lazy_synth_three_way_comparator.h>
 #include <__utility/move.h>
@@ -656,6 +658,50 @@ struct __generic_container_node_destructor<__tree_node<_Tp, _VoidPtr>, _Alloc> :
 };
 #endif
 
+// Do an in-order traversal of the tree until `__break` returns true. Takes the root node of the tree.
+template <class _Reference, class _Break, class _NodePtr, class _Func, class _Proj>
+_LIBCPP_HIDE_FROM_ABI bool __tree_iterate_from_root(_Break __break, _NodePtr __root, _Func& __func, _Proj& __proj) {
+  if (__root->__left_) {
+    if (std::__tree_iterate_from_root<_Reference>(__break, static_cast<_NodePtr>(__root->__left_), __func, __proj))
+      return true;
+  }
+  if (__break(__root))
+    return true;
+  __func(static_cast<_Reference>(__root->__get_value()));
+  if (__root->__right_)
+    return std::__tree_iterate_from_root<_Reference>(__break, static_cast<_NodePtr>(__root->__right_), __func, __proj);
+  return false;
+}
+
+// Do an in-order traversal of the tree from __first to __last.
+template <class _NodeIter, class _Func, class _Proj>
+_LIBCPP_HIDE_FROM_ABI void
+__tree_iterate_subrange(_NodeIter __first_it, _NodeIter __last_it, _Func& __func, _Proj& __proj) {
+  using _NodePtr   = typename _NodeIter::__node_pointer;
+  using _Reference = typename _NodeIter::reference;
+
+  auto __first = __first_it.__ptr_;
+  auto __last  = __last_it.__ptr_;
+
+  while (true) {
+    if (__first == __last)
+      return;
+    const auto __nfirst = static_cast<_NodePtr>(__first);
+    __func(static_cast<_Reference>(__nfirst->__get_value()));
+    if (__nfirst->__right_) {
+      if (std::__tree_iterate_from_root<_Reference>(
+              [&](_NodePtr __node) -> bool { return __node == __last; },
+              static_cast<_NodePtr>(__nfirst->__right_),
+              __func,
+              __proj))
+        return;
+    }
+    while (!std::__tree_is_left_child(static_cast<_NodePtr>(__first)))
+      __first = static_cast<_NodePtr>(__first)->__parent_;
+    __first = static_cast<_NodePtr>(__first)->__parent_;
+  }
+}
+
 template <class _Tp, class _NodePtr, class _DiffType>
 class __tree_iterator {
   using _NodeTypes _LIBCPP_NODEBUG = __tree_node_types<_NodePtr>;
@@ -715,8 +761,28 @@ private:
   friend class __tree;
   template <class, class, class>
   friend class __tree_const_iterator;
+
+  template <class _NodeIter, class _Func, class _Proj>
+  friend void __tree_iterate_subrange(_NodeIter, _NodeIter, _Func&, _Proj&);
 };
 
+#ifndef _LIBCPP_CXX03_LANG
+// This also handles {multi,}set::iterator, since they're just aliases to __tree::iterator
+template <class _Tp, class _NodePtr, class _DiffType>
+struct __specialized_algorithm<
+    _Algorithm::__for_each,
+    __iterator_pair<__tree_iterator<_Tp, _NodePtr, _DiffType>, __tree_iterator<_Tp, _NodePtr, _DiffType>>> {
+  static const bool __has_algorithm = true;
+
+  using __iterator _LIBCPP_NODEBUG = __tree_iterator<_Tp, _NodePtr, _DiffType>;
+
+  template <class _Func, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Func& __func, _Proj& __proj) {
+    std::__tree_iterate_subrange(__first, __last, __func, __proj);
+  }
+};
+#endif
+
 template <class _Tp, class _NodePtr, class _DiffType>
 class __tree_const_iterator {
   using _NodeTypes _LIBCPP_NODEBUG = __tree_node_types<_NodePtr>;
@@ -780,7 +846,27 @@ private:
 
   template <class, class, class>
   friend class __tree;
+
+  template <class _NodeIter, class _Func, class _Proj>
+  friend void __tree_iterate_subrange(_NodeIter, _NodeIter, _Func&, _Proj&);
+};
+
+#ifndef _LIBCPP_CXX03_LANG
+// This also handles {multi,}set::const_iterator, since they're just aliases to __tree::iterator
+template <class _Tp, class _NodePtr, class _DiffType>
+struct __specialized_algorithm<
+    _Algorithm::__for_each,
+    __iterator_pair<__tree_const_iterator<_Tp, _NodePtr, _DiffType>, __tree_const_iterator<_Tp, _NodePtr, _DiffType>>> {
+  static const bool __has_algorithm = true;
+
+  using __iterator _LIBCPP_NODEBUG = __tree_const_iterator<_Tp, _NodePtr, _DiffType>;
+
+  template <class _Func, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Func& __func, _Proj& __proj) {
+    std::__tree_iterate_subrange(__first, __last, __func, __proj);
+  }
 };
+#endif
 
 template <class _Tp, class _Compare>
 #ifndef _LIBCPP_CXX03_LANG
@@ -1484,7 +1570,25 @@ private:
         [](value_type& __lhs, value_type& __rhs) { __assign_value(__lhs, std::move(__rhs)); },
         [this](__node_pointer __nd) { return __move_construct_tree(__nd); });
   }
+
+  friend struct __specialized_algorithm<_Algorithm::__for_each, __single_range<__tree> >;
+};
+
+#if _LIBCPP_STD_VER >= 14
+template <class _Tp, class _Compare, class _Allocator>
+struct __specialized_algorithm<_Algorithm::__for_each, __single_range<__tree<_Tp, _Compare, _Allocator> > > {
+  static const bool __has_algorithm = true;
+
+  using __node_pointer _LIBCPP_NODEBUG = typename __tree<_Tp, _Compare, _Allocator>::__node_pointer;
+
+  template <class _Tree, class _Func, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static auto operator()(_Tree&& __range, _Func __func, _Proj __proj) {
+    std::__tree_iterate_from_root<__copy_cvref_t<_Tree, typename __remove_cvref_t<_Tree>::value_type>>(
+        [](__node_pointer) { return false; }, __range.__root(), __func, __proj);
+    return std::make_pair(__range.end(), std::move(__func));
+  }
 };
+#endif
 
 // Precondition:  __size_ != 0
 template <class _Tp, class _Compare, class _Allocator>
diff --git a/libcxx/include/map b/libcxx/include/map
index 0dca11cabd12e..1ec36c62c3773 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -577,6 +577,7 @@ erase_if(multimap<Key, T, Compare, Allocator>& c, Predicate pred);  // C++20
 #  include <__algorithm/equal.h>
 #  include <__algorithm/lexicographical_compare.h>
 #  include <__algorithm/lexicographical_compare_three_way.h>
+#  include <__algorithm/specialized_algorithms.h>
 #  include <__assert>
 #  include <__config>
 #  include <__functional/binary_function.h>
@@ -818,7 +819,26 @@ public:
   friend class multimap;
   template <class>
   friend class __map_const_iterator;
+
+  template <class, class...>
+  friend struct __specialized_algorithm;
+};
+
+#  ifndef _LIBCPP_CXX03_LANG
+template <class _Alg, class _TreeIterator>
+struct __specialized_algorithm<_Alg, __iterator_pair<__map_iterator<_TreeIterator>, __map_iterator<_TreeIterator>>> {
+  using __base _LIBCPP_NODEBUG = __specialized_algorithm<_Alg, __iterator_pair<_TreeIterator, _TreeIterator>>;
+
+  static const bool __has_algorithm = __base::__has_algorithm;
+
+  using __iterator _LIBCPP_NODEBUG = __map_iterator<_TreeIterator>;
+
+  template <class... _Args>
+  _LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Args&&... __args) {
+    __base()(__first.__i_, __last.__i_, std::forward<_Args>(__args)...);
+  }
 };
+#  endif
 
 template <class _TreeIterator>
 class __map_const_iterator {
@@ -873,7 +893,28 @@ public:
   friend class multimap;
   template <class, class, class>
   friend class __tree_const_iterator;
+
+  template <class, class...>
+  friend struct __specialized_algorithm;
+};
+
+#  ifndef _LIBCPP_CXX03_LANG
+template <class _Alg, class _TreeIterator>
+struct __specialized_algorithm<
+    _Alg,
+    __iterator_pair<__map_const_iterator<_TreeIterator>, __map_const_iterator<_TreeIterator>>> {
+  using __base _LIBCPP_NODEBUG = __specialized_algorithm<_Alg, __iterator_pair<_TreeIterator, _TreeIterator>>;
+
+  static const bool __has_algorithm = __base::__has_algorithm;
+
+  using __iterator _LIBCPP_NODEBUG = __map_const_iterator<_TreeIterator>;
+
+  template <class... _Args>
+  _LIBCPP_HIDE_FROM_ABI static void operator()(__iterator __first, __iterator __last, _Args&&... __args) {
+    __base()(__first.__i_, __last.__i_, std::forward<_Args>(__args)...);
+  }
 };
+#  endif
 
 template <class _Key, class _Tp, class _Compare = less<_Key>, class _Allocator = allocator<pair<const _Key, _Tp> > >
 class multimap;
@@ -1370,6 +1411,8 @@ private:
 #  ifdef _LIBCPP_CXX03_LANG
   _LIBCPP_HIDE_FROM_ABI __node_holder __construct_node_with_key(const key_type& __k);
 #  endif
+
+  friend struct __specialized_algorithm<_Algorithm::__for_each, __single_range<map> >;
 };
 
 #  if _LIBCPP_STD_VER >= 17
@@ -1422,6 +1465,22 @@ map(initializer_list<pair<_Key, _Tp>>, _Allocator)
     -> map<remove_const_t<_Key>, _Tp, less<remove_const_t<_Key>>, _Allocator>;
 #  endif
 
+#  if _LIBCPP_STD_VER >= 14
+template <class _Key, class _Tp, class _Compare, class _Allocator>
+struct __specialized_algorithm<_Algorithm::__for_each, __single_range<map<_Key, _Tp, _Compare, _Allocator>>> {
+  using __map _LIBCPP_NODEBUG = map<_Key, _Tp, _Compare, _Allocator>;
+
+  static const bool __has_algorithm = true;
+
+  template <class _Map, class _Func, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static auto operator()(_Map&& __map, _Func __func, _Proj __proj) {
+    auto [_, __func2] = __specialized_algorithm<_Algorithm::__for_each, __single_range<typename __map::__base>>()(
+        __map.__tree_, std::move(__func), std::move(__proj));
+    return std::make_pair(__map.end(), std::move(__func2));
+  }
+};
+#  endif
+
 #  ifndef _LIBCPP_CXX03_LANG
 template <class _Key, class _Tp, class _Compare, class _Allocator>
 _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
@@ -1920,6 +1979,8 @@ private:
 
   typedef __map_node_destructor<__node_allocator> _Dp;
   typedef unique_ptr<__node, _Dp> __node_holder;
+
+  friend struct __specialized_algorithm<_Algorithm::__for_each, __single_range<multimap> >;
 };
 
 #  if _LIBCPP_STD_VER >= 17
@@ -1972,6 +2033,22 @@ multimap(initializer_list<pair<_Key, _Tp>>, _Allocator)
     -> multimap<remove_const_t<_Key>, _Tp, less<remove_const_t<_Key>>, _Allocator>;
 #  endif
 
+#  if _LIBCPP_STD_VER >= 14
+template <class _Key, class _Tp, class _Compare, class _Allocator>
+struct __specialized_algorithm<_Algorithm::__for_each, __single_range<multimap<_Key, _Tp, _Compare, _Allocator>>> {
+  using __map _LIBCPP_NODEBUG = multimap<_Key, _Tp, _Compare, _Allocator>;
+
+  static const bool __has_algorithm = true;
+
+  template <class _Map, class _Func, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static auto operator()(_Map&& __map, _Func __func, _Proj __proj) {
+    auto [_, __func2] = __specialized_algorithm<_Algorithm::__for_each, __single_range<typename __map::__base>>()(
+        __map.__tree_, std::move(__func), std::move(__proj));
+    return std::make_pair(__map.end(), std::move(__func2));
+  }
+};
+#  endif
+
 template <class _Key, class _Tp, class _Compare, class _Allocator>
 inline _LIBCPP_HIDE_FROM_ABI bool
 operator==(const multimap<_Key, _Tp, _Compare, _Allocator>& __x, const multimap<_Key, _Tp, _Compare, _Allocator>& __y) {
diff --git a/libcxx/include/set b/libcxx/include/set
index 3d6f571a42a1a..bd6e139f6e075 100644
--- a/libcxx/include/set
+++ b/libcxx/include/set
@@ -518,6 +518,7 @@ erase_if(multiset<Key, Compare, Allocator>& c, Predicate pred);  // C++20
 #  include <__algorithm/equal.h>
 #  include <__algorithm/lexicographical_compare.h>
 #  include <__algorithm/lexicographical_compare_three_way.h>
+#  include <__algorithm/specialized_algorithms.h>
 #  include <__assert>
 #  include <__config>
 #  include <__functional/is_transparent.h>
@@ -898,6 +899,9 @@ public:
     return __tree_.__equal_range_multi(__k);
   }
 #  endif
+
+  template <class, class...>
+  friend struct __specialized_algorithm;
 };
 
 #  if _LIBCPP_STD_VER >= 17
@@ -944,6 +948,23 @@ template <class _Key, class _Allocator, class = enable_if_t<__is_allocator_v<_Al
 set(initializer_list<_Key>, _Allocator) -> set<_Key, less<_Key>, _Allocator>;
 #  endif
 
+#  if _LIBCPP_STD_VER >= 14
+template <class _Alg, class _Key, class _Compare, class _Allocator>
+struct __specialized_algorithm<_Alg, __single_range<set<_Key, _Compare, _Allocator>>> {
+  using __set _LIBCPP_NODEBUG = set<_Key, _Compare, _Allocator>;
+
+  static const bool __has_algorithm =
+      __specialized_algorithm<_Alg, __single_range<typename __set::__base>>::__has_algorithm;
+
+  // set's begin() and end() are identical with and without const qualification
+  template <class... _Args>
+  _LIBCPP_HIDE_FROM_ABI static auto operator()(const __set& __set, _Args&&... __args) {
+    return __specialized_algorithm<_Alg, __single_range<typename __set::__base>>()(
+        __set.__tree_, std::forward<_Args>(__args)...);
+  }
+};
+#  endif
+
 template <class _Key, class _Compare, class _Allocator>
 inline _LIBCPP_HIDE_FROM_ABI bool
 operator==(const set<_Key, _Compare, _Allocator>& __x, const set<_Key, _Compare, _Allocator>& __y) {
@@ -1342,6 +1363,9 @@ public:
     return __tree_.__equal_range_multi(__k);
   }
 #  endif
+
+  template <class, class...>
+  friend struct __specialized_algorithm;
 };
 
 #  if _LIBCPP_STD_VER >= 17
@@ -1389,6 +1413,23 @@ template <class _Key, class _Allocator, class = enable_if_t<__is_allocator_v<_Al
 multiset(initializer_list<_Key>, _Allocator) -> multiset<_Key, less<_Key>, _Allocator>;
 #  endif
 
+#  if _LIBCPP_STD_VER >= 14
+template <class _Alg, class _Key, class _Compare, class _Allocator>
+struct __specialized_algorithm<_Alg, __single_range<multiset<_Key, _Compare, _Allocator>>> {
+  using __set _LIBCPP_NODEBUG = multiset<_Key, _Compare, _Allocator>;
+
+  static const bool __has_algorithm =
+      __specialized_algorithm<_Alg, __single_range<typename __set::__base>>::__has_algorithm;
+
+  // set's begin() and end() are identical with and without const qualification
+  template <class... _Args>
+  _LIBCPP_HIDE_FROM_ABI static auto operator()(const __set& __set, _Args&&... __args) {
+    return __specialized_algorithm<_Alg, __single_range<typename __set::__base>>()(
+        __set.__tree_, std::forward<_Args>(__args)...);
+  }
+};
+#  endif
+
 template <class _Key, class _Compare, class _Allocator>
 inline _LIBCPP_HIDE_FROM_ABI bool
 operator==(const multiset<_Key, _Compare, _Allocator>& __x, const multiset<_Key, _Compare, _Allocator>& __y) {
diff --git a/libcxx/test/benchmarks/algorithms/nonmodifying/for_each.bench.cpp b/libcxx/test/benchmarks/algorithms/nonmodifying/for_each.bench.cpp
index f58f336f8b892..a1a3d07e27978 100644
--- a/libcxx/test/benchmarks/algorithms/nonmodifying/for_each.bench.cpp
+++ b/libcxx/test/benchmarks/algorithms/nonmodifying/for_each.bench.cpp
@@ -10,10 +10,13 @@
 
 #include <algorithm>
 #include <cstddef>
+#include <cstdint>
 #include <deque>
 #include <list>
+#include <map>
 #include <ranges>
 #include <string>
+#include <type_traits>
 #include <vector>
 
 #include <benchmark/benchmark.h>
@@ -52,6 +55,66 @@ int main(int argc, char** argv) {
     bm.operator()<std::list<int>>("rng::for_each(list<int>)", std::ranges::for_each);
   }
 
+  // std::{,range::}for_each for associative containers
+  {
+    auto iterator_bm = []<class Container, bool IsMapLike>(
+                           std::type_identity<Container>, std::bool_constant<IsMapLike>, std::string name) {
+      benchmark::RegisterBenchmark(
+          name,
+          [](auto& st) {
+            Container c;
+            for (std::int64_t i = 0; i != st.range(0); ++i) {
+              if constexpr (IsMapLike)
+                c.emplace(i, i);
+              else
+                c.emplace(i);
+            }
+
+            for (auto _ : st) {
+              benchmark::DoNotOptimize(c);
+              std::for_each(c.begin(), c.end(), [](auto v) { benchmark::DoNotOptimize(&v); });
+            }
+          })
+          ->Arg(8)
+          ->Arg(32)
+          ->Arg(50) // non power-of-two
+          ->Arg(8192);
+    };
+    iterator_bm(std::type_identity<std::set<int>>{}, std::false_type{}, "rng::for_each(set<int>::iterator)");
+    iterator_bm(std::type_identity<std::multiset<int>>{}, std::false_type{}, "rng::for_each(multiset<int>::iterator)");
+    iterator_bm(std::type_identity<std::map<int, int>>{}, std::true_type{}, "rng::for_each(map<int>::iterator)");
+    iterator_bm(
+        std::type_identity<std::multimap<int, int>>{}, std::true_type{}, "rng::for_each(multimap<int>::iterator)");
+
+    auto container_bm = []<class Container, bool IsMapLike>(
+                            std::type_identity<Container>, std::bool_constant<IsMapLike>, std::string name) {
+      benchmark::RegisterBenchmark(
+          name,
+          [](auto& st) {
+            Container c;
+            for (std::int64_t i = 0; i != st.range(0); ++i) {
+              if constexpr (IsMapLike)
+                c.emplace(i, i);
+              else
+                c.emplace(i);
+            }
+
+            for (auto _ : st) {
+              benchmark::DoNotOptimize(c);
+              std::ranges::for_each(c, [](auto v) { benchmark::DoNotOptimize(&v); });
+            }
+          })
+          ->Arg(8)
+          ->Arg(32)
+          ->Arg(50) // non power-of-two
+          ->Arg(8192);
+    };
+    container_bm(std::type_identity<std::set<int>>{}, std::false_type{}, "rng::for_each(set<int>)");
+    container_bm(std::type_identity<std::multiset<int>>{}, std::false_type{}, "rng::for_each(multiset<int>)");
+    container_bm(std::type_identity<std::map<int, int>>{}, std::true_type{}, "rng::for_each(map<int>)");
+    container_bm(std::type_identity<std::multimap<int, int>>{}, std::true_type{}, "rng::for_each(multimap<int>)");
+  }
+
   // {std,ranges}::for_each for join_view
   {
     auto bm = []<class Container>(std::string name, auto for_each) {
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.associative.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.associative.pass.cpp
new file mode 100644
index 0000000000000..e976ffaaddfbc
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.associative.pass.cpp
@@ -0,0 +1,78 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// <algorithm>
+
+// Check that the special implementation of std::for_each for the associative container iterators works as expected
+
+// template<InputIterator Iter, class Function>
+//   constexpr Function   // constexpr since C++20
+//   for_each(Iter first, Iter last, Function f);
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <set>
+
+template <class Container, class Converter>
+void test_node_container(Converter conv) {
+  Container c;
+  using value_type = typename Container::value_type;
+  for (int i = 0; i != 10; ++i)
+    c.insert(conv(i));
+  { // Make sure that a start within the container works as expected
+    for (int i = 0; i != 10; ++i) {
+      int invoke_count = i;
+      std::for_each(std::next(c.begin(), i), c.end(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10);
+    }
+  }
+  { // Make sure that an end within the container works as expected
+    for (int i = 0; i != 10; ++i) {
+      int invoke_count = 0;
+      std::for_each(c.begin(), std::prev(c.end(), i), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10 - i);
+    }
+  }
+  {   // Make sure that an empty range works
+    { // With an element as the pointee
+      int invoke_count = 0;
+      std::for_each(c.begin(), c.begin(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 0);
+    }
+    { // With no element as the pointee
+      int invoke_count = 0;
+      std::for_each(c.end(), c.end(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 0);
+    }
+  }
+  { // Make sure that a single-element range works
+    int invoke_count = 0;
+    std::for_each(c.begin(), std::next(c.begin()), [&c, &invoke_count](const value_type& i) {
+      assert(&i == &*std::next(c.begin(), invoke_count++));
+    });
+    assert(invoke_count == 1);
+  }
+}
+
+int main(int, char**) {
+  test_node_container<std::set<int> >([](int i) { return i; });
+  test_node_container<std::multiset<int> >([](int i) { return i; });
+  test_node_container<std::map<int, int> >([](int i) { return std::make_pair(i, i); });
+  test_node_container<std::multimap<int, int> >([](int i) { return std::make_pair(i, i); });
+
+  return 0;
+}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.pass.cpp
index 3db0bde75abd7..3c0ff75fc56c7 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/for_each.pass.cpp
@@ -15,9 +15,7 @@
 #include <algorithm>
 #include <cassert>
 #include <deque>
-#if __has_include(<ranges>)
-#  include <ranges>
-#endif
+#include <ranges>
 #include <vector>
 
 #include "test_macros.h"
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/ranges.for_each.associative.pass copy.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/ranges.for_each.associative.pass copy.cpp
new file mode 100644
index 0000000000000..b78adcc461ed1
--- /dev/null
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.foreach/ranges.for_each.associative.pass copy.cpp	
@@ -0,0 +1,168 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// <algorithm>
+
+// Check that the special implementation of ranges::for_each for the associative container iterators works as expected
+
+// template<input_iterator I, sentinel_for<I> S, class Proj = identity,
+//          indirectly_unary_invocable<projected<I, Proj>> Fun>
+//   constexpr ranges::for_each_result<I, Fun>
+//     ranges::for_each(I first, S last, Fun f, Proj proj = {});
+// template<input_range R, class Proj = identity,
+//          indirectly_unary_invocable<projected<iterator_t<R>, Proj>> Fun>
+//   constexpr ranges::for_each_result<borrowed_iterator_t<R>, Fun>
+//     ranges::for_each(R&& r, Fun f, Proj proj = {});
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <set>
+
+template <class Container, class Converter>
+void test_node_container(Converter conv) {
+  using value_type = typename Container::value_type;
+
+  { // Check that an empty container works
+    Container c;
+    int invoke_count = 0;
+    std::ranges::for_each(c.begin(), c.end(), [&c, &invoke_count](const value_type& i) {
+      assert(&i == &*std::next(c.begin(), invoke_count++));
+    });
+    assert(invoke_count == 0);
+  }
+  { // Check that a single-element container works
+    Container c;
+    c.insert(conv(0));
+    int invoke_count = 0;
+    std::ranges::for_each(c.begin(), c.end(), [&c, &invoke_count](const value_type& i) {
+      assert(&i == &*std::next(c.begin(), invoke_count++));
+    });
+    assert(invoke_count == 1);
+  }
+  { // Check that a two-element container works
+    Container c;
+    c.insert(conv(0));
+    c.insert(conv(1));
+    int invoke_count = 0;
+    std::ranges::for_each(c.begin(), c.end(), [&c, &invoke_count](const value_type& i) {
+      assert(&i == &*std::next(c.begin(), invoke_count++));
+    });
+    assert(invoke_count == 2);
+  }
+
+  Container c;
+  for (int i = 0; i != 10; ++i)
+    c.insert(conv(i));
+
+  { // Simple check
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(c.begin(), c.end(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10);
+    }
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(c, [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10);
+    }
+  }
+  { // Make sure that a start within the container works as expected
+    {
+      int invoke_count = 1;
+      std::ranges::for_each(std::next(c.begin()), c.end(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10);
+    }
+    {
+      int invoke_count = 1;
+      std::ranges::for_each(
+          std::ranges::subrange(std::next(c.begin()), c.end()),
+          [&c, &invoke_count](const value_type& i) { assert(&i == &*std::next(c.begin(), invoke_count++)); });
+      assert(invoke_count == 10);
+    }
+  }
+  { // Make sure that a start within the container works as expected
+    {
+      int invoke_count = 2;
+      std::ranges::for_each(std::next(c.begin(), 2), c.end(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 10);
+    }
+    {
+      int invoke_count = 2;
+      std::ranges::for_each(
+          std::ranges::subrange(std::next(c.begin(), 2), c.end()),
+          [&c, &invoke_count](const value_type& i) { assert(&i == &*std::next(c.begin(), invoke_count++)); });
+      assert(invoke_count == 10);
+    }
+  }
+  { // Make sure that an end within the container works as expected
+    {
+      int invoke_count = 1;
+      std::ranges::for_each(std::next(c.begin()), std::prev(c.end()), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 9);
+    }
+    {
+      int invoke_count = 1;
+      std::ranges::for_each(
+          std::ranges::subrange(std::next(c.begin()), std::prev(c.end())),
+          [&c, &invoke_count](const value_type& i) { assert(&i == &*std::next(c.begin(), invoke_count++)); });
+      assert(invoke_count == 9);
+    }
+  }
+  { // Make sure that an empty range works
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(c.begin(), c.begin(), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 0);
+    }
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(std::ranges::subrange(c.begin(), c.begin()), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 0);
+    }
+  }
+  { // Make sure that a single-element range works
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(c.begin(), std::next(c.begin()), [&c, &invoke_count](const value_type& i) {
+        assert(&i == &*std::next(c.begin(), invoke_count++));
+      });
+      assert(invoke_count == 1);
+    }
+    {
+      int invoke_count = 0;
+      std::ranges::for_each(
+          std::ranges::subrange(c.begin(), std::next(c.begin())),
+          [&c, &invoke_count](const value_type& i) { assert(&i == &*std::next(c.begin(), invoke_count++)); });
+      assert(invoke_count == 1);
+    }
+  }
+}
+
+int main(int, char**) {
+  test_node_container<std::set<int> >([](int i) { return i; });
+  test_node_container<std::multiset<int> >([](int i) { return i; });
+  test_node_container<std::map<int, int> >([](int i) { return std::make_pair(i, i); });
+  test_node_container<std::multimap<int, int> >([](int i) { return std::make_pair(i, i); });
+
+  return 0;
+}



More information about the libcxx-commits mailing list