[libcxx-commits] [libcxx] [libc++] Optimize most of the __tree search algorithms (PR #155245)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Fri Aug 29 02:47:27 PDT 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/155245

>From 9a4316bc7e0954bcb653195bd1462f641ecaf695 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Fri, 22 Aug 2025 10:39:37 +0200
Subject: [PATCH] [libc++] Optimize __tree::__find_equal

---
 libcxx/include/CMakeLists.txt                 |   1 +
 libcxx/include/__tree                         |  45 ++++---
 .../include/__utility/three_way_comparator.h  | 114 ++++++++++++++++++
 libcxx/include/map                            |  42 +++++++
 libcxx/include/module.modulemap.in            |   1 +
 libcxx/include/string                         |  17 ++-
 6 files changed, 204 insertions(+), 16 deletions(-)
 create mode 100644 libcxx/include/__utility/three_way_comparator.h

diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 6fd16419f0c49..284465a50d4aa 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -931,6 +931,7 @@ set(files
   __utility/scope_guard.h
   __utility/small_buffer.h
   __utility/swap.h
+  __utility/three_way_comparator.h
   __utility/to_underlying.h
   __utility/try_key_extraction.h
   __utility/unreachable.h
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 3ad2129ba9ddf..f9ecefbe61e65 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -37,6 +37,7 @@
 #include <__utility/move.h>
 #include <__utility/pair.h>
 #include <__utility/swap.h>
+#include <__utility/three_way_comparator.h>
 #include <__utility/try_key_extraction.h>
 #include <limits>
 
@@ -1696,8 +1697,12 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
   __node_pointer __nd           = __root();
   __node_base_pointer* __nd_ptr = __root_ptr();
   if (__nd != nullptr) {
+    auto __comp = __lazy_synth_three_way_comparator<_Compare, _Key, value_type>(value_comp());
+
     while (true) {
-      if (value_comp()(__v, __nd->__get_value())) {
+      auto __comp_res = __comp(__v, __nd->__get_value());
+
+      if (__comp_res.__less()) {
         if (__nd->__left_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__left_);
           __nd     = static_cast<__node_pointer>(__nd->__left_);
@@ -1705,7 +1710,7 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
           __parent = static_cast<__end_node_pointer>(__nd);
           return __parent->__left_;
         }
-      } else if (value_comp()(__nd->__get_value(), __v)) {
+      } else if (__comp_res.__greater()) {
         if (__nd->__right_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__right_);
           __nd     = static_cast<__node_pointer>(__nd->__right_);
@@ -2030,10 +2035,12 @@ template <class _Key>
 typename __tree<_Tp, _Compare, _Allocator>::size_type
 __tree<_Tp, _Compare, _Allocator>::__count_unique(const _Key& __k) const {
   __node_pointer __rt = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __rt = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return 1;
@@ -2047,11 +2054,13 @@ typename __tree<_Tp, _Compare, _Allocator>::size_type
 __tree<_Tp, _Compare, _Allocator>::__count_multi(const _Key& __k) const {
   __end_node_pointer __result = __end_node();
   __node_pointer __rt         = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __result = static_cast<__end_node_pointer>(__rt);
       __rt     = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return std::distance(
@@ -2124,11 +2133,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) {
   typedef pair<iterator, iterator> _Pp;
   __end_node_pointer __result = __end_node();
   __node_pointer __rt         = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __result = static_cast<__end_node_pointer>(__rt);
       __rt     = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return _Pp(iterator(__rt),
@@ -2146,11 +2157,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) const {
   typedef pair<const_iterator, const_iterator> _Pp;
   __end_node_pointer __result = __end_node();
   __node_pointer __rt         = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __result = static_cast<__end_node_pointer>(__rt);
       __rt     = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return _Pp(
@@ -2168,11 +2181,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) {
   typedef pair<iterator, iterator> _Pp;
   __end_node_pointer __result = __end_node();
   __node_pointer __rt     = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __result = static_cast<__end_node_pointer>(__rt);
       __rt     = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
@@ -2189,11 +2204,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) const {
   typedef pair<const_iterator, const_iterator> _Pp;
   __end_node_pointer __result = __end_node();
   __node_pointer __rt     = __root();
+  auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
   while (__rt != nullptr) {
-    if (value_comp()(__k, __rt->__get_value())) {
+    auto __comp_res = __comp(__k, __rt->__get_value());
+    if (__comp_res.__less()) {
       __result = static_cast<__end_node_pointer>(__rt);
       __rt     = static_cast<__node_pointer>(__rt->__left_);
-    } else if (value_comp()(__rt->__get_value(), __k))
+    } else if (__comp_res.__greater())
       __rt = static_cast<__node_pointer>(__rt->__right_);
     else
       return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
diff --git a/libcxx/include/__utility/three_way_comparator.h b/libcxx/include/__utility/three_way_comparator.h
new file mode 100644
index 0000000000000..c34023328b5c3
--- /dev/null
+++ b/libcxx/include/__utility/three_way_comparator.h
@@ -0,0 +1,114 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+#define _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+
+#include <__config>
+#include <__type_traits/desugars_to.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/is_arithmetic.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+// This file adds a __lazy_synth_three_way_comparator, which tries to build an efficient three way comparison from a
+// binary comparator. That is done in multiple steps:
+// 1) Check whether the comparator desugars to a less than operator
+//    If that is the case, check whether there exists a specialization of `__default_three_way_comparator`, which
+//    can be specialized to implement a three way comparator for the specific types.
+// 2) Fall back to doing a lazy less than/greater than comparison
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+// This struct can be specialized to provide a thee way comparator between _LHS and _RHS.
+// The return value should be
+// - less than zero if (lhs_val < rhs_val)
+// - greater than zero if (rhs_val < lhs_val)
+// - zero otherwise
+template <class _LHS, class _RHS, class = void>
+struct __default_three_way_comparator;
+
+template <class _Tp>
+struct __default_three_way_comparator<_Tp, _Tp, __enable_if_t<is_arithmetic<_Tp>::value> > {
+  _LIBCPP_HIDE_FROM_ABI static int operator()(_Tp __lhs, _Tp __rhs) {
+    if (__lhs < __rhs)
+      return -1;
+    if (__lhs > __rhs)
+      return 1;
+    return 0;
+  }
+};
+
+template <class _LHS, class _RHS, bool = true>
+inline const bool __has_default_three_way_comparator_v = false;
+
+template <class _LHS, class _RHS>
+inline const bool
+    __has_default_three_way_comparator_v< _LHS, _RHS, sizeof(__default_three_way_comparator<_LHS, _RHS>) >= 0> = true;
+
+template <class _Comparator, class _LHS, class _RHS>
+struct __lazy_compare_result {
+  const _Comparator& __comp_;
+  const _LHS& __lhs_;
+  const _RHS& __rhs_;
+
+  __lazy_compare_result(_LIBCPP_LIFETIMEBOUND const _Comparator& __comp,
+                        _LIBCPP_LIFETIMEBOUND const _LHS& __lhs,
+                        _LIBCPP_LIFETIMEBOUND const _RHS& __rhs)
+      : __comp_(__comp), __lhs_(__lhs), __rhs_(__rhs) {}
+
+  bool __less() const { return __comp_(__lhs_, __rhs_); }
+  bool __greater() const { return __comp_(__rhs_, __lhs_); }
+};
+
+// This class provides three way comparion between _LHS and _RHS as efficiently as possible. This can be specialized if
+// a comparator only compares part of the object, potentially allowing an efficient three way comparison between the
+// subobjects. The specialization should use the __lazy_synth_three_way_comparator for the subobjects to achieve this.
+template <class _Comparator, class _LHS, class _RHS, class = void>
+struct __lazy_synth_three_way_comparator {
+  const _Comparator& __comp_;
+
+  __lazy_synth_three_way_comparator(_LIBCPP_LIFETIMEBOUND const _Comparator& __comp) : __comp_(__comp) {}
+
+  __lazy_compare_result<_Comparator, _LHS, _RHS>
+  operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) const {
+    return __lazy_compare_result<_Comparator, _LHS, _RHS>(__comp_, __lhs, __rhs);
+  }
+};
+
+struct __eager_compare_result {
+  int __res_;
+
+  explicit __eager_compare_result(int __res) : __res_(__res) {}
+
+  bool __less() const { return __res_ < 0; }
+  bool __greater() const { return __res_ > 0; }
+};
+
+template <class _Comparator, class _LHS, class _RHS>
+struct __lazy_synth_three_way_comparator<_Comparator,
+                              _LHS,
+                              _RHS,
+                              __enable_if_t<__desugars_to_v<__less_tag, _Comparator, _LHS, _RHS> &&
+                                            __has_default_three_way_comparator_v<_LHS, _RHS> > > {
+  // This lifetimebound annotation is technically incorrect, but other specializations actually capture the lifetime of
+  // the comparator.
+  __lazy_synth_three_way_comparator(_LIBCPP_LIFETIMEBOUND const _Comparator&) {}
+
+  // Same comment as above.
+  static __eager_compare_result
+  operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) {
+    return __eager_compare_result(__default_three_way_comparator<_LHS, _RHS>()(__lhs, __rhs));
+  }
+};
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
diff --git a/libcxx/include/map b/libcxx/include/map
index b53e1c4213487..3f62b98125f85 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -606,6 +606,7 @@ erase_if(multimap<Key, T, Compare, Allocator>& c, Predicate pred);  // C++20
 #  include <__utility/pair.h>
 #  include <__utility/piecewise_construct.h>
 #  include <__utility/swap.h>
+#  include <__utility/three_way_comparator.h>
 #  include <stdexcept>
 #  include <tuple>
 #  include <version>
@@ -702,6 +703,47 @@ public:
 #  endif
 };
 
+#  if _LIBCPP_STD_VER >= 14
+template <class _MapValueT, class _Key, class _Compare>
+struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _MapValueT> {
+  __lazy_synth_three_way_comparator<_Compare, _Key, _Key> __comp_;
+
+  __lazy_synth_three_way_comparator(_LIBCPP_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
+      : __comp_(__comp.key_comp()) {}
+
+  _LIBCPP_HIDE_FROM_ABI auto
+  operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
+    return __comp_(__lhs.first, __rhs.first);
+  }
+};
+
+template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
+struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _TransparentKey, _MapValueT> {
+  __lazy_synth_three_way_comparator<_Compare, _TransparentKey, _Key> __comp_;
+
+  __lazy_synth_three_way_comparator(_LIBCPP_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
+      : __comp_(__comp.key_comp()) {}
+
+  _LIBCPP_HIDE_FROM_ABI auto
+  operator()(_LIBCPP_LIFETIMEBOUND const _TransparentKey& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
+    return __comp_(__lhs, __rhs.first);
+  }
+};
+
+template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
+struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _TransparentKey> {
+  __lazy_synth_three_way_comparator<_Compare, _Key, _TransparentKey> __comp_;
+
+  __lazy_synth_three_way_comparator(_LIBCPP_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
+      : __comp_(__comp.key_comp()) {}
+
+  _LIBCPP_HIDE_FROM_ABI auto
+  operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _TransparentKey& __rhs) const {
+    return __comp_(__lhs.first, __rhs);
+  }
+};
+#  endif // _LIBCPP_STD_VER >= 14
+
 template <class _Key, class _CP, class _Compare, bool __b>
 inline _LIBCPP_HIDE_FROM_ABI void
 swap(__map_value_compare<_Key, _CP, _Compare, __b>& __x, __map_value_compare<_Key, _CP, _Compare, __b>& __y)
diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in
index ee18d04c78d0e..e25086271ff7c 100644
--- a/libcxx/include/module.modulemap.in
+++ b/libcxx/include/module.modulemap.in
@@ -2176,6 +2176,7 @@ module std [system] {
     module scope_guard                { header "__utility/scope_guard.h" }
     module small_buffer               { header "__utility/small_buffer.h" }
     module swap                       { header "__utility/swap.h" }
+    module three_way_comparator       { header "__utility/three_way_comparator.h" }
     module to_underlying              { header "__utility/to_underlying.h" }
     module try_key_extraction         { header "__utility/try_key_extraction.h" }
     module unreachable                { header "__utility/unreachable.h" }
diff --git a/libcxx/include/string b/libcxx/include/string
index 1d197654b9fee..032ac13c7ccf9 100644
--- a/libcxx/include/string
+++ b/libcxx/include/string
@@ -645,7 +645,7 @@ basic_string<char32_t> operator""s( const char32_t *str, size_t len );
 #  include <__utility/move.h>
 #  include <__utility/scope_guard.h>
 #  include <__utility/swap.h>
-#  include <__utility/unreachable.h>
+#  include <__utility/three_way_comparator.h>
 #  include <climits>
 #  include <cstdio> // EOF
 #  include <cstring>
@@ -966,7 +966,7 @@ private:
         std::__wrap_iter<const_pointer>(__get_pointer() + size()));
 #  else
     return const_iterator(__p);
-#  endif // _LIBCPP_ABI_BOUNDED_ITERATORS_IN_STRING
+#  endif                    // _LIBCPP_ABI_BOUNDED_ITERATORS_IN_STRING
   }
 
 public:
@@ -2522,6 +2522,19 @@ _LIBCPP_STRING_V1_EXTERN_TEMPLATE_LIST(_LIBCPP_DECLARE, wchar_t)
 #  endif
 #  undef _LIBCPP_DECLARE
 
+template <class _CharT, class _Traits, class _Alloc>
+struct __default_three_way_comparator<basic_string<_CharT, _Traits, _Alloc>, basic_string<_CharT, _Traits, _Alloc>> {
+  using __string_t _LIBCPP_NODEBUG = basic_string<_CharT, _Traits, _Alloc>;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(const __string_t& __lhs, const __string_t& __rhs) {
+    auto __min_len = std::min(__lhs.size(), __rhs.size());
+    auto __ret     = _Traits::compare(__lhs.data(), __rhs.data(), __min_len);
+    if (__ret == 0)
+      return __lhs.size() == __rhs.size() ? 0 : __lhs.size() < __rhs.size() ? -1 : 1;
+    return __ret;
+  }
+};
+
 #  if _LIBCPP_STD_VER >= 17
 template <class _InputIterator,
           class _CharT     = __iter_value_type<_InputIterator>,



More information about the libcxx-commits mailing list