[Mlir-commits] [mlir] 0cb5d7f - [mlir] Add value_begin/value_end methods to DenseElementsAttr

River Riddle llvmlistbot at llvm.org
Mon Sep 20 18:58:12 PDT 2021


Author: River Riddle
Date: 2021-09-21T01:57:43Z
New Revision: 0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6

URL: https://github.com/llvm/llvm-project/commit/0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6
DIFF: https://github.com/llvm/llvm-project/commit/0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6.diff

LOG: [mlir] Add value_begin/value_end methods to DenseElementsAttr

Currently DenseElementsAttr only exposes the ability to get the full range of values for a given type T, but there are many situations where we just want the beginning/end iterator. This revision adds proper value_begin/value_end methods for all of the supported T types, and also cleans up a bit of the interface.

Differential Revision: https://reviews.llvm.org/D104173

Added: 
    

Modified: 
    mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
    mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
    mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/TableGen/StructsGenTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 8acbc37c77f7c..a647a51f0c9ab 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
     // functor recursively walks the dimensions of the constant shape,
     // generating a store when the recursion hits the base case.
     SmallVector<Value, 2> indices;
-    auto valueIt = constantValue.getValues<FloatAttr>().begin();
+    auto valueIt = constantValue.value_begin<FloatAttr>();
     std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
       // The last dimension is the base case of the recursion, at this point
       // we store the element at the given index.

diff  --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index c1ad4dc66e996..1684dc6cf32bc 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -164,7 +164,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
     // functor recursively walks the dimensions of the constant shape,
     // generating a store when the recursion hits the base case.
     SmallVector<Value, 2> indices;
-    auto valueIt = constantValue.getValues<FloatAttr>().begin();
+    auto valueIt = constantValue.value_begin<FloatAttr>();
     std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
       // The last dimension is the base case of the recursion, at this point
       // we store the element at the given index.

diff  --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 8acbc37c77f7c..a647a51f0c9ab 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
     // functor recursively walks the dimensions of the constant shape,
     // generating a store when the recursion hits the base case.
     SmallVector<Value, 2> indices;
-    auto valueIt = constantValue.getValues<FloatAttr>().begin();
+    auto valueIt = constantValue.value_begin<FloatAttr>();
     std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
       // The last dimension is the base case of the recursion, at this point
       // we store the element at the given index.

diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index fb5991fa72af7..a52b89027c5b6 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -58,8 +58,8 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
     auto lhs = operands[0].cast<ElementsAttr>();
     auto rhs = operands[1].cast<ElementsAttr>();
 
-    auto lhsIt = lhs.getValues<ElementValueT>().begin();
-    auto rhsIt = rhs.getValues<ElementValueT>().begin();
+    auto lhsIt = lhs.value_begin<ElementValueT>();
+    auto rhsIt = rhs.value_begin<ElementValueT>();
     SmallVector<ElementValueT, 4> elementResults;
     elementResults.reserve(lhs.getNumElements());
     for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index d332718bb9b17..6edd56b09b55e 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -51,6 +51,9 @@ class ElementsAttr : public Attribute {
   /// with static shape.
   ShapedType getType() const;
 
+  /// Return the element type of this ElementsAttr.
+  Type getElementType() const;
+
   /// Return the value at the given index. The index is expected to refer to a
   /// valid element.
   Attribute getValue(ArrayRef<uint64_t> index) const;
@@ -65,8 +68,9 @@ class ElementsAttr : public Attribute {
   /// Return the elements of this attribute as a value of type 'T'. Note:
   /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
   /// iteration.
-  template <typename T>
-  iterator_range<T> getValues() const;
+  template <typename T> iterator_range<T> getValues() const;
+  template <typename T> iterator<T> value_begin() const;
+  template <typename T> iterator<T> value_end() const;
 
   /// Return if the given 'index' refers to a valid element in this attribute.
   bool isValidIndex(ArrayRef<uint64_t> index) const;
@@ -417,7 +421,7 @@ class DenseElementsAttr : public ElementsAttr {
                           T>::type
   getSplatValue() const {
     assert(isSplat() && "expected the attribute to be a splat");
-    return *getValues<T>().begin();
+    return *value_begin<T>();
   }
   /// Return the splat value for derived attribute element types.
   template <typename T>
@@ -436,15 +440,21 @@ class DenseElementsAttr : public ElementsAttr {
   template <typename T>
   T getValue(ArrayRef<uint64_t> index) const {
     // Skip to the element corresponding to the flattened index.
-    return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
+    return getFlatValue<T>(getFlattenedIndex(index));
+  }
+  /// Return the value at the given flattened index.
+  template <typename T> T getFlatValue(uint64_t index) const {
+    return *std::next(value_begin<T>(), index);
   }
 
   /// Return the held element values as a range of integer or floating-point
   /// values.
-  template <typename T, typename = typename std::enable_if<
-                            (!std::is_same<T, bool>::value &&
-                             std::numeric_limits<T>::is_integer) ||
-                            is_valid_cpp_fp_type<T>::value>::type>
+  template <typename T>
+  using IntFloatValueTemplateCheckT =
+      typename std::enable_if<(!std::is_same<T, bool>::value &&
+                               std::numeric_limits<T>::is_integer) ||
+                              is_valid_cpp_fp_type<T>::value>::type;
+  template <typename T, typename = IntFloatValueTemplateCheckT<T>>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
                              std::numeric_limits<T>::is_signed));
@@ -453,13 +463,27 @@ class DenseElementsAttr : public ElementsAttr {
     return {ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
+  template <typename T, typename = IntFloatValueTemplateCheckT<T>>
+  ElementIterator<T> value_begin() const {
+    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+                             std::numeric_limits<T>::is_signed));
+    return ElementIterator<T>(getRawData().data(), isSplat(), 0);
+  }
+  template <typename T, typename = IntFloatValueTemplateCheckT<T>>
+  ElementIterator<T> value_end() const {
+    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+                             std::numeric_limits<T>::is_signed));
+    return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+  }
 
   /// Return the held element values as a range of std::complex.
+  template <typename T, typename ElementT>
+  using ComplexValueTemplateCheckT =
+      typename std::enable_if<detail::is_complex_t<T>::value &&
+                              (std::numeric_limits<ElementT>::is_integer ||
+                               is_valid_cpp_fp_type<ElementT>::value)>::type;
   template <typename T, typename ElementT = typename T::value_type,
-            typename = typename std::enable_if<
-                detail::is_complex_t<T>::value &&
-                (std::numeric_limits<ElementT>::is_integer ||
-                 is_valid_cpp_fp_type<ElementT>::value)>::type>
+            typename = ComplexValueTemplateCheckT<T, ElementT>>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
                           std::numeric_limits<ElementT>::is_signed));
@@ -468,10 +492,26 @@ class DenseElementsAttr : public ElementsAttr {
     return {ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
+  template <typename T, typename ElementT = typename T::value_type,
+            typename = ComplexValueTemplateCheckT<T, ElementT>>
+  ElementIterator<T> value_begin() const {
+    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+                          std::numeric_limits<ElementT>::is_signed));
+    return ElementIterator<T>(getRawData().data(), isSplat(), 0);
+  }
+  template <typename T, typename ElementT = typename T::value_type,
+            typename = ComplexValueTemplateCheckT<T, ElementT>>
+  ElementIterator<T> value_end() const {
+    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+                          std::numeric_limits<ElementT>::is_signed));
+    return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+  }
 
   /// Return the held element values as a range of StringRef.
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, StringRef>::value>::type>
+  template <typename T>
+  using StringRefValueTemplateCheckT =
+      typename std::enable_if<std::is_same<T, StringRef>::value>::type;
+  template <typename T, typename = StringRefValueTemplateCheckT<T>>
   llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
     auto stringRefs = getRawStringData();
     const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
@@ -479,80 +519,156 @@ class DenseElementsAttr : public ElementsAttr {
     return {ElementIterator<StringRef>(ptr, splat, 0),
             ElementIterator<StringRef>(ptr, splat, getNumElements())};
   }
+  template <typename T, typename = StringRefValueTemplateCheckT<T>>
+  ElementIterator<StringRef> value_begin() const {
+    const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
+    return ElementIterator<StringRef>(ptr, isSplat(), 0);
+  }
+  template <typename T, typename = StringRefValueTemplateCheckT<T>>
+  ElementIterator<StringRef> value_end() const {
+    const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
+    return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
+  }
 
   /// Return the held element values as a range of Attributes.
-  llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, Attribute>::value>::type>
+  template <typename T>
+  using AttributeValueTemplateCheckT =
+      typename std::enable_if<std::is_same<T, Attribute>::value>::type;
+  template <typename T, typename = AttributeValueTemplateCheckT<T>>
   llvm::iterator_range<AttributeElementIterator> getValues() const {
-    return getAttributeValues();
+    return {value_begin<Attribute>(), value_end<Attribute>()};
+  }
+  template <typename T, typename = AttributeValueTemplateCheckT<T>>
+  AttributeElementIterator value_begin() const {
+    return AttributeElementIterator(*this, 0);
+  }
+  template <typename T, typename = AttributeValueTemplateCheckT<T>>
+  AttributeElementIterator value_end() const {
+    return AttributeElementIterator(*this, getNumElements());
   }
-  AttributeElementIterator attr_value_begin() const;
-  AttributeElementIterator attr_value_end() const;
 
   /// Return the held element values a range of T, where T is a derived
   /// attribute type.
   template <typename T>
+  using DerivedAttrValueTemplateCheckT =
+      typename std::enable_if<std::is_base_of<Attribute, T>::value &&
+                              !std::is_same<Attribute, T>::value>::type;
+  template <typename T>
   using DerivedAttributeElementIterator =
       llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_base_of<Attribute, T>::value &&
-                            !std::is_same<Attribute, T>::value>::type>
+  template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
     auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-    return llvm::map_range(getAttributeValues(),
+    return llvm::map_range(getValues<Attribute>(),
                            static_cast<T (*)(Attribute)>(castFn));
   }
+  template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
+  DerivedAttributeElementIterator<T> value_begin() const {
+    auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+    return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+  }
+  template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
+  DerivedAttributeElementIterator<T> value_end() const {
+    auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+    return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+  }
 
   /// Return the held element values as a range of bool. The element type of
   /// this attribute must be of integer type of bitwidth 1.
-  llvm::iterator_range<BoolElementIterator> getBoolValues() const;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, bool>::value>::type>
+  template <typename T>
+  using BoolValueTemplateCheckT =
+      typename std::enable_if<std::is_same<T, bool>::value>::type;
+  template <typename T, typename = BoolValueTemplateCheckT<T>>
   llvm::iterator_range<BoolElementIterator> getValues() const {
-    return getBoolValues();
+    assert(isValidBool() && "bool is not the value of this elements attribute");
+    return {BoolElementIterator(*this, 0),
+            BoolElementIterator(*this, getNumElements())};
+  }
+  template <typename T, typename = BoolValueTemplateCheckT<T>>
+  BoolElementIterator value_begin() const {
+    assert(isValidBool() && "bool is not the value of this elements attribute");
+    return BoolElementIterator(*this, 0);
+  }
+  template <typename T, typename = BoolValueTemplateCheckT<T>>
+  BoolElementIterator value_end() const {
+    assert(isValidBool() && "bool is not the value of this elements attribute");
+    return BoolElementIterator(*this, getNumElements());
   }
 
   /// Return the held element values as a range of APInts. The element type of
   /// this attribute must be of integer type.
-  llvm::iterator_range<IntElementIterator> getIntValues() const;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, APInt>::value>::type>
+  template <typename T>
+  using APIntValueTemplateCheckT =
+      typename std::enable_if<std::is_same<T, APInt>::value>::type;
+  template <typename T, typename = APIntValueTemplateCheckT<T>>
   llvm::iterator_range<IntElementIterator> getValues() const {
-    return getIntValues();
+    assert(getElementType().isIntOrIndex() && "expected integral type");
+    return {raw_int_begin(), raw_int_end()};
+  }
+  template <typename T, typename = APIntValueTemplateCheckT<T>>
+  IntElementIterator value_begin() const {
+    assert(getElementType().isIntOrIndex() && "expected integral type");
+    return raw_int_begin();
+  }
+  template <typename T, typename = APIntValueTemplateCheckT<T>>
+  IntElementIterator value_end() const {
+    assert(getElementType().isIntOrIndex() && "expected integral type");
+    return raw_int_end();
   }
-  IntElementIterator int_value_begin() const;
-  IntElementIterator int_value_end() const;
 
   /// Return the held element values as a range of complex APInts. The element
   /// type of this attribute must be a complex of integer type.
-  llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, std::complex<APInt>>::value>::type>
+  template <typename T>
+  using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
+      std::is_same<T, std::complex<APInt>>::value>::type;
+  template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
   llvm::iterator_range<ComplexIntElementIterator> getValues() const {
     return getComplexIntValues();
   }
+  template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
+  ComplexIntElementIterator value_begin() const {
+    return complex_value_begin();
+  }
+  template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
+  ComplexIntElementIterator value_end() const {
+    return complex_value_end();
+  }
 
   /// Return the held element values as a range of APFloat. The element type of
   /// this attribute must be of float type.
-  llvm::iterator_range<FloatElementIterator> getFloatValues() const;
-  template <typename T, typename = typename std::enable_if<
-                            std::is_same<T, APFloat>::value>::type>
+  template <typename T>
+  using APFloatValueTemplateCheckT =
+      typename std::enable_if<std::is_same<T, APFloat>::value>::type;
+  template <typename T, typename = APFloatValueTemplateCheckT<T>>
   llvm::iterator_range<FloatElementIterator> getValues() const {
     return getFloatValues();
   }
-  FloatElementIterator float_value_begin() const;
-  FloatElementIterator float_value_end() const;
+  template <typename T, typename = APFloatValueTemplateCheckT<T>>
+  FloatElementIterator value_begin() const {
+    return float_value_begin();
+  }
+  template <typename T, typename = APFloatValueTemplateCheckT<T>>
+  FloatElementIterator value_end() const {
+    return float_value_end();
+  }
 
   /// Return the held element values as a range of complex APFloat. The element
   /// type of this attribute must be a complex of float type.
-  llvm::iterator_range<ComplexFloatElementIterator>
-  getComplexFloatValues() const;
-  template <typename T, typename = typename std::enable_if<std::is_same<
-                            T, std::complex<APFloat>>::value>::type>
+  template <typename T>
+  using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
+      std::is_same<T, std::complex<APFloat>>::value>::type;
+  template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
   llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
     return getComplexFloatValues();
   }
+  template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
+  ComplexFloatElementIterator value_begin() const {
+    return complex_float_value_begin();
+  }
+  template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
+  ComplexFloatElementIterator value_end() const {
+    return complex_float_value_end();
+  }
 
   /// Return the raw storage data held by this attribute. Users should generally
   /// not use this directly, as the internal storage format is not always in the
@@ -590,13 +706,25 @@ class DenseElementsAttr : public ElementsAttr {
             function_ref<APInt(const APFloat &)> mapping) const;
 
 protected:
-  /// Get iterators to the raw APInt values for each element in this attribute.
+  /// Iterators to various elements that require out-of-line definition. These
+  /// are hidden from the user to encourage consistent use of the
+  /// getValues/value_begin/value_end API.
   IntElementIterator raw_int_begin() const {
     return IntElementIterator(*this, 0);
   }
   IntElementIterator raw_int_end() const {
     return IntElementIterator(*this, getNumElements());
   }
+  llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+  ComplexIntElementIterator complex_value_begin() const;
+  ComplexIntElementIterator complex_value_end() const;
+  llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+  FloatElementIterator float_value_begin() const;
+  FloatElementIterator float_value_end() const;
+  llvm::iterator_range<ComplexFloatElementIterator>
+  getComplexFloatValues() const;
+  ComplexFloatElementIterator complex_float_value_begin() const;
+  ComplexFloatElementIterator complex_float_value_end() const;
 
   /// Overload of the raw 'get' method that asserts that the given type is of
   /// complex type. This method is used to verify type invariants that the
@@ -616,11 +744,8 @@ class DenseElementsAttr : public ElementsAttr {
   /// Check the information for a C++ data type, check if this type is valid for
   /// the current attribute. This method is used to verify specific type
   /// invariants that the templatized 'getValues' method cannot.
+  bool isValidBool() const { return getElementType().isInteger(1); }
   bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
-
-  /// Check the information for a C++ data type, check if this type is valid for
-  /// the current attribute. This method is used to verify specific type
-  /// invariants that the templatized 'getValues' method cannot.
   bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
 };
 
@@ -806,7 +931,7 @@ template <typename T>
 auto SparseElementsAttr::getValues() const
     -> llvm::iterator_range<iterator<T>> {
   auto zeroValue = getZeroValue<T>();
-  auto valueIt = getValues().getValues<T>().begin();
+  auto valueIt = getValues().value_begin<T>();
   const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
   std::function<T(ptr
diff _t)> mapFn =
       [flatSparseIndices{std::move(flatSparseIndices)},
@@ -821,6 +946,14 @@ auto SparseElementsAttr::getValues() const
       };
   return llvm::map_range(llvm::seq<ptr
diff _t>(0, getNumElements()), mapFn);
 }
+template <typename T>
+auto SparseElementsAttr::value_begin() const -> iterator<T> {
+  return getValues<T>().begin();
+}
+template <typename T>
+auto SparseElementsAttr::value_end() const -> iterator<T> {
+  return getValues<T>().end();
+}
 
 namespace detail {
 /// This class represents a general iterator over the values of an ElementsAttr.
@@ -833,8 +966,7 @@ class ElementsAttrIterator
   // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
   // inside of a conversion operator.
   using DenseIteratorT = typename std::enable_if<
-      true,
-      decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
+      true, decltype(std::declval<DenseElementsAttr>().value_begin<T>())>::type;
   using SparseIteratorT = SparseElementsAttr::iterator<T>;
 
   /// A union containing the specific iterators for each derived attribute kind.
@@ -960,6 +1092,21 @@ auto ElementsAttr::getValues() const -> iterator_range<T> {
   llvm_unreachable("unexpected attribute kind");
 }
 
+template <typename T> auto ElementsAttr::value_begin() const -> iterator<T> {
+  if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
+    return iterator<T>(*this, denseAttr.value_begin<T>());
+  if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
+    return iterator<T>(*this, sparseAttr.value_begin<T>());
+  llvm_unreachable("unexpected attribute kind");
+}
+template <typename T> auto ElementsAttr::value_end() const -> iterator<T> {
+  if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
+    return iterator<T>(*this, denseAttr.value_end<T>());
+  if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
+    return iterator<T>(*this, sparseAttr.value_end<T>());
+  llvm_unreachable("unexpected attribute kind");
+}
+
 } // end namespace mlir.
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 25e54cbfd68c9..e39d56f146ef1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -721,6 +721,8 @@ def Builtin_SparseElementsAttr
     /// 'T'  may be any of Attribute, APInt, APFloat, c++ integer/float types,
     /// etc.
     template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
+    template <typename T> iterator<T> value_begin() const;
+    template <typename T> iterator<T> value_end() const;
 
     /// Return the value of the element at the given index. The 'index' is
     /// expected to refer to a valid element.

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 4ae54d4caad4e..a2ee06722f0d8 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -505,48 +505,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
 // Indexed accessors.
 
 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
 }
 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
 }
 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
 }
 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
 }
 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
-  return *(
-      unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
-      pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
 }
 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
 }
 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
-  return *(
-      unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
-      pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
 }
 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
 }
 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
-  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
-           pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
 }
 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
                                                   intptr_t pos) {
   return wrap(
-      *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
-        pos));
+      unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index c20f4ed6b567d..bd1e4ade4ec7a 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -127,7 +127,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 
   if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
     return failure();
-  if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
+  if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
                    [](const APInt &size) { return !size.isOneValue(); }))
     return failure();
 

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index c59f291adc1e9..0da2209e2df08 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -558,9 +558,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
   if (srcElemType != dstElemType) {
     SmallVector<Attribute, 8> elements;
     if (srcElemType.isa<FloatType>()) {
-      for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
-        FloatAttr dstAttr = convertFloatAttr(
-            srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
+      for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
+        FloatAttr dstAttr =
+            convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
         if (!dstAttr)
           return failure();
         elements.push_back(dstAttr);
@@ -568,10 +568,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
     } else if (srcElemType.isInteger(1)) {
       return failure();
     } else {
-      for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
-        IntegerAttr dstAttr =
-            convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
-                               dstElemType.cast<IntegerType>(), rewriter);
+      for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
+        IntegerAttr dstAttr = convertIntegerAttr(
+            srcAttr, dstElemType.cast<IntegerType>(), rewriter);
         if (!dstAttr)
           return failure();
         elements.push_back(dstAttr);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 558ba4d04acbb..4184b9cadf4dc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1610,7 +1610,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
     SmallVector<AffineExpr, 2> inputExprs;
     inputExprs.resize(resultTy.getRank());
-    for (auto permutation : llvm::enumerate(perms.getIntValues())) {
+    for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
       inputExprs[permutation.value().getZExtValue()] =
           rewriter.getAffineDimExpr(permutation.index());
     }

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b3fe10d05ce04..26ab9551854a5 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -337,11 +337,12 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
   auto attrName =
       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
   auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
+
+  // Async dependencies is the only variadic operand.
   if (!sizeAttr)
-    return; // Async dependencies is the only variadic operand.
-  SmallVector<int32_t, 8> sizes;
-  for (auto size : sizeAttr.getIntValues())
-    sizes.push_back(size.getSExtValue());
+    return;
+
+  SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>());
   ++sizes.front();
   op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes));
 }

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b6d072f9f60a1..f3b69ae85182e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1825,8 +1825,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
     // and hence was replaced.
     if (complexElementType.isa<IntegerType>()) {
       bool isSigned = !complexElementType.isUnsignedInteger();
+      auto valueIt = attr.value_begin<std::complex<APInt>>();
       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-        auto complexValue = *(attr.getComplexIntValues().begin() + index);
+        auto complexValue = *(valueIt + index);
         os << "(";
         printDenseIntElement(complexValue.real(), os, isSigned);
         os << ",";
@@ -1834,8 +1835,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
         os << ")";
       });
     } else {
+      auto valueIt = attr.value_begin<std::complex<APFloat>>();
       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-        auto complexValue = *(attr.getComplexFloatValues().begin() + index);
+        auto complexValue = *(valueIt + index);
         os << "(";
         printFloatValue(complexValue.real(), os);
         os << ",";
@@ -1845,15 +1847,15 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
     }
   } else if (elementType.isIntOrIndex()) {
     bool isSigned = !elementType.isUnsignedInteger();
-    auto intValues = attr.getIntValues();
+    auto valueIt = attr.value_begin<APInt>();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printDenseIntElement(*(intValues.begin() + index), os, isSigned);
+      printDenseIntElement(*(valueIt + index), os, isSigned);
     });
   } else {
     assert(elementType.isa<FloatType>() && "unexpected element type");
-    auto floatValues = attr.getFloatValues();
+    auto valueIt = attr.value_begin<APFloat>();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printFloatValue(*(floatValues.begin() + index), os);
+      printFloatValue(*(valueIt + index), os);
     });
   }
 }

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index d906adc3e8151..a3c7fb0af9293 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -390,6 +390,8 @@ ShapedType ElementsAttr::getType() const {
   return Attribute::getType().cast<ShapedType>();
 }
 
+Type ElementsAttr::getElementType() const { return getType().getElementType(); }
+
 /// Returns the number of elements held by this attribute.
 int64_t ElementsAttr::getNumElements() const {
   return getType().getNumElements();
@@ -635,7 +637,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
 
 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
-  Type eltTy = owner.getType().getElementType();
+  Type eltTy = owner.getElementType();
   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
   if (eltTy.isa<IndexType>())
@@ -690,7 +692,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
     DenseElementsAttr attr, size_t dataIndex)
     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
           attr.getRawData().data(), attr.isSplat(), dataIndex),
-      bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
+      bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
 
 APInt DenseElementsAttr::IntElementIterator::operator*() const {
   return readBits(getData(),
@@ -707,7 +709,7 @@ DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
                                       std::complex<APInt>, std::complex<APInt>,
                                       std::complex<APInt>>(
           attr.getRawData().data(), attr.isSplat(), dataIndex) {
-  auto complexType = attr.getType().getElementType().cast<ComplexType>();
+  auto complexType = attr.getElementType().cast<ComplexType>();
   bitWidth = getDenseElementBitWidth(complexType.getElementType());
 }
 
@@ -930,21 +932,15 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
                                                     isInt, isSigned);
 }
 
-/// A method used to verify specific type invariants that the templatized 'get'
-/// method cannot.
 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
                                           bool isSigned) const {
-  return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
-                             isSigned);
+  return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
 }
-
-/// Check the information for a C++ data type, check if this type is valid for
-/// the current attribute.
 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
                                        bool isSigned) const {
   return ::isValidIntOrFloat(
-      getType().getElementType().cast<ComplexType>().getElementType(),
-      dataEltSize / 2, isInt, isSigned);
+      getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
+      isInt, isSigned);
 }
 
 /// Returns true if this attribute corresponds to a splat, i.e. if all element
@@ -953,76 +949,69 @@ bool DenseElementsAttr::isSplat() const {
   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
 }
 
-/// Return the held element values as a range of Attributes.
-auto DenseElementsAttr::getAttributeValues() const
-    -> llvm::iterator_range<AttributeElementIterator> {
-  return {attr_value_begin(), attr_value_end()};
-}
-auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
-  return AttributeElementIterator(*this, 0);
-}
-auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
-  return AttributeElementIterator(*this, getNumElements());
+/// Return if the given complex type has an integer element type.
+static bool isComplexOfIntType(Type type) {
+  return type.cast<ComplexType>().getElementType().isa<IntegerType>();
 }
 
-/// Return the held element values as a range of bool. The element type of
-/// this attribute must be of integer type of bitwidth 1.
-auto DenseElementsAttr::getBoolValues() const
-    -> llvm::iterator_range<BoolElementIterator> {
-  auto eltType = getType().getElementType().dyn_cast<IntegerType>();
-  assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
-  (void)eltType;
-  return {BoolElementIterator(*this, 0),
-          BoolElementIterator(*this, getNumElements())};
-}
-
-/// Return the held element values as a range of APInts. The element type of
-/// this attribute must be of integer type.
-auto DenseElementsAttr::getIntValues() const
-    -> llvm::iterator_range<IntElementIterator> {
-  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
-  return {raw_int_begin(), raw_int_end()};
-}
-auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
-  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
-  return raw_int_begin();
-}
-auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
-  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
-  return raw_int_end();
-}
 auto DenseElementsAttr::getComplexIntValues() const
     -> llvm::iterator_range<ComplexIntElementIterator> {
-  Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
-  (void)eltTy;
-  assert(eltTy.isa<IntegerType>() && "expected complex integral type");
+  assert(isComplexOfIntType(getElementType()) &&
+         "expected complex integral type");
   return {ComplexIntElementIterator(*this, 0),
           ComplexIntElementIterator(*this, getNumElements())};
 }
+auto DenseElementsAttr::complex_value_begin() const
+    -> ComplexIntElementIterator {
+  assert(isComplexOfIntType(getElementType()) &&
+         "expected complex integral type");
+  return ComplexIntElementIterator(*this, 0);
+}
+auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
+  assert(isComplexOfIntType(getElementType()) &&
+         "expected complex integral type");
+  return ComplexIntElementIterator(*this, getNumElements());
+}
 
 /// Return the held element values as a range of APFloat. The element type of
 /// this attribute must be of float type.
 auto DenseElementsAttr::getFloatValues() const
     -> llvm::iterator_range<FloatElementIterator> {
-  auto elementType = getType().getElementType().cast<FloatType>();
+  auto elementType = getElementType().cast<FloatType>();
   const auto &elementSemantics = elementType.getFloatSemantics();
   return {FloatElementIterator(elementSemantics, raw_int_begin()),
           FloatElementIterator(elementSemantics, raw_int_end())};
 }
 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
-  return getFloatValues().begin();
+  auto elementType = getElementType().cast<FloatType>();
+  return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
 }
 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
-  return getFloatValues().end();
+  auto elementType = getElementType().cast<FloatType>();
+  return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
 }
+
 auto DenseElementsAttr::getComplexFloatValues() const
     -> llvm::iterator_range<ComplexFloatElementIterator> {
-  Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+  Type eltTy = getElementType().cast<ComplexType>().getElementType();
   assert(eltTy.isa<FloatType>() && "expected complex float type");
   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
   return {{semantics, {*this, 0}},
           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
 }
+auto DenseElementsAttr::complex_float_value_begin() const
+    -> ComplexFloatElementIterator {
+  Type eltTy = getElementType().cast<ComplexType>().getElementType();
+  assert(eltTy.isa<FloatType>() && "expected complex float type");
+  return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
+}
+auto DenseElementsAttr::complex_float_value_end() const
+    -> ComplexFloatElementIterator {
+  Type eltTy = getElementType().cast<ComplexType>().getElementType();
+  assert(eltTy.isa<FloatType>() && "expected complex float type");
+  return {eltTy.cast<FloatType>().getFloatSemantics(),
+          {*this, static_cast<size_t>(getNumElements())}};
+}
 
 /// Return the raw storage data held by this attribute.
 ArrayRef<char> DenseElementsAttr::getRawData() const {
@@ -1374,19 +1363,19 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 
 /// Get a zero APFloat for the given sparse attribute.
 APFloat SparseElementsAttr::getZeroAPFloat() const {
-  auto eltType = getType().getElementType().cast<FloatType>();
+  auto eltType = getElementType().cast<FloatType>();
   return APFloat(eltType.getFloatSemantics());
 }
 
 /// Get a zero APInt for the given sparse attribute.
 APInt SparseElementsAttr::getZeroAPInt() const {
-  auto eltType = getType().getElementType().cast<IntegerType>();
+  auto eltType = getElementType().cast<IntegerType>();
   return APInt::getZero(eltType.getWidth());
 }
 
 /// Get a zero attribute for the given attribute type.
 Attribute SparseElementsAttr::getZeroAttr() const {
-  auto eltType = getType().getElementType();
+  auto eltType = getElementType();
 
   // Handle floating point elements.
   if (eltType.isa<FloatType>())

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 457c993b92ad3..6cba5742a3e0f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1024,7 +1024,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
     return op->emitOpError("requires 1D i32 elements attribute '")
            << attrName << "'";
 
-  if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
+  if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) {
         return !element.isNonNegative();
       }))
     return op->emitOpError("'")

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 7a3f602ca3d43..9676553aeaa57 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -51,7 +51,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
     auto dattr = attr.cast<DenseIntElementsAttr>();
     res.clear();
     res.reserve(dattr.size());
-    for (auto it : dattr.getIntValues())
+    for (auto it : dattr.getValues<APInt>())
       res.push_back(it.getSExtValue());
   } else {
     auto vals = val.get<ShapedTypeComponents *>()->getDims();
@@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
     return t.cast<ShapedType>().getDimSize(index);
   if (auto attr = val.dyn_cast<Attribute>())
     return attr.cast<DenseIntElementsAttr>()
-        .getValue<APInt>({static_cast<uint64_t>(index)})
+        .getFlatValue<APInt>(index)
         .getSExtValue();
   auto *stc = val.get<ShapedTypeComponents *>();
   return stc->getDims()[index];
@@ -94,7 +94,7 @@ bool ShapeAdaptor::hasStaticShape() const {
     return t.cast<ShapedType>().hasStaticShape();
   if (auto attr = val.dyn_cast<Attribute>()) {
     auto dattr = attr.cast<DenseIntElementsAttr>();
-    for (auto index : dattr.getIntValues())
+    for (auto index : dattr.getValues<APInt>())
       if (ShapedType::isDynamic(index.getSExtValue()))
         return false;
     return true;
@@ -115,7 +115,7 @@ int64_t ShapeAdaptor::getNumElements() const {
   if (auto attr = val.dyn_cast<Attribute>()) {
     auto dattr = attr.cast<DenseIntElementsAttr>();
     int64_t num = 1;
-    for (auto index : dattr.getIntValues()) {
+    for (auto index : dattr.getValues<APInt>()) {
       num *= index.getZExtValue();
       assert(num >= 0 && "integer overflow in element count computation");
     }

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 3945cfa6ee0f7..199a8c49aca8c 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -294,7 +294,8 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
       if (!nested)
         return nullptr;
 
-      values.append(nested.attr_value_begin(), nested.attr_value_end());
+      values.append(nested.value_begin<Attribute>(),
+                    nested.value_end<Attribute>());
     }
 
     return DenseElementsAttr::get(outerType, values);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index e40fa9b8e6b7a..9562e9b1c1f92 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -83,12 +83,14 @@ const char *opSegmentSizeAttrInitCode = R"(
   auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
 )";
 const char *attrSizedSegmentValueRangeCalcCode = R"(
-  auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
+  const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
+  if (sizeAttr.isSplat())
+    return {*sizeAttrValueIt * index, *sizeAttrValueIt};
+
   unsigned start = 0;
   for (unsigned i = 0; i < index; ++i)
-    start += *(sizeAttrValues.begin() + i);
-  unsigned size = *(sizeAttrValues.begin() + index);
-  return {start, size};
+    start += sizeAttrValueIt[i];
+  return {start, sizeAttrValueIt[index]};
 )";
 // The logic to calculate the actual value range for a declared operand
 // of an op with variadic of variadic operands within the OpAdaptor.

diff  --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index ef0bdd81ee3a5..4fcdf1d673195 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -158,7 +158,7 @@ TEST(StructsGenTest, GetElements) {
   auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
   ASSERT_TRUE(denseAttr);
 
-  for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
+  for (const auto &valIndexIt : llvm::enumerate(denseAttr.getValues<APInt>())) {
     EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
   }
 }


        


More information about the Mlir-commits mailing list