[Mlir-commits] [mlir] [mlir][WIP] `DenseElementsAttr` generalized (PR #179122)

Matthias Springer llvmlistbot at llvm.org
Thu Feb 5 07:00:15 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/179122

>From e199909fb8562064d09595fc9c3ffa9fd569935d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 1 Feb 2026 17:41:37 +0000
Subject: [PATCH 1/2] [mlir][WIP] `DenseElementsAttr` generalized

---
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td |  55 +++++++
 mlir/lib/AsmParser/AttributeParser.cpp        | 149 +++++++++++++++++-
 mlir/lib/IR/AsmPrinter.cpp                    |  72 +++++++--
 mlir/lib/IR/AttributeDetail.h                 |   4 +
 mlir/lib/IR/BuiltinAttributes.cpp             |  20 +++
 mlir/lib/IR/BuiltinTypes.cpp                  |   1 +
 .../IR/dense-elements-type-interface.mlir     |  28 ++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  14 ++
 mlir/test/lib/Dialect/Test/TestTypes.cpp      |  28 ++++
 mlir/test/lib/Dialect/Test/TestTypes.h        |   1 +
 10 files changed, 353 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/IR/dense-elements-type-interface.mlir

diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 9ef08b7020b99..f111bb5b32b49 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -338,4 +338,59 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface allows custom types to be used as element types in
+    DenseElementsAttr. Types implementing this interface define:
+
+    1. The bit size for element storage. Only full byte sizes are supported
+       at the moment.
+    2. Helper methods for converting from/to Attribute. This assumes that there
+       is a corresponding attribute for each type that implements this
+       interface.
+
+    The helper methods for converting from/to Attribute are utilized when
+    parsing/printing IR or iterating over the elements via Attribute.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of bits required to store one element in dense
+        storage. This must be a compile-time constant for the type and must
+        be a multiple of 8 (byte-aligned).
+      }],
+      /*retTy=*/"size_t",
+      /*methodName=*/"getDenseElementBitSize",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Convert raw storage bytes to an attribute representing this element
+        value. The `rawData` array contains exactly `getDenseElementBitSize()/8`
+        bytes.
+      }],
+      /*retTy=*/"::mlir::Attribute",
+      /*methodName=*/"convertToAttribute",
+      /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Convert an attribute to raw storage bytes. Appends exactly
+        `getDenseElementBitSize()/8` bytes to `result`. Returns failure if the
+        attribute is incompatible with this element type.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"convertFromAttribute",
+      /*args=*/(ins "::mlir::Attribute":$attr,
+                    "::llvm::SmallVectorImpl<char>&":$result)
+    >,
+  ];
+}
+
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 519609a38be6e..09e5146c24e00 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -954,6 +955,137 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
   return eltParser.getAttr();
 }
 
+/// Try to parse a dense elements attribute with the type-first syntax.
+/// Syntax: dense<TYPE : [ATTR, ATTR, ...]>
+/// This is used for element types implementing DenseElementTypeInterface.
+///
+/// Returns:
+///   - failure() on parse error
+///   - Attribute() (null) if this is not the type-first syntax
+///   - A valid Attribute on success
+static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
+  // Try to parse an optional type. Skip l_paren because parseOptionalType
+  // would try to parse it as a tuple/function type, but '(' starts a complex
+  // literal like (0, 1) in dense syntax.
+  Type type;
+  OptionalParseResult typeResult = p.getToken().is(Token::l_paren)
+                                       ? OptionalParseResult(std::nullopt)
+                                       : p.parseOptionalType(type);
+  if (!typeResult.has_value())
+    return Attribute(); // Not type-first syntax.
+
+  if (failed(*typeResult))
+    return failure(); // Type parse error.
+
+  // We parsed a type. Check for ':' to confirm type-first syntax.
+  if (!p.getToken().is(Token::colon)) {
+    p.emitError(loc, "expected ':' after type in dense attribute");
+    return failure();
+  }
+
+  // Validate the type.
+  auto shapedType = dyn_cast<ShapedType>(type);
+  if (!shapedType) {
+    p.emitError(loc, "expected a shaped type for dense elements");
+    return failure();
+  }
+
+  if (!shapedType.hasStaticShape()) {
+    p.emitError(loc, "dense elements type must have static shape");
+    return failure();
+  }
+
+  // Check that the element type implements DenseElementTypeInterface.
+  auto denseEltType = dyn_cast<DenseElementType>(shapedType.getElementType());
+  if (!denseEltType) {
+    p.emitError(loc, "element type must implement DenseElementTypeInterface "
+                     "for type-first dense syntax");
+    return failure();
+  }
+
+  // Consume the ':' that separates the type from the element list.
+  p.consumeToken(Token::colon);
+
+  ArrayRef<int64_t> shape = shapedType.getShape();
+
+  // Parse the element attributes and convert to raw bytes.
+  SmallVector<char> rawData;
+  size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+
+  // Helper to parse a single element.
+  auto parseSingleElement = [&]() -> ParseResult {
+    Attribute elemAttr = p.parseAttribute();
+    if (!elemAttr)
+      return failure();
+    if (failed(denseEltType.convertFromAttribute(elemAttr, rawData))) {
+      p.emitError("incompatible attribute for element type");
+      return failure();
+    }
+    return success();
+  };
+
+  // Recursively parse elements matching the expected shape.
+  std::function<ParseResult(ArrayRef<int64_t>)> parseElements;
+  parseElements = [&](ArrayRef<int64_t> remainingShape) -> ParseResult {
+    // Leaf: parse a single element.
+    if (remainingShape.empty())
+      return parseSingleElement();
+
+    // Non-leaf: expect a list with the correct number of elements.
+    int64_t expectedCount = remainingShape[0];
+    ArrayRef<int64_t> innerShape = remainingShape.drop_front();
+    int64_t actualCount = 0;
+
+    auto parseOne = [&]() -> ParseResult {
+      if (parseElements(innerShape))
+        return failure();
+      ++actualCount;
+      return success();
+    };
+
+    if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOne))
+      return failure();
+
+    if (actualCount != expectedCount) {
+      p.emitError() << "expected " << expectedCount
+                    << " elements in dimension, got " << actualCount;
+      return failure();
+    }
+    return success();
+  };
+
+  // Check for splat (single element for the whole tensor).
+  bool isSplat = false;
+  if (!p.getToken().is(Token::l_square)) {
+    // Single element - parse as splat.
+    if (parseSingleElement())
+      return failure();
+    isSplat = shapedType.getNumElements() != 1;
+  } else if (shape.empty()) {
+    // Scalar type shouldn't have a list.
+    p.emitError(loc, "expected single element for scalar type, got list");
+    return failure();
+  } else {
+    // Parse structured literal matching the shape.
+    if (parseElements(shape))
+      return failure();
+  }
+
+  // Verify element count (should match unless it's a splat).
+  int64_t numElements = shapedType.getNumElements();
+  if (!isSplat && rawData.size() != byteSize * numElements) {
+    p.emitError(loc) << "parsed " << (rawData.size() / byteSize)
+                     << " elements, but type expects " << numElements;
+    return failure();
+  }
+
+  if (p.parseToken(Token::greater, "expected '>' to close dense attribute"))
+    return failure();
+
+  // Create the attribute from raw buffer.
+  return DenseElementsAttr::getFromRawBuffer(shapedType, rawData);
+}
+
 /// Parse a dense elements attribute.
 Attribute Parser::parseDenseElementsAttr(Type attrType) {
   auto attribLoc = getToken().getLoc();
@@ -961,7 +1093,16 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   if (parseToken(Token::less, "expected '<' after 'dense'"))
     return nullptr;
 
-  // Parse the literal data if necessary.
+  // Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]>
+  // This is used for element types implementing DenseElementTypeInterface.
+  FailureOr<Attribute> typedResult =
+      parseDenseElementsAttrTyped(*this, attribLoc);
+  if (failed(typedResult))
+    return nullptr;
+  if (*typedResult)
+    return *typedResult;
+
+  // Parse the literal data if necessary (old syntax: dense<LITERAL> : TYPE).
   TensorLiteralParser literalParser(*this);
   if (!consumeIf(Token::greater)) {
     if (literalParser.parse(/*allowHex=*/true) ||
@@ -969,10 +1110,10 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
       return nullptr;
   }
 
-  auto type = parseElementsLiteralType(attribLoc, attrType);
-  if (!type)
+  auto literalType = parseElementsLiteralType(attribLoc, attrType);
+  if (!literalType)
     return nullptr;
-  return literalParser.getAttr(attribLoc, type);
+  return literalParser.getAttr(attribLoc, literalType);
 }
 
 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 81455699421cc..0897065849c25 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -512,6 +512,11 @@ class AsmPrinter::Impl {
   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
                                      bool allowHex);
 
+  /// Print a dense elements attribute using DenseElementTypeInterface.
+  /// Uses the type-first syntax: dense<TYPE : [ATTR, ...]>
+  void printDenseElementsAttrWithInterface(DenseElementsAttr attr,
+                                           DenseElementType denseEltType);
+
   /// Print a dense array attribute.
   void printDenseArrayAttr(DenseArrayAttr attr);
 
@@ -2501,23 +2506,41 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
       printSymbolReference(nestedRef.getValue(), os);
     }
 
-  } else if (auto intOrFpEltAttr =
-                 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
-    if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
-      printElidedElementsAttr(os);
-    } else {
-      os << "dense<";
-      printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
-      os << '>';
+  } else if (auto denseEltAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
+    // Check if the element type implements DenseElementTypeInterface.
+    // If so, use the type-first syntax which embeds the type in the attribute.
+    Type eltType = denseEltAttr.getElementType();
+    if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
+      if (printerFlags.shouldElideElementsAttr(denseEltAttr)) {
+        printElidedElementsAttr(os);
+      } else {
+        os << "dense<";
+        printDenseElementsAttrWithInterface(denseEltAttr, denseEltType);
+        os << '>';
+      }
+      // Type is embedded in the syntax, don't print it again.
+      return;
     }
 
-  } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
-    if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
-      printElidedElementsAttr(os);
-    } else {
-      os << "dense<";
-      printDenseStringElementsAttr(strEltAttr);
-      os << '>';
+    // Fall back to existing printing for built-in element types.
+    if (auto intOrFpEltAttr =
+            llvm::dyn_cast<DenseIntOrFPElementsAttr>(denseEltAttr)) {
+      if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
+        printElidedElementsAttr(os);
+      } else {
+        os << "dense<";
+        printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
+        os << '>';
+      }
+    } else if (auto strEltAttr =
+                   llvm::dyn_cast<DenseStringElementsAttr>(denseEltAttr)) {
+      if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
+        printElidedElementsAttr(os);
+      } else {
+        os << "dense<";
+        printDenseStringElementsAttr(strEltAttr);
+        os << '>';
+      }
     }
 
   } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
@@ -2705,6 +2728,25 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
 }
 
+void AsmPrinter::Impl::printDenseElementsAttrWithInterface(
+    DenseElementsAttr attr, DenseElementType denseEltType) {
+  // Print the type first: dense<TYPE : [ELEMENTS]>
+  printType(attr.getType());
+  os << " : ";
+
+  ArrayRef<char> rawData = attr.getRawData();
+  size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+
+  // Print elements: convert raw bytes to attribute, then print attribute.
+  printDenseElementsAttrImpl(
+      attr.isSplat(), attr.getType(), os, [&](unsigned index) {
+        size_t offset = attr.isSplat() ? 0 : index * byteSize;
+        ArrayRef<char> elemData = rawData.slice(offset, byteSize);
+        Attribute elemAttr = denseEltType.convertToAttribute(elemData);
+        printAttributeImpl(elemAttr);
+      });
+}
+
 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
   Type type = attr.getElementType();
   unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index cb9d21bf3e611..9055d58c5fe5d 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/AttributeSupport.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
@@ -32,6 +33,9 @@ namespace detail {
 
 /// Return the bit width which DenseElementsAttr should use for this type.
 inline size_t getDenseElementBitWidth(Type eltType) {
+  // Check for DenseElementTypeInterface first.
+  if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
+    return denseEltType.getDenseElementBitSize();
   // Align the width for complex to 8 to make storage and interpretation easier.
   if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
     return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 6f880f810d651..90dc6d0129658 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -651,6 +651,13 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
     ArrayRef<StringRef> vals = owner.getRawStringData();
     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
   }
+  // Check if the element type implements DenseElementTypeInterface.
+  if (auto denseEltTy = llvm::dyn_cast<DenseElementType>(eltTy)) {
+    ArrayRef<char> rawData = owner.getRawData();
+    size_t byteSize = denseEltTy.getDenseElementBitSize() / CHAR_BIT;
+    size_t offset = owner.isSplat() ? 0 : index * byteSize;
+    return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
+  }
   llvm_unreachable("unexpected element type");
 }
 
@@ -946,6 +953,19 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
     return DenseElementsAttr::get(type, complexValues);
   }
 
+  // Check if the element type implements DenseElementTypeInterface.
+  if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
+    SmallVector<char> data;
+    for (Attribute attr : values) {
+      SmallVector<char> elementData;
+      if (failed(denseEltType.convertFromAttribute(attr, elementData))) {
+        llvm_unreachable("incompatible attribute for DenseElementType");
+      }
+      llvm::append_range(data, elementData);
+    }
+    return DenseIntOrFPElementsAttr::getRaw(type, data);
+  }
+
   // If the element type is not based on int/float/index, assume it is a string
   // type.
   if (!eltType.isIntOrIndexOrFloat()) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1e198043c590a..d0165db058683 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/TensorEncoding.h"
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
new file mode 100644
index 0000000000000..8aa1386cfeb73
--- /dev/null
+++ b/mlir/test/IR/dense-elements-type-interface.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// Test dense elements attribute with custom element type using DenseElementTypeInterface.
+// Uses the new type-first syntax: dense<TYPE : [ATTR, ...]>
+// Note: The type is embedded in the attribute, so it's not printed again at the end.
+
+// CHECK-LABEL: func @dense_custom_element_type
+func.func @dense_custom_element_type() {
+  // The type is embedded in the dense attribute syntax, not printed separately.
+  // CHECK: "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
+  "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @dense_custom_element_type_2d
+func.func @dense_custom_element_type_2d() {
+  // CHECK: "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
+  "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @dense_custom_element_splat
+func.func @dense_custom_element_splat() {
+  // A splat should be detected and stored efficiently
+  // CHECK: "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
+  "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 964792ceebc07..cfbadc6aa8a7a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -18,6 +18,7 @@ include "TestDialect.td"
 include "TestAttrDefs.td"
 include "TestInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 
@@ -512,4 +513,17 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> {
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Test type for DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypeDenseElement : Test_Type<"TestDenseElement",
+    [DeclareTypeInterfaceMethods<DenseElementTypeInterface>]> {
+  let mnemonic = "dense_element";
+  let description = [{
+    A test type that implements DenseElementTypeInterface to test dense
+    elements with custom element types. Elements are stored as 32-bit integers.
+  }];
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 71dd25b0093e0..ef3396fc4f610 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -15,6 +15,7 @@
 #include "TestDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/Types.h"
@@ -22,6 +23,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/TypeSize.h"
+#include <cstring>
 #include <optional>
 
 using namespace mlir;
@@ -605,3 +607,29 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const {
   printer.printNewline();
   printer << ">";
 }
+
+//===----------------------------------------------------------------------===//
+// TestDenseElementType - DenseElementTypeInterface Implementation
+//===----------------------------------------------------------------------===//
+
+// Elements are stored as 32-bit integers.
+size_t TestDenseElementType::getDenseElementBitSize() const { return 32; }
+
+Attribute
+TestDenseElementType::convertToAttribute(ArrayRef<char> rawData) const {
+  assert(rawData.size() == 4 && "expected 4 bytes for TestDenseElement");
+  int32_t value;
+  std::memcpy(&value, rawData.data(), sizeof(value));
+  return IntegerAttr::get(IntegerType::get(getContext(), 32), value);
+}
+
+LogicalResult TestDenseElementType::convertFromAttribute(
+    Attribute attr, SmallVectorImpl<char> &result) const {
+  auto intAttr = dyn_cast<IntegerAttr>(attr);
+  if (!intAttr || intAttr.getType().getIntOrFloatBitWidth() != 32)
+    return failure();
+  int32_t value = intAttr.getValue().getSExtValue();
+  result.append(reinterpret_cast<const char *>(&value),
+                reinterpret_cast<const char *>(&value) + sizeof(value));
+  return success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 6499a96f495d0..705fb86e9e9b3 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -19,6 +19,7 @@
 
 #include "TestTraits.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"

>From 11e3abd1e927b45be087a18daa47aec60f8b618f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 4 Feb 2026 18:16:37 +0000
Subject: [PATCH 2/2] getter / iterator via interface

---
 mlir/include/mlir/IR/BuiltinTypeInterfaces.h  |  20 +++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 128 ++++++++-------
 mlir/include/mlir/IR/BuiltinTypes.td          |  13 +-
 mlir/lib/AsmParser/AttributeParser.cpp        |  26 +--
 mlir/lib/IR/AsmPrinter.cpp                    |  74 ++++-----
 mlir/lib/IR/AttributeDetail.h                 |  10 +-
 mlir/lib/IR/BuiltinAttributes.cpp             | 150 +++++-------------
 mlir/lib/IR/BuiltinTypeInterfaces.cpp         |  50 ++++++
 mlir/lib/IR/BuiltinTypes.cpp                  | 121 ++++++++++++++
 .../IR/dense-elements-type-interface.mlir     |  25 ++-
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |   8 +-
 11 files changed, 385 insertions(+), 240 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index 5f14517d8dd71..c6e6e86d64b9c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -19,6 +19,26 @@ struct fltSemantics;
 namespace mlir {
 class FloatType;
 class MLIRContext;
+
+namespace detail {
+/// Default implementation of DenseElementTypeInterface::getDenseElementBitSize.
+size_t getDefaultDenseElementBitSize(Type type);
+
+/// Default implementation of DenseElementTypeInterface::convertToAttribute.
+Attribute defaultConvertToAttribute(Type type, llvm::ArrayRef<char> rawData);
+
+/// Default implementation of DenseElementTypeInterface::convertFromAttribute.
+LogicalResult defaultConvertFromAttribute(Type type, Attribute attr,
+                                          llvm::SmallVectorImpl<char> &result);
+
+/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as
+/// an APInt. Handles endianness correctly.
+llvm::APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
+
+/// Write `value` to byte-aligned position `bitPos` in `rawData`. Handles
+/// endianness correctly.
+void writeBits(char *rawData, size_t bitPos, llvm::APInt value);
+} // namespace detail
 } // namespace mlir
 
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index f111bb5b32b49..6463f62b1923b 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -41,12 +41,83 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This interface allows custom types to be used as element types in
+    DenseElementsAttr. Types implementing this interface define:
+
+    1. The bit size for element storage. Only full byte sizes are supported
+       at the moment.
+    2. Helper methods for converting from/to Attribute. This assumes that there
+       is a corresponding attribute for each type that implements this
+       interface.
+
+    The helper methods for converting from/to Attribute are utilized when
+    parsing/printing IR or iterating over the elements via Attribute.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of bits required to store one element in dense
+        storage.
+        
+        Note: The DenseElementsAttr infrastructure will automatically align
+        every element to a full byte in storage. This limitation could be lifted
+        in the future to support dense packing of non-byte-sized elements.
+      }],
+      /*retTy=*/"size_t",
+      /*methodName=*/"getDenseElementBitSize",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getDefaultDenseElementBitSize($_type);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Attribute deserialization / attribute factory: Convert raw storage bytes
+        into an MLIR attribute. The size of `rawData` is
+        "ceilDiv(getDenseElementBitSize(), 8)".
+      }],
+      /*retTy=*/"::mlir::Attribute",
+      /*methodName=*/"convertToAttribute",
+      /*args=*/(ins "::llvm::ArrayRef<char>":$rawData),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::defaultConvertToAttribute($_type, rawData);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Attribute serialization: Convert an MLIR attribute into raw bytes.
+        Implementations must append "getDenseElementBitSize() / 8" values to
+        `result`. Return "failure" if the attribute is incompatible with this
+        element type.
+      }],
+      /*retTy=*/"::llvm::LogicalResult",
+      /*methodName=*/"convertFromAttribute",
+      /*args=*/(ins "::mlir::Attribute":$attr,
+                    "::llvm::SmallVectorImpl<char>&":$result),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::defaultConvertFromAttribute($_type, attr, result);
+      }]
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // FloatTypeInterface
 //===----------------------------------------------------------------------===//
 
 def FloatTypeInterface : TypeInterface<"FloatType",
-    [VectorElementTypeInterface]> {
+    [DenseElementTypeInterface, VectorElementTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
@@ -338,59 +409,4 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// DenseElementTypeInterface
-//===----------------------------------------------------------------------===//
-
-def DenseElementTypeInterface : TypeInterface<"DenseElementType"> {
-  let cppNamespace = "::mlir";
-  let description = [{
-    This interface allows custom types to be used as element types in
-    DenseElementsAttr. Types implementing this interface define:
-
-    1. The bit size for element storage. Only full byte sizes are supported
-       at the moment.
-    2. Helper methods for converting from/to Attribute. This assumes that there
-       is a corresponding attribute for each type that implements this
-       interface.
-
-    The helper methods for converting from/to Attribute are utilized when
-    parsing/printing IR or iterating over the elements via Attribute.
-  }];
-
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the number of bits required to store one element in dense
-        storage. This must be a compile-time constant for the type and must
-        be a multiple of 8 (byte-aligned).
-      }],
-      /*retTy=*/"size_t",
-      /*methodName=*/"getDenseElementBitSize",
-      /*args=*/(ins)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Convert raw storage bytes to an attribute representing this element
-        value. The `rawData` array contains exactly `getDenseElementBitSize()/8`
-        bytes.
-      }],
-      /*retTy=*/"::mlir::Attribute",
-      /*methodName=*/"convertToAttribute",
-      /*args=*/(ins "::llvm::ArrayRef<char>":$rawData)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Convert an attribute to raw storage bytes. Appends exactly
-        `getDenseElementBitSize()/8` bytes to `result`. Returns failure if the
-        attribute is incompatible with this element type.
-      }],
-      /*retTy=*/"::llvm::LogicalResult",
-      /*methodName=*/"convertFromAttribute",
-      /*args=*/(ins "::mlir::Attribute":$attr,
-                    "::llvm::SmallVectorImpl<char>&":$result)
-    >,
-  ];
-}
-
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 08847dd11c685..e671f96f2d0f3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -44,7 +44,10 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
 // ComplexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
+def Builtin_Complex : Builtin_Type<"Complex", "complex",
+    [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+      ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
+    ]> {
   let summary = "Complex number with a parameterized element type";
   let description = [{
     Syntax:
@@ -470,7 +473,9 @@ def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
 //===----------------------------------------------------------------------===//
 
 def Builtin_Index : Builtin_Type<"Index", "index",
-    [VectorElementTypeInterface]> {
+    [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+      ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>,
+     VectorElementTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -501,7 +506,9 @@ def Builtin_Index : Builtin_Type<"Index", "index",
 //===----------------------------------------------------------------------===//
 
 def Builtin_Integer : Builtin_Type<"Integer", "integer",
-    [VectorElementTypeInterface]> {
+    [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+      ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>,
+     VectorElementTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 09e5146c24e00..81cafb45abff2 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -957,12 +957,12 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
 
 /// Try to parse a dense elements attribute with the type-first syntax.
 /// Syntax: dense<TYPE : [ATTR, ATTR, ...]>
-/// This is used for element types implementing DenseElementTypeInterface.
+/// This syntax is used for types other than int, float, index and complex.
 ///
 /// Returns:
-///   - failure() on parse error
-///   - Attribute() (null) if this is not the type-first syntax
-///   - A valid Attribute on success
+///   - "null" attribute if this is not the type-first syntax.
+///   - "failure" in case of a parse error.
+///   - A valid Attribute otherwise.
 static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
   // Try to parse an optional type. Skip l_paren because parseOptionalType
   // would try to parse it as a tuple/function type, but '(' starts a complex
@@ -1010,7 +1010,11 @@ static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
 
   // Parse the element attributes and convert to raw bytes.
   SmallVector<char> rawData;
-  size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+  // Storage is byte-aligned: align bit size up to next byte boundary. This
+  // limitation could be lifted in the future to support dense packing of
+  // non-byte-sized elements.
+  size_t bitSize = denseEltType.getDenseElementBitSize();
+  size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT));
 
   // Helper to parse a single element.
   auto parseSingleElement = [&]() -> ParseResult {
@@ -1032,7 +1036,7 @@ static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) {
       return parseSingleElement();
 
     // Non-leaf: expect a list with the correct number of elements.
-    int64_t expectedCount = remainingShape[0];
+    int64_t expectedCount = remainingShape.front();
     ArrayRef<int64_t> innerShape = remainingShape.drop_front();
     int64_t actualCount = 0;
 
@@ -1094,7 +1098,6 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
     return nullptr;
 
   // Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]>
-  // This is used for element types implementing DenseElementTypeInterface.
   FailureOr<Attribute> typedResult =
       parseDenseElementsAttrTyped(*this, attribLoc);
   if (failed(typedResult))
@@ -1102,7 +1105,8 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   if (*typedResult)
     return *typedResult;
 
-  // Parse the literal data if necessary (old syntax: dense<LITERAL> : TYPE).
+  // Try to parse the literal-first syntax, which is the default format for
+  // int, float, index and complex element types.
   TensorLiteralParser literalParser(*this);
   if (!consumeIf(Token::greater)) {
     if (literalParser.parse(/*allowHex=*/true) ||
@@ -1110,10 +1114,10 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
       return nullptr;
   }
 
-  auto literalType = parseElementsLiteralType(attribLoc, attrType);
-  if (!literalType)
+  auto type = parseElementsLiteralType(attribLoc, attrType);
+  if (!type)
     return nullptr;
-  return literalParser.getAttr(attribLoc, literalType);
+  return literalParser.getAttr(attribLoc, type);
 }
 
 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 0897065849c25..b4f5a2b0ff67b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -507,15 +507,17 @@ class AsmPrinter::Impl {
   /// Print a dense string elements attribute.
   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
 
-  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
-  /// used instead of individual elements when the elements attr is large.
+  /// Print a dense elements attribute in the literal-first syntax. If
+  /// 'allowHex' is true, a hex string is used instead of individual elements
+  /// when the elements attr is large.
   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
                                      bool allowHex);
 
-  /// Print a dense elements attribute using DenseElementTypeInterface.
-  /// Uses the type-first syntax: dense<TYPE : [ATTR, ...]>
-  void printDenseElementsAttrWithInterface(DenseElementsAttr attr,
-                                           DenseElementType denseEltType);
+  /// Print a dense elements attribute using the type-first syntax and the
+  /// DenseElementTypeInterface, which provides the attribute printer for each
+  /// element.
+  void printTypeFirstDenseElementsAttr(DenseElementsAttr attr,
+                                       DenseElementType denseEltType);
 
   /// Print a dense array attribute.
   void printDenseArrayAttr(DenseArrayAttr attr);
@@ -2506,41 +2508,33 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
       printSymbolReference(nestedRef.getValue(), os);
     }
 
-  } else if (auto denseEltAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) {
-    // Check if the element type implements DenseElementTypeInterface.
-    // If so, use the type-first syntax which embeds the type in the attribute.
-    Type eltType = denseEltAttr.getElementType();
-    if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
-      if (printerFlags.shouldElideElementsAttr(denseEltAttr)) {
-        printElidedElementsAttr(os);
+  } else if (auto intOrFpEltAttr =
+                 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
+    if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
+      printElidedElementsAttr(os);
+    } else {
+      os << "dense<";
+      // Check if the element type implements DenseElementTypeInterface and is
+      // not a built-in type. Built-in types (int, float, index, complex) use
+      // the existing printing format for backwards compatibility.
+      Type eltType = intOrFpEltAttr.getElementType();
+      if (isa<FloatType, IntegerType, IndexType, ComplexType>(eltType)) {
+        printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
       } else {
-        os << "dense<";
-        printDenseElementsAttrWithInterface(denseEltAttr, denseEltType);
-        os << '>';
+        printTypeFirstDenseElementsAttr(intOrFpEltAttr,
+                                        cast<DenseElementType>(eltType));
+        typeElision = AttrTypeElision::Must;
       }
-      // Type is embedded in the syntax, don't print it again.
-      return;
+      os << '>';
     }
 
-    // Fall back to existing printing for built-in element types.
-    if (auto intOrFpEltAttr =
-            llvm::dyn_cast<DenseIntOrFPElementsAttr>(denseEltAttr)) {
-      if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
-        printElidedElementsAttr(os);
-      } else {
-        os << "dense<";
-        printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
-        os << '>';
-      }
-    } else if (auto strEltAttr =
-                   llvm::dyn_cast<DenseStringElementsAttr>(denseEltAttr)) {
-      if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
-        printElidedElementsAttr(os);
-      } else {
-        os << "dense<";
-        printDenseStringElementsAttr(strEltAttr);
-        os << '>';
-      }
+  } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
+    if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
+      printElidedElementsAttr(os);
+    } else {
+      os << "dense<";
+      printDenseStringElementsAttr(strEltAttr);
+      os << '>';
     }
 
   } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
@@ -2728,14 +2722,16 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
 }
 
-void AsmPrinter::Impl::printDenseElementsAttrWithInterface(
+void AsmPrinter::Impl::printTypeFirstDenseElementsAttr(
     DenseElementsAttr attr, DenseElementType denseEltType) {
   // Print the type first: dense<TYPE : [ELEMENTS]>
   printType(attr.getType());
   os << " : ";
 
   ArrayRef<char> rawData = attr.getRawData();
-  size_t byteSize = denseEltType.getDenseElementBitSize() / CHAR_BIT;
+  // Storage is byte-aligned: align bit size up to next byte boundary.
+  size_t bitSize = denseEltType.getDenseElementBitSize();
+  size_t byteSize = llvm::divideCeil(bitSize, (size_t)CHAR_BIT);
 
   // Print elements: convert raw bytes to attribute, then print attribute.
   printDenseElementsAttrImpl(
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 9055d58c5fe5d..b6b3a0551079d 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -33,14 +33,12 @@ namespace detail {
 
 /// Return the bit width which DenseElementsAttr should use for this type.
 inline size_t getDenseElementBitWidth(Type eltType) {
-  // Check for DenseElementTypeInterface first.
+  // i1 is stored as a single bit (bit-packed storage).
+  if (eltType.isInteger(1))
+    return 1;
+  // Check for DenseElementTypeInterface.
   if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
     return denseEltType.getDenseElementBitSize();
-  // Align the width for complex to 8 to make storage and interpretation easier.
-  if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
-    return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
-  if (eltType.isIndex())
-    return IndexType::kInternalStorageBitWidth;
   return eltType.getIntOrFloatBitWidth();
 }
 
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 90dc6d0129658..d9c5fd9acb811 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -10,6 +10,7 @@
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
@@ -543,7 +544,7 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
 }
 
 /// Writes value to the bit position `bitPos` in array `rawData`.
-static void writeBits(char *rawData, size_t bitPos, APInt value) {
+void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) {
   size_t bitWidth = value.getBitWidth();
 
   // If the bitwidth is 1 we just toggle the specific bit.
@@ -569,7 +570,8 @@ static void writeBits(char *rawData, size_t bitPos, APInt value) {
 
 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
 /// `rawData`.
-static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
+APInt mlir::detail::readBits(const char *rawData, size_t bitPos,
+                             size_t bitWidth) {
   // Handle a boolean bit position.
   if (bitWidth == 1)
     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
@@ -619,46 +621,27 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
   Type eltTy = owner.getElementType();
-  if (llvm::dyn_cast<IntegerType>(eltTy))
-    return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
-  if (llvm::isa<IndexType>(eltTy))
-    return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
-  if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
-    IntElementIterator intIt(owner, index);
-    FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
-    return FloatAttr::get(eltTy, *floatIt);
-  }
-  if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
-    auto complexEltTy = complexTy.getElementType();
-    ComplexIntElementIterator complexIntIt(owner, index);
-    if (llvm::isa<IntegerType>(complexEltTy)) {
-      auto value = *complexIntIt;
-      auto real = IntegerAttr::get(complexEltTy, value.real());
-      auto imag = IntegerAttr::get(complexEltTy, value.imag());
-      return ArrayAttr::get(complexTy.getContext(),
-                            ArrayRef<Attribute>{real, imag});
-    }
 
-    ComplexFloatElementIterator complexFloatIt(
-        llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
-    auto value = *complexFloatIt;
-    auto real = FloatAttr::get(complexEltTy, value.real());
-    auto imag = FloatAttr::get(complexEltTy, value.imag());
-    return ArrayAttr::get(complexTy.getContext(),
-                          ArrayRef<Attribute>{real, imag});
+  // Handle i1 (boolean) specially - it's bit-packed and doesn't use interface.
+  if (eltTy.isInteger(1)) {
+    bool value = *BoolElementIterator(owner, index);
+    return IntegerAttr::get(eltTy, APInt(1, value));
   }
+
+  // Handle strings specially.
   if (llvm::isa<DenseStringElementsAttr>(owner)) {
     ArrayRef<StringRef> vals = owner.getRawStringData();
     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
   }
-  // Check if the element type implements DenseElementTypeInterface.
-  if (auto denseEltTy = llvm::dyn_cast<DenseElementType>(eltTy)) {
-    ArrayRef<char> rawData = owner.getRawData();
-    size_t byteSize = denseEltTy.getDenseElementBitSize() / CHAR_BIT;
-    size_t offset = owner.isSplat() ? 0 : index * byteSize;
-    return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
-  }
-  llvm_unreachable("unexpected element type");
+
+  // All other types should implement DenseElementTypeInterface.
+  auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
+  ArrayRef<char> rawData = owner.getRawData();
+  // Storage is byte-aligned: align bit size up to next byte boundary.
+  size_t bitSize = denseEltTy.getDenseElementBitSize();
+  size_t byteSize = llvm::divideCeil(bitSize, CHAR_BIT);
+  size_t offset = owner.isSplat() ? 0 : index * byteSize;
+  return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize));
 }
 
 //===----------------------------------------------------------------------===//
@@ -920,94 +903,37 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 
   Type eltType = type.getElementType();
 
-  // Take care complex type case first.
-  if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
-    if (complexType.getElementType().isIntOrIndex()) {
-      SmallVector<std::complex<APInt>> complexValues;
-      complexValues.reserve(values.size());
-      for (Attribute attr : values) {
-        assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
-        auto arrayAttr = llvm::cast<ArrayAttr>(attr);
-        assert(arrayAttr.size() == 2 && "expected 2 element for complex");
-        auto attr0 = arrayAttr[0];
-        auto attr1 = arrayAttr[1];
-        complexValues.push_back(
-            std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
-                                llvm::cast<IntegerAttr>(attr1).getValue()));
-      }
-      return DenseElementsAttr::get(type, complexValues);
-    }
-    // Must be float.
-    SmallVector<std::complex<APFloat>> complexValues;
-    complexValues.reserve(values.size());
-    for (Attribute attr : values) {
-      assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
-      auto arrayAttr = llvm::cast<ArrayAttr>(attr);
-      assert(arrayAttr.size() == 2 && "expected 2 element for complex");
-      auto attr0 = arrayAttr[0];
-      auto attr1 = arrayAttr[1];
-      complexValues.push_back(
-          std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
-                                llvm::cast<FloatAttr>(attr1).getValue()));
-    }
-    return DenseElementsAttr::get(type, complexValues);
-  }
-
-  // Check if the element type implements DenseElementTypeInterface.
-  if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) {
-    SmallVector<char> data;
-    for (Attribute attr : values) {
-      SmallVector<char> elementData;
-      if (failed(denseEltType.convertFromAttribute(attr, elementData))) {
-        llvm_unreachable("incompatible attribute for DenseElementType");
-      }
-      llvm::append_range(data, elementData);
-    }
-    return DenseIntOrFPElementsAttr::getRaw(type, data);
+  // Handle i1 (boolean) specially - it's bit-packed.
+  if (eltType.isInteger(1)) {
+    SmallVector<bool> boolValues;
+    boolValues.reserve(values.size());
+    for (Attribute attr : values)
+      boolValues.push_back(llvm::cast<IntegerAttr>(attr).getValue().isOne());
+    return get(type, boolValues);
   }
 
-  // If the element type is not based on int/float/index, assume it is a string
-  // type.
-  if (!eltType.isIntOrIndexOrFloat()) {
+  // Handle strings specially.
+  if (!llvm::isa<DenseElementType>(eltType)) {
     SmallVector<StringRef, 8> stringValues;
     stringValues.reserve(values.size());
     for (Attribute attr : values) {
       assert(llvm::isa<StringAttr>(attr) &&
-             "expected string value for non integer/index/float element");
+             "expected string value for non-DenseElementType element");
       stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
     }
     return get(type, stringValues);
   }
 
-  // Otherwise, get the raw storage width to use for the allocation.
-  size_t bitWidth = getDenseElementBitWidth(eltType);
-  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-
-  // Compress the attribute values into a character buffer.
-  SmallVector<char, 8> data(
-      llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
-  APInt intVal;
-  for (unsigned i = 0, e = values.size(); i < e; ++i) {
-    if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
-      assert(floatAttr.getType() == eltType &&
-             "expected float attribute type to equal element type");
-      intVal = floatAttr.getValue().bitcastToAPInt();
-    } else {
-      auto intAttr = llvm::cast<IntegerAttr>(values[i]);
-      assert(intAttr.getType() == eltType &&
-             "expected integer attribute type to equal element type");
-      intVal = intAttr.getValue();
-    }
-
-    assert(intVal.getBitWidth() == bitWidth &&
-           "expected value to have same bitwidth as element type");
-    writeBits(data.data(), i * storageBitWidth, intVal);
+  // All other types go through DenseElementTypeInterface.
+  auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
+  assert(denseEltType &&
+         "attempted to get DenseElementsAttr with unsupported element type");
+  SmallVector<char> data;
+  for (Attribute attr : values) {
+    LogicalResult result = denseEltType.convertFromAttribute(attr, data);
+    assert(succeeded(result) && "incompatible attribute for DenseElementType");
+    (void)result;
   }
-
-  // Handle the special encoding of splat of bool.
-  if (values.size() == 1 && eltType.isInteger(1))
-    data[0] = data[0] ? -1 : 0;
-
   return DenseIntOrFPElementsAttr::getRaw(type, data);
 }
 
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index 2f063be3e7cd0..4db0ed7a0c80a 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -6,9 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/Support/CheckedArithmetic.h"
+#include "llvm/Support/MathExtras.h"
+#include <climits>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -19,6 +22,53 @@ using namespace mlir::detail;
 
 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// DenseElementTypeInterface default implementations
+//===----------------------------------------------------------------------===//
+
+size_t mlir::detail::getDefaultDenseElementBitSize(Type type) {
+  // TODO: This implementation should be defined on FloatType. However, due to
+  // TableGen limitations, a type interface cannot provide an implementation for
+  // an interface method from a base type interface.
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
+    llvm_unreachable("getDenseElementBitSize not implemented");
+  return floatType.getWidth();
+}
+
+Attribute mlir::detail::defaultConvertToAttribute(Type type,
+                                                  ArrayRef<char> rawData) {
+  // TODO: This implementation should be defined on FloatType. However, due to
+  // TableGen limitations, a type interface cannot provide an implementation for
+  // an interface method from a base type interface.
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
+    llvm_unreachable("convertToAttribute not implemented");
+  APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth());
+  APFloat floatVal(floatType.getFloatSemantics(), intVal);
+  return FloatAttr::get(type, floatVal);
+}
+
+LogicalResult
+mlir::detail::defaultConvertFromAttribute(Type type, Attribute attr,
+                                          SmallVectorImpl<char> &result) {
+  // TODO: This implementation should be defined on FloatType. However, due to
+  // TableGen limitations, a type interface cannot provide an implementation for
+  // an interface method from a base type interface.
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
+    llvm_unreachable("convertFromAttribute not implemented");
+  auto floatAttr = dyn_cast<FloatAttr>(attr);
+  if (!floatAttr || floatAttr.getType() != type)
+    return failure();
+  size_t byteSize =
+      llvm::divideCeil(floatType.getWidth(), static_cast<unsigned>(CHAR_BIT));
+  size_t bitPos = result.size() * CHAR_BIT;
+  result.resize(result.size() + byteSize);
+  writeBits(result.data(), bitPos, floatAttr.getValue().bitcastToAPInt());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FloatType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d0165db058683..7c0a75d82879b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -18,9 +18,11 @@
 #include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CheckedArithmetic.h"
+#include <cstring>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -86,6 +88,125 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
   return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
 }
 
+size_t IntegerType::getDenseElementBitSize() const {
+  // Return the actual bit width. Storage alignment is handled separately.
+  // Note: i1 is bit-packed and should be special-cased by the caller.
+  return getWidth();
+}
+
+Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const {
+  APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth());
+  return IntegerAttr::get(*this, value);
+}
+
+LogicalResult
+IntegerType::convertFromAttribute(Attribute attr,
+                                  SmallVectorImpl<char> &result) const {
+  auto intAttr = dyn_cast<IntegerAttr>(attr);
+  if (!intAttr || intAttr.getType() != *this)
+    return failure();
+
+  size_t byteSize = llvm::divideCeil(getDenseElementBitSize(), CHAR_BIT);
+  size_t bitPos = result.size() * CHAR_BIT;
+  result.resize(result.size() + byteSize);
+  detail::writeBits(result.data(), bitPos, intAttr.getValue());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Index Type
+//===----------------------------------------------------------------------===//
+
+size_t IndexType::getDenseElementBitSize() const {
+  return kInternalStorageBitWidth;
+}
+
+Attribute IndexType::convertToAttribute(ArrayRef<char> rawData) const {
+  APInt value =
+      detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth);
+  return IntegerAttr::get(*this, value);
+}
+
+LogicalResult
+IndexType::convertFromAttribute(Attribute attr,
+                                SmallVectorImpl<char> &result) const {
+  auto intAttr = dyn_cast<IntegerAttr>(attr);
+  if (!intAttr || intAttr.getType() != *this)
+    return failure();
+
+  size_t byteSize = kInternalStorageBitWidth / CHAR_BIT;
+  size_t bitPos = result.size() * CHAR_BIT;
+  result.resize(result.size() + byteSize);
+  detail::writeBits(result.data(), bitPos, intAttr.getValue());
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Complex Type
+//===----------------------------------------------------------------------===//
+
+size_t ComplexType::getDenseElementBitSize() const {
+  Type eltType = getElementType();
+  if (auto intType = dyn_cast<IntegerType>(eltType))
+    return 2 * llvm::alignTo<CHAR_BIT>(intType.getWidth());
+  return 2 * cast<FloatType>(eltType).getWidth();
+}
+
+Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const {
+  Type eltType = getElementType();
+  size_t bitWidth = getDenseElementBitSize() / 2;
+
+  if (auto intType = dyn_cast<IntegerType>(eltType)) {
+    APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
+    APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
+    auto real = IntegerAttr::get(eltType, realVal);
+    auto imag = IntegerAttr::get(eltType, imagVal);
+    return ArrayAttr::get(getContext(), {real, imag});
+  }
+
+  auto floatType = cast<FloatType>(eltType);
+  const auto &semantics = floatType.getFloatSemantics();
+  APInt realVal = detail::readBits(rawData.data(), /*bitPos=*/0, bitWidth);
+  APInt imagVal = detail::readBits(rawData.data(), bitWidth, bitWidth);
+  auto real = FloatAttr::get(eltType, APFloat(semantics, realVal));
+  auto imag = FloatAttr::get(eltType, APFloat(semantics, imagVal));
+  return ArrayAttr::get(getContext(), {real, imag});
+}
+
+LogicalResult
+ComplexType::convertFromAttribute(Attribute attr,
+                                  SmallVectorImpl<char> &result) const {
+  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
+  if (!arrayAttr || arrayAttr.size() != 2)
+    return failure();
+
+  Type eltType = getElementType();
+  size_t bitWidth = getDenseElementBitSize() / 2;
+  size_t byteSize = llvm::divideCeil(bitWidth, (size_t)CHAR_BIT);
+  size_t bitPos = result.size() * CHAR_BIT;
+  result.resize(result.size() + 2 * byteSize);
+
+  if (auto intType = dyn_cast<IntegerType>(eltType)) {
+    auto realAttr = dyn_cast<IntegerAttr>(arrayAttr[0]);
+    auto imagAttr = dyn_cast<IntegerAttr>(arrayAttr[1]);
+    if (!realAttr || !imagAttr)
+      return failure();
+    detail::writeBits(result.data(), bitPos, realAttr.getValue());
+    detail::writeBits(result.data(), bitPos + bitWidth, imagAttr.getValue());
+    return success();
+  }
+
+  auto realAttr = dyn_cast<FloatAttr>(arrayAttr[0]);
+  auto imagAttr = dyn_cast<FloatAttr>(arrayAttr[1]);
+  if (!realAttr || !imagAttr)
+    return failure();
+  detail::writeBits(result.data(), bitPos,
+                    realAttr.getValue().bitcastToAPInt());
+  detail::writeBits(result.data(), bitPos + bitWidth,
+                    imagAttr.getValue().bitcastToAPInt());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Float Types
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir
index 8aa1386cfeb73..579a2ca3d1551 100644
--- a/mlir/test/IR/dense-elements-type-interface.mlir
+++ b/mlir/test/IR/dense-elements-type-interface.mlir
@@ -6,23 +6,32 @@
 
 // CHECK-LABEL: func @dense_custom_element_type
 func.func @dense_custom_element_type() {
-  // The type is embedded in the dense attribute syntax, not printed separately.
-  // CHECK: "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
-  "unregistered_op"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
+  // CHECK: "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>}
+  "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
   return
 }
 
 // CHECK-LABEL: func @dense_custom_element_type_2d
 func.func @dense_custom_element_type_2d() {
-  // CHECK: "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
-  "unregistered_op"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
+  // CHECK: "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>}
+  "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> ()
   return
 }
 
 // CHECK-LABEL: func @dense_custom_element_splat
 func.func @dense_custom_element_splat() {
-  // A splat should be detected and stored efficiently
-  // CHECK: "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
-  "unregistered_op"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+  // CHECK: "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>}
+  "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> ()
+  return
+}
+
+// CHECK-LABEL func @dense_i32_1d
+func.func @dense_i32_1d() {
+  // The default assembly format for int, index, float, complex element types is
+  // the literal-first syntax. Such a dense elements attribute can be parsed
+  // with the type-first syntax, but it will come back with the literal-first
+  // syntax.
+  // CHECK: "test.dummy"() {attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+  "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> ()
   return
 }
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index cfbadc6aa8a7a..08600ce713a17 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -513,12 +513,10 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> {
   let hasCustomAssemblyFormat = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// Test type for DenseElementTypeInterface
-//===----------------------------------------------------------------------===//
-
 def TestTypeDenseElement : Test_Type<"TestDenseElement",
-    [DeclareTypeInterfaceMethods<DenseElementTypeInterface>]> {
+    [DeclareTypeInterfaceMethods<DenseElementTypeInterface,
+      ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>
+    ]> {
   let mnemonic = "dense_element";
   let description = [{
     A test type that implements DenseElementTypeInterface to test dense



More information about the Mlir-commits mailing list