[Mlir-commits] [mlir] [mlir][LLVM] Add support for constant struct with multiple fields (PR #102752)
Sirui Mu
llvmlistbot at llvm.org
Tue Aug 20 10:23:02 PDT 2024
https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/102752
>From 94acee41d84aaf8cc77295fd6d8f9538e2caaaf7 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sat, 10 Aug 2024 21:40:14 +0800
Subject: [PATCH] [mlir][LLVM] Add support for constant struct with multiple
fields
Currently `mlir.llvm.constant` of structure types restricts that the structure
type effectively represents a complex type -- it must have exactly two fields
of the same type and the field type must be either an integer type or a float
type.
This patch relaxes this restriction and it allows the structure type to have an
arbitrary number of fields.
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 37 +++++++++++-------
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 41 ++++++++++----------
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 24 ++++++------
mlir/test/Dialect/LLVMIR/invalid.mlir | 10 ++---
mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 6 +--
mlir/test/Target/LLVMIR/llvmir.mlir | 6 +++
6 files changed, 70 insertions(+), 54 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 71f249fa538ca9..46bf1c9640c174 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
let description = [{
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
all constants must be created as SSA values before being used in other
- operations. `llvm.mlir.constant` creates such values for scalars and
- vectors. It has a mandatory `value` attribute, which may be an integer,
- floating point attribute; dense or sparse attribute containing integers or
- floats. The type of the attribute is one of the corresponding MLIR builtin
- types. It may be omitted for `i64` and `f64` types that are implied.
-
- The operation produces a new SSA value of the specified LLVM IR dialect
- type. Certain builtin types such as integer, float and vector types are
- also allowed. The result type _must_ correspond to the attribute type
- converted to LLVM IR. In particular, the number of elements of a container
- type must match the number of elements in the attribute. If the type is or
- contains a scalable vector type, the attribute must be a splat elements
- attribute.
+ operations. `llvm.mlir.constant` creates such values for scalars, vectors,
+ strings, and structs. It has a mandatory `value` attribute whose type
+ depends on the type of the constant value. The type of the constant value
+ must correspond to the attribute type converted to LLVM IR type.
+
+ When creating constant scalars, the `value` attribute must be either an
+ integer attribute or a floating point attribute. The type of the attribute
+ may be omitted for `i64` and `f64` types that are implied.
+
+ When creating constant vectors, the `value` attribute must be either an
+ array attribute, a dense attribute, or a sparse attribute that contains
+ integers or floats. The number of elements in the result vector must match
+ the number of elements in the attribute.
+
+ When creating constant strings, the `value` attribute must be a string
+ attribute. The type of the constant must be an LLVM array of `i8`s, and the
+ length of the array must match the length of the attribute.
+
+ When creating constant structs, the `value` attribute must be an array
+ attribute that contains integers or floats. The type of the constant must be
+ an LLVM struct type. The number of fields in the struct must match the
+ number of elements in the attribute, and the type of each LLVM struct field
+ must correspond to the type of the corresponding attribute element converted
+ to LLVM IR.
Examples:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 92f3984e5e6db6..4e1cb0b409b9b2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2710,32 +2710,31 @@ LogicalResult LLVM::ConstantOp::verify() {
}
return success();
}
- if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
- if (structType.getBody().size() != 2 ||
- structType.getBody()[0] != structType.getBody()[1]) {
- return emitError() << "expected struct type with two elements of the "
- "same type, the type of a complex constant";
+ if (auto structType = dyn_cast<LLVMStructType>(getType())) {
+ auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
+ if (!arrayAttr) {
+ return emitOpError() << "expected array attribute for a struct constant";
}
- auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
- if (!arrayAttr || arrayAttr.size() != 2) {
- return emitOpError() << "expected array attribute with two elements, "
- "representing a complex constant";
- }
- auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
- auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
- if (!re || !im || re.getType() != im.getType()) {
- return emitOpError()
- << "expected array attribute with two elements of the same type";
+ ArrayRef<Type> elementTypes = structType.getBody();
+ if (arrayAttr.size() != elementTypes.size()) {
+ return emitOpError() << "expected array attribute of size "
+ << elementTypes.size();
}
- Type elementType = structType.getBody()[0];
- if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
- elementType)) {
- return emitError()
- << "expected struct element types to be floating point type or "
- "integer type";
+ for (size_t i = 0; i < elementTypes.size(); ++i) {
+ auto element = arrayAttr[i];
+ if (!isa<IntegerAttr, FloatAttr>(element)) {
+ return emitOpError() << "expected struct element types to be floating "
+ "point type or integer type";
+ }
+ auto elementType = cast<TypedAttr>(element).getType();
+ if (elementType != elementTypes[i]) {
+ return emitOpError()
+ << "struct element at index " << i << " is of wrong type";
+ }
}
+
return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 930300d26c4479..1900507dcfcf29 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -557,20 +557,20 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2) {
- emitError(loc, "expected struct type to be a complex number");
+ if (!arrayAttr) {
+ emitError(loc, "expected an array attribute for a struct constant");
return nullptr;
}
- llvm::Type *elementType = structType->getElementType(0);
- llvm::Constant *real =
- getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
- if (!real)
- return nullptr;
- llvm::Constant *imag =
- getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
- if (!imag)
- return nullptr;
- return llvm::ConstantStruct::get(structType, {real, imag});
+ llvm::SmallVector<llvm::Constant *, 8> structElements;
+ structElements.reserve(structType->getNumElements());
+ for (size_t i = 0; i < arrayAttr.size(); ++i) {
+ llvm::Constant *element = getLLVMConstant(
+ structType->getElementType(i), arrayAttr[i], loc, moduleTranslation);
+ if (!element)
+ return nullptr;
+ structElements.push_back(element);
+ }
+ return llvm::ConstantStruct::get(structType, structElements);
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 62346ce0d2c4b1..6670e4b186c397 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() {
// -----
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+ // expected-error @+1 {{expected array attribute of size 2}}
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements of the same type}}
+ // expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+ // expected-error @+1 {{expected array attribute}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
@@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// -----
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
- // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+ // expected-error @+1 {{expected array attribute of size 1}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
llvm.return %0 : !llvm.struct<(f64)>
}
@@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
// -----
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
- // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+ // expected-error @+1 {{struct element at index 1 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
llvm.return %0 : !llvm.struct<(f64, f32)>
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 9cf922ad490a92..e04c54e8451681 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -16,7 +16,7 @@ llvm.func @vector_with_non_vector_type() -> f32 {
// -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
+ // expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}
@@ -24,7 +24,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
// -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
+ // expected-error @below{{expected an array attribute for a struct constant}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}
@@ -32,7 +32,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
// -----
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
- // expected-error @below{{FloatAttr does not match expected type of the constant}}
+ // expected-error @below{{struct element at index 0 is of wrong type}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..df61fef605fde0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}
+llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
+ %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
+ // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
+ llvm.return %1 : !llvm.struct<(i32, f32)>
+}
+
// CHECK-LABEL: @indexconstantsplat
llvm.func @indexconstantsplat() -> vector<3xi32> {
%1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>
More information about the Mlir-commits
mailing list