[llvm] [ADT] Make Zippy more iterator-like for lifetime safety (PR #112441)

David Blaikie via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 15 14:52:27 PDT 2024


https://github.com/dwblaikie created https://github.com/llvm/llvm-project/pull/112441

@geoffromer identifier/encountered a lifetime issue when using concat+zip, zip would return by value, concat would take references to that value and use them in its result after they had expired.

This is a common problem with range adapters and the lifetime of values.

But it's also non-conforming with the C++ iterator requirements, I think
- partly because op-> should be supported (which I haven't done here) and that basically has to return by pointer.

So the best thing is to stash a value in the iterator and return a pointer/reference to that.

(some context that may or may not be relevant to this part of the code may be in https://github.com/llvm/llvm-project/commit/981ce8fa15afa11d083033240edb1daff29081c7 )

>From 75eb8c747a41c5e9a0275def57e534083d973921 Mon Sep 17 00:00:00 2001
From: David Blaikie <dblaikie at gmail.com>
Date: Tue, 15 Oct 2024 21:48:24 +0000
Subject: [PATCH] [ADT] Make Zippy more iterator-like for lifetime safety

@geoffromer identifier/encountered a lifetime issue when using
concat+zip, zip would return by value, concat would take references to
that value and use them in its result after they had expired.

This is a common problem with range adapters and the lifetime of values.

But it's also non-conforming with the C++ iterator requirements, I think
- partly because op-> should be supported (which I haven't done here)
and that basically has to return by pointer.

So the best thing is to stash a value in the iterator and return a
pointer/reference to that.
---
 llvm/include/llvm/ADT/STLExtras.h             |  8 +++--
 .../Transforms/Vectorize/SLPVectorizer.cpp    |  2 +-
 llvm/unittests/ADT/IteratorTest.cpp           | 36 ++++++++++---------
 3 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index eb441bb31c9bc8..3692b199c9ec1c 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -704,10 +704,12 @@ struct zip_common : public zip_traits<ZipType, ReferenceTupleType, Iters...> {
   using value_type = typename Base::value_type;
 
   std::tuple<Iters...> iterators;
+  mutable std::optional<value_type> value;
 
 protected:
-  template <size_t... Ns> value_type deref(std::index_sequence<Ns...>) const {
-    return value_type(*std::get<Ns>(iterators)...);
+  template <size_t... Ns> const value_type &deref(std::index_sequence<Ns...>) const {
+    value.emplace(*std::get<Ns>(iterators)...);
+    return *value;
   }
 
   template <size_t... Ns> void tup_inc(std::index_sequence<Ns...>) {
@@ -728,7 +730,7 @@ struct zip_common : public zip_traits<ZipType, ReferenceTupleType, Iters...> {
 public:
   zip_common(Iters &&... ts) : iterators(std::forward<Iters>(ts)...) {}
 
-  value_type operator*() const { return deref(IndexSequence{}); }
+  const value_type &operator*() const { return deref(IndexSequence{}); }
 
   ZipType &operator++() {
     tup_inc(IndexSequence{});
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 336126cc1fbc21..61aeb03634645c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9991,7 +9991,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
       if (!E->ReorderIndices.empty() && CommonVF == E->ReorderIndices.size() &&
           CommonVF == CommonMask.size() &&
           any_of(enumerate(CommonMask),
-                 [](const auto &&P) {
+                 [](const auto &P) {
                    return P.value() != PoisonMaskElem &&
                           static_cast<unsigned>(P.value()) != P.index();
                  }) &&
diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp
index a0d3c9b564d857..6c9aba8fcad2b8 100644
--- a/llvm/unittests/ADT/IteratorTest.cpp
+++ b/llvm/unittests/ADT/IteratorTest.cpp
@@ -482,39 +482,26 @@ TEST(ZipIteratorTest, ZipEqualConstCorrectness) {
   EXPECT_THAT(first, ElementsAre(0, 0, 0));
   EXPECT_THAT(second, ElementsAre(true, true, true));
 
-  std::vector<bool> nemesis = {true, false, true};
-  const std::vector<bool> c_nemesis = nemesis;
-
-  for (auto &&[a, b, c, d] : zip_equal(first, c_first, nemesis, c_nemesis)) {
+  for (auto &&[a, b] : zip_equal(first, c_first)) {
     a = 2;
-    c = true;
     static_assert(!IsConstRef<decltype(a)>);
     static_assert(IsConstRef<decltype(b)>);
-    static_assert(!IsBoolConstRef<decltype(c)>);
-    static_assert(IsBoolConstRef<decltype(d)>);
   }
 
   EXPECT_THAT(first, ElementsAre(2, 2, 2));
-  EXPECT_THAT(nemesis, ElementsAre(true, true, true));
 
   unsigned iters = 0;
-  for (const auto &[a, b, c, d] :
-       zip_equal(first, c_first, nemesis, c_nemesis)) {
+  for (const auto &[a, b] : zip_equal(first, c_first)) {
     static_assert(!IsConstRef<decltype(a)>);
     static_assert(IsConstRef<decltype(b)>);
-    static_assert(!IsBoolConstRef<decltype(c)>);
-    static_assert(IsBoolConstRef<decltype(d)>);
     ++iters;
   }
   EXPECT_EQ(iters, 3u);
   iters = 0;
 
-  for (const auto &[a, b, c, d] :
-       MakeConst(zip_equal(first, c_first, nemesis, c_nemesis))) {
+  for (const auto &[a, b] : MakeConst(zip_equal(first, c_first))) {
     static_assert(!IsConstRef<decltype(a)>);
     static_assert(IsConstRef<decltype(b)>);
-    static_assert(!IsBoolConstRef<decltype(c)>);
-    static_assert(IsBoolConstRef<decltype(d)>);
     ++iters;
   }
   EXPECT_EQ(iters, 3u);
@@ -643,6 +630,23 @@ TEST(ZipIteratorTest, Mutability) {
   }
 }
 
+TEST(ZipIteratorTest, Lifetime) {
+  SmallVector<unsigned> v1 = {1, 2, 3, 4};
+  SmallVector<unsigned> v2 = {5, 6, 7, 8};
+
+  auto zipper = zip(v1, v2);
+
+  auto zip_iter = zipper.begin();
+  EXPECT_NE(zip_iter, zipper.end());
+
+  auto *elem = &*zip_iter;
+  EXPECT_EQ(std::get<0>(*elem), 1u);
+  EXPECT_EQ(std::get<1>(*elem), 5u);
+
+  std::get<0>(*elem) = 42;
+  EXPECT_EQ(v1[0], 42u);
+}
+
 TEST(ZipIteratorTest, ZipFirstMutability) {
   using namespace std;
   vector<unsigned> pi{3, 1, 4, 1, 5, 9};



More information about the llvm-commits mailing list