[libcxx-commits] [libcxx] [libc++] Refactor __tree::__find_equal to not have an out parameter (PR #147345)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Wed Aug 20 01:53:03 PDT 2025
https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/147345
>From 97ca0093a889819716f8eb8c2dfc745f6b83eddb Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 7 Jul 2025 18:42:36 +0200
Subject: [PATCH] [libc++] Refactor __tree::__find_equal to not have an out
parameter
---
libcxx/include/__tree | 224 +++++++++++++++++++-----------------------
libcxx/include/map | 11 +--
2 files changed, 104 insertions(+), 131 deletions(-)
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index a84a0e43d3dda..74e714bf84843 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -1045,8 +1045,7 @@ public:
template <class _Key>
_LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __key) {
- __end_node_pointer __parent;
- __node_base_pointer __match = __find_equal(__parent, __key);
+ auto [__, __match] = __find_equal(__key);
if (__match == nullptr)
return end();
return iterator(static_cast<__node_pointer>(__match));
@@ -1054,8 +1053,7 @@ public:
template <class _Key>
_LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __key) const {
- __end_node_pointer __parent;
- __node_base_pointer __match = __find_equal(__parent, __key);
+ auto [__, __match] = __find_equal(__key);
if (__match == nullptr)
return end();
return const_iterator(static_cast<__node_pointer>(__match));
@@ -1109,15 +1107,89 @@ public:
// FIXME: Make this function const qualified. Unfortunately doing so
// breaks existing code which uses non-const callable comparators.
+
+ // Find place to insert if __v doesn't exist
+ // Set __parent to parent of null leaf
+ // Return reference to null leaf
+ // If __v exists, set parent to node of __v and return reference to node of __v
template <class _Key>
- _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v);
+ _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) {
+ using _Pair = pair<__end_node_pointer, __node_base_pointer&>;
+
+ __node_pointer __nd = __root();
+
+ if (__nd == nullptr) {
+ auto __end = __end_node();
+ return _Pair(__end, __end->__left_);
+ }
+
+ __node_base_pointer* __node_ptr = __root_ptr();
+ while (true) {
+ if (value_comp()(__v, __nd->__get_value())) {
+ if (__nd->__left_ == nullptr)
+ return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__left_);
+
+ __node_ptr = std::addressof(__nd->__left_);
+ __nd = static_cast<__node_pointer>(__nd->__left_);
+ } else if (value_comp()(__nd->__get_value(), __v)) {
+ if (__nd->__right_ == nullptr)
+ return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__right_);
+
+ __node_ptr = std::addressof(__nd->__right_);
+ __nd = static_cast<__node_pointer>(__nd->__right_);
+ } else {
+ return _Pair(static_cast<__end_node_pointer>(__nd), *__node_ptr);
+ }
+ }
+ }
+
template <class _Key>
- _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v) const {
- return const_cast<__tree*>(this)->__find_equal(__parent, __v);
+ _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) const {
+ return const_cast<__tree*>(this)->__find_equal(__v);
}
+
+ // Find place to insert if __v doesn't exist
+ // First check prior to __hint.
+ // Next check after __hint.
+ // Next do O(log N) search.
+ // Set __parent to parent of null leaf
+ // Return reference to null leaf
+ // If __v exists, set parent to node of __v and return reference to node of __v
template <class _Key>
- _LIBCPP_HIDE_FROM_ABI __node_base_pointer&
- __find_equal(const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v);
+ _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&>
+ __find_equal(const_iterator __hint, __node_base_pointer& __dummy, const _Key& __v) {
+ using _PairT = pair<__end_node_pointer, __node_base_pointer&>;
+
+ if (__hint == end() || value_comp()(__v, *__hint)) { // check before
+ // __v < *__hint
+ const_iterator __prior = __hint;
+ if (__prior == begin() || value_comp()(*--__prior, __v)) {
+ // *prev(__hint) < __v < *__hint
+ if (__hint.__ptr_->__left_ == nullptr)
+ return _PairT(__hint.__ptr_, __hint.__ptr_->__left_);
+ return _PairT(__prior.__ptr_, static_cast<__node_pointer>(__prior.__ptr_)->__right_);
+ }
+ // __v <= *prev(__hint)
+ return __find_equal(__v);
+ }
+
+ if (value_comp()(*__hint, __v)) { // check after
+ // *__hint < __v
+ const_iterator __next = std::next(__hint);
+ if (__next == end() || value_comp()(__v, *__next)) {
+ // *__hint < __v < *std::next(__hint)
+ if (__hint.__get_np()->__right_ == nullptr)
+ return _PairT(__hint.__ptr_, static_cast<__node_pointer>(__hint.__ptr_)->__right_);
+ return _PairT(__next.__ptr_, __next.__ptr_->__left_);
+ }
+ // *next(__hint) <= __v
+ return __find_equal(__v);
+ }
+
+ // else __v == *__hint
+ __dummy = static_cast<__node_base_pointer>(__hint.__ptr_);
+ return _PairT(__hint.__ptr_, __dummy);
+ }
_LIBCPP_HIDE_FROM_ABI void __copy_assign_alloc(const __tree& __t) {
__copy_assign_alloc(__t, integral_constant<bool, __node_traits::propagate_on_container_copy_assignment::value>());
@@ -1670,94 +1742,6 @@ typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Co
return __find_leaf_low(__parent, __v);
}
-// Find place to insert if __v doesn't exist
-// Set __parent to parent of null leaf
-// Return reference to null leaf
-// If __v exists, set parent to node of __v and return reference to node of __v
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer&
-__tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, const _Key& __v) {
- __node_pointer __nd = __root();
- __node_base_pointer* __nd_ptr = __root_ptr();
- if (__nd != nullptr) {
- while (true) {
- if (value_comp()(__v, __nd->__get_value())) {
- if (__nd->__left_ != nullptr) {
- __nd_ptr = std::addressof(__nd->__left_);
- __nd = static_cast<__node_pointer>(__nd->__left_);
- } else {
- __parent = static_cast<__end_node_pointer>(__nd);
- return __parent->__left_;
- }
- } else if (value_comp()(__nd->__get_value(), __v)) {
- if (__nd->__right_ != nullptr) {
- __nd_ptr = std::addressof(__nd->__right_);
- __nd = static_cast<__node_pointer>(__nd->__right_);
- } else {
- __parent = static_cast<__end_node_pointer>(__nd);
- return __nd->__right_;
- }
- } else {
- __parent = static_cast<__end_node_pointer>(__nd);
- return *__nd_ptr;
- }
- }
- }
- __parent = __end_node();
- return __parent->__left_;
-}
-
-// Find place to insert if __v doesn't exist
-// First check prior to __hint.
-// Next check after __hint.
-// Next do O(log N) search.
-// Set __parent to parent of null leaf
-// Return reference to null leaf
-// If __v exists, set parent to node of __v and return reference to node of __v
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Compare, _Allocator>::__find_equal(
- const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v) {
- if (__hint == end() || value_comp()(__v, *__hint)) // check before
- {
- // __v < *__hint
- const_iterator __prior = __hint;
- if (__prior == begin() || value_comp()(*--__prior, __v)) {
- // *prev(__hint) < __v < *__hint
- if (__hint.__ptr_->__left_ == nullptr) {
- __parent = __hint.__ptr_;
- return __parent->__left_;
- } else {
- __parent = __prior.__ptr_;
- return static_cast<__node_base_pointer>(__prior.__ptr_)->__right_;
- }
- }
- // __v <= *prev(__hint)
- return __find_equal(__parent, __v);
- } else if (value_comp()(*__hint, __v)) // check after
- {
- // *__hint < __v
- const_iterator __next = std::next(__hint);
- if (__next == end() || value_comp()(__v, *__next)) {
- // *__hint < __v < *std::next(__hint)
- if (__hint.__get_np()->__right_ == nullptr) {
- __parent = __hint.__ptr_;
- return static_cast<__node_base_pointer>(__hint.__ptr_)->__right_;
- } else {
- __parent = __next.__ptr_;
- return __parent->__left_;
- }
- }
- // *next(__hint) <= __v
- return __find_equal(__parent, __v);
- }
- // else __v == *__hint
- __parent = __hint.__ptr_;
- __dummy = static_cast<__node_base_pointer>(__hint.__ptr_);
- return __dummy;
-}
-
template <class _Tp, class _Compare, class _Allocator>
void __tree<_Tp, _Compare, _Allocator>::__insert_node_at(
__end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT {
@@ -1776,10 +1760,9 @@ template <class _Tp, class _Compare, class _Allocator>
template <class _Key, class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_unique_key_args(_Key const& __k, _Args&&... __args) {
- __end_node_pointer __parent;
- __node_base_pointer& __child = __find_equal(__parent, __k);
- __node_pointer __r = static_cast<__node_pointer>(__child);
- bool __inserted = false;
+ auto [__parent, __child] = __find_equal(__k);
+ __node_pointer __r = static_cast<__node_pointer>(__child);
+ bool __inserted = false;
if (__child == nullptr) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -1794,11 +1777,10 @@ template <class _Key, class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_key_args(
const_iterator __p, _Key const& __k, _Args&&... __args) {
- __end_node_pointer __parent;
__node_base_pointer __dummy;
- __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __k);
- __node_pointer __r = static_cast<__node_pointer>(__child);
- bool __inserted = false;
+ auto [__parent, __child] = __find_equal(__p, __dummy, __k);
+ __node_pointer __r = static_cast<__node_pointer>(__child);
+ bool __inserted = false;
if (__child == nullptr) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -1823,11 +1805,10 @@ template <class _Tp, class _Compare, class _Allocator>
template <class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_unique_impl(_Args&&... __args) {
- __node_holder __h = __construct_node(std::forward<_Args>(__args)...);
- __end_node_pointer __parent;
- __node_base_pointer& __child = __find_equal(__parent, __h->__get_value());
- __node_pointer __r = static_cast<__node_pointer>(__child);
- bool __inserted = false;
+ __node_holder __h = __construct_node(std::forward<_Args>(__args)...);
+ auto [__parent, __child] = __find_equal(__h->__get_value());
+ __node_pointer __r = static_cast<__node_pointer>(__child);
+ bool __inserted = false;
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
__r = __h.release();
@@ -1841,10 +1822,9 @@ template <class... _Args>
typename __tree<_Tp, _Compare, _Allocator>::iterator
__tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_impl(const_iterator __p, _Args&&... __args) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
- __end_node_pointer __parent;
__node_base_pointer __dummy;
- __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __h->__get_value());
- __node_pointer __r = static_cast<__node_pointer>(__child);
+ auto [__parent, __child] = __find_equal(__p, __dummy, __h->__get_value());
+ __node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
__r = __h.release();
@@ -1877,10 +1857,9 @@ __tree<_Tp, _Compare, _Allocator>::__emplace_hint_multi(const_iterator __p, _Arg
template <class _Tp, class _Compare, class _Allocator>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__node_assign_unique(const value_type& __v, __node_pointer __nd) {
- __end_node_pointer __parent;
- __node_base_pointer& __child = __find_equal(__parent, __v);
- __node_pointer __r = static_cast<__node_pointer>(__child);
- bool __inserted = false;
+ auto [__parent, __child] = __find_equal(__v);
+ __node_pointer __r = static_cast<__node_pointer>(__child);
+ bool __inserted = false;
if (__child == nullptr) {
__assign_value(__nd->__get_value(), __v);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__nd));
@@ -1929,8 +1908,7 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(_NodeHandle&& __n
return _InsertReturnType{end(), false, _NodeHandle()};
__node_pointer __ptr = __nh.__ptr_;
- __end_node_pointer __parent;
- __node_base_pointer& __child = __find_equal(__parent, __ptr->__get_value());
+ auto [__parent, __child] = __find_equal(__ptr->__get_value());
if (__child != nullptr)
return _InsertReturnType{iterator(static_cast<__node_pointer>(__child)), false, std::move(__nh)};
@@ -1947,10 +1925,9 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(const_iterator __
return end();
__node_pointer __ptr = __nh.__ptr_;
- __end_node_pointer __parent;
__node_base_pointer __dummy;
- __node_base_pointer& __child = __find_equal(__hint, __parent, __dummy, __ptr->__get_value());
- __node_pointer __r = static_cast<__node_pointer>(__child);
+ auto [__parent, __child] = __find_equal(__hint, __dummy, __ptr->__get_value());
+ __node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__ptr));
__r = __ptr;
@@ -1983,8 +1960,7 @@ _LIBCPP_HIDE_FROM_ABI void __tree<_Tp, _Compare, _Allocator>::__node_handle_merg
for (typename _Tree::iterator __i = __source.begin(); __i != __source.end();) {
__node_pointer __src_ptr = __i.__get_np();
- __end_node_pointer __parent;
- __node_base_pointer& __child = __find_equal(__parent, __src_ptr->__get_value());
+ auto [__parent, __child] = __find_equal(__src_ptr->__get_value());
++__i;
if (__child != nullptr)
continue;
diff --git a/libcxx/include/map b/libcxx/include/map
index 4dfce70e50e7f..9b54903b5b334 100644
--- a/libcxx/include/map
+++ b/libcxx/include/map
@@ -1429,9 +1429,8 @@ map<_Key, _Tp, _Compare, _Allocator>::__construct_node_with_key(const key_type&
template <class _Key, class _Tp, class _Compare, class _Allocator>
_Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
- __parent_pointer __parent;
- __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
- __node_pointer __r = static_cast<__node_pointer>(__child);
+ auto [__parent, __child] = __tree_.__find_equal(__k);
+ __node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__node_holder __h = __construct_node_with_key(__k);
__tree_.__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
@@ -1444,8 +1443,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
template <class _Key, class _Tp, class _Compare, class _Allocator>
_Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) {
- __parent_pointer __parent;
- __node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
+ auto [_, __child] = __tree_.__find_equal(__k);
if (__child == nullptr)
std::__throw_out_of_range("map::at: key not found");
return static_cast<__node_pointer>(__child)->__get_value().second;
@@ -1453,8 +1451,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) {
template <class _Key, class _Tp, class _Compare, class _Allocator>
const _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) const {
- __parent_pointer __parent;
- __node_base_pointer __child = __tree_.__find_equal(__parent, __k);
+ auto [_, __child] = __tree_.__find_equal(__k);
if (__child == nullptr)
std::__throw_out_of_range("map::at: key not found");
return static_cast<__node_pointer>(__child)->__get_value().second;
More information about the libcxx-commits
mailing list