[Mlir-commits] [mlir] ae40d62 - [mlir] Refactor ElementsAttr's value access API

River Riddle llvmlistbot at llvm.org
Mon Nov 8 16:21:29 PST 2021


Author: River Riddle
Date: 2021-11-09T00:15:08Z
New Revision: ae40d625410036d65cfe09f2122b81450f62ea99

URL: https://github.com/llvm/llvm-project/commit/ae40d625410036d65cfe09f2122b81450f62ea99
DIFF: https://github.com/llvm/llvm-project/commit/ae40d625410036d65cfe09f2122b81450f62ea99.diff

LOG: [mlir] Refactor ElementsAttr's value access API

There are several aspects of the API that either aren't easy to use, or are
deceptively easy to do the wrong thing. The main change of this commit
is to remove all of the `getValue<T>`/`getFlatValue<T>` from ElementsAttr
and instead provide operator[] methods on the ranges returned by
`getValues<T>`. This provides a much more convenient API for the value
ranges. It also removes the easy-to-be-inefficient nature of
getValue/getFlatValue, which under the hood would construct a new range for
the type `T`. Constructing a range is not necessarily cheap in all cases, and
could lead to very poor performance if used within a loop; i.e. if you were to
naively write something like:

```
DenseElementsAttr attr = ...;
for (int i = 0; i < size; ++i) {
  // We are internally rebuilding the APFloat value range on each iteration!!
  APFloat it = attr.getFlatValue<APFloat>(i);
}
```

Differential Revision: https://reviews.llvm.org/D113229

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/IR/BuiltinAttributeInterfaces.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
    mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index c48a359383ff4..2ed1c84ee537f 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -227,6 +227,33 @@ class ElementsAttrIterator
   ElementsAttrIndexer indexer;
   ptr
diff _t index;
 };
+
+/// This class provides iterator utilities for an ElementsAttr range.
+template <typename IteratorT>
+class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
+public:
+  using reference = typename IteratorT::reference;
+
+  ElementsAttrRange(Type shapeType,
+                    const llvm::iterator_range<IteratorT> &range)
+      : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
+  ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt)
+      : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
+
+  /// Return the value at the given index.
+  reference operator[](ArrayRef<uint64_t> index) const;
+  reference operator[](uint64_t index) const {
+    return *std::next(this->begin(), index);
+  }
+
+  /// Return the size of this range.
+  size_t size() const { return llvm::size(*this); }
+
+private:
+  /// The shaped type of the parent ElementsAttr.
+  Type shapeType;
+};
+
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -256,6 +283,16 @@ verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
+namespace detail {
+/// Return the value at the given index.
+template <typename IteratorT>
+auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const
+    -> reference {
+  // Skip to the element corresponding to the flattened index.
+  return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
+}
+} // namespace detail
+
 /// Return the elements of this attribute as a value of type 'T'.
 template <typename T>
 auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 30b3ea7ca09a8..45295e874f3bd 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -158,27 +158,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
   ];
 
   string ElementsAttrInterfaceAccessors = [{
-    /// Return the attribute value at the given index. The index is expected to
-    /// refer to a valid element.
-    Attribute getValue(ArrayRef<uint64_t> index) const {
-      return getValue<Attribute>(index);
-    }
-
-    /// Return the value of type 'T' at the given index, where 'T' corresponds
-    /// to an Attribute type.
-    template <typename T>
-    std::enable_if_t<!std::is_same<T, ::mlir::Attribute>::value &&
-                     std::is_base_of<T, ::mlir::Attribute>::value>
-    getValue(ArrayRef<uint64_t> index) const {
-      return getValue(index).template dyn_cast_or_null<T>();
-    }
-
-    /// Return the value of type 'T' at the given index.
-    template <typename T>
-    T getValue(ArrayRef<uint64_t> index) const {
-      return getFlatValue<T>(getFlattenedIndex(index));
-    }
-
     /// Return the number of elements held by this attribute.
     int64_t size() const { return getNumElements(); }
 
@@ -281,6 +260,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     // Value Iteration
     //===------------------------------------------------------------------===//
 
+    /// The iterator for the given element type T.
+    template <typename T, typename AttrT = ConcreteAttr>
+    using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
+    /// The iterator range over the given element T.
+    template <typename T, typename AttrT = ConcreteAttr>
+    using iterator_range =
+        decltype(std::declval<AttrT>().template getValues<T>());
+
     /// Return an iterator to the first element of this attribute as a value of
     /// type `T`.
     template <typename T>
@@ -292,11 +279,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     template <typename T>
     auto getValues() const {
       auto beginIt = $_attr.template value_begin<T>();
-      return llvm::make_range(beginIt, std::next(beginIt, size()));
-    }
-    /// Return the value at the given flattened index.
-    template <typename T> T getFlatValue(uint64_t index) const {
-      return *std::next($_attr.template value_begin<T>(), index);
+      return detail::ElementsAttrRange<decltype(beginIt)>(
+        Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
     }
   }] # ElementsAttrInterfaceAccessors;
 
@@ -304,7 +288,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     template <typename T>
     using iterator = detail::ElementsAttrIterator<T>;
     template <typename T>
-    using iterator_range = llvm::iterator_range<iterator<T>>;
+    using iterator_range = detail::ElementsAttrRange<iterator<T>>;
 
     //===------------------------------------------------------------------===//
     // Accessors
@@ -329,8 +313,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
       return getFlattenedIndex(*this, index);
     }
-    static uint64_t getFlattenedIndex(Attribute elementsAttr,
+    static uint64_t getFlattenedIndex(Type type,
                                       ArrayRef<uint64_t> index);
+    static uint64_t getFlattenedIndex(Attribute elementsAttr,
+                                      ArrayRef<uint64_t> index) {
+      return getFlattenedIndex(elementsAttr.getType(), index);
+    }
 
     /// Returns the number of elements held by this attribute.
     int64_t getNumElements() const { return getNumElements(*this); }
@@ -350,13 +338,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
                                   !std::is_base_of<Attribute, T>::value,
                                   ResultT>;
 
-    /// Return the element of this attribute at the given index as a value of
-    /// type 'T'.
-    template <typename T>
-    T getFlatValue(uint64_t index) const {
-      return *std::next(value_begin<T>(), index);
-    }
-
     /// Return the splat value for this attribute. This asserts that the
     /// attribute corresponds to a splat.
     template <typename T>
@@ -368,7 +349,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// Return the elements of this attribute as a value of type 'T'.
     template <typename T>
     DefaultValueCheckT<T, iterator_range<T>> getValues() const {
-      return iterator_range<T>(value_begin<T>(), value_end<T>());
+      return {Attribute::getType(), value_begin<T>(), value_end<T>()};
     }
     template <typename T>
     DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -384,12 +365,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
       llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
     template <typename T>
     using DerivedAttrValueIteratorRange =
-      llvm::iterator_range<DerivedAttrValueIterator<T>>;
+      detail::ElementsAttrRange<DerivedAttrValueIterator<T>>;
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     DerivedAttrValueIteratorRange<T> getValues() const {
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-      return llvm::map_range(getValues<Attribute>(),
-                             static_cast<T (*)(Attribute)>(castFn));
+      return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
+                             static_cast<T (*)(Attribute)>(castFn))};
     }
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     DerivedAttrValueIterator<T> value_begin() const {
@@ -407,8 +388,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// return the iterable range. Otherwise, return llvm::None.
     template <typename T>
     DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
-      if (Optional<iterator<T>> beginIt = try_value_begin<T>())
-        return iterator_range<T>(*beginIt, value_end<T>());
+      if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
+        return iterator_range<T>(Attribute::getType(), *beginIt,
+                                 value_end<T>());
+      }
       return llvm::None;
     }
     template <typename T>
@@ -418,10 +401,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// return the iterable range. Otherwise, return llvm::None.
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
+      auto values = tryGetValues<Attribute>();
+      if (!values)
+        return llvm::None;
+
       auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-      if (auto values = tryGetValues<Attribute>())
-        return llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn));
-      return llvm::None;
+      return DerivedAttrValueIteratorRange<T>(
+        Attribute::getType(),
+        llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
+      );
     }
     template <typename T, typename = DerivedAttrValueCheckT<T>>
     Optional<DerivedAttrValueIterator<T>> try_value_begin() const {

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index cc1d1d74615ea..37da2eb9150b2 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -61,8 +61,7 @@ class DenseElementIndexedIteratorImpl
 };
 
 /// Type trait detector that checks if a given type T is a complex type.
-template <typename T>
-struct is_complex_t : public std::false_type {};
+template <typename T> struct is_complex_t : public std::false_type {};
 template <typename T>
 struct is_complex_t<std::complex<T>> : public std::true_type {};
 } // namespace detail
@@ -82,8 +81,7 @@ class DenseElementsAttr : public Attribute {
   /// floating point type that can be used to access the underlying element
   /// types of a DenseElementsAttr.
   // TODO: Use std::disjunction when C++17 is supported.
-  template <typename T>
-  struct is_valid_cpp_fp_type {
+  template <typename T> struct is_valid_cpp_fp_type {
     /// The type is a valid floating point type if it is a builtin floating
     /// point type, or is a potentially user defined floating point type. The
     /// latter allows for supporting users that have custom types defined for
@@ -219,6 +217,18 @@ class DenseElementsAttr : public Attribute {
   // Iterators
   //===--------------------------------------------------------------------===//
 
+  /// The iterator range over the given iterator type T.
+  template <typename IteratorT>
+  using iterator_range_impl = detail::ElementsAttrRange<IteratorT>;
+
+  /// The iterator for the given element type T.
+  template <typename T, typename AttrT = DenseElementsAttr>
+  using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
+  /// The iterator range over the given element T.
+  template <typename T, typename AttrT = DenseElementsAttr>
+  using iterator_range =
+      decltype(std::declval<AttrT>().template getValues<T>());
+
   /// A utility iterator that allows walking over the internal Attribute values
   /// of a DenseElementsAttr.
   class AttributeElementIterator
@@ -358,22 +368,7 @@ class DenseElementsAttr : public Attribute {
                               !std::is_same<Attribute, T>::value,
                           T>::type
   getSplatValue() const {
-    return getSplatValue().template cast<T>();
-  }
-
-  /// Return the value at the given index. The 'index' is expected to refer to a
-  /// valid element.
-  Attribute getValue(ArrayRef<uint64_t> index) const {
-    return getValue<Attribute>(index);
-  }
-  template <typename T>
-  T getValue(ArrayRef<uint64_t> index) const {
-    // Skip to the element corresponding to the flattened index.
-    return getFlatValue<T>(ElementsAttr::getFlattenedIndex(*this, index));
-  }
-  /// Return the value at the given flattened index.
-  template <typename T> T getFlatValue(uint64_t index) const {
-    return *std::next(value_begin<T>(), index);
+    return getSplatValue<Attribute>().template cast<T>();
   }
 
   /// Return the held element values as a range of integer or floating-point
@@ -384,12 +379,12 @@ class DenseElementsAttr : public Attribute {
                                std::numeric_limits<T>::is_integer) ||
                               is_valid_cpp_fp_type<T>::value>::type;
   template <typename T, typename = IntFloatValueTemplateCheckT<T>>
-  llvm::iterator_range<ElementIterator<T>> getValues() const {
+  iterator_range_impl<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
                              std::numeric_limits<T>::is_signed));
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {ElementIterator<T>(rawData, splat, 0),
+    return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
   template <typename T, typename = IntFloatValueTemplateCheckT<T>>
@@ -413,12 +408,12 @@ class DenseElementsAttr : public Attribute {
                                is_valid_cpp_fp_type<ElementT>::value)>::type;
   template <typename T, typename ElementT = typename T::value_type,
             typename = ComplexValueTemplateCheckT<T, ElementT>>
-  llvm::iterator_range<ElementIterator<T>> getValues() const {
+  iterator_range_impl<ElementIterator<T>> getValues() const {
     assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
                           std::numeric_limits<ElementT>::is_signed));
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {ElementIterator<T>(rawData, splat, 0),
+    return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
   template <typename T, typename ElementT = typename T::value_type,
@@ -441,11 +436,11 @@ class DenseElementsAttr : public Attribute {
   using StringRefValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, StringRef>::value>::type;
   template <typename T, typename = StringRefValueTemplateCheckT<T>>
-  llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
+  iterator_range_impl<ElementIterator<StringRef>> getValues() const {
     auto stringRefs = getRawStringData();
     const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
     bool splat = isSplat();
-    return {ElementIterator<StringRef>(ptr, splat, 0),
+    return {Attribute::getType(), ElementIterator<StringRef>(ptr, splat, 0),
             ElementIterator<StringRef>(ptr, splat, getNumElements())};
   }
   template <typename T, typename = StringRefValueTemplateCheckT<T>>
@@ -464,8 +459,9 @@ class DenseElementsAttr : public Attribute {
   using AttributeValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, Attribute>::value>::type;
   template <typename T, typename = AttributeValueTemplateCheckT<T>>
-  llvm::iterator_range<AttributeElementIterator> getValues() const {
-    return {value_begin<Attribute>(), value_end<Attribute>()};
+  iterator_range_impl<AttributeElementIterator> getValues() const {
+    return {Attribute::getType(), value_begin<Attribute>(),
+            value_end<Attribute>()};
   }
   template <typename T, typename = AttributeValueTemplateCheckT<T>>
   AttributeElementIterator value_begin() const {
@@ -486,10 +482,11 @@ class DenseElementsAttr : public Attribute {
   using DerivedAttributeElementIterator =
       llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
-  llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
+  iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
     auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
-    return llvm::map_range(getValues<Attribute>(),
-                           static_cast<T (*)(Attribute)>(castFn));
+    return {Attribute::getType(),
+            llvm::map_range(getValues<Attribute>(),
+                            static_cast<T (*)(Attribute)>(castFn))};
   }
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
   DerivedAttributeElementIterator<T> value_begin() const {
@@ -508,9 +505,9 @@ class DenseElementsAttr : public Attribute {
   using BoolValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, bool>::value>::type;
   template <typename T, typename = BoolValueTemplateCheckT<T>>
-  llvm::iterator_range<BoolElementIterator> getValues() const {
+  iterator_range_impl<BoolElementIterator> getValues() const {
     assert(isValidBool() && "bool is not the value of this elements attribute");
-    return {BoolElementIterator(*this, 0),
+    return {Attribute::getType(), BoolElementIterator(*this, 0),
             BoolElementIterator(*this, getNumElements())};
   }
   template <typename T, typename = BoolValueTemplateCheckT<T>>
@@ -530,9 +527,9 @@ class DenseElementsAttr : public Attribute {
   using APIntValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, APInt>::value>::type;
   template <typename T, typename = APIntValueTemplateCheckT<T>>
-  llvm::iterator_range<IntElementIterator> getValues() const {
+  iterator_range_impl<IntElementIterator> getValues() const {
     assert(getElementType().isIntOrIndex() && "expected integral type");
-    return {raw_int_begin(), raw_int_end()};
+    return {Attribute::getType(), raw_int_begin(), raw_int_end()};
   }
   template <typename T, typename = APIntValueTemplateCheckT<T>>
   IntElementIterator value_begin() const {
@@ -551,7 +548,7 @@ class DenseElementsAttr : public Attribute {
   using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
       std::is_same<T, std::complex<APInt>>::value>::type;
   template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
-  llvm::iterator_range<ComplexIntElementIterator> getValues() const {
+  iterator_range_impl<ComplexIntElementIterator> getValues() const {
     return getComplexIntValues();
   }
   template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
@@ -569,7 +566,7 @@ class DenseElementsAttr : public Attribute {
   using APFloatValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, APFloat>::value>::type;
   template <typename T, typename = APFloatValueTemplateCheckT<T>>
-  llvm::iterator_range<FloatElementIterator> getValues() const {
+  iterator_range_impl<FloatElementIterator> getValues() const {
     return getFloatValues();
   }
   template <typename T, typename = APFloatValueTemplateCheckT<T>>
@@ -587,7 +584,7 @@ class DenseElementsAttr : public Attribute {
   using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
       std::is_same<T, std::complex<APFloat>>::value>::type;
   template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
-  llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
+  iterator_range_impl<ComplexFloatElementIterator> getValues() const {
     return getComplexFloatValues();
   }
   template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
@@ -660,13 +657,13 @@ class DenseElementsAttr : public Attribute {
   IntElementIterator raw_int_end() const {
     return IntElementIterator(*this, getNumElements());
   }
-  llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+  iterator_range_impl<ComplexIntElementIterator> getComplexIntValues() const;
   ComplexIntElementIterator complex_value_begin() const;
   ComplexIntElementIterator complex_value_end() const;
-  llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+  iterator_range_impl<FloatElementIterator> getFloatValues() const;
   FloatElementIterator float_value_begin() const;
   FloatElementIterator float_value_end() const;
-  llvm::iterator_range<ComplexFloatElementIterator>
+  iterator_range_impl<ComplexFloatElementIterator>
   getComplexFloatValues() const;
   ComplexFloatElementIterator complex_float_value_begin() const;
   ComplexFloatElementIterator complex_float_value_end() const;
@@ -872,8 +869,7 @@ class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
 //===----------------------------------------------------------------------===//
 
 template <typename T>
-auto SparseElementsAttr::getValues() const
-    -> llvm::iterator_range<iterator<T>> {
+auto SparseElementsAttr::value_begin() const -> iterator<T> {
   auto zeroValue = getZeroValue<T>();
   auto valueIt = getValues().value_begin<T>();
   const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
@@ -888,15 +884,7 @@ auto SparseElementsAttr::getValues() const
         // Otherwise, return the zero value.
         return zeroValue;
       };
-  return llvm::map_range(llvm::seq<ptr
diff _t>(0, getNumElements()), mapFn);
-}
-template <typename T>
-auto SparseElementsAttr::value_begin() const -> iterator<T> {
-  return getValues<T>().begin();
-}
-template <typename T>
-auto SparseElementsAttr::value_end() const -> iterator<T> {
-  return getValues<T>().end();
+  return iterator<T>(llvm::seq<ptr
diff _t>(0, getNumElements()).begin(), mapFn);
 }
 } // end namespace mlir.
 

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 01af84c421e97..c6631cd79fe58 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -174,9 +174,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
                         "ArrayRef<char>":$rawData);
   let extraClassDeclaration = [{
     using DenseElementsAttr::empty;
-    using DenseElementsAttr::getFlatValue;
     using DenseElementsAttr::getNumElements;
-    using DenseElementsAttr::getValue;
     using DenseElementsAttr::getValues;
     using DenseElementsAttr::isSplat;
     using DenseElementsAttr::size;
@@ -313,9 +311,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
   ];
   let extraClassDeclaration = [{
     using DenseElementsAttr::empty;
-    using DenseElementsAttr::getFlatValue;
     using DenseElementsAttr::getNumElements;
-    using DenseElementsAttr::getValue;
     using DenseElementsAttr::getValues;
     using DenseElementsAttr::isSplat;
     using DenseElementsAttr::size;
@@ -712,10 +708,6 @@ def Builtin_OpaqueElementsAttr : Builtin_Attr<
   let extraClassDeclaration = [{
     using ValueType = StringRef;
 
-    /// Return the value at the given index. The 'index' is expected to refer to
-    /// a valid element.
-    Attribute getValue(ArrayRef<uint64_t> index) const;
-
     /// Decodes the attribute value using dialect-specific decoding hook.
     /// Returns false if decoding is successful. If not, returns true and leaves
     /// 'result' argument unspecified.
@@ -802,6 +794,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       // String types.
       StringRef
     >;
+    using ElementsAttr::Trait<SparseElementsAttr>::getValues;
 
     /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
     template <typename T>
@@ -817,13 +810,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
     /// Return the values of this attribute in the form of the given type 'T'.
     /// 'T'  may be any of Attribute, APInt, APFloat, c++ integer/float types,
     /// etc.
-    template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
     template <typename T> iterator<T> value_begin() const;
-    template <typename T> iterator<T> value_end() const;
-
-    /// Return the value of the element at the given index. The 'index' is
-    /// expected to refer to a valid element.
-    Attribute getValue(ArrayRef<uint64_t> index) const;
 
   private:
     /// Get a zero APFloat for the given sparse attribute.

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 3b15212e30023..8d6c4ccf6a8b0 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -288,8 +288,9 @@ bool mlirAttributeIsAElements(MlirAttribute attr) {
 
 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
                                        uint64_t *idxs) {
-  return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
-      llvm::makeArrayRef(idxs, rank)));
+  return wrap(unwrap(attr)
+                  .cast<ElementsAttr>()
+                  .getValues<Attribute>()[llvm::makeArrayRef(idxs, rank)]);
 }
 
 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
@@ -482,7 +483,8 @@ bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
 }
 
 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
-  return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
+  return wrap(
+      unwrap(attr).cast<DenseElementsAttr>().getSplatValue<Attribute>());
 }
 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
@@ -520,36 +522,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
 // Indexed accessors.
 
 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<bool>()[pos];
 }
 int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>()[pos];
 }
 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
 }
 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
 }
 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>()[pos];
 }
 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>()[pos];
 }
 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>()[pos];
 }
 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<float>()[pos];
 }
 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
-  return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<double>()[pos];
 }
 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
                                                   intptr_t pos) {
   return wrap(
-      unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
+      unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>()[pos]);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 88eca4600a423..d55deb5ce84ac 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -169,7 +169,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
     return failure();
 
   auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
-  auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
+  auto val = workGroupSizeAttr.getValues<int32_t>()[index.getValue()];
   auto convertedType =
       getTypeConverter()->convertType(op.getResult().getType());
   if (!convertedType)

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 46c689d9b1775..b4ea696f80d0c 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering
       // For scalar memrefs, the global variable created is of the element type,
       // so unpack the elements attribute to extract the value.
       if (type.getRank() == 0)
-        initialValue = elementsAttr.getValue({});
+        initialValue = elementsAttr.getValues<Attribute>()[0];
     }
 
     uint64_t alignment = global.alignment().getValueOr(0);

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2dae2feb2f4c6..1acb0a565daeb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2415,8 +2415,7 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
 // AffineMinMaxOpBase
 //===----------------------------------------------------------------------===//
 
-template <typename T>
-static LogicalResult verifyAffineMinMaxOp(T op) {
+template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
   // Verify that operand count matches affine map dimension and symbol count.
   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
     return op.emitOpError(
@@ -2424,8 +2423,7 @@ static LogicalResult verifyAffineMinMaxOp(T op) {
   return success();
 }
 
-template <typename T>
-static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
+template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
   p << ' ' << op->getAttr(T::getMapAttrName());
   auto operands = op.getOperands();
   unsigned numDims = op.map().getNumDims();
@@ -2532,8 +2530,7 @@ struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
 ///
 ///   %1 = affine.min affine_map<
 ///          ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
-template <typename T>
-struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
+template <typename T> struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
   using OpRewritePattern<T>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(T affineOp,
@@ -2890,19 +2887,19 @@ AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
 }
 
 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
+  auto values = lowerBoundsGroups().getValues<int32_t>();
   unsigned start = 0;
   for (unsigned i = 0; i < pos; ++i)
-    start += lowerBoundsGroups().getValue<int32_t>(i);
-  return lowerBoundsMap().getSliceMap(
-      start, lowerBoundsGroups().getValue<int32_t>(pos));
+    start += values[i];
+  return lowerBoundsMap().getSliceMap(start, values[pos]);
 }
 
 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
+  auto values = upperBoundsGroups().getValues<int32_t>();
   unsigned start = 0;
   for (unsigned i = 0; i < pos; ++i)
-    start += upperBoundsGroups().getValue<int32_t>(i);
-  return upperBoundsMap().getSliceMap(
-      start, upperBoundsGroups().getValue<int32_t>(pos));
+    start += values[i];
+  return upperBoundsMap().getSliceMap(start, values[pos]);
 }
 
 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index d9560bf9139d4..ea4d7a69c0633 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -163,8 +163,8 @@ constexpr const static unsigned kBitsInByte = 8;
 /// Returns the value that corresponds to named position `pos` from the
 /// attribute `attr` assuming it's a dense integer elements attribute.
 static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) {
-  return attr.cast<DenseIntElementsAttr>().getValue<unsigned>(
-      static_cast<unsigned>(pos));
+  return attr.cast<DenseIntElementsAttr>()
+      .getValues<unsigned>()[static_cast<unsigned>(pos)];
 }
 
 /// Returns the part of the data layout entry that corresponds to `pos` for the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ee5622d2662db..fa836ed9577aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1184,7 +1184,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
               splatAttr.isSplat() &&
               splatAttr.getType().getElementType().isIntOrFloat()) {
-            constantAttr = splatAttr.getSplatValue();
+            constantAttr = splatAttr.getSplatValue<Attribute>();
             return true;
           }
         }
@@ -1455,10 +1455,9 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     bool isFloat = elementType.isa<FloatType>();
     if (isFloat) {
-      SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
-          inputFpIterators;
+      SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
       for (int i = 0; i < numInputs; ++i)
-        inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
+        inFpRanges.push_back(inputValues[i].getValues<APFloat>());
 
       computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
 
@@ -1469,22 +1468,17 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
         computeRemappedLinearIndex(linearIndex);
 
         // Collect constant elements for all inputs at this loop iteration.
-        for (int i = 0; i < numInputs; ++i) {
-          computeFnInputs.apFloats[i] =
-              *(inputFpIterators[i].begin() + srcLinearIndices[i]);
-        }
+        for (int i = 0; i < numInputs; ++i)
+          computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
 
         // Invoke the computation to get the corresponding constant output
         // element.
-        APIntOrFloat outputs = computeFn(computeFnInputs);
-
-        fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
+        fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
       }
     } else {
-      SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
-          inputIntIterators;
+      SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
       for (int i = 0; i < numInputs; ++i)
-        inputIntIterators.push_back(inputValues[i].getValues<APInt>());
+        inIntRanges.push_back(inputValues[i].getValues<APInt>());
 
       computeFnInputs.apInts.resize(numInputs);
 
@@ -1495,25 +1489,19 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
         computeRemappedLinearIndex(linearIndex);
 
         // Collect constant elements for all inputs at this loop iteration.
-        for (int i = 0; i < numInputs; ++i) {
-          computeFnInputs.apInts[i] =
-              *(inputIntIterators[i].begin() + srcLinearIndices[i]);
-        }
+        for (int i = 0; i < numInputs; ++i)
+          computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
 
         // Invoke the computation to get the corresponding constant output
         // element.
-        APIntOrFloat outputs = computeFn(computeFnInputs);
-
-        intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
+        intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
       }
     }
 
-    DenseIntOrFPElementsAttr outputAttr;
-    if (isFloat) {
-      outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
-    } else {
-      outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
-    }
+    DenseElementsAttr outputAttr =
+        isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
+                : DenseElementsAttr::get(outputType, intOutputValues);
+
     rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 758aeecf380d3..af3e528212f7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -913,9 +913,9 @@ struct DownscaleSizeOneWindowed2DConvolution final
         loc, newOutputType, output, ioReshapeIndices);
 
     // We need to shrink the strides and dilations too.
-    auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
+    auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0];
     auto stridesAttr = rewriter.getI64VectorAttr(stride);
-    auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
+    auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0];
     auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
 
     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 437c762b5f550..c5125860d4379 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -56,7 +56,7 @@ static Attribute extractCompositeElement(Attribute composite,
 
   if (auto vector = composite.dyn_cast<ElementsAttr>()) {
     assert(indices.size() == 1 && "must have exactly one index for a vector");
-    return vector.getValue({indices[0]});
+    return vector.getValues<Attribute>()[indices[0]];
   }
 
   if (auto array = composite.dyn_cast<ArrayAttr>()) {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 635d3c05a72e3..27d60b5f02f34 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1138,7 +1138,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
     return nullptr;
   if (dim.getValue() >= elements.getNumElements())
     return nullptr;
-  return elements.getValue({(uint64_t)dim.getValue()});
+  return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
 }
 
 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 91470be5e7a5c..6bc2d7fd436d9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1304,13 +1304,14 @@ static void printSwitchOpCases(
   if (!caseValues)
     return;
 
-  for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
+  for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
     p << ',';
     p.printNewline();
     p << "  ";
-    p << caseValues.getValue<APInt>(i).getLimitedValue();
+    p << it.value().getLimitedValue();
     p << ": ";
-    p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
+    p.printSuccessorAndUseList(caseDestinations[it.index()],
+                               caseOperands[it.index()]);
   }
   p.printNewline();
 }
@@ -1353,9 +1354,9 @@ Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
 
   SuccessorRange caseDests = getCaseDestinations();
   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
-    for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i)
-      if (value == caseValues->getValue<IntegerAttr>(i))
-        return caseDests[i];
+    for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
+      if (it.value() == value.getValue())
+        return caseDests[it.index()];
     return getDefaultDestination();
   }
   return nullptr;
@@ -1394,15 +1395,15 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
   auto caseValues = op.getCaseValues();
   auto caseDests = op.getCaseDestinations();
 
-  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
-    if (caseDests[i] == op.getDefaultDestination() &&
-        op.getCaseOperands(i) == op.getDefaultOperands()) {
+  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
+    if (caseDests[it.index()] == op.getDefaultDestination() &&
+        op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
       requiresChange = true;
       continue;
     }
-    newCaseDestinations.push_back(caseDests[i]);
-    newCaseOperands.push_back(op.getCaseOperands(i));
-    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+    newCaseDestinations.push_back(caseDests[it.index()]);
+    newCaseOperands.push_back(op.getCaseOperands(it.index()));
+    newCaseValues.push_back(it.value());
   }
 
   if (!requiresChange)
@@ -1424,10 +1425,11 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
                        APInt caseValue) {
   auto caseValues = op.getCaseValues();
-  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
-    if (caseValues->getValue<APInt>(i) == caseValue) {
-      rewriter.replaceOpWithNewOp<BranchOp>(op, op.getCaseDestinations()[i],
-                                            op.getCaseOperands(i));
+  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
+    if (it.value() == caseValue) {
+      rewriter.replaceOpWithNewOp<BranchOp>(
+          op, op.getCaseDestinations()[it.index()],
+          op.getCaseOperands(it.index()));
       return;
     }
   }
@@ -1551,22 +1553,16 @@ simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
     return failure();
 
   // Fold this switch to an unconditional branch.
-  APInt caseValue;
-  bool isDefault = true;
   SuccessorRange predDests = predSwitch.getCaseDestinations();
-  Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
-  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
-    if (currentBlock == predDests[i]) {
-      caseValue = predCaseValues->getValue<APInt>(i);
-      isDefault = false;
-      break;
-    }
-  }
-  if (isDefault)
+  auto it = llvm::find(predDests, currentBlock);
+  if (it != predDests.end()) {
+    Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
+    foldSwitch(op, rewriter,
+               predCaseValues->getValues<APInt>()[it - predDests.begin()]);
+  } else {
     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
                                           op.getDefaultOperands());
-  else
-    foldSwitch(op, rewriter, caseValue);
+  }
   return success();
 }
 
@@ -1613,7 +1609,7 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
   auto predCaseValues = predSwitch.getCaseValues();
   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
     if (currentBlock != predDests[i])
-      caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
+      caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
 
   SmallVector<Block *> newCaseDestinations;
   SmallVector<ValueRange> newCaseOperands;
@@ -1622,14 +1618,14 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
 
   auto caseValues = op.getCaseValues();
   auto caseDests = op.getCaseDestinations();
-  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
-    if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
+  for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
+    if (caseValuesToRemove.contains(it.value())) {
       requiresChange = true;
       continue;
     }
-    newCaseDestinations.push_back(caseDests[i]);
-    newCaseOperands.push_back(op.getCaseOperands(i));
-    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+    newCaseDestinations.push_back(caseDests[it.index()]);
+    newCaseOperands.push_back(op.getCaseOperands(it.index()));
+    newCaseValues.push_back(it.value());
   }
 
   if (!requiresChange)

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 287289b4cd6d3..1d8d8e2d50688 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -340,7 +340,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
   // If this is a splat elements attribute, simply return the value. All of the
   // elements of a splat attribute are the same.
   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
-    return splatTensor.getSplatValue();
+    return splatTensor.getSplatValue<Attribute>();
 
   // Otherwise, collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
@@ -353,7 +353,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
   // If this is an elements attribute, query the value at the given indices.
   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
   if (elementsAttr && elementsAttr.isValidIndex(indices))
-    return elementsAttr.getValue(indices);
+    return elementsAttr.getValues<Attribute>()[indices];
   return {};
 }
 
@@ -440,7 +440,7 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
   Attribute dest = operands[1];
   if (scalar && dest)
     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
-      if (scalar == splatDest.getSplatValue())
+      if (scalar == splatDest.getSplatValue<Attribute>())
         return dest;
   return {};
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 85415d92bd1b6..2a435476e5bef 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -230,6 +230,7 @@ struct ConstantTransposeOptimization
     // Transpose the input constant. Because we don't know its rank in advance,
     // we need to loop over the range [0, element count) and delinearize the
     // index.
+    auto attrValues = inputValues.getValues<Attribute>();
     for (int srcLinearIndex = 0; srcLinearIndex < numElements;
          ++srcLinearIndex) {
       SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
@@ -247,7 +248,7 @@ struct ConstantTransposeOptimization
       for (int dim = 1; dim < outputType.getRank(); ++dim)
         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
 
-      outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
+      outputValues[dstLinearIndex] = attrValues[srcIndices];
     }
 
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(
@@ -424,8 +425,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
 
-template <typename T>
-static LogicalResult verifyConvOp(T op) {
+template <typename T> static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();

diff  --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 96992c219aa0c..fd289917c64cc 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -56,15 +56,15 @@ bool ElementsAttr::isValidIndex(Attribute elementsAttr,
   return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
 }
 
-uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
-                                         ArrayRef<uint64_t> index) {
-  ShapedType type = elementsAttr.getType().cast<ShapedType>();
-  assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
+uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
+  ShapedType shapeType = type.cast<ShapedType>();
+  assert(isValidIndex(shapeType, index) &&
+         "expected valid multi-dimensional index");
 
   // Reduce the provided multidimensional index into a flattended 1D row-major
   // index.
-  auto rank = type.getRank();
-  auto shape = type.getShape();
+  auto rank = shapeType.getRank();
+  ArrayRef<int64_t> shape = shapeType.getShape();
   uint64_t valueIndex = 0;
   uint64_t dimMultiplier = 1;
   for (int i = rank - 1; i >= 0; --i) {

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 41a8c46c0c6d9..38c8430268985 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -902,10 +902,10 @@ LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
 }
 
 auto DenseElementsAttr::getComplexIntValues() const
-    -> llvm::iterator_range<ComplexIntElementIterator> {
+    -> iterator_range_impl<ComplexIntElementIterator> {
   assert(isComplexOfIntType(getElementType()) &&
          "expected complex integral type");
-  return {ComplexIntElementIterator(*this, 0),
+  return {getType(), ComplexIntElementIterator(*this, 0),
           ComplexIntElementIterator(*this, getNumElements())};
 }
 auto DenseElementsAttr::complex_value_begin() const
@@ -923,10 +923,10 @@ auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
 /// Return the held element values as a range of APFloat. The element type of
 /// this attribute must be of float type.
 auto DenseElementsAttr::getFloatValues() const
-    -> llvm::iterator_range<FloatElementIterator> {
+    -> iterator_range_impl<FloatElementIterator> {
   auto elementType = getElementType().cast<FloatType>();
   const auto &elementSemantics = elementType.getFloatSemantics();
-  return {FloatElementIterator(elementSemantics, raw_int_begin()),
+  return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
           FloatElementIterator(elementSemantics, raw_int_end())};
 }
 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
@@ -939,11 +939,12 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
 }
 
 auto DenseElementsAttr::getComplexFloatValues() const
-    -> llvm::iterator_range<ComplexFloatElementIterator> {
+    -> iterator_range_impl<ComplexFloatElementIterator> {
   Type eltTy = getElementType().cast<ComplexType>().getElementType();
   assert(eltTy.isa<FloatType>() && "expected complex float type");
   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
-  return {{semantics, {*this, 0}},
+  return {getType(),
+          {semantics, {*this, 0}},
           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
 }
 auto DenseElementsAttr::complex_float_value_begin() const
@@ -1248,13 +1249,6 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
 // OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
 
-/// Return the value at the given index. If index does not refer to a valid
-/// element, then a null attribute is returned.
-Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
-  assert(isValidIndex(index) && "expected valid multi-dimensional index");
-  return Attribute();
-}
-
 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
   Dialect *dialect = getDialect().getDialect();
   if (!dialect)
@@ -1279,47 +1273,6 @@ OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // SparseElementsAttr
 //===----------------------------------------------------------------------===//
 
-/// Return the value of the element at the given index.
-Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
-  assert(isValidIndex(index) && "expected valid multi-dimensional index");
-  auto type = getType();
-
-  // The sparse indices are 64-bit integers, so we can reinterpret the raw data
-  // as a 1-D index array.
-  auto sparseIndices = getIndices();
-  auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
-
-  // Check to see if the indices are a splat.
-  if (sparseIndices.isSplat()) {
-    // If the index is also not a splat of the index value, we know that the
-    // value is zero.
-    auto splatIndex = *sparseIndexValues.begin();
-    if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
-      return getZeroAttr();
-
-    // If the indices are a splat, we also expect the values to be a splat.
-    assert(getValues().isSplat() && "expected splat values");
-    return getValues().getSplatValue();
-  }
-
-  // Build a mapping between known indices and the offset of the stored element.
-  llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
-  auto numSparseIndices = sparseIndices.getType().getDimSize(0);
-  size_t rank = type.getRank();
-  for (size_t i = 0, e = numSparseIndices; i != e; ++i)
-    mappedIndices.try_emplace(
-        {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
-
-  // Look for the provided index key within the mapped indices. If the provided
-  // index is not found, then return a zero attribute.
-  auto it = mappedIndices.find(index);
-  if (it == mappedIndices.end())
-    return getZeroAttr();
-
-  // Otherwise, return the held sparse value element.
-  return getValues().getValue(it->second);
-}
-
 /// Get a zero APFloat for the given sparse attribute.
 APFloat SparseElementsAttr::getZeroAPFloat() const {
   auto eltType = getElementType().cast<FloatType>();

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9676553aeaa57..67c9ccbaec5be 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
     return t.cast<ShapedType>().getDimSize(index);
   if (auto attr = val.dyn_cast<Attribute>())
     return attr.cast<DenseIntElementsAttr>()
-        .getFlatValue<APInt>(index)
+        .getValues<APInt>()[index]
         .getSExtValue();
   auto *stc = val.get<ShapedTypeComponents *>();
   return stc->getDims()[index];

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 864bba19b2ca8..6c10e61dcc857 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -386,14 +386,12 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
-    auto weights = condbrOp.getBranchWeights();
     llvm::MDNode *branchWeights = nullptr;
-    if (weights) {
+    if (auto weights = condbrOp.getBranchWeights()) {
       // Map weight attributes to LLVM metadata.
-      auto trueWeight =
-          weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
-      auto falseWeight =
-          weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
+      auto weightValues = weights->getValues<APInt>();
+      auto trueWeight = weightValues[0].getSExtValue();
+      auto falseWeight = weightValues[1].getSExtValue();
       branchWeights =
           llvm::MDBuilder(moduleTranslation.getLLVMContext())
               .createBranchWeights(static_cast<uint32_t>(trueWeight),

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 24eea6c317116..5701b44a2a9f2 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -706,11 +706,12 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
   if (shapedType.getRank() == dim) {
     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
       return attr.getType().getElementType().isInteger(1)
-                 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
-                 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
+                 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
+                 : prepareConstantInt(loc,
+                                      attr.getValues<IntegerAttr>()[index]);
     }
     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
-      return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
+      return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
     }
     return 0;
   }

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index b1b8077016767..b8edbea19a401 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -154,9 +154,9 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:  LogicalResult verifyIndexingMapRequiredAttributes();
 
 #       IMPL:  getSymbolBindings(Test2Op self)
-#       IMPL:  cst2 = self.strides().getValue<int64_t>({ 0 });
+#       IMPL:  cst2 = self.strides().getValues<int64_t>()[0];
 #  IMPL-NEXT:  getAffineConstantExpr(cst2, context)
-#       IMPL:  cst3 = self.strides().getValue<int64_t>({ 1 });
+#       IMPL:  cst3 = self.strides().getValues<int64_t>()[1];
 #  IMPL-NEXT:  getAffineConstantExpr(cst3, context)
 
 #       IMPL:  Test2Op::indexing_maps()

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index bcf3616f8b0af..507713e256782 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -142,8 +142,7 @@ namespace yaml {
 /// Top-level type containing op metadata and one of a concrete op type.
 /// Currently, the only defined op type is `structured_op` (maps to
 /// `LinalgStructuredOpConfig`).
-template <>
-struct MappingTraits<LinalgOpConfig> {
+template <> struct MappingTraits<LinalgOpConfig> {
   static void mapping(IO &io, LinalgOpConfig &info) {
     io.mapOptional("metadata", info.metadata);
     io.mapOptional("structured_op", info.structuredOp);
@@ -156,8 +155,7 @@ struct MappingTraits<LinalgOpConfig> {
 ///   - List of indexing maps (see `LinalgIndexingMaps`).
 ///   - Iterator types (see `LinalgIteratorTypeDef`).
 ///   - List of scalar level assignment (see `ScalarAssign`).
-template <>
-struct MappingTraits<LinalgStructuredOpConfig> {
+template <> struct MappingTraits<LinalgStructuredOpConfig> {
   static void mapping(IO &io, LinalgStructuredOpConfig &info) {
     io.mapRequired("args", info.args);
     io.mapRequired("indexing_maps", info.indexingMaps);
@@ -180,8 +178,7 @@ struct MappingTraits<LinalgStructuredOpConfig> {
 ///     attribute symbols. During op creation these symbols are replaced by the
 ///     corresponding `name` attribute values. Only attribute arguments have
 ///     an `attribute_map`.
-template <>
-struct MappingTraits<LinalgOperandDef> {
+template <> struct MappingTraits<LinalgOperandDef> {
   static void mapping(IO &io, LinalgOperandDef &info) {
     io.mapRequired("name", info.name);
     io.mapRequired("usage", info.usage);
@@ -192,8 +189,7 @@ struct MappingTraits<LinalgOperandDef> {
 };
 
 /// Usage enum for a named argument.
-template <>
-struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
+template <> struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
   static void enumeration(IO &io, LinalgOperandDefUsage &value) {
     io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
     io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
@@ -202,8 +198,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
 };
 
 /// Iterator type enum.
-template <>
-struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
+template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
   static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
     io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
     io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
@@ -211,8 +206,7 @@ struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
 };
 
 /// Metadata about the op (name, C++ name, and documentation).
-template <>
-struct MappingTraits<LinalgOpMetadata> {
+template <> struct MappingTraits<LinalgOpMetadata> {
   static void mapping(IO &io, LinalgOpMetadata &info) {
     io.mapRequired("name", info.name);
     io.mapRequired("cpp_class_name", info.cppClassName);
@@ -226,8 +220,7 @@ struct MappingTraits<LinalgOpMetadata> {
 ///     some symbols that bind to attributes of the op. Each indexing map must
 ///     be normalized over the same list of dimensions, and its symbols must
 ///     match the symbols for argument shapes.
-template <>
-struct MappingTraits<LinalgIndexingMapsConfig> {
+template <> struct MappingTraits<LinalgIndexingMapsConfig> {
   static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
     io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
   }
@@ -237,8 +230,7 @@ struct MappingTraits<LinalgIndexingMapsConfig> {
 ///   - The `arg` name must match a named output.
 ///   - The `value` is a scalar expression for computing the value to
 ///     assign (see `ScalarExpression`).
-template <>
-struct MappingTraits<ScalarAssign> {
+template <> struct MappingTraits<ScalarAssign> {
   static void mapping(IO &io, ScalarAssign &info) {
     io.mapRequired("arg", info.arg);
     io.mapRequired("value", info.value);
@@ -250,8 +242,7 @@ struct MappingTraits<ScalarAssign> {
 ///   - `scalar_apply`: Result of evaluating a named function (see
 ///      `ScalarApply`).
 ///   - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere.
-template <>
-struct MappingTraits<ScalarExpression> {
+template <> struct MappingTraits<ScalarExpression> {
   static void mapping(IO &io, ScalarExpression &info) {
     io.mapOptional("scalar_arg", info.arg);
     io.mapOptional("scalar_const", info.constant);
@@ -266,16 +257,14 @@ struct MappingTraits<ScalarExpression> {
 /// functions include:
 ///   - `add(lhs, rhs)`
 ///   - `mul(lhs, rhs)`
-template <>
-struct MappingTraits<ScalarApply> {
+template <> struct MappingTraits<ScalarApply> {
   static void mapping(IO &io, ScalarApply &info) {
     io.mapRequired("fn_name", info.fnName);
     io.mapRequired("operands", info.operands);
   }
 };
 
-template <>
-struct MappingTraits<ScalarSymbolicCast> {
+template <> struct MappingTraits<ScalarSymbolicCast> {
   static void mapping(IO &io, ScalarSymbolicCast &info) {
     io.mapRequired("type_var", info.typeVar);
     io.mapRequired("operands", info.operands);
@@ -285,8 +274,7 @@ struct MappingTraits<ScalarSymbolicCast> {
 
 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
 /// the same.
-template <>
-struct ScalarTraits<SerializedAffineMap> {
+template <> struct ScalarTraits<SerializedAffineMap> {
   static void output(const SerializedAffineMap &value, void *rawYamlContext,
                      raw_ostream &out) {
     assert(value.affineMapAttr);
@@ -726,7 +714,7 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
       // {1}: Symbol position
       // {2}: Attribute index
       static const char structuredOpAccessAttrFormat[] = R"FMT(
-int64_t cst{1} = self.{0}().getValue<int64_t>({ {2} });
+int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}];
 exprs.push_back(getAffineConstantExpr(cst{1}, context));
 )FMT";
       // Update all symbol bindings mapped to an attribute.

diff  --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index 33c6360a4e908..5125413a6c11c 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -113,7 +113,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
   EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
 
   // Check Elements attribute element value is expected.
-  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  auto firstValue =
+      returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
 }
 
@@ -138,7 +139,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
   EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
 
   // Check Elements attribute element value is expected.
-  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  auto firstValue =
+      returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
 }
 
@@ -162,7 +164,8 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
   EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
 
   // Check Elements attribute element value is expected.
-  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  auto firstValue =
+      returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
 }
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index aaff61e7d5f9f..19b57fa754df1 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -202,7 +202,7 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
   RankedTensorType shape = RankedTensorType::get({}, intTy);
 
   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
-  EXPECT_TRUE(attr.getValue({0}) == value);
+  EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
 }
 
 TEST(SparseElementsAttrTest, GetZero) {
@@ -238,15 +238,15 @@ TEST(SparseElementsAttrTest, GetZero) {
 
   // Only index (0, 0) contains an element, others are supposed to return
   // the zero/empty value.
-  auto zeroIntValue = sparseInt.getValue({1, 1});
+  auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
   EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
   EXPECT_TRUE(zeroIntValue.getType() == intTy);
 
-  auto zeroFloatValue = sparseFloat.getValue({1, 1});
+  auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
   EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
 
-  auto zeroStringValue = sparseString.getValue({1, 1});
+  auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
   EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
   EXPECT_TRUE(zeroStringValue.getType() == stringTy);
 }


        


More information about the Mlir-commits mailing list