[Mlir-commits] [mlir] a006af5 - [llvm] Add enum iteration to Sequence

Guillaume Chatelet llvmlistbot at llvm.org
Tue Jul 13 09:22:31 PDT 2021


Author: Guillaume Chatelet
Date: 2021-07-13T16:22:19Z
New Revision: a006af5d6ec6280034ae4249f6d2266d726ccef4

URL: https://github.com/llvm/llvm-project/commit/a006af5d6ec6280034ae4249f6d2266d726ccef4
DIFF: https://github.com/llvm/llvm-project/commit/a006af5d6ec6280034ae4249f6d2266d726ccef4.diff

LOG: [llvm] Add enum iteration to Sequence

This patch allows iterating typed enum via the ADT/Sequence utility.

Differential Revision: https://reviews.llvm.org/D103900

Added: 
    

Modified: 
    llvm/include/llvm/ADT/Sequence.h
    llvm/include/llvm/Support/MachineValueType.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/tools/llvm-exegesis/lib/X86/Target.cpp
    llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp
    llvm/unittests/ADT/SequenceTest.cpp
    llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
    llvm/unittests/IR/ConstantRangeTest.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h
index d033c01ecc573..0348f3dd4959e 100644
--- a/llvm/include/llvm/ADT/Sequence.h
+++ b/llvm/include/llvm/ADT/Sequence.h
@@ -15,46 +15,29 @@
 #ifndef LLVM_ADT_SEQUENCE_H
 #define LLVM_ADT_SEQUENCE_H
 
-#include <cstddef>  //std::ptr
diff _t
-#include <iterator> //std::random_access_iterator_tag
+#include <cassert>     // assert
+#include <cstddef>     // std::ptr
diff _t
+#include <iterator>    // std::random_access_iterator_tag
+#include <limits>      // std::numeric_limits
+#include <type_traits> // std::underlying_type, std::is_enum
 
 namespace llvm {
 
 namespace detail {
 
-template <typename T, bool IsReversed> struct iota_range_iterator {
+template <typename T, typename U, bool IsReversed> struct iota_range_iterator {
   using iterator_category = std::random_access_iterator_tag;
   using value_type = T;
   using 
diff erence_type = std::ptr
diff _t;
   using pointer = T *;
   using reference = T &;
 
-private:
-  struct Forward {
-    static void increment(T &V) { ++V; }
-    static void decrement(T &V) { --V; }
-    static void offset(T &V, 
diff erence_type Offset) { V += Offset; }
-    static T add(const T &V, 
diff erence_type Offset) { return V + Offset; }
-    static 
diff erence_type 
diff erence(const T &A, const T &B) { return A - B; }
-  };
-
-  struct Reverse {
-    static void increment(T &V) { --V; }
-    static void decrement(T &V) { ++V; }
-    static void offset(T &V, 
diff erence_type Offset) { V -= Offset; }
-    static T add(const T &V, 
diff erence_type Offset) { return V - Offset; }
-    static 
diff erence_type 
diff erence(const T &A, const T &B) { return B - A; }
-  };
-
-  using Op = std::conditional_t<!IsReversed, Forward, Reverse>;
-
-public:
   // default-constructible
   iota_range_iterator() = default;
   // copy-constructible
   iota_range_iterator(const iota_range_iterator &) = default;
   // value constructor
-  explicit iota_range_iterator(T Value) : Value(Value) {}
+  explicit iota_range_iterator(U Value) : Value(Value) {}
   // copy-assignable
   iota_range_iterator &operator=(const iota_range_iterator &) = default;
   // destructible
@@ -83,8 +66,10 @@ template <typename T, bool IsReversed> struct iota_range_iterator {
   }
 
   // Dereference
-  T operator*() const { return Value; }
-  T operator[](
diff erence_type Offset) const { return Op::add(Value, Offset); }
+  T operator*() const { return static_cast<T>(Value); }
+  T operator[](
diff erence_type Offset) const {
+    return static_cast<T>(Op::add(Value, Offset));
+  }
 
   // Arithmetic
   iota_range_iterator operator+(
diff erence_type Offset) const {
@@ -132,46 +117,116 @@ template <typename T, bool IsReversed> struct iota_range_iterator {
   }
 
 private:
-  T Value;
+  struct Forward {
+    static void increment(U &V) { ++V; }
+    static void decrement(U &V) { --V; }
+    static void offset(U &V, 
diff erence_type Offset) { V += Offset; }
+    static U add(const U &V, 
diff erence_type Offset) { return V + Offset; }
+    static 
diff erence_type 
diff erence(const U &A, const U &B) {
+      return 
diff erence_type(A) - 
diff erence_type(B);
+    }
+  };
+
+  struct Reverse {
+    static void increment(U &V) { --V; }
+    static void decrement(U &V) { ++V; }
+    static void offset(U &V, 
diff erence_type Offset) { V -= Offset; }
+    static U add(const U &V, 
diff erence_type Offset) { return V - Offset; }
+    static 
diff erence_type 
diff erence(const U &A, const U &B) {
+      return 
diff erence_type(B) - 
diff erence_type(A);
+    }
+  };
+
+  using Op = std::conditional_t<!IsReversed, Forward, Reverse>;
+
+  U Value;
 };
 
+// Providing std::type_identity for C++14.
+template <class T> struct type_identity { using type = T; };
+
 } // namespace detail
 
-template <typename ValueT> struct iota_range {
-  static_assert(std::is_integral<ValueT>::value,
-                "ValueT must be an integral type");
+template <typename T> struct iota_range {
+private:
+  using underlying_type =
+      typename std::conditional_t<std::is_enum<T>::value,
+                                  std::underlying_type<T>,
+                                  detail::type_identity<T>>::type;
+  using numeric_type =
+      typename std::conditional_t<std::is_signed<underlying_type>::value,
+                                  intmax_t, uintmax_t>;
+
+  static numeric_type compute_past_end(numeric_type End, bool Inclusive) {
+    if (Inclusive) {
+      // This assertion forbids overflow of `PastEndValue`.
+      assert(End != std::numeric_limits<numeric_type>::max() &&
+             "Forbidden End value for seq_inclusive.");
+      return End + 1;
+    }
+    return End;
+  }
+  static numeric_type raw(T Value) { return static_cast<numeric_type>(Value); }
 
-  using value_type = ValueT;
-  using reference = ValueT &;
-  using const_reference = const ValueT &;
-  using iterator = detail::iota_range_iterator<value_type, false>;
+  numeric_type BeginValue;
+  numeric_type PastEndValue;
+
+public:
+  using value_type = T;
+  using reference = T &;
+  using const_reference = const T &;
+  using iterator = detail::iota_range_iterator<value_type, numeric_type, false>;
   using const_iterator = iterator;
-  using reverse_iterator = detail::iota_range_iterator<value_type, true>;
+  using reverse_iterator =
+      detail::iota_range_iterator<value_type, numeric_type, true>;
   using const_reverse_iterator = reverse_iterator;
   using 
diff erence_type = std::ptr
diff _t;
   using size_type = std::size_t;
 
-  value_type Begin;
-  value_type End;
-
-  explicit iota_range(ValueT Begin, ValueT End) : Begin(Begin), End(End) {}
+  explicit iota_range(T Begin, T End, bool Inclusive)
+      : BeginValue(raw(Begin)),
+        PastEndValue(compute_past_end(raw(End), Inclusive)) {
+    assert(Begin <= End && "Begin must be less or equal to End.");
+  }
 
-  size_t size() const { return End - Begin; }
-  bool empty() const { return Begin == End; }
+  size_t size() const { return PastEndValue - BeginValue; }
+  bool empty() const { return BeginValue == PastEndValue; }
 
-  auto begin() const { return const_iterator(Begin); }
-  auto end() const { return const_iterator(End); }
+  auto begin() const { return const_iterator(BeginValue); }
+  auto end() const { return const_iterator(PastEndValue); }
 
-  auto rbegin() const { return const_reverse_iterator(End - 1); }
-  auto rend() const { return const_reverse_iterator(Begin - 1); }
+  auto rbegin() const { return const_reverse_iterator(PastEndValue - 1); }
+  auto rend() const {
+    assert(std::is_unsigned<numeric_type>::value ||
+           BeginValue != std::numeric_limits<numeric_type>::min() &&
+               "Forbidden Begin value for reverse iteration");
+    return const_reverse_iterator(BeginValue - 1);
+  }
 
 private:
-  static_assert(std::is_same<ValueT, std::remove_cv_t<ValueT>>::value,
-                "ValueT must not be const nor volatile");
+  static_assert(std::is_integral<T>::value || std::is_enum<T>::value,
+                "T must be an integral or enum type");
+  static_assert(std::is_same<T, std::remove_cv_t<T>>::value,
+                "T must not be const nor volatile");
+  static_assert(std::is_integral<numeric_type>::value,
+                "numeric_type must be an integral type");
 };
 
-template <typename ValueT> auto seq(ValueT Begin, ValueT End) {
-  return iota_range<ValueT>(Begin, End);
+/// Iterate over an integral/enum type from Begin up to - but not including -
+/// End.
+/// Note on enum iteration: `seq` will generate each consecutive value, even if
+/// no enumerator with that value exists.
+template <typename T> auto seq(T Begin, T End) {
+  return iota_range<T>(Begin, End, false);
+}
+
+/// Iterate over an integral/enum type from Begin to End inclusive.
+/// Note on enum iteration: `seq_inclusive` will generate each consecutive
+/// value, even if no enumerator with that value exists.
+/// To prevent overflow, `End` must be 
diff erent from INTMAX_MAX if T is signed
+/// (resp. UINTMAX_MAX if T is unsigned).
+template <typename T> auto seq_inclusive(T Begin, T End) {
+  return iota_range<T>(Begin, End, true);
 }
 
 } // end namespace llvm

diff  --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h
index 4b8d937bde379..31f2d5a481832 100644
--- a/llvm/include/llvm/Support/MachineValueType.h
+++ b/llvm/include/llvm/Support/MachineValueType.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_SUPPORT_MACHINEVALUETYPE_H
 #define LLVM_SUPPORT_MACHINEVALUETYPE_H
 
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
@@ -1398,84 +1399,55 @@ namespace llvm {
     /// returned as Other, otherwise they are invalid.
     static MVT getVT(Type *Ty, bool HandleUnknown = false);
 
-  private:
-    /// A simple iterator over the MVT::SimpleValueType enum.
-    struct mvt_iterator {
-      SimpleValueType VT;
-
-      mvt_iterator(SimpleValueType VT) : VT(VT) {}
-
-      MVT operator*() const { return VT; }
-      bool operator!=(const mvt_iterator &LHS) const { return VT != LHS.VT; }
-
-      mvt_iterator& operator++() {
-        VT = (MVT::SimpleValueType)((int)VT + 1);
-        assert((int)VT <= MVT::MAX_ALLOWED_VALUETYPE &&
-               "MVT iterator overflowed.");
-        return *this;
-      }
-    };
-
-    /// A range of the MVT::SimpleValueType enum.
-    using mvt_range = iterator_range<mvt_iterator>;
-
   public:
     /// SimpleValueType Iteration
     /// @{
-    static mvt_range all_valuetypes() {
-      return mvt_range(MVT::FIRST_VALUETYPE,
-                       (MVT::SimpleValueType)(MVT::LAST_VALUETYPE + 1));
+    static auto all_valuetypes() {
+      return seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE);
     }
 
-    static mvt_range integer_valuetypes() {
-      return mvt_range(MVT::FIRST_INTEGER_VALUETYPE,
-                       (MVT::SimpleValueType)(MVT::LAST_INTEGER_VALUETYPE + 1));
+    static auto integer_valuetypes() {
+      return seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE,
+                           MVT::LAST_INTEGER_VALUETYPE);
     }
 
-    static mvt_range fp_valuetypes() {
-      return mvt_range(MVT::FIRST_FP_VALUETYPE,
-                       (MVT::SimpleValueType)(MVT::LAST_FP_VALUETYPE + 1));
+    static auto fp_valuetypes() {
+      return seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE);
     }
 
-    static mvt_range vector_valuetypes() {
-      return mvt_range(MVT::FIRST_VECTOR_VALUETYPE,
-                       (MVT::SimpleValueType)(MVT::LAST_VECTOR_VALUETYPE + 1));
+    static auto vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE,
+                           MVT::LAST_VECTOR_VALUETYPE);
     }
 
-    static mvt_range fixedlen_vector_valuetypes() {
-      return mvt_range(
-               MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE,
-               (MVT::SimpleValueType)(MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE + 1));
+    static auto fixedlen_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE,
+                           MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE);
     }
 
-    static mvt_range scalable_vector_valuetypes() {
-      return mvt_range(
-               MVT::FIRST_SCALABLE_VECTOR_VALUETYPE,
-               (MVT::SimpleValueType)(MVT::LAST_SCALABLE_VECTOR_VALUETYPE + 1));
+    static auto scalable_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE,
+                           MVT::LAST_SCALABLE_VECTOR_VALUETYPE);
     }
 
-    static mvt_range integer_fixedlen_vector_valuetypes() {
-      return mvt_range(
-       MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE,
-       (MVT::SimpleValueType)(MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE + 1));
+    static auto integer_fixedlen_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE,
+                           MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE);
     }
 
-    static mvt_range fp_fixedlen_vector_valuetypes() {
-      return mvt_range(
-          MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE,
-          (MVT::SimpleValueType)(MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE + 1));
+    static auto fp_fixedlen_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE,
+                           MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE);
     }
 
-    static mvt_range integer_scalable_vector_valuetypes() {
-      return mvt_range(
-       MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE,
-       (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE + 1));
+    static auto integer_scalable_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE,
+                           MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE);
     }
 
-    static mvt_range fp_scalable_vector_valuetypes() {
-      return mvt_range(
-            MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE,
-            (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE + 1));
+    static auto fp_scalable_vector_valuetypes() {
+      return seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE,
+                           MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE);
     }
     /// @}
   };

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index f286bc9067b75..91242bbf866f9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -4634,8 +4634,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_EXTEND(SDNode *N) {
   EVT InVT = InOp.getValueType();
   if (InVT.getSizeInBits() != VT.getSizeInBits()) {
     EVT InEltVT = InVT.getVectorElementType();
-    for (int i = MVT::FIRST_VECTOR_VALUETYPE, e = MVT::LAST_VECTOR_VALUETYPE; i < e; ++i) {
-      EVT FixedVT = (MVT::SimpleValueType)i;
+    for (EVT FixedVT : MVT::vector_valuetypes()) {
       EVT FixedEltVT = FixedVT.getVectorElementType();
       if (TLI.isTypeLegal(FixedVT) &&
           FixedVT.getSizeInBits() == VT.getSizeInBits() &&
@@ -5162,14 +5161,11 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
   if (!Scalable && Width == WidenEltWidth)
     return RetVT;
 
-  // See if there is larger legal integer than the element type to load/store.
-  unsigned VT;
   // Don't bother looking for an integer type if the vector is scalable, skip
   // to vector types.
   if (!Scalable) {
-    for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE;
-         VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) {
-      EVT MemVT((MVT::SimpleValueType) VT);
+    // See if there is larger legal integer than the element type to load/store.
+    for (EVT MemVT : reverse(MVT::integer_valuetypes())) {
       unsigned MemVTWidth = MemVT.getSizeInBits();
       if (MemVT.getSizeInBits() <= WidenEltWidth)
         break;
@@ -5190,9 +5186,7 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
 
   // See if there is a larger vector type to load/store that has the same vector
   // element type and is evenly divisible with the WidenVT.
-  for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE;
-       VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) {
-    EVT MemVT = (MVT::SimpleValueType) VT;
+  for (EVT MemVT : reverse(MVT::vector_valuetypes())) {
     // Skip vector MVTs which don't match the scalable property of WidenVT.
     if (Scalable != MemVT.isScalableVector())
       continue;

diff  --git a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp
index 40fed0789a819..1be119a508d54 100644
--- a/llvm/tools/llvm-exegesis/lib/X86/Target.cpp
+++ b/llvm/tools/llvm-exegesis/lib/X86/Target.cpp
@@ -918,9 +918,9 @@ std::vector<InstructionTemplate> ExegesisX86Target::generateInstructionVariants(
       continue;
     case X86::OperandType::OPERAND_COND_CODE: {
       Exploration = true;
-      auto CondCodes = seq((int)X86::CondCode::COND_O,
-                           1 + (int)X86::CondCode::LAST_VALID_COND);
-      Choices.reserve(std::distance(CondCodes.begin(), CondCodes.end()));
+      auto CondCodes =
+          seq_inclusive(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND);
+      Choices.reserve(CondCodes.size());
       for (int CondCode : CondCodes)
         Choices.emplace_back(MCOperand::createImm(CondCode));
       break;

diff  --git a/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp b/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp
index 26b77bfd5aba5..223866ba52c25 100644
--- a/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp
+++ b/llvm/tools/llvm-reduce/deltas/ReduceAttributes.cpp
@@ -84,7 +84,8 @@ class AttributeRemapper : public InstVisitor<AttributeRemapper> {
                           AttrPtrVecVecTy &AttributeSetsToPreserve) {
     assert(AttributeSetsToPreserve.empty() && "Should not be sharing vectors.");
     AttributeSetsToPreserve.reserve(AL.getNumAttrSets());
-    for (unsigned SetIdx : seq(AL.index_begin(), AL.index_end())) {
+    for (unsigned SetIdx = AL.index_begin(), SetEndIdx = AL.index_end();
+         SetIdx != SetEndIdx; ++SetIdx) {
       AttrPtrIdxVecVecTy AttributesToPreserve;
       AttributesToPreserve.first = SetIdx;
       visitAttributeSet(AL.getAttributes(AttributesToPreserve.first),

diff  --git a/llvm/unittests/ADT/SequenceTest.cpp b/llvm/unittests/ADT/SequenceTest.cpp
index f10e80ff4125c..dc3ca7e464226 100644
--- a/llvm/unittests/ADT/SequenceTest.cpp
+++ b/llvm/unittests/ADT/SequenceTest.cpp
@@ -7,12 +7,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/Sequence.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 #include <list>
 
 using namespace llvm;
 
+using testing::ElementsAre;
+
 namespace {
 
 TEST(SequenceTest, Forward) {
@@ -48,4 +51,108 @@ TEST(SequenceTest, Dereference) {
   EXPECT_EQ(Backward[2], 7);
 }
 
+enum class CharEnum : char { A = 1, B, C, D, E };
+
+TEST(SequenceTest, ForwardIteration) {
+  EXPECT_THAT(seq_inclusive(CharEnum::C, CharEnum::E),
+              ElementsAre(CharEnum::C, CharEnum::D, CharEnum::E));
+}
+
+TEST(SequenceTest, BackwardIteration) {
+  EXPECT_THAT(reverse(seq_inclusive(CharEnum::B, CharEnum::D)),
+              ElementsAre(CharEnum::D, CharEnum::C, CharEnum::B));
+}
+
+using IntegralTypes =
+    testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, uintmax_t, //
+                   int8_t, int16_t, int32_t, int64_t, intmax_t>;
+
+template <class T> class SequenceTest : public testing::Test {
+public:
+  const T min = std::numeric_limits<T>::min();
+  const T minp1 = min + 1;
+  const T max = std::numeric_limits<T>::max();
+  const T maxm1 = max - 1;
+
+  void checkIteration() const {
+    // Forward
+    EXPECT_THAT(seq(min, min), ElementsAre());
+    EXPECT_THAT(seq(min, minp1), ElementsAre(min));
+    EXPECT_THAT(seq(maxm1, max), ElementsAre(maxm1));
+    EXPECT_THAT(seq(max, max), ElementsAre());
+    // Reverse
+    if (!std::is_same<T, intmax_t>::value) {
+      EXPECT_THAT(reverse(seq(min, min)), ElementsAre());
+      EXPECT_THAT(reverse(seq(min, minp1)), ElementsAre(min));
+    }
+    EXPECT_THAT(reverse(seq(maxm1, max)), ElementsAre(maxm1));
+    EXPECT_THAT(reverse(seq(max, max)), ElementsAre());
+    // Inclusive
+    EXPECT_THAT(seq_inclusive(min, min), ElementsAre(min));
+    EXPECT_THAT(seq_inclusive(min, minp1), ElementsAre(min, minp1));
+    EXPECT_THAT(seq_inclusive(maxm1, maxm1), ElementsAre(maxm1));
+    // Inclusive Reverse
+    if (!std::is_same<T, intmax_t>::value) {
+      EXPECT_THAT(reverse(seq_inclusive(min, min)), ElementsAre(min));
+      EXPECT_THAT(reverse(seq_inclusive(min, minp1)), ElementsAre(minp1, min));
+    }
+    EXPECT_THAT(reverse(seq_inclusive(maxm1, maxm1)), ElementsAre(maxm1));
+  }
+
+  void checkIterators() const {
+    auto checkValidIterators = [](auto sequence) {
+      EXPECT_LE(sequence.begin(), sequence.end());
+    };
+    checkValidIterators(seq(min, min));
+    checkValidIterators(seq(max, max));
+    checkValidIterators(seq_inclusive(min, min));
+    checkValidIterators(seq_inclusive(maxm1, maxm1));
+  }
+};
+TYPED_TEST_SUITE(SequenceTest, IntegralTypes);
+TYPED_TEST(SequenceTest, Boundaries) {
+  this->checkIteration();
+  this->checkIterators();
+}
+
+#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG)
+template <class T> class SequenceDeathTest : public SequenceTest<T> {
+public:
+  using SequenceTest<T>::min;
+  using SequenceTest<T>::minp1;
+  using SequenceTest<T>::max;
+  using SequenceTest<T>::maxm1;
+
+  void checkInvalidOrder() const {
+    EXPECT_DEATH(seq(max, min), "Begin must be less or equal to End.");
+    EXPECT_DEATH(seq(minp1, min), "Begin must be less or equal to End.");
+    EXPECT_DEATH(seq_inclusive(maxm1, min),
+                 "Begin must be less or equal to End.");
+    EXPECT_DEATH(seq_inclusive(minp1, min),
+                 "Begin must be less or equal to End.");
+  }
+  void checkInvalidValues() const {
+    if (std::is_same<T, intmax_t>::value || std::is_same<T, uintmax_t>::value) {
+      EXPECT_DEATH(seq_inclusive(min, max),
+                   "Forbidden End value for seq_inclusive.");
+      EXPECT_DEATH(seq_inclusive(minp1, max),
+                   "Forbidden End value for seq_inclusive.");
+    }
+    if (std::is_same<T, intmax_t>::value) {
+      EXPECT_DEATH(reverse(seq(min, min)),
+                   "Forbidden Begin value for reverse iteration");
+      EXPECT_DEATH(reverse(seq_inclusive(min, min)),
+                   "Forbidden Begin value for reverse iteration");
+      // Note it is fine to use `Begin == 0` when `iota_range::numeric_type ==
+      // uintmax_t` as unsigned integer underflow is well-defined.
+    }
+  }
+};
+TYPED_TEST_SUITE(SequenceDeathTest, IntegralTypes);
+TYPED_TEST(SequenceDeathTest, DeathTests) {
+  this->checkInvalidOrder();
+  this->checkInvalidValues();
+}
+#endif // defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG)
+
 } // anonymous namespace

diff  --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
index e38d1045586cb..1f6d23859186e 100644
--- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
+++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
@@ -18,7 +18,7 @@ using namespace llvm;
 namespace {
 
 TEST(ScalableVectorMVTsTest, IntegerMVTs) {
-  for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) {
+  for (MVT VecTy : MVT::integer_scalable_vector_valuetypes()) {
     ASSERT_TRUE(VecTy.isValid());
     ASSERT_TRUE(VecTy.isInteger());
     ASSERT_TRUE(VecTy.isVector());
@@ -30,7 +30,7 @@ TEST(ScalableVectorMVTsTest, IntegerMVTs) {
 }
 
 TEST(ScalableVectorMVTsTest, FloatMVTs) {
-  for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) {
+  for (MVT VecTy : MVT::fp_scalable_vector_valuetypes()) {
     ASSERT_TRUE(VecTy.isValid());
     ASSERT_TRUE(VecTy.isFloatingPoint());
     ASSERT_TRUE(VecTy.isVector());

diff  --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index f8816e4d43df4..8eca261e65e29 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1551,9 +1551,9 @@ void ICmpTestImpl(CmpInst::Predicate Pred) {
 }
 
 TEST(ConstantRange, ICmp) {
-  for (auto Pred : seq<unsigned>(CmpInst::Predicate::FIRST_ICMP_PREDICATE,
-                                 1 + CmpInst::Predicate::LAST_ICMP_PREDICATE))
-    ICmpTestImpl((CmpInst::Predicate)Pred);
+  for (auto Pred : seq_inclusive(CmpInst::Predicate::FIRST_ICMP_PREDICATE,
+                                 CmpInst::Predicate::LAST_ICMP_PREDICATE))
+    ICmpTestImpl(Pred);
 }
 
 TEST(ConstantRange, MakeGuaranteedNoWrapRegion) {

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4ef942d547765..a86cb4f050c8c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1282,7 +1282,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
   unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
   AffineExpr expr;
   SmallVector<Value, 2> dynamicDims;
-  for (auto dim : llvm::seq(startPos, endPos + 1)) {
+  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
     dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
     AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
     expr = (expr ? expr * currExpr : currExpr);
@@ -1315,7 +1315,7 @@ getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
         map.value().getResults().front().cast<AffineDimExpr>().getPosition();
     unsigned endPos =
         map.value().getResults().back().cast<AffineDimExpr>().getPosition();
-    for (auto dim : llvm::seq(startPos, endPos + 1)) {
+    for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
       expandedDimToCollapsedDim[dim] = map.index();
     }
   }


        


More information about the Mlir-commits mailing list