[libcxx-commits] [libcxx] [libc++] `std::ranges::advance`: avoid unneeded bounds checks when advancing iterator (PR #84126)

Jan Kokemüller via libcxx-commits libcxx-commits at lists.llvm.org
Fri Mar 29 02:54:55 PDT 2024


https://github.com/jiixyj updated https://github.com/llvm/llvm-project/pull/84126

>From dc039616cee1453a38729a4f2f6adc03fd93549c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Wed, 6 Mar 2024 07:49:25 +0100
Subject: [PATCH 1/6] skip unneeded checks against '__bound_sentinel' when
 advancing iterator

---
 libcxx/include/__iterator/advance.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/libcxx/include/__iterator/advance.h b/libcxx/include/__iterator/advance.h
index 7959bdeae32643..296db1aaab6526 100644
--- a/libcxx/include/__iterator/advance.h
+++ b/libcxx/include/__iterator/advance.h
@@ -170,14 +170,14 @@ struct __fn {
     } else {
       // Otherwise, if `n` is non-negative, while `bool(i != bound_sentinel)` is true, increments `i` but at
       // most `n` times.
-      while (__i != __bound_sentinel && __n > 0) {
+      while (__n > 0 && __i != __bound_sentinel) {
         ++__i;
         --__n;
       }
 
       // Otherwise, while `bool(i != bound_sentinel)` is true, decrements `i` but at most `-n` times.
       if constexpr (bidirectional_iterator<_Ip> && same_as<_Ip, _Sp>) {
-        while (__i != __bound_sentinel && __n < 0) {
+        while (__n < 0 && __i != __bound_sentinel) {
           --__i;
           ++__n;
         }

>From b23e7fe87d6eaa65cf684749e2eb1ceee51176d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Sat, 9 Mar 2024 11:22:59 +0100
Subject: [PATCH 2/6] add tests

---
 .../iterator_count_sentinel.pass.cpp          | 22 ++++++++++++++++++
 libcxx/test/support/test_iterators.h          | 23 +++++++++++++++++--
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index a1c15640182162..72584e60d9f89f 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -42,6 +42,13 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
     // regardless of the iterator category.
     assert(it.stride_count() == M);
     assert(it.stride_displacement() == M);
+    if (n == 0) {
+      assert(it.equals_count() == 0);
+    } else {
+      assert(it.equals_count() > 0);
+      assert(it.equals_count() == M || it.equals_count() == M + 1);
+      assert(it.equals_count() <= n);
+    }
   }
 }
 
@@ -95,6 +102,7 @@ constexpr void check_backward(int* first, int* last, std::iter_difference_t<It>
     (void)std::ranges::advance(it, n, sent);
     assert(it.stride_count() <= 1);
     assert(it.stride_displacement() <= 1);
+    assert(it.equals_count() == 0);
   }
 }
 
@@ -213,6 +221,20 @@ constexpr bool test() {
     assert(i == iota_iterator{INT_MIN+1});
   }
 
+  // Check that we don't do an unneeded bounds check when decrementing a
+  // `bidirectional_iterator` that doesn't model `sized_sentinel_for`.
+  {
+    static_assert(std::bidirectional_iterator<bidirectional_iterator<iota_iterator>>);
+    static_assert(!std::sized_sentinel_for<bidirectional_iterator<iota_iterator>,
+                                           bidirectional_iterator<iota_iterator>>);
+
+    auto it = stride_counting_iterator(bidirectional_iterator(iota_iterator{+1}));
+    auto sent = stride_counting_iterator(bidirectional_iterator(iota_iterator{-2}));
+    assert(std::ranges::advance(it, -3, sent) == 0);
+    assert(base(base(it)) == iota_iterator{-2});
+    assert(it.equals_count() == 3);
+  }
+
   return true;
 }
 
diff --git a/libcxx/test/support/test_iterators.h b/libcxx/test/support/test_iterators.h
index c92ce375348ff4..e551ab5a62d08b 100644
--- a/libcxx/test/support/test_iterators.h
+++ b/libcxx/test/support/test_iterators.h
@@ -725,11 +725,14 @@ struct common_input_iterator {
 #  endif // TEST_STD_VER >= 20
 
 // Iterator adaptor that counts the number of times the iterator has had a successor/predecessor
-// operation called. Has two recorders:
+// operation or an equality comparison operation called. Has three recorders:
 // * `stride_count`, which records the total number of calls to an op++, op--, op+=, or op-=.
 // * `stride_displacement`, which records the displacement of the calls. This means that both
 //   op++/op+= will increase the displacement counter by 1, and op--/op-= will decrease the
 //   displacement counter by 1.
+// * `equals_count`, which records the total number of calls to an op== or op!=. If compared
+//   against a sentinel object, that sentinel object must call the `record_equality_comparison`
+//   function so that the comparison is counted correctly.
 template <class It>
 class stride_counting_iterator {
 public:
@@ -754,6 +757,8 @@ class stride_counting_iterator {
 
     constexpr difference_type stride_displacement() const { return stride_displacement_; }
 
+    constexpr difference_type equals_count() const { return equals_count_; }
+
     constexpr decltype(auto) operator*() const { return *It(base_); }
 
     constexpr decltype(auto) operator[](difference_type n) const { return It(base_)[n]; }
@@ -838,9 +843,15 @@ class stride_counting_iterator {
         return base(x) - base(y);
     }
 
+    constexpr void record_equality_comparison() const
+    {
+        ++equals_count_;
+    }
+
     constexpr bool operator==(stride_counting_iterator const& other) const
         requires std::sentinel_for<It, It>
     {
+        record_equality_comparison();
         return It(base_) == It(other.base_);
     }
 
@@ -875,6 +886,7 @@ class stride_counting_iterator {
     decltype(base(std::declval<It>())) base_;
     difference_type stride_count_ = 0;
     difference_type stride_displacement_ = 0;
+    mutable difference_type equals_count_ = 0;
 };
 template <class It>
 stride_counting_iterator(It) -> stride_counting_iterator<It>;
@@ -887,7 +899,14 @@ class sentinel_wrapper {
 public:
     explicit sentinel_wrapper() = default;
     constexpr explicit sentinel_wrapper(const It& it) : base_(base(it)) {}
-    constexpr bool operator==(const It& other) const { return base_ == base(other); }
+    constexpr bool operator==(const It& other) const {
+      // If supported, record statistics about the equality operator call
+      // inside `other`.
+      if constexpr (requires { other.record_equality_comparison(); }) {
+        other.record_equality_comparison();
+      }
+      return base_ == base(other);
+    }
     friend constexpr It base(const sentinel_wrapper& s) { return It(s.base_); }
 private:
     decltype(base(std::declval<It>())) base_;

>From 4051b3f569e6f70cb55755537123cdd67ddcd2c4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Sat, 9 Mar 2024 11:27:49 +0100
Subject: [PATCH 3/6] apply clang-format

---
 .../iterator_count_sentinel.pass.cpp                     | 6 +++---
 libcxx/test/support/test_iterators.h                     | 9 +++------
 2 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index 72584e60d9f89f..5dffb5267c391b 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -225,10 +225,10 @@ constexpr bool test() {
   // `bidirectional_iterator` that doesn't model `sized_sentinel_for`.
   {
     static_assert(std::bidirectional_iterator<bidirectional_iterator<iota_iterator>>);
-    static_assert(!std::sized_sentinel_for<bidirectional_iterator<iota_iterator>,
-                                           bidirectional_iterator<iota_iterator>>);
+    static_assert(
+        !std::sized_sentinel_for<bidirectional_iterator<iota_iterator>, bidirectional_iterator<iota_iterator>>);
 
-    auto it = stride_counting_iterator(bidirectional_iterator(iota_iterator{+1}));
+    auto it   = stride_counting_iterator(bidirectional_iterator(iota_iterator{+1}));
     auto sent = stride_counting_iterator(bidirectional_iterator(iota_iterator{-2}));
     assert(std::ranges::advance(it, -3, sent) == 0);
     assert(base(base(it)) == iota_iterator{-2});
diff --git a/libcxx/test/support/test_iterators.h b/libcxx/test/support/test_iterators.h
index e551ab5a62d08b..7ffb74990fa4dd 100644
--- a/libcxx/test/support/test_iterators.h
+++ b/libcxx/test/support/test_iterators.h
@@ -843,16 +843,13 @@ class stride_counting_iterator {
         return base(x) - base(y);
     }
 
-    constexpr void record_equality_comparison() const
-    {
-        ++equals_count_;
-    }
+    constexpr void record_equality_comparison() const { ++equals_count_; }
 
     constexpr bool operator==(stride_counting_iterator const& other) const
         requires std::sentinel_for<It, It>
     {
-        record_equality_comparison();
-        return It(base_) == It(other.base_);
+      record_equality_comparison();
+      return It(base_) == It(other.base_);
     }
 
     friend constexpr bool operator<(stride_counting_iterator const& x, stride_counting_iterator const& y)

>From c41370a113a718f11256ca57d9dd62d4835e46f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Tue, 19 Mar 2024 19:32:55 +0100
Subject: [PATCH 4/6] remove unneeded check

---
 .../range.iter.ops.advance/iterator_count_sentinel.pass.cpp      | 1 -
 1 file changed, 1 deletion(-)

diff --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index 5dffb5267c391b..843dcc39d944ca 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -45,7 +45,6 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
     if (n == 0) {
       assert(it.equals_count() == 0);
     } else {
-      assert(it.equals_count() > 0);
       assert(it.equals_count() == M || it.equals_count() == M + 1);
       assert(it.equals_count() <= n);
     }

>From a4da00177fadef0acdea54aabe8a5c2c8e558187 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Fri, 29 Mar 2024 09:58:02 +0100
Subject: [PATCH 5/6] test for the extra bounds check iff 'n > M'

---
 .../iterator_count_sentinel.pass.cpp                      | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index 843dcc39d944ca..44be579001be13 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -45,7 +45,13 @@ constexpr void check_forward(int* first, int* last, std::iter_difference_t<It> n
     if (n == 0) {
       assert(it.equals_count() == 0);
     } else {
-      assert(it.equals_count() == M || it.equals_count() == M + 1);
+      if (n > M) {
+        // We "hit" the bound, so there is one extra equality check.
+        assert(it.equals_count() == M + 1);
+      } else {
+        assert(it.equals_count() == M);
+      }
+      // In any case, there must not be more than `n` bounds checks.
       assert(it.equals_count() <= n);
     }
   }

>From 7f8ba1f01885567de8e4e6c82f3ca341ee698914 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jan=20Kokem=C3=BCller?= <jan.kokemueller at gmail.com>
Date: Fri, 29 Mar 2024 10:54:06 +0100
Subject: [PATCH 6/6] adapt 'check_backward' test so that it can test
 bidirectional iterators as well

---
 .../iterator_count_sentinel.pass.cpp          | 59 ++++++++++++-------
 1 file changed, 37 insertions(+), 22 deletions(-)

diff --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index 44be579001be13..c3f378d3af1a30 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -88,7 +88,11 @@ constexpr void check_forward_sized_sentinel(int* first, int* last, std::iter_dif
 
 template <typename It>
 constexpr void check_backward(int* first, int* last, std::iter_difference_t<It> n, int* expected) {
-  static_assert(std::random_access_iterator<It>, "This test doesn't support non random access iterators");
+  // Check preconditions for `advance` when called with negative `n`:
+  // <https://eel.is/c++draft/iterators#range.iter.op.advance-5>
+  assert(n < 0);
+  static_assert(std::bidirectional_iterator<It>);
+
   using Difference = std::iter_difference_t<It>;
   Difference const M = (expected - last); // expected travel distance (which is negative)
 
@@ -104,10 +108,34 @@ constexpr void check_backward(int* first, int* last, std::iter_difference_t<It>
   {
     auto it = stride_counting_iterator(It(last));
     auto sent = stride_counting_iterator(It(first));
+    static_assert(std::bidirectional_iterator<stride_counting_iterator<It>>);
+
     (void)std::ranges::advance(it, n, sent);
-    assert(it.stride_count() <= 1);
-    assert(it.stride_displacement() <= 1);
-    assert(it.equals_count() == 0);
+
+    if constexpr (std::sized_sentinel_for<It, It>) {
+      if (expected == first) {
+        // In this case, the algorithm can just do `it = std::move(sent);`
+        // instead of doing iterator arithmetic:
+        // <https://eel.is/c++draft/iterators#range.iter.op.advance-4.1>
+        assert(it.stride_count() == 0);
+        assert(it.stride_displacement() == 0);
+      } else {
+        assert(it.stride_count() == 1);
+        assert(it.stride_displacement() == 1);
+      }
+      assert(it.equals_count() == 0);
+    } else {
+      assert(it.stride_count() == -M);
+      assert(it.stride_displacement() == M);
+      if (-n > -M) {
+        // We "hit" the bound, so there is one extra equality check.
+        assert(it.equals_count() == -M + 1);
+      } else {
+        assert(it.equals_count() == -M);
+      }
+      // In any case, there must not be more than `-n` bounds checks.
+      assert(it.equals_count() <= -n);
+    }
   }
 }
 
@@ -201,11 +229,12 @@ constexpr bool test() {
         check_forward_sized_sentinel<int*>(                        range, range+size, n, expected);
       }
 
-      {
-        // Note that we can only test ranges::advance with a negative n for iterators that
-        // are sized sentinels for themselves, because ranges::advance is UB otherwise.
-        // In particular, that excludes bidirectional_iterators since those are not sized sentinels.
+      // Exclude the `n == 0` case for the backwards checks.
+      // Input and forward iterators are not tested as the backwards case does
+      // not apply for them.
+      if (n > 0) {
         int* expected = n > size ? range : range + size - n;
+        check_backward<bidirectional_iterator<int*>>(range, range+size, -n, expected);
         check_backward<random_access_iterator<int*>>(range, range+size, -n, expected);
         check_backward<contiguous_iterator<int*>>(   range, range+size, -n, expected);
         check_backward<int*>(                        range, range+size, -n, expected);
@@ -226,20 +255,6 @@ constexpr bool test() {
     assert(i == iota_iterator{INT_MIN+1});
   }
 
-  // Check that we don't do an unneeded bounds check when decrementing a
-  // `bidirectional_iterator` that doesn't model `sized_sentinel_for`.
-  {
-    static_assert(std::bidirectional_iterator<bidirectional_iterator<iota_iterator>>);
-    static_assert(
-        !std::sized_sentinel_for<bidirectional_iterator<iota_iterator>, bidirectional_iterator<iota_iterator>>);
-
-    auto it   = stride_counting_iterator(bidirectional_iterator(iota_iterator{+1}));
-    auto sent = stride_counting_iterator(bidirectional_iterator(iota_iterator{-2}));
-    assert(std::ranges::advance(it, -3, sent) == 0);
-    assert(base(base(it)) == iota_iterator{-2});
-    assert(it.equals_count() == 3);
-  }
-
   return true;
 }
 



More information about the libcxx-commits mailing list