[Mlir-commits] [mlir] [mlir][LLVM] handle ArrayAttr for constant array of structs (PR #139724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 13 06:00:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (jeanPerier)
<details>
<summary>Changes</summary>
While LLVM IR dialect has a way to represent arbitrary LLVM constant array of structs via an insert chain, it is in practice very expensive for the compilation time as soon as the array is bigger than a couple hundred elements. This is because generating and later folding such insert chain is really not cheap.
This is an issue for flang because it is very easy to generate rather big array of struct constant initializer in Fortran, and unlike C++, dynamic initialization of globals is not a feature of the language. Initializers must be static.
For instance, here are the compile time I measuring for the following program changing N size:
```
! test.F90
module m
type t
integer :: i = 42
real :: x = 1.0
end type
type(t) :: some_global(N)
end module
```
```
/usr/bin/time flang -c -DN=1000 test.F90
0.08user 0.07system 0:00.11elapsed 140%CPU (0avgtext+0avgdata 83968maxresident)k
8inputs+40outputs (13major+4840minor)pagefaults 0swaps
/usr/bin/time flang -c -DN=10000 test.F90
1.40user 0.08system 0:01.44elapsed 102%CPU (0avgtext+0avgdata 89088maxresident)k
8inputs+184outputs (13major+8764minor)pagefaults 0swaps
/usr/bin/time flang -c -DN=100000 test.F90
137.79user 0.22system 2:18.00elapsed 100%CPU (0avgtext+0avgdata 145540maxresident)k
8inputs+1584outputs (10major+82461minor)pagefaults 0swap
```
In the last case, more than 99.99% of the time is spend folding the insert chain in ModuleTranslation.cpp
With this patch (and updating flang to generate an `ArrayAttr` instead of an insert chain), the last case with 100000 elements takes 0.15s to compile (~1000x compilation speed up :)).
This is not a silver bullet because there are cases where an insert chain will still currently be needed, like when the initial values contain symbol reference, but this is not very common for my use case.
---
Full diff: https://github.com/llvm/llvm-project/pull/139724.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+44-9)
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+27)
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+22)
- (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (-5)
- (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+17)
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c757f3ceb90e3..1868995e3f5ed 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3221,15 +3221,50 @@ LogicalResult LLVM::ConstantOp::verify() {
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
return emitOpError() << "expected vector or array type";
// The number of elements of the attribute and the type must match.
- int64_t attrNumElements;
- if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
- attrNumElements = elementsAttr.getNumElements();
- else
- attrNumElements = cast<ArrayAttr>(getValue()).size();
- if (getNumElements(getType()) != attrNumElements)
- return emitOpError()
- << "type and attribute have a different number of elements: "
- << getNumElements(getType()) << " vs. " << attrNumElements;
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
+ int64_t attrNumElements = elementsAttr.getNumElements();
+ if (getNumElements(getType()) != attrNumElements)
+ return emitOpError()
+ << "type and attribute have a different number of elements: "
+ << getNumElements(getType()) << " vs. " << attrNumElements;
+ } else {
+ // When the attribute is an ArrayAttr, check that its nesting matches the
+ // corresponding ArrayType or VectorType nesting.
+ Type dimType = getType();
+ Attribute dimVal = getValue();
+ int dim = 0;
+ while (true) {
+ int64_t dimSize =
+ llvm::TypeSwitch<Type, int64_t>(dimType)
+ .Case<VectorType, LLVMArrayType>([&dimType](auto t) -> int64_t {
+ dimType = t.getElementType();
+ return t.getNumElements();
+ })
+ .Default([](auto) -> int64_t { return -1; });
+ if (dimSize < 0)
+ break;
+ auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+ if (!arrayAttr)
+ return emitOpError()
+ << "array attribute nesting must match array type nesting";
+ if (dimSize != static_cast<int64_t>(arrayAttr.size()))
+ return emitOpError()
+ << "array attribute size does not match array type size in "
+ "dimension "
+ << dim << ": " << arrayAttr.size() << " vs. " << dimSize;
+ if (arrayAttr.size() == 0)
+ break;
+ dimVal = arrayAttr.getValue()[0];
+ ++dim;
+ }
+ if (auto structType = dyn_cast<LLVMStructType>(dimType)) {
+ auto arrayAttr = dyn_cast<ArrayAttr>(dimVal);
+ if (!arrayAttr || arrayAttr.size() != structType.getBody().size())
+ return emitOpError()
+ << "nested attribute must be an array attribute with the same "
+ "number of elements as the struct type";
+ }
+ }
} else {
return emitOpError()
<< "only supports integer, float, string or elements attributes";
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1168b9f339904..1d4509ccb044e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -713,6 +713,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
ArrayRef<char>{stringAttr.getValue().data(),
stringAttr.getValue().size()});
}
+
+ // Handle arrays of structs that cannot be represented as DenseElementsAttr
+ // in MLIR.
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
+ llvm::Type *elementType = arrayTy->getElementType();
+ Attribute previousElementAttr;
+ llvm::Constant *elementCst = nullptr;
+ SmallVector<llvm::Constant *> constants;
+ constants.reserve(arrayTy->getNumElements());
+ for (auto elementAttr : arrayAttr) {
+ // Arrays with a single value or with repeating values are quite common.
+ // short-circuit the translation when the element value is the same as
+ // the previous one.
+ if (!previousElementAttr || previousElementAttr != elementAttr) {
+ previousElementAttr = elementAttr;
+ elementCst =
+ getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
+ if (!elementCst)
+ return nullptr;
+ }
+ constants.push_back(elementCst);
+ }
+ return llvm::ConstantArray::get(arrayTy, constants);
+ }
+ }
+
emitError(loc, "unsupported constant value");
return nullptr;
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9ea066a63624..4c82e586b8a3c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1850,3 +1850,25 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
llvm.return
}
+
+// -----
+
+llvm.mlir.global @x1() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+ // expected-error at +1{{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
+ %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+ llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// -----
+llvm.mlir.global @x2() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> {
+ // expected-error at +1{{'llvm.mlir.constant' op array attribute nesting must match array type nesting}}
+ %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+ llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
+}
+
+// -----
+llvm.mlir.global @x3() : !llvm.array<1x!llvm.struct<(i32, f32)>> {
+ // expected-error at +1{{'llvm.mlir.constant' op nested attribute must be an array attribute with the same number of elements as the struct type}}
+ %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>>
+ llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 90c0f5ac55cb1..24a7b42557278 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
// -----
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 4ef68fa83a70d..242a151116fb3 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -3000,3 +3000,20 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} {
llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> ()
llvm.return
}
+
+// -----
+
+// CHECK: @test_array_attr_1 = internal constant [2 x double] [double 2.500000e+00, double 7.400000e+00]
+llvm.mlir.global internal constant @test_array_attr_1([2.5, 7.4]) : !llvm.array<2 x f64>
+
+// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
+llvm.mlir.global @test_array_attr_2() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
+ %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
+ llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
+}
+
+// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 }
+llvm.mlir.global @test_array_attr_3() : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>> {
+ %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32],[2 : i32, 1.000000e+00 : f32],[3 : i32, 1.000000e+00 : f32]],[[4 : i32, 1.000000e+00 : f32],[5 : i32, 1.000000e+00 : f32],[6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+ llvm.return %0 : !llvm.array<2x!llvm.array<3x!llvm.struct<(i32, f32)>>>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/139724
More information about the Mlir-commits
mailing list