[Mlir-commits] [mlir] [mlir][LLVM] Add support for constant struct with multiple fields (PR #102752)

Sirui Mu llvmlistbot at llvm.org
Tue Aug 20 10:23:02 PDT 2024


https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/102752

>From 94acee41d84aaf8cc77295fd6d8f9538e2caaaf7 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sat, 10 Aug 2024 21:40:14 +0800
Subject: [PATCH] [mlir][LLVM] Add support for constant struct with multiple
 fields

Currently `mlir.llvm.constant` of structure types restricts that the structure
type effectively represents a complex type -- it must have exactly two fields
of the same type and the field type must be either an integer type or a float
type.

This patch relaxes this restriction and it allows the structure type to have an
arbitrary number of fields.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td  | 37 +++++++++++-------
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp   | 41 ++++++++++----------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 24 ++++++------
 mlir/test/Dialect/LLVMIR/invalid.mlir        | 10 ++---
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir  |  6 +--
 mlir/test/Target/LLVMIR/llvmir.mlir          |  6 +++
 6 files changed, 70 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 71f249fa538ca9..46bf1c9640c174 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
   let description = [{
     Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
     all constants must be created as SSA values before being used in other
-    operations. `llvm.mlir.constant` creates such values for scalars and
-    vectors. It has a mandatory `value` attribute, which may be an integer,
-    floating point attribute; dense or sparse attribute containing integers or
-    floats. The type of the attribute is one of the corresponding MLIR builtin
-    types. It may be omitted for `i64` and `f64` types that are implied.
-
-    The operation produces a new SSA value of the specified LLVM IR dialect
-    type. Certain builtin types such as integer, float and vector types are
-    also allowed. The result type _must_ correspond to the attribute type
-    converted to LLVM IR. In particular, the number of elements of a container
-    type must match the number of elements in the attribute. If the type is or
-    contains a scalable vector type, the attribute must be a splat elements
-    attribute.
+    operations. `llvm.mlir.constant` creates such values for scalars, vectors,
+    strings, and structs. It has a mandatory `value` attribute whose type
+    depends on the type of the constant value. The type of the constant value
+    must correspond to the attribute type converted to LLVM IR type.
+
+    When creating constant scalars, the `value` attribute must be either an
+    integer attribute or a floating point attribute. The type of the attribute
+    may be omitted for `i64` and `f64` types that are implied.
+
+    When creating constant vectors, the `value` attribute must be either an
+    array attribute, a dense attribute, or a sparse attribute that contains
+    integers or floats. The number of elements in the result vector must match
+    the number of elements in the attribute.
+
+    When creating constant strings, the `value` attribute must be a string
+    attribute. The type of the constant must be an LLVM array of `i8`s, and the
+    length of the array must match the length of the attribute.
+
+    When creating constant structs, the `value` attribute must be an array
+    attribute that contains integers or floats. The type of the constant must be
+    an LLVM struct type. The number of fields in the struct must match the
+    number of elements in the attribute, and the type of each LLVM struct field
+    must correspond to the type of the corresponding attribute element converted
+    to LLVM IR.
 
     Examples:
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 92f3984e5e6db6..4e1cb0b409b9b2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2710,32 +2710,31 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
     return success();
   }
-  if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
-    if (structType.getBody().size() != 2 ||
-        structType.getBody()[0] != structType.getBody()[1]) {
-      return emitError() << "expected struct type with two elements of the "
-                            "same type, the type of a complex constant";
+  if (auto structType = dyn_cast<LLVMStructType>(getType())) {
+    auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
+    if (!arrayAttr) {
+      return emitOpError() << "expected array attribute for a struct constant";
     }
 
-    auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      return emitOpError() << "expected array attribute with two elements, "
-                              "representing a complex constant";
-    }
-    auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
-    auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
-    if (!re || !im || re.getType() != im.getType()) {
-      return emitOpError()
-             << "expected array attribute with two elements of the same type";
+    ArrayRef<Type> elementTypes = structType.getBody();
+    if (arrayAttr.size() != elementTypes.size()) {
+      return emitOpError() << "expected array attribute of size "
+                           << elementTypes.size();
     }
 
-    Type elementType = structType.getBody()[0];
-    if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
-            elementType)) {
-      return emitError()
-             << "expected struct element types to be floating point type or "
-                "integer type";
+    for (size_t i = 0; i < elementTypes.size(); ++i) {
+      auto element = arrayAttr[i];
+      if (!isa<IntegerAttr, FloatAttr>(element)) {
+        return emitOpError() << "expected struct element types to be floating "
+                                "point type or integer type";
+      }
+      auto elementType = cast<TypedAttr>(element).getType();
+      if (elementType != elementTypes[i]) {
+        return emitOpError()
+               << "struct element at index " << i << " is of wrong type";
+      }
     }
+
     return success();
   }
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 930300d26c4479..1900507dcfcf29 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -557,20 +557,20 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      emitError(loc, "expected struct type to be a complex number");
+    if (!arrayAttr) {
+      emitError(loc, "expected an array attribute for a struct constant");
       return nullptr;
     }
-    llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real =
-        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
-    if (!real)
-      return nullptr;
-    llvm::Constant *imag =
-        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
-    if (!imag)
-      return nullptr;
-    return llvm::ConstantStruct::get(structType, {real, imag});
+    llvm::SmallVector<llvm::Constant *, 8> structElements;
+    structElements.reserve(structType->getNumElements());
+    for (size_t i = 0; i < arrayAttr.size(); ++i) {
+      llvm::Constant *element = getLLVMConstant(
+          structType->getElementType(i), arrayAttr[i], loc, moduleTranslation);
+      if (!element)
+        return nullptr;
+      structElements.push_back(element);
+    }
+    return llvm::ConstantStruct::get(structType, structElements);
   }
   // For integer types, we allow a mismatch in sizes as the index type in
   // MLIR might have a different size than the index type in the LLVM module.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 62346ce0d2c4b1..6670e4b186c397 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() {
 // -----
 
 llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 2}}
   %0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements of the same type}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute}}
   %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 1}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
   llvm.return %0 : !llvm.struct<(f64)>
 }
@@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
 // -----
 
 llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
   llvm.return %0 : !llvm.struct<(f64, f32)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 9cf922ad490a92..e04c54e8451681 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -16,7 +16,7 @@ llvm.func @vector_with_non_vector_type() -> f32 {
 // -----
 
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
 }
@@ -24,7 +24,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
 // -----
 
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
 }
@@ -32,7 +32,7 @@ llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct
 // -----
 
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @below{{FloatAttr does not match expected type of the constant}}
+  // expected-error @below{{struct element at index 0 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..df61fef605fde0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
   llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
 }
 
+llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
+  %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
+  // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
+  llvm.return %1 : !llvm.struct<(i32, f32)>
+}
+
 // CHECK-LABEL: @indexconstantsplat
 llvm.func @indexconstantsplat() -> vector<3xi32> {
   %1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>



More information about the Mlir-commits mailing list