[Mlir-commits] [mlir] [mlir][LLVM] Add support for constant struct with multiple fields (PR #102752)

Sirui Mu llvmlistbot at llvm.org
Sat Aug 10 06:47:33 PDT 2024


https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/102752

Currently `mlir.llvm.constant` of structure types restricts that the structure type effectively represents a complex type -- it must have exactly two fields of the same type and the field type must be either an integer type or a float type.

This PR relaxes this restriction and it allows the structure type to have an arbitrary number of fields.

>From f5ff4a958956c01859fd5bb9f36f488b41cee6ba Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sat, 10 Aug 2024 21:40:14 +0800
Subject: [PATCH] [mlir][LLVM] Add support for constant struct with multiple
 fields

Currently `mlir.llvm.constant` of structure types restricts that the structure
type effectively represents a complex type -- it must have exactly two fields
of the same type and the field type must be either an integer type or a float
type.

This patch relaxes this restriction and it allows the structure type to have an
arbitrary number of fields.
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp   | 39 ++++++++++----------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 24 ++++++------
 mlir/test/Dialect/LLVMIR/invalid.mlir        | 10 ++---
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir  |  6 +--
 mlir/test/Target/LLVMIR/llvmir.mlir          |  6 +++
 5 files changed, 45 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..9ee74cb562d81b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2678,31 +2678,30 @@ LogicalResult LLVM::ConstantOp::verify() {
     return success();
   }
   if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
-    if (structType.getBody().size() != 2 ||
-        structType.getBody()[0] != structType.getBody()[1]) {
-      return emitError() << "expected struct type with two elements of the "
-                            "same type, the type of a complex constant";
-    }
-
     auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      return emitOpError() << "expected array attribute with two elements, "
-                              "representing a complex constant";
+    if (!arrayAttr) {
+      return emitOpError() << "expected array attribute for a struct constant";
     }
-    auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
-    auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
-    if (!re || !im || re.getType() != im.getType()) {
-      return emitOpError()
-             << "expected array attribute with two elements of the same type";
+
+    ArrayRef<Type> elementTypes = structType.getBody();
+    if (arrayAttr.size() != elementTypes.size()) {
+      return emitOpError() << "expected array attribute of size "
+                           << elementTypes.size();
     }
 
-    Type elementType = structType.getBody()[0];
-    if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
-            elementType)) {
-      return emitError()
-             << "expected struct element types to be floating point type or "
-                "integer type";
+    for (size_t i = 0; i < elementTypes.size(); ++i) {
+      auto element = arrayAttr[i];
+      if (!mlir::isa<IntegerAttr, FloatAttr>(element)) {
+        return emitOpError() << "expected struct element types to be floating "
+                                "point type or integer type";
+      }
+      auto elementType = mlir::cast<TypedAttr>(element).getType();
+      if (elementType != elementTypes[i]) {
+        return emitOpError()
+               << "struct element at index " << i << " is of wrong type";
+      }
     }
+
     return success();
   }
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b468228ea78b78..f2f992a2180c7e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -557,20 +557,20 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      emitError(loc, "expected struct type to be a complex number");
+    if (!arrayAttr) {
+      emitError(loc, "expected an array attribute for a struct constant");
       return nullptr;
     }
-    llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real =
-        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
-    if (!real)
-      return nullptr;
-    llvm::Constant *imag =
-        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
-    if (!imag)
-      return nullptr;
-    return llvm::ConstantStruct::get(structType, {real, imag});
+    llvm::SmallVector<llvm::Constant *, 8> structElements;
+    structElements.reserve(structType->getNumElements());
+    for (size_t i = 0; i < arrayAttr.size(); ++i) {
+      llvm::Constant *element = getLLVMConstant(
+          structType->getElementType(i), arrayAttr[i], loc, moduleTranslation);
+      if (!element)
+        return nullptr;
+      structElements.push_back(element);
+    }
+    return llvm::ConstantStruct::get(structType, structElements);
   }
   // For integer types, we allow a mismatch in sizes as the index type in
   // MLIR might have a different size than the index type in the LLVM module.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..04f8aeb0afcd8b 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() {
 // -----
 
 llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 2}}
   %0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements of the same type}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute}}
   %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 1}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
   llvm.return %0 : !llvm.struct<(f64)>
 }
@@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
 // -----
 
 llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
   llvm.return %0 : !llvm.struct<(f64, f32)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 9cf922ad490a92..e04c54e8451681 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -16,7 +16,7 @@ llvm.func @vector_with_non_vector_type() -> f32 {
 // -----
 
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
 }
@@ -24,7 +24,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
 // -----
 
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
 }
@@ -32,7 +32,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
 // -----
 
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @below{{FloatAttr does not match expected type of the constant}}
+  // expected-error @below{{struct element at index 0 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..b85f3e82da3d25 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1306,6 +1306,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
   llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
 }
 
+llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
+  %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
+  // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
+  llvm.return %1 : !llvm.struct<(i32, f32)>
+}
+
 // CHECK-LABEL: @indexconstantsplat
 llvm.func @indexconstantsplat() -> vector<3xi32> {
   %1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>



More information about the Mlir-commits mailing list