[llvm-branch-commits] [flang] [llvm] [mlir] [mlir][IR] Rename + merge `DenseIntOrFPElementsAttr` (PR #181559)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Feb 15 10:40:37 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/181559
>From 2e885b945d9094e836de3bf67994500343e75ae0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 15 Feb 2026 18:11:12 +0000
Subject: [PATCH 1/3] [mlir][IR] Rename + merge DenseElementsAttr
---
mlir/include/mlir-c/BuiltinAttributes.h | 2 +-
mlir/include/mlir/IR/BuiltinAttributes.h | 746 ++++--------------
mlir/include/mlir/IR/BuiltinAttributes.td | 293 ++++++-
.../include/mlir/IR/BuiltinDialectBytecode.td | 5 +-
mlir/lib/AsmParser/AttributeParser.cpp | 4 +-
mlir/lib/Bindings/Python/IRAttributes.cpp | 5 +-
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 4 +-
.../Transforms/ShardingInterfaceImpl.cpp | 2 +-
.../Linalg/Transforms/ConstantFold.cpp | 2 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +-
mlir/lib/IR/AsmPrinter.cpp | 26 +-
mlir/lib/IR/AttributeDetail.h | 10 +-
mlir/lib/IR/BuiltinAttributes.cpp | 102 +--
mlir/lib/Rewrite/ByteCode.cpp | 4 +-
mlir/utils/gdb-scripts/prettyprinters.py | 2 +-
15 files changed, 484 insertions(+), 725 deletions(-)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 69a50942e8ee6..d2fcec0f2e62a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -469,7 +469,7 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
/// Returns the typeID of an DenseIntOrFPElements attribute.
-MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void);
+MLIR_CAPI_EXPORTED MlirTypeID mlirDenseElementsAttrGetTypeID(void);
/// Creates a dense elements attribute with the given Shaped type and elements
/// in the same context as the type.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 3ba943c7ccd41..6a0fca6f92dba 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -32,7 +32,7 @@ class Operation;
class RankedTensorType;
namespace detail {
-struct DenseIntOrFPElementsAttrStorage;
+struct DenseElementsAttrStorage;
struct DenseStringElementsAttrStorage;
struct StringAttrStorage;
} // namespace detail
@@ -77,574 +77,169 @@ template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
-/// An attribute that represents a reference to a dense vector or tensor
-/// object.
-class DenseElementsAttr : public Attribute {
-public:
- using Attribute::Attribute;
-
- /// Allow implicit conversion to ElementsAttr.
- operator ElementsAttr() const { return cast_if_present<ElementsAttr>(*this); }
- /// Allow implicit conversion to TypedAttr.
- operator TypedAttr() const { return ElementsAttr(*this); }
-
- /// Type trait used to check if the given type T is a potentially valid C++
- /// floating point type that can be used to access the underlying element
- /// types of a DenseElementsAttr.
- template <typename T>
- struct is_valid_cpp_fp_type {
- /// The type is a valid floating point type if it is a builtin floating
- /// point type, or is a potentially user defined floating point type. The
- /// latter allows for supporting users that have custom types defined for
- /// bfloat16/half/etc.
- static constexpr bool value = llvm::is_one_of<T, float, double>::value ||
- (std::numeric_limits<T>::is_specialized &&
- !std::numeric_limits<T>::is_integer);
- };
-
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr);
-
- /// Constructs a dense elements attribute from an array of element values.
- /// Each element attribute value is expected to be an element of 'type'.
- /// 'type' must be a vector or tensor with static shape. If the element of
- /// `type` is non-integer/index/float it is assumed to be a string type.
- static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
-
- /// Constructs a dense integer elements attribute from an array of integer
- /// or floating-point values. Each value is expected to be the same bitwidth
- /// of the element type of 'type'. 'type' must be a vector or tensor with
- /// static shape.
- template <typename T,
- typename = std::enable_if_t<std::numeric_limits<T>::is_integer ||
- is_valid_cpp_fp_type<T>::value>>
- static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
- const char *data = reinterpret_cast<const char *>(values.data());
- return getRawIntOrFloat(
- type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
- std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed);
- }
-
- /// Constructs a dense integer elements attribute from a single element.
- template <typename T,
- typename = std::enable_if_t<std::numeric_limits<T>::is_integer ||
- is_valid_cpp_fp_type<T>::value ||
- detail::is_complex_t<T>::value>>
- static DenseElementsAttr get(const ShapedType &type, T value) {
- return get(type, llvm::ArrayRef(value));
- }
-
- /// Constructs a dense complex elements attribute from an array of complex
- /// values. Each value is expected to be the same bitwidth of the element type
- /// of 'type'. 'type' must be a vector or tensor with static shape.
- template <
- typename T, typename ElementT = typename T::value_type,
- typename = std::enable_if_t<detail::is_complex_t<T>::value &&
- (std::numeric_limits<ElementT>::is_integer ||
- is_valid_cpp_fp_type<ElementT>::value)>>
- static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
- const char *data = reinterpret_cast<const char *>(values.data());
- return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
- sizeof(T), std::numeric_limits<ElementT>::is_integer,
- std::numeric_limits<ElementT>::is_signed);
- }
-
- /// Overload of the above 'get' method that is specialized for boolean values.
- static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
-
- /// Constructs a dense integer elements attribute from an array of APInt
- /// values. Each APInt value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
-
- /// Constructs a dense complex elements attribute from an array of APInt
- /// values. Each APInt value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APInt>> values);
-
- /// Constructs a dense float elements attribute from an array of APFloat
- /// values. Each APFloat value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
-
- /// Constructs a dense complex elements attribute from an array of APFloat
- /// values. Each APFloat value is expected to have the same bitwidth as the
- /// element type of 'type'. 'type' must be a vector or tensor with static
- /// shape.
- static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APFloat>> values);
-
- /// Construct a dense elements attribute for an initializer_list of values.
- /// Each value is expected to be the same bitwidth of the element type of
- /// 'type'. 'type' must be a vector or tensor with static shape.
- template <typename T>
- static DenseElementsAttr get(const ShapedType &type,
- const std::initializer_list<T> &list) {
- return get(type, ArrayRef<T>(list));
- }
-
- /// Construct a dense elements attribute from a raw buffer representing the
- /// data for this attribute. Users are encouraged to use one of the
- /// constructors above, which provide more safeties. However, this
- /// constructor is useful for tools which may want to interop and can
- /// follow the precise definition.
- ///
- /// The format of the raw buffer is a densely packed array of values that
- /// can be bitcast to the storage format of the element type specified.
- /// Types that are not byte aligned will be rounded up to the next byte.
- static DenseElementsAttr getFromRawBuffer(ShapedType type,
- ArrayRef<char> rawBuffer);
-
- /// Returns true if the given buffer is a valid raw buffer for the given type.
- static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer);
-
- //===--------------------------------------------------------------------===//
- // Iterators
- //===--------------------------------------------------------------------===//
-
- /// The iterator range over the given iterator type T.
- template <typename IteratorT>
- using iterator_range_impl = detail::ElementsAttrRange<IteratorT>;
-
- /// The iterator for the given element type T.
- template <typename T, typename AttrT = DenseElementsAttr>
- using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
- /// The iterator range over the given element T.
- template <typename T, typename AttrT = DenseElementsAttr>
- using iterator_range =
- decltype(std::declval<AttrT>().template getValues<T>());
-
- /// A utility iterator that allows walking over the internal Attribute values
- /// of a dense elements attribute (DenseElementsAttr or
- /// DenseStringElementsAttr).
- class AttributeElementIterator
- : public llvm::indexed_accessor_iterator<AttributeElementIterator,
- const void *, Attribute,
- Attribute, Attribute> {
- public:
- /// Accesses the Attribute value at this iterator position.
- Attribute operator*() const;
-
- /// Constructs a new iterator. Accepts any attribute implementing
- /// ElementsAttr (e.g. DenseElementsAttr, DenseStringElementsAttr).
- AttributeElementIterator(Attribute attr, size_t index);
- };
-
- /// Iterator for walking raw element values of the specified type 'T', which
- /// may be any c++ data type matching the stored representation: int32_t,
- /// float, etc.
- template <typename T>
- class ElementIterator
- : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
- const T> {
- public:
- /// Accesses the raw value at this iterator position.
- const T &operator*() const {
- return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
- }
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- ElementIterator(const char *data, bool isSplat, size_t dataIndex)
- : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
- data, isSplat, dataIndex) {}
- };
-
- /// A utility iterator that allows walking over the internal bool values.
- class BoolElementIterator
- : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
- bool, bool, bool> {
- public:
- /// Accesses the bool value at this iterator position.
- bool operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
- };
-
- /// A utility iterator that allows walking over the internal raw APInt values.
- class IntElementIterator
- : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
- APInt, APInt, APInt> {
- public:
- /// Accesses the raw APInt value at this iterator position.
- APInt operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
-
- /// The bitwidth of the element type.
- size_t bitWidth;
- };
-
- /// A utility iterator that allows walking over the internal raw complex APInt
- /// values.
- class ComplexIntElementIterator
- : public detail::DenseElementIndexedIteratorImpl<
- ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
- std::complex<APInt>> {
- public:
- /// Accesses the raw std::complex<APInt> value at this iterator position.
- std::complex<APInt> operator*() const;
-
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
-
- /// The bitwidth of the element type.
- size_t bitWidth;
- };
-
- /// Iterator for walking over APFloat values.
- class FloatElementIterator final
- : public llvm::mapped_iterator_base<FloatElementIterator,
- IntElementIterator, APFloat> {
- public:
- /// Map the element to the iterator result type.
- APFloat mapElement(const APInt &value) const {
- return APFloat(*smt, value);
- }
-
- private:
- friend DenseElementsAttr;
-
- /// Initializes the float element iterator to the specified iterator.
- FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
- : BaseT(it), smt(&smt) {}
-
- /// The float semantics to use when constructing the APFloat.
- const llvm::fltSemantics *smt;
- };
-
- /// Iterator for walking over complex APFloat values.
- class ComplexFloatElementIterator final
- : public llvm::mapped_iterator_base<ComplexFloatElementIterator,
- ComplexIntElementIterator,
- std::complex<APFloat>> {
- public:
- /// Map the element to the iterator result type.
- std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
- return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
- }
-
- private:
- friend DenseElementsAttr;
-
- /// Initializes the float element iterator to the specified iterator.
- ComplexFloatElementIterator(const llvm::fltSemantics &smt,
- ComplexIntElementIterator it)
- : BaseT(it), smt(&smt) {}
-
- /// The float semantics to use when constructing the APFloat.
- const llvm::fltSemantics *smt;
- };
-
- //===--------------------------------------------------------------------===//
- // Value Querying
- //===--------------------------------------------------------------------===//
-
- /// Returns true if this attribute corresponds to a splat, i.e. if all element
- /// values are the same.
- bool isSplat() const;
-
- /// Return the splat value for this attribute. This asserts that the attribute
- /// corresponds to a splat.
- template <typename T>
- std::enable_if_t<!std::is_base_of<Attribute, T>::value ||
- std::is_same<Attribute, T>::value,
- T>
- getSplatValue() const {
- assert(isSplat() && "expected the attribute to be a splat");
- return *value_begin<T>();
- }
- /// Return the splat value for derived attribute element types.
- template <typename T>
- std::enable_if_t<std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value,
- T>
- getSplatValue() const {
- return llvm::cast<T>(getSplatValue<Attribute>());
- }
-
- /// Try to get an iterator of the given type to the start of the held element
- /// values. Return failure if the type cannot be iterated.
- template <typename T>
- auto try_value_begin() const {
- auto range = tryGetValues<T>();
- using iterator = decltype(range->begin());
- return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
- }
-
- /// Try to get an iterator of the given type to the end of the held element
- /// values. Return failure if the type cannot be iterated.
- template <typename T>
- auto try_value_end() const {
- auto range = tryGetValues<T>();
- using iterator = decltype(range->begin());
- return failed(range) ? FailureOr<iterator>(failure()) : range->end();
- }
-
- /// Return the held element values as a range of the given type.
- template <typename T>
- auto getValues() const {
- auto range = tryGetValues<T>();
- assert(succeeded(range) && "element type cannot be iterated");
- return std::move(*range);
- }
-
- /// Get an iterator of the given type to the start of the held element values.
- template <typename T>
- auto value_begin() const {
- return getValues<T>().begin();
- }
-
- /// Get an iterator of the given type to the end of the held element values.
- template <typename T>
- auto value_end() const {
- return getValues<T>().end();
- }
-
- /// Try to get the held element values as a range of integer or floating-point
- /// values.
- template <typename T>
- using IntFloatValueTemplateCheckT =
- std::enable_if_t<(!std::is_same<T, bool>::value &&
- std::numeric_limits<T>::is_integer) ||
- is_valid_cpp_fp_type<T>::value>;
- template <typename T, typename = IntFloatValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
- if (!isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
- std::numeric_limits<T>::is_signed))
- return failure();
- const char *rawData = getRawData().data();
- bool splat = isSplat();
- return iterator_range_impl<ElementIterator<T>>(
- getType(), ElementIterator<T>(rawData, splat, 0),
- ElementIterator<T>(rawData, splat, getNumElements()));
- }
-
- /// Try to get the held element values as a range of std::complex.
- template <typename T, typename ElementT>
- using ComplexValueTemplateCheckT =
- std::enable_if_t<detail::is_complex_t<T>::value &&
- (std::numeric_limits<ElementT>::is_integer ||
- is_valid_cpp_fp_type<ElementT>::value)>;
- template <typename T, typename ElementT = typename T::value_type,
- typename = ComplexValueTemplateCheckT<T, ElementT>>
- FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
- if (!isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
- std::numeric_limits<ElementT>::is_signed))
- return failure();
- const char *rawData = getRawData().data();
- bool splat = isSplat();
- return iterator_range_impl<ElementIterator<T>>(
- getType(), ElementIterator<T>(rawData, splat, 0),
- ElementIterator<T>(rawData, splat, getNumElements()));
- }
-
- /// Try to get the held element values as a range of Attributes.
- template <typename T>
- using AttributeValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, Attribute>::value>;
- template <typename T, typename = AttributeValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<AttributeElementIterator>>
- tryGetValues() const {
- return iterator_range_impl<AttributeElementIterator>(
- getType(), AttributeElementIterator(Attribute(*this), 0),
- AttributeElementIterator(Attribute(*this), getNumElements()));
- }
-
- /// Try to get the held element values a range of T, where T is a derived
- /// attribute type.
- template <typename T>
- using DerivedAttrValueTemplateCheckT =
- std::enable_if_t<std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value>;
- template <typename T>
- struct DerivedAttributeElementIterator
- : public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
- AttributeElementIterator, T> {
- using llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
- AttributeElementIterator,
- T>::mapped_iterator_base;
-
- /// Map the element to the iterator result type.
- T mapElement(Attribute attr) const { return llvm::cast<T>(attr); }
- };
- template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
- tryGetValues() const {
- using DerivedIterT = DerivedAttributeElementIterator<T>;
- return iterator_range_impl<DerivedIterT>(
- getType(), DerivedIterT(value_begin<Attribute>()),
- DerivedIterT(value_end<Attribute>()));
- }
-
- /// Try to get the held element values as a range of bool. The element type of
- /// this attribute must be of integer type of bitwidth 1.
- template <typename T>
- using BoolValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, bool>::value>;
- template <typename T, typename = BoolValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<BoolElementIterator>> tryGetValues() const {
- if (!isValidBool())
- return failure();
- return iterator_range_impl<BoolElementIterator>(
- getType(), BoolElementIterator(*this, 0),
- BoolElementIterator(*this, getNumElements()));
- }
-
- /// Try to get the held element values as a range of APInts. The element type
- /// of this attribute must be of integer type.
- template <typename T>
- using APIntValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, APInt>::value>;
- template <typename T, typename = APIntValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
- if (!getElementType().isIntOrIndex())
- return failure();
- return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
- raw_int_end());
- }
-
- /// Try to get the held element values as a range of complex APInts. The
- /// element type of this attribute must be a complex of integer type.
- template <typename T>
- using ComplexAPIntValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>;
- template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<ComplexIntElementIterator>>
- tryGetValues() const {
- return tryGetComplexIntValues();
- }
-
- /// Try to get the held element values as a range of APFloat. The element type
- /// of this attribute must be of float type.
- template <typename T>
- using APFloatValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, APFloat>::value>;
- template <typename T, typename = APFloatValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const {
- return tryGetFloatValues();
- }
-
- /// Try to get the held element values as a range of complex APFloat. The
- /// element type of this attribute must be a complex of float type.
- template <typename T>
- using ComplexAPFloatValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>;
- template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
- tryGetValues() const {
- return tryGetComplexFloatValues();
- }
-
- /// Return the raw storage data held by this attribute. Users should generally
- /// not use this directly, as the internal storage format is not always in the
- /// form the user might expect.
- ArrayRef<char> getRawData() const;
-
- /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
- /// with static shape.
- ShapedType getType() const;
-
- /// Return the element type of this DenseElementsAttr.
- Type getElementType() const;
-
- /// Returns the number of elements held by this attribute.
- int64_t getNumElements() const;
-
- /// Returns the number of elements held by this attribute.
- int64_t size() const { return getNumElements(); }
-
- /// Returns if the number of elements held by this attribute is 0.
- bool empty() const { return size() == 0; }
-
- //===--------------------------------------------------------------------===//
- // Mutation Utilities
- //===--------------------------------------------------------------------===//
-
- /// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but has been reshaped to 'newType'. The new type must have the
- /// same total number of elements as well as element type.
- DenseElementsAttr reshape(ShapedType newType);
-
- /// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but with a different shape for a splat type. The new type must
- /// have the same element type.
- DenseElementsAttr resizeSplat(ShapedType newType);
+} // namespace mlir
- /// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but has bitcast elements to 'newElType'. The new type must have
- /// the same bitwidth as the current element type.
- DenseElementsAttr bitcast(Type newElType);
+// DenseResourceElementsHandle is used by the generated
+// DenseResourceElementsAttr.
+namespace mlir {
+using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>;
+} // namespace mlir
- /// Generates a new DenseElementsAttr by mapping each int value to a new
- /// underlying APInt. The new values can represent either an integer or float.
- /// This underlying type must be an DenseIntElementsAttr.
- DenseElementsAttr mapValues(Type newElementType,
- function_ref<APInt(const APInt &)> mapping) const;
+// DenseElementsAttr is defined in TableGen (see Builtin_DenseElementsAttr in
+// BuiltinAttributes.td) and generated in BuiltinAttributes.h.inc below.
+#define GET_ATTRDEF_CLASSES
+#include "mlir/IR/BuiltinAttributes.h.inc"
- /// Generates a new DenseElementsAttr by mapping each float value to a new
- /// underlying APInt. the new values can represent either an integer or float.
- /// This underlying type must be an DenseFPElementsAttr.
- DenseElementsAttr
- mapValues(Type newElementType,
- function_ref<APInt(const APFloat &)> mapping) const;
+// Template method definitions for DenseElementsAttr (declared in TableGen).
+namespace mlir {
+template <typename T, typename>
+DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
+ ArrayRef<T> values) {
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawIntOrFloat(type, ArrayRef<char>(data, values.size() * sizeof(T)),
+ sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed);
+}
+template <typename T, typename>
+DenseElementsAttr DenseElementsAttr::get(const ShapedType &type, T value) {
+ return get(type, llvm::ArrayRef(value));
+}
+template <typename T, typename ElementT, typename>
+DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
+ ArrayRef<T> values) {
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
+ sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed);
+}
+template <typename T>
+DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
+ const std::initializer_list<T> &list) {
+ return get(type, ArrayRef<T>(list));
+}
+template <typename T>
+std::enable_if_t<!std::is_base_of<Attribute, T>::value ||
+ std::is_same<Attribute, T>::value,
+ T>
+DenseElementsAttr::getSplatValue() const {
+ assert(isSplat() && "expected the attribute to be a splat");
+ return *value_begin<T>();
+}
+template <typename T>
+std::enable_if_t<std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value,
+ T>
+DenseElementsAttr::getSplatValue() const {
+ return llvm::cast<T>(getSplatValue<Attribute>());
+}
+template <typename T>
+auto DenseElementsAttr::try_value_begin() const {
+ auto range = tryGetValues<T>();
+ using iterator = decltype(range->begin());
+ return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
+}
+template <typename T>
+auto DenseElementsAttr::try_value_end() const {
+ auto range = tryGetValues<T>();
+ using iterator = decltype(range->begin());
+ return failed(range) ? FailureOr<iterator>(failure()) : range->end();
+}
+template <typename T>
+auto DenseElementsAttr::getValues() const {
+ auto range = tryGetValues<T>();
+ assert(succeeded(range) && "element type cannot be iterated");
+ return std::move(*range);
+}
+template <typename T>
+auto DenseElementsAttr::value_begin() const {
+ return getValues<T>().begin();
+}
+template <typename T>
+auto DenseElementsAttr::value_end() const {
+ return getValues<T>().end();
+}
-protected:
- /// Iterators to various elements that require out-of-line definition. These
- /// are hidden from the user to encourage consistent use of the
- /// getValues/value_begin/value_end API.
- IntElementIterator raw_int_begin() const {
- return IntElementIterator(*this, 0);
- }
- IntElementIterator raw_int_end() const {
- return IntElementIterator(*this, getNumElements());
- }
- FailureOr<iterator_range_impl<ComplexIntElementIterator>>
- tryGetComplexIntValues() const;
- FailureOr<iterator_range_impl<FloatElementIterator>>
- tryGetFloatValues() const;
- FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
- tryGetComplexFloatValues() const;
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// complex type. This method is used to verify type invariants that the
- /// templatized 'get' method cannot.
- static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-
- /// Overload of the raw 'get' method that asserts that the given type is of
- /// integer or floating-point type. This method is used to verify type
- /// invariants that the templatized 'get' method cannot.
- static DenseElementsAttr getRawIntOrFloat(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned);
-
- /// Check the information for a C++ data type, check if this type is valid for
- /// the current attribute. This method is used to verify specific type
- /// invariants that the templatized 'getValues' method cannot.
- bool isValidBool() const { return getElementType().isInteger(1); }
- bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
- bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
-};
+// tryGetValues template definitions (required for instantiation in other TUs).
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::ElementIterator<T>>>
+DenseElementsAttr::tryGetValues() const {
+ if (!isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed))
+ return failure();
+ const char *rawData = getRawData().data();
+ bool splat = isSplat();
+ return iterator_range_impl<ElementIterator<T>>(
+ getType(), ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements()));
+}
+template <typename T, typename ElementT, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::ElementIterator<T>>>
+DenseElementsAttr::tryGetValues() const {
+ if (!isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed))
+ return failure();
+ const char *rawData = getRawData().data();
+ bool splat = isSplat();
+ return iterator_range_impl<ElementIterator<T>>(
+ getType(), ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements()));
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::AttributeElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ return iterator_range_impl<AttributeElementIterator>(
+ getType(), AttributeElementIterator(Attribute(*this), 0),
+ AttributeElementIterator(Attribute(*this), getNumElements()));
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::DerivedAttributeElementIterator<T>>>
+DenseElementsAttr::tryGetValues() const {
+ using DerivedIterT = DerivedAttributeElementIterator<T>;
+ return iterator_range_impl<DerivedIterT>(
+ getType(), DerivedIterT(value_begin<Attribute>()),
+ DerivedIterT(value_end<Attribute>()));
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::BoolElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ if (!isValidBool())
+ return failure();
+ return iterator_range_impl<BoolElementIterator>(
+ getType(), BoolElementIterator(*this, 0),
+ BoolElementIterator(*this, getNumElements()));
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::IntElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ if (!getElementType().isIntOrIndex())
+ return failure();
+ return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
+ raw_int_end());
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::ComplexIntElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ return tryGetComplexIntValues();
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::FloatElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ return tryGetFloatValues();
+}
+template <typename T, typename>
+FailureOr<DenseElementsAttr::iterator_range_impl<
+ DenseElementsAttr::ComplexFloatElementIterator>>
+DenseElementsAttr::tryGetValues() const {
+ return tryGetComplexFloatValues();
+}
/// An attribute that represents a reference to a splat vector or tensor
/// constant, meaning all of the elements have the same value.
@@ -659,21 +254,8 @@ class SplatElementsAttr : public DenseElementsAttr {
}
};
-//===----------------------------------------------------------------------===//
-// DenseResourceElementsAttr
-//===----------------------------------------------------------------------===//
-
-using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>;
-
} // namespace mlir
-//===----------------------------------------------------------------------===//
-// Tablegen Attribute Declarations
-//===----------------------------------------------------------------------===//
-
-#define GET_ATTRDEF_CLASSES
-#include "mlir/IR/BuiltinAttributes.h.inc"
-
//===----------------------------------------------------------------------===//
// C++ Attribute Declarations
//===----------------------------------------------------------------------===//
@@ -874,11 +456,11 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
/// An attribute that represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
-class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
+class DenseFPElementsAttr : public DenseElementsAttr {
public:
using iterator = DenseElementsAttr::FloatElementIterator;
- using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
+ using DenseElementsAttr::DenseElementsAttr;
/// Get an instance of a DenseFPElementsAttr with the given arguments. This
/// simply wraps the DenseElementsAttr::get calls.
@@ -913,13 +495,13 @@ class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
/// An attribute that represents a reference to a dense integer vector or tensor
/// object.
-class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
+class DenseIntElementsAttr : public DenseElementsAttr {
public:
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
/// iterator directly.
using iterator = DenseElementsAttr::IntElementIterator;
- using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
+ using DenseElementsAttr::DenseElementsAttr;
/// Get an instance of a DenseIntElementsAttr with the given arguments. This
/// simply wraps the DenseElementsAttr::get calls.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 064783ae5f87a..162759bfdbb8c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -169,7 +169,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
let summary = "A dense array of integer or floating point elements.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
- primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
+ primitive element types. Contrary to DenseElementsAttr this is a
flat unidimensional array which does not have a storage optimization for
splat. This allows to expose the raw array through a C++ API as
`ArrayRef<T>` for compatible types. The element type must be bool or an
@@ -231,12 +231,11 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
}
//===----------------------------------------------------------------------===//
-// DenseIntOrFPElementsAttr
+// DenseElementsAttr
//===----------------------------------------------------------------------===//
-def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
- "DenseIntOrFPElements", "dense_int_or_fp_elements", [ElementsAttrInterface],
- "DenseElementsAttr"
+def Builtin_DenseElementsAttr : Builtin_Attr<
+ "DenseElements", "dense_int_or_fp_elements", [ElementsAttrInterface]
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"values";
@@ -285,17 +284,257 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
"ArrayRef<char>":$rawData);
let extraClassDeclaration = [{
- using DenseElementsAttr::empty;
- using DenseElementsAttr::getNumElements;
- using DenseElementsAttr::getElementType;
- using DenseElementsAttr::getValues;
- using DenseElementsAttr::isSplat;
- using DenseElementsAttr::size;
- using DenseElementsAttr::value_begin;
+ /// Allow implicit conversion to ElementsAttr.
+ operator ElementsAttr() const { return cast_if_present<ElementsAttr>(*this); }
+ /// Allow implicit conversion to TypedAttr.
+ operator TypedAttr() const { return ElementsAttr(*this); }
- /// The set of data types that can be iterated by this attribute.
+ template <typename T>
+ struct is_valid_cpp_fp_type {
+ static constexpr bool value = llvm::is_one_of<T, float, double>::value ||
+ (std::numeric_limits<T>::is_specialized &&
+ !std::numeric_limits<T>::is_integer);
+ };
+
+ static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
+ template <typename T,
+ typename = std::enable_if_t<std::numeric_limits<T>::is_integer ||
+ is_valid_cpp_fp_type<T>::value>>
+ static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values);
+ template <typename T,
+ typename = std::enable_if_t<std::numeric_limits<T>::is_integer ||
+ is_valid_cpp_fp_type<T>::value ||
+ detail::is_complex_t<T>::value>>
+ static DenseElementsAttr get(const ShapedType &type, T value);
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = std::enable_if_t<detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ is_valid_cpp_fp_type<ElementT>::value)>>
+ static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values);
+ static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
+ static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APInt>> values);
+ static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APFloat>> values);
+ template <typename T>
+ static DenseElementsAttr get(const ShapedType &type,
+ const std::initializer_list<T> &list);
+ static DenseElementsAttr getFromRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer);
+ static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer);
+
+ template <typename IteratorT>
+ using iterator_range_impl = detail::ElementsAttrRange<IteratorT>;
+ template <typename T, typename AttrT = DenseElementsAttr>
+ using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
+ template <typename T, typename AttrT = DenseElementsAttr>
+ using iterator_range =
+ decltype(std::declval<AttrT>().template getValues<T>());
+
+ class AttributeElementIterator
+ : public llvm::indexed_accessor_iterator<AttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute> {
+ public:
+ Attribute operator*() const;
+ AttributeElementIterator(Attribute attr, size_t index);
+ };
+ template <typename T>
+ class ElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>,
+ const T> {
+ public:
+ const T &operator*() const {
+ return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()];
+ }
+ private:
+ friend DenseElementsAttr;
+ ElementIterator(const char *data, bool isSplat, size_t dataIndex)
+ : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>(
+ data, isSplat, dataIndex) {}
+ };
+ class BoolElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator,
+ bool, bool, bool> {
+ public:
+ bool operator*() const;
+ private:
+ friend DenseElementsAttr;
+ BoolElementIterator(DenseElementsAttr attr, size_t dataIndex);
+ };
+ class IntElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<IntElementIterator,
+ APInt, APInt, APInt> {
+ public:
+ APInt operator*() const;
+ private:
+ friend DenseElementsAttr;
+ IntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+ size_t bitWidth;
+ };
+ class ComplexIntElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<
+ ComplexIntElementIterator, std::complex<APInt>,
+ std::complex<APInt>, std::complex<APInt>> {
+ public:
+ std::complex<APInt> operator*() const;
+ private:
+ friend DenseElementsAttr;
+ ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+ size_t bitWidth;
+ };
+ class FloatElementIterator final
+ : public llvm::mapped_iterator_base<FloatElementIterator,
+ IntElementIterator, APFloat> {
+ public:
+ APFloat mapElement(const APInt &value) const {
+ return APFloat(*smt, value);
+ }
+ private:
+ friend DenseElementsAttr;
+ FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
+ : BaseT(it), smt(&smt) {}
+ const llvm::fltSemantics *smt;
+ };
+ class ComplexFloatElementIterator final
+ : public llvm::mapped_iterator_base<ComplexFloatElementIterator,
+ ComplexIntElementIterator,
+ std::complex<APFloat>> {
+ public:
+ std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
+ return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
+ }
+ private:
+ friend DenseElementsAttr;
+ ComplexFloatElementIterator(const llvm::fltSemantics &smt,
+ ComplexIntElementIterator it)
+ : BaseT(it), smt(&smt) {}
+ const llvm::fltSemantics *smt;
+ };
+
+ bool isSplat() const;
+ template <typename T>
+ std::enable_if_t<!std::is_base_of<Attribute, T>::value ||
+ std::is_same<Attribute, T>::value,
+ T>
+ getSplatValue() const;
+ template <typename T>
+ std::enable_if_t<std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value,
+ T>
+ getSplatValue() const;
+ template <typename T>
+ auto try_value_begin() const;
+ template <typename T>
+ auto try_value_end() const;
+ template <typename T>
+ auto getValues() const;
+ template <typename T>
+ auto value_begin() const;
+ template <typename T>
+ auto value_end() const;
+
+ template <typename T>
+ using IntFloatValueTemplateCheckT =
+ std::enable_if_t<(!std::is_same<T, bool>::value &&
+ std::numeric_limits<T>::is_integer) ||
+ is_valid_cpp_fp_type<T>::value>;
+ template <typename T, typename = IntFloatValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const;
+ template <typename T, typename ElementT>
+ using ComplexValueTemplateCheckT =
+ std::enable_if_t<detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ is_valid_cpp_fp_type<ElementT>::value)>;
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = ComplexValueTemplateCheckT<T, ElementT>>
+ FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const;
+ template <typename T>
+ using AttributeValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, Attribute>::value>;
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<AttributeElementIterator>>
+ tryGetValues() const;
+ template <typename T>
+ using DerivedAttrValueTemplateCheckT =
+ std::enable_if_t<std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value>;
+ template <typename T>
+ struct DerivedAttributeElementIterator
+ : public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
+ AttributeElementIterator, T> {
+ using llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
+ AttributeElementIterator,
+ T>::mapped_iterator_base;
+ T mapElement(Attribute attr) const { return llvm::cast<T>(attr); }
+ };
+ template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
+ tryGetValues() const;
+ template <typename T>
+ using BoolValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, bool>::value>;
+ template <typename T, typename = BoolValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<BoolElementIterator>> tryGetValues() const;
+ template <typename T>
+ using APIntValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, APInt>::value>;
+ template <typename T, typename = APIntValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const;
+ template <typename T>
+ using ComplexAPIntValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>;
+ template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<ComplexIntElementIterator>>
+ tryGetValues() const;
+ template <typename T>
+ using APFloatValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, APFloat>::value>;
+ template <typename T, typename = APFloatValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const;
+ template <typename T>
+ using ComplexAPFloatValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>;
+ template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
+ FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
+ tryGetValues() const;
+
+ ArrayRef<char> getRawData() const;
+ ShapedType getType() const;
+ Type getElementType() const;
+ int64_t getNumElements() const;
+ int64_t size() const { return getNumElements(); }
+ bool empty() const { return size() == 0; }
+ DenseElementsAttr reshape(ShapedType newType);
+ DenseElementsAttr resizeSplat(ShapedType newType);
+ DenseElementsAttr bitcast(Type newElType);
+ DenseElementsAttr mapValues(Type newElementType,
+ function_ref<APInt(const APInt &)> mapping) const;
+ DenseElementsAttr
+ mapValues(Type newElementType,
+ function_ref<APInt(const APFloat &)> mapping) const;
+
+ protected:
+ IntElementIterator raw_int_begin() const {
+ return IntElementIterator(*this, 0);
+ }
+ IntElementIterator raw_int_end() const {
+ return IntElementIterator(*this, getNumElements());
+ }
+ FailureOr<iterator_range_impl<ComplexIntElementIterator>>
+ tryGetComplexIntValues() const;
+ FailureOr<iterator_range_impl<FloatElementIterator>>
+ tryGetFloatValues() const;
+ FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
+ tryGetComplexFloatValues() const;
+ bool isValidBool() const { return getElementType().isInteger(1); }
+ bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
+ bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
+
+ public:
using ContiguousIterableTypesT = std::tuple<
- // Integer types.
uint8_t, uint16_t, uint32_t, uint64_t,
int8_t, int16_t, int32_t, int64_t,
short, unsigned short, int, unsigned, long, unsigned long,
@@ -303,46 +542,24 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
std::complex<uint64_t>,
std::complex<int8_t>, std::complex<int16_t>, std::complex<int32_t>,
std::complex<int64_t>,
- // Float types.
float, double, std::complex<float>, std::complex<double>
>;
using NonContiguousIterableTypesT = std::tuple<
- Attribute,
- // Integer types.
- APInt, bool, std::complex<APInt>,
- // Float types.
- APFloat, std::complex<APFloat>
+ Attribute, APInt, bool, std::complex<APInt>, APFloat, std::complex<APFloat>
>;
-
- /// Provide a `try_value_begin_impl` to enable iteration within
- /// ElementsAttr.
template <typename T>
auto try_value_begin_impl(OverloadToken<T>) const {
return try_value_begin<T>();
}
-
- /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
- /// the elements of `inRawData` has `type`. If `inRawData` is little endian
- /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is
- /// BE, converted to LE.
- static void
- convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,
+ static void convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,
MutableArrayRef<char> outRawData,
ShapedType type);
-
- /// Convert endianess of input for big-endian(BE) machines. The number of
- /// elements of `inRawData` is `numElements`, and each element has
- /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is
- /// converted to big endian (BE) and saved in `outRawData`. Conversely, if
- /// `inRawData` is BE, converted to LE.
static void convertEndianOfCharForBEmachine(const char *inRawData,
char *outRawData,
size_t elementBitWidth,
size_t numElements);
protected:
- friend DenseElementsAttr;
-
/// Constructs a dense elements attribute from an array of raw APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..495e4e18cd6cb 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -172,8 +172,7 @@ def DenseArrayAttr : DialectAttribute<(attr
Blob:$rawData
)>;
-def DenseElementsAttr : WithType<"DenseElementsAttr", Attribute>;
-def DenseIntOrFPElementsAttr : DialectAttribute<(attr
+def DenseElementsAttr : DialectAttribute<(attr
ShapedType:$type,
Blob:$rawData
)> {
@@ -333,7 +332,7 @@ def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
UnknownLoc,
DenseResourceElementsAttr,
DenseArrayAttr,
- DenseIntOrFPElementsAttr,
+ DenseElementsAttr,
DenseStringElementsAttr,
SparseElementsAttr,
DistinctAttr,
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 15c2e0225f98b..302dd0498b38f 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -732,8 +732,8 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
// machines.
SmallVector<char, 64> outDataVec(rawData.size());
MutableArrayRef<char> convRawData(outDataVec);
- DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
- rawData, convRawData, type);
+ DenseElementsAttr::convertEndianOfArrayRefForBEmachine(rawData, convRawData,
+ type);
return DenseElementsAttr::getFromRawBuffer(type, convRawData);
}
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2a13889d773c1..31cfb5e6514b3 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1300,8 +1300,7 @@ nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
std::string msg =
- std::string(
- "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
+ std::string("Can't cast unknown element type DenseElementsAttr (") +
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
throw nb::type_error(msg.c_str());
}
@@ -1399,7 +1398,7 @@ void populateIRAttributes(nb::module_ &m) {
PyDenseFPElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
- mlirDenseIntOrFPElementsAttrGetTypeID(),
+ mlirDenseElementsAttrGetTypeID(),
nb::cast<nb::callable>(
nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
PyDenseResourceElementsAttribute::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 7325179c047c5..dd6731faf7be0 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -559,8 +559,8 @@ bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
}
-MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
- return wrap(DenseIntOrFPElementsAttr::getTypeID());
+MlirTypeID mlirDenseElementsAttrGetTypeID(void) {
+ return wrap(DenseElementsAttr::getTypeID());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index 3e34246f66f2c..2831d68da50a4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -73,7 +73,7 @@ struct ConstantShardingInterface
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
- if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
+ if (auto value = dyn_cast<DenseElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
// Currently non-splat constants are not supported.
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 0c53cd2589f42..e0b74fba2779b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -115,7 +115,7 @@ class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
// All inputs should be constants.
int numInputs = linalgOp.getNumDpsInputs();
- SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
+ SmallVector<DenseElementsAttr> inputValues(numInputs);
for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 34e06bf52f70d..a8b288a0762a5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -575,7 +575,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
- if (isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
+ if (isa<DenseElementsAttr, SparseElementsAttr>(value)) {
auto valueType = cast<TypedAttr>(value).getType();
if (valueType == opType)
return success();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b3242f838fc1d..7880cb4b575d0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -507,12 +507,6 @@ class AsmPrinter::Impl {
/// Print a dense string elements attribute.
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
- /// Print a dense elements attribute in the literal-first syntax. If
- /// 'allowHex' is true, a hex string is used instead of individual elements
- /// when the elements attr is large.
- void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
- bool allowHex);
-
/// Print a dense elements attribute using the type-first syntax and the
/// DenseElementTypeInterface, which provides the attribute printer for each
/// element.
@@ -2508,8 +2502,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
printSymbolReference(nestedRef.getValue(), os);
}
- } else if (auto intOrFpEltAttr =
- llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
+ } else if (auto intOrFpEltAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
printElidedElementsAttr(os);
} else {
@@ -2519,7 +2512,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
// the existing printing format for backwards compatibility.
Type eltType = intOrFpEltAttr.getElementType();
if (isa<FloatType, IntegerType, IndexType, ComplexType>(eltType)) {
- printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
+ printDenseElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
} else {
printTypeFirstDenseElementsAttr(intOrFpEltAttr,
cast<DenseElementType>(eltType));
@@ -2545,7 +2538,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
os << "sparse<";
DenseIntElementsAttr indices = sparseEltAttr.getIndices();
if (indices.getNumElements() != 0) {
- printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
+ printDenseElementsAttr(indices, /*allowHex=*/false);
os << ", ";
printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
}
@@ -2645,15 +2638,6 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
- if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
- return printDenseStringElementsAttr(stringAttr);
-
- printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
- allowHex);
-}
-
-void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
- DenseIntOrFPElementsAttr attr, bool allowHex) {
auto type = attr.getType();
auto elementType = type.getElementType();
@@ -2665,8 +2649,8 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
// machines. It is converted here to print in LE format.
SmallVector<char, 64> outDataVec(rawData.size());
MutableArrayRef<char> convRawData(outDataVec);
- DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
- rawData, convRawData, type);
+ DenseElementsAttr::convertEndianOfArrayRefForBEmachine(rawData,
+ convRawData, type);
printHexString(convRawData);
} else {
printHexString(rawData);
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 8505149afdd9c..b8032fca005e0 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -47,8 +47,8 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
};
/// An attribute representing a reference to a dense vector or tensor object.
-struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
- DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data)
+struct DenseElementsAttrStorage : public DenseElementsAttributeStorage {
+ DenseElementsAttrStorage(ShapedType ty, ArrayRef<char> data)
: DenseElementsAttributeStorage(ty), data(data) {}
struct KeyTy {
@@ -108,7 +108,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
}
/// Construct a new storage instance.
- static DenseIntOrFPElementsAttrStorage *
+ static DenseElementsAttrStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator with a
// 64-bit alignment.
@@ -120,8 +120,8 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
copy = ArrayRef<char>(rawData, data.size());
}
- return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
- DenseIntOrFPElementsAttrStorage(key.type, copy);
+ return new (allocator.allocate<DenseElementsAttrStorage>())
+ DenseElementsAttrStorage(key.type, copy);
}
ArrayRef<char> data;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e288be3271fab..5b2b8da16ec9d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -481,13 +481,13 @@ static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
// ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
size_t lastWordPos = numFilledWords;
SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ DenseElementsAttr::convertEndianOfCharForBEmachine(
reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
// Extract actual APInt data from `valueLE`, convert endianness to BE format,
// and store it in `result`.
// ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ DenseElementsAttr::convertEndianOfCharForBEmachine(
valueLE.begin(), result + lastWordPos,
(numBytes - lastWordPos) * CHAR_BIT, 1);
}
@@ -514,13 +514,13 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
// ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
size_t lastWordPos = numFilledWords;
SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ DenseElementsAttr::convertEndianOfCharForBEmachine(
inArray + lastWordPos, inArrayLE.begin(),
(numBytes - lastWordPos) * CHAR_BIT, 1);
// Convert `inArrayLE` to BE format, and store it in last word of `result`.
// ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
- DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
+ DenseElementsAttr::convertEndianOfCharForBEmachine(
inArrayLE.begin(),
const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
lastWordPos,
@@ -853,11 +853,6 @@ template class DenseArrayAttrImpl<double>;
// DenseElementsAttr
//===----------------------------------------------------------------------===//
-/// Method for support type inquiry through isa, cast and dyn_cast.
-bool DenseElementsAttr::classof(Attribute attr) {
- return llvm::isa<DenseIntOrFPElementsAttr>(attr);
-}
-
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameNumElementsOrSplat(type, values));
@@ -870,14 +865,14 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
if (failed(result))
return {};
}
- return DenseIntOrFPElementsAttr::getRaw(type, data);
+ return DenseElementsAttr::getRaw(type, data);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<bool> values) {
assert(hasSameNumElementsOrSplat(type, values));
assert(type.getElementType().isInteger(1));
- return DenseIntOrFPElementsAttr::getRaw(
+ return DenseElementsAttr::getRaw(
type, ArrayRef<char>(reinterpret_cast<const char *>(values.data()),
values.size()));
}
@@ -890,7 +885,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(type.getElementType().isIntOrIndex());
assert(hasSameNumElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
+ return DenseElementsAttr::getRaw(type, storageBitWidth, values);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<std::complex<APInt>> values) {
@@ -900,7 +895,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
values.size() * 2);
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
+ return DenseElementsAttr::getRaw(type, storageBitWidth, intVals);
}
// Constructs a dense float elements attribute from an array of APFloat
@@ -911,7 +906,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(llvm::isa<FloatType>(type.getElementType()));
assert(hasSameNumElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
+ return DenseElementsAttr::getRaw(type, storageBitWidth, values);
}
DenseElementsAttr
DenseElementsAttr::get(ShapedType type,
@@ -922,7 +917,7 @@ DenseElementsAttr::get(ShapedType type,
ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
values.size() * 2);
size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
- return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
+ return DenseElementsAttr::getRaw(type, storageBitWidth, apVals);
}
/// Construct a dense elements attribute from a raw buffer representing the
@@ -930,7 +925,7 @@ DenseElementsAttr::get(ShapedType type,
/// the expected buffer format may not be a form the user expects.
DenseElementsAttr
DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
- return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
+ return DenseElementsAttr::getRaw(type, rawBuffer);
}
/// Returns true if the given buffer is a valid raw buffer for the given type.
@@ -986,23 +981,6 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
return valid;
}
-/// Defaults down the subclass implementation.
-DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt, bool isSigned) {
- return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
- isSigned);
-}
-DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt,
- bool isSigned) {
- return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
- isInt, isSigned);
-}
-
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
bool isSigned) const {
return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
@@ -1064,7 +1042,7 @@ auto DenseElementsAttr::tryGetComplexFloatValues() const
/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
- return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
+ return static_cast<DenseElementsAttrStorage *>(impl)->data;
}
/// Return a new DenseElementsAttr that has the same data as the current
@@ -1079,7 +1057,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
"expected the same element type");
assert(newType.getNumElements() == curType.getNumElements() &&
"expected the same number of elements");
- return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
+ return DenseElementsAttr::getRaw(newType, getRawData());
}
DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
@@ -1091,7 +1069,7 @@ DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
assert(newType.getElementType() == curType.getElementType() &&
"expected the same element type");
- return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
+ return DenseElementsAttr::getRaw(newType, getRawData());
}
/// Return a new DenseElementsAttr that has the same data as the current
@@ -1107,8 +1085,7 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
assert(getDenseElementBitWidth(newElType) ==
getDenseElementBitWidth(curElType) &&
"expected element types with the same bitwidth");
- return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
- getRawData());
+ return DenseElementsAttr::getRaw(curType.clone(newElType), getRawData());
}
DenseElementsAttr
@@ -1137,7 +1114,7 @@ int64_t DenseElementsAttr::getNumElements() const {
}
//===----------------------------------------------------------------------===//
-// DenseIntOrFPElementsAttr
+// DenseElementsAttr
//===----------------------------------------------------------------------===//
/// Utility method to write a range of APInt values to a buffer.
@@ -1158,28 +1135,28 @@ static void writeAPIntsToBuffer(size_t storageWidth,
/// Constructs a dense elements attribute from an array of raw APFloat values.
/// Each APFloat value is expected to have the same bitwidth as the element
/// type of 'type'. 'type' must be a vector or tensor with static shape.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- size_t storageWidth,
- ArrayRef<APFloat> values) {
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ size_t storageWidth,
+ ArrayRef<APFloat> values) {
SmallVector<char> data;
auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
- return DenseIntOrFPElementsAttr::getRaw(type, data);
+ return DenseElementsAttr::getRaw(type, data);
}
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- size_t storageWidth,
- ArrayRef<APInt> values) {
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ size_t storageWidth,
+ ArrayRef<APInt> values) {
SmallVector<char> data;
writeAPIntsToBuffer(storageWidth, data, values);
- return DenseIntOrFPElementsAttr::getRaw(type, data);
+ return DenseElementsAttr::getRaw(type, data);
}
-DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- ArrayRef<char> data) {
+DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
+ ArrayRef<char> data) {
assert(type.hasStaticShape() && "type must have static shape");
assert(isValidRawBuffer(type, data));
return Base::get(type.getContext(), type, data);
@@ -1188,11 +1165,10 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
/// Overload of the raw 'get' method that asserts that the given type is of
/// complex type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
-DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
- ArrayRef<char> data,
- int64_t dataEltSize,
- bool isInt,
- bool isSigned) {
+DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt, bool isSigned) {
assert(::isValidIntOrFloat(
llvm::cast<ComplexType>(type.getElementType()).getElementType(),
dataEltSize / 2, isInt, isSigned) &&
@@ -1207,10 +1183,11 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
/// Overload of the 'getRaw' method that asserts that the given type is of
/// integer type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
-DenseElementsAttr
-DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
- int64_t dataEltSize, bool isInt,
- bool isSigned) {
+DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt,
+ bool isSigned) {
assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt,
isSigned) &&
"Try re-running with -debug-only=builtinattributes");
@@ -1221,9 +1198,10 @@ DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
return getRaw(type, data);
}
-void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
- const char *inRawData, char *outRawData, size_t elementBitWidth,
- size_t numElements) {
+void DenseElementsAttr::convertEndianOfCharForBEmachine(const char *inRawData,
+ char *outRawData,
+ size_t elementBitWidth,
+ size_t numElements) {
using llvm::support::ulittle16_t;
using llvm::support::ulittle32_t;
using llvm::support::ulittle64_t;
@@ -1263,7 +1241,7 @@ void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
}
}
-void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
+void DenseElementsAttr::convertEndianOfArrayRefForBEmachine(
ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
ShapedType type) {
size_t numElements = type.getNumElements();
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index cf00216288115..2387781fc15a1 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -2075,7 +2075,7 @@ void ByteCodeExecutor::executeSwitchAttribute() {
void ByteCodeExecutor::executeSwitchOperandCount() {
LDBG() << "Executing SwitchOperandCount:";
Operation *op = read<Operation *>();
- auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+ auto cases = read<DenseElementsAttr>().getValues<uint32_t>();
LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumOperands(), cases);
@@ -2112,7 +2112,7 @@ void ByteCodeExecutor::executeSwitchOperationName() {
void ByteCodeExecutor::executeSwitchResultCount() {
LDBG() << "Executing SwitchResultCount:";
Operation *op = read<Operation *>();
- auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+ auto cases = read<DenseElementsAttr>().getValues<uint32_t>();
LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumResults(), cases);
diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py
index 45fd0837c9391..d5a57081efb98 100644
--- a/mlir/utils/gdb-scripts/prettyprinters.py
+++ b/mlir/utils/gdb-scripts/prettyprinters.py
@@ -159,7 +159,7 @@ def to_string(self):
"TypeAttr",
"UnitAttr",
"DenseStringElementsAttr",
- "DenseIntOrFPElementsAttr",
+ "DenseElementsAttr",
"SparseElementsAttr",
# mlir/IR/BuiltinTypes.h
"ComplexType",
>From ecc8dd5be3e0d7228e9414832cf5b86a65078c55 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 15 Feb 2026 18:19:30 +0000
Subject: [PATCH 2/3] move to TD
---
mlir/include/mlir/IR/BuiltinAttributes.h | 42 -----------------------
mlir/include/mlir/IR/BuiltinAttributes.td | 24 ++++++++-----
2 files changed, 16 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 6a0fca6f92dba..93dbf22005c1c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -100,10 +100,6 @@ DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed);
}
-template <typename T, typename>
-DenseElementsAttr DenseElementsAttr::get(const ShapedType &type, T value) {
- return get(type, llvm::ArrayRef(value));
-}
template <typename T, typename ElementT, typename>
DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
ArrayRef<T> values) {
@@ -113,11 +109,6 @@ DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
std::numeric_limits<ElementT>::is_signed);
}
template <typename T>
-DenseElementsAttr DenseElementsAttr::get(const ShapedType &type,
- const std::initializer_list<T> &list) {
- return get(type, ArrayRef<T>(list));
-}
-template <typename T>
std::enable_if_t<!std::is_base_of<Attribute, T>::value ||
std::is_same<Attribute, T>::value,
T>
@@ -126,13 +117,6 @@ DenseElementsAttr::getSplatValue() const {
return *value_begin<T>();
}
template <typename T>
-std::enable_if_t<std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value,
- T>
-DenseElementsAttr::getSplatValue() const {
- return llvm::cast<T>(getSplatValue<Attribute>());
-}
-template <typename T>
auto DenseElementsAttr::try_value_begin() const {
auto range = tryGetValues<T>();
using iterator = decltype(range->begin());
@@ -150,14 +134,6 @@ auto DenseElementsAttr::getValues() const {
assert(succeeded(range) && "element type cannot be iterated");
return std::move(*range);
}
-template <typename T>
-auto DenseElementsAttr::value_begin() const {
- return getValues<T>().begin();
-}
-template <typename T>
-auto DenseElementsAttr::value_end() const {
- return getValues<T>().end();
-}
// tryGetValues template definitions (required for instantiation in other TUs).
template <typename T, typename>
@@ -222,24 +198,6 @@ DenseElementsAttr::tryGetValues() const {
return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
raw_int_end());
}
-template <typename T, typename>
-FailureOr<DenseElementsAttr::iterator_range_impl<
- DenseElementsAttr::ComplexIntElementIterator>>
-DenseElementsAttr::tryGetValues() const {
- return tryGetComplexIntValues();
-}
-template <typename T, typename>
-FailureOr<DenseElementsAttr::iterator_range_impl<
- DenseElementsAttr::FloatElementIterator>>
-DenseElementsAttr::tryGetValues() const {
- return tryGetFloatValues();
-}
-template <typename T, typename>
-FailureOr<DenseElementsAttr::iterator_range_impl<
- DenseElementsAttr::ComplexFloatElementIterator>>
-DenseElementsAttr::tryGetValues() const {
- return tryGetComplexFloatValues();
-}
/// An attribute that represents a reference to a splat vector or tensor
/// constant, meaning all of the elements have the same value.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 162759bfdbb8c..11f8d79bfcb81 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -305,7 +305,9 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
typename = std::enable_if_t<std::numeric_limits<T>::is_integer ||
is_valid_cpp_fp_type<T>::value ||
detail::is_complex_t<T>::value>>
- static DenseElementsAttr get(const ShapedType &type, T value);
+ static DenseElementsAttr get(const ShapedType &type, T value) {
+ return get(type, llvm::ArrayRef(value));
+ }
template <typename T, typename ElementT = typename T::value_type,
typename = std::enable_if_t<detail::is_complex_t<T>::value &&
(std::numeric_limits<ElementT>::is_integer ||
@@ -320,7 +322,9 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
ArrayRef<std::complex<APFloat>> values);
template <typename T>
static DenseElementsAttr get(const ShapedType &type,
- const std::initializer_list<T> &list);
+ const std::initializer_list<T> &list) {
+ return get(type, ArrayRef<T>(list));
+ }
static DenseElementsAttr getFromRawBuffer(ShapedType type,
ArrayRef<char> rawBuffer);
static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer);
@@ -424,7 +428,9 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
std::enable_if_t<std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value,
T>
- getSplatValue() const;
+ getSplatValue() const {
+ return llvm::cast<T>(getSplatValue<Attribute>());
+ }
template <typename T>
auto try_value_begin() const;
template <typename T>
@@ -432,9 +438,9 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
template <typename T>
auto getValues() const;
template <typename T>
- auto value_begin() const;
+ auto value_begin() const { return getValues<T>().begin(); }
template <typename T>
- auto value_end() const;
+ auto value_end() const { return getValues<T>().end(); }
template <typename T>
using IntFloatValueTemplateCheckT =
@@ -488,18 +494,20 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>;
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<ComplexIntElementIterator>>
- tryGetValues() const;
+ tryGetValues() const { return tryGetComplexIntValues(); }
template <typename T>
using APFloatValueTemplateCheckT =
std::enable_if_t<std::is_same<T, APFloat>::value>;
template <typename T, typename = APFloatValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const;
+ FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const {
+ return tryGetFloatValues();
+ }
template <typename T>
using ComplexAPFloatValueTemplateCheckT =
std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>;
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
- tryGetValues() const;
+ tryGetValues() const { return tryGetComplexFloatValues(); }
ArrayRef<char> getRawData() const;
ShapedType getType() const;
>From c3f5e45776d2ba7ab31d2c725475d4e6407e9859 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 15 Feb 2026 18:40:16 +0000
Subject: [PATCH 3/3] rename
---
.../llvm-prettyprinters/gdb/mlir-support.gdb | 4 ++--
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 6 +++---
mlir/docs/Tutorials/Toy/Ch-2.md | 2 +-
mlir/include/mlir-c/BuiltinAttributes.h | 2 +-
mlir/include/mlir/IR/BuiltinAttributes.td | 5 +----
mlir/lib/Bindings/Python/IRAttributes.cpp | 5 ++---
mlir/test/Dialect/Builtin/Bytecode/attrs.mlir | 6 +++---
7 files changed, 13 insertions(+), 17 deletions(-)
diff --git a/cross-project-tests/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.gdb b/cross-project-tests/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.gdb
index 269c24c83ee77..6263b045c47e6 100644
--- a/cross-project-tests/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.gdb
+++ b/cross-project-tests/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.gdb
@@ -143,5 +143,5 @@ print StringAttr
# CHECK-LABEL: +print ElementsAttr
print ElementsAttr
-# CHECK: typeID = mlir::TypeID::get<mlir::DenseIntOrFPElementsAttr>()
-# CHECK: members of mlir::detail::DenseIntOrFPElementsAttrStorage
+# CHECK: typeID = mlir::TypeID::get<mlir::DenseElementsAttr>()
+# CHECK: members of mlir::detail::DenseElementsAttrStorage
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 625701725003f..fd8040c3e55c8 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3236,10 +3236,10 @@ static inline bool attributeTypeIsCompatible(mlir::MLIRContext *ctx,
// Get attr's LLVM element type.
if (!attr)
return true;
- auto intOrFpEleAttr = mlir::dyn_cast<mlir::DenseIntOrFPElementsAttr>(attr);
- if (!intOrFpEleAttr)
+ auto denseEleAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(attr);
+ if (!denseEleAttr)
return true;
- auto tensorTy = mlir::dyn_cast<mlir::TensorType>(intOrFpEleAttr.getType());
+ auto tensorTy = mlir::dyn_cast<mlir::TensorType>(denseEleAttr.getType());
if (!tensorTy)
return true;
mlir::Type attrEleTy =
diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md
index 81e41615ee55d..d6c317314de47 100644
--- a/mlir/docs/Tutorials/Toy/Ch-2.md
+++ b/mlir/docs/Tutorials/Toy/Ch-2.md
@@ -241,7 +241,7 @@ operation. This operation will represent a constant value in the Toy language.
```
This operation takes zero operands, a
-[dense elements](../../Dialects/Builtin.md/#denseintorfpelementsattr) attribute named
+[dense elements](../../Dialects/Builtin.md/#denseelementsattr) attribute named
`value` to represent the constant value, and returns a single result of
[RankedTensorType](../../Dialects/Builtin.md/#rankedtensortype). An operation class
inherits from the [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index d2fcec0f2e62a..fb0138ac77108 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -468,7 +468,7 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
-/// Returns the typeID of an DenseIntOrFPElements attribute.
+/// Returns the typeID of a DenseElements attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirDenseElementsAttrGetTypeID(void);
/// Creates a dense elements attribute with the given Shaped type and elements
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 11f8d79bfcb81..d385df66acb95 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -235,7 +235,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
//===----------------------------------------------------------------------===//
def Builtin_DenseElementsAttr : Builtin_Attr<
- "DenseElements", "dense_int_or_fp_elements", [ElementsAttrInterface]
+ "DenseElements", "dense_elements", [ElementsAttrInterface]
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"values";
@@ -257,9 +257,6 @@ def Builtin_DenseElementsAttr : Builtin_Attr<
offset "i * ceildiv(w, 8)". In other words, each element starts at a full
byte offset.
- TODO: The name `DenseIntOrFPElements` is no longer accurate. The attribute
- will be renamed in the future.
-
Examples:
```
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 31cfb5e6514b3..358e462dcdcd6 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1294,7 +1294,7 @@ nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
throw nb::type_error(msg.c_str());
}
-nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
+nb::object denseElementsAttributeCaster(PyAttribute &pyAttribute) {
if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
@@ -1399,8 +1399,7 @@ void populateIRAttributes(nb::module_ &m) {
PyDenseIntElementsAttribute::bind(m);
PyGlobals::get().registerTypeCaster(
mlirDenseElementsAttrGetTypeID(),
- nb::cast<nb::callable>(
- nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
+ nb::cast<nb::callable>(nb::cpp_function(denseElementsAttributeCaster)));
PyDenseResourceElementsAttribute::bind(m);
PyDictAttribute::bind(m);
diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
index 0f5643aa3bb43..c6fd9f8e17329 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
@@ -25,14 +25,14 @@ module @TestDenseArray attributes {
} {}
//===----------------------------------------------------------------------===//
-// DenseIntOfFPElementsAttr
+// DenseElementsAttr
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: @TestDenseIntOrFPElements
+// CHECK-LABEL: @TestDenseElements
// CHECK: bytecode.test1 = dense<true> : tensor<256xi1>
// CHECK: bytecode.test2 = dense<[10, 32, -1]> : tensor<3xi8>
// CHECK: bytecode.test3 = dense<[1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03]> : tensor<3xf64>
-module @TestDenseIntOrFPElements attributes {
+module @TestDenseElements attributes {
bytecode.test1 = dense<true> : tensor<256xi1>,
bytecode.test2 = dense<[10, 32, 255]> : tensor<3xi8>,
bytecode.test3 = dense<[10.0, 32.0, 1809.0]> : tensor<3xf64>
More information about the llvm-branch-commits
mailing list