[libcxx-commits] [libcxx] [libc++] Optimize __tree::__find_equal (PR #155245)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Mon Aug 25 06:50:05 PDT 2025


https://github.com/philnik777 created https://github.com/llvm/llvm-project/pull/155245

None

>From 648da5888874356068443a589eb4e46413f55007 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Fri, 22 Aug 2025 10:39:37 +0200
Subject: [PATCH] [libc++] Optimize __tree::__find_equal

---
 libcxx/include/CMakeLists.txt                 |  1 +
 libcxx/include/__tree                         | 32 +++++++++-
 libcxx/include/__type_traits/enable_if.h      |  3 +
 .../include/__utility/three_way_comparator.h  | 64 +++++++++++++++++++
 libcxx/include/map                            | 12 ++++
 libcxx/include/string                         | 17 ++++-
 6 files changed, 126 insertions(+), 3 deletions(-)
 create mode 100644 libcxx/include/__utility/three_way_comparator.h

diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index c6b87a34a43e9..acda31701adf3 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -932,6 +932,7 @@ set(files
   __utility/scope_guard.h
   __utility/small_buffer.h
   __utility/swap.h
+  __utility/three_way_comparator.h
   __utility/to_underlying.h
   __utility/unreachable.h
   __variant/monostate.h
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 0f3640ef6a834..670e28cbb8d2b 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -38,6 +38,7 @@
 #include <__utility/move.h>
 #include <__utility/pair.h>
 #include <__utility/swap.h>
+#include <__utility/three_way_comparator.h>
 #include <limits>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -1711,7 +1712,34 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
   __node_base_pointer* __nd_ptr = __root_ptr();
   if (__nd != nullptr) {
     while (true) {
-      if (value_comp()(__v, __nd->__get_value())) {
+#ifndef _LIBCPP_CXX03_LANG
+      static const bool __use_three_way = __has_three_way_comparator_v<_Compare, _Key, value_type>;
+
+      int __comp_res;
+      if constexpr (__use_three_way) {
+        __comp_res = __three_way_comparator<_Compare, _Key, value_type>()(__v, __nd->__get_value());
+      }
+#endif
+
+      auto __less = [&] {
+#ifndef _LIBCPP_CXX03_LANG
+        if constexpr (__use_three_way)
+          return __comp_res < 0;
+        else
+#endif
+          return value_comp()(__v, __nd->__get_value());
+      };
+
+      auto __greater = [&] {
+#ifndef _LIBCPP_CXX03_LANG
+        if constexpr (__use_three_way)
+          return __comp_res > 0;
+        else
+#endif
+          return value_comp()(__nd->__get_value(), __v);
+      };
+
+      if (__less()) {
         if (__nd->__left_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__left_);
           __nd     = static_cast<__node_pointer>(__nd->__left_);
@@ -1719,7 +1747,7 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, co
           __parent = static_cast<__end_node_pointer>(__nd);
           return __parent->__left_;
         }
-      } else if (value_comp()(__nd->__get_value(), __v)) {
+      } else if (__greater()) {
         if (__nd->__right_ != nullptr) {
           __nd_ptr = std::addressof(__nd->__right_);
           __nd     = static_cast<__node_pointer>(__nd->__right_);
diff --git a/libcxx/include/__type_traits/enable_if.h b/libcxx/include/__type_traits/enable_if.h
index ae1af6ebf17d9..83f721892ac22 100644
--- a/libcxx/include/__type_traits/enable_if.h
+++ b/libcxx/include/__type_traits/enable_if.h
@@ -38,6 +38,9 @@ template <bool _Bp, class _Tp = void>
 using enable_if_t = typename enable_if<_Bp, _Tp>::type;
 #endif
 
+template <bool __cond, class _Tp, __enable_if_t<__cond, int> = 0>
+using __enable_specialization_if = _Tp;
+
 _LIBCPP_END_NAMESPACE_STD
 
 #endif // _LIBCPP___TYPE_TRAITS_ENABLE_IF_H
diff --git a/libcxx/include/__utility/three_way_comparator.h b/libcxx/include/__utility/three_way_comparator.h
new file mode 100644
index 0000000000000..aed6f85de581a
--- /dev/null
+++ b/libcxx/include/__utility/three_way_comparator.h
@@ -0,0 +1,64 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+#define _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
+
+#include <__config>
+#include <__type_traits/desugars_to.h>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/is_arithmetic.h>
+#include <__type_traits/remove_const_ref.h>
+#include <__type_traits/void_t.h>
+#include <__functional/operations.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _LHS, class _RHS, class = void>
+struct __three_way_comparison_traits {};
+
+template <class _Tp>
+struct __three_way_comparison_traits<_Tp, _Tp, __enable_if_t<is_arithmetic<_Tp>::value> > {
+  using __three_way_comparable = void;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(_Tp __lhs, _Tp __rhs) {
+    if (__lhs < __rhs)
+      return -1;
+    if (__lhs > __rhs)
+      return 1;
+    return 0;
+  }
+};
+
+template <class _Comparator, class _LHS, class _RHS, class = void>
+struct __three_way_comparator {};
+
+template <class _Comparator, class _LHS, class _RHS>
+struct __three_way_comparator<_Comparator,
+                              _LHS,
+                              _RHS,
+                              __enable_if_t<__desugars_to_v<__less_tag, _Comparator, _LHS, _RHS>>>
+    : __three_way_comparison_traits<__remove_const_ref_t<_LHS>, __remove_const_ref_t<_RHS>> {};
+
+template <class _Comparator, class _LHS, class _RHS, class = void>
+inline const bool __has_three_way_comparator_v = false;
+
+template <class _Comparator, class _LHS, class _RHS>
+inline const bool __has_three_way_comparator_v<
+    _Comparator,
+    _LHS,
+    _RHS,
+    __void_t<typename __three_way_comparator<_Comparator, _LHS, _RHS>::__three_way_comparable>> = true;
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___UTILITY_THREE_WAY_COMPARATOR_H
diff --git a/libcxx/include/map b/libcxx/include/map
index 9bd2282e77a3c..ef044d3000b84 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -702,6 +702,18 @@ public:
 #  endif
 };
 
+template <class _Key, class _CP, class _Compare>
+struct __three_way_comparator<__enable_specialization_if<__has_three_way_comparator_v<_Compare, _Key, _Key>,
+                                                         __map_value_compare<_Key, _CP, _Compare>>,
+                              _Key,
+                              _CP> : __three_way_comparator<_Compare, _Key, _Key> {
+  using __base = __three_way_comparator<_Compare, _Key, _Key>;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(const _Key& __lhs, const _CP& __rhs) {
+    return __base()(__lhs, __rhs.first);
+  }
+};
+
 template <class _Key, class _CP, class _Compare, bool __b>
 inline _LIBCPP_HIDE_FROM_ABI void
 swap(__map_value_compare<_Key, _CP, _Compare, __b>& __x, __map_value_compare<_Key, _CP, _Compare, __b>& __y)
diff --git a/libcxx/include/string b/libcxx/include/string
index 1d197654b9fee..edfa80162d587 100644
--- a/libcxx/include/string
+++ b/libcxx/include/string
@@ -645,7 +645,7 @@ basic_string<char32_t> operator""s( const char32_t *str, size_t len );
 #  include <__utility/move.h>
 #  include <__utility/scope_guard.h>
 #  include <__utility/swap.h>
-#  include <__utility/unreachable.h>
+#  include <__utility/three_way_comparator.h>
 #  include <climits>
 #  include <cstdio> // EOF
 #  include <cstring>
@@ -2522,6 +2522,21 @@ _LIBCPP_STRING_V1_EXTERN_TEMPLATE_LIST(_LIBCPP_DECLARE, wchar_t)
 #  endif
 #  undef _LIBCPP_DECLARE
 
+template <class _CharT, class _Traits, class _Alloc>
+struct __three_way_comparison_traits<basic_string<_CharT, _Traits, _Alloc>, basic_string<_CharT, _Traits, _Alloc>> {
+  using __string_t = basic_string<_CharT, _Traits, _Alloc>;
+
+  using __three_way_comparable = void;
+
+  _LIBCPP_HIDE_FROM_ABI static int operator()(const __string_t& __lhs, const __string_t& __rhs) {
+    auto __min_len = std::min(__lhs.size(), __rhs.size());
+    auto __ret     = _Traits::compare(__lhs.data(), __rhs.data(), __min_len);
+    if (__ret == 0)
+      return __lhs.size() == __rhs.size() ? 0 : __lhs.size() < __rhs.size() ? -1 : 1;
+    return __ret;
+  }
+};
+
 #  if _LIBCPP_STD_VER >= 17
 template <class _InputIterator,
           class _CharT     = __iter_value_type<_InputIterator>,



More information about the libcxx-commits mailing list