[libcxx-commits] [libcxx] [libc++] Refactor __tree::__find_equal to not have an out parameter (PR #147345)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Tue Sep 2 23:28:45 PDT 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/147345

>From cb3793a711cc522092eb33a65148929548477822 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 7 Jul 2025 18:42:36 +0200
Subject: [PATCH 1/2] [libc++] Refactor __tree::__find_equal to not have an out
 parameter

---
 libcxx/include/__tree | 179 +++++++++++++++++++-----------------------
 libcxx/include/map    |  11 +--
 2 files changed, 86 insertions(+), 104 deletions(-)

diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 3ad2129ba9ddf..603189e971293 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -910,10 +910,9 @@ public:
   _LIBCPP_HIDE_FROM_ABI pair<iterator, bool> __emplace_unique(_Args&&... __args) {
     return std::__try_key_extraction<key_type>(
         [this](const key_type& __key, _Args&&... __args2) {
-          __end_node_pointer __parent;
-          __node_base_pointer& __child = __find_equal(__parent, __key);
-          __node_pointer __r           = static_cast<__node_pointer>(__child);
-          bool __inserted              = false;
+          auto [__parent, __child] = __find_equal(__key);
+          __node_pointer __r       = static_cast<__node_pointer>(__child);
+          bool __inserted          = false;
           if (__child == nullptr) {
             __node_holder __h = __construct_node(std::forward<_Args>(__args2)...);
             __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -923,11 +922,10 @@ public:
           return pair<iterator, bool>(iterator(__r), __inserted);
         },
         [this](_Args&&... __args2) {
-          __node_holder __h = __construct_node(std::forward<_Args>(__args2)...);
-          __end_node_pointer __parent;
-          __node_base_pointer& __child = __find_equal(__parent, __h->__get_value());
-          __node_pointer __r           = static_cast<__node_pointer>(__child);
-          bool __inserted              = false;
+          __node_holder __h        = __construct_node(std::forward<_Args>(__args2)...);
+          auto [__parent, __child] = __find_equal(__h->__get_value());
+          __node_pointer __r       = static_cast<__node_pointer>(__child);
+          bool __inserted          = false;
           if (__child == nullptr) {
             __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
             __r        = __h.release();
@@ -942,11 +940,10 @@ public:
   _LIBCPP_HIDE_FROM_ABI pair<iterator, bool> __emplace_hint_unique(const_iterator __p, _Args&&... __args) {
     return std::__try_key_extraction<key_type>(
         [this, __p](const key_type& __key, _Args&&... __args2) {
-          __end_node_pointer __parent;
           __node_base_pointer __dummy;
-          __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __key);
-          __node_pointer __r           = static_cast<__node_pointer>(__child);
-          bool __inserted              = false;
+          auto [__parent, __child] = __find_equal(__p, __dummy, __key);
+          __node_pointer __r       = static_cast<__node_pointer>(__child);
+          bool __inserted          = false;
           if (__child == nullptr) {
             __node_holder __h = __construct_node(std::forward<_Args>(__args2)...);
             __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -957,10 +954,9 @@ public:
         },
         [this, __p](_Args&&... __args2) {
           __node_holder __h = __construct_node(std::forward<_Args>(__args2)...);
-          __end_node_pointer __parent;
           __node_base_pointer __dummy;
-          __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __h->__get_value());
-          __node_pointer __r           = static_cast<__node_pointer>(__child);
+          auto [__parent, __child] = __find_equal(__p, __dummy, __h->__get_value());
+          __node_pointer __r       = static_cast<__node_pointer>(__child);
           if (__child == nullptr) {
             __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
             __r = __h.release();
@@ -1060,8 +1056,7 @@ public:
 
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __key) {
-    __end_node_pointer __parent;
-    __node_base_pointer __match = __find_equal(__parent, __key);
+    auto [__, __match] = __find_equal(__key);
     if (__match == nullptr)
       return end();
     return iterator(static_cast<__node_pointer>(__match));
@@ -1069,8 +1064,7 @@ public:
 
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __key) const {
-    __end_node_pointer __parent;
-    __node_base_pointer __match = __find_equal(__parent, __key);
+    auto [__, __match] = __find_equal(__key);
     if (__match == nullptr)
       return end();
     return const_iterator(static_cast<__node_pointer>(__match));
@@ -1125,14 +1119,16 @@ public:
   // FIXME: Make this function const qualified. Unfortunately doing so
   // breaks existing code which uses non-const callable comparators.
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v);
+  _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v);
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v) const {
-    return const_cast<__tree*>(this)->__find_equal(__parent, __v);
+  _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) const {
+    return const_cast<__tree*>(this)->__find_equal(__v);
   }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI __node_base_pointer&
-  __find_equal(const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v);
+  _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&>
+  __find_equal(const_iterator __hint, __node_base_pointer& __dummy, const _Key& __v);
 
   _LIBCPP_HIDE_FROM_ABI void __copy_assign_alloc(const __tree& __t) {
     __copy_assign_alloc(__t, integral_constant<bool, __node_traits::propagate_on_container_copy_assignment::value>());
@@ -1685,92 +1681,85 @@ typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Co
   return __find_leaf_low(__parent, __v);
 }
 
-// Find place to insert if __v doesn't exist
-// Set __parent to parent of null leaf
-// Return reference to null leaf
-// If __v exists, set parent to node of __v and return reference to node of __v
+// Find __v
+// If __v exists, return the parent of the node of __v a reference to the pointer to the node of __v.
+// If __v doesn't exist, return the parent of the null leaf and a reference to the pointer to the null leaf.
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer&
-__tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, const _Key& __v) {
-  __node_pointer __nd           = __root();
-  __node_base_pointer* __nd_ptr = __root_ptr();
-  if (__nd != nullptr) {
-    while (true) {
-      if (value_comp()(__v, __nd->__get_value())) {
-        if (__nd->__left_ != nullptr) {
-          __nd_ptr = std::addressof(__nd->__left_);
-          __nd     = static_cast<__node_pointer>(__nd->__left_);
-        } else {
-          __parent = static_cast<__end_node_pointer>(__nd);
-          return __parent->__left_;
-        }
-      } else if (value_comp()(__nd->__get_value(), __v)) {
-        if (__nd->__right_ != nullptr) {
-          __nd_ptr = std::addressof(__nd->__right_);
-          __nd     = static_cast<__node_pointer>(__nd->__right_);
-        } else {
-          __parent = static_cast<__end_node_pointer>(__nd);
-          return __nd->__right_;
-        }
-      } else {
-        __parent = static_cast<__end_node_pointer>(__nd);
-        return *__nd_ptr;
-      }
+_LIBCPP_HIDE_FROM_ABI pair<typename __tree<_Tp, _Compare, _Allocator>::__end_node_pointer,
+                           typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer&>
+__tree<_Tp, _Compare, _Allocator>::__find_equal(const _Key& __v) {
+  using _Pair = pair<__end_node_pointer, __node_base_pointer&>;
+
+  __node_pointer __nd = __root();
+
+  if (__nd == nullptr) {
+    auto __end = __end_node();
+    return _Pair(__end, __end->__left_);
+  }
+
+  __node_base_pointer* __node_ptr = __root_ptr();
+  while (true) {
+    if (value_comp()(__v, __nd->__get_value())) {
+      if (__nd->__left_ == nullptr)
+        return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__left_);
+
+      __node_ptr = std::addressof(__nd->__left_);
+      __nd       = static_cast<__node_pointer>(__nd->__left_);
+    } else if (value_comp()(__nd->__get_value(), __v)) {
+      if (__nd->__right_ == nullptr)
+        return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__right_);
+
+      __node_ptr = std::addressof(__nd->__right_);
+      __nd       = static_cast<__node_pointer>(__nd->__right_);
+    } else {
+      return _Pair(static_cast<__end_node_pointer>(__nd), *__node_ptr);
     }
   }
-  __parent = __end_node();
-  return __parent->__left_;
 }
 
-// Find place to insert if __v doesn't exist
+// Find __v
 // First check prior to __hint.
 // Next check after __hint.
 // Next do O(log N) search.
-// Set __parent to parent of null leaf
-// Return reference to null leaf
-// If __v exists, set parent to node of __v and return reference to node of __v
+// If __v exists, return the parent of the node of __v a reference to the pointer to the node of __v.
+// If __v doesn't exist, return the parent of the null leaf and a reference to the pointer to the null leaf.
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Compare, _Allocator>::__find_equal(
-    const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v) {
-  if (__hint == end() || value_comp()(__v, *__hint)) // check before
-  {
+_LIBCPP_HIDE_FROM_ABI pair<typename __tree<_Tp, _Compare, _Allocator>::__end_node_pointer,
+                           typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer&>
+__tree<_Tp, _Compare, _Allocator>::__find_equal(const_iterator __hint, __node_base_pointer& __dummy, const _Key& __v) {
+  using _Pair = pair<__end_node_pointer, __node_base_pointer&>;
+
+  if (__hint == end() || value_comp()(__v, *__hint)) { // check before
     // __v < *__hint
     const_iterator __prior = __hint;
     if (__prior == begin() || value_comp()(*--__prior, __v)) {
       // *prev(__hint) < __v < *__hint
-      if (__hint.__ptr_->__left_ == nullptr) {
-        __parent = __hint.__ptr_;
-        return __parent->__left_;
-      } else {
-        __parent = __prior.__ptr_;
-        return static_cast<__node_base_pointer>(__prior.__ptr_)->__right_;
-      }
+      if (__hint.__ptr_->__left_ == nullptr)
+        return _Pair(__hint.__ptr_, __hint.__ptr_->__left_);
+      return _Pair(__prior.__ptr_, static_cast<__node_pointer>(__prior.__ptr_)->__right_);
     }
     // __v <= *prev(__hint)
-    return __find_equal(__parent, __v);
-  } else if (value_comp()(*__hint, __v)) // check after
-  {
+    return __find_equal(__v);
+  }
+
+  if (value_comp()(*__hint, __v)) { // check after
     // *__hint < __v
     const_iterator __next = std::next(__hint);
     if (__next == end() || value_comp()(__v, *__next)) {
       // *__hint < __v < *std::next(__hint)
-      if (__hint.__get_np()->__right_ == nullptr) {
-        __parent = __hint.__ptr_;
-        return static_cast<__node_base_pointer>(__hint.__ptr_)->__right_;
-      } else {
-        __parent = __next.__ptr_;
-        return __parent->__left_;
-      }
+      if (__hint.__get_np()->__right_ == nullptr)
+        return _Pair(__hint.__ptr_, static_cast<__node_pointer>(__hint.__ptr_)->__right_);
+      return _Pair(__next.__ptr_, __next.__ptr_->__left_);
     }
     // *next(__hint) <= __v
-    return __find_equal(__parent, __v);
+    return __find_equal(__v);
   }
+
   // else __v == *__hint
-  __parent = __hint.__ptr_;
-  __dummy  = static_cast<__node_base_pointer>(__hint.__ptr_);
-  return __dummy;
+  __dummy = static_cast<__node_base_pointer>(__hint.__ptr_);
+  return _Pair(__hint.__ptr_, __dummy);
 }
 
 template <class _Tp, class _Compare, class _Allocator>
@@ -1823,10 +1812,9 @@ __tree<_Tp, _Compare, _Allocator>::__emplace_hint_multi(const_iterator __p, _Arg
 template <class _Tp, class _Compare, class _Allocator>
 pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
 __tree<_Tp, _Compare, _Allocator>::__node_assign_unique(const value_type& __v, __node_pointer __nd) {
-  __end_node_pointer __parent;
-  __node_base_pointer& __child = __find_equal(__parent, __v);
-  __node_pointer __r           = static_cast<__node_pointer>(__child);
-  bool __inserted              = false;
+  auto [__parent, __child] = __find_equal(__v);
+  __node_pointer __r       = static_cast<__node_pointer>(__child);
+  bool __inserted          = false;
   if (__child == nullptr) {
     __assign_value(__nd->__get_value(), __v);
     __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__nd));
@@ -1875,8 +1863,7 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(_NodeHandle&& __n
     return _InsertReturnType{end(), false, _NodeHandle()};
 
   __node_pointer __ptr = __nh.__ptr_;
-  __end_node_pointer __parent;
-  __node_base_pointer& __child = __find_equal(__parent, __ptr->__get_value());
+  auto [__parent, __child] = __find_equal(__ptr->__get_value());
   if (__child != nullptr)
     return _InsertReturnType{iterator(static_cast<__node_pointer>(__child)), false, std::move(__nh)};
 
@@ -1893,10 +1880,9 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(const_iterator __
     return end();
 
   __node_pointer __ptr = __nh.__ptr_;
-  __end_node_pointer __parent;
   __node_base_pointer __dummy;
-  __node_base_pointer& __child = __find_equal(__hint, __parent, __dummy, __ptr->__get_value());
-  __node_pointer __r           = static_cast<__node_pointer>(__child);
+  auto [__parent, __child] = __find_equal(__hint, __dummy, __ptr->__get_value());
+  __node_pointer __r       = static_cast<__node_pointer>(__child);
   if (__child == nullptr) {
     __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__ptr));
     __r = __ptr;
@@ -1929,8 +1915,7 @@ _LIBCPP_HIDE_FROM_ABI void __tree<_Tp, _Compare, _Allocator>::__node_handle_merg
 
   for (typename _Tree::iterator __i = __source.begin(); __i != __source.end();) {
     __node_pointer __src_ptr = __i.__get_np();
-    __end_node_pointer __parent;
-    __node_base_pointer& __child = __find_equal(__parent, __src_ptr->__get_value());
+    auto [__parent, __child] = __find_equal(__src_ptr->__get_value());
     ++__i;
     if (__child != nullptr)
       continue;
diff --git a/libcxx/include/map b/libcxx/include/map
index b53e1c4213487..af24e8a55896a 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -1418,9 +1418,8 @@ map<_Key, _Tp, _Compare, _Allocator>::__construct_node_with_key(const key_type&
 
 template <class _Key, class _Tp, class _Compare, class _Allocator>
 _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
-  __parent_pointer __parent;
-  __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
-  __node_pointer __r           = static_cast<__node_pointer>(__child);
+  auto [__parent, __child] = __tree_.__find_equal(__k);
+  __node_pointer __r       = static_cast<__node_pointer>(__child);
   if (__child == nullptr) {
     __node_holder __h = __construct_node_with_key(__k);
     __tree_.__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -1433,8 +1432,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
 
 template <class _Key, class _Tp, class _Compare, class _Allocator>
 _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) {
-  __parent_pointer __parent;
-  __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
+  auto [_, __child] = __tree_.__find_equal(__k);
   if (__child == nullptr)
     std::__throw_out_of_range("map::at:  key not found");
   return static_cast<__node_pointer>(__child)->__get_value().second;
@@ -1442,8 +1440,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) {
 
 template <class _Key, class _Tp, class _Compare, class _Allocator>
 const _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) const {
-  __parent_pointer __parent;
-  __node_base_pointer __child = __tree_.__find_equal(__parent, __k);
+  auto [_, __child] = __tree_.__find_equal(__k);
   if (__child == nullptr)
     std::__throw_out_of_range("map::at:  key not found");
   return static_cast<__node_pointer>(__child)->__get_value().second;

>From 7e16c7be4848b043b762a9985f5f5b909f321d7f Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Wed, 3 Sep 2025 08:28:37 +0200
Subject: [PATCH 2/2] Update libcxx/include/__tree

Co-authored-by: Louis Dionne <ldionne.2 at gmail.com>
---
 libcxx/include/__tree | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 603189e971293..505ea6db162db 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -1682,7 +1682,7 @@ typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Co
 }
 
 // Find __v
-// If __v exists, return the parent of the node of __v a reference to the pointer to the node of __v.
+// If __v exists, return the parent of the node of __v and a reference to the pointer to the node of __v.
 // If __v doesn't exist, return the parent of the null leaf and a reference to the pointer to the null leaf.
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>



More information about the libcxx-commits mailing list