[Mlir-commits] [mlir] [mlir][LLVM] Add support for constant struct with multiple fields (PR #102752)
Sirui Mu
llvmlistbot at llvm.org
Sat Aug 10 06:47:33 PDT 2024
https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/102752
Currently `mlir.llvm.constant` of structure types restricts that the structure type effectively represents a complex type -- it must have exactly two fields of the same type and the field type must be either an integer type or a float type.
This PR relaxes this restriction and it allows the structure type to have an arbitrary number of fields.
>From f5ff4a958956c01859fd5bb9f36f488b41cee6ba Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sat, 10 Aug 2024 21:40:14 +0800
Subject: [PATCH] [mlir][LLVM] Add support for constant struct with multiple
fields
Currently `mlir.llvm.constant` of structure types restricts that the structure
type effectively represents a complex type -- it must have exactly two fields
of the same type and the field type must be either an integer type or a float
type.
This patch relaxes this restriction and it allows the structure type to have an
arbitrary number of fields.
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 39 ++++++++++----------
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 24 ++++++------
mlir/test/Dialect/LLVMIR/invalid.mlir | 10 ++---
mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 6 +--
mlir/test/Target/LLVMIR/llvmir.mlir | 6 +++
5 files changed, 45 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..9ee74cb562d81b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2678,31 +2678,30 @@ LogicalResult LLVM::ConstantOp::verify() {
return success();
}
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
- if (structType.getBody().size() != 2 ||
- structType.getBody()[0] != structType.getBody()[1]) {
- return emitError() << "expected struct type with two elements of the "
- "same type, the type of a complex constant";
- }
-
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
- if (!arrayAttr || arrayAttr.size() != 2) {
- return emitOpError() << "expected array attribute with two elements, "
- "representing a complex constant";
+ if (!arrayAttr) {
+ return emitOpError() << "expected array attribute for a struct constant";
}
- auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
- if (!re || !im || re.getType() != im.getType()) {
- return emitOpError()
- << "expected array attribute with two elements of the same type";
+
+ ArrayRef<Type> elementTypes = structType.getBody();
+ if (arrayAttr.size() != elementTypes.size()) {
+ return emitOpError() << "expected array attribute of size "
+ << elementTypes.size();
}
- Type elementType = structType.getBody()[0];
- if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
- elementType)) {
- return emitError()
- << "expected struct element types to be floating point type or "
- "integer type";
+ for (size_t i = 0; i < elementTypes.size(); ++i) {
+ auto element = arrayAttr[i];
+ if (!mlir::isa<IntegerAttr, FloatAttr>(element)) {
+ return emitOpError() << "expected struct element types to be floating "
+ "point type or integer type";
+ }
+ auto elementType = mlir::cast<TypedAttr>(element).getType();
+ if (elementType != elementTypes[i]) {
+ return emitOpError()
+ << "struct element at index " << i << " is of wrong type";
+ }
}
+
return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b468228ea78b78..f2f992a2180c7e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -557,20 +557,20 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2) {
- emitError(loc, "expected struct type to be a complex number");
+ if (!arrayAttr) {
+ emitError(loc, "expected an array attribute for a struct constant");
return nullptr;
}
- llvm::Type *elementType = structType->getElementType(0);
- llvm::Constant *real =
- getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
- if (!real)
- return nullptr;
- llvm::Constant *imag =
- getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
- if (!imag)
- return nullptr;
- return llvm::ConstantStruct::get(structType, {real, imag});
+ llvm::SmallVector<llvm::Constant *, 8> structElements;
+ structElements.reserve(structType->getNumElements());
+ for (size_t i = 0; i < arrayAttr.size(); ++i) {
+ llvm::Constant *element = getLLVMConstant(
+ structType->getElementType(i), arrayAttr[i], loc, moduleTranslation);
+ if (!element)
+ return nullptr;
+ structElements.push_back(element);
+ }
+ return llvm::ConstantStruct::get(structType, structElements);
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..04f8aeb0afcd8b 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() {
// -----
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+ // expected-error @+1 {{expected array attribute of size 2}}
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements of the same type}}
+ // expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+ // expected-error @+1 {{expected array attribute}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
- // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+ // expected-error @+1 {{expected array attribute of size 1}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
llvm.return %0 : !llvm.struct<(f64)>
}
@@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
// -----
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
- // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+ // expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
llvm.return %0 : !llvm.struct<(f64, f32)>
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 9cf922ad490a92..e04c54e8451681 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -16,7 +16,7 @@ llvm.func @vector_with_non_vector_type() -> f32 {
// -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
+ // expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}
@@ -24,7 +24,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
// -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
+ // expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}
@@ -32,7 +32,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
// -----
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @below{{FloatAttr does not match expected type of the constant}}
+ // expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..b85f3e82da3d25 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1306,6 +1306,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}
+llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
+ %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
+ // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
+ llvm.return %1 : !llvm.struct<(i32, f32)>
+}
+
// CHECK-LABEL: @indexconstantsplat
llvm.func @indexconstantsplat() -> vector<3xi32> {
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>
More information about the Mlir-commits
mailing list