[llvm] d80d3a3 - [mlir] Refactor ElementsAttr into an AttrInterface

River Riddle via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 20 18:58:14 PDT 2021


Author: River Riddle
Date: 2021-09-21T01:57:43Z
New Revision: d80d3a358fffce430c94c7e9c716a5641010e4d0

URL: https://github.com/llvm/llvm-project/commit/d80d3a358fffce430c94c7e9c716a5641010e4d0
DIFF: https://github.com/llvm/llvm-project/commit/d80d3a358fffce430c94c7e9c716a5641010e4d0.diff

LOG: [mlir] Refactor ElementsAttr into an AttrInterface

This revision refactors ElementsAttr into an Attribute Interface.
This enables a common interface with which to interact with
element attributes, without needing to modify the builtin
dialect. It also removes a majority (if not all?) of the need for
the current OpaqueElementsAttr, which was originally intended as
a way to opaquely represent data that was not representable by
the other builtin constructs.

The new ElementsAttr interface not only allows for users to
natively represent their data in the way that best suits them,
it also allows for efficient opaque access and iteration of the
underlying data. Attributes using the ElementsAttr interface
can directly expose support for interacting with the held
elements using any C++ data type they claim to support. For
example, DenseIntOrFpElementsAttr supports iteration using
various native C++ integer/float data types, as well as
APInt/APFloat, and more. ElementsAttr instances that refer to
DenseIntOrFpElementsAttr can use all of these data types for
iteration:

```c++
DenseIntOrFpElementsAttr intElementsAttr = ...;

ElementsAttr attr = intElementsAttr;
for (uint64_t value : attr.getValues<uint64_t>())
  ...;
for (APInt value : attr.getValues<APInt>())
  ...;
for (IntegerAttr value : attr.getValues<IntegerAttr>())
  ...;
```

ElementsAttr also supports failable range/iterator access,
allowing for selective code paths depending on data type
support:

```c++
ElementsAttr attr = ...;
if (auto range = attr.tryGetValues<uint64_t>()) {
  for (uint64_t value : *range)
    ...;
}
```

Differential Revision: https://reviews.llvm.org/D109190

Added: 
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/lib/IR/BuiltinAttributeInterfaces.cpp
    mlir/test/IR/elements-attr-interface.mlir
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 0c923e905b59c..a58042eb89854 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -285,6 +285,8 @@ class mapped_iterator
 
   ItTy getCurrent() { return this->I; }
 
+  const FuncTy &getFunction() const { return F; }
+
   FuncReturnTy operator*() const { return F(*this->I); }
 
 private:

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
new file mode 100644
index 0000000000000..392bffc09da6e
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -0,0 +1,264 @@
+//===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
+#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/Any.h"
+#include "llvm/Support/raw_ostream.h"
+#include <complex>
+
+namespace mlir {
+class ShapedType;
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+namespace detail {
+/// This class provides support for indexing into the element range of an
+/// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
+/// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
+/// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
+/// range, where all of the elements are layed out sequentially in memory. A
+/// non-contiguous range implies no contiguity, and elements may even be
+/// materialized when indexing, such as the case for a mapped_range.
+struct ElementsAttrIndexer {
+public:
+  ElementsAttrIndexer()
+      : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
+  ElementsAttrIndexer(ElementsAttrIndexer &&rhs)
+      : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
+    if (isContiguous)
+      conState = std::move(rhs.conState);
+    else
+      new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
+  }
+  ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
+      : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
+    if (isContiguous)
+      conState = rhs.conState;
+    else
+      new (&nonConState) NonContiguousState(rhs.nonConState);
+  }
+  ~ElementsAttrIndexer() {
+    if (!isContiguous)
+      nonConState.~NonContiguousState();
+  }
+
+  /// Construct an indexer for a non-contiguous range starting at the given
+  /// iterator. A non-contiguous range implies no contiguity, and elements may
+  /// even be materialized when indexing, such as the case for a mapped_range.
+  template <typename IteratorT>
+  static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
+    ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
+    new (&indexer.nonConState)
+        NonContiguousState(std::forward<IteratorT>(iterator));
+    return indexer;
+  }
+
+  // Construct an indexer for a contiguous range starting at the given element
+  // pointer. A contiguous range is an array-like range, where all of the
+  // elements are layed out sequentially in memory.
+  template <typename T>
+  static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
+    ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
+    new (&indexer.conState) ContiguousState(firstEltPtr);
+    return indexer;
+  }
+
+  /// Access the element at the given index.
+  template <typename T> T at(uint64_t index) const {
+    if (isSplat)
+      index = 0;
+    return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
+  }
+
+private:
+  ElementsAttrIndexer(bool isContiguous, bool isSplat)
+      : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
+
+  /// This class contains all of the state necessary to index a contiguous
+  /// range.
+  class ContiguousState {
+  public:
+    ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
+
+    /// Access the element at the given index.
+    template <typename T> const T &at(uint64_t index) const {
+      return *(reinterpret_cast<const T *>(firstEltPtr) + index);
+    }
+
+  private:
+    const void *firstEltPtr;
+  };
+
+  /// This class contains all of the state necessary to index a non-contiguous
+  /// range.
+  class NonContiguousState {
+  private:
+    /// This class is used to represent the abstract base of an opaque iterator.
+    /// This allows for all iterator and element types to be completely
+    /// type-erased.
+    struct OpaqueIteratorBase {
+      virtual ~OpaqueIteratorBase() {}
+      virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
+    };
+    /// This class is used to represent the abstract base of an opaque iterator
+    /// that iterates over elements of type `T`. This allows for all iterator
+    /// types to be completely type-erased.
+    template <typename T>
+    struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
+      virtual T at(uint64_t index) = 0;
+    };
+    /// This class is used to represent an opaque handle to an iterator of type
+    /// `IteratorT` that iterates over elements of type `T`.
+    template <typename IteratorT, typename T>
+    struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
+      template <typename ItTy, typename FuncTy, typename FuncReturnTy>
+      static void isMappedIteratorTestFn(
+          llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
+      template <typename U, typename... Args>
+      using is_mapped_iterator =
+          decltype(isMappedIteratorTestFn(std::declval<U>()));
+      template <typename U>
+      using detect_is_mapped_iterator =
+          llvm::is_detected<is_mapped_iterator, U>;
+
+      /// Access the element within the iterator at the given index.
+      template <typename ItT>
+      static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
+      atImpl(ItT &&it, uint64_t index) {
+        return *std::next(it, index);
+      }
+      template <typename ItT>
+      static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
+      atImpl(ItT &&it, uint64_t index) {
+        // Special case mapped_iterator to avoid copying the function.
+        return it.getFunction()(*std::next(it.getCurrent(), index));
+      }
+
+    public:
+      template <typename U>
+      OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
+      std::unique_ptr<OpaqueIteratorBase> clone() const final {
+        return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
+      }
+
+      /// Access the element at the given index.
+      T at(uint64_t index) final { return atImpl(iterator, index); }
+
+    private:
+      IteratorT iterator;
+    };
+
+  public:
+    /// Construct the state with the given iterator type.
+    template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
+                                      decltype(*std::declval<IteratorT>())>>
+    NonContiguousState(IteratorT iterator)
+        : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
+    NonContiguousState(const NonContiguousState &other)
+        : iterator(other.iterator->clone()) {}
+    NonContiguousState(NonContiguousState &&other) = default;
+
+    /// Access the element at the given index.
+    template <typename T> T at(uint64_t index) const {
+      auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
+      return valueIt->at(index);
+    }
+
+    /// The opaque iterator state.
+    std::unique_ptr<OpaqueIteratorBase> iterator;
+  };
+
+  /// A boolean indicating if this range is contiguous or not.
+  bool isContiguous;
+  /// A boolean indicating if this range is a splat.
+  bool isSplat;
+  /// The underlying range state.
+  union {
+    ContiguousState conState;
+    NonContiguousState nonConState;
+  };
+};
+
+/// This class implements a generic iterator for ElementsAttr.
+template <typename T>
+class ElementsAttrIterator
+    : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
+                                        std::random_access_iterator_tag, T,
+                                        std::ptr
diff _t, T, T> {
+public:
+  ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
+      : indexer(std::move(indexer)), index(dataIndex) {}
+
+  // Boilerplate iterator methods.
+  ptr
diff _t operator-(const ElementsAttrIterator &rhs) const {
+    return index - rhs.index;
+  }
+  bool operator==(const ElementsAttrIterator &rhs) const {
+    return index == rhs.index;
+  }
+  bool operator<(const ElementsAttrIterator &rhs) const {
+    return index < rhs.index;
+  }
+  ElementsAttrIterator &operator+=(ptr
diff _t offset) {
+    index += offset;
+    return *this;
+  }
+  ElementsAttrIterator &operator-=(ptr
diff _t offset) {
+    index -= offset;
+    return *this;
+  }
+
+  /// Return the value at the current iterator position.
+  T operator*() const { return indexer.at<T>(index); }
+
+private:
+  ElementsAttrIndexer indexer;
+  ptr
diff _t index;
+};
+} // namespace detail
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Tablegen Interface Declarations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+/// Return the elements of this attribute as a value of type 'T'.
+template <typename T>
+auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
+  if (Optional<iterator<T>> iterator = try_value_begin<T>())
+    return std::move(*iterator);
+  llvm::errs()
+      << "ElementsAttr does not provide iteration facilities for type `"
+      << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
+  llvm_unreachable("invalid `T` for ElementsAttr::getValues");
+}
+template <typename T>
+auto ElementsAttr::try_value_begin() const
+    -> DefaultValueCheckT<T, Optional<iterator<T>>> {
+  FailureOr<detail::ElementsAttrIndexer> indexer =
+      getValuesImpl(TypeID::get<T>());
+  if (failed(indexer))
+    return llvm::None;
+  return iterator<T>(std::move(*indexer), 0);
+}
+} // end namespace mlir.
+
+#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H

diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
new file mode 100644
index 0000000000000..2df7603c004a9
--- /dev/null
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -0,0 +1,430 @@
+//===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definition of the ElementsAttr interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
+#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ElementsAttrInterface
+//===----------------------------------------------------------------------===//
+
+def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface is used for attributes that contain the constant elements of
+    a tensor or vector type. It allows for opaquely interacting with the
+    elements of the underlying attribute, and most importantly allows for
+    accessing the element values (including iteration) in any of the C++ data
+    types supported by the underlying attribute.
+
+    An attribute implementing this interface can expose the supported data types
+    in two steps:
+
+    * Define the set of iterable C++ data types:
+
+    An attribute may define the set of iterable types by providing a definition
+    of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`.
+
+    -  `ContiguousIterableTypesT` should contain types which can be iterated
+       contiguously. A contiguous range is an array-like range, such as
+       ArrayRef, where all of the elements are layed out sequentially in memory.
+
+    -  `NonContiguousIterableTypesT` should contain types which can not be
+       iterated contiguously. A non-contiguous range implies no contiguity,
+       whose elements may even be materialized when indexing, such as the case
+       for a mapped_range.
+
+    As an example, consider an attribute that only contains i64 elements, with
+    the elements being stored within an ArrayRef. This attribute could
+    potentially define the iterable types as so:
+
+    ```c++
+    using ContiguousIterableTypesT = std::tuple<uint64_t>;
+    using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
+    ```
+
+    * Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for
+      each iterable type
+
+    These overloads should return an iterator to the start of the range for the
+    respective iterable type. Consider the example i64 elements attribute
+    described in the previous section. This attribute may define the
+    value_begin_impl overloads like so:
+
+    ```c++
+    /// Provide begin iterators for the various iterable types.
+    /// * uint64_t
+    auto value_begin_impl(OverloadToken<uint64_t>) const {
+      return getElements().begin();
+    }
+    /// * APInt
+    auto value_begin_impl(OverloadToken<llvm::APInt>) const {
+      return llvm::map_range(getElements(), [=](uint64_t value) {
+        return llvm::APInt(/*numBits=*/64, value);
+      }).begin();
+    }
+    /// * Attribute
+    auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
+      mlir::Type elementType = getType().getElementType();
+      return llvm::map_range(getElements(), [=](uint64_t value) {
+        return mlir::IntegerAttr::get(elementType,
+                                      llvm::APInt(/*numBits=*/64, value));
+      }).begin();
+    }
+    ```
+
+    After the above, ElementsAttr will now be able to iterate over elements
+    using each of the registered iterable data types:
+
+    ```c++
+    ElementsAttr attr = myI64ElementsAttr;
+
+    // We can access value ranges for the data types via `getValues<T>`.
+    for (uint64_t value : attr.getValues<uint64_t>())
+      ...;
+    for (llvm::APInt value : attr.getValues<llvm::APInt>())
+      ...;
+    for (mlir::IntegerAttr value : attr.getValues<mlir::IntegerAttr>())
+      ...;
+
+    // We can also access the value iterators directly.
+    auto it = attr.value_begin<uint64_t>(), e = attr.value_end<uint64_t>();
+    for (; it != e; ++it) {
+      uint64_t value = *it;
+      ...
+    }
+    ```
+
+    ElementsAttr also supports failable access to iterators and ranges. This
+    allows for safely checking if the attribute supports the data type, and can
+    also allow for code to have fast paths for native data types.
+
+    ```c++
+    // Using `tryGetValues<T>`, we can also safely handle when the attribute
+    // doesn't support the data type.
+    if (auto range = attr.tryGetValues<uint64_t>()) {
+      for (uint64_t value : *range)
+        ...;
+      return;
+    }
+
+    // We can also access the begin iterator safely, by using `try_value_begin`.
+    if (auto safeIt = attr.try_value_begin<uint64_t>()) {
+      auto it = *safeIt, e = attr.value_end<uint64_t>();
+      for (; it != e; ++it) {
+        uint64_t value = *it;
+        ...
+      }
+      return;
+    }
+    ```
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      This method returns an opaque range indexer for the given elementID, which
+      corresponds to a desired C++ element data type. Returns the indexer if the
+      attribute supports the given data type, failure otherwise.
+    }],
+    "::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl",
+    (ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{
+      auto result = getValueImpl(
+        (typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID,
+        /*isContiguous=*/std::true_type());
+      if (succeeded(result))
+        return std::move(result);
+
+      return getValueImpl(
+        (typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr,
+        elementID, /*isContiguous=*/std::false_type());
+    }]>,
+    InterfaceMethod<[{
+      Returns true if the attribute elements correspond to a splat, i.e. that
+      all elements of the attribute are the same value.
+    }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
+        // By default, only check for a single element splat.
+        return $_attr.getNumElements() == 1;
+    }]>
+  ];
+
+  string ElementsAttrInterfaceAccessors = [{
+    /// Return the attribute value at the given index. The index is expected to
+    /// refer to a valid element.
+    Attribute getValue(ArrayRef<uint64_t> index) const {
+      return getValue<Attribute>(index);
+    }
+
+    /// Return the value of type 'T' at the given index, where 'T' corresponds
+    /// to an Attribute type.
+    template <typename T>
+    std::enable_if_t<!std::is_same<T, ::mlir::Attribute>::value &&
+                     std::is_base_of<T, ::mlir::Attribute>::value>
+    getValue(ArrayRef<uint64_t> index) const {
+      return getValue(index).template dyn_cast_or_null<T>();
+    }
+
+    /// Return the value of type 'T' at the given index.
+    template <typename T>
+    T getValue(ArrayRef<uint64_t> index) const {
+      return getFlatValue<T>(getFlattenedIndex(index));
+    }
+
+    /// Return the number of elements held by this attribute.
+    int64_t size() const { return getNumElements(); }
+
+    /// Return if the attribute holds no elements.
+    bool empty() const { return size() == 0; }
+  }];
+
+  let extraTraitClassDeclaration = [{
+    // By default, no types are iterable.
+    using ContiguousIterableTypesT = std::tuple<>;
+    using NonContiguousIterableTypesT = std::tuple<>;
+
+    //===------------------------------------------------------------------===//
+    // Accessors
+    //===------------------------------------------------------------------===//
+
+    /// Return the element type of this ElementsAttr.
+    Type getElementType() const {
+      return ::mlir::ElementsAttr::getElementType($_attr);
+    }
+
+    /// Returns the number of elements held by this attribute.
+    int64_t getNumElements() const {
+      return ::mlir::ElementsAttr::getNumElements($_attr);
+    }
+
+    /// Return if the given 'index' refers to a valid element in this attribute.
+    bool isValidIndex(ArrayRef<uint64_t> index) const {
+      return ::mlir::ElementsAttr::isValidIndex($_attr, index);
+    }
+
+  protected:
+    /// Returns the 1-dimensional flattened row-major index from the given
+    /// multi-dimensional index.
+    uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
+      return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index);
+    }
+
+    //===------------------------------------------------------------------===//
+    // Value Iteration Internals
+    //===------------------------------------------------------------------===//
+  protected:
+    /// This class is used to allow specifying function overloads for 
diff erent
+    /// types, without actually taking the types as parameters. This avoids the
+    /// need to build complicated SFINAE to select specific overloads.
+    template <typename T>
+    struct OverloadToken {};
+
+  private:
+    /// This function unpacks the types within a given tuple and then forwards
+    /// on to the unwrapped variant.
+    template <typename... Ts, typename IsContiguousT>
+    auto getValueImpl(std::tuple<Ts...> *, ::mlir::TypeID elementID,
+                      IsContiguousT isContiguous) const {
+      return getValueImpl<Ts...>(elementID, isContiguous);
+    }
+    /// Check to see if the given `elementID` matches the current type `T`. If
+    /// it does, build a value result using the current type. If it doesn't,
+    /// keep looking for the desired type.
+    template <typename T, typename... Ts, typename IsContiguousT>
+    auto getValueImpl(::mlir::TypeID elementID,
+                      IsContiguousT isContiguous) const {
+      if (::mlir::TypeID::get<T>() == elementID)
+        return buildValueResult<T>(isContiguous);
+      return getValueImpl<Ts...>(elementID, isContiguous);
+    }
+    /// Bottom out case for no matching type.
+    template <typename IsContiguousT>
+    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>
+    getValueImpl(::mlir::TypeID, IsContiguousT) const {
+      return failure();
+    }
+
+    /// Build an indexer for the given type `T`, which is represented via a
+    /// contiguous range.
+    template <typename T>
+    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
+        /*isContiguous*/std::true_type) const {
+      auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+      return ::mlir::detail::ElementsAttrIndexer::contiguous(
+        $_attr.isSplat(), &*valueIt);
+    }
+    /// Build an indexer for the given type `T`, which is represented via a
+    /// non-contiguous range.
+    template <typename T>
+    ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
+        /*isContiguous*/std::false_type) const {
+      auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+      return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
+        $_attr.isSplat(), valueIt);
+    }
+
+  public:
+    //===------------------------------------------------------------------===//
+    // Value Iteration
+    //===------------------------------------------------------------------===//
+
+    /// Return an iterator to the first element of this attribute as a value of
+    /// type `T`.
+    template <typename T>
+    auto value_begin() const {
+      return $_attr.value_begin_impl(OverloadToken<T>());
+    }
+
+    /// Return the elements of this attribute as a value of type 'T'.
+    template <typename T>
+    auto getValues() const {
+      auto beginIt = $_attr.template value_begin<T>();
+      return llvm::make_range(beginIt, std::next(beginIt, size()));
+    }
+    /// Return the value at the given flattened index.
+    template <typename T> T getFlatValue(uint64_t index) const {
+      return *std::next($_attr.template value_begin<T>(), index);
+    }
+  }] # ElementsAttrInterfaceAccessors;
+
+  let extraClassDeclaration = [{
+    template <typename T>
+    using iterator = detail::ElementsAttrIterator<T>;
+    template <typename T>
+    using iterator_range = llvm::iterator_range<iterator<T>>;
+
+    //===------------------------------------------------------------------===//
+    // Accessors
+    //===------------------------------------------------------------------===//
+
+    /// Return the type of this attribute.
+    ShapedType getType() const;
+
+    /// Return the element type of this ElementsAttr.
+    Type getElementType() const { return getElementType(*this); }
+    static Type getElementType(Attribute elementsAttr);
+
+    /// Return if the given 'index' refers to a valid element in this attribute.
+    bool isValidIndex(ArrayRef<uint64_t> index) const {
+      return isValidIndex(*this, index);
+    }
+    static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
+    static bool isValidIndex(Attribute elementsAttr, ArrayRef<uint64_t> index);
+
+    /// Return the 1 dimensional flattened row-major index from the given
+    /// multi-dimensional index.
+    uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
+      return getFlattenedIndex(*this, index);
+    }
+    static uint64_t getFlattenedIndex(Attribute elementsAttr,
+                                      ArrayRef<uint64_t> index);
+
+    /// Returns the number of elements held by this attribute.
+    int64_t getNumElements() const { return getNumElements(*this); }
+    static int64_t getNumElements(Attribute elementsAttr);
+
+    //===------------------------------------------------------------------===//
+    // Value Iteration
+    //===------------------------------------------------------------------===//
+
+    template <typename T>
+    using DerivedAttrValueCheckT =
+        typename std::enable_if_t<std::is_base_of<Attribute, T>::value &&
+                                  !std::is_same<Attribute, T>::value>;
+    template <typename T, typename ResultT>
+    using DefaultValueCheckT =
+        typename std::enable_if_t<std::is_same<Attribute, T>::value ||
+                                  !std::is_base_of<Attribute, T>::value,
+                                  ResultT>;
+
+    /// Return the element of this attribute at the given index as a value of
+    /// type 'T'.
+    template <typename T>
+    T getFlatValue(uint64_t index) const {
+      return *std::next(value_begin<T>(), index);
+    }
+
+    /// Return the splat value for this attribute. This asserts that the
+    /// attribute corresponds to a splat.
+    template <typename T>
+    T getSplatValue() const {
+      assert(isSplat() && "expected splat attribute");
+      return *value_begin<T>();
+    }
+
+    /// Return the elements of this attribute as a value of type 'T'.
+    template <typename T>
+    DefaultValueCheckT<T, iterator_range<T>> getValues() const {
+      return iterator_range<T>(value_begin<T>(), value_end<T>());
+    }
+    template <typename T>
+    DefaultValueCheckT<T, iterator<T>> value_begin() const;
+    template <typename T>
+    DefaultValueCheckT<T, iterator<T>> value_end() const {
+      return iterator<T>({}, size());
+    }
+
+    /// Return the held element values a range of T, where T is a derived
+    /// attribute type.
+    template <typename T>
+    using DerivedAttrValueIterator =
+      llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
+    template <typename T>
+    using DerivedAttrValueIteratorRange =
+      llvm::iterator_range<DerivedAttrValueIterator<T>>;
+    template <typename T, typename = DerivedAttrValueCheckT<T>>
+    DerivedAttrValueIteratorRange<T> getValues() const {
+      auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+      return llvm::map_range(getValues<Attribute>(),
+                             static_cast<T (*)(Attribute)>(castFn));
+    }
+    template <typename T, typename = DerivedAttrValueCheckT<T>>
+    DerivedAttrValueIterator<T> value_begin() const {
+      return getValues<T>().begin();
+    }
+    template <typename T, typename = DerivedAttrValueCheckT<T>>
+    DerivedAttrValueIterator<T> value_end() const {
+      return {value_end<Attribute>(), nullptr};
+    }
+
+    //===------------------------------------------------------------------===//
+    // Failable Value Iteration
+
+    /// If this attribute supports iterating over element values of type `T`,
+    /// return the iterable range. Otherwise, return llvm::None.
+    template <typename T>
+    DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
+      if (Optional<iterator<T>> beginIt = try_value_begin<T>())
+        return iterator_range<T>(*beginIt, value_end<T>());
+      return llvm::None;
+    }
+    template <typename T>
+    DefaultValueCheckT<T, Optional<iterator<T>>> try_value_begin() const;
+
+    /// If this attribute supports iterating over element values of type `T`,
+    /// return the iterable range. Otherwise, return llvm::None.
+    template <typename T, typename = DerivedAttrValueCheckT<T>>
+    Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
+      auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+      if (auto values = tryGetValues<Attribute>())
+        return llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn));
+      return llvm::None;
+    }
+    template <typename T, typename = DerivedAttrValueCheckT<T>>
+    Optional<DerivedAttrValueIterator<T>> try_value_begin() const {
+      if (auto values = tryGetValues<T>())
+        return values->begin();
+      return llvm::None;
+    }
+  }] # ElementsAttrInterfaceAccessors;
+}
+
+#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 6edd56b09b55e..6938448de1ea7 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -9,7 +9,8 @@
 #ifndef MLIR_IR_BUILTINATTRIBUTES_H
 #define MLIR_IR_BUILTINATTRIBUTES_H
 
-#include "SubElementInterfaces.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/SubElementInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include <complex>
@@ -31,99 +32,8 @@ class ShapedType;
 //===----------------------------------------------------------------------===//
 
 namespace detail {
-template <typename T>
-class ElementsAttrIterator;
-template <typename T>
-class ElementsAttrRange;
-} // namespace detail
-
-/// A base attribute that represents a reference to a static shaped tensor or
-/// vector constant.
-class ElementsAttr : public Attribute {
-public:
-  using Attribute::Attribute;
-  template <typename T>
-  using iterator = detail::ElementsAttrIterator<T>;
-  template <typename T>
-  using iterator_range = detail::ElementsAttrRange<T>;
-
-  /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
-  /// with static shape.
-  ShapedType getType() const;
-
-  /// Return the element type of this ElementsAttr.
-  Type getElementType() const;
-
-  /// Return the value at the given index. The index is expected to refer to a
-  /// valid element.
-  Attribute getValue(ArrayRef<uint64_t> index) const;
-
-  /// Return the value of type 'T' at the given index, where 'T' corresponds to
-  /// an Attribute type.
-  template <typename T>
-  T getValue(ArrayRef<uint64_t> index) const {
-    return getValue(index).template cast<T>();
-  }
-
-  /// Return the elements of this attribute as a value of type 'T'. Note:
-  /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
-  /// iteration.
-  template <typename T> iterator_range<T> getValues() const;
-  template <typename T> iterator<T> value_begin() const;
-  template <typename T> iterator<T> value_end() const;
-
-  /// Return if the given 'index' refers to a valid element in this attribute.
-  bool isValidIndex(ArrayRef<uint64_t> index) const;
-  static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
-
-  /// Returns the 1-dimensional flattened row-major index from the given
-  /// multi-dimensional index.
-  uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const;
-  static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index);
-
-  /// Returns the number of elements held by this attribute.
-  int64_t getNumElements() const;
-
-  /// Returns the number of elements held by this attribute.
-  int64_t size() const { return getNumElements(); }
-
-  /// Returns if the number of elements held by this attribute is 0.
-  bool empty() const { return size() == 0; }
-
-  /// Generates a new ElementsAttr by mapping each int value to a new
-  /// underlying APInt. The new values can represent either an integer or float.
-  /// This ElementsAttr should contain integers.
-  ElementsAttr mapValues(Type newElementType,
-                         function_ref<APInt(const APInt &)> mapping) const;
-
-  /// Generates a new ElementsAttr by mapping each float value to a new
-  /// underlying APInt. The new values can represent either an integer or float.
-  /// This ElementsAttr should contain floats.
-  ElementsAttr mapValues(Type newElementType,
-                         function_ref<APInt(const APFloat &)> mapping) const;
-
-  /// Method for support type inquiry through isa, cast and dyn_cast.
-  static bool classof(Attribute attr);
-};
-
-namespace detail {
-/// DenseElementsAttr data is aligned to uint64_t, so this traits class is
-/// necessary to interop with PointerIntPair.
-class DenseElementDataPointerTypeTraits {
-public:
-  static inline const void *getAsVoidPointer(const char *ptr) { return ptr; }
-  static inline const char *getFromVoidPointer(const void *ptr) {
-    return static_cast<const char *>(ptr);
-  }
-
-  // Note: We could steal more bits if the need arises.
-  static constexpr int NumLowBitsAvailable = 1;
-};
-
 /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat,
-using DenseIterPtrAndSplat =
-    llvm::PointerIntPair<const char *, 1, bool,
-                         DenseElementDataPointerTypeTraits>;
+using DenseIterPtrAndSplat = std::pair<const char *, bool>;
 
 /// Impl iterator for indexed DenseElementsAttr iterators that records a data
 /// pointer and data index that is adjusted for the case of a splat attribute.
@@ -142,12 +52,12 @@ class DenseElementIndexedIteratorImpl
   /// Return the current index for this iterator, adjusted for the case of a
   /// splat.
   ptr
diff _t getDataIndex() const {
-    bool isSplat = this->base.getInt();
+    bool isSplat = this->base.second;
     return isSplat ? 0 : this->index;
   }
 
   /// Return the data base pointer.
-  const char *getData() const { return this->base.getPointer(); }
+  const char *getData() const { return this->base.first; }
 };
 
 /// Type trait detector that checks if a given type T is a complex type.
@@ -159,9 +69,14 @@ struct is_complex_t<std::complex<T>> : public std::true_type {};
 
 /// An attribute that represents a reference to a dense vector or tensor object.
 ///
-class DenseElementsAttr : public ElementsAttr {
+class DenseElementsAttr : public Attribute {
 public:
-  using ElementsAttr::ElementsAttr;
+  using Attribute::Attribute;
+
+  /// Allow implicit conversion to ElementsAttr.
+  operator ElementsAttr() const {
+    return *this ? cast<ElementsAttr>() : nullptr;
+  }
 
   /// Type trait used to check if the given type T is a potentially valid C++
   /// floating point type that can be used to access the underlying element
@@ -440,7 +355,7 @@ class DenseElementsAttr : public ElementsAttr {
   template <typename T>
   T getValue(ArrayRef<uint64_t> index) const {
     // Skip to the element corresponding to the flattened index.
-    return getFlatValue<T>(getFlattenedIndex(index));
+    return getFlatValue<T>(ElementsAttr::getFlattenedIndex(*this, index));
   }
   /// Return the value at the given flattened index.
   template <typename T> T getFlatValue(uint64_t index) const {
@@ -678,6 +593,22 @@ class DenseElementsAttr : public ElementsAttr {
   /// Return the raw StringRef data held by this attribute.
   ArrayRef<StringRef> getRawStringData() const;
 
+  /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
+  /// with static shape.
+  ShapedType getType() const;
+
+  /// Return the element type of this DenseElementsAttr.
+  Type getElementType() const;
+
+  /// Returns the number of elements held by this attribute.
+  int64_t getNumElements() const;
+
+  /// Returns the number of elements held by this attribute.
+  int64_t size() const { return getNumElements(); }
+
+  /// Returns if the number of elements held by this attribute is 0.
+  bool empty() const { return size() == 0; }
+
   //===--------------------------------------------------------------------===//
   // Mutation Utilities
   //===--------------------------------------------------------------------===//
@@ -761,7 +692,6 @@ class SplatElementsAttr : public DenseElementsAttr {
     return denseAttr && denseAttr.isSplat();
   }
 };
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -954,159 +884,6 @@ template <typename T>
 auto SparseElementsAttr::value_end() const -> iterator<T> {
   return getValues<T>().end();
 }
-
-namespace detail {
-/// This class represents a general iterator over the values of an ElementsAttr.
-/// It supports all subclasses aside from OpaqueElementsAttr.
-template <typename T>
-class ElementsAttrIterator
-    : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
-                                        std::random_access_iterator_tag, T,
-                                        std::ptr
diff _t, T, T> {
-  // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
-  // inside of a conversion operator.
-  using DenseIteratorT = typename std::enable_if<
-      true, decltype(std::declval<DenseElementsAttr>().value_begin<T>())>::type;
-  using SparseIteratorT = SparseElementsAttr::iterator<T>;
-
-  /// A union containing the specific iterators for each derived attribute kind.
-  union Iterator {
-    Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {}
-    Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {}
-    Iterator() {}
-    ~Iterator() {}
-
-    operator const DenseIteratorT &() const { return denseIt; }
-    operator const SparseIteratorT &() const { return sparseIt; }
-    operator DenseIteratorT &() { return denseIt; }
-    operator SparseIteratorT &() { return sparseIt; }
-
-    /// An instance of a dense elements iterator.
-    DenseIteratorT denseIt;
-    /// An instance of a sparse elements iterator.
-    SparseIteratorT sparseIt;
-  };
-
-  /// Utility method to process a functor on each of the internal iterator
-  /// types.
-  template <typename RetT, template <typename> class ProcessFn,
-            typename... Args>
-  RetT process(Args &...args) const {
-    if (attr.isa<DenseElementsAttr>())
-      return ProcessFn<DenseIteratorT>()(args...);
-    if (attr.isa<SparseElementsAttr>())
-      return ProcessFn<SparseIteratorT>()(args...);
-    llvm_unreachable("unexpected attribute kind");
-  }
-
-  /// Utility functors used to generically implement the iterators methods.
-  template <typename ItT>
-  struct PlusAssign {
-    void operator()(ItT &it, ptr
diff _t offset) { it += offset; }
-  };
-  template <typename ItT>
-  struct Minus {
-    ptr
diff _t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
-  };
-  template <typename ItT>
-  struct MinusAssign {
-    void operator()(ItT &it, ptr
diff _t offset) { it -= offset; }
-  };
-  template <typename ItT>
-  struct Dereference {
-    T operator()(ItT &it) { return *it; }
-  };
-  template <typename ItT>
-  struct ConstructIter {
-    void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
-  };
-  template <typename ItT>
-  struct DestructIter {
-    void operator()(ItT &it) { it.~ItT(); }
-  };
-
-public:
-  ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
-    process<void, ConstructIter>(it, rhs.it);
-  }
-  ~ElementsAttrIterator() { process<void, DestructIter>(it); }
-
-  /// Methods necessary to support random access iteration.
-  ptr
diff _t operator-(const ElementsAttrIterator<T> &rhs) const {
-    assert(attr == rhs.attr && "incompatible iterators");
-    return process<ptr
diff _t, Minus>(it, rhs.it);
-  }
-  bool operator==(const ElementsAttrIterator<T> &rhs) const {
-    return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
-  }
-  bool operator<(const ElementsAttrIterator<T> &rhs) const {
-    assert(attr == rhs.attr && "incompatible iterators");
-    return process<bool, std::less>(it, rhs.it);
-  }
-  ElementsAttrIterator<T> &operator+=(ptr
diff _t offset) {
-    process<void, PlusAssign>(it, offset);
-    return *this;
-  }
-  ElementsAttrIterator<T> &operator-=(ptr
diff _t offset) {
-    process<void, MinusAssign>(it, offset);
-    return *this;
-  }
-
-  /// Dereference the iterator at the current index.
-  T operator*() { return process<T, Dereference>(it); }
-
-private:
-  template <typename IteratorT>
-  ElementsAttrIterator(Attribute attr, IteratorT &&it)
-      : attr(attr), it(std::forward<IteratorT>(it)) {}
-
-  /// Allow accessing the constructor.
-  friend ElementsAttr;
-
-  /// The parent elements attribute.
-  Attribute attr;
-
-  /// A union containing the specific iterators for each derived kind.
-  Iterator it;
-};
-
-template <typename T>
-class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> {
-  using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range;
-};
-} // namespace detail
-
-/// Return the elements of this attribute as a value of type 'T'.
-template <typename T>
-auto ElementsAttr::getValues() const -> iterator_range<T> {
-  if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
-    auto values = denseAttr.getValues<T>();
-    return {iterator<T>(*this, values.begin()),
-            iterator<T>(*this, values.end())};
-  }
-  if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
-    auto values = sparseAttr.getValues<T>();
-    return {iterator<T>(*this, values.begin()),
-            iterator<T>(*this, values.end())};
-  }
-  llvm_unreachable("unexpected attribute kind");
-}
-
-template <typename T> auto ElementsAttr::value_begin() const -> iterator<T> {
-  if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
-    return iterator<T>(*this, denseAttr.value_begin<T>());
-  if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
-    return iterator<T>(*this, sparseAttr.value_begin<T>());
-  llvm_unreachable("unexpected attribute kind");
-}
-template <typename T> auto ElementsAttr::value_end() const -> iterator<T> {
-  if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
-    return iterator<T>(*this, denseAttr.value_end<T>());
-  if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
-    return iterator<T>(*this, sparseAttr.value_end<T>());
-  llvm_unreachable("unexpected attribute kind");
-}
-
 } // end namespace mlir.
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index e39d56f146ef1..0d3ead2383722 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -15,6 +15,7 @@
 #define BUILTIN_ATTRIBUTES
 
 include "mlir/IR/BuiltinDialect.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the attributes defined in this file are prefixed with
@@ -136,8 +137,9 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 // DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_DenseIntOrFPElementsAttr
-    : Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> {
+def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
+    "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
+  > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "integer or floating-point values";
   let description = [{
@@ -165,6 +167,42 @@ def Builtin_DenseIntOrFPElementsAttr
   let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
                         "ArrayRef<char>":$rawData);
   let extraClassDeclaration = [{
+    using DenseElementsAttr::empty;
+    using DenseElementsAttr::getFlatValue;
+    using DenseElementsAttr::getNumElements;
+    using DenseElementsAttr::getValue;
+    using DenseElementsAttr::getValues;
+    using DenseElementsAttr::isSplat;
+    using DenseElementsAttr::size;
+    using DenseElementsAttr::value_begin;
+
+    /// The set of data types that can be iterated by this attribute.
+    using ContiguousIterableTypesT = std::tuple<
+      // Integer types.
+      uint8_t, uint16_t, uint32_t, uint64_t,
+      int8_t, int16_t, int32_t, int64_t,
+      short, unsigned short, int, unsigned, long, unsigned long,
+      std::complex<uint8_t>, std::complex<uint16_t>, std::complex<uint32_t>,
+      std::complex<uint64_t>,
+      std::complex<int8_t>, std::complex<int16_t>, std::complex<int32_t>,
+      std::complex<int64_t>,
+      // Float types.
+      float, double, std::complex<float>, std::complex<double>
+    >;
+    using NonContiguousIterableTypesT = std::tuple<
+      Attribute,
+      // Integer types.
+      APInt, bool, std::complex<APInt>,
+      // Float types.
+      APFloat, std::complex<APFloat>
+    >;
+
+    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    template <typename T>
+    auto value_begin_impl(OverloadToken<T>) const {
+      return value_begin<T>();
+    }
+
     /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
     /// the elements of `inRawData` has `type`. If `inRawData` is little endian
     /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is
@@ -231,8 +269,9 @@ def Builtin_DenseIntOrFPElementsAttr
 // DenseStringElementsAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_DenseStringElementsAttr
-    : Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> {
+def Builtin_DenseStringElementsAttr : Builtin_Attr<
+    "DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr"
+  > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "strings";
   let description = [{
@@ -267,6 +306,25 @@ def Builtin_DenseStringElementsAttr
     }]>,
   ];
   let extraClassDeclaration = [{
+    using DenseElementsAttr::empty;
+    using DenseElementsAttr::getFlatValue;
+    using DenseElementsAttr::getNumElements;
+    using DenseElementsAttr::getValue;
+    using DenseElementsAttr::getValues;
+    using DenseElementsAttr::isSplat;
+    using DenseElementsAttr::size;
+    using DenseElementsAttr::value_begin;
+
+    /// The set of data types that can be iterated by this attribute.
+    using ContiguousIterableTypesT = std::tuple<StringRef>;
+    using NonContiguousIterableTypesT = std::tuple<Attribute>;
+
+    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    template <typename T>
+    auto value_begin_impl(OverloadToken<T>) const {
+      return value_begin<T>();
+    }
+
   protected:
     friend DenseElementsAttr;
 
@@ -594,8 +652,9 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
 // OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_OpaqueElementsAttr
-    : Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> {
+def Builtin_OpaqueElementsAttr : Builtin_Attr<
+    "OpaqueElements", [ElementsAttrInterface]
+  > {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
     Syntax:
@@ -650,7 +709,6 @@ def Builtin_OpaqueElementsAttr
     /// Returns false if decoding is successful. If not, returns true and leaves
     /// 'result' argument unspecified.
     bool decode(ElementsAttr &result);
-
   }];
   let genVerifyDecl = 1;
   let skipDefaultBuilders = 1;
@@ -660,8 +718,9 @@ def Builtin_OpaqueElementsAttr
 // SparseElementsAttr
 //===----------------------------------------------------------------------===//
 
-def Builtin_SparseElementsAttr
-    : Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> {
+def Builtin_SparseElementsAttr : Builtin_Attr<
+    "SparseElements", [ElementsAttrInterface]
+  > {
   let summary = "An opaque representation of a multi-dimensional array";
   let description = [{
     Syntax:
@@ -712,6 +771,33 @@ def Builtin_SparseElementsAttr
     }]>,
   ];
   let extraClassDeclaration = [{
+    /// The set of data types that can be iterated by this attribute.
+    // FIXME: Realistically, SparseElementsAttr could use ElementsAttr for the
+    // value storage. This would mean dispatching to `values` when accessing
+    // values. For now, we just add the types that can be iterated by
+    // DenseElementsAttr.
+    using NonContiguousIterableTypesT = std::tuple<
+      Attribute,
+      // Integer types.
+      APInt, bool, uint8_t, uint16_t, uint32_t, uint64_t,
+      int8_t, int16_t, int32_t, int64_t,
+      short, unsigned short, int, unsigned, long, unsigned long,
+      std::complex<APInt>, std::complex<uint8_t>, std::complex<uint16_t>,
+      std::complex<uint32_t>, std::complex<uint64_t>, std::complex<int8_t>,
+      std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
+      // Float types.
+      APFloat, float, double,
+      std::complex<APFloat>, std::complex<float>, std::complex<double>,
+      // String types.
+      StringRef
+    >;
+
+    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    template <typename T>
+    auto value_begin_impl(OverloadToken<T>) const {
+      return value_begin<T>();
+    }
+
     template <typename T>
     using iterator =
         llvm::mapped_iterator<typename decltype(llvm::seq<ptr
diff _t>(0, 0))::iterator,

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 2757f3d6ead59..59305a594df2f 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -7,6 +7,11 @@ mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRBuiltinAttributesIncGen)
 
+set(LLVM_TARGET_DEFINITIONS BuiltinAttributeInterfaces.td)
+mlir_tablegen(BuiltinAttributeInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(BuiltinAttributeInterfaces.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRBuiltinAttributeInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td)
 mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
 mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs)

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index e5d526b1ae515..2c474fac59d87 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -93,6 +93,7 @@ class Interface : public BaseType {
       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
     assert((!t || impl) && "expected value to provide interface instance");
   }
+  Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {}
 
   /// Construct an interface instance from a type that implements this
   /// interface's trait.

diff  --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
new file mode 100644
index 0000000000000..6bfa1ee6633ab
--- /dev/null
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -0,0 +1,74 @@
+//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/Sequence.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType ElementsAttr::getType() const {
+  return Attribute::getType().cast<ShapedType>();
+}
+
+Type ElementsAttr::getElementType(Attribute elementsAttr) {
+  return elementsAttr.getType().cast<ShapedType>().getElementType();
+}
+
+int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
+  return elementsAttr.getType().cast<ShapedType>().getNumElements();
+}
+
+bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
+  // Verify that the rank of the indices matches the held type.
+  int64_t rank = type.getRank();
+  if (rank == 0 && index.size() == 1 && index[0] == 0)
+    return true;
+  if (rank != static_cast<int64_t>(index.size()))
+    return false;
+
+  // Verify that all of the indices are within the shape dimensions.
+  ArrayRef<int64_t> shape = type.getShape();
+  return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
+    int64_t dim = static_cast<int64_t>(index[i]);
+    return 0 <= dim && dim < shape[i];
+  });
+}
+bool ElementsAttr::isValidIndex(Attribute elementsAttr,
+                                ArrayRef<uint64_t> index) {
+  return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
+}
+
+uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
+                                         ArrayRef<uint64_t> index) {
+  ShapedType type = elementsAttr.getType().cast<ShapedType>();
+  assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
+
+  // Reduce the provided multidimensional index into a flattended 1D row-major
+  // index.
+  auto rank = type.getRank();
+  auto shape = type.getShape();
+  uint64_t valueIndex = 0;
+  uint64_t dimMultiplier = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    valueIndex += index[i] * dimMultiplier;
+    dimMultiplier *= shape[i];
+  }
+  return valueIndex;
+}

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index a3c7fb0af9293..e03bac1d3d5c2 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -382,92 +382,6 @@ LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// ElementsAttr
-//===----------------------------------------------------------------------===//
-
-ShapedType ElementsAttr::getType() const {
-  return Attribute::getType().cast<ShapedType>();
-}
-
-Type ElementsAttr::getElementType() const { return getType().getElementType(); }
-
-/// Returns the number of elements held by this attribute.
-int64_t ElementsAttr::getNumElements() const {
-  return getType().getNumElements();
-}
-
-/// Return the value at the given index. If index does not refer to a valid
-/// element, then a null attribute is returned.
-Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
-  if (auto denseAttr = dyn_cast<DenseElementsAttr>())
-    return denseAttr.getValue(index);
-  if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
-    return opaqueAttr.getValue(index);
-  return cast<SparseElementsAttr>().getValue(index);
-}
-
-bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
-  return isValidIndex(getType(), index);
-}
-bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
-  // Verify that the rank of the indices matches the held type.
-  int64_t rank = type.getRank();
-  if (rank == 0 && index.size() == 1 && index[0] == 0)
-    return true;
-  if (rank != static_cast<int64_t>(index.size()))
-    return false;
-
-  // Verify that all of the indices are within the shape dimensions.
-  ArrayRef<int64_t> shape = type.getShape();
-  return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
-    int64_t dim = static_cast<int64_t>(index[i]);
-    return 0 <= dim && dim < shape[i];
-  });
-}
-
-uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
-  return getFlattenedIndex(getType(), index);
-}
-uint64_t ElementsAttr::getFlattenedIndex(ShapedType type,
-                                         ArrayRef<uint64_t> index) {
-  assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
-
-  // Reduce the provided multidimensional index into a flattended 1D row-major
-  // index.
-  auto rank = type.getRank();
-  auto shape = type.getShape();
-  uint64_t valueIndex = 0;
-  uint64_t dimMultiplier = 1;
-  for (int i = rank - 1; i >= 0; --i) {
-    valueIndex += index[i] * dimMultiplier;
-    dimMultiplier *= shape[i];
-  }
-  return valueIndex;
-}
-
-ElementsAttr
-ElementsAttr::mapValues(Type newElementType,
-                        function_ref<APInt(const APInt &)> mapping) const {
-  if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
-    return intOrFpAttr.mapValues(newElementType, mapping);
-  llvm_unreachable("unsupported ElementsAttr subtype");
-}
-
-ElementsAttr
-ElementsAttr::mapValues(Type newElementType,
-                        function_ref<APInt(const APFloat &)> mapping) const {
-  if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
-    return intOrFpAttr.mapValues(newElementType, mapping);
-  llvm_unreachable("unsupported ElementsAttr subtype");
-}
-
-/// Method for support type inquiry through isa, cast and dyn_cast.
-bool ElementsAttr::classof(Attribute attr) {
-  return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
-                  OpaqueElementsAttr, SparseElementsAttr>();
-}
-
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr Utilities
 //===----------------------------------------------------------------------===//
@@ -1065,6 +979,18 @@ DenseElementsAttr DenseElementsAttr::mapValues(
   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
 }
 
+ShapedType DenseElementsAttr::getType() const {
+  return Attribute::getType().cast<ShapedType>();
+}
+
+Type DenseElementsAttr::getElementType() const {
+  return getType().getElementType();
+}
+
+int64_t DenseElementsAttr::getNumElements() const {
+  return getType().getNumElements();
+}
+
 //===----------------------------------------------------------------------===//
 // DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
@@ -1431,7 +1357,7 @@ SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   // Verify indices shape.
   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
   if (indicesRank == 2) {
-    if (indicesType.getDimSize(1) != rank)
+    if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
       return emitShapeError();
   } else if (indicesRank != 1 || rank != 1) {
     return emitShapeError();

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3f5e59d13b085..53326ad2ec648 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRIR
   Attributes.cpp
   Block.cpp
   Builders.cpp
+  BuiltinAttributeInterfaces.cpp
   BuiltinAttributes.cpp
   BuiltinDialect.cpp
   BuiltinTypes.cpp
@@ -36,6 +37,7 @@ add_mlir_library(MLIRIR
 
   DEPENDS
   MLIRBuiltinAttributesIncGen
+  MLIRBuiltinAttributeInterfacesIncGen
   MLIRBuiltinDialectIncGen
   MLIRBuiltinLocationAttributesIncGen
   MLIRBuiltinOpsIncGen

diff  --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
new file mode 100644
index 0000000000000..a9b00d2339b02
--- /dev/null
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -test-elements-attr-interface -verify-diagnostics
+
+// This test contains various `ElementsAttr` attributes, and tests the support
+// for iterating the values of these attributes using various native C++ types.
+// This tests that the abstract iteration of ElementsAttr works properly, and
+// is properly failable when necessary.
+
+// expected-error at below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
+// expected-error at below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
+// expected-error at below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
+std.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
+
+// expected-error at below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
+// expected-error at below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
+// expected-error at below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
+std.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
+
+// expected-error at below {{Test iterating `uint64_t`: unable to iterate type}}
+// expected-error at below {{Test iterating `APInt`: unable to iterate type}}
+// expected-error at below {{Test iterating `IntegerAttr`: unable to iterate type}}
+std.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 2dc583a8cfd03..8e36f634fdc92 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -15,6 +15,7 @@
 
 // To get the test dialect definition.
 include "TestOps.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
 
 // All of the attributes will extend this class.
 class Test_Attr<string name, list<Trait> traits = []>
@@ -63,4 +64,41 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
   let parameters = (ins );
 }
 
+// Test support for ElementsAttrInterface.
+def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
+    ElementsAttrInterface
+  ]> {
+  let mnemonic = "i64_elements";
+  let parameters = (ins
+    AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
+    ArrayRefParameter<"uint64_t">:$elements
+  );
+  let extraClassDeclaration = [{
+    /// The set of data types that can be iterated by this attribute.
+    using ContiguousIterableTypesT = std::tuple<uint64_t>;
+    using NonContiguousIterableTypesT = std::tuple<mlir::Attribute, llvm::APInt>;
+
+    /// Provide begin iterators for the various iterable types.
+    // * uint64_t
+    auto value_begin_impl(OverloadToken<uint64_t>) const {
+      return getElements().begin();
+    }
+    // * Attribute
+    auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
+      mlir::Type elementType = getType().getElementType();
+      return llvm::map_range(getElements(), [=](uint64_t value) {
+        return mlir::IntegerAttr::get(elementType,
+                                      llvm::APInt(/*numBits=*/64, value));
+      }).begin();
+    }
+    // * APInt
+    auto value_begin_impl(OverloadToken<llvm::APInt>) const {
+      return llvm::map_range(getElements(), [=](uint64_t value) {
+        return llvm::APInt(/*numBits=*/64, value);
+      }).begin();
+    }
+  }];
+  let genVerifyDecl = 1;
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 94b9ea8429944..3cc2c81fee486 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -89,6 +89,48 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
   printer << "]>";
 }
 
+//===----------------------------------------------------------------------===//
+// CompoundAAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestI64ElementsAttr::parse(MLIRContext *context,
+                                     DialectAsmParser &parser, Type type) {
+  SmallVector<uint64_t> elements;
+  if (parser.parseLess() || parser.parseLSquare())
+    return Attribute();
+  uint64_t intVal;
+  while (succeeded(*parser.parseOptionalInteger(intVal))) {
+    elements.push_back(intVal);
+    if (parser.parseOptionalComma())
+      break;
+  }
+
+  if (parser.parseRSquare() || parser.parseGreater())
+    return Attribute();
+  return parser.getChecked<TestI64ElementsAttr>(
+      context, type.cast<ShapedType>(), elements);
+}
+
+void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const {
+  printer << "i64_elements<[";
+  llvm::interleaveComma(getElements(), printer);
+  printer << "] : " << getType() << ">";
+}
+
+LogicalResult
+TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                            ShapedType type, ArrayRef<uint64_t> elements) {
+  if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
+    return emitError()
+           << "number of elements does not match the provided shape type, got: "
+           << elements.size() << ", but expected: " << type.getNumElements();
+  }
+  if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
+    return emitError() << "expected single rank 64-bit shape type, but got: "
+                       << type;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 8fc9a22e6c740..0b5833979ac94 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestIR
+  TestBuiltinAttributeInterfaces.cpp
   TestDiagnostics.cpp
   TestDominance.cpp
   TestFunc.cpp

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
new file mode 100644
index 0000000000000..9f334603ca82f
--- /dev/null
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -0,0 +1,61 @@
+//===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestAttributes.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+struct TestElementsAttrInterface
+    : public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-elements-attr-interface"; }
+  StringRef getDescription() const final {
+    return "Test ElementsAttr interface support.";
+  }
+  void runOnOperation() override {
+    getOperation().walk([&](Operation *op) {
+      for (NamedAttribute attr : op->getAttrs()) {
+        auto elementsAttr = attr.second.dyn_cast<ElementsAttr>();
+        if (!elementsAttr)
+          continue;
+        testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
+        testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
+        testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
+      }
+    });
+  }
+
+  template <typename T>
+  void testElementsAttrIteration(Operation *op, ElementsAttr attr,
+                                 StringRef type) {
+    InFlightDiagnostic diag = op->emitError()
+                              << "Test iterating `" << type << "`: ";
+
+    auto values = attr.tryGetValues<T>();
+    if (!values) {
+      diag << "unable to iterate type";
+      return;
+    }
+
+    llvm::interleaveComma(*values, diag, [&](T value) {
+      diag << llvm::formatv("{0}", value).str();
+    });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestBuiltinAttributeInterfaces() {
+  PassRegistration<TestElementsAttrInterface>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2fba77f76b433..b3d9e54193cb2 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -62,6 +62,7 @@ void registerPatternsTestPass();
 void registerSimpleParametricTilingPass();
 void registerTestAffineLoopParametricTilingPass();
 void registerTestAliasAnalysisPass();
+void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
 void registerTestConstantFold();
 void registerTestConvVectorization();
@@ -146,6 +147,7 @@ void registerTestPasses() {
   mlir::test::registerSimpleParametricTilingPass();
   mlir::test::registerTestAffineLoopParametricTilingPass();
   mlir::test::registerTestAliasAnalysisPass();
+  mlir::test::registerTestBuiltinAttributeInterfaces();
   mlir::test::registerTestCallGraphPass();
   mlir::test::registerTestConstantFold();
   mlir::test::registerTestDiagnosticsPass();

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 631620d8b9f4e..f4dfb37914775 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -106,6 +106,8 @@ gentbl_cc_library(
 td_library(
     name = "BuiltinDialectTdFiles",
     srcs = [
+        "include/mlir/IR/BuiltinAttributeInterfaces.td",
+        "include/mlir/IR/BuiltinAttributes.td",
         "include/mlir/IR/BuiltinDialect.td",
         "include/mlir/IR/BuiltinLocationAttributes.td",
         "include/mlir/IR/BuiltinOps.td",
@@ -159,6 +161,24 @@ gentbl_cc_library(
     deps = [":BuiltinDialectTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "BuiltinAttributeInterfacesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["--gen-attr-interface-decls"],
+            "include/mlir/IR/BuiltinAttributeInterfaces.h.inc",
+        ),
+        (
+            ["--gen-attr-interface-defs"],
+            "include/mlir/IR/BuiltinAttributeInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/IR/BuiltinAttributeInterfaces.td",
+    deps = [":BuiltinDialectTdFiles"],
+)
+
 gentbl_cc_library(
     name = "BuiltinLocationAttributesIncGen",
     strip_include_prefix = "include",
@@ -249,6 +269,7 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
+        ":BuiltinAttributeInterfacesIncGen",
         ":BuiltinAttributesIncGen",
         ":BuiltinDialectIncGen",
         ":BuiltinLocationAttributesIncGen",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 72134154a686b..fb883c736462a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -35,6 +35,7 @@ td_library(
     name = "TestOpTdFiles",
     srcs = glob(["lib/Dialect/Test/*.td"]),
     deps = [
+        "//mlir:BuiltinDialectTdFiles",
         "//mlir:CallInterfacesTdFiles",
         "//mlir:ControlFlowInterfacesTdFiles",
         "//mlir:CopyOpInterfaceTdFiles",


        


More information about the llvm-commits mailing list