[libcxx-commits] [libcxx] [libc++] Optimize search_n (PR #171389)
Nikolas Klauser via libcxx-commits
libcxx-commits at lists.llvm.org
Tue Dec 23 01:16:18 PST 2025
https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/171389
>From 9c29afcc59392259dc06ae1e6f446aefec422808 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 8 Dec 2025 19:03:05 +0100
Subject: [PATCH] [libc++] Optimize search_n
---
libcxx/docs/ReleaseNotes/22.rst | 1 +
libcxx/include/__algorithm/ranges_search_n.h | 4 +-
libcxx/include/__algorithm/search_n.h | 86 +++++++-----
.../nonmodifying/search_n.bench.cpp | 52 ++++++-
.../alg.search/ranges.search_n.pass.cpp | 54 +++++--
.../alg.search/search_n.pass.cpp | 132 +++++++++---------
6 files changed, 212 insertions(+), 117 deletions(-)
diff --git a/libcxx/docs/ReleaseNotes/22.rst b/libcxx/docs/ReleaseNotes/22.rst
index 36048165003dd..36ea962b676d0 100644
--- a/libcxx/docs/ReleaseNotes/22.rst
+++ b/libcxx/docs/ReleaseNotes/22.rst
@@ -92,6 +92,7 @@ Improvements and New Features
- ``std::for_each`` and ``ranges::for_each`` have been optimized to iterate more efficiently over the associative
containers, resulting in performance improvements of up to 2x.
+- The performance fo ``search_n`` has been significantly improved.
- The ``num_get::do_get`` integral overloads have been optimized, resulting in a performance improvement of up to 2.8x.
- The performance of ``std::align`` has been improved by making it an inline function, which allows the compiler to
diff --git a/libcxx/include/__algorithm/ranges_search_n.h b/libcxx/include/__algorithm/ranges_search_n.h
index 81b568c0965fd..746bfcc3d1a8f 100644
--- a/libcxx/include/__algorithm/ranges_search_n.h
+++ b/libcxx/include/__algorithm/ranges_search_n.h
@@ -54,8 +54,8 @@ struct __search_n {
}
if constexpr (random_access_iterator<_Iter1>) {
- auto __ret = std::__search_n_random_access_impl<_RangeAlgPolicy>(
- __first, __last, __count, __value, __pred, __proj, __size);
+ auto __ret =
+ std::__search_n_random_access_impl<_RangeAlgPolicy>(__first, __count, __value, __pred, __proj, __size);
return {std::move(__ret.first), std::move(__ret.second)};
}
}
diff --git a/libcxx/include/__algorithm/search_n.h b/libcxx/include/__algorithm/search_n.h
index 38474e1b2379d..67a12af971464 100644
--- a/libcxx/include/__algorithm/search_n.h
+++ b/libcxx/include/__algorithm/search_n.h
@@ -14,11 +14,7 @@
#include <__algorithm/iterator_operations.h>
#include <__config>
#include <__functional/identity.h>
-#include <__iterator/advance.h>
-#include <__iterator/concepts.h>
-#include <__iterator/distance.h>
#include <__iterator/iterator_traits.h>
-#include <__ranges/concepts.h>
#include <__type_traits/enable_if.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_callable.h>
@@ -68,44 +64,60 @@ _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter> __search_
}
}
-template <class _AlgPolicy, class _Pred, class _Iter, class _Sent, class _SizeT, class _Type, class _Proj, class _DiffT>
+// Finds the longest suffix in [__first, __last) where each element satisfies __pred.
+template <class _RAIter, class _Pred, class _Proj, class _ValueT>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _RAIter
+__find_longest_suffix(_RAIter __first, _RAIter __last, const _ValueT& __value, _Pred& __pred, _Proj& __proj) {
+ while (__first != __last) {
+ if (!std::__invoke(__pred, std::__invoke(__proj, *--__last), __value)) {
+ return ++__last;
+ }
+ }
+ return __first;
+}
+
+template <class _AlgPolicy, class _Pred, class _Iter, class _SizeT, class _Type, class _Proj, class _DiffT>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 std::pair<_Iter, _Iter> __search_n_random_access_impl(
- _Iter __first, _Sent __last, _SizeT __count, const _Type& __value, _Pred& __pred, _Proj& __proj, _DiffT __size1) {
- using difference_type = typename iterator_traits<_Iter>::difference_type;
+ _Iter __first, _SizeT __count_in, const _Type& __value, _Pred& __pred, _Proj& __proj, _DiffT __size) {
+ auto __last = __first + __size;
+ auto __count = static_cast<_DiffT>(__count_in);
+
if (__count == 0)
return std::make_pair(__first, __first);
- if (__size1 < static_cast<_DiffT>(__count)) {
- _IterOps<_AlgPolicy>::__advance_to(__first, __last);
- return std::make_pair(__first, __first);
- }
+ if (__size < __count)
+ return std::make_pair(__last, __last);
+
+ // [__match_start, __match_start + __count) is the subrange which we currently check whether it only contains matching
+ // elements. This subrange is returned in case all the elements match.
+ // [__match_start, __matched_until) is the longest subrange where all elements are known to match at any given point
+ // in time.
+ // [__matched_until, __match_start + __count) is the subrange where we don't know whether the elements match.
+
+ // This algorithm tries to expand the subrange [__match_start, __matched_until) into a range of sufficient length.
+ // When we fail to do that because we find a mismatching element, we move it forward to the beginning of the next
+ // consecutive sequence that is not known not to match.
+
+ const _Iter __try_match_until = __last - __count;
+ _Iter __match_start = __first;
+ _Iter __matched_until = __first;
- const auto __s = __first + __size1 - difference_type(__count - 1); // Start of pattern match can't go beyond here
while (true) {
- // Find first element in sequence that matchs __value, with a mininum of loop checks
- while (true) {
- if (__first >= __s) { // return __last if no element matches __value
- _IterOps<_AlgPolicy>::__advance_to(__first, __last);
- return std::make_pair(__first, __first);
- }
- if (std::__invoke(__pred, std::__invoke(__proj, *__first), __value))
- break;
- ++__first;
- }
- // *__first matches __value_, now match elements after here
- auto __m = __first;
- _SizeT __c(0);
- while (true) {
- if (++__c == __count) // If pattern exhausted, __first is the answer (works for 1 element pattern)
- return std::make_pair(__first, __first + _DiffT(__count));
- ++__m; // no need to check range on __m because __s guarantees we have enough source
+ // There's no chance of expanding the subrange into a sequence of sufficient length, since we don't have enough
+ // elements in the haystack anymore.
+ if (__match_start > __try_match_until)
+ return std::make_pair(__last, __last);
- // if there is a mismatch, restart with a new __first
- if (!std::__invoke(__pred, std::__invoke(__proj, *__m), __value)) {
- __first = __m;
- ++__first;
- break;
- } // else there is a match, check next elements
- }
+ auto __mismatch = std::__find_longest_suffix(__matched_until, __match_start + __count, __value, __pred, __proj);
+
+ // If all elements in [__matched_until, __match_start + __count) match, we know that
+ // [__match_start, __match_start + __count) is a full sequence of matching elements, so we're done.
+ if (__mismatch == __matched_until)
+ return std::make_pair(__match_start, __match_start + __count);
+
+ // Otherwise, we have to move the [__match_start, __matched_until) subrange forward past the point where we know for
+ // sure a match is impossible.
+ __matched_until = __match_start + __count;
+ __match_start = ++__mismatch;
}
}
@@ -119,7 +131,7 @@ template <class _Iter,
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
__search_n_impl(_Iter __first, _Sent __last, _DiffT __count, const _Type& __value, _Pred& __pred, _Proj& __proj) {
return std::__search_n_random_access_impl<_ClassicAlgPolicy>(
- __first, __last, __count, __value, __pred, __proj, __last - __first);
+ __first, __count, __value, __pred, __proj, __last - __first);
}
template <class _Iter1,
diff --git a/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp b/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
index de404fedaed3a..91c6ad3eafa0d 100644
--- a/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
+++ b/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
@@ -59,10 +59,56 @@ int main(int argc, char** argv) {
benchmark::DoNotOptimize(result);
}
})
- ->Arg(1000) // non power-of-two
+ ->Arg(32)
->Arg(1024)
- ->Arg(8192)
- ->Arg(1 << 20);
+ ->Arg(8192);
+ };
+ // {std,ranges}::search_n
+ bm.operator()<std::vector<int>>("std::search_n(vector<int>) (no match)", std_search_n);
+ bm.operator()<std::deque<int>>("std::search_n(deque<int>) (no match)", std_search_n);
+ bm.operator()<std::list<int>>("std::search_n(list<int>) (no match)", std_search_n);
+ bm.operator()<std::vector<int>>("rng::search_n(vector<int>) (no match)", std::ranges::search_n);
+ bm.operator()<std::deque<int>>("rng::search_n(deque<int>) (no match)", std::ranges::search_n);
+ bm.operator()<std::list<int>>("rng::search_n(list<int>) (no match)", std::ranges::search_n);
+
+ // {std,ranges}::search_n(pred)
+ bm.operator()<std::vector<int>>("std::search_n(vector<int>, pred) (no match)", std_search_n_pred);
+ bm.operator()<std::deque<int>>("std::search_n(deque<int>, pred) (no match)", std_search_n_pred);
+ bm.operator()<std::list<int>>("std::search_n(list<int>, pred) (no match)", std_search_n_pred);
+ bm.operator()<std::vector<int>>("rng::search_n(vector<int>, pred) (no match)", ranges_search_n_pred);
+ bm.operator()<std::deque<int>>("rng::search_n(deque<int>, pred) (no match)", ranges_search_n_pred);
+ bm.operator()<std::list<int>>("rng::search_n(list<int>, pred) (no match)", ranges_search_n_pred);
+ }
+
+ // Benchmark {std,ranges}::search_n where the needle almost matches a lot.
+ {
+ auto bm = []<class Container>(std::string name, auto search_n) {
+ benchmark::RegisterBenchmark(
+ name,
+ [search_n](auto& st) {
+ std::size_t const size = st.range(0);
+ using ValueType = typename Container::value_type;
+ ValueType x = Generate<ValueType>::random();
+ ValueType y = random_different_from({x});
+ Container haystack(size, x);
+ std::size_t n = size / 10; // needle size is 10% of the haystack
+
+ // Make sure there are no actual matches
+ for (size_t i = 0; i < size; i += getRandomInteger<size_t>(1, n - 1)) {
+ *std::next(haystack.begin(), i) = y;
+ }
+
+ for ([[maybe_unused]] auto _ : st) {
+ benchmark::DoNotOptimize(haystack);
+ benchmark::DoNotOptimize(n);
+ benchmark::DoNotOptimize(y);
+ auto result = search_n(haystack.begin(), haystack.end(), n, x);
+ benchmark::DoNotOptimize(result);
+ }
+ })
+ ->Arg(32)
+ ->Arg(1024)
+ ->Arg(8192);
};
// {std,ranges}::search_n
bm.operator()<std::vector<int>>("std::search_n(vector<int>) (no match)", std_search_n);
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
index 2f2e436c79130..f68c31ead7b8f 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
@@ -186,18 +186,50 @@ constexpr void test_iterators() {
}
{ // check that the first match is returned
- {
- int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
- auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
- assert(base(ret.begin()) == a);
- assert(base(ret.end()) == a + 2);
+ { // Match is at the start
+ {
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+ auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
+ assert(base(ret.begin()) == a);
+ assert(base(ret.end()) == a + 2);
+ }
+ {
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+ auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+ auto ret = std::ranges::search_n(range, 2, 6);
+ assert(base(ret.begin()) == a);
+ assert(base(ret.end()) == a + 2);
+ }
}
- {
- int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
- auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
- auto ret = std::ranges::search_n(range, 2, 6);
- assert(base(ret.begin()) == a);
- assert(base(ret.end()) == a + 2);
+ { // Match is in the middle
+ {
+ int a[] = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+ auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
+ assert(base(ret.begin()) == a + 3);
+ assert(base(ret.end()) == a + 5);
+ }
+ {
+ int a[] = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+ auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+ auto ret = std::ranges::search_n(range, 2, 6);
+ assert(base(ret.begin()) == a + 3);
+ assert(base(ret.end()) == a + 5);
+ }
+ }
+ { // Match is at the end
+ {
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+ auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 3, 6);
+ assert(base(ret.begin()) == a + 6);
+ assert(base(ret.end()) == a + 9);
+ }
+ {
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+ auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+ auto ret = std::ranges::search_n(range, 3, 6);
+ assert(base(ret.begin()) == a + 6);
+ assert(base(ret.end()) == a + 9);
+ }
}
}
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
index 228bc6b768cf9..514dc9a911302 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
@@ -14,79 +14,83 @@
// const T& value);
#include <algorithm>
+#include <array>
#include <cassert>
#include "test_macros.h"
#include "test_iterators.h"
-#include "user_defined_integral.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
- int ia[] = {0, 0, 1, 1, 2, 2};
- return (std::search_n(std::begin(ia), std::end(ia), 1, 0) == ia)
- && (std::search_n(std::begin(ia), std::end(ia), 2, 1) == ia+2)
- && (std::search_n(std::begin(ia), std::end(ia), 1, 3) == std::end(ia))
- ;
- }
-#endif
template <class Iter>
-void
-test()
-{
- int ia[] = {0, 1, 2, 3, 4, 5};
- const unsigned sa = sizeof(ia)/sizeof(ia[0]);
- assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 0) == Iter(ia));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 0) == Iter(ia+0));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 0) == Iter(ia+sa));
- assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 0) == Iter(ia+sa));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 3) == Iter(ia));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 3) == Iter(ia+3));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 3) == Iter(ia+sa));
- assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 3) == Iter(ia+sa));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 5) == Iter(ia));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 5) == Iter(ia+5));
- assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 5) == Iter(ia+sa));
- assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 5) == Iter(ia+sa));
-
- int ib[] = {0, 0, 1, 1, 2, 2};
- const unsigned sb = sizeof(ib)/sizeof(ib[0]);
- assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 0) == Iter(ib));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 0) == Iter(ib+0));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 0) == Iter(ib+0));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 0) == Iter(ib+sb));
- assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 0) == Iter(ib+sb));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 1) == Iter(ib));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 1) == Iter(ib+2));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 1) == Iter(ib+2));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 1) == Iter(ib+sb));
- assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 1) == Iter(ib+sb));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 2) == Iter(ib));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 2) == Iter(ib+4));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 2) == Iter(ib+4));
- assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 2) == Iter(ib+sb));
- assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 2) == Iter(ib+sb));
-
- int ic[] = {0, 0, 0};
- const unsigned sc = sizeof(ic)/sizeof(ic[0]);
- assert(std::search_n(Iter(ic), Iter(ic+sc), 0, 0) == Iter(ic));
- assert(std::search_n(Iter(ic), Iter(ic+sc), 1, 0) == Iter(ic));
- assert(std::search_n(Iter(ic), Iter(ic+sc), 2, 0) == Iter(ic));
- assert(std::search_n(Iter(ic), Iter(ic+sc), 3, 0) == Iter(ic));
- assert(std::search_n(Iter(ic), Iter(ic+sc), 4, 0) == Iter(ic+sc));
+TEST_CONSTEXPR_CXX20 bool test() {
+ { // simple test
+ int a[] = {1, 2, 3, 4, 5, 6};
+ auto ret = std::search_n(Iter(a), Iter(a + 6), 1, 3);
+ assert(base(ret) == a + 2);
+ }
+ { // matching part begins at the front
+ int a[] = {7, 7, 3, 7, 3, 6};
+ auto ret = std::search_n(Iter(a), Iter(a + 6), 2, 7);
+ assert(base(ret) == a);
+ }
+ { // matching part ends at the back
+ int a[] = {9, 3, 6, 4, 4};
+ auto ret = std::search_n(Iter(a), Iter(a + 5), 2, 4);
+ assert(base(ret) == a + 3);
+ }
+ { // pattern does not match
+ int a[] = {9, 3, 6, 4, 8};
+ auto ret = std::search_n(Iter(a), Iter(a + 5), 1, 1);
+ assert(base(ret) == a + 5);
+ }
+ { // range and pattern are identical
+ int a[] = {1, 1, 1, 1};
+ auto ret = std::search_n(Iter(a), Iter(a + 4), 4, 1);
+ assert(base(ret) == a);
+ }
+ { // pattern is longer than range
+ int a[] = {3, 3, 3};
+ auto ret = std::search_n(Iter(a), Iter(a + 3), 4, 3);
+ assert(base(ret) == a + 3);
+ }
+ { // pattern has zero length
+ int a[] = {6, 7, 8};
+ auto ret = std::search_n(Iter(a), Iter(a + 3), 0, 7);
+ assert(base(ret) == a);
+ }
+ { // range has zero length
+ std::array<int, 0> a = {};
+ auto ret = std::search_n(Iter(a.data()), Iter(a.data()), 1, 1);
+ assert(base(ret) == a.data());
+ }
+ { // check that the first match is returned
+ { // Match is at the start
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+ auto ret = std::search_n(Iter(a), Iter(a + 9), 2, 6);
+ assert(base(ret) == a);
+ }
+ { // Match is in the middle
+ int a[] = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+ auto ret = std::search_n(Iter(a), Iter(a + 9), 2, 6);
+ assert(base(ret) == a + 3);
+ }
+ { // Match is at the end
+ int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+ auto ret = std::search_n(Iter(a), Iter(a + 9), 3, 6);
+ assert(base(ret) == a + 6);
+ }
+ }
- // Check that we properly convert the size argument to an integral.
- (void)std::search_n(Iter(ic), Iter(ic+sc), UserDefinedIntegral<unsigned>(0), 0);
+ return true;
}
-int main(int, char**)
-{
- test<forward_iterator<const int*> >();
- test<bidirectional_iterator<const int*> >();
- test<random_access_iterator<const int*> >();
-
-#if TEST_STD_VER > 17
- static_assert(test_constexpr());
+int main(int, char**) {
+ test<forward_iterator<const int*> >();
+ test<bidirectional_iterator<const int*> >();
+ test<random_access_iterator<const int*> >();
+#if TEST_STD_VER >= 20
+ static_assert(test<forward_iterator<const int*> >());
+ static_assert(test<bidirectional_iterator<const int*> >());
+ static_assert(test<random_access_iterator<const int*> >());
#endif
return 0;
More information about the libcxx-commits
mailing list