[Mlir-commits] [mlir] 80816e7 - [mlir][LLVM] handle ArrayAttr for constant array of structs (#139724)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 20 01:45:33 PDT 2025


Author: jeanPerier
Date: 2025-05-20T10:45:29+02:00
New Revision: 80816e792382da286b29f937938ab54ae159f482

URL: https://github.com/llvm/llvm-project/commit/80816e792382da286b29f937938ab54ae159f482
DIFF: https://github.com/llvm/llvm-project/commit/80816e792382da286b29f937938ab54ae159f482.diff

LOG: [mlir][LLVM] handle ArrayAttr for constant array of structs (#139724)

While LLVM IR dialect has a way to represent arbitrary LLVM constant
array of structs via an insert chain, it is in practice very expensive
for the compilation time as soon as the array is bigger than a couple
hundred elements. This is because generating and later folding such
insert chain is really not cheap.

This patch allows representing array of struct constants via ArrayAttr in
the LLVM dialect.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f19f9d5a3083c..61ba8f7b991c8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -2073,9 +2073,9 @@ def LLVM_ConstantOp
     Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
     all constants must be created as SSA values before being used in other
     operations. `llvm.mlir.constant` creates such values for scalars, vectors,
-    strings, and structs. It has a mandatory `value` attribute whose type
-    depends on the type of the constant value. The type of the constant value
-    must correspond to the attribute type converted to LLVM IR type.
+    strings, structs, and array of structs. It has a mandatory `value` attribute
+    whose type depends on the type of the constant value. The type of the constant
+    value must correspond to the attribute type converted to LLVM IR type.
 
     When creating constant scalars, the `value` attribute must be either an
     integer attribute or a floating point attribute. The type of the attribute
@@ -2097,6 +2097,11 @@ def LLVM_ConstantOp
     must correspond to the type of the corresponding attribute element converted
     to LLVM IR.
 
+    When creating an array of structs, the `value` attribute must be an array
+    attribute, itself containing zero, or undef, or array attributes for each
+    potential nested array type, and the elements of the leaf array attributes
+    for must match the struct element types or be zero or undef attributes.
+
     Examples:
 
     ```mlir

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c757f3ceb90e3..d8abf6fd41301 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3142,6 +3142,74 @@ static bool hasScalableVectorType(Type t) {
   return false;
 }
 
+/// Verifies the constant array represented by `arrayAttr` matches the provided
+/// `arrayType`.
+static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
+                                               LLVM::LLVMArrayType arrayType,
+                                               ArrayAttr arrayAttr, int dim) {
+  if (arrayType.getNumElements() != arrayAttr.size())
+    return op.emitOpError()
+           << "array attribute size does not match array type size in "
+              "dimension "
+           << dim << ": " << arrayAttr.size() << " vs. "
+           << arrayType.getNumElements();
+
+  llvm::DenseSet<Attribute> elementsVerified;
+
+  // Recursively verify sub-dimensions for multidimensional arrays.
+  if (auto subArrayType =
+          dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
+    for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
+      if (elementsVerified.insert(elementAttr).second) {
+        if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
+          continue;
+        auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
+        if (!subArrayAttr)
+          return op.emitOpError()
+                 << "nested attribute for sub-array in dimension " << dim
+                 << " at index " << idx
+                 << " must be a zero, or undef, or array attribute";
+        if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
+                                             dim + 1)))
+          return failure();
+      }
+    return success();
+  }
+
+  // Forbid usages of ArrayAttr for simple array types that should use
+  // DenseElementsAttr instead. Note that there would be a use case for such
+  // array types when one element value is obtained via a ptr-to-int conversion
+  // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
+  // user needs this so far, and it seems better to avoid people misusing the
+  // ArrayAttr for simple types.
+  auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType());
+  if (!structType)
+    return op.emitOpError() << "for array with an array attribute must have a "
+                               "struct element type";
+
+  // Shallow verification that leaf attributes are appropriate as struct initial
+  // value.
+  size_t numStructElements = structType.getBody().size();
+  for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
+    if (elementsVerified.insert(elementAttr).second) {
+      if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
+        continue;
+      auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
+      if (!subArrayAttr)
+        return op.emitOpError()
+               << "nested attribute for struct element at index " << idx
+               << " must be a zero, or undef, or array attribute";
+      if (subArrayAttr.size() != numStructElements)
+        return op.emitOpError()
+               << "nested array attribute size for struct element at index "
+               << idx << " must match struct size: " << subArrayAttr.size()
+               << " vs. " << numStructElements;
+    }
+  }
+
+  return success();
+}
+
 LogicalResult LLVM::ConstantOp::verify() {
   if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
     auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -3208,7 +3276,7 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
       return emitOpError() << "expected integer type of width " << floatWidth;
     }
-  } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
+  } else if (isa<ElementsAttr>(getValue())) {
     if (hasScalableVectorType(getType())) {
       // The exact number of elements of a scalable vector is unknown, so we
       // allow only splat attributes.
@@ -3221,15 +3289,20 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
       return emitOpError() << "expected vector or array type";
     // The number of elements of the attribute and the type must match.
-    int64_t attrNumElements;
-    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
-      attrNumElements = elementsAttr.getNumElements();
-    else
-      attrNumElements = cast<ArrayAttr>(getValue()).size();
-    if (getNumElements(getType()) != attrNumElements)
-      return emitOpError()
-             << "type and attribute have a 
diff erent number of elements: "
-             << getNumElements(getType()) << " vs. " << attrNumElements;
+    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
+      int64_t attrNumElements = elementsAttr.getNumElements();
+      if (getNumElements(getType()) != attrNumElements)
+        return emitOpError()
+               << "type and attribute have a 
diff erent number of elements: "
+               << getNumElements(getType()) << " vs. " << attrNumElements;
+    }
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
+    auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
+    if (!arrayType)
+      return emitOpError() << "expected array type";
+    // When the attribute is an ArrayAttr, check that its nesting matches the
+    // corresponding ArrayType or VectorType nesting.
+    return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
   } else {
     return emitOpError()
            << "only supports integer, float, string or elements attributes";

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 95b8ee0331c55..9b5c93171abfd 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -553,8 +553,10 @@ static llvm::Constant *convertDenseResourceElementsAttr(
 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *llvmType, Attribute attr, Location loc,
     const ModuleTranslation &moduleTranslation) {
-  if (!attr)
+  if (!attr || isa<UndefAttr>(attr))
     return llvm::UndefValue::get(llvmType);
+  if (isa<ZeroAttr>(attr))
+    return llvm::Constant::getNullValue(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
     if (!arrayAttr) {
@@ -713,6 +715,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         ArrayRef<char>{stringAttr.getValue().data(),
                        stringAttr.getValue().size()});
   }
+
+  // Handle arrays of structs that cannot be represented as DenseElementsAttr
+  // in MLIR.
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
+      llvm::Type *elementType = arrayTy->getElementType();
+      Attribute previousElementAttr;
+      llvm::Constant *elementCst = nullptr;
+      SmallVector<llvm::Constant *> constants;
+      constants.reserve(arrayTy->getNumElements());
+      for (Attribute elementAttr : arrayAttr) {
+        // Arrays with a single value or with repeating values are quite common.
+        // Short-circuit the translation when the element value is the same as
+        // the previous one.
+        if (!previousElementAttr || previousElementAttr != elementAttr) {
+          previousElementAttr = elementAttr;
+          elementCst =
+              getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
+          if (!elementCst)
+            return nullptr;
+        }
+        constants.push_back(elementCst);
+      }
+      return llvm::ConstantArray::get(arrayTy, constants);
+    }
+  }
+
   emitError(loc, "unsupported constant value");
   return nullptr;
 }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9ea066a63624..f5adf4b3bf33d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1850,3 +1850,35 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
   llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
   llvm.return
 }
+
+// -----
+
+llvm.mlir.global @bad_struct_array_init_size() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+  // expected-error at below {{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// -----
+
+llvm.mlir.global @bad_struct_array_init_nesting() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> {
+  // expected-error at below {{'llvm.mlir.constant' op nested attribute for sub-array in dimension 1 at index 0 must be a zero, or undef, or array attribute}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+  llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+}
+
+// -----
+
+llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<(i32, f32)>> {
+  // expected-error at below {{'llvm.mlir.constant' op nested array attribute size for struct element at index 0 must match struct size: 1 vs. 2}}
+  %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
+}
+
+// ----
+
+llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> {
+  // expected-error at below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}}
+  %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
+  llvm.return %0 : !llvm.array<2 x f64>
+}

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 90c0f5ac55cb1..24a7b42557278 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
 
 // -----
 
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
 // expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
 

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 3c8de1cf63b94..237612244d8de 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -3022,3 +3022,29 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} {
   llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
+llvm.mlir.global @test_array_attr_2() : !llvm.array<2 x !llvm.struct<(i32, f32)>> {
+  %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2 x !llvm.struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<2 x !llvm.struct<(i32, f32)>>
+}
+
+// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 }
+llvm.mlir.global @test_array_attr_3() : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> {
+  %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32], [2 : i32, 1.000000e+00 : f32], [3 : i32, 1.000000e+00 : f32]], [[4 : i32, 1.000000e+00 : f32], [5 : i32, 1.000000e+00 : f32], [6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>>
+  llvm.return %0 : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>>
+}
+
+// CHECK: @test_array_attr_struct_with_ptr = internal constant [2 x { ptr }] [{ ptr } zeroinitializer, { ptr } undef]
+llvm.mlir.global internal constant @test_array_attr_struct_with_ptr() : !llvm.array<2 x struct<(ptr)>> {
+  %0 = llvm.mlir.constant([[#llvm.zero], [#llvm.undef]]) : !llvm.array<2 x struct<(ptr)>>
+  llvm.return %0 : !llvm.array<2 x struct<(ptr)>>
+}
+
+// CHECK: @test_array_attr_struct_with_struct = internal constant [3 x { i32, float }] [{ i32, float } zeroinitializer, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } undef]
+llvm.mlir.global internal constant @test_array_attr_struct_with_struct() : !llvm.array<3 x struct<(i32, f32)>> {
+  %0 = llvm.mlir.constant([#llvm.zero, [2 : i32, 1.0 : f32], #llvm.undef]) : !llvm.array<3 x struct<(i32, f32)>>
+  llvm.return %0 : !llvm.array<3 x struct<(i32, f32)>>
+}


        


More information about the Mlir-commits mailing list