[libcxx-commits] [libcxx] [libc++] Optimize __tree::find and __tree::__erase_unique (PR #152370)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Aug 14 00:14:26 PDT 2025
https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/152370
>From fc6edb706d878ee9f6de227cedd45763ac234042 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Tue, 5 Aug 2025 09:50:59 +0200
Subject: [PATCH] [libc++] Optimize __tree::__erase_unique
---
libcxx/docs/ReleaseNotes/22.rst | 2 +
libcxx/include/__tree | 36 ++++++-------
.../multimap/multimap.ops/find.pass.cpp | 51 +++++++++++--------
3 files changed, 47 insertions(+), 42 deletions(-)
diff --git a/libcxx/docs/ReleaseNotes/22.rst b/libcxx/docs/ReleaseNotes/22.rst
index 8b8dce5083149..bb6e0a2fa9c8f 100644
--- a/libcxx/docs/ReleaseNotes/22.rst
+++ b/libcxx/docs/ReleaseNotes/22.rst
@@ -45,6 +45,8 @@ Improvements and New Features
- The performance of ``map::map(const map&)`` has been improved up to 2.3x
- The performance of ``map::operator=(const map&)`` has been improved by up to 11x
+- The performance of ``map::erase`` and ``set::erase`` has been improved by up to 2x
+- The performance of ``find(key)`` in ``map``, ``set``, ``multimap`` and ``multiset`` has been improved by up to 2.3x
Deprecations and Removals
-------------------------
diff --git a/libcxx/include/__tree b/libcxx/include/__tree
index 6dadd0915c984..8f9c508ee1996 100644
--- a/libcxx/include/__tree
+++ b/libcxx/include/__tree
@@ -1009,9 +1009,22 @@ public:
__insert_node_at(__end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT;
template <class _Key>
- _LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __v);
+ _LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __key) {
+ __end_node_pointer __parent;
+ __node_base_pointer __match = __find_equal(__parent, __key);
+ if (__match == nullptr)
+ return end();
+ return iterator(static_cast<__node_pointer>(__match));
+ }
+
template <class _Key>
- _LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __v) const;
+ _LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __key) const {
+ __end_node_pointer __parent;
+ __node_base_pointer __match = __find_equal(__parent, __key);
+ if (__match == nullptr)
+ return end();
+ return const_iterator(static_cast<__node_pointer>(__match));
+ }
template <class _Key>
_LIBCPP_HIDE_FROM_ABI size_type __count_unique(const _Key& __k) const;
@@ -2031,25 +2044,6 @@ __tree<_Tp, _Compare, _Allocator>::__erase_multi(const _Key& __k) {
return __r;
}
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) {
- iterator __p = __lower_bound(__v, __root(), __end_node());
- if (__p != end() && !value_comp()(__v, *__p))
- return __p;
- return end();
-}
-
-template <class _Tp, class _Compare, class _Allocator>
-template <class _Key>
-typename __tree<_Tp, _Compare, _Allocator>::const_iterator
-__tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) const {
- const_iterator __p = __lower_bound(__v, __root(), __end_node());
- if (__p != end() && !value_comp()(__v, *__p))
- return __p;
- return end();
-}
-
template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::size_type
diff --git a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
index 6d5018ff5263e..15df6c15bfa78 100644
--- a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
+++ b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp
@@ -21,6 +21,15 @@
#include "private_constructor.h"
#include "is_transparent.h"
+template <class Iter>
+bool iter_in_range(Iter first, Iter last, Iter to_find) {
+ for (; first != last; ++first) {
+ if (first == to_find)
+ return true;
+ }
+ return false;
+}
+
int main(int, char**) {
typedef std::pair<const int, double> V;
{
@@ -30,15 +39,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
@@ -47,15 +56,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
@@ -68,15 +77,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
@@ -85,15 +94,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
@@ -107,28 +116,28 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
r = m.find(C2Int(5));
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(C2Int(6));
assert(r == m.end());
r = m.find(C2Int(7));
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(C2Int(8));
assert(r == m.end());
r = m.find(C2Int(9));
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(C2Int(10));
assert(r == m.end());
}
@@ -150,15 +159,15 @@ int main(int, char**) {
m.insert(std::make_pair<PC, double>(PC::make(9), 3));
R r = m.find(5);
- assert(r == m.begin());
+ assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
- assert(r == std::next(m.begin(), 3));
+ assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
- assert(r == std::next(m.begin(), 6));
+ assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
More information about the libcxx-commits
mailing list