[Mlir-commits] [mlir] [mlir][IR] Generalize`DenseElementsAttr` to custom element types (PR #179122)
Matthias Springer
llvmlistbot at llvm.org
Sat Feb 7 01:31:53 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/179122
>From e199909fb8562064d09595fc9c3ffa9fd569935d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 1 Feb 2026 17:41:37 +0000
Subject: [PATCH 1/5] [mlir][WIP] `DenseElementsAttr` generalized
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 55 +++++++
mlir/lib/AsmParser/AttributeParser.cpp | 149 +++++++++++++++++-
mlir/lib/IR/AsmPrinter.cpp | 72 +++++++--
mlir/lib/IR/AttributeDetail.h | 4 +
mlir/lib/IR/BuiltinAttributes.cpp | 20 +++
mlir/lib/IR/BuiltinTypes.cpp | 1 +
.../IR/dense-elements-type-interface.mlir | 28 ++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 14 ++
mlir/test/lib/Dialect/Test/TestTypes.cpp | 28 ++++
mlir/test/lib/Dialect/Test/TestTypes.h | 1 +
10 files changed, 353 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/IR/dense-elements-type-interface.mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 9ef08b7020b99..f111bb5b32b49 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -338,4 +338,59 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface allows custom types to be used as element types in
+ DenseElementsAttr. Types implementing this interface define:
+
+ 1. The bit size for element storage. Only full byte sizes are supported
+ at the moment.
+ 2. Helper methods for converting from/to Attribute. This assumes that there
+ is a corresponding attribute for each type that implements this
+ interface.
+
+ The helper methods for converting from/to Attribute are utilized when
+ parsing/printing IR or iterating over the elements via Attribute.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of bits required to store one element in dense
+ storage. This must be a compile-time constant for the type and must
+ be a multiple of 8 (byte-aligned).
+ }],
+ /*retTy=*/"size_t",
+ /*methodName=*/"getDenseElementBitSize",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Convert raw storage bytes to an attribute representing this element
+ value. The `rawData` array contains exactly `getDenseElementBitSize()/8`
+ bytes.
+ }],
+ /*retTy=*/"::mlir::Attribute",
+ /*methodName=*/"convertToAttribute",
+ /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Convert an attribute to raw storage bytes. Appends exactly
+ `getDenseElementBitSize()/8` bytes to `result`. Returns failure if the
+ attribute is incompatible with this element type.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"convertFromAttribute",
+ /*args=*/(ins "::mlir::Attribute":$attr,
+ "::llvm::SmallVectorImpl<char>&":$result)
+ >,
+ ];
+}
+
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 519609a38be6e..09e5146c24e00 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -954,6 +955,137 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
return eltParser.getAttr();
}
+/// Try to parse a dense elements attribute with the type-first syntax.
+/// Syntax: dense<TYPE : [ATTR, ATTR, ...]>
+/// This is used for element types implementing DenseElementTypeInterface.
+///
+/// Returns:
+/// - failure() on parse error
+/// - Attribute() (null) if this is not the type-first syntax
+/// - A valid Attribute on success
+static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
+ // Try to parse an optional type. Skip l_paren because parseOptionalType
+ // would try to parse it as a tuple/function type, but '(' starts a complex
+ // literal like (0, 1) in dense syntax.
+ Type type;
+ OptionalParseResult typeResult = p.getToken().is(Token::l_paren)
+ ? OptionalParseResult(std::nullopt)
+ : p.parseOptionalType(type);
+ if (!typeResult.has_value())
+ return Attribute(); // Not type-first syntax.
+
+ if (failed(*typeResult))
+ return failure(); // Type parse error.
+
+ // We parsed a type. Check for ':' to confirm type-first syntax.
+ if (!p.getToken().is(Token::colon)) {
+ p.emitError(loc, "expected ':' after type in dense attribute");
+ return failure();
+ }
+
+ // Validate the type.
+ auto shapedType = dyn_cast<ShapedType>(type);
+ if (!shapedType) {
+ p.emitError(loc, "expected a shaped type for dense elements");
+ return failure();
+ }
+
+ if (!shapedType.hasStaticShape()) {
+ p.emitError(loc, "dense elements type must have static shape");
+ return failure();
+ }
+
+ // Check that the element type implements DenseElementTypeInterface.
+ auto denseEltType = dyn_cast<DenseElementType>(shapedType.getElementType());
+ if (!denseEltType) {
+ p.emitError(loc, "element type must implement DenseElementTypeInterface "
+ "for type-first dense syntax");
+ return failure();
+ }
+
+ // Consume the ':' that separates the type from the element list.
+ p.consumeToken(Token::colon);
+
+ ArrayRef<int64_t> shape = shapedType.getShape();
+
+ // Parse the element attributes and convert to raw bytes.
+ SmallVector<char> rawData;
+ size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+
+ // Helper to parse a single element.
+ auto parseSingleElement = [&]() -> ParseResult {
+ Attribute elemAttr = p.parseAttribute();
+ if (!elemAttr)
+ return failure();
+ if (failed(denseEltType.convertFromAttribute(elemAttr, rawData))) {
+ p.emitError("incompatible attribute for element type");
+ return failure();
+ }
+ return success();
+ };
+
+ // Recursively parse elements matching the expected shape.
+ std::function<ParseResult(ArrayRef<int64_t>)> parseElements;
+ parseElements = [&](ArrayRef<int64_t> remainingShape) -> ParseResult {
+ // Leaf: parse a single element.
+ if (remainingShape.empty())
+ return parseSingleElement();
+
+ // Non-leaf: expect a list with the correct number of elements.
+ int64_t expectedCount = remainingShape[0];
+ ArrayRef<int64_t> innerShape = remainingShape.drop_front();
+ int64_t actualCount = 0;
+
+ auto parseOne = [&]() -> ParseResult {
+ if (parseElements(innerShape))
+ return failure();
+ ++actualCount;
+ return success();
+ };
+
+ if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOne))
+ return failure();
+
+ if (actualCount != expectedCount) {
+ p.emitError() << "expected " << expectedCount
+ << " elements in dimension, got " << actualCount;
+ return failure();
+ }
+ return success();
+ };
+
+ // Check for splat (single element for the whole tensor).
+ bool isSplat = false;
+ if (!p.getToken().is(Token::l_square)) {
+ // Single element - parse as splat.
+ if (parseSingleElement())
+ return failure();
+ isSplat = shapedType.getNumElements() != 1;
+ } else if (shape.empty()) {
+ // Scalar type shouldn't have a list.
+ p.emitError(loc, "expected single element for scalar type, got list");
+ return failure();
+ } else {
+ // Parse structured literal matching the shape.
+ if (parseElements(shape))
+ return failure();
+ }
+
+ // Verify element count (should match unless it's a splat).
+ int64_t numElements = shapedType.getNumElements();
+ if (!isSplat && rawData.size() != byteSize * numElements) {
+ p.emitError(loc) << "parsed " << (rawData.size() / byteSize)
+ << " elements, but type expects " << numElements;
+ return failure();
+ }
+
+ if (p.parseToken(Token::greater, "expected '>' to close dense attribute"))
+ return failure();
+
+ // Create the attribute from raw buffer.
+ return DenseElementsAttr::getFromRawBuffer(shapedType, rawData);
+}
+
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();
@@ -961,7 +1093,16 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
- // Parse the literal data if necessary.
+ // Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]>
+ // This is used for element types implementing DenseElementTypeInterface.
+ FailureOr<Attribute> typedResult =
+ parseDenseElementsAttrTyped(*this, attribLoc);
+ if (failed(typedResult))
+ return nullptr;
+ if (*typedResult)
+ return *typedResult;
+
+ // Parse the literal data if necessary (old syntax: dense<LITERAL> : TYPE).
TensorLiteralParser literalParser(*this);
if (!consumeIf(Token::greater)) {
if (literalParser.parse(/*allowHex=*/true) ||
@@ -969,10 +1110,10 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return nullptr;
}
- auto type = parseElementsLiteralType(attribLoc, attrType);
- if (!type)
+ auto literalType = parseElementsLiteralType(attribLoc, attrType);
+ if (!literalType)
return nullptr;
- return literalParser.getAttr(attribLoc, type);
+ return literalParser.getAttr(attribLoc, literalType);
}
Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 81455699421cc..0897065849c25 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -512,6 +512,11 @@ class AsmPrinter::Impl {
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
+ /// Print a dense elements attribute using DenseElementTypeInterface.
+ /// Uses the type-first syntax: dense<TYPE : [ATTR, ...]>
+ void printDenseElementsAttrWithInterface(DenseElementsAttr attr,
+ DenseElementType denseEltType);
+
/// Print a dense array attribute.
void printDenseArrayAttr(DenseArrayAttr attr);
@@ -2501,23 +2506,41 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
printSymbolReference(nestedRef.getValue(), os);
}
- } else if (auto intOrFpEltAttr =
- llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
- if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
- printElidedElementsAttr(os);
- } else {
- os << "dense<";
- printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
- os << '>';
+ } else if (auto denseEltAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
+ // Check if the element type implements DenseElementTypeInterface.
+ // If so, use the type-first syntax which embeds the type in the attribute.
+ Type eltType = denseEltAttr.getElementType();
+ if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
+ if (printerFlags.shouldElideElementsAttr(denseEltAttr)) {
+ printElidedElementsAttr(os);
+ } else {
+ os << "dense<";
+ printDenseElementsAttrWithInterface(denseEltAttr, denseEltType);
+ os << '>';
+ }
+ // Type is embedded in the syntax, don't print it again.
+ return;
}
- } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
- if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
- printElidedElementsAttr(os);
- } else {
- os << "dense<";
- printDenseStringElementsAttr(strEltAttr);
- os << '>';
+ // Fall back to existing printing for built-in element types.
+ if (auto intOrFpEltAttr =
+ llvm::dyn_cast<DenseIntOrFPElementsAttr>(denseEltAttr)) {
+ if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
+ printElidedElementsAttr(os);
+ } else {
+ os << "dense<";
+ printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
+ os << '>';
+ }
+ } else if (auto strEltAttr =
+ llvm::dyn_cast<DenseStringElementsAttr>(denseEltAttr)) {
+ if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
+ printElidedElementsAttr(os);
+ } else {
+ os << "dense<";
+ printDenseStringElementsAttr(strEltAttr);
+ os << '>';
+ }
}
} else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
@@ -2705,6 +2728,25 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
+void AsmPrinter::Impl::printDenseElementsAttrWithInterface(
+ DenseElementsAttr attr, DenseElementType denseEltType) {
+ // Print the type first: dense<TYPE : [ELEMENTS]>
+ printType(attr.getType());
+ os << " : ";
+
+ ArrayRef<char> rawData = attr.getRawData();
+ size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+
+ // Print elements: convert raw bytes to attribute, then print attribute.
+ printDenseElementsAttrImpl(
+ attr.isSplat(), attr.getType(), os, [&](unsigned index) {
+ size_t offset = attr.isSplat() ? 0 : index * byteSize;
+ ArrayRef<char> elemData = rawData.slice(offset, byteSize);
+ Attribute elemAttr = denseEltType.convertToAttribute(elemData);
+ printAttributeImpl(elemAttr);
+ });
+}
+
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
Type type = attr.getElementType();
unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index cb9d21bf3e611..9055d58c5fe5d 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -16,6 +16,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AttributeSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
@@ -32,6 +33,9 @@ namespace detail {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
+ // Check for DenseElementTypeInterface first.
+ if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
+ return denseEltType.getDenseElementBitSize();
// Align the width for complex to 8 to make storage and interpretation easier.
if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 6f880f810d651..90dc6d0129658 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -651,6 +651,13 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
ArrayRef<StringRef> vals = owner.getRawStringData();
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
}
+ // Check if the element type implements DenseElementTypeInterface.
+ if (auto denseEltTy = llvm::dyn_cast<DenseElementType>(eltTy)) {
+ ArrayRef<char> rawData = owner.getRawData();
+ size_t byteSize = denseEltTy.getDenseElementBitSize() / CHAR_BIT;
+ size_t offset = owner.isSplat() ? 0 : index * byteSize;
+ return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
+ }
llvm_unreachable("unexpected element type");
}
@@ -946,6 +953,19 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
return DenseElementsAttr::get(type, complexValues);
}
+ // Check if the element type implements DenseElementTypeInterface.
+ if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
+ SmallVector<char> data;
+ for (Attribute attr : values) {
+ SmallVector<char> elementData;
+ if (failed(denseEltType.convertFromAttribute(attr, elementData))) {
+ llvm_unreachable("incompatible attribute for DenseElementType");
+ }
+ llvm::append_range(data, elementData);
+ }
+ return DenseIntOrFPElementsAttr::getRaw(type, data);
+ }
+
// If the element type is not based on int/float/index, assume it is a string
// type.
if (!eltType.isIntOrIndexOrFloat()) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1e198043c590a..d0165db058683 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TensorEncoding.h"
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
new file mode 100644
index 0000000000000..8aa1386cfeb73
--- /dev/null
+++ b/mlir/test/IR/dense-elements-type-interface.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// Test dense elements attribute with custom element type using DenseElementTypeInterface.
+// Uses the new type-first syntax: dense<TYPE : [ATTR, ...]>
+// Note: The type is embedded in the attribute, so it's not printed again at the end.
+
+// CHECK-LABEL: func @dense_custom_element_type
+func.func @dense_custom_element_type() {
+ // The type is embedded in the dense attribute syntax, not printed separately.
+ // CHECK: "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
+ "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @dense_custom_element_type_2d
+func.func @dense_custom_element_type_2d() {
+ // CHECK: "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
+ "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @dense_custom_element_splat
+func.func @dense_custom_element_splat() {
+ // A splat should be detected and stored efficiently
+ // CHECK: "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
+ "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 964792ceebc07..cfbadc6aa8a7a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -18,6 +18,7 @@ include "TestDialect.td"
include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -512,4 +513,17 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> {
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// Test type for DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypeDenseElement : Test_Type<"TestDenseElement",
+ [DeclareTypeInterfaceMethods<DenseElementTypeInterface>]> {
+ let mnemonic = "dense_element";
+ let description = [{
+ A test type that implements DenseElementTypeInterface to test dense
+ elements with custom element types. Elements are stored as 32-bit integers.
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 71dd25b0093e0..ef3396fc4f610 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -15,6 +15,7 @@
#include "TestDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
@@ -22,6 +23,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/TypeSize.h"
+#include <cstring>
#include <optional>
using namespace mlir;
@@ -605,3 +607,29 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const {
printer.printNewline();
printer << ">";
}
+
+//===----------------------------------------------------------------------===//
+// TestDenseElementType - DenseElementTypeInterface Implementation
+//===----------------------------------------------------------------------===//
+
+// Elements are stored as 32-bit integers.
+size_t TestDenseElementType::getDenseElementBitSize() const { return 32; }
+
+Attribute
+TestDenseElementType::convertToAttribute(ArrayRef<char> rawData) const {
+ assert(rawData.size() == 4 && "expected 4 bytes for TestDenseElement");
+ int32_t value;
+ std::memcpy(&value, rawData.data(), sizeof(value));
+ return IntegerAttr::get(IntegerType::get(getContext(), 32), value);
+}
+
+LogicalResult TestDenseElementType::convertFromAttribute(
+ Attribute attr, SmallVectorImpl<char> &result) const {
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
+ if (!intAttr || intAttr.getType().getIntOrFloatBitWidth() != 32)
+ return failure();
+ int32_t value = intAttr.getValue().getSExtValue();
+ result.append(reinterpret_cast<const char *>(&value),
+ reinterpret_cast<const char *>(&value) + sizeof(value));
+ return success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 6499a96f495d0..705fb86e9e9b3 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -19,6 +19,7 @@
#include "TestTraits.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
>From 11e3abd1e927b45be087a18daa47aec60f8b618f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 4 Feb 2026 18:16:37 +0000
Subject: [PATCH 2/5] getter / iterator via interface
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.h | 20 +++
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 128 ++++++++-------
mlir/include/mlir/IR/BuiltinTypes.td | 13 +-
mlir/lib/AsmParser/AttributeParser.cpp | 26 +--
mlir/lib/IR/AsmPrinter.cpp | 74 ++++-----
mlir/lib/IR/AttributeDetail.h | 10 +-
mlir/lib/IR/BuiltinAttributes.cpp | 150 +++++-------------
mlir/lib/IR/BuiltinTypeInterfaces.cpp | 50 ++++++
mlir/lib/IR/BuiltinTypes.cpp | 121 ++++++++++++++
.../IR/dense-elements-type-interface.mlir | 25 ++-
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 8 +-
11 files changed, 385 insertions(+), 240 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index 5f14517d8dd71..c6e6e86d64b9c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -19,6 +19,26 @@ struct fltSemantics;
namespace mlir {
class FloatType;
class MLIRContext;
+
+namespace detail {
+/// Default implementation of DenseElementTypeInterface::getDenseElementBitSize.
+size_t getDefaultDenseElementBitSize(Type type);
+
+/// Default implementation of DenseElementTypeInterface::convertToAttribute.
+Attribute defaultConvertToAttribute(Type type, llvm::ArrayRef<char> rawData);
+
+/// Default implementation of DenseElementTypeInterface::convertFromAttribute.
+LogicalResult defaultConvertFromAttribute(Type type, Attribute attr,
+ llvm::SmallVectorImpl<char> &result);
+
+/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as
+/// an APInt. Handles endianness correctly.
+llvm::APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
+
+/// Write `value` to byte-aligned position `bitPos` in `rawData`. Handles
+/// endianness correctly.
+void writeBits(char *rawData, size_t bitPos, llvm::APInt value);
+} // namespace detail
} // namespace mlir
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index f111bb5b32b49..6463f62b1923b 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -41,12 +41,83 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
}];
}
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This interface allows custom types to be used as element types in
+ DenseElementsAttr. Types implementing this interface define:
+
+ 1. The bit size for element storage. Only full byte sizes are supported
+ at the moment.
+ 2. Helper methods for converting from/to Attribute. This assumes that there
+ is a corresponding attribute for each type that implements this
+ interface.
+
+ The helper methods for converting from/to Attribute are utilized when
+ parsing/printing IR or iterating over the elements via Attribute.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of bits required to store one element in dense
+ storage.
+
+ Note: The DenseElementsAttr infrastructure will automatically align
+ every element to a full byte in storage. This limitation could be lifted
+ in the future to support dense packing of non-byte-sized elements.
+ }],
+ /*retTy=*/"size_t",
+ /*methodName=*/"getDenseElementBitSize",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getDefaultDenseElementBitSize($_type);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Attribute deserialization / attribute factory: Convert raw storage bytes
+ into an MLIR attribute. The size of `rawData` is
+ "ceilDiv(getDenseElementBitSize(), 8)".
+ }],
+ /*retTy=*/"::mlir::Attribute",
+ /*methodName=*/"convertToAttribute",
+ /*args=*/(ins "::llvm::ArrayRef<char>":$rawData),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultConvertToAttribute($_type, rawData);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Attribute serialization: Convert an MLIR attribute into raw bytes.
+ Implementations must append "getDenseElementBitSize() / 8" values to
+ `result`. Return "failure" if the attribute is incompatible with this
+ element type.
+ }],
+ /*retTy=*/"::llvm::LogicalResult",
+ /*methodName=*/"convertFromAttribute",
+ /*args=*/(ins "::mlir::Attribute":$attr,
+ "::llvm::SmallVectorImpl<char>&":$result),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultConvertFromAttribute($_type, attr, result);
+ }]
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// FloatTypeInterface
//===----------------------------------------------------------------------===//
def FloatTypeInterface : TypeInterface<"FloatType",
- [VectorElementTypeInterface]> {
+ [DenseElementTypeInterface, VectorElementTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
@@ -338,59 +409,4 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
}];
}
-//===----------------------------------------------------------------------===//
-// DenseElementTypeInterface
-//===----------------------------------------------------------------------===//
-
-def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
- let cppNamespace = "::mlir";
- let description = [{
- This interface allows custom types to be used as element types in
- DenseElementsAttr. Types implementing this interface define:
-
- 1. The bit size for element storage. Only full byte sizes are supported
- at the moment.
- 2. Helper methods for converting from/to Attribute. This assumes that there
- is a corresponding attribute for each type that implements this
- interface.
-
- The helper methods for converting from/to Attribute are utilized when
- parsing/printing IR or iterating over the elements via Attribute.
- }];
-
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Return the number of bits required to store one element in dense
- storage. This must be a compile-time constant for the type and must
- be a multiple of 8 (byte-aligned).
- }],
- /*retTy=*/"size_t",
- /*methodName=*/"getDenseElementBitSize",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Convert raw storage bytes to an attribute representing this element
- value. The `rawData` array contains exactly `getDenseElementBitSize()/8`
- bytes.
- }],
- /*retTy=*/"::mlir::Attribute",
- /*methodName=*/"convertToAttribute",
- /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Convert an attribute to raw storage bytes. Appends exactly
- `getDenseElementBitSize()/8` bytes to `result`. Returns failure if the
- attribute is incompatible with this element type.
- }],
- /*retTy=*/"::llvm::LogicalResult",
- /*methodName=*/"convertFromAttribute",
- /*args=*/(ins "::mlir::Attribute":$attr,
- "::llvm::SmallVectorImpl<char>&":$result)
- >,
- ];
-}
-
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 08847dd11c685..e671f96f2d0f3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -44,7 +44,10 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
// ComplexType
//===----------------------------------------------------------------------===//
-def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
+def Builtin_Complex : Builtin_Type<"Complex", "complex",
+ [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+ ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
+ ]> {
let summary = "Complex number with a parameterized element type";
let description = [{
Syntax:
@@ -470,7 +473,9 @@ def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
//===----------------------------------------------------------------------===//
def Builtin_Index : Builtin_Type<"Index", "index",
- [VectorElementTypeInterface]> {
+ [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+ ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>,
+ VectorElementTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -501,7 +506,9 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//
def Builtin_Integer : Builtin_Type<"Integer", "integer",
- [VectorElementTypeInterface]> {
+ [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+ ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>,
+ VectorElementTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 09e5146c24e00..81cafb45abff2 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -957,12 +957,12 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
/// Try to parse a dense elements attribute with the type-first syntax.
/// Syntax: dense<TYPE : [ATTR, ATTR, ...]>
-/// This is used for element types implementing DenseElementTypeInterface.
+/// This syntax is used for types other than int, float, index and complex.
///
/// Returns:
-/// - failure() on parse error
-/// - Attribute() (null) if this is not the type-first syntax
-/// - A valid Attribute on success
+/// - "null" attribute if this is not the type-first syntax.
+/// - "failure" in case of a parse error.
+/// - A valid Attribute otherwise.
static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
// Try to parse an optional type. Skip l_paren because parseOptionalType
// would try to parse it as a tuple/function type, but '(' starts a complex
@@ -1010,7 +1010,11 @@ static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
// Parse the element attributes and convert to raw bytes.
SmallVector<char> rawData;
- size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+ // Storage is byte-aligned: align bit size up to next byte boundary. This
+ // limitation could be lifted in the future to support dense packing of
+ // non-byte-sized elements.
+ size_t bitSize = denseEltType.getDenseElementBitSize();
+ size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT));
// Helper to parse a single element.
auto parseSingleElement = [&]() -> ParseResult {
@@ -1032,7 +1036,7 @@ static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
return parseSingleElement();
// Non-leaf: expect a list with the correct number of elements.
- int64_t expectedCount = remainingShape[0];
+ int64_t expectedCount = remainingShape.front();
ArrayRef<int64_t> innerShape = remainingShape.drop_front();
int64_t actualCount = 0;
@@ -1094,7 +1098,6 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return nullptr;
// Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]>
- // This is used for element types implementing DenseElementTypeInterface.
FailureOr<Attribute> typedResult =
parseDenseElementsAttrTyped(*this, attribLoc);
if (failed(typedResult))
@@ -1102,7 +1105,8 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
if (*typedResult)
return *typedResult;
- // Parse the literal data if necessary (old syntax: dense<LITERAL> : TYPE).
+ // Try to parse the literal-first syntax, which is the default format for
+ // int, float, index and complex element types.
TensorLiteralParser literalParser(*this);
if (!consumeIf(Token::greater)) {
if (literalParser.parse(/*allowHex=*/true) ||
@@ -1110,10 +1114,10 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return nullptr;
}
- auto literalType = parseElementsLiteralType(attribLoc, attrType);
- if (!literalType)
+ auto type = parseElementsLiteralType(attribLoc, attrType);
+ if (!type)
return nullptr;
- return literalParser.getAttr(attribLoc, literalType);
+ return literalParser.getAttr(attribLoc, type);
}
Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0897065849c25..b4f5a2b0ff67b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -507,15 +507,17 @@ class AsmPrinter::Impl {
/// Print a dense string elements attribute.
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
- /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
- /// used instead of individual elements when the elements attr is large.
+ /// Print a dense elements attribute in the literal-first syntax. If
+ /// 'allowHex' is true, a hex string is used instead of individual elements
+ /// when the elements attr is large.
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
- /// Print a dense elements attribute using DenseElementTypeInterface.
- /// Uses the type-first syntax: dense<TYPE : [ATTR, ...]>
- void printDenseElementsAttrWithInterface(DenseElementsAttr attr,
- DenseElementType denseEltType);
+ /// Print a dense elements attribute using the type-first syntax and the
+ /// DenseElementTypeInterface, which provides the attribute printer for each
+ /// element.
+ void printTypeFirstDenseElementsAttr(DenseElementsAttr attr,
+ DenseElementType denseEltType);
/// Print a dense array attribute.
void printDenseArrayAttr(DenseArrayAttr attr);
@@ -2506,41 +2508,33 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
printSymbolReference(nestedRef.getValue(), os);
}
- } else if (auto denseEltAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
- // Check if the element type implements DenseElementTypeInterface.
- // If so, use the type-first syntax which embeds the type in the attribute.
- Type eltType = denseEltAttr.getElementType();
- if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
- if (printerFlags.shouldElideElementsAttr(denseEltAttr)) {
- printElidedElementsAttr(os);
+ } else if (auto intOrFpEltAttr =
+ llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
+ if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
+ printElidedElementsAttr(os);
+ } else {
+ os << "dense<";
+ // Check if the element type implements DenseElementTypeInterface and is
+ // not a built-in type. Built-in types (int, float, index, complex) use
+ // the existing printing format for backwards compatibility.
+ Type eltType = intOrFpEltAttr.getElementType();
+ if (isa<FloatType, IntegerType, IndexType, ComplexType>(eltType)) {
+ printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
} else {
- os << "dense<";
- printDenseElementsAttrWithInterface(denseEltAttr, denseEltType);
- os << '>';
+ printTypeFirstDenseElementsAttr(intOrFpEltAttr,
+ cast<DenseElementType>(eltType));
+ typeElision = AttrTypeElision::Must;
}
- // Type is embedded in the syntax, don't print it again.
- return;
+ os << '>';
}
- // Fall back to existing printing for built-in element types.
- if (auto intOrFpEltAttr =
- llvm::dyn_cast<DenseIntOrFPElementsAttr>(denseEltAttr)) {
- if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
- printElidedElementsAttr(os);
- } else {
- os << "dense<";
- printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
- os << '>';
- }
- } else if (auto strEltAttr =
- llvm::dyn_cast<DenseStringElementsAttr>(denseEltAttr)) {
- if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
- printElidedElementsAttr(os);
- } else {
- os << "dense<";
- printDenseStringElementsAttr(strEltAttr);
- os << '>';
- }
+ } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
+ if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
+ printElidedElementsAttr(os);
+ } else {
+ os << "dense<";
+ printDenseStringElementsAttr(strEltAttr);
+ os << '>';
}
} else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
@@ -2728,14 +2722,16 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
-void AsmPrinter::Impl::printDenseElementsAttrWithInterface(
+void AsmPrinter::Impl::printTypeFirstDenseElementsAttr(
DenseElementsAttr attr, DenseElementType denseEltType) {
// Print the type first: dense<TYPE : [ELEMENTS]>
printType(attr.getType());
os << " : ";
ArrayRef<char> rawData = attr.getRawData();
- size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+ // Storage is byte-aligned: align bit size up to next byte boundary.
+ size_t bitSize = denseEltType.getDenseElementBitSize();
+ size_t byteSize = llvm::divideCeil(bitSize, (size_t)CHAR_BIT);
// Print elements: convert raw bytes to attribute, then print attribute.
printDenseElementsAttrImpl(
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 9055d58c5fe5d..b6b3a0551079d 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -33,14 +33,12 @@ namespace detail {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
- // Check for DenseElementTypeInterface first.
+ // i1 is stored as a single bit (bit-packed storage).
+ if (eltType.isInteger(1))
+ return 1;
+ // Check for DenseElementTypeInterface.
if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
return denseEltType.getDenseElementBitSize();
- // Align the width for complex to 8 to make storage and interpretation easier.
- if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
- return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
- if (eltType.isIndex())
- return IndexType::kInternalStorageBitWidth;
return eltType.getIntOrFloatBitWidth();
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 90dc6d0129658..d9c5fd9acb811 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
@@ -543,7 +544,7 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
}
/// Writes value to the bit position `bitPos` in array `rawData`.
-static void writeBits(char *rawData, size_t bitPos, APInt value) {
+void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) {
size_t bitWidth = value.getBitWidth();
// If the bitwidth is 1 we just toggle the specific bit.
@@ -569,7 +570,8 @@ static void writeBits(char *rawData, size_t bitPos, APInt value) {
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData`.
-static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
+APInt mlir::detail::readBits(const char *rawData, size_t bitPos,
+ size_t bitWidth) {
// Handle a boolean bit position.
if (bitWidth == 1)
return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
@@ -619,46 +621,27 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
Type eltTy = owner.getElementType();
- if (llvm::dyn_cast<IntegerType>(eltTy))
- return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (llvm::isa<IndexType>(eltTy))
- return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
- IntElementIterator intIt(owner, index);
- FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
- return FloatAttr::get(eltTy, *floatIt);
- }
- if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
- auto complexEltTy = complexTy.getElementType();
- ComplexIntElementIterator complexIntIt(owner, index);
- if (llvm::isa<IntegerType>(complexEltTy)) {
- auto value = *complexIntIt;
- auto real = IntegerAttr::get(complexEltTy, value.real());
- auto imag = IntegerAttr::get(complexEltTy, value.imag());
- return ArrayAttr::get(complexTy.getContext(),
- ArrayRef<Attribute>{real, imag});
- }
- ComplexFloatElementIterator complexFloatIt(
- llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
- auto value = *complexFloatIt;
- auto real = FloatAttr::get(complexEltTy, value.real());
- auto imag = FloatAttr::get(complexEltTy, value.imag());
- return ArrayAttr::get(complexTy.getContext(),
- ArrayRef<Attribute>{real, imag});
+ // Handle i1 (boolean) specially - it's bit-packed and doesn't use interface.
+ if (eltTy.isInteger(1)) {
+ bool value = *BoolElementIterator(owner, index);
+ return IntegerAttr::get(eltTy, APInt(1, value));
}
+
+ // Handle strings specially.
if (llvm::isa<DenseStringElementsAttr>(owner)) {
ArrayRef<StringRef> vals = owner.getRawStringData();
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
}
- // Check if the element type implements DenseElementTypeInterface.
- if (auto denseEltTy = llvm::dyn_cast<DenseElementType>(eltTy)) {
- ArrayRef<char> rawData = owner.getRawData();
- size_t byteSize = denseEltTy.getDenseElementBitSize() / CHAR_BIT;
- size_t offset = owner.isSplat() ? 0 : index * byteSize;
- return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
- }
- llvm_unreachable("unexpected element type");
+
+ // All other types should implement DenseElementTypeInterface.
+ auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
+ ArrayRef<char> rawData = owner.getRawData();
+ // Storage is byte-aligned: align bit size up to next byte boundary.
+ size_t bitSize = denseEltTy.getDenseElementBitSize();
+ size_t byteSize = llvm::divideCeil(bitSize, CHAR_BIT);
+ size_t offset = owner.isSplat() ? 0 : index * byteSize;
+ return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
}
//===----------------------------------------------------------------------===//
@@ -920,94 +903,37 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
Type eltType = type.getElementType();
- // Take care complex type case first.
- if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
- if (complexType.getElementType().isIntOrIndex()) {
- SmallVector<std::complex<APInt>> complexValues;
- complexValues.reserve(values.size());
- for (Attribute attr : values) {
- assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
- auto arrayAttr = llvm::cast<ArrayAttr>(attr);
- assert(arrayAttr.size() == 2 && "expected 2 element for complex");
- auto attr0 = arrayAttr[0];
- auto attr1 = arrayAttr[1];
- complexValues.push_back(
- std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
- llvm::cast<IntegerAttr>(attr1).getValue()));
- }
- return DenseElementsAttr::get(type, complexValues);
- }
- // Must be float.
- SmallVector<std::complex<APFloat>> complexValues;
- complexValues.reserve(values.size());
- for (Attribute attr : values) {
- assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
- auto arrayAttr = llvm::cast<ArrayAttr>(attr);
- assert(arrayAttr.size() == 2 && "expected 2 element for complex");
- auto attr0 = arrayAttr[0];
- auto attr1 = arrayAttr[1];
- complexValues.push_back(
- std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
- llvm::cast<FloatAttr>(attr1).getValue()));
- }
- return DenseElementsAttr::get(type, complexValues);
- }
-
- // Check if the element type implements DenseElementTypeInterface.
- if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
- SmallVector<char> data;
- for (Attribute attr : values) {
- SmallVector<char> elementData;
- if (failed(denseEltType.convertFromAttribute(attr, elementData))) {
- llvm_unreachable("incompatible attribute for DenseElementType");
- }
- llvm::append_range(data, elementData);
- }
- return DenseIntOrFPElementsAttr::getRaw(type, data);
+ // Handle i1 (boolean) specially - it's bit-packed.
+ if (eltType.isInteger(1)) {
+ SmallVector<bool> boolValues;
+ boolValues.reserve(values.size());
+ for (Attribute attr : values)
+ boolValues.push_back(llvm::cast<IntegerAttr>(attr).getValue().isOne());
+ return get(type, boolValues);
}
- // If the element type is not based on int/float/index, assume it is a string
- // type.
- if (!eltType.isIntOrIndexOrFloat()) {
+ // Handle strings specially.
+ if (!llvm::isa<DenseElementType>(eltType)) {
SmallVector<StringRef, 8> stringValues;
stringValues.reserve(values.size());
for (Attribute attr : values) {
assert(llvm::isa<StringAttr>(attr) &&
- "expected string value for non integer/index/float element");
+ "expected string value for non-DenseElementType element");
stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
}
return get(type, stringValues);
}
- // Otherwise, get the raw storage width to use for the allocation.
- size_t bitWidth = getDenseElementBitWidth(eltType);
- size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-
- // Compress the attribute values into a character buffer.
- SmallVector<char, 8> data(
- llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
- APInt intVal;
- for (unsigned i = 0, e = values.size(); i < e; ++i) {
- if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
- assert(floatAttr.getType() == eltType &&
- "expected float attribute type to equal element type");
- intVal = floatAttr.getValue().bitcastToAPInt();
- } else {
- auto intAttr = llvm::cast<IntegerAttr>(values[i]);
- assert(intAttr.getType() == eltType &&
- "expected integer attribute type to equal element type");
- intVal = intAttr.getValue();
- }
-
- assert(intVal.getBitWidth() == bitWidth &&
- "expected value to have same bitwidth as element type");
- writeBits(data.data(), i * storageBitWidth, intVal);
+ // All other types go through DenseElementTypeInterface.
+ auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
+ assert(denseEltType &&
+ "attempted to get DenseElementsAttr with unsupported element type");
+ SmallVector<char> data;
+ for (Attribute attr : values) {
+ LogicalResult result = denseEltType.convertFromAttribute(attr, data);
+ assert(succeeded(result) && "incompatible attribute for DenseElementType");
+ (void)result;
}
-
- // Handle the special encoding of splat of bool.
- if (values.size() == 1 && eltType.isInteger(1))
- data[0] = data[0] ? -1 : 0;
-
return DenseIntOrFPElementsAttr::getRaw(type, data);
}
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 2f063be3e7cd0..4db0ed7a0c80a 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -6,9 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/CheckedArithmetic.h"
+#include "llvm/Support/MathExtras.h"
+#include <climits>
using namespace mlir;
using namespace mlir::detail;
@@ -19,6 +22,53 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface default implementations
+//===----------------------------------------------------------------------===//
+
+size_t mlir::detail::getDefaultDenseElementBitSize(Type type) {
+ // TODO: This implementation should be defined on FloatType. However, due to
+ // TableGen limitations, a type interface cannot provide an implementation for
+ // an interface method from a base type interface.
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
+ llvm_unreachable("getDenseElementBitSize not implemented");
+ return floatType.getWidth();
+}
+
+Attribute mlir::detail::defaultConvertToAttribute(Type type,
+ ArrayRef<char> rawData) {
+ // TODO: This implementation should be defined on FloatType. However, due to
+ // TableGen limitations, a type interface cannot provide an implementation for
+ // an interface method from a base type interface.
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
+ llvm_unreachable("convertToAttribute not implemented");
+ APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth());
+ APFloat floatVal(floatType.getFloatSemantics(), intVal);
+ return FloatAttr::get(type, floatVal);
+}
+
+LogicalResult
+mlir::detail::defaultConvertFromAttribute(Type type, Attribute attr,
+ SmallVectorImpl<char> &result) {
+ // TODO: This implementation should be defined on FloatType. However, due to
+ // TableGen limitations, a type interface cannot provide an implementation for
+ // an interface method from a base type interface.
+ auto floatType = dyn_cast<FloatType>(type);
+ if (!floatType)
+ llvm_unreachable("convertFromAttribute not implemented");
+ auto floatAttr = dyn_cast<FloatAttr>(attr);
+ if (!floatAttr || floatAttr.getType() != type)
+ return failure();
+ size_t byteSize =
+ llvm::divideCeil(floatType.getWidth(), static_cast<unsigned>(CHAR_BIT));
+ size_t bitPos = result.size() * CHAR_BIT;
+ result.resize(result.size() + byteSize);
+ writeBits(result.data(), bitPos, floatAttr.getValue().bitcastToAPInt());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d0165db058683..7c0a75d82879b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -18,9 +18,11 @@
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CheckedArithmetic.h"
+#include <cstring>
using namespace mlir;
using namespace mlir::detail;
@@ -86,6 +88,125 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
}
+size_t IntegerType::getDenseElementBitSize() const {
+ // Return the actual bit width. Storage alignment is handled separately.
+ // Note: i1 is bit-packed and should be special-cased by the caller.
+ return getWidth();
+}
+
+Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const {
+ APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth());
+ return IntegerAttr::get(*this, value);
+}
+
+LogicalResult
+IntegerType::convertFromAttribute(Attribute attr,
+ SmallVectorImpl<char> &result) const {
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
+ if (!intAttr || intAttr.getType() != *this)
+ return failure();
+
+ size_t byteSize = llvm::divideCeil(getDenseElementBitSize(), CHAR_BIT);
+ size_t bitPos = result.size() * CHAR_BIT;
+ result.resize(result.size() + byteSize);
+ detail::writeBits(result.data(), bitPos, intAttr.getValue());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Index Type
+//===----------------------------------------------------------------------===//
+
+size_t IndexType::getDenseElementBitSize() const {
+ return kInternalStorageBitWidth;
+}
+
+Attribute IndexType::convertToAttribute(ArrayRef<char> rawData) const {
+ APInt value =
+ detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth);
+ return IntegerAttr::get(*this, value);
+}
+
+LogicalResult
+IndexType::convertFromAttribute(Attribute attr,
+ SmallVectorImpl<char> &result) const {
+ auto intAttr = dyn_cast<IntegerAttr>(attr);
+ if (!intAttr || intAttr.getType() != *this)
+ return failure();
+
+ size_t byteSize = kInternalStorageBitWidth / CHAR_BIT;
+ size_t bitPos = result.size() * CHAR_BIT;
+ result.resize(result.size() + byteSize);
+ detail::writeBits(result.data(), bitPos, intAttr.getValue());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Complex Type
+//===----------------------------------------------------------------------===//
+
+size_t ComplexType::getDenseElementBitSize() const {
+ Type eltType = getElementType();
+ if (auto intType = dyn_cast<IntegerType>(eltType))
+ return 2 * llvm::alignTo<CHAR_BIT>(intType.getWidth());
+ return 2 * cast<FloatType>(eltType).getWidth();
+}
+
+Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
+ Type eltType = getElementType();
+ size_t bitWidth = getDenseElementBitSize() / 2;
+
+ if (auto intType = dyn_cast<IntegerType>(eltType)) {
+ APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
+ APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
+ auto real = IntegerAttr::get(eltType, realVal);
+ auto imag = IntegerAttr::get(eltType, imagVal);
+ return ArrayAttr::get(getContext(), {real, imag});
+ }
+
+ auto floatType = cast<FloatType>(eltType);
+ const auto &semantics = floatType.getFloatSemantics();
+ APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
+ APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
+ auto real = FloatAttr::get(eltType, APFloat(semantics, realVal));
+ auto imag = FloatAttr::get(eltType, APFloat(semantics, imagVal));
+ return ArrayAttr::get(getContext(), {real, imag});
+}
+
+LogicalResult
+ComplexType::convertFromAttribute(Attribute attr,
+ SmallVectorImpl<char> &result) const {
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
+ if (!arrayAttr || arrayAttr.size() != 2)
+ return failure();
+
+ Type eltType = getElementType();
+ size_t bitWidth = getDenseElementBitSize() / 2;
+ size_t byteSize = llvm::divideCeil(bitWidth, (size_t)CHAR_BIT);
+ size_t bitPos = result.size() * CHAR_BIT;
+ result.resize(result.size() + 2 * byteSize);
+
+ if (auto intType = dyn_cast<IntegerType>(eltType)) {
+ auto realAttr = dyn_cast<IntegerAttr>(arrayAttr[0]);
+ auto imagAttr = dyn_cast<IntegerAttr>(arrayAttr[1]);
+ if (!realAttr || !imagAttr)
+ return failure();
+ detail::writeBits(result.data(), bitPos, realAttr.getValue());
+ detail::writeBits(result.data(), bitPos + bitWidth, imagAttr.getValue());
+ return success();
+ }
+
+ auto realAttr = dyn_cast<FloatAttr>(arrayAttr[0]);
+ auto imagAttr = dyn_cast<FloatAttr>(arrayAttr[1]);
+ if (!realAttr || !imagAttr)
+ return failure();
+ detail::writeBits(result.data(), bitPos,
+ realAttr.getValue().bitcastToAPInt());
+ detail::writeBits(result.data(), bitPos + bitWidth,
+ imagAttr.getValue().bitcastToAPInt());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Float Types
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
index 8aa1386cfeb73..579a2ca3d1551 100644
--- a/mlir/test/IR/dense-elements-type-interface.mlir
+++ b/mlir/test/IR/dense-elements-type-interface.mlir
@@ -6,23 +6,32 @@
// CHECK-LABEL: func @dense_custom_element_type
func.func @dense_custom_element_type() {
- // The type is embedded in the dense attribute syntax, not printed separately.
- // CHECK: "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
- "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+ // CHECK: "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
+ "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
return
}
// CHECK-LABEL: func @dense_custom_element_type_2d
func.func @dense_custom_element_type_2d() {
- // CHECK: "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
- "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
+ // CHECK: "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
+ "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
return
}
// CHECK-LABEL: func @dense_custom_element_splat
func.func @dense_custom_element_splat() {
- // A splat should be detected and stored efficiently
- // CHECK: "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
- "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+ // CHECK: "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
+ "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+ return
+}
+
+// CHECK-LABEL func @dense_i32_1d
+func.func @dense_i32_1d() {
+ // The default assembly format for int, index, float, complex element types is
+ // the literal-first syntax. Such a dense elements attribute can be parsed
+ // with the type-first syntax, but it will come back with the literal-first
+ // syntax.
+ // CHECK: "test.dummy"() {attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+ "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index cfbadc6aa8a7a..08600ce713a17 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -513,12 +513,10 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> {
let hasCustomAssemblyFormat = 1;
}
-//===----------------------------------------------------------------------===//
-// Test type for DenseElementTypeInterface
-//===----------------------------------------------------------------------===//
-
def TestTypeDenseElement : Test_Type<"TestDenseElement",
- [DeclareTypeInterfaceMethods<DenseElementTypeInterface>]> {
+ [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+ ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
+ ]> {
let mnemonic = "dense_element";
let description = [{
A test type that implements DenseElementTypeInterface to test dense
>From 3407cf96d65c8631172dce6bf352e2d51d7c67d1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 6 Feb 2026 16:32:49 +0000
Subject: [PATCH 3/5] extraTraitClassDeclaration to provide default FloatType
impls
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.h | 21 +++++++-----
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 33 ++++++++++--------
mlir/lib/IR/BuiltinTypeInterfaces.cpp | 34 +++++--------------
3 files changed, 39 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index c6e6e86d64b9c..9425d554b427c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -21,15 +21,18 @@ class FloatType;
class MLIRContext;
namespace detail {
-/// Default implementation of DenseElementTypeInterface::getDenseElementBitSize.
-size_t getDefaultDenseElementBitSize(Type type);
-
-/// Default implementation of DenseElementTypeInterface::convertToAttribute.
-Attribute defaultConvertToAttribute(Type type, llvm::ArrayRef<char> rawData);
-
-/// Default implementation of DenseElementTypeInterface::convertFromAttribute.
-LogicalResult defaultConvertFromAttribute(Type type, Attribute attr,
- llvm::SmallVectorImpl<char> &result);
+/// Float type implementation of
+/// DenseElementTypeInterface::getDenseElementBitSize.
+size_t getFloatTypeDenseElementBitSize(Type type);
+
+/// Float type implementation of DenseElementTypeInterface::convertToAttribute.
+Attribute convertFloatTypeToAttribute(Type type, llvm::ArrayRef<char> rawData);
+
+/// Float type implementation of
+/// DenseElementTypeInterface::convertFromAttribute.
+LogicalResult
+convertFloatTypeFromAttribute(Type type, Attribute attr,
+ llvm::SmallVectorImpl<char> &result);
/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as
/// an APInt. Handles endianness correctly.
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 6463f62b1923b..37d28c5c31a43 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -73,11 +73,7 @@ def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
}],
/*retTy=*/"size_t",
/*methodName=*/"getDenseElementBitSize",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::getDefaultDenseElementBitSize($_type);
- }]
+ /*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
@@ -87,11 +83,7 @@ def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
}],
/*retTy=*/"::mlir::Attribute",
/*methodName=*/"convertToAttribute",
- /*args=*/(ins "::llvm::ArrayRef<char>":$rawData),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::defaultConvertToAttribute($_type, rawData);
- }]
+ /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
>,
InterfaceMethod<
/*desc=*/[{
@@ -103,11 +95,7 @@ def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
/*retTy=*/"::llvm::LogicalResult",
/*methodName=*/"convertFromAttribute",
/*args=*/(ins "::mlir::Attribute":$attr,
- "::llvm::SmallVectorImpl<char>&":$result),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::defaultConvertFromAttribute($_type, attr, result);
- }]
+ "::llvm::SmallVectorImpl<char>&":$result)
>,
];
}
@@ -154,6 +142,21 @@ def FloatTypeInterface : TypeInterface<"FloatType",
/// The width includes the integer bit.
unsigned getFPMantissaWidth();
}];
+
+ let extraTraitClassDeclaration = [{
+ /// DenseElementTypeInterface implementations for float types.
+ size_t getDenseElementBitSize() const {
+ return ::mlir::detail::getFloatTypeDenseElementBitSize($_type);
+ }
+ ::mlir::Attribute convertToAttribute(::llvm::ArrayRef<char> rawData) const {
+ return ::mlir::detail::convertFloatTypeToAttribute($_type, rawData);
+ }
+ ::llvm::LogicalResult
+ convertFromAttribute(::mlir::Attribute attr,
+ ::llvm::SmallVectorImpl<char> &result) const {
+ return ::mlir::detail::convertFloatTypeFromAttribute($_type, attr, result);
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 4db0ed7a0c80a..29303d95eb003 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -23,41 +23,25 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
-// DenseElementTypeInterface default implementations
+// DenseElementTypeInterface implementations for float types
//===----------------------------------------------------------------------===//
-size_t mlir::detail::getDefaultDenseElementBitSize(Type type) {
- // TODO: This implementation should be defined on FloatType. However, due to
- // TableGen limitations, a type interface cannot provide an implementation for
- // an interface method from a base type interface.
- auto floatType = dyn_cast<FloatType>(type);
- if (!floatType)
- llvm_unreachable("getDenseElementBitSize not implemented");
- return floatType.getWidth();
+size_t mlir::detail::getFloatTypeDenseElementBitSize(Type type) {
+ return cast<FloatType>(type).getWidth();
}
-Attribute mlir::detail::defaultConvertToAttribute(Type type,
- ArrayRef<char> rawData) {
- // TODO: This implementation should be defined on FloatType. However, due to
- // TableGen limitations, a type interface cannot provide an implementation for
- // an interface method from a base type interface.
- auto floatType = dyn_cast<FloatType>(type);
- if (!floatType)
- llvm_unreachable("convertToAttribute not implemented");
+Attribute mlir::detail::convertFloatTypeToAttribute(Type type,
+ ArrayRef<char> rawData) {
+ auto floatType = cast<FloatType>(type);
APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth());
APFloat floatVal(floatType.getFloatSemantics(), intVal);
return FloatAttr::get(type, floatVal);
}
LogicalResult
-mlir::detail::defaultConvertFromAttribute(Type type, Attribute attr,
- SmallVectorImpl<char> &result) {
- // TODO: This implementation should be defined on FloatType. However, due to
- // TableGen limitations, a type interface cannot provide an implementation for
- // an interface method from a base type interface.
- auto floatType = dyn_cast<FloatType>(type);
- if (!floatType)
- llvm_unreachable("convertFromAttribute not implemented");
+mlir::detail::convertFloatTypeFromAttribute(Type type, Attribute attr,
+ SmallVectorImpl<char> &result) {
+ auto floatType = cast<FloatType>(type);
auto floatAttr = dyn_cast<FloatAttr>(attr);
if (!floatAttr || floatAttr.getType() != type)
return failure();
>From f709e8892a2a5fc3c5efc995c33625cdc0bf3703 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 6 Feb 2026 17:49:29 +0000
Subject: [PATCH 4/5] address comments
---
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 3 +-
mlir/lib/IR/AttributeDetail.h | 2 +-
mlir/lib/IR/BuiltinTypes.cpp | 118 +++++++-----------
3 files changed, 44 insertions(+), 79 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 37d28c5c31a43..93c8c0694b467 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -51,8 +51,7 @@ def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
This interface allows custom types to be used as element types in
DenseElementsAttr. Types implementing this interface define:
- 1. The bit size for element storage. Only full byte sizes are supported
- at the moment.
+ 1. The bit size for element storage.
2. Helper methods for converting from/to Attribute. This assumes that there
is a corresponding attribute for each type that implements this
interface.
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index b6b3a0551079d..7af5c8cd9191d 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -39,7 +39,7 @@ inline size_t getDenseElementBitWidth(Type eltType) {
// Check for DenseElementTypeInterface.
if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
return denseEltType.getDenseElementBitSize();
- return eltType.getIntOrFloatBitWidth();
+ llvm_unreachable("unsupported element type");
}
/// An attribute representing a reference to a dense vector or tensor object.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 7c0a75d82879b..0f0025bd5bfc0 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -61,6 +61,39 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+size_t ComplexType::getDenseElementBitSize() const {
+ auto elemTy = cast<DenseElementType>(getElementType());
+ return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2;
+}
+
+Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
+ auto elemTy = cast<DenseElementType>(getElementType());
+ size_t singleElementBytes =
+ llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8;
+ Attribute real =
+ elemTy.convertToAttribute(rawData.take_front(singleElementBytes));
+ Attribute imag =
+ elemTy.convertToAttribute(rawData.take_back(singleElementBytes));
+ return ArrayAttr::get(getContext(), {real, imag});
+}
+
+LogicalResult
+ComplexType::convertFromAttribute(Attribute attr,
+ SmallVectorImpl<char> &result) const {
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
+ if (!arrayAttr || arrayAttr.size() != 2)
+ return failure();
+ auto elemTy = cast<DenseElementType>(getElementType());
+ SmallVector<char> realData, imagData;
+ if (failed(elemTy.convertFromAttribute(arrayAttr[0], realData)))
+ return failure();
+ if (failed(elemTy.convertFromAttribute(arrayAttr[1], imagData)))
+ return failure();
+ result.append(realData);
+ result.append(imagData);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
@@ -99,17 +132,20 @@ Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const {
return IntegerAttr::get(*this, value);
}
+static void writeAPIntToVector(APInt apInt, SmallVectorImpl<char> &result) {
+ size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT);
+ size_t bitPos = result.size() * CHAR_BIT;
+ result.resize(result.size() + byteSize);
+ detail::writeBits(result.data(), bitPos, apInt);
+}
+
LogicalResult
IntegerType::convertFromAttribute(Attribute attr,
SmallVectorImpl<char> &result) const {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || intAttr.getType() != *this)
return failure();
-
- size_t byteSize = llvm::divideCeil(getDenseElementBitSize(), CHAR_BIT);
- size_t bitPos = result.size() * CHAR_BIT;
- result.resize(result.size() + byteSize);
- detail::writeBits(result.data(), bitPos, intAttr.getValue());
+ writeAPIntToVector(intAttr.getValue(), result);
return success();
}
@@ -133,77 +169,7 @@ IndexType::convertFromAttribute(Attribute attr,
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || intAttr.getType() != *this)
return failure();
-
- size_t byteSize = kInternalStorageBitWidth / CHAR_BIT;
- size_t bitPos = result.size() * CHAR_BIT;
- result.resize(result.size() + byteSize);
- detail::writeBits(result.data(), bitPos, intAttr.getValue());
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Complex Type
-//===----------------------------------------------------------------------===//
-
-size_t ComplexType::getDenseElementBitSize() const {
- Type eltType = getElementType();
- if (auto intType = dyn_cast<IntegerType>(eltType))
- return 2 * llvm::alignTo<CHAR_BIT>(intType.getWidth());
- return 2 * cast<FloatType>(eltType).getWidth();
-}
-
-Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
- Type eltType = getElementType();
- size_t bitWidth = getDenseElementBitSize() / 2;
-
- if (auto intType = dyn_cast<IntegerType>(eltType)) {
- APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
- APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
- auto real = IntegerAttr::get(eltType, realVal);
- auto imag = IntegerAttr::get(eltType, imagVal);
- return ArrayAttr::get(getContext(), {real, imag});
- }
-
- auto floatType = cast<FloatType>(eltType);
- const auto &semantics = floatType.getFloatSemantics();
- APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
- APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
- auto real = FloatAttr::get(eltType, APFloat(semantics, realVal));
- auto imag = FloatAttr::get(eltType, APFloat(semantics, imagVal));
- return ArrayAttr::get(getContext(), {real, imag});
-}
-
-LogicalResult
-ComplexType::convertFromAttribute(Attribute attr,
- SmallVectorImpl<char> &result) const {
- auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2)
- return failure();
-
- Type eltType = getElementType();
- size_t bitWidth = getDenseElementBitSize() / 2;
- size_t byteSize = llvm::divideCeil(bitWidth, (size_t)CHAR_BIT);
- size_t bitPos = result.size() * CHAR_BIT;
- result.resize(result.size() + 2 * byteSize);
-
- if (auto intType = dyn_cast<IntegerType>(eltType)) {
- auto realAttr = dyn_cast<IntegerAttr>(arrayAttr[0]);
- auto imagAttr = dyn_cast<IntegerAttr>(arrayAttr[1]);
- if (!realAttr || !imagAttr)
- return failure();
- detail::writeBits(result.data(), bitPos, realAttr.getValue());
- detail::writeBits(result.data(), bitPos + bitWidth, imagAttr.getValue());
- return success();
- }
-
- auto realAttr = dyn_cast<FloatAttr>(arrayAttr[0]);
- auto imagAttr = dyn_cast<FloatAttr>(arrayAttr[1]);
- if (!realAttr || !imagAttr)
- return failure();
- detail::writeBits(result.data(), bitPos,
- realAttr.getValue().bitcastToAPInt());
- detail::writeBits(result.data(), bitPos + bitWidth,
- imagAttr.getValue().bitcastToAPInt());
+ writeAPIntToVector(intAttr.getValue(), result);
return success();
}
>From 3463c3f3fbd52e80c301a3905d929fe615ae9d6c Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 7 Feb 2026 09:31:26 +0000
Subject: [PATCH 5/5] simplify parser
---
mlir/lib/AsmParser/AttributeParser.cpp | 69 ++++++++-----------
.../IR/dense-elements-type-interface.mlir | 48 ++++++++++++-
2 files changed, 75 insertions(+), 42 deletions(-)
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 81cafb45abff2..01db7505b6665 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -964,57 +964,41 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
/// - "failure" in case of a parse error.
/// - A valid Attribute otherwise.
static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
- // Try to parse an optional type. Skip l_paren because parseOptionalType
- // would try to parse it as a tuple/function type, but '(' starts a complex
- // literal like (0, 1) in dense syntax.
- Type type;
- OptionalParseResult typeResult = p.getToken().is(Token::l_paren)
- ? OptionalParseResult(std::nullopt)
- : p.parseOptionalType(type);
- if (!typeResult.has_value())
- return Attribute(); // Not type-first syntax.
-
- if (failed(*typeResult))
- return failure(); // Type parse error.
-
- // We parsed a type. Check for ':' to confirm type-first syntax.
- if (!p.getToken().is(Token::colon)) {
- p.emitError(loc, "expected ':' after type in dense attribute");
+ // Skip l_paren because "parseType" would try to parse it as a tuple/function
+ // type, but '(' starts a complex literal like in the literal-first syntax.
+ if (p.getToken().is(Token::l_paren))
+ return Attribute();
+
+ // Parse type and valdiate that it's a shaped type.
+ auto typeLoc = p.getToken().getLoc();
+ Type type = p.parseType();
+ if (!type)
return failure();
- }
-
- // Validate the type.
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType) {
- p.emitError(loc, "expected a shaped type for dense elements");
+ p.emitError(typeLoc, "expected a shaped type for dense elements");
return failure();
}
-
if (!shapedType.hasStaticShape()) {
- p.emitError(loc, "dense elements type must have static shape");
+ p.emitError(typeLoc, "dense elements type must have static shape");
return failure();
}
// Check that the element type implements DenseElementTypeInterface.
auto denseEltType = dyn_cast<DenseElementType>(shapedType.getElementType());
if (!denseEltType) {
- p.emitError(loc, "element type must implement DenseElementTypeInterface "
- "for type-first dense syntax");
+ p.emitError(typeLoc,
+ "element type must implement DenseElementTypeInterface "
+ "for type-first dense syntax");
return failure();
}
- // Consume the ':' that separates the type from the element list.
- p.consumeToken(Token::colon);
-
- ArrayRef<int64_t> shape = shapedType.getShape();
+ // Parse colon.
+ if (p.parseToken(Token::colon, "expected ':' after type in dense attribute"))
+ return failure();
// Parse the element attributes and convert to raw bytes.
SmallVector<char> rawData;
- // Storage is byte-aligned: align bit size up to next byte boundary. This
- // limitation could be lifted in the future to support dense packing of
- // non-byte-sized elements.
- size_t bitSize = denseEltType.getDenseElementBitSize();
- size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT));
// Helper to parse a single element.
auto parseSingleElement = [&]() -> ParseResult {
@@ -1065,22 +1049,25 @@ static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
if (parseSingleElement())
return failure();
isSplat = shapedType.getNumElements() != 1;
- } else if (shape.empty()) {
+ } else if (shapedType.getShape().empty()) {
// Scalar type shouldn't have a list.
p.emitError(loc, "expected single element for scalar type, got list");
return failure();
} else {
// Parse structured literal matching the shape.
- if (parseElements(shape))
+ if (parseElements(shapedType.getShape()))
return failure();
}
- // Verify element count (should match unless it's a splat).
- int64_t numElements = shapedType.getNumElements();
- if (!isSplat && rawData.size() != byteSize * numElements) {
- p.emitError(loc) << "parsed " << (rawData.size() / byteSize)
- << " elements, but type expects " << numElements;
- return failure();
+ if (!isSplat) {
+ // Safety check to protect against incorrect interface implementations.
+ // Storage is byte-aligned: each element starts at the beginning of a byte.
+ // (This limitation could be lifted in the future to support dense packing
+ // of non-byte-sized elements.)
+ size_t bitSize = denseEltType.getDenseElementBitSize();
+ size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT));
+ assert(rawData.size() == byteSize * shapedType.getNumElements() &&
+ "incorrect number of bytes in result buffer");
}
if (p.parseToken(Token::greater, "expected '>' to close dense attribute"))
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
index 579a2ca3d1551..8749e562087c2 100644
--- a/mlir/test/IR/dense-elements-type-interface.mlir
+++ b/mlir/test/IR/dense-elements-type-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
// Test dense elements attribute with custom element type using DenseElementTypeInterface.
// Uses the new type-first syntax: dense<TYPE : [ATTR, ...]>
@@ -11,6 +11,8 @@ func.func @dense_custom_element_type() {
return
}
+// -----
+
// CHECK-LABEL: func @dense_custom_element_type_2d
func.func @dense_custom_element_type_2d() {
// CHECK: "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
@@ -18,6 +20,8 @@ func.func @dense_custom_element_type_2d() {
return
}
+// -----
+
// CHECK-LABEL: func @dense_custom_element_splat
func.func @dense_custom_element_splat() {
// CHECK: "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
@@ -25,6 +29,8 @@ func.func @dense_custom_element_splat() {
return
}
+// -----
+
// CHECK-LABEL func @dense_i32_1d
func.func @dense_i32_1d() {
// The default assembly format for int, index, float, complex element types is
@@ -35,3 +41,43 @@ func.func @dense_i32_1d() {
"test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
return
}
+
+// -----
+
+func.func @invalid_element() {
+ // expected-error @+1 {{expected attribute value}}
+ "test.dummy"() {attr = dense<tensor<3xi32> : [foo]>} : () -> ()
+ return
+}
+
+// -----
+
+func.func @incompatible_attribute() {
+ // expected-error @+1 {{incompatible attribute for element type}}
+ "test.dummy"() {attr = dense<tensor<3xi32> : ["foo"]>} : () -> ()
+ return
+}
+
+// -----
+
+func.func @shape_mismatch() {
+ // expected-error @+1 {{expected 3 elements in dimension, got 2}}
+ "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32]>} : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamic_shape() {
+ // expected-error @+1 {{dense elements type must have static shape}}
+ "test.dummy"() {attr = dense<tensor<?xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+ return
+}
+
+// -----
+
+func.func @invalid_type() {
+ // expected-error @+1 {{expected a shaped type for dense elements}}
+ "test.dummy"() {attr = dense<i32 : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+ return
+}
More information about the Mlir-commits
mailing list