[libcxx-commits] [libcxx] 84ae8cb - [libc++] `std::ranges::advance`: avoid unneeded bounds checks when advancing iterator (#84126)

via libcxx-commits libcxx-commits at lists.llvm.org
Tue Apr 2 16:09:30 PDT 2024


Author: Jan Kokemüller
Date: 2024-04-02T16:09:26-07:00
New Revision: 84ae8cb4af9abafe9f45e69744607aadb38d649a

URL: https://github.com/llvm/llvm-project/commit/84ae8cb4af9abafe9f45e69744607aadb38d649a
DIFF: https://github.com/llvm/llvm-project/commit/84ae8cb4af9abafe9f45e69744607aadb38d649a.diff

LOG: [libc++] `std::ranges::advance`: avoid unneeded bounds checks when advancing iterator (#84126)

Currently, the bounds check in `std::ranges::advance(it, n, s)` is done
_before_ `n` is checked. This results in one extra, unneeded bounds
check.

Thus, `std::ranges::advance(it, 1, s)` currently is _not_ simply
equivalent to:

```c++
if (it != s) {
    ++it;
}
```

This difference in behavior matters when the check involves some
"expensive" logic. For example, the `==` operator of
`std::istreambuf_iterator` may actually have to read the underlying
`streambuf`.

Swapping around the checks in the `while` results in the expected
behavior.

Added: 
    

Modified: 
    libcxx/include/__iterator/advance.h
    libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
    libcxx/test/support/test_iterators.h

Removed: 
    


################################################################################
diff  --git a/libcxx/include/__iterator/advance.h b/libcxx/include/__iterator/advance.h
index 7959bdeae32643..296db1aaab6526 100644
--- a/libcxx/include/__iterator/advance.h
+++ b/libcxx/include/__iterator/advance.h
@@ -170,14 +170,14 @@ struct __fn {
     } else {
       // Otherwise, if `n` is non-negative, while `bool(i != bound_sentinel)` is true, increments `i` but at
       // most `n` times.
-      while (__i != __bound_sentinel && __n > 0) {
+      while (__n > 0 && __i != __bound_sentinel) {
         ++__i;
         --__n;
       }
 
       // Otherwise, while `bool(i != bound_sentinel)` is true, decrements `i` but at most `-n` times.
       if constexpr (bidirectional_iterator<_Ip> && same_as<_Ip, _Sp>) {
-        while (__i != __bound_sentinel && __n < 0) {
+        while (__n < 0 && __i != __bound_sentinel) {
           --__i;
           ++__n;
         }

diff  --git a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
index a1c15640182162..76439ef93a607a 100644
--- a/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
+++ b/libcxx/test/std/iterators/iterator.primitives/range.iter.ops/range.iter.ops.advance/iterator_count_sentinel.pass.cpp
@@ -21,9 +21,12 @@
 #include "../types.h"
 
 template <bool Count, typename It>
-constexpr void check_forward(int* first, int* last, std::iter_
diff erence_t<It> n, int* expected) {
+constexpr void
+check_forward(int* first, int* last, std::iter_
diff erence_t<It> n, int* expected, int expected_equals_count = -1) {
   using Difference = std::iter_
diff erence_t<It>;
   Difference const M = (expected - first); // expected travel distance
+  // `expected_equals_count` is only relevant when `Count` is true.
+  assert(Count || expected_equals_count == -1);
 
   {
     It it(first);
@@ -42,6 +45,7 @@ constexpr void check_forward(int* first, int* last, std::iter_
diff erence_t<It> n
     // regardless of the iterator category.
     assert(it.stride_count() == M);
     assert(it.stride_displacement() == M);
+    assert(it.equals_count() == expected_equals_count);
   }
 }
 
@@ -74,9 +78,20 @@ constexpr void check_forward_sized_sentinel(int* first, int* last, std::iter_dif
   }
 }
 
-template <typename It>
-constexpr void check_backward(int* first, int* last, std::iter_
diff erence_t<It> n, int* expected) {
-  static_assert(std::random_access_iterator<It>, "This test doesn't support non random access iterators");
+struct Expected {
+  int stride_count;
+  int stride_displacement;
+  int equals_count;
+};
+
+template <bool Count, typename It>
+constexpr void
+check_backward(int* first, int* last, std::iter_
diff erence_t<It> n, int* expected, Expected expected_counts) {
+  // Check preconditions for `advance` when called with negative `n`
+  // (see [range.iter.op.advance]). In addition, allow `n == 0`.
+  assert(n <= 0);
+  static_assert(std::bidirectional_iterator<It>);
+
   using Difference = std::iter_
diff erence_t<It>;
   Difference const M = (expected - last); // expected travel distance (which is negative)
 
@@ -92,9 +107,14 @@ constexpr void check_backward(int* first, int* last, std::iter_
diff erence_t<It>
   {
     auto it = stride_counting_iterator(It(last));
     auto sent = stride_counting_iterator(It(first));
+    static_assert(std::bidirectional_iterator<stride_counting_iterator<It>>);
+    static_assert(Count == !std::sized_sentinel_for<It, It>);
+
     (void)std::ranges::advance(it, n, sent);
-    assert(it.stride_count() <= 1);
-    assert(it.stride_displacement() <= 1);
+
+    assert(it.stride_count() == expected_counts.stride_count);
+    assert(it.stride_displacement() == expected_counts.stride_displacement);
+    assert(it.equals_count() == expected_counts.equals_count);
   }
 }
 
@@ -171,13 +191,17 @@ constexpr bool test() {
 
       {
         int* expected = n > size ? range + size : range + n;
+        int equals_count = n > size ? size + 1 : n;
+
+        // clang-format off
         check_forward<false, cpp17_input_iterator<int*>>(  range, range+size, n, expected);
         check_forward<false, cpp20_input_iterator<int*>>(  range, range+size, n, expected);
-        check_forward<true,  forward_iterator<int*>>(      range, range+size, n, expected);
-        check_forward<true,  bidirectional_iterator<int*>>(range, range+size, n, expected);
-        check_forward<true,  random_access_iterator<int*>>(range, range+size, n, expected);
-        check_forward<true,  contiguous_iterator<int*>>(   range, range+size, n, expected);
-        check_forward<true,  int*>(                        range, range+size, n, expected);
+        check_forward<true,  forward_iterator<int*>>(      range, range+size, n, expected, equals_count);
+        check_forward<true,  bidirectional_iterator<int*>>(range, range+size, n, expected, equals_count);
+        check_forward<true,  random_access_iterator<int*>>(range, range+size, n, expected, equals_count);
+        check_forward<true,  contiguous_iterator<int*>>(   range, range+size, n, expected, equals_count);
+        check_forward<true,  int*>(                        range, range+size, n, expected, equals_count);
+        // clang-format on
 
         check_forward_sized_sentinel<cpp17_input_iterator<int*>>(  range, range+size, n, expected);
         check_forward_sized_sentinel<cpp20_input_iterator<int*>>(  range, range+size, n, expected);
@@ -188,14 +212,32 @@ constexpr bool test() {
         check_forward_sized_sentinel<int*>(                        range, range+size, n, expected);
       }
 
+      // Input and forward iterators are not tested as the backwards case does
+      // not apply for them.
       {
-        // Note that we can only test ranges::advance with a negative n for iterators that
-        // are sized sentinels for themselves, because ranges::advance is UB otherwise.
-        // In particular, that excludes bidirectional_iterators since those are not sized sentinels.
         int* expected = n > size ? range : range + size - n;
-        check_backward<random_access_iterator<int*>>(range, range+size, -n, expected);
-        check_backward<contiguous_iterator<int*>>(   range, range+size, -n, expected);
-        check_backward<int*>(                        range, range+size, -n, expected);
+        {
+          Expected expected_counts = {
+              .stride_count        = static_cast<int>(range + size - expected),
+              .stride_displacement = -expected_counts.stride_count,
+              .equals_count        = n > size ? size + 1 : n,
+          };
+
+          check_backward<true, bidirectional_iterator<int*>>(range, range + size, -n, expected, expected_counts);
+        }
+        {
+          Expected expected_counts = {
+              // If `n >= size`, the algorithm can just do `it = std::move(sent);`
+              // instead of doing iterator arithmetic.
+              .stride_count        = (n >= size) ? 0 : 1,
+              .stride_displacement = (n >= size) ? 0 : 1,
+              .equals_count        = 0,
+          };
+
+          check_backward<false, random_access_iterator<int*>>(range, range + size, -n, expected, expected_counts);
+          check_backward<false, contiguous_iterator<int*>>(range, range + size, -n, expected, expected_counts);
+          check_backward<false, int*>(range, range + size, -n, expected, expected_counts);
+        }
       }
     }
   }

diff  --git a/libcxx/test/support/test_iterators.h b/libcxx/test/support/test_iterators.h
index c92ce375348ff4..7ffb74990fa4dd 100644
--- a/libcxx/test/support/test_iterators.h
+++ b/libcxx/test/support/test_iterators.h
@@ -725,11 +725,14 @@ struct common_input_iterator {
 #  endif // TEST_STD_VER >= 20
 
 // Iterator adaptor that counts the number of times the iterator has had a successor/predecessor
-// operation called. Has two recorders:
+// operation or an equality comparison operation called. Has three recorders:
 // * `stride_count`, which records the total number of calls to an op++, op--, op+=, or op-=.
 // * `stride_displacement`, which records the displacement of the calls. This means that both
 //   op++/op+= will increase the displacement counter by 1, and op--/op-= will decrease the
 //   displacement counter by 1.
+// * `equals_count`, which records the total number of calls to an op== or op!=. If compared
+//   against a sentinel object, that sentinel object must call the `record_equality_comparison`
+//   function so that the comparison is counted correctly.
 template <class It>
 class stride_counting_iterator {
 public:
@@ -754,6 +757,8 @@ class stride_counting_iterator {
 
     constexpr 
diff erence_type stride_displacement() const { return stride_displacement_; }
 
+    constexpr 
diff erence_type equals_count() const { return equals_count_; }
+
     constexpr decltype(auto) operator*() const { return *It(base_); }
 
     constexpr decltype(auto) operator[](
diff erence_type n) const { return It(base_)[n]; }
@@ -838,10 +843,13 @@ class stride_counting_iterator {
         return base(x) - base(y);
     }
 
+    constexpr void record_equality_comparison() const { ++equals_count_; }
+
     constexpr bool operator==(stride_counting_iterator const& other) const
         requires std::sentinel_for<It, It>
     {
-        return It(base_) == It(other.base_);
+      record_equality_comparison();
+      return It(base_) == It(other.base_);
     }
 
     friend constexpr bool operator<(stride_counting_iterator const& x, stride_counting_iterator const& y)
@@ -875,6 +883,7 @@ class stride_counting_iterator {
     decltype(base(std::declval<It>())) base_;
     
diff erence_type stride_count_ = 0;
     
diff erence_type stride_displacement_ = 0;
+    mutable 
diff erence_type equals_count_ = 0;
 };
 template <class It>
 stride_counting_iterator(It) -> stride_counting_iterator<It>;
@@ -887,7 +896,14 @@ class sentinel_wrapper {
 public:
     explicit sentinel_wrapper() = default;
     constexpr explicit sentinel_wrapper(const It& it) : base_(base(it)) {}
-    constexpr bool operator==(const It& other) const { return base_ == base(other); }
+    constexpr bool operator==(const It& other) const {
+      // If supported, record statistics about the equality operator call
+      // inside `other`.
+      if constexpr (requires { other.record_equality_comparison(); }) {
+        other.record_equality_comparison();
+      }
+      return base_ == base(other);
+    }
     friend constexpr It base(const sentinel_wrapper& s) { return It(s.base_); }
 private:
     decltype(base(std::declval<It>())) base_;


        


More information about the libcxx-commits mailing list