[llvm-branch-commits] [mlir] d789a23 - Revert "[mlir][IR] Generalize `DenseElementsAttr` to custom element types (#1…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 16 07:48:24 PST 2026
Author: Matthias Springer
Date: 2026-02-16T17:48:21+02:00
New Revision: d789a2392b6d2bec6f5c26aeeb4cb48eade48682
URL: https://github.com/llvm/llvm-project/commit/d789a2392b6d2bec6f5c26aeeb4cb48eade48682
DIFF: https://github.com/llvm/llvm-project/commit/d789a2392b6d2bec6f5c26aeeb4cb48eade48682.diff
LOG: Revert "[mlir][IR] Generalize `DenseElementsAttr` to custom element types (#1…"
This reverts commit f13301084923ed131aaea7fbadb9307a76bb21f6.
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/BuiltinTypeInterfaces.h
mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypeInterfaces.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
mlir/test/IR/dense-elements-type-interface.mlir
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index dced379d1f979..798d3c84f9618 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -239,48 +239,29 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
"DenseElementsAttr"
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
- "values";
+ "integer or floating-point values";
let description = [{
- A dense elements attribute stores one or multiple elements of the same type.
- The term "dense" refers to the fact that elements are not stored as
- individual MLIR attributes, but in a raw buffer. The attribute provides a
- covenience API to access elements in the form of MLIR attributes, but users
- should avoid that API in performance-critical code and utilize APIs that
- operate on raw bytes instead.
-
- The number of elements is determined by the `type` shaped type. (Unranked
- shaped types are not supported.) The element type of the shaped type must
- implement the `DenseElementType` interface. This type interface defines the
- bitwidth of an element and provides a serializer/deserializer to/from MLIR
- attributes.
-
- Storage format: Given an element bitwidth "w", element "i" starts at byte
- offset "i * ceildiv(w, 8)". In other words, each element starts at a full
- byte offset.
-
- TODO: The name `DenseIntOrFPElements` is no longer accurate. The attribute
- will be renamed in the future.
+ Syntax:
+
+ ```
+ tensor-literal ::= integer-literal | float-literal | bool-literal | [] | [tensor-literal (, tensor-literal)* ]
+ dense-intorfloat-elements-attribute ::= `dense` `<` tensor-literal `>` `:`
+ ( tensor-type | vector-type )
+ ```
+
+ A dense int-or-float elements attribute is an elements attribute containing
+ a densely packed vector or tensor of integer or floating-point values. The
+ element type of this attribute is required to be either an `IntegerType` or
+ a `FloatType`.
Examples:
```
- // Literal-first syntax: A splat tensor of integer values.
+ // A splat tensor of integer values.
dense<10> : tensor<2xi32>
-
- // Literal-first syntax: A tensor of 2 float32 elements.
+ // A tensor of 2 float32 elements.
dense<[10.0, 11.0]> : tensor<2xf32>
-
- // Type-first syntax: A splat tensor of integer values.
- dense<tensor<2xi32> : 10 : i32>
-
- // Type-first syntax: A tensor of 2 float32 elements.
- dense<tensor<2xf32> : [10.0, 11.0]>
```
-
- Note: The literal-first syntax is supported only for complex, float, index,
- int element types. The parser/print have special casing for these types.
- Dense element attributes with other element types must use the type-first
- syntax.
}];
let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
"ArrayRef<char>":$rawData);
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index 9425d554b427c..5f14517d8dd71 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -19,29 +19,6 @@ struct fltSemantics;
namespace mlir {
class FloatType;
class MLIRContext;
-
-namespace detail {
-/// Float type implementation of
-/// DenseElementTypeInterface::getDenseElementBitSize.
-size_t getFloatTypeDenseElementBitSize(Type type);
-
-/// Float type implementation of DenseElementTypeInterface::convertToAttribute.
-Attribute convertFloatTypeToAttribute(Type type, llvm::ArrayRef<char> rawData);
-
-/// Float type implementation of
-/// DenseElementTypeInterface::convertFromAttribute.
-LogicalResult
-convertFloatTypeFromAttribute(Type type, Attribute attr,
- llvm::SmallVectorImpl<char> &result);
-
-/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as
-/// an APInt. Handles endianness correctly.
-llvm::APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
-
-/// Write `value` to byte-aligned position `bitPos` in `rawData`. Handles
-/// endianness correctly.
-void writeBits(char *rawData, size_t bitPos, llvm::APInt value);
-} // namespace detail
} // namespace mlir
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 93c8c0694b467..9ef08b7020b99 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -41,70 +41,12 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
}];
}
-//===----------------------------------------------------------------------===//
-// DenseElementTypeInterface
-//===----------------------------------------------------------------------===//
-
-def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
- let cppNamespace = "::mlir";
- let description = [{
- This interface allows custom types to be used as element types in
- DenseElementsAttr. Types implementing this interface define:
-
- 1. The bit size for element storage.
- 2. Helper methods for converting from/to Attribute. This assumes that there
- is a corresponding attribute for each type that implements this
- interface.
-
- The helper methods for converting from/to Attribute are utilized when
- parsing/printing IR or iterating over the elements via Attribute.
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Return the number of bits required to store one element in dense
- storage.
-
- Note: The DenseElementsAttr infrastructure will automatically align
- every element to a full byte in storage. This limitation could be lifted
- in the future to support dense packing of non-byte-sized elements.
- }],
- /*retTy=*/"size_t",
- /*methodName=*/"getDenseElementBitSize",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Attribute deserialization / attribute factory: Convert raw storage bytes
- into an MLIR attribute. The size of `rawData` is
- "ceilDiv(getDenseElementBitSize(), 8)".
- }],
- /*retTy=*/"::mlir::Attribute",
- /*methodName=*/"convertToAttribute",
- /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Attribute serialization: Convert an MLIR attribute into raw bytes.
- Implementations must append "getDenseElementBitSize() / 8" values to
- `result`. Return "failure" if the attribute is incompatible with this
- element type.
- }],
- /*retTy=*/"::llvm::LogicalResult",
- /*methodName=*/"convertFromAttribute",
- /*args=*/(ins "::mlir::Attribute":$attr,
- "::llvm::SmallVectorImpl<char>&":$result)
- >,
- ];
-}
-
//===----------------------------------------------------------------------===//
// FloatTypeInterface
//===----------------------------------------------------------------------===//
def FloatTypeInterface : TypeInterface<"FloatType",
- [DenseElementTypeInterface, VectorElementTypeInterface]> {
+ [VectorElementTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
@@ -141,21 +83,6 @@ def FloatTypeInterface : TypeInterface<"FloatType",
/// The width includes the integer bit.
unsigned getFPMantissaWidth();
}];
-
- let extraTraitClassDeclaration = [{
- /// DenseElementTypeInterface implementations for float types.
- size_t getDenseElementBitSize() const {
- return ::mlir::detail::getFloatTypeDenseElementBitSize($_type);
- }
- ::mlir::Attribute convertToAttribute(::llvm::ArrayRef<char> rawData) const {
- return ::mlir::detail::convertFloatTypeToAttribute($_type, rawData);
- }
- ::llvm::LogicalResult
- convertFromAttribute(::mlir::Attribute attr,
- ::llvm::SmallVectorImpl<char> &result) const {
- return ::mlir::detail::convertFloatTypeFromAttribute($_type, attr, result);
- }
- }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e7d0a03a85e7d..806064faeda00 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -45,10 +45,7 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
// ComplexType
//===----------------------------------------------------------------------===//
-def Builtin_Complex : Builtin_Type<"Complex", "complex",
- [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
- ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
- ]> {
+def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
let summary = "Complex number with a parameterized element type";
let description = [{
Syntax:
@@ -563,9 +560,7 @@ def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
//===----------------------------------------------------------------------===//
def Builtin_Index : Builtin_Type<"Index", "index",
- [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
- ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>,
- VectorElementTypeInterface]> {
+ [VectorElementTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -596,10 +591,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface, QuantStorageTypeInterface,
- DeclareTypeInterfaceMethods<DenseElementTypeInterface, [
- "getDenseElementBitSize", "convertToAttribute",
- "convertFromAttribute"]>]> {
+ [VectorElementTypeInterface, QuantStorageTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index dc9744a42b730..5978a11d06bc9 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,7 +16,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -954,119 +953,6 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
return eltParser.getAttr();
}
-/// Try to parse a dense elements attribute with the type-first syntax.
-/// Syntax: dense<TYPE : [ATTR, ATTR, ...]>
-/// This syntax is used for types other than int, float, index and complex.
-///
-/// Returns:
-/// - "null" attribute if this is not the type-first syntax.
-/// - "failure" in case of a parse error.
-/// - A valid Attribute otherwise.
-static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
- // Skip l_paren because "parseType" would try to parse it as a tuple/function
- // type, but '(' starts a complex literal like in the literal-first syntax.
- if (p.getToken().is(Token::l_paren))
- return Attribute();
-
- // Parse type and valdiate that it's a shaped type.
- auto typeLoc = p.getToken().getLoc();
- Type type;
- OptionalParseResult typeResult = p.parseOptionalType(type);
- if (!typeResult.has_value())
- return Attribute(); // Not type-first syntax.
- if (failed(*typeResult))
- return failure(); // Type parse error.
-
- auto shapedType = dyn_cast<ShapedType>(type);
- if (!shapedType) {
- p.emitError(typeLoc, "expected a shaped type for dense elements");
- return failure();
- }
- if (!shapedType.hasStaticShape()) {
- p.emitError(typeLoc, "dense elements type must have static shape");
- return failure();
- }
-
- // Check that the element type implements DenseElementTypeInterface.
- auto denseEltType = dyn_cast<DenseElementType>(shapedType.getElementType());
- if (!denseEltType) {
- p.emitError(typeLoc,
- "element type must implement DenseElementTypeInterface "
- "for type-first dense syntax");
- return failure();
- }
-
- // Parse colon.
- if (p.parseToken(Token::colon, "expected ':' after type in dense attribute"))
- return failure();
-
- // Parse the element attributes and convert to raw bytes.
- SmallVector<char> rawData;
-
- // Helper to parse a single element.
- auto parseSingleElement = [&]() -> ParseResult {
- Attribute elemAttr = p.parseAttribute();
- if (!elemAttr)
- return failure();
- if (failed(denseEltType.convertFromAttribute(elemAttr, rawData))) {
- p.emitError("incompatible attribute for element type");
- return failure();
- }
- return success();
- };
-
- // Recursively parse elements matching the expected shape.
- std::function<ParseResult(ArrayRef<int64_t>)> parseElements;
- parseElements = [&](ArrayRef<int64_t> remainingShape) -> ParseResult {
- // Leaf: parse a single element.
- if (remainingShape.empty())
- return parseSingleElement();
-
- // Non-leaf: expect a list with the correct number of elements.
- int64_t expectedCount = remainingShape.front();
- ArrayRef<int64_t> innerShape = remainingShape.drop_front();
- int64_t actualCount = 0;
-
- auto parseOne = [&]() -> ParseResult {
- if (parseElements(innerShape))
- return failure();
- ++actualCount;
- return success();
- };
-
- if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOne))
- return failure();
-
- if (actualCount != expectedCount) {
- p.emitError() << "expected " << expectedCount
- << " elements in dimension, got " << actualCount;
- return failure();
- }
- return success();
- };
-
- // Parse elements.
- if (!p.getToken().is(Token::l_square)) {
- // Single element - parse as splat.
- if (parseSingleElement())
- return failure();
- } else if (shapedType.getShape().empty()) {
- // Scalar type shouldn't have a list.
- p.emitError(loc, "expected single element for scalar type, got list");
- return failure();
- } else {
- // Parse structured literal matching the shape.
- if (parseElements(shapedType.getShape()))
- return failure();
- }
-
- if (p.parseToken(Token::greater, "expected '>' to close dense attribute"))
- return failure();
-
- // Create the attribute from raw buffer.
- return DenseElementsAttr::getFromRawBuffer(shapedType, rawData);
-}
-
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();
@@ -1074,16 +960,7 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
- // Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]>
- FailureOr<Attribute> typedResult =
- parseDenseElementsAttrTyped(*this, attribLoc);
- if (failed(typedResult))
- return nullptr;
- if (*typedResult)
- return *typedResult;
-
- // Try to parse the literal-first syntax, which is the default format for
- // int, float, index and complex element types.
+ // Parse the literal data if necessary.
TensorLiteralParser literalParser(*this);
if (!consumeIf(Token::greater)) {
if (literalParser.parse(/*allowHex=*/true) ||
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b3242f838fc1d..81455699421cc 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -507,18 +507,11 @@ class AsmPrinter::Impl {
/// Print a dense string elements attribute.
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
- /// Print a dense elements attribute in the literal-first syntax. If
- /// 'allowHex' is true, a hex string is used instead of individual elements
- /// when the elements attr is large.
+ /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
+ /// used instead of individual elements when the elements attr is large.
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
- /// Print a dense elements attribute using the type-first syntax and the
- /// DenseElementTypeInterface, which provides the attribute printer for each
- /// element.
- void printTypeFirstDenseElementsAttr(DenseElementsAttr attr,
- DenseElementType denseEltType);
-
/// Print a dense array attribute.
void printDenseArrayAttr(DenseArrayAttr attr);
@@ -2514,17 +2507,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
printElidedElementsAttr(os);
} else {
os << "dense<";
- // Check if the element type implements DenseElementTypeInterface and is
- // not a built-in type. Built-in types (int, float, index, complex) use
- // the existing printing format for backwards compatibility.
- Type eltType = intOrFpEltAttr.getElementType();
- if (isa<FloatType, IntegerType, IndexType, ComplexType>(eltType)) {
- printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
- } else {
- printTypeFirstDenseElementsAttr(intOrFpEltAttr,
- cast<DenseElementType>(eltType));
- typeElision = AttrTypeElision::Must;
- }
+ printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
os << '>';
}
@@ -2722,27 +2705,6 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
-void AsmPrinter::Impl::printTypeFirstDenseElementsAttr(
- DenseElementsAttr attr, DenseElementType denseEltType) {
- // Print the type first: dense<TYPE : [ELEMENTS]>
- printType(attr.getType());
- os << " : ";
-
- ArrayRef<char> rawData = attr.getRawData();
- // Storage is byte-aligned: align bit size up to next byte boundary.
- size_t bitSize = denseEltType.getDenseElementBitSize();
- size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT));
-
- // Print elements: convert raw bytes to attribute, then print attribute.
- printDenseElementsAttrImpl(
- attr.isSplat(), attr.getType(), os, [&](unsigned index) {
- size_t offset = attr.isSplat() ? 0 : index * byteSize;
- ArrayRef<char> elemData = rawData.slice(offset, byteSize);
- Attribute elemAttr = denseEltType.convertToAttribute(elemData);
- printAttributeImpl(elemAttr);
- });
-}
-
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
Type type = attr.getElementType();
unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 8505149afdd9c..1f268603cf37f 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -16,7 +16,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
@@ -33,9 +32,12 @@ namespace detail {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
- if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
- return denseEltType.getDenseElementBitSize();
- llvm_unreachable("unsupported element type");
+ // Align the width for complex to 8 to make storage and interpretation easier.
+ if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
+ return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
+ if (eltType.isIndex())
+ return IndexType::kInternalStorageBitWidth;
+ return eltType.getIntOrFloatBitWidth();
}
/// An attribute representing a reference to a dense vector or tensor object.
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bbbc9198a68ab..1a29fc534b40f 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,7 +10,6 @@
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -528,7 +527,7 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
}
/// Writes value to the bit position `bitPos` in array `rawData`.
-void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) {
+static void writeBits(char *rawData, size_t bitPos, APInt value) {
size_t bitWidth = value.getBitWidth();
// The bit position is guaranteed to be byte aligned.
@@ -550,8 +549,7 @@ void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) {
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData`.
-APInt mlir::detail::readBits(const char *rawData, size_t bitPos,
- size_t bitWidth) {
+static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
// The bit position is guaranteed to be byte aligned.
assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
APInt result(bitWidth, 0);
@@ -597,21 +595,39 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
Type eltTy = owner.getElementType();
+ if (llvm::dyn_cast<IntegerType>(eltTy))
+ return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+ if (llvm::isa<IndexType>(eltTy))
+ return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
+ if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
+ IntElementIterator intIt(owner, index);
+ FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
+ return FloatAttr::get(eltTy, *floatIt);
+ }
+ if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
+ auto complexEltTy = complexTy.getElementType();
+ ComplexIntElementIterator complexIntIt(owner, index);
+ if (llvm::isa<IntegerType>(complexEltTy)) {
+ auto value = *complexIntIt;
+ auto real = IntegerAttr::get(complexEltTy, value.real());
+ auto imag = IntegerAttr::get(complexEltTy, value.imag());
+ return ArrayAttr::get(complexTy.getContext(),
+ ArrayRef<Attribute>{real, imag});
+ }
- // Handle strings specially.
+ ComplexFloatElementIterator complexFloatIt(
+ llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
+ auto value = *complexFloatIt;
+ auto real = FloatAttr::get(complexEltTy, value.real());
+ auto imag = FloatAttr::get(complexEltTy, value.imag());
+ return ArrayAttr::get(complexTy.getContext(),
+ ArrayRef<Attribute>{real, imag});
+ }
if (llvm::isa<DenseStringElementsAttr>(owner)) {
ArrayRef<StringRef> vals = owner.getRawStringData();
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
}
-
- // All other types should implement DenseElementTypeInterface.
- auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
- ArrayRef<char> rawData = owner.getRawData();
- // Storage is byte-aligned: align bit size up to next byte boundary.
- size_t bitSize = denseEltTy.getDenseElementBitSize();
- size_t byteSize = llvm::divideCeil(bitSize, CHAR_BIT);
- size_t offset = owner.isSplat() ? 0 : index * byteSize;
- return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
+ llvm_unreachable("unexpected element type");
}
//===----------------------------------------------------------------------===//
@@ -872,28 +888,79 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(hasSameNumElementsOrSplat(type, values));
Type eltType = type.getElementType();
- // Handle strings specially.
- if (!llvm::isa<DenseElementType>(eltType)) {
+ // Take care complex type case first.
+ if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
+ if (complexType.getElementType().isIntOrIndex()) {
+ SmallVector<std::complex<APInt>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
+ auto arrayAttr = llvm::cast<ArrayAttr>(attr);
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
+ llvm::cast<IntegerAttr>(attr1).getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+ // Must be float.
+ SmallVector<std::complex<APFloat>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
+ auto arrayAttr = llvm::cast<ArrayAttr>(attr);
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
+ llvm::cast<FloatAttr>(attr1).getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+
+ // If the element type is not based on int/float/index, assume it is a string
+ // type.
+ if (!eltType.isIntOrIndexOrFloat()) {
SmallVector<StringRef, 8> stringValues;
stringValues.reserve(values.size());
for (Attribute attr : values) {
assert(llvm::isa<StringAttr>(attr) &&
- "expected string value for non-DenseElementType element");
+ "expected string value for non integer/index/float element");
stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
}
return get(type, stringValues);
}
- // All other types go through DenseElementTypeInterface.
- auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
- assert(denseEltType &&
- "attempted to get DenseElementsAttr with unsupported element type");
- SmallVector<char> data;
- for (Attribute attr : values) {
- LogicalResult result = denseEltType.convertFromAttribute(attr, data);
- if (failed(result))
+ // Otherwise, get the raw storage width to use for the allocation.
+ size_t bitWidth = getDenseElementBitWidth(eltType);
+ size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+
+ // Compress the attribute values into a character buffer.
+ SmallVector<char, 8> data(
+ llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
+ APInt intVal;
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
+ assert(floatAttr.getType() == eltType &&
+ "expected float attribute type to equal element type");
+ intVal = floatAttr.getValue().bitcastToAPInt();
+ } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(values[i])) {
+ assert(intAttr.getType() == eltType &&
+ "expected integer attribute type to equal element type");
+ intVal = intAttr.getValue();
+ } else {
+ // Unsupported attribute type.
return {};
+ }
+
+ assert(intVal.getBitWidth() == bitWidth &&
+ "expected value to have same bitwidth as element type");
+ writeBits(data.data(), i * storageBitWidth, intVal);
}
+
return DenseIntOrFPElementsAttr::getRaw(type, data);
}
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 29303d95eb003..2f063be3e7cd0 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -6,12 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/CheckedArithmetic.h"
-#include "llvm/Support/MathExtras.h"
-#include <climits>
using namespace mlir;
using namespace mlir::detail;
@@ -22,37 +19,6 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
-//===----------------------------------------------------------------------===//
-// DenseElementTypeInterface implementations for float types
-//===----------------------------------------------------------------------===//
-
-size_t mlir::detail::getFloatTypeDenseElementBitSize(Type type) {
- return cast<FloatType>(type).getWidth();
-}
-
-Attribute mlir::detail::convertFloatTypeToAttribute(Type type,
- ArrayRef<char> rawData) {
- auto floatType = cast<FloatType>(type);
- APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth());
- APFloat floatVal(floatType.getFloatSemantics(), intVal);
- return FloatAttr::get(type, floatVal);
-}
-
-LogicalResult
-mlir::detail::convertFloatTypeFromAttribute(Type type, Attribute attr,
- SmallVectorImpl<char> &result) {
- auto floatType = cast<FloatType>(type);
- auto floatAttr = dyn_cast<FloatAttr>(attr);
- if (!floatAttr || floatAttr.getType() != type)
- return failure();
- size_t byteSize =
- llvm::divideCeil(floatType.getWidth(), static_cast<unsigned>(CHAR_BIT));
- size_t bitPos = result.size() * CHAR_BIT;
- result.resize(result.size() + byteSize);
- writeBits(result.data(), bitPos, floatAttr.getValue().bitcastToAPInt());
- return success();
-}
-
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 786c30851a071..1e198043c590a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -12,17 +12,14 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/APFloat.h"
-#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CheckedArithmetic.h"
-#include <cstring>
using namespace mlir;
using namespace mlir::detail;
@@ -61,39 +58,6 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-size_t ComplexType::getDenseElementBitSize() const {
- auto elemTy = cast<DenseElementType>(getElementType());
- return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2;
-}
-
-Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
- auto elemTy = cast<DenseElementType>(getElementType());
- size_t singleElementBytes =
- llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8;
- Attribute real =
- elemTy.convertToAttribute(rawData.take_front(singleElementBytes));
- Attribute imag =
- elemTy.convertToAttribute(rawData.take_back(singleElementBytes));
- return ArrayAttr::get(getContext(), {real, imag});
-}
-
-LogicalResult
-ComplexType::convertFromAttribute(Attribute attr,
- SmallVectorImpl<char> &result) const {
- auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2)
- return failure();
- auto elemTy = cast<DenseElementType>(getElementType());
- SmallVector<char> realData, imagData;
- if (failed(elemTy.convertFromAttribute(arrayAttr[0], realData)))
- return failure();
- if (failed(elemTy.convertFromAttribute(arrayAttr[1], imagData)))
- return failure();
- result.append(realData);
- result.append(imagData);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -121,57 +85,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
}
-size_t IntegerType::getDenseElementBitSize() const {
- // Return the actual bit width. Storage alignment is handled separately.
- return getWidth();
-}
-
-Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const {
- APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth());
- return IntegerAttr::get(*this, value);
-}
-
-static void writeAPIntToVector(APInt apInt, SmallVectorImpl<char> &result) {
- size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT);
- size_t bitPos = result.size() * CHAR_BIT;
- result.resize(result.size() + byteSize);
- detail::writeBits(result.data(), bitPos, apInt);
-}
-
-LogicalResult
-IntegerType::convertFromAttribute(Attribute attr,
- SmallVectorImpl<char> &result) const {
- auto intAttr = dyn_cast<IntegerAttr>(attr);
- if (!intAttr || intAttr.getType() != *this)
- return failure();
- writeAPIntToVector(intAttr.getValue(), result);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Index Type
-//===----------------------------------------------------------------------===//
-
-size_t IndexType::getDenseElementBitSize() const {
- return kInternalStorageBitWidth;
-}
-
-Attribute IndexType::convertToAttribute(ArrayRef<char> rawData) const {
- APInt value =
- detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth);
- return IntegerAttr::get(*this, value);
-}
-
-LogicalResult
-IndexType::convertFromAttribute(Attribute attr,
- SmallVectorImpl<char> &result) const {
- auto intAttr = dyn_cast<IntegerAttr>(attr);
- if (!intAttr || intAttr.getType() != *this)
- return failure();
- writeAPIntToVector(intAttr.getValue(), result);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Float Types
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
deleted file mode 100644
index 8749e562087c2..0000000000000
--- a/mlir/test/IR/dense-elements-type-interface.mlir
+++ /dev/null
@@ -1,83 +0,0 @@
-// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
-
-// Test dense elements attribute with custom element type using DenseElementTypeInterface.
-// Uses the new type-first syntax: dense<TYPE : [ATTR, ...]>
-// Note: The type is embedded in the attribute, so it's not printed again at the end.
-
-// CHECK-LABEL: func @dense_custom_element_type
-func.func @dense_custom_element_type() {
- // CHECK: "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
- "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @dense_custom_element_type_2d
-func.func @dense_custom_element_type_2d() {
- // CHECK: "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
- "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @dense_custom_element_splat
-func.func @dense_custom_element_splat() {
- // CHECK: "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
- "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
- return
-}
-
-// -----
-
-// CHECK-LABEL func @dense_i32_1d
-func.func @dense_i32_1d() {
- // The default assembly format for int, index, float, complex element types is
- // the literal-first syntax. Such a dense elements attribute can be parsed
- // with the type-first syntax, but it will come back with the literal-first
- // syntax.
- // CHECK: "test.dummy"() {attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
- "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
- return
-}
-
-// -----
-
-func.func @invalid_element() {
- // expected-error @+1 {{expected attribute value}}
- "test.dummy"() {attr = dense<tensor<3xi32> : [foo]>} : () -> ()
- return
-}
-
-// -----
-
-func.func @incompatible_attribute() {
- // expected-error @+1 {{incompatible attribute for element type}}
- "test.dummy"() {attr = dense<tensor<3xi32> : ["foo"]>} : () -> ()
- return
-}
-
-// -----
-
-func.func @shape_mismatch() {
- // expected-error @+1 {{expected 3 elements in dimension, got 2}}
- "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32]>} : () -> ()
- return
-}
-
-// -----
-
-func.func @dynamic_shape() {
- // expected-error @+1 {{dense elements type must have static shape}}
- "test.dummy"() {attr = dense<tensor<?xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
- return
-}
-
-// -----
-
-func.func @invalid_type() {
- // expected-error @+1 {{expected a shaped type for dense elements}}
- "test.dummy"() {attr = dense<i32 : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
- return
-}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 08600ce713a17..964792ceebc07 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -18,7 +18,6 @@ include "TestDialect.td"
include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
-include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -513,15 +512,4 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> {
let hasCustomAssemblyFormat = 1;
}
-def TestTypeDenseElement : Test_Type<"TestDenseElement",
- [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
- ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
- ]> {
- let mnemonic = "dense_element";
- let description = [{
- A test type that implements DenseElementTypeInterface to test dense
- elements with custom element types. Elements are stored as 32-bit integers.
- }];
-}
-
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index ef3396fc4f610..71dd25b0093e0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -15,7 +15,6 @@
#include "TestDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
@@ -23,7 +22,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/TypeSize.h"
-#include <cstring>
#include <optional>
using namespace mlir;
@@ -607,29 +605,3 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const {
printer.printNewline();
printer << ">";
}
-
-//===----------------------------------------------------------------------===//
-// TestDenseElementType - DenseElementTypeInterface Implementation
-//===----------------------------------------------------------------------===//
-
-// Elements are stored as 32-bit integers.
-size_t TestDenseElementType::getDenseElementBitSize() const { return 32; }
-
-Attribute
-TestDenseElementType::convertToAttribute(ArrayRef<char> rawData) const {
- assert(rawData.size() == 4 && "expected 4 bytes for TestDenseElement");
- int32_t value;
- std::memcpy(&value, rawData.data(), sizeof(value));
- return IntegerAttr::get(IntegerType::get(getContext(), 32), value);
-}
-
-LogicalResult TestDenseElementType::convertFromAttribute(
- Attribute attr, SmallVectorImpl<char> &result) const {
- auto intAttr = dyn_cast<IntegerAttr>(attr);
- if (!intAttr || intAttr.getType().getIntOrFloatBitWidth() != 32)
- return failure();
- int32_t value = intAttr.getValue().getSExtValue();
- result.append(reinterpret_cast<const char *>(&value),
- reinterpret_cast<const char *>(&value) + sizeof(value));
- return success();
-}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 705fb86e9e9b3..6499a96f495d0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -19,7 +19,6 @@
#include "TestTraits.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
More information about the llvm-branch-commits
mailing list