[Mlir-commits] [mlir] a0a7680 - [ADT] Allow `llvm::enumerate` to enumerate over multiple ranges

Jakub Kuderski llvmlistbot at llvm.org
Wed Mar 15 16:36:49 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-15T19:34:22-04:00
New Revision: a0a76804c4b56058ba3dcd7374bcaec2fec3978e

URL: https://github.com/llvm/llvm-project/commit/a0a76804c4b56058ba3dcd7374bcaec2fec3978e
DIFF: https://github.com/llvm/llvm-project/commit/a0a76804c4b56058ba3dcd7374bcaec2fec3978e.diff

LOG: [ADT] Allow `llvm::enumerate` to enumerate over multiple ranges

This does not work by a mere composition of `enumerate` and `zip_equal`,
because C++17 does not allow for recursive expansion of structured
bindings.

This implementation uses `zippy` to manage the iteratees and adds the
stream of indices as the first zipped range. Because we have an upfront
assertion that all input ranges are of the same length, we only need to
check if the second range has ended during iteration.

As a consequence of using `zippy`, `enumerate` will now follow the
reference and lifetime semantics of the `zip*` family of functions. The
main difference is that `enumerate` exposes each tuple of references
through a new tuple-like type `enumerate_result`, with the familiar
`.index()` and `.value()` member functions.

Because the `enumerate_result` returned on dereference is a
temporary, enumeration result can no longer be used through an
lvalue ref.

Reviewed By: dblaikie, zero9178

Differential Revision: https://reviews.llvm.org/D144503

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
    llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/tools/llvm-mca/Views/InstructionInfoView.cpp
    llvm/unittests/ADT/STLExtrasTest.cpp
    llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 545e888c18230..bf33d79801065 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -755,26 +755,25 @@ template<typename... Iters> struct ZipTupleType {
   using type = std::tuple<decltype(*declval<Iters>())...>;
 };
 
-template <typename ZipType, typename... Iters>
+template <typename ZipType, typename ReferenceTupleType, typename... Iters>
 using zip_traits = iterator_facade_base<
     ZipType,
     std::common_type_t<
         std::bidirectional_iterator_tag,
         typename std::iterator_traits<Iters>::iterator_category...>,
     // ^ TODO: Implement random access methods.
-    typename ZipTupleType<Iters...>::type,
+    ReferenceTupleType,
     typename std::iterator_traits<
         std::tuple_element_t<0, std::tuple<Iters...>>>::
diff erence_type,
     // ^ FIXME: This follows boost::make_zip_iterator's assumption that all
     // inner iterators have the same 
diff erence_type. It would fail if, for
     // instance, the second field's 
diff erence_type were non-numeric while the
     // first is.
-    typename ZipTupleType<Iters...>::type *,
-    typename ZipTupleType<Iters...>::type>;
+    ReferenceTupleType *, ReferenceTupleType>;
 
-template <typename ZipType, typename... Iters>
-struct zip_common : public zip_traits<ZipType, Iters...> {
-  using Base = zip_traits<ZipType, Iters...>;
+template <typename ZipType, typename ReferenceTupleType, typename... Iters>
+struct zip_common : public zip_traits<ZipType, ReferenceTupleType, Iters...> {
+  using Base = zip_traits<ZipType, ReferenceTupleType, Iters...>;
   using IndexSequence = std::index_sequence_for<Iters...>;
   using value_type = typename Base::value_type;
 
@@ -824,8 +823,10 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
 };
 
 template <typename... Iters>
-struct zip_first : zip_common<zip_first<Iters...>, Iters...> {
-  using zip_common<zip_first, Iters...>::zip_common;
+struct zip_first : zip_common<zip_first<Iters...>,
+                              typename ZipTupleType<Iters...>::type, Iters...> {
+  using zip_common<zip_first, typename ZipTupleType<Iters...>::type,
+                   Iters...>::zip_common;
 
   bool operator==(const zip_first &other) const {
     return std::get<0>(this->iterators) == std::get<0>(other.iterators);
@@ -833,8 +834,11 @@ struct zip_first : zip_common<zip_first<Iters...>, Iters...> {
 };
 
 template <typename... Iters>
-struct zip_shortest : zip_common<zip_shortest<Iters...>, Iters...> {
-  using zip_common<zip_shortest, Iters...>::zip_common;
+struct zip_shortest
+    : zip_common<zip_shortest<Iters...>, typename ZipTupleType<Iters...>::type,
+                 Iters...> {
+  using zip_common<zip_shortest, typename ZipTupleType<Iters...>::type,
+                   Iters...>::zip_common;
 
   bool operator==(const zip_shortest &other) const {
     return any_iterator_equals(other, std::index_sequence_for<Iters...>{});
@@ -2213,113 +2217,182 @@ template <typename T> struct deref {
 
 namespace detail {
 
-template <typename R> class enumerator_iter;
+/// Tuple-like type for `zip_enumerator` dereference.
+template <typename... Refs> struct enumerator_result;
 
-template <typename R> struct result_pair {
-  using value_reference =
-      typename std::iterator_traits<IterOfRange<R>>::reference;
-
-  friend class enumerator_iter<R>;
-
-  result_pair(std::size_t Index, IterOfRange<R> Iter)
-      : Index(Index), Iter(Iter) {}
-
-  std::size_t index() const { return Index; }
-  value_reference value() const { return *Iter; }
-
-private:
-  std::size_t Index = std::numeric_limits<std::size_t>::max();
-  IterOfRange<R> Iter;
-};
-
-template <std::size_t i, typename R>
-decltype(auto) get(const result_pair<R> &Pair) {
-  static_assert(i < 2);
-  if constexpr (i == 0) {
-    return Pair.index();
-  } else {
-    return Pair.value();
-  }
-}
-
-template <typename R>
-class enumerator_iter
-    : public iterator_facade_base<enumerator_iter<R>, std::forward_iterator_tag,
-                                  const result_pair<R>> {
-  using result_type = result_pair<R>;
-
-public:
-  explicit enumerator_iter(IterOfRange<R> EndIter)
-      : Result(std::numeric_limits<size_t>::max(), EndIter) {}
-
-  enumerator_iter(std::size_t Index, IterOfRange<R> Iter)
-      : Result(Index, Iter) {}
-
-  const result_type &operator*() const { return Result; }
+template <typename... Iters>
+using EnumeratorTupleType = enumerator_result<decltype(*declval<Iters>())...>;
+
+/// Zippy iterator that uses the second iterator for comparisons. For the
+/// increment to be safe, the second range has to be the shortest.
+/// Returns `enumerator_result` on dereference to provide `.index()` and
+/// `.value()` member functions.
+/// Note: Because the dereference operator returns `enumerator_result` as a
+/// value instead of a reference and does not strictly conform to the C++17's
+/// definition of forward iterator. However, it satisfies all the
+/// forward_iterator requirements that the `zip_common` and `zippy` depend on
+/// and fully conforms to the C++20 definition of forward iterator.
+/// This is similar to `std::vector<bool>::iterator` that returns bit reference
+/// wrappers on dereference.
+template <typename... Iters>
+struct zip_enumerator : zip_common<zip_enumerator<Iters...>,
+                                   EnumeratorTupleType<Iters...>, Iters...> {
+  static_assert(sizeof...(Iters) >= 2, "Expected at least two iteratees");
+  using zip_common<zip_enumerator<Iters...>, EnumeratorTupleType<Iters...>,
+                   Iters...>::zip_common;
 
-  enumerator_iter &operator++() {
-    assert(Result.Index != std::numeric_limits<size_t>::max());
-    ++Result.Iter;
-    ++Result.Index;
-    return *this;
+  bool operator==(const zip_enumerator &Other) const {
+    return std::get<1>(this->iterators) == std::get<1>(Other.iterators);
   }
+};
 
-  bool operator==(const enumerator_iter &RHS) const {
-    // Don't compare indices here, only iterators.  It's possible for an end
-    // iterator to have 
diff erent indices depending on whether it was created
-    // by calling std::end() versus incrementing a valid iterator.
-    return Result.Iter == RHS.Result.Iter;
+template <typename... Refs> struct enumerator_result<std::size_t, Refs...> {
+  static constexpr std::size_t NumRefs = sizeof...(Refs);
+  static_assert(NumRefs != 0);
+  // `NumValues` includes the index.
+  static constexpr std::size_t NumValues = NumRefs + 1;
+
+  // Tuple type whose element types are references for each `Ref`.
+  using range_reference_tuple = std::tuple<Refs...>;
+  // Tuple type who elements are references to all values, including both
+  // the index and `Refs` reference types.
+  using value_reference_tuple = std::tuple<std::size_t, Refs...>;
+
+  enumerator_result(std::size_t Index, Refs &&...Rs)
+      : Idx(Index), Storage(std::forward<Refs>(Rs)...) {}
+
+  /// Returns the 0-based index of the current position within the original
+  /// input range(s).
+  std::size_t index() const { return Idx; }
+
+  /// Returns the value(s) for the current iterator. This does not include the
+  /// index.
+  decltype(auto) value() const {
+    if constexpr (NumRefs == 1)
+      return std::get<0>(Storage);
+    else
+      return Storage;
+  }
+
+  /// Returns the value at index `I`. This includes the index.
+  template <std::size_t I>
+  friend decltype(auto) get(const enumerator_result &Result) {
+    static_assert(I < NumValues, "Index out of bounds");
+    if constexpr (I == 0)
+      return Result.Idx;
+    else
+      return std::get<I - 1>(Result.Storage);
+  }
+
+  template <typename... Ts>
+  friend bool operator==(const enumerator_result &Result,
+                         const std::tuple<std::size_t, Ts...> &Other) {
+    static_assert(NumRefs == sizeof...(Ts), "Size mismatch");
+    if (Result.Idx != std::get<0>(Other))
+      return false;
+    return Result.is_value_equal(Other, std::make_index_sequence<NumRefs>{});
   }
 
 private:
-  result_type Result;
+  template <typename Tuple, std::size_t... Idx>
+  bool is_value_equal(const Tuple &Other, std::index_sequence<Idx...>) const {
+    return ((std::get<Idx>(Storage) == std::get<Idx + 1>(Other)) && ...);
+  }
+
+  std::size_t Idx;
+  // Make this tuple mutable to avoid casts that obfuscate const-correctness
+  // issues. Const-correctness of references is taken care of by `zippy` that
+  // defines const-non and const iterator types that will propagate down to
+  // `enumerator_result`'s `Refs`.
+  //  Note that unlike the results of `zip*` functions, `enumerate`'s result are
+  //  supposed to be modifiable even when defined as
+  // `const`.
+  mutable range_reference_tuple Storage;
 };
 
-template <typename R> class enumerator {
-public:
-  explicit enumerator(R &&Range) : TheRange(std::forward<R>(Range)) {}
+/// Infinite stream of increasing 0-based `size_t` indices.
+struct index_stream {
+  struct iterator : iterator_facade_base<iterator, std::forward_iterator_tag,
+                                         const iterator> {
+    iterator &operator++() {
+      assert(Index != std::numeric_limits<std::size_t>::max() &&
+             "Attempting to increment end iterator");
+      ++Index;
+      return *this;
+    }
 
-  enumerator_iter<R> begin() {
-    return enumerator_iter<R>(0, adl_begin(TheRange));
-  }
-  enumerator_iter<R> begin() const {
-    return enumerator_iter<R>(0, adl_begin(TheRange));
-  }
+    // Note: This dereference operator returns a value instead of a reference
+    // and does not strictly conform to the C++17's definition of forward
+    // iterator. However, it satisfies all the forward_iterator requirements
+    // that the `zip_common` depends on and fully conforms to the C++20
+    // definition of forward iterator.
+    std::size_t operator*() const { return Index; }
 
-  enumerator_iter<R> end() { return enumerator_iter<R>(adl_end(TheRange)); }
-  enumerator_iter<R> end() const {
-    return enumerator_iter<R>(adl_end(TheRange));
-  }
+    friend bool operator==(const iterator &Lhs, const iterator &Rhs) {
+      return Lhs.Index == Rhs.Index;
+    }
 
-private:
-  R TheRange;
+    std::size_t Index = 0;
+  };
+
+  iterator begin() const { return {}; }
+  iterator end() const {
+    // We approximate 'infinity' with the max size_t value, which should be good
+    // enough to index over any container.
+    iterator It;
+    It.Index = std::numeric_limits<std::size_t>::max();
+    return It;
+  }
 };
 
 } // end namespace detail
 
-/// Given an input range, returns a new range whose values are are pair (A,B)
-/// such that A is the 0-based index of the item in the sequence, and B is
-/// the value from the original sequence.  Example:
+/// Given two or more input ranges, returns a new range whose values are are
+/// tuples (A, B, C, ...), such that A is the 0-based index of the item in the
+/// sequence, and B, C, ..., are the values from the original input ranges. All
+/// input ranges are required to have equal lengths. Note that the returned
+/// iterator allows for the values (B, C, ...) to be modified.  Example:
+///
+/// ```c++
+/// std::vector<char> Letters = {'A', 'B', 'C', 'D'};
+/// std::vector<int> Vals = {10, 11, 12, 13};
 ///
-/// std::vector<char> Items = {'A', 'B', 'C', 'D'};
-/// for (auto X : enumerate(Items)) {
-///   printf("Item %zu - %c\n", X.index(), X.value());
+/// for (auto [Index, Letter, Value] : enumerate(Letters, Vals)) {
+///   printf("Item %zu - %c: %d\n", Index, Letter, Value);
+///   Value -= 10;
 /// }
+/// ```
 ///
-/// or using structured bindings:
+/// Output:
+///   Item 0 - A: 10
+///   Item 1 - B: 11
+///   Item 2 - C: 12
+///   Item 3 - D: 13
 ///
-/// for (auto [Index, Value] : enumerate(Items)) {
-///   printf("Item %zu - %c\n", Index, Value);
+/// or using an iterator:
+/// ```c++
+/// for (auto it : enumerate(Vals)) {
+///   it.value() += 10;
+///   printf("Item %zu: %d\n", it.index(), it.value());
 /// }
+/// ```
 ///
 /// Output:
-///   Item 0 - A
-///   Item 1 - B
-///   Item 2 - C
-///   Item 3 - D
+///   Item 0: 20
+///   Item 1: 21
+///   Item 2: 22
+///   Item 3: 23
 ///
-template <typename R> detail::enumerator<R> enumerate(R &&TheRange) {
-  return detail::enumerator<R>(std::forward<R>(TheRange));
+template <typename FirstRange, typename... RestRanges>
+auto enumerate(FirstRange &&First, RestRanges &&...Rest) {
+  assert((sizeof...(Rest) == 0 ||
+          all_equal({std::distance(adl_begin(First), adl_end(First)),
+                     std::distance(adl_begin(Rest), adl_end(Rest))...})) &&
+         "Ranges have 
diff erent length");
+  using enumerator = detail::zippy<detail::zip_enumerator, detail::index_stream,
+                                   FirstRange, RestRanges...>;
+  return enumerator(detail::index_stream{}, std::forward<FirstRange>(First),
+                    std::forward<RestRanges>(Rest)...);
 }
 
 namespace detail {
@@ -2451,15 +2524,17 @@ template <class T> constexpr T *to_address(T *P) { return P; }
 } // end namespace llvm
 
 namespace std {
-template <typename R>
-struct tuple_size<llvm::detail::result_pair<R>>
-    : std::integral_constant<std::size_t, 2> {};
+template <typename... Refs>
+struct tuple_size<llvm::detail::enumerator_result<Refs...>>
+    : std::integral_constant<std::size_t, sizeof...(Refs)> {};
 
-template <std::size_t i, typename R>
-struct tuple_element<i, llvm::detail::result_pair<R>>
-    : std::conditional<i == 0, std::size_t,
-                       typename llvm::detail::result_pair<R>::value_reference> {
-};
+template <std::size_t I, typename... Refs>
+struct tuple_element<I, llvm::detail::enumerator_result<Refs...>>
+    : std::tuple_element<I, std::tuple<Refs...>> {};
+
+template <std::size_t I, typename... Refs>
+struct tuple_element<I, const llvm::detail::enumerator_result<Refs...>>
+    : std::tuple_element<I, std::tuple<Refs...>> {};
 
 } // namespace std
 

diff  --git a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
index 4555f1a3ebb08..5846fd454b654 100644
--- a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
+++ b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
@@ -6590,11 +6590,11 @@ static unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
   assert(M.size() == 4 && "Expected a 4 entry perfect shuffle");
 
   // Special case zero-cost nop copies, from either LHS or RHS.
-  if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
+  if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
         return E.value() < 0 || E.value() == (int)E.index();
       }))
     return 0;
-  if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
+  if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
         return E.value() < 0 || E.value() == (int)E.index() + 4;
       }))
     return 0;

diff  --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
index c8b156a6f2c0f..1e0a7d11d8ac2 100644
--- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
+++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
@@ -1249,7 +1249,7 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr *MI) {
   const MCInstrDesc &MCID = MI->getDesc();
   bool IsUse = false;
   unsigned LastOpIdx = MI->getNumOperands() - 1;
-  for (auto &Op : enumerate(reverse(MCID.operands()))) {
+  for (const auto &Op : enumerate(reverse(MCID.operands()))) {
     const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index());
     if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR)
       continue;

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 08f1c9ffb738a..d525365762293 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1651,11 +1651,11 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
                                        StringRef &ErrInfo) const {
   MCInstrDesc const &Desc = MI.getDesc();
 
-  for (auto &OI : enumerate(Desc.operands())) {
-    unsigned OpType = OI.value().OperandType;
+  for (const auto &[Index, Operand] : enumerate(Desc.operands())) {
+    unsigned OpType = Operand.OperandType;
     if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM &&
         OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) {
-      const MachineOperand &MO = MI.getOperand(OI.index());
+      const MachineOperand &MO = MI.getOperand(Index);
       if (MO.isImm()) {
         int64_t Imm = MO.getImm();
         bool Ok;

diff  --git a/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp b/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp
index 257fdca8cb366..3262de48d41e3 100644
--- a/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp
+++ b/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp
@@ -55,10 +55,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
     }
   }
 
-  int Index = 0;
-  for (const auto &I : enumerate(zip(IIVD, Source))) {
-    const InstructionInfoViewData &IIVDEntry = std::get<0>(I.value());
-
+  for (const auto &[Index, IIVDEntry, Inst] : enumerate(IIVD, Source)) {
     TempStream << ' ' << IIVDEntry.NumMicroOpcodes << "    ";
     if (IIVDEntry.NumMicroOpcodes < 10)
       TempStream << "  ";
@@ -92,7 +89,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
     }
 
     if (PrintEncodings) {
-      StringRef Encoding(CE.getEncoding(I.index()));
+      StringRef Encoding(CE.getEncoding(Index));
       unsigned EncodingSize = Encoding.size();
       TempStream << " " << EncodingSize
                  << (EncodingSize < 10 ? "     " : "    ");
@@ -104,9 +101,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
       FOS.flush();
     }
 
-    const MCInst &Inst = std::get<1>(I.value());
     TempStream << printInstructionString(Inst) << '\n';
-    ++Index;
   }
 
   TempStream.flush();

diff  --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 8ec3df10bc6c3..bb602bb6c39f7 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -13,8 +13,11 @@
 
 #include <array>
 #include <climits>
+#include <cstddef>
+#include <initializer_list>
 #include <list>
 #include <tuple>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -153,6 +156,131 @@ TEST(STLExtrasTest, EnumerateModifyRValue) {
                                    PairType(2u, '4')));
 }
 
+TEST(STLExtrasTest, EnumerateTwoRanges) {
+  using Tuple = std::tuple<size_t, int, bool>;
+
+  std::vector<int> Ints = {1, 2};
+  std::vector<bool> Bools = {true, false};
+  EXPECT_THAT(llvm::enumerate(Ints, Bools),
+              ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false)));
+
+  // Check that we can modify the values when the temporary is a const
+  // reference.
+  for (const auto &[Idx, Int, Bool] : llvm::enumerate(Ints, Bools)) {
+    (void)Idx;
+    Bool = false;
+    Int = -1;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(-1, -1));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+
+  // Check that we can modify the values when the result gets copied.
+  for (auto [Idx, Bool, Int] : llvm::enumerate(Bools, Ints)) {
+    (void)Idx;
+    Int = 3;
+    Bool = true;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(3, 3));
+  EXPECT_THAT(Bools, ElementsAre(true, true));
+
+  // Check that we can modify the values through `.value()`.
+  size_t Iters = 0;
+  for (auto It : llvm::enumerate(Bools, Ints)) {
+    EXPECT_EQ(It.index(), Iters);
+    ++Iters;
+
+    std::get<0>(It.value()) = false;
+    std::get<1>(It.value()) = 4;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(4, 4));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+}
+
+TEST(STLExtrasTest, EnumerateThreeRanges) {
+  using Tuple = std::tuple<size_t, int, bool, char>;
+
+  std::vector<int> Ints = {1, 2};
+  std::vector<bool> Bools = {true, false};
+  char Chars[] = {'X', 'D'};
+  EXPECT_THAT(llvm::enumerate(Ints, Bools, Chars),
+              ElementsAre(Tuple(0, 1, true, 'X'), Tuple(1, 2, false, 'D')));
+
+  for (auto [Idx, Int, Bool, Char] : llvm::enumerate(Ints, Bools, Chars)) {
+    (void)Idx;
+    Int = 0;
+    Bool = true;
+    Char = '!';
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(0, 0));
+  EXPECT_THAT(Bools, ElementsAre(true, true));
+  EXPECT_THAT(Chars, ElementsAre('!', '!'));
+
+  // Check that we can modify the values through `.values()`.
+  size_t Iters = 0;
+  for (auto It : llvm::enumerate(Ints, Bools, Chars)) {
+    EXPECT_EQ(It.index(), Iters);
+    ++Iters;
+    auto [Int, Bool, Char] = It.value();
+    Int = 42;
+    Bool = false;
+    Char = '$';
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(42, 42));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+  EXPECT_THAT(Chars, ElementsAre('$', '$'));
+}
+
+TEST(STLExtrasTest, EnumerateTemporaries) {
+  using Tuple = std::tuple<size_t, int, bool>;
+
+  EXPECT_THAT(
+      llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                      std::vector<bool>({true, false, true})),
+      ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false), Tuple(2, 3, true)));
+
+  size_t Iters = 0;
+  // This is fine from the point of view of range lifetimes because `zippy` will
+  // move all temporaries into its storage. No lifetime extension is necessary.
+  for (auto [Idx, Int, Bool] :
+       llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                       std::vector<bool>({true, false, true}))) {
+    EXPECT_EQ(Idx, Iters);
+    ++Iters;
+    Int = 0;
+    Bool = true;
+  }
+
+  Iters = 0;
+  // The same thing but with the result as a const reference.
+  for (const auto &[Idx, Int, Bool] :
+       llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                       std::vector<bool>({true, false, true}))) {
+    EXPECT_EQ(Idx, Iters);
+    ++Iters;
+    Int = 0;
+    Bool = true;
+  }
+}
+
+#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG)
+TEST(STLExtrasTest, EnumerateDifferentLengths) {
+  std::vector<int> Ints = {0, 1};
+  bool Bools[] = {true, false, true};
+  std::string Chars = "abc";
+  EXPECT_DEATH(llvm::enumerate(Ints, Bools, Chars),
+               "Ranges have 
diff erent length");
+  EXPECT_DEATH(llvm::enumerate(Bools, Ints, Chars),
+               "Ranges have 
diff erent length");
+  EXPECT_DEATH(llvm::enumerate(Bools, Chars, Ints),
+               "Ranges have 
diff erent length");
+}
+#endif
+
 template <bool B> struct CanMove {};
 template <> struct CanMove<false> {
   CanMove(CanMove &&) = delete;
@@ -190,8 +318,8 @@ class Counted : CanMove<Moveable>, CanCopy<Copyable> {
 template <bool Moveable, bool Copyable>
 struct Range : Counted<Moveable, Copyable> {
   using Counted<Moveable, Copyable>::Counted;
-  int *begin() { return nullptr; }
-  int *end() { return nullptr; }
+  int *begin() const { return nullptr; }
+  int *end() const { return nullptr; }
 };
 
 TEST(STLExtrasTest, EnumerateLifetimeSemanticsPRValue) {

diff  --git a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp
index 493c6b35c8406..fcf9c25ae74a3 100644
--- a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp
+++ b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp
@@ -338,9 +338,9 @@ void GIMatchTreeBuilder::runStep() {
          "Must always partition into at least one partition");
 
   TreeNode->setNumChildren(Partitioner->getNumPartitions());
-  for (auto &C : enumerate(TreeNode->children())) {
-    SubtreeBuilders.emplace_back(&C.value(), NextInstrID);
-    Partitioner->applyForPartition(C.index(), *this, SubtreeBuilders.back());
+  for (const auto &[Idx, Child] : enumerate(TreeNode->children())) {
+    SubtreeBuilders.emplace_back(&Child, NextInstrID);
+    Partitioner->applyForPartition(Idx, *this, SubtreeBuilders.back());
   }
 
   TreeNode->setPartitioner(std::move(Partitioner));

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6e2aa2c2755bc..4df55ddd62e0f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1394,7 +1394,7 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   //   i1 = (i_{folded} / d2) % d1
   //   i0 = i_{folded} / (d1 * d2)
   llvm::DenseMap<unsigned, Value> indexReplacementVals;
-  for (auto &foldedDims :
+  for (auto foldedDims :
        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
     Value newIndexVal =

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c8999943d868e..27f1eab09eec6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1871,15 +1871,13 @@ LogicalResult ReinterpretCastOp::verify() {
            << srcType << " and result memref type " << resultType;
 
   // Match sizes in result memref type and in static_sizes attribute.
-  for (auto &en :
-       llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
-    int64_t resultSize = std::get<0>(en.value());
-    int64_t expectedSize = std::get<1>(en.value());
+  for (auto [idx, resultSize, expectedSize] :
+       llvm::enumerate(resultType.getShape(), getStaticSizes())) {
     if (!ShapedType::isDynamic(resultSize) &&
         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
-             << " in dim = " << en.index();
+             << " in dim = " << idx;
   }
 
   // Match offset and strides in static_offset and static_strides attributes. If
@@ -1900,16 +1898,14 @@ LogicalResult ReinterpretCastOp::verify() {
            << resultOffset << " instead of " << expectedOffset;
 
   // Match strides in result memref type and in static_strides attribute.
-  for (auto &en :
-       llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
-    int64_t resultStride = std::get<0>(en.value());
-    int64_t expectedStride = std::get<1>(en.value());
+  for (auto [idx, resultStride, expectedStride] :
+       llvm::enumerate(resultStrides, getStaticStrides())) {
     if (!ShapedType::isDynamic(resultStride) &&
         !ShapedType::isDynamic(expectedStride) &&
         resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
-             << " in dim = " << en.index();
+             << " in dim = " << idx;
   }
 
   return success();

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2b59ada88c338..f3fcd5ac20263 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1891,10 +1891,7 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
     auto elseYieldArgs = op.elseYield().getOperands();
 
     SmallVector<Type> nonHoistable;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
-      Value trueVal = std::get<0>(it.value());
-      Value falseVal = std::get<1>(it.value());
+    for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
       if (&op.getThenRegion() == trueVal.getParentRegion() ||
           &op.getElseRegion() == falseVal.getParentRegion())
         nonHoistable.push_back(trueVal.getType());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 7a802379355fe..915e4b4ed1c56 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -423,7 +423,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     return b.notifyMatchFailure(
         op, "only support ops with one reduction dimension.");
   int reductionDim;
-  for (auto &[idx, iteratorType] :
+  for (auto [idx, iteratorType] :
        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
     if (iteratorType == utils::IteratorType::reduction) {
       reductionDim = idx;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 71c78d9061a9b..ae6e40f1c19cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1148,10 +1148,9 @@ class SparseExtractSliceConverter
     desc.setSpecifier(newSpec);
 
     // Fills in slice information.
-    for (const auto &it : llvm::enumerate(llvm::zip(
-             op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()))) {
-      Dimension dim = it.index();
-      auto [offset, size, stride] = it.value();
+    for (auto [idx, offset, size, stride] : llvm::enumerate(
+             op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
+      Dimension dim = idx;
 
       Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
       Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index eb73a2c872956..ea087c1357aec 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -93,7 +93,7 @@ static OpFoldResult getExpandedOutputDimFromInputShape(
                         .cast<AffineDimExpr>()
                         .getPosition();
   int64_t linearizedStaticDim = 1;
-  for (auto &d :
+  for (auto d :
        llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
     if (d.index() + startPos == static_cast<unsigned>(dimIndex))
       continue;

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index fdeb1e490033d..43dedf1c4ce06 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -762,7 +762,7 @@ struct InsertSliceOpInterface
           return isConstantIntValue(ofr, 0);
         });
     bool sizesMatchDestSizes = llvm::all_of(
-        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) {
+        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
           return getConstantIntValue(it.value()) ==
                  destType.getDimSize(it.index());
         });

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index a51bdb50da9ed..52ea148179357 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -869,14 +869,14 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
           if (arg.kind != LinalgOperandDefKind::IndexAttr)
             continue;
           assert(arg.indexAttrMap);
-          for (auto &en :
+          for (auto [idx, result] :
                llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
-            if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
+            if (auto symbol = result.dyn_cast<AffineSymbolExpr>()) {
               std::string argName = arg.name;
               argName[0] = toupper(argName[0]);
               symbolBindings[symbol.getPosition()] =
                   llvm::formatv(structuredOpAccessAttrFormat, argName,
-                                symbol.getPosition(), en.index());
+                                symbol.getPosition(), idx);
             }
           }
         }


        


More information about the Mlir-commits mailing list