[libcxx-commits] [libcxx] [libc++] Optimize __tree::find and __tree::__erase_unique (PR #152370)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Thu Aug 14 00:14:26 PDT 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/152370

>From fc6edb706d878ee9f6de227cedd45763ac234042 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 5 Aug 2025 09:50:59 +0200
Subject: [PATCH] [libc++] Optimize __tree::__erase_unique

---
 libcxx/docs/ReleaseNotes/22.rst               |  2 +
 libcxx/include/__tree                         | 36 ++++++-------
 .../multimap/multimap.ops/find.pass.cpp       | 51 +++++++++++--------
 3 files changed, 47 insertions(+), 42 deletions(-)

diff --git a/libcxx/docs/ReleaseNotes/22.rst b/libcxx/docs/ReleaseNotes/22.rst
index 8b8dce5083149..bb6e0a2fa9c8f 100644
--- a/libcxx/docs/ReleaseNotes/22.rst
+++ b/libcxx/docs/ReleaseNotes/22.rst
@@ -45,6 +45,8 @@ Improvements and New Features
 
 - The performance of ``map::map(const map&)`` has been improved up to 2.3x
 - The performance of ``map::operator=(const map&)`` has been improved by up to 11x
+- The performance of ``map::erase`` and ``set::erase`` has been improved by up to 2x
+- The performance of ``find(key)`` in ``map``, ``set``, ``multimap`` and ``multiset`` has been improved by up to 2.3x
 
 Deprecations and Removals
 -------------------------
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 6dadd0915c984..8f9c508ee1996 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -1009,9 +1009,22 @@ public:
   __insert_node_at(__end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT;
 
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __v);
+  _LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __key) {
+    __end_node_pointer __parent;
+    __node_base_pointer __match = __find_equal(__parent, __key);
+    if (__match == nullptr)
+      return end();
+    return iterator(static_cast<__node_pointer>(__match));
+  }
+
   template <class _Key>
-  _LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __v) const;
+  _LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __key) const {
+    __end_node_pointer __parent;
+    __node_base_pointer __match = __find_equal(__parent, __key);
+    if (__match == nullptr)
+      return end();
+    return const_iterator(static_cast<__node_pointer>(__match));
+  }
 
   template <class _Key>
   _LIBCPP_HIDE_FROM_ABI size_type __count_unique(const _Key& __k) const;
@@ -2031,25 +2044,6 @@ __tree<_Tp, _Compare, _Allocator>::__erase_multi(const _Key& __k) {
   return __r;
 }
 
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) {
-  iterator __p = __lower_bound(__v, __root(), __end_node());
-  if (__p != end() && !value_comp()(__v, *__p))
-    return __p;
-  return end();
-}
-
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::const_iterator
-__tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) const {
-  const_iterator __p = __lower_bound(__v, __root(), __end_node());
-  if (__p != end() && !value_comp()(__v, *__p))
-    return __p;
-  return end();
-}
-
 template <class _Tp, class _Compare, class _Allocator>
 template <class _Key>
 typename __tree<_Tp, _Compare, _Allocator>::size_type
diff --git a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
index 6d5018ff5263e..15df6c15bfa78 100644
--- a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
+++ b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
@@ -21,6 +21,15 @@
 #include "private_constructor.h"
 #include "is_transparent.h"
 
+template <class Iter>
+bool iter_in_range(Iter first, Iter last, Iter to_find) {
+  for (; first != last; ++first) {
+    if (first == to_find)
+      return true;
+  }
+  return false;
+}
+
 int main(int, char**) {
   typedef std::pair<const int, double> V;
   {
@@ -30,15 +39,15 @@ int main(int, char**) {
       V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
       M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
       R r = m.find(5);
-      assert(r == m.begin());
+      assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
       r = m.find(6);
       assert(r == m.end());
       r = m.find(7);
-      assert(r == std::next(m.begin(), 3));
+      assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
       r = m.find(8);
       assert(r == m.end());
       r = m.find(9);
-      assert(r == std::next(m.begin(), 6));
+      assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
       r = m.find(10);
       assert(r == m.end());
     }
@@ -47,15 +56,15 @@ int main(int, char**) {
       V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
       const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
       R r = m.find(5);
-      assert(r == m.begin());
+      assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
       r = m.find(6);
       assert(r == m.end());
       r = m.find(7);
-      assert(r == std::next(m.begin(), 3));
+      assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
       r = m.find(8);
       assert(r == m.end());
       r = m.find(9);
-      assert(r == std::next(m.begin(), 6));
+      assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
       r = m.find(10);
       assert(r == m.end());
     }
@@ -68,15 +77,15 @@ int main(int, char**) {
       V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
       M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
       R r = m.find(5);
-      assert(r == m.begin());
+      assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
       r = m.find(6);
       assert(r == m.end());
       r = m.find(7);
-      assert(r == std::next(m.begin(), 3));
+      assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
       r = m.find(8);
       assert(r == m.end());
       r = m.find(9);
-      assert(r == std::next(m.begin(), 6));
+      assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
       r = m.find(10);
       assert(r == m.end());
     }
@@ -85,15 +94,15 @@ int main(int, char**) {
       V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
       const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
       R r = m.find(5);
-      assert(r == m.begin());
+      assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
       r = m.find(6);
       assert(r == m.end());
       r = m.find(7);
-      assert(r == std::next(m.begin(), 3));
+      assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
       r = m.find(8);
       assert(r == m.end());
       r = m.find(9);
-      assert(r == std::next(m.begin(), 6));
+      assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
       r = m.find(10);
       assert(r == m.end());
     }
@@ -107,28 +116,28 @@ int main(int, char**) {
     V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
     M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
     R r = m.find(5);
-    assert(r == m.begin());
+    assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
     r = m.find(6);
     assert(r == m.end());
     r = m.find(7);
-    assert(r == std::next(m.begin(), 3));
+    assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
     r = m.find(8);
     assert(r == m.end());
     r = m.find(9);
-    assert(r == std::next(m.begin(), 6));
+    assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
     r = m.find(10);
     assert(r == m.end());
 
     r = m.find(C2Int(5));
-    assert(r == m.begin());
+    assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
     r = m.find(C2Int(6));
     assert(r == m.end());
     r = m.find(C2Int(7));
-    assert(r == std::next(m.begin(), 3));
+    assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
     r = m.find(C2Int(8));
     assert(r == m.end());
     r = m.find(C2Int(9));
-    assert(r == std::next(m.begin(), 6));
+    assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
     r = m.find(C2Int(10));
     assert(r == m.end());
   }
@@ -150,15 +159,15 @@ int main(int, char**) {
     m.insert(std::make_pair<PC, double>(PC::make(9), 3));
 
     R r = m.find(5);
-    assert(r == m.begin());
+    assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
     r = m.find(6);
     assert(r == m.end());
     r = m.find(7);
-    assert(r == std::next(m.begin(), 3));
+    assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
     r = m.find(8);
     assert(r == m.end());
     r = m.find(9);
-    assert(r == std::next(m.begin(), 6));
+    assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
     r = m.find(10);
     assert(r == m.end());
   }



More information about the libcxx-commits mailing list