[llvm] 38774c4 - [ADT] Avoid needless iterator copies in `zippy`

Jakub Kuderski via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 6 11:45:19 PST 2023


Author: Jakub Kuderski
Date: 2023-03-06T14:39:40-05:00
New Revision: 38774c4f39a78deab1dd35ff45bc557300cd9b29

URL: https://github.com/llvm/llvm-project/commit/38774c4f39a78deab1dd35ff45bc557300cd9b29
DIFF: https://github.com/llvm/llvm-project/commit/38774c4f39a78deab1dd35ff45bc557300cd9b29.diff

LOG: [ADT] Avoid needless iterator copies in `zippy`

Make `zip_common` increment and decrement iterators in place.

This improves performance with iterator types that have non-triviall
copy constructors.

Reviewed By: zero9178

Differential Revision: https://reviews.llvm.org/D145337

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/unittests/ADT/IteratorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 74fbfc958beac..dc7bf6a106862 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -775,6 +775,7 @@ using zip_traits = iterator_facade_base<
 template <typename ZipType, typename... Iters>
 struct zip_common : public zip_traits<ZipType, Iters...> {
   using Base = zip_traits<ZipType, Iters...>;
+  using IndexSequence = std::index_sequence_for<Iters...>;
   using value_type = typename Base::value_type;
 
   std::tuple<Iters...> iterators;
@@ -784,19 +785,17 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
     return value_type(*std::get<Ns>(iterators)...);
   }
 
-  template <size_t... Ns>
-  decltype(iterators) tup_inc(std::index_sequence<Ns...>) const {
-    return std::tuple<Iters...>(std::next(std::get<Ns>(iterators))...);
+  template <size_t... Ns> void tup_inc(std::index_sequence<Ns...>) {
+    (++std::get<Ns>(iterators), ...);
   }
 
-  template <size_t... Ns>
-  decltype(iterators) tup_dec(std::index_sequence<Ns...>) const {
-    return std::tuple<Iters...>(std::prev(std::get<Ns>(iterators))...);
+  template <size_t... Ns> void tup_dec(std::index_sequence<Ns...>) {
+    (--std::get<Ns>(iterators), ...);
   }
 
   template <size_t... Ns>
   bool test_all_equals(const zip_common &other,
-            std::index_sequence<Ns...>) const {
+                       std::index_sequence<Ns...>) const {
     return ((std::get<Ns>(this->iterators) == std::get<Ns>(other.iterators)) &&
             ...);
   }
@@ -804,25 +803,23 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
 public:
   zip_common(Iters &&... ts) : iterators(std::forward<Iters>(ts)...) {}
 
-  value_type operator*() const {
-    return deref(std::index_sequence_for<Iters...>{});
-  }
+  value_type operator*() const { return deref(IndexSequence{}); }
 
   ZipType &operator++() {
-    iterators = tup_inc(std::index_sequence_for<Iters...>{});
-    return *reinterpret_cast<ZipType *>(this);
+    tup_inc(IndexSequence{});
+    return static_cast<ZipType &>(*this);
   }
 
   ZipType &operator--() {
     static_assert(Base::IsBidirectional,
                   "All inner iterators must be at least bidirectional.");
-    iterators = tup_dec(std::index_sequence_for<Iters...>{});
-    return *reinterpret_cast<ZipType *>(this);
+    tup_dec(IndexSequence{});
+    return static_cast<ZipType &>(*this);
   }
 
   /// Return true if all the iterator are matching `other`'s iterators.
   bool all_equals(zip_common &other) {
-    return test_all_equals(other, std::index_sequence_for<Iters...>{});
+    return test_all_equals(other, IndexSequence{});
   }
 };
 

diff  --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp
index b2a11c4c6bd7d..7d10729c2dd9f 100644
--- a/llvm/unittests/ADT/IteratorTest.cpp
+++ b/llvm/unittests/ADT/IteratorTest.cpp
@@ -692,6 +692,49 @@ TEST(ZipIteratorTest, Reverse) {
   EXPECT_TRUE(all_of(ascending, [](unsigned n) { return (n & 0x01) == 0; }));
 }
 
+// Int iterator that keeps track of the number of its copies.
+struct CountingIntIterator : IntIterator {
+  unsigned *cnt;
+
+  CountingIntIterator(int *it, unsigned &counter)
+      : IntIterator(it), cnt(&counter) {}
+
+  CountingIntIterator(const CountingIntIterator &other)
+      : IntIterator(other.I), cnt(other.cnt) {
+    ++(*cnt);
+  }
+  CountingIntIterator &operator=(const CountingIntIterator &other) {
+    this->I = other.I;
+    this->cnt = other.cnt;
+    ++(*cnt);
+    return *this;
+  }
+};
+
+// Check that the iterators do not get copied with each `zippy` iterator
+// increment.
+TEST(ZipIteratorTest, IteratorCopies) {
+  std::vector<int> ints(1000, 42);
+  unsigned total_copy_count = 0;
+  CountingIntIterator begin(ints.data(), total_copy_count);
+  CountingIntIterator end(ints.data() + ints.size(), total_copy_count);
+
+  size_t iters = 0;
+  auto zippy = zip_equal(ints, llvm::make_range(begin, end));
+  const unsigned creation_copy_count = total_copy_count;
+
+  for (auto [a, b] : zippy) {
+    EXPECT_EQ(a, b);
+    ++iters;
+  }
+  EXPECT_EQ(iters, ints.size());
+
+  // We expect the number of copies to be much smaller than the number of loop
+  // iterations.
+  unsigned loop_copy_count = total_copy_count - creation_copy_count;
+  EXPECT_LT(loop_copy_count, 10u);
+}
+
 TEST(RangeTest, Distance) {
   std::vector<int> v1;
   std::vector<int> v2{1, 2, 3};


        


More information about the llvm-commits mailing list