[libcxx-commits] [libcxx] [libc++] Optimize {set, map}::{lower, upper}_bound (PR #161366)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Wed Oct 8 04:14:52 PDT 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/161366

>From 7ee59ec131965a9489ac3ccabd0cbb4ecefc0d30 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 30 Sep 2025 13:02:09 +0200
Subject: [PATCH] [libc++] Optimize {set,map}::{lower,upper}_bound

---
 libcxx/include/__tree | 105 ++++++++++++++++++++++++++++++++----------
 libcxx/include/map    |  50 +++++++++++++-------
 libcxx/include/set    |  50 +++++++++++++-------
 3 files changed, 149 insertions(+), 56 deletions(-)

diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index ef960d481cb7b..d7d074a00f555 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -1166,32 +1166,87 @@ public:
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI size_type __count_multi(const _Key& __k) const;
 
+  template <bool _LowerBound, class _Key>
+  _LIBCPP_HIDE_FROM_ABI __end_node_pointer __lower_upper_bound_unique_impl(const _Key& __v) const {
+    auto __rt     = __root();
+    auto __result = __end_node();
+    auto __comp   = __lazy_synth_three_way_comparator<_Compare, _Key, value_type>(value_comp());
+    while (__rt != nullptr) {
+      auto __comp_res = __comp(__v, __rt->__get_value());
+
+      if (__comp_res.__less()) {
+        __result = static_cast<__end_node_pointer>(__rt);
+        __rt     = static_cast<__node_pointer>(__rt->__left_);
+      } else if (__comp_res.__greater()) {
+        __rt = static_cast<__node_pointer>(__rt->__right_);
+      } else if _LIBCPP_CONSTEXPR (_LowerBound) {
+        return static_cast<__end_node_pointer>(__rt);
+      } else {
+        return __rt->__right_ ? static_cast<__end_node_pointer>(std::__tree_min(__rt->__right_)) : __result;
+      }
+    }
+    return __result;
+  }
+
+  template <class _Key>
+  _LIBCPP_HIDE_FROM_ABI iterator __lower_bound_unique(const _Key& __v) {
+    return iterator(__lower_upper_bound_unique_impl<true>(__v));
+  }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const _Key& __v) {
-    return __lower_bound(__v, __root(), __end_node());
+  _LIBCPP_HIDE_FROM_ABI const_iterator __lower_bound_unique(const _Key& __v) const {
+    return const_iterator(__lower_upper_bound_unique_impl<true>(__v));
   }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI iterator __lower_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result);
+  _LIBCPP_HIDE_FROM_ABI iterator __upper_bound_unique(const _Key& __v) {
+    return iterator(__lower_upper_bound_unique_impl<false>(__v));
+  }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const _Key& __v) const {
-    return __lower_bound(__v, __root(), __end_node());
+  _LIBCPP_HIDE_FROM_ABI const_iterator __upper_bound_unique(const _Key& __v) const {
+    return iterator(__lower_upper_bound_unique_impl<false>(__v));
   }
+
+private:
+  template <class _Key>
+  _LIBCPP_HIDE_FROM_ABI iterator
+  __lower_bound_multi(const _Key& __v, __node_pointer __root, __end_node_pointer __result);
+
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI const_iterator
-  __lower_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result) const;
+  __lower_bound_multi(const _Key& __v, __node_pointer __root, __end_node_pointer __result) const;
+
+public:
+  template <class _Key>
+  _LIBCPP_HIDE_FROM_ABI iterator __lower_bound_multi(const _Key& __v) {
+    return __lower_bound_multi(__v, __root(), __end_node());
+  }
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const _Key& __v) {
-    return __upper_bound(__v, __root(), __end_node());
+  _LIBCPP_HIDE_FROM_ABI const_iterator __lower_bound_multi(const _Key& __v) const {
+    return __lower_bound_multi(__v, __root(), __end_node());
   }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI iterator __upper_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result);
+  _LIBCPP_HIDE_FROM_ABI iterator __upper_bound_multi(const _Key& __v) {
+    return __upper_bound_multi(__v, __root(), __end_node());
+  }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const _Key& __v) const {
-    return __upper_bound(__v, __root(), __end_node());
+  _LIBCPP_HIDE_FROM_ABI const_iterator __upper_bound_multi(const _Key& __v) const {
+    return __upper_bound_multi(__v, __root(), __end_node());
   }
+
+private:
+  template <class _Key>
+  _LIBCPP_HIDE_FROM_ABI iterator
+  __upper_bound_multi(const _Key& __v, __node_pointer __root, __end_node_pointer __result);
+
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI const_iterator
-  __upper_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result) const;
+  __upper_bound_multi(const _Key& __v, __node_pointer __root, __end_node_pointer __result) const;
+
+public:
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI pair<iterator, iterator> __equal_range_unique(const _Key& __k);
   template <class _Key>
@@ -2100,16 +2155,16 @@ __tree<_Tp, _Compare, _Allocator>::__count_multi(const _Key& __k) const {
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return std::distance(
-          __lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
-          __upper_bound(__k, static_cast<__node_pointer>(__rt->__right_), __result));
+          __lower_bound_multi(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
+          __upper_bound_multi(__k, static_cast<__node_pointer>(__rt->__right_), __result));
   }
   return 0;
 }
 
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::iterator
-__tree<_Tp, _Compare, _Allocator>::__lower_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result) {
+typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::__lower_bound_multi(
+    const _Key& __v, __node_pointer __root, __end_node_pointer __result) {
   while (__root != nullptr) {
     if (!value_comp()(__root->__get_value(), __v)) {
       __result = static_cast<__end_node_pointer>(__root);
@@ -2122,7 +2177,7 @@ __tree<_Tp, _Compare, _Allocator>::__lower_bound(const _Key& __v, __node_pointer
 
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::const_iterator __tree<_Tp, _Compare, _Allocator>::__lower_bound(
+typename __tree<_Tp, _Compare, _Allocator>::const_iterator __tree<_Tp, _Compare, _Allocator>::__lower_bound_multi(
     const _Key& __v, __node_pointer __root, __end_node_pointer __result) const {
   while (__root != nullptr) {
     if (!value_comp()(__root->__get_value(), __v)) {
@@ -2136,8 +2191,8 @@ typename __tree<_Tp, _Compare, _Allocator>::const_iterator __tree<_Tp, _Compare,
 
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::iterator
-__tree<_Tp, _Compare, _Allocator>::__upper_bound(const _Key& __v, __node_pointer __root, __end_node_pointer __result) {
+typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::__upper_bound_multi(
+    const _Key& __v, __node_pointer __root, __end_node_pointer __result) {
   while (__root != nullptr) {
     if (value_comp()(__v, __root->__get_value())) {
       __result = static_cast<__end_node_pointer>(__root);
@@ -2150,7 +2205,7 @@ __tree<_Tp, _Compare, _Allocator>::__upper_bound(const _Key& __v, __node_pointer
 
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::const_iterator __tree<_Tp, _Compare, _Allocator>::__upper_bound(
+typename __tree<_Tp, _Compare, _Allocator>::const_iterator __tree<_Tp, _Compare, _Allocator>::__upper_bound_multi(
     const _Key& __v, __node_pointer __root, __end_node_pointer __result) const {
   while (__root != nullptr) {
     if (value_comp()(__v, __root->__get_value())) {
@@ -2226,8 +2281,9 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) {
     } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
-      return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
-                 __upper_bound(__k, static_cast<__node_pointer>(__rt->__right_), __result));
+      return _Pp(
+          __lower_bound_multi(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
+          __upper_bound_multi(__k, static_cast<__node_pointer>(__rt->__right_), __result));
   }
   return _Pp(iterator(__result), iterator(__result));
 }
@@ -2249,8 +2305,9 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) const {
     } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
-      return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
-                 __upper_bound(__k, static_cast<__node_pointer>(__rt->__right_), __result));
+      return _Pp(
+          __lower_bound_multi(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
+          __upper_bound_multi(__k, static_cast<__node_pointer>(__rt->__right_), __result));
   }
   return _Pp(const_iterator(__result), const_iterator(__result));
 }
diff --git a/libcxx/include/map b/libcxx/include/map
index 035f913bd3497..3ff849afcde09 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -1300,38 +1300,48 @@ public:
   }
 #  endif // _LIBCPP_STD_VER >= 20
 
-  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.lower_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const { return __tree_.lower_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.__lower_bound_unique(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const {
+    return __tree_.__lower_bound_unique(__k);
+  }
+
+  // The transparent versions of the lookup functions use the _multi version, since a non-element key is allowed to
+  // match multiple elements.
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2,
             enable_if_t<__is_transparent_v<_Compare, _K2> || __is_transparently_comparable_v<_Compare, key_type, _K2>,
                         int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const _K2& __k) {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 
   template <typename _K2,
             enable_if_t<__is_transparent_v<_Compare, _K2> || __is_transparently_comparable_v<_Compare, key_type, _K2>,
                         int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const _K2& __k) const {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 #  endif
 
-  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.upper_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const { return __tree_.upper_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.__upper_bound_unique(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const {
+    return __tree_.__upper_bound_unique(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2,
             enable_if_t<__is_transparent_v<_Compare, _K2> || __is_transparently_comparable_v<_Compare, key_type, _K2>,
                         int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const _K2& __k) {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
   template <typename _K2,
             enable_if_t<__is_transparent_v<_Compare, _K2> || __is_transparently_comparable_v<_Compare, key_type, _K2>,
                         int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const _K2& __k) const {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
 #  endif
 
@@ -1871,30 +1881,38 @@ public:
   }
 #  endif // _LIBCPP_STD_VER >= 20
 
-  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.lower_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const { return __tree_.lower_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.__lower_bound_multi(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const {
+    return __tree_.__lower_bound_multi(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const _K2& __k) {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const _K2& __k) const {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 #  endif
 
-  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.upper_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const { return __tree_.upper_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.__upper_bound_multi(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const {
+    return __tree_.__upper_bound_multi(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const _K2& __k) {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const _K2& __k) const {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
 #  endif
 
diff --git a/libcxx/include/set b/libcxx/include/set
index 4203c69b55c84..59ed0155c1def 100644
--- a/libcxx/include/set
+++ b/libcxx/include/set
@@ -849,30 +849,40 @@ public:
   }
 #  endif // _LIBCPP_STD_VER >= 20
 
-  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.lower_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const { return __tree_.lower_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.__lower_bound_unique(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const {
+    return __tree_.__lower_bound_unique(__k);
+  }
+
+  // The transparent versions of the lookup functions use the _multi version, since a non-element key is allowed to
+  // match multiple elements.
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const _K2& __k) {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const _K2& __k) const {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 #  endif
 
-  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.upper_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const { return __tree_.upper_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.__upper_bound_unique(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const {
+    return __tree_.__upper_bound_unique(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const _K2& __k) {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const _K2& __k) const {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
 #  endif
 
@@ -1301,30 +1311,38 @@ public:
   }
 #  endif // _LIBCPP_STD_VER >= 20
 
-  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.lower_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const { return __tree_.lower_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const key_type& __k) { return __tree_.__lower_bound_multi(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const key_type& __k) const {
+    return __tree_.__lower_bound_multi(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator lower_bound(const _K2& __k) {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator lower_bound(const _K2& __k) const {
-    return __tree_.lower_bound(__k);
+    return __tree_.__lower_bound_multi(__k);
   }
 #  endif
 
-  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.upper_bound(__k); }
-  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const { return __tree_.upper_bound(__k); }
+  _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const key_type& __k) { return __tree_.__upper_bound_multi(__k); }
+
+  _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const key_type& __k) const {
+    return __tree_.__upper_bound_multi(__k);
+  }
+
 #  if _LIBCPP_STD_VER >= 14
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI iterator upper_bound(const _K2& __k) {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
   template <typename _K2, enable_if_t<__is_transparent_v<_Compare, _K2>, int> = 0>
   _LIBCPP_HIDE_FROM_ABI const_iterator upper_bound(const _K2& __k) const {
-    return __tree_.upper_bound(__k);
+    return __tree_.__upper_bound_multi(__k);
   }
 #  endif
 



More information about the libcxx-commits mailing list