[libcxx-commits] [libcxx] [libc++] Optimize __tree::__find_equal (PR #155245)

via libcxx-commits libcxx-commits at lists.llvm.org
Mon Aug 25 08:56:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-libcxx

Author: Nikolas Klauser (philnik777)

<details>
<summary>Changes</summary>

```
-----------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                  old            new
-----------------------------------------------------------------------------------------------------------------------------
std::map<int, int>::insert(value) (already present)/0                                                  1.10 ns        1.13 ns
std::map<int, int>::insert(value) (already present)/32                                                 3.63 ns        3.54 ns
std::map<int, int>::insert(value) (already present)/1024                                               13.2 ns        12.6 ns
std::map<int, int>::insert(value) (already present)/8192                                               16.9 ns        17.2 ns
std::map<int, int>::insert(value) (new value)/0                                                        21.6 ns        21.9 ns
std::map<int, int>::insert(value) (new value)/32                                                       30.2 ns        29.7 ns
std::map<int, int>::insert(value) (new value)/1024                                                     68.7 ns        76.6 ns
std::map<int, int>::insert(value) (new value)/8192                                                      149 ns         175 ns
std::map<int, int>::insert(hint, value) (good hint)/0                                                  22.0 ns        22.5 ns
std::map<int, int>::insert(hint, value) (good hint)/32                                                 29.7 ns        29.6 ns
std::map<int, int>::insert(hint, value) (good hint)/1024                                               60.1 ns        58.0 ns
std::map<int, int>::insert(hint, value) (good hint)/8192                                                112 ns         119 ns
std::map<int, int>::insert(hint, value) (bad hint)/0                                                   21.5 ns        21.8 ns
std::map<int, int>::insert(hint, value) (bad hint)/32                                                  30.4 ns        30.7 ns
std::map<int, int>::insert(hint, value) (bad hint)/1024                                                78.9 ns        73.4 ns
std::map<int, int>::insert(hint, value) (bad hint)/8192                                                 176 ns         178 ns
std::map<int, int>::insert(iterator, iterator) (all new keys)/0                                         451 ns         455 ns
std::map<int, int>::insert(iterator, iterator) (all new keys)/32                                        976 ns        1002 ns
std::map<int, int>::insert(iterator, iterator) (all new keys)/1024                                    27746 ns       28672 ns
std::map<int, int>::insert(iterator, iterator) (all new keys)/8192                                   301483 ns      306763 ns
std::map<int, int>::insert(iterator, iterator) (half new keys)/0                                        451 ns         456 ns
std::map<int, int>::insert(iterator, iterator) (half new keys)/32                                       774 ns         793 ns
std::map<int, int>::insert(iterator, iterator) (half new keys)/1024                                   31396 ns       32557 ns
std::map<int, int>::insert(iterator, iterator) (half new keys)/8192                                  467307 ns      461338 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from same type)/0                      454 ns         457 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from same type)/32                    1025 ns        1037 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from same type)/1024                 29587 ns       30245 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from same type)/8192                339035 ns      343855 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from zip_view)/0                       452 ns         455 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from zip_view)/32                      948 ns         954 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from zip_view)/1024                  29199 ns       29829 ns
std::map<int, int>::insert(iterator, iterator) (product_iterator from zip_view)/8192                 282469 ns      286135 ns
std::map<int, int>::erase(key) (existent)/0                                                            24.7 ns        24.7 ns
std::map<int, int>::erase(key) (existent)/32                                                           31.6 ns        30.0 ns
std::map<int, int>::erase(key) (existent)/1024                                                         51.0 ns        51.1 ns
std::map<int, int>::erase(key) (existent)/8192                                                         65.2 ns        66.9 ns
std::map<int, int>::erase(key) (non-existent)/0                                                       0.327 ns       0.507 ns
std::map<int, int>::erase(key) (non-existent)/32                                                       5.06 ns        4.72 ns
std::map<int, int>::erase(key) (non-existent)/1024                                                     11.1 ns        11.1 ns
std::map<int, int>::erase(key) (non-existent)/8192                                                     16.7 ns        17.0 ns
std::map<int, int>::erase(iterator)/0                                                                  24.3 ns        24.0 ns
std::map<int, int>::erase(iterator)/32                                                                 24.9 ns        24.4 ns
std::map<int, int>::erase(iterator)/1024                                                               26.3 ns        25.5 ns
std::map<int, int>::erase(iterator)/8192                                                               27.4 ns        26.9 ns
std::map<int, int>::erase(iterator, iterator) (erase half the container)/0                              452 ns         453 ns
std::map<int, int>::erase(iterator, iterator) (erase half the container)/32                             687 ns         699 ns
std::map<int, int>::erase(iterator, iterator) (erase half the container)/1024                          8843 ns        8624 ns
std::map<int, int>::erase(iterator, iterator) (erase half the container)/8192                         66617 ns       67326 ns
std::map<int, int>::find(key) (existent)/0                                                            0.007 ns       0.007 ns
std::map<int, int>::find(key) (existent)/32                                                            4.30 ns        4.08 ns
std::map<int, int>::find(key) (existent)/1024                                                          7.41 ns        7.38 ns
std::map<int, int>::find(key) (existent)/8192                                                          11.9 ns        11.8 ns
std::map<int, int>::find(key) (non-existent)/0                                                        0.512 ns       0.305 ns
std::map<int, int>::find(key) (non-existent)/32                                                        4.51 ns        4.61 ns
std::map<int, int>::find(key) (non-existent)/1024                                                      21.4 ns        11.2 ns
std::map<int, int>::find(key) (non-existent)/8192                                                      27.3 ns        16.8 ns
std::map<std::string, int>::insert(value) (already present)/0                                          14.8 ns        13.7 ns
std::map<std::string, int>::insert(value) (already present)/32                                         34.0 ns        32.3 ns
std::map<std::string, int>::insert(value) (already present)/1024                                       54.7 ns        52.0 ns
std::map<std::string, int>::insert(value) (already present)/8192                                       67.3 ns        67.9 ns
std::map<std::string, int>::insert(value) (new value)/0                                                42.4 ns        45.2 ns
std::map<std::string, int>::insert(value) (new value)/32                                               78.1 ns        86.8 ns
std::map<std::string, int>::insert(value) (new value)/1024                                              158 ns         125 ns
std::map<std::string, int>::insert(value) (new value)/8192                                              191 ns         161 ns
std::map<std::string, int>::insert(hint, value) (good hint)/0                                          45.6 ns        57.9 ns
std::map<std::string, int>::insert(hint, value) (good hint)/32                                         62.1 ns        73.1 ns
std::map<std::string, int>::insert(hint, value) (good hint)/1024                                        108 ns         120 ns
std::map<std::string, int>::insert(hint, value) (good hint)/8192                                        134 ns         138 ns
std::map<std::string, int>::insert(hint, value) (bad hint)/0                                           62.1 ns        41.2 ns
std::map<std::string, int>::insert(hint, value) (bad hint)/32                                           128 ns        88.1 ns
std::map<std::string, int>::insert(hint, value) (bad hint)/1024                                         165 ns         134 ns
std::map<std::string, int>::insert(hint, value) (bad hint)/8192                                         197 ns         184 ns
std::map<std::string, int>::insert(iterator, iterator) (all new keys)/0                                 457 ns         453 ns
std::map<std::string, int>::insert(iterator, iterator) (all new keys)/32                               2435 ns        2128 ns
std::map<std::string, int>::insert(iterator, iterator) (all new keys)/1024                           184763 ns      169332 ns
std::map<std::string, int>::insert(iterator, iterator) (all new keys)/8192                          1757497 ns     1446335 ns
std::map<std::string, int>::insert(iterator, iterator) (half new keys)/0                                457 ns         453 ns
std::map<std::string, int>::insert(iterator, iterator) (half new keys)/32                              1997 ns        1587 ns
std::map<std::string, int>::insert(iterator, iterator) (half new keys)/1024                          116950 ns      118256 ns
std::map<std::string, int>::insert(iterator, iterator) (half new keys)/8192                         1236227 ns     1037691 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/0              460 ns         455 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/32            1615 ns        1628 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/1024         93612 ns      116950 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/8192        907143 ns      925279 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/0               458 ns         454 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/32             1762 ns        1776 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/1024         117041 ns      105967 ns
std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/8192         938247 ns      807760 ns
std::map<std::string, int>::erase(key) (existent)/0                                                    63.7 ns        74.4 ns
std::map<std::string, int>::erase(key) (existent)/32                                                   67.8 ns        99.8 ns
std::map<std::string, int>::erase(key) (existent)/1024                                                  119 ns         114 ns
std::map<std::string, int>::erase(key) (existent)/8192                                                  105 ns         130 ns
std::map<std::string, int>::erase(key) (non-existent)/0                                               0.309 ns       0.323 ns
std::map<std::string, int>::erase(key) (non-existent)/32                                               38.4 ns        22.8 ns
std::map<std::string, int>::erase(key) (non-existent)/1024                                             89.8 ns        53.8 ns
std::map<std::string, int>::erase(key) (non-existent)/8192                                              121 ns        69.7 ns
std::map<std::string, int>::erase(iterator)/0                                                          36.7 ns        38.7 ns
std::map<std::string, int>::erase(iterator)/32                                                         37.3 ns        37.7 ns
std::map<std::string, int>::erase(iterator)/1024                                                       37.3 ns        38.8 ns
std::map<std::string, int>::erase(iterator)/8192                                                       37.5 ns        38.2 ns
std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/0                      444 ns         452 ns
std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/32                     788 ns         824 ns
std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/1024                 22047 ns       21292 ns
std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/8192                134951 ns      141552 ns
std::map<std::string, int>::find(key) (existent)/0                                                    0.057 ns       0.011 ns
std::map<std::string, int>::find(key) (existent)/32                                                    32.6 ns        19.9 ns
std::map<std::string, int>::find(key) (existent)/1024                                                  59.5 ns        43.7 ns
std::map<std::string, int>::find(key) (existent)/8192                                                  76.6 ns        61.9 ns
std::map<std::string, int>::find(key) (non-existent)/0                                                0.293 ns       0.338 ns
std::map<std::string, int>::find(key) (non-existent)/32                                                38.2 ns        24.2 ns
std::map<std::string, int>::find(key) (non-existent)/1024                                              91.0 ns        54.6 ns
std::map<std::string, int>::find(key) (non-existent)/8192                                               119 ns        70.6 ns
```

---
Full diff: https://github.com/llvm/llvm-project/pull/155245.diff


6 Files Affected:

- (modified) libcxx/include/CMakeLists.txt (+1) 
- (modified) libcxx/include/__tree (+30-2) 
- (modified) libcxx/include/__type_traits/enable_if.h (+3) 
- (added) libcxx/include/__utility/three_way_comparator.h (+64) 
- (modified) libcxx/include/map (+12) 
- (modified) libcxx/include/string (+16-1) 


``````````diff
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index c6b87a34a43e9..acda31701adf3 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -932,6 +932,7 @@ set(files
   __utility/scope_guard.h
   __utility/small_buffer.h
   __utility/swap.h
+  __utility/three_way_comparator.h
   __utility/to_underlying.h
   __utility/unreachable.h
   __variant/monostate.h
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 0f3640ef6a834..670e28cbb8d2b 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -38,6 +38,7 @@
 #include <__utility/move.h>
 #include <__utility/pair.h>
 #include <__utility/swap.h>
+#include <__utility/three_way_comparator.h>
 #include <limits>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -1711,7 +1712,34 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
   __node_base_pointer* __nd_ptr = __root_ptr();
   if (__nd != nullptr) {
     while (true) {
-      if (value_comp()(__v, __nd->__get_value())) {
+#ifndef _LIBCPP_CXX03_LANG
+      static const bool __use_three_way = __has_three_way_comparator_v<_Compare, _Key, value_type>;
+
+      int __comp_res;
+      if constexpr (__use_three_way) {
+        __comp_res = __three_way_comparator<_Compare, _Key, value_type>()(__v, __nd->__get_value());
+      }
+#endif
+
+      auto __less = [&] {
+#ifndef _LIBCPP_CXX03_LANG
+        if constexpr (__use_three_way)
+          return __comp_res < 0;
+        else
+#endif
+          return value_comp()(__v, __nd->__get_value());
+      };
+
+      auto __greater = [&] {
+#ifndef _LIBCPP_CXX03_LANG
+        if constexpr (__use_three_way)
+          return __comp_res > 0;
+        else
+#endif
+          return value_comp()(__nd->__get_value(), __v);
+      };
+
+      if (__less()) {
         if (__nd->__left_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__left_);
           __nd     = static_cast<__node_pointer>(__nd->__left_);
@@ -1719,7 +1747,7 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
           __parent = static_cast<__end_node_pointer>(__nd);
           return __parent->__left_;
         }
-      } else if (value_comp()(__nd->__get_value(), __v)) {
+      } else if (__greater()) {
         if (__nd->__right_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__right_);
           __nd     = static_cast<__node_pointer>(__nd->__right_);
diff --git a/libcxx/include/__type_traits/enable_if.h b/libcxx/include/__type_traits/enable_if.h
index ae1af6ebf17d9..83f721892ac22 100644
--- a/libcxx/include/__type_traits/enable_if.h
+++ b/libcxx/include/__type_traits/enable_if.h
@@ -38,6 +38,9 @@ template <bool _Bp, class _Tp = void>
 using enable_if_t = typename enable_if<_Bp, _Tp>::type;
 #endif
 
+template <bool __cond, class _Tp, __enable_if_t<__cond, int> = 0>
+using __enable_specialization_if = _Tp;
+
 _LIBCPP_END_NAMESPACE_STD
 
 #endif // _LIBCPP___TYPE_TRAITS_ENABLE_IF_H
diff --git a/libcxx/include/__utility/three_way_comparator.h b/libcxx/include/__utility/three_way_comparator.h
new file mode 100644
index 0000000000000..aed6f85de581a
--- /dev/null
+++ b/libcxx/include/__utility/three_way_comparator.h
@@ -0,0 +1,64 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+#define _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+
+#include <__config>
+#include <__type_traits/desugars_to.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/is_arithmetic.h>
+#include <__type_traits/remove_const_ref.h>
+#include <__type_traits/void_t.h>
+#include <__functional/operations.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _LHS, class _RHS, class = void>
+struct __three_way_comparison_traits {};
+
+template <class _Tp>
+struct __three_way_comparison_traits<_Tp, _Tp, __enable_if_t<is_arithmetic<_Tp>::value> > {
+  using __three_way_comparable = void;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(_Tp __lhs, _Tp __rhs) {
+    if (__lhs < __rhs)
+      return -1;
+    if (__lhs > __rhs)
+      return 1;
+    return 0;
+  }
+};
+
+template <class _Comparator, class _LHS, class _RHS, class = void>
+struct __three_way_comparator {};
+
+template <class _Comparator, class _LHS, class _RHS>
+struct __three_way_comparator<_Comparator,
+                              _LHS,
+                              _RHS,
+                              __enable_if_t<__desugars_to_v<__less_tag, _Comparator, _LHS, _RHS>>>
+    : __three_way_comparison_traits<__remove_const_ref_t<_LHS>, __remove_const_ref_t<_RHS>> {};
+
+template <class _Comparator, class _LHS, class _RHS, class = void>
+inline const bool __has_three_way_comparator_v = false;
+
+template <class _Comparator, class _LHS, class _RHS>
+inline const bool __has_three_way_comparator_v<
+    _Comparator,
+    _LHS,
+    _RHS,
+    __void_t<typename __three_way_comparator<_Comparator, _LHS, _RHS>::__three_way_comparable>> = true;
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
diff --git a/libcxx/include/map b/libcxx/include/map
index 9bd2282e77a3c..ef044d3000b84 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -702,6 +702,18 @@ public:
 #  endif
 };
 
+template <class _Key, class _CP, class _Compare>
+struct __three_way_comparator<__enable_specialization_if<__has_three_way_comparator_v<_Compare, _Key, _Key>,
+                                                         __map_value_compare<_Key, _CP, _Compare>>,
+                              _Key,
+                              _CP> : __three_way_comparator<_Compare, _Key, _Key> {
+  using __base = __three_way_comparator<_Compare, _Key, _Key>;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(const _Key& __lhs, const _CP& __rhs) {
+    return __base()(__lhs, __rhs.first);
+  }
+};
+
 template <class _Key, class _CP, class _Compare, bool __b>
 inline _LIBCPP_HIDE_FROM_ABI void
 swap(__map_value_compare<_Key, _CP, _Compare, __b>& __x, __map_value_compare<_Key, _CP, _Compare, __b>& __y)
diff --git a/libcxx/include/string b/libcxx/include/string
index 1d197654b9fee..edfa80162d587 100644
--- a/libcxx/include/string
+++ b/libcxx/include/string
@@ -645,7 +645,7 @@ basic_string<char32_t> operator""s( const char32_t *str, size_t len );
 #  include <__utility/move.h>
 #  include <__utility/scope_guard.h>
 #  include <__utility/swap.h>
-#  include <__utility/unreachable.h>
+#  include <__utility/three_way_comparator.h>
 #  include <climits>
 #  include <cstdio> // EOF
 #  include <cstring>
@@ -2522,6 +2522,21 @@ _LIBCPP_STRING_V1_EXTERN_TEMPLATE_LIST(_LIBCPP_DECLARE, wchar_t)
 #  endif
 #  undef _LIBCPP_DECLARE
 
+template <class _CharT, class _Traits, class _Alloc>
+struct __three_way_comparison_traits<basic_string<_CharT, _Traits, _Alloc>, basic_string<_CharT, _Traits, _Alloc>> {
+  using __string_t = basic_string<_CharT, _Traits, _Alloc>;
+
+  using __three_way_comparable = void;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(const __string_t& __lhs, const __string_t& __rhs) {
+    auto __min_len = std::min(__lhs.size(), __rhs.size());
+    auto __ret     = _Traits::compare(__lhs.data(), __rhs.data(), __min_len);
+    if (__ret == 0)
+      return __lhs.size() == __rhs.size() ? 0 : __lhs.size() < __rhs.size() ? -1 : 1;
+    return __ret;
+  }
+};
+
 #  if _LIBCPP_STD_VER >= 17
 template <class _InputIterator,
           class _CharT     = __iter_value_type<_InputIterator>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/155245


More information about the libcxx-commits mailing list