[libcxx-commits] [libcxx] [libc++] Optimize search_n (PR #171389)

Nikolas Klauser via libcxx-commits libcxx-commits at lists.llvm.org
Tue Dec 23 01:16:18 PST 2025


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/171389

>From 9c29afcc59392259dc06ae1e6f446aefec422808 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Mon, 8 Dec 2025 19:03:05 +0100
Subject: [PATCH] [libc++] Optimize search_n

---
 libcxx/docs/ReleaseNotes/22.rst               |   1 +
 libcxx/include/__algorithm/ranges_search_n.h  |   4 +-
 libcxx/include/__algorithm/search_n.h         |  86 +++++++-----
 .../nonmodifying/search_n.bench.cpp           |  52 ++++++-
 .../alg.search/ranges.search_n.pass.cpp       |  54 +++++--
 .../alg.search/search_n.pass.cpp              | 132 +++++++++---------
 6 files changed, 212 insertions(+), 117 deletions(-)

diff --git a/libcxx/docs/ReleaseNotes/22.rst b/libcxx/docs/ReleaseNotes/22.rst
index 36048165003dd..36ea962b676d0 100644
--- a/libcxx/docs/ReleaseNotes/22.rst
+++ b/libcxx/docs/ReleaseNotes/22.rst
@@ -92,6 +92,7 @@ Improvements and New Features
 - ``std::for_each`` and ``ranges::for_each`` have been optimized to iterate more efficiently over the associative
   containers, resulting in performance improvements of up to 2x.
 
+- The performance fo ``search_n`` has been significantly improved.
 - The ``num_get::do_get`` integral overloads have been optimized, resulting in a performance improvement of up to 2.8x.
 
 - The performance of ``std::align`` has been improved by making it an inline function, which allows the compiler to
diff --git a/libcxx/include/__algorithm/ranges_search_n.h b/libcxx/include/__algorithm/ranges_search_n.h
index 81b568c0965fd..746bfcc3d1a8f 100644
--- a/libcxx/include/__algorithm/ranges_search_n.h
+++ b/libcxx/include/__algorithm/ranges_search_n.h
@@ -54,8 +54,8 @@ struct __search_n {
       }
 
       if constexpr (random_access_iterator<_Iter1>) {
-        auto __ret = std::__search_n_random_access_impl<_RangeAlgPolicy>(
-            __first, __last, __count, __value, __pred, __proj, __size);
+        auto __ret =
+            std::__search_n_random_access_impl<_RangeAlgPolicy>(__first, __count, __value, __pred, __proj, __size);
         return {std::move(__ret.first), std::move(__ret.second)};
       }
     }
diff --git a/libcxx/include/__algorithm/search_n.h b/libcxx/include/__algorithm/search_n.h
index 38474e1b2379d..67a12af971464 100644
--- a/libcxx/include/__algorithm/search_n.h
+++ b/libcxx/include/__algorithm/search_n.h
@@ -14,11 +14,7 @@
 #include <__algorithm/iterator_operations.h>
 #include <__config>
 #include <__functional/identity.h>
-#include <__iterator/advance.h>
-#include <__iterator/concepts.h>
-#include <__iterator/distance.h>
 #include <__iterator/iterator_traits.h>
-#include <__ranges/concepts.h>
 #include <__type_traits/enable_if.h>
 #include <__type_traits/invoke.h>
 #include <__type_traits/is_callable.h>
@@ -68,44 +64,60 @@ _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter> __search_
   }
 }
 
-template <class _AlgPolicy, class _Pred, class _Iter, class _Sent, class _SizeT, class _Type, class _Proj, class _DiffT>
+// Finds the longest suffix in [__first, __last) where each element satisfies __pred.
+template <class _RAIter, class _Pred, class _Proj, class _ValueT>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _RAIter
+__find_longest_suffix(_RAIter __first, _RAIter __last, const _ValueT& __value, _Pred& __pred, _Proj& __proj) {
+  while (__first != __last) {
+    if (!std::__invoke(__pred, std::__invoke(__proj, *--__last), __value)) {
+      return ++__last;
+    }
+  }
+  return __first;
+}
+
+template <class _AlgPolicy, class _Pred, class _Iter, class _SizeT, class _Type, class _Proj, class _DiffT>
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 std::pair<_Iter, _Iter> __search_n_random_access_impl(
-    _Iter __first, _Sent __last, _SizeT __count, const _Type& __value, _Pred& __pred, _Proj& __proj, _DiffT __size1) {
-  using difference_type = typename iterator_traits<_Iter>::difference_type;
+    _Iter __first, _SizeT __count_in, const _Type& __value, _Pred& __pred, _Proj& __proj, _DiffT __size) {
+  auto __last  = __first + __size;
+  auto __count = static_cast<_DiffT>(__count_in);
+
   if (__count == 0)
     return std::make_pair(__first, __first);
-  if (__size1 < static_cast<_DiffT>(__count)) {
-    _IterOps<_AlgPolicy>::__advance_to(__first, __last);
-    return std::make_pair(__first, __first);
-  }
+  if (__size < __count)
+    return std::make_pair(__last, __last);
+
+  // [__match_start, __match_start + __count) is the subrange which we currently check whether it only contains matching
+  // elements. This subrange is returned in case all the elements match.
+  // [__match_start, __matched_until) is the longest subrange where all elements are known to match at any given point
+  // in time.
+  // [__matched_until, __match_start + __count) is the subrange where we don't know whether the elements match.
+
+  // This algorithm tries to expand the subrange [__match_start, __matched_until) into a range of sufficient length.
+  // When we fail to do that because we find a mismatching element, we move it forward to the beginning of the next
+  // consecutive sequence that is not known not to match.
+
+  const _Iter __try_match_until = __last - __count;
+  _Iter __match_start           = __first;
+  _Iter __matched_until         = __first;
 
-  const auto __s = __first + __size1 - difference_type(__count - 1); // Start of pattern match can't go beyond here
   while (true) {
-    // Find first element in sequence that matchs __value, with a mininum of loop checks
-    while (true) {
-      if (__first >= __s) { // return __last if no element matches __value
-        _IterOps<_AlgPolicy>::__advance_to(__first, __last);
-        return std::make_pair(__first, __first);
-      }
-      if (std::__invoke(__pred, std::__invoke(__proj, *__first), __value))
-        break;
-      ++__first;
-    }
-    // *__first matches __value_, now match elements after here
-    auto __m = __first;
-    _SizeT __c(0);
-    while (true) {
-      if (++__c == __count) // If pattern exhausted, __first is the answer (works for 1 element pattern)
-        return std::make_pair(__first, __first + _DiffT(__count));
-      ++__m; // no need to check range on __m because __s guarantees we have enough source
+    // There's no chance of expanding the subrange into a sequence of sufficient length, since we don't have enough
+    // elements in the haystack anymore.
+    if (__match_start > __try_match_until)
+      return std::make_pair(__last, __last);
 
-      // if there is a mismatch, restart with a new __first
-      if (!std::__invoke(__pred, std::__invoke(__proj, *__m), __value)) {
-        __first = __m;
-        ++__first;
-        break;
-      } // else there is a match, check next elements
-    }
+    auto __mismatch = std::__find_longest_suffix(__matched_until, __match_start + __count, __value, __pred, __proj);
+
+    // If all elements in [__matched_until, __match_start + __count) match, we know that
+    // [__match_start, __match_start + __count) is a full sequence of matching elements, so we're done.
+    if (__mismatch == __matched_until)
+      return std::make_pair(__match_start, __match_start + __count);
+
+    // Otherwise, we have to move the [__match_start, __matched_until) subrange forward past the point where we know for
+    // sure a match is impossible.
+    __matched_until = __match_start + __count;
+    __match_start   = ++__mismatch;
   }
 }
 
@@ -119,7 +131,7 @@ template <class _Iter,
 _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pair<_Iter, _Iter>
 __search_n_impl(_Iter __first, _Sent __last, _DiffT __count, const _Type& __value, _Pred& __pred, _Proj& __proj) {
   return std::__search_n_random_access_impl<_ClassicAlgPolicy>(
-      __first, __last, __count, __value, __pred, __proj, __last - __first);
+      __first, __count, __value, __pred, __proj, __last - __first);
 }
 
 template <class _Iter1,
diff --git a/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp b/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
index de404fedaed3a..91c6ad3eafa0d 100644
--- a/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
+++ b/libcxx/test/benchmarks/algorithms/nonmodifying/search_n.bench.cpp
@@ -59,10 +59,56 @@ int main(int argc, char** argv) {
               benchmark::DoNotOptimize(result);
             }
           })
-          ->Arg(1000) // non power-of-two
+          ->Arg(32)
           ->Arg(1024)
-          ->Arg(8192)
-          ->Arg(1 << 20);
+          ->Arg(8192);
+    };
+    // {std,ranges}::search_n
+    bm.operator()<std::vector<int>>("std::search_n(vector<int>) (no match)", std_search_n);
+    bm.operator()<std::deque<int>>("std::search_n(deque<int>) (no match)", std_search_n);
+    bm.operator()<std::list<int>>("std::search_n(list<int>) (no match)", std_search_n);
+    bm.operator()<std::vector<int>>("rng::search_n(vector<int>) (no match)", std::ranges::search_n);
+    bm.operator()<std::deque<int>>("rng::search_n(deque<int>) (no match)", std::ranges::search_n);
+    bm.operator()<std::list<int>>("rng::search_n(list<int>) (no match)", std::ranges::search_n);
+
+    // {std,ranges}::search_n(pred)
+    bm.operator()<std::vector<int>>("std::search_n(vector<int>, pred) (no match)", std_search_n_pred);
+    bm.operator()<std::deque<int>>("std::search_n(deque<int>, pred) (no match)", std_search_n_pred);
+    bm.operator()<std::list<int>>("std::search_n(list<int>, pred) (no match)", std_search_n_pred);
+    bm.operator()<std::vector<int>>("rng::search_n(vector<int>, pred) (no match)", ranges_search_n_pred);
+    bm.operator()<std::deque<int>>("rng::search_n(deque<int>, pred) (no match)", ranges_search_n_pred);
+    bm.operator()<std::list<int>>("rng::search_n(list<int>, pred) (no match)", ranges_search_n_pred);
+  }
+
+  // Benchmark {std,ranges}::search_n where the needle almost matches a lot.
+  {
+    auto bm = []<class Container>(std::string name, auto search_n) {
+      benchmark::RegisterBenchmark(
+          name,
+          [search_n](auto& st) {
+            std::size_t const size = st.range(0);
+            using ValueType        = typename Container::value_type;
+            ValueType x            = Generate<ValueType>::random();
+            ValueType y            = random_different_from({x});
+            Container haystack(size, x);
+            std::size_t n = size / 10; // needle size is 10% of the haystack
+
+            // Make sure there are no actual matches
+            for (size_t i = 0; i < size; i += getRandomInteger<size_t>(1, n - 1)) {
+              *std::next(haystack.begin(), i) = y;
+            }
+
+            for ([[maybe_unused]] auto _ : st) {
+              benchmark::DoNotOptimize(haystack);
+              benchmark::DoNotOptimize(n);
+              benchmark::DoNotOptimize(y);
+              auto result = search_n(haystack.begin(), haystack.end(), n, x);
+              benchmark::DoNotOptimize(result);
+            }
+          })
+          ->Arg(32)
+          ->Arg(1024)
+          ->Arg(8192);
     };
     // {std,ranges}::search_n
     bm.operator()<std::vector<int>>("std::search_n(vector<int>) (no match)", std_search_n);
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
index 2f2e436c79130..f68c31ead7b8f 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/ranges.search_n.pass.cpp
@@ -186,18 +186,50 @@ constexpr void test_iterators() {
   }
 
   { // check that the first match is returned
-    {
-      int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
-      auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
-      assert(base(ret.begin()) == a);
-      assert(base(ret.end()) == a + 2);
+    { // Match is at the start
+      {
+        int a[]  = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+        auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
+        assert(base(ret.begin()) == a);
+        assert(base(ret.end()) == a + 2);
+      }
+      {
+        int a[]    = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+        auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+        auto ret   = std::ranges::search_n(range, 2, 6);
+        assert(base(ret.begin()) == a);
+        assert(base(ret.end()) == a + 2);
+      }
     }
-    {
-      int a[] = {6, 6, 8, 6, 6, 8, 6, 6, 8};
-      auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
-      auto ret = std::ranges::search_n(range, 2, 6);
-      assert(base(ret.begin()) == a);
-      assert(base(ret.end()) == a + 2);
+    { // Match is in the middle
+      {
+        int a[]  = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+        auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 2, 6);
+        assert(base(ret.begin()) == a + 3);
+        assert(base(ret.end()) == a + 5);
+      }
+      {
+        int a[]    = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+        auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+        auto ret   = std::ranges::search_n(range, 2, 6);
+        assert(base(ret.begin()) == a + 3);
+        assert(base(ret.end()) == a + 5);
+      }
+    }
+    { // Match is at the end
+      {
+        int a[]  = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+        auto ret = std::ranges::search_n(Iter(a), Sent(Iter(a + 9)), 3, 6);
+        assert(base(ret.begin()) == a + 6);
+        assert(base(ret.end()) == a + 9);
+      }
+      {
+        int a[]    = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+        auto range = std::ranges::subrange(Iter(a), Sent(Iter(a + 9)));
+        auto ret   = std::ranges::search_n(range, 3, 6);
+        assert(base(ret.begin()) == a + 6);
+        assert(base(ret.end()) == a + 9);
+      }
     }
   }
 
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
index 228bc6b768cf9..514dc9a911302 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.search/search_n.pass.cpp
@@ -14,79 +14,83 @@
 //            const T& value);
 
 #include <algorithm>
+#include <array>
 #include <cassert>
 
 #include "test_macros.h"
 #include "test_iterators.h"
-#include "user_defined_integral.h"
-
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {0, 0, 1, 1, 2, 2};
-    return    (std::search_n(std::begin(ia), std::end(ia), 1, 0) == ia)
-           && (std::search_n(std::begin(ia), std::end(ia), 2, 1) == ia+2)
-           && (std::search_n(std::begin(ia), std::end(ia), 1, 3) == std::end(ia))
-           ;
-    }
-#endif
 
 template <class Iter>
-void
-test()
-{
-    int ia[] = {0, 1, 2, 3, 4, 5};
-    const unsigned sa = sizeof(ia)/sizeof(ia[0]);
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 0) == Iter(ia));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 0) == Iter(ia+0));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 0) == Iter(ia+sa));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 0) == Iter(ia+sa));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 3) == Iter(ia));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 3) == Iter(ia+3));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 3) == Iter(ia+sa));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 3) == Iter(ia+sa));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 0, 5) == Iter(ia));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 1, 5) == Iter(ia+5));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), 2, 5) == Iter(ia+sa));
-    assert(std::search_n(Iter(ia), Iter(ia+sa), sa, 5) == Iter(ia+sa));
-
-    int ib[] = {0, 0, 1, 1, 2, 2};
-    const unsigned sb = sizeof(ib)/sizeof(ib[0]);
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 0) == Iter(ib));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 0) == Iter(ib+0));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 0) == Iter(ib+0));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 0) == Iter(ib+sb));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 0) == Iter(ib+sb));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 1) == Iter(ib));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 1) == Iter(ib+2));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 1) == Iter(ib+2));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 1) == Iter(ib+sb));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 1) == Iter(ib+sb));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 0, 2) == Iter(ib));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 1, 2) == Iter(ib+4));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 2, 2) == Iter(ib+4));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), 3, 2) == Iter(ib+sb));
-    assert(std::search_n(Iter(ib), Iter(ib+sb), sb, 2) == Iter(ib+sb));
-
-    int ic[] = {0, 0, 0};
-    const unsigned sc = sizeof(ic)/sizeof(ic[0]);
-    assert(std::search_n(Iter(ic), Iter(ic+sc), 0, 0) == Iter(ic));
-    assert(std::search_n(Iter(ic), Iter(ic+sc), 1, 0) == Iter(ic));
-    assert(std::search_n(Iter(ic), Iter(ic+sc), 2, 0) == Iter(ic));
-    assert(std::search_n(Iter(ic), Iter(ic+sc), 3, 0) == Iter(ic));
-    assert(std::search_n(Iter(ic), Iter(ic+sc), 4, 0) == Iter(ic+sc));
+TEST_CONSTEXPR_CXX20 bool test() {
+  { // simple test
+    int a[]  = {1, 2, 3, 4, 5, 6};
+    auto ret = std::search_n(Iter(a), Iter(a + 6), 1, 3);
+    assert(base(ret) == a + 2);
+  }
+  { // matching part begins at the front
+    int a[]  = {7, 7, 3, 7, 3, 6};
+    auto ret = std::search_n(Iter(a), Iter(a + 6), 2, 7);
+    assert(base(ret) == a);
+  }
+  { // matching part ends at the back
+    int a[]  = {9, 3, 6, 4, 4};
+    auto ret = std::search_n(Iter(a), Iter(a + 5), 2, 4);
+    assert(base(ret) == a + 3);
+  }
+  { // pattern does not match
+    int a[]  = {9, 3, 6, 4, 8};
+    auto ret = std::search_n(Iter(a), Iter(a + 5), 1, 1);
+    assert(base(ret) == a + 5);
+  }
+  { // range and pattern are identical
+    int a[]  = {1, 1, 1, 1};
+    auto ret = std::search_n(Iter(a), Iter(a + 4), 4, 1);
+    assert(base(ret) == a);
+  }
+  { // pattern is longer than range
+    int a[]  = {3, 3, 3};
+    auto ret = std::search_n(Iter(a), Iter(a + 3), 4, 3);
+    assert(base(ret) == a + 3);
+  }
+  { // pattern has zero length
+    int a[]  = {6, 7, 8};
+    auto ret = std::search_n(Iter(a), Iter(a + 3), 0, 7);
+    assert(base(ret) == a);
+  }
+  { // range has zero length
+    std::array<int, 0> a = {};
+    auto ret             = std::search_n(Iter(a.data()), Iter(a.data()), 1, 1);
+    assert(base(ret) == a.data());
+  }
+  {   // check that the first match is returned
+    { // Match is at the start
+      int a[]  = {6, 6, 8, 6, 6, 8, 6, 6, 8};
+      auto ret = std::search_n(Iter(a), Iter(a + 9), 2, 6);
+      assert(base(ret) == a);
+    }
+    { // Match is in the middle
+      int a[]  = {6, 8, 8, 6, 6, 8, 6, 6, 8};
+      auto ret = std::search_n(Iter(a), Iter(a + 9), 2, 6);
+      assert(base(ret) == a + 3);
+    }
+    { // Match is at the end
+      int a[]  = {6, 6, 8, 6, 6, 8, 6, 6, 6};
+      auto ret = std::search_n(Iter(a), Iter(a + 9), 3, 6);
+      assert(base(ret) == a + 6);
+    }
+  }
 
-    // Check that we properly convert the size argument to an integral.
-    (void)std::search_n(Iter(ic), Iter(ic+sc), UserDefinedIntegral<unsigned>(0), 0);
+  return true;
 }
 
-int main(int, char**)
-{
-    test<forward_iterator<const int*> >();
-    test<bidirectional_iterator<const int*> >();
-    test<random_access_iterator<const int*> >();
-
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
+int main(int, char**) {
+  test<forward_iterator<const int*> >();
+  test<bidirectional_iterator<const int*> >();
+  test<random_access_iterator<const int*> >();
+#if TEST_STD_VER >= 20
+  static_assert(test<forward_iterator<const int*> >());
+  static_assert(test<bidirectional_iterator<const int*> >());
+  static_assert(test<random_access_iterator<const int*> >());
 #endif
 
   return 0;



More information about the libcxx-commits mailing list