[Mlir-commits] [mlir] [mlir][LLVM] Add support for constant struct with multiple fields (PR #102752)

Sirui Mu llvmlistbot at llvm.org
Wed Aug 21 09:05:50 PDT 2024


https://github.com/Lancern updated https://github.com/llvm/llvm-project/pull/102752

>From 922f8d3edbb53c112d99d2daec5466db4bbbf131 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sat, 10 Aug 2024 21:40:14 +0800
Subject: [PATCH] [mlir][LLVM] Add support for constant struct with multiple
 fields

Currently `mlir.llvm.constant` of structure types restricts that the structure
type effectively represents a complex type -- it must have exactly two fields
of the same type and the field type must be either an integer type or a float
type.

This patch relaxes this restriction and it allows the structure type to have an
arbitrary number of fields.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td  | 37 +++++++++------
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp   | 47 +++++++++++---------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 25 ++++++-----
 mlir/test/Dialect/LLVMIR/invalid.mlir        | 10 ++---
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir  | 26 ++++++++---
 mlir/test/Target/LLVMIR/llvmir.mlir          |  6 +++
 6 files changed, 96 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 71f249fa538ca9..46bf1c9640c174 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1620,19 +1620,30 @@ def LLVM_ConstantOp
   let description = [{
     Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
     all constants must be created as SSA values before being used in other
-    operations. `llvm.mlir.constant` creates such values for scalars and
-    vectors. It has a mandatory `value` attribute, which may be an integer,
-    floating point attribute; dense or sparse attribute containing integers or
-    floats. The type of the attribute is one of the corresponding MLIR builtin
-    types. It may be omitted for `i64` and `f64` types that are implied.
-
-    The operation produces a new SSA value of the specified LLVM IR dialect
-    type. Certain builtin types such as integer, float and vector types are
-    also allowed. The result type _must_ correspond to the attribute type
-    converted to LLVM IR. In particular, the number of elements of a container
-    type must match the number of elements in the attribute. If the type is or
-    contains a scalable vector type, the attribute must be a splat elements
-    attribute.
+    operations. `llvm.mlir.constant` creates such values for scalars, vectors,
+    strings, and structs. It has a mandatory `value` attribute whose type
+    depends on the type of the constant value. The type of the constant value
+    must correspond to the attribute type converted to LLVM IR type.
+
+    When creating constant scalars, the `value` attribute must be either an
+    integer attribute or a floating point attribute. The type of the attribute
+    may be omitted for `i64` and `f64` types that are implied.
+
+    When creating constant vectors, the `value` attribute must be either an
+    array attribute, a dense attribute, or a sparse attribute that contains
+    integers or floats. The number of elements in the result vector must match
+    the number of elements in the attribute.
+
+    When creating constant strings, the `value` attribute must be a string
+    attribute. The type of the constant must be an LLVM array of `i8`s, and the
+    length of the array must match the length of the attribute.
+
+    When creating constant structs, the `value` attribute must be an array
+    attribute that contains integers or floats. The type of the constant must be
+    an LLVM struct type. The number of fields in the struct must match the
+    number of elements in the attribute, and the type of each LLVM struct field
+    must correspond to the type of the corresponding attribute element converted
+    to LLVM IR.
 
     Examples:
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 92f3984e5e6db6..3870aab52f199d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -2710,32 +2711,38 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
     return success();
   }
-  if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
-    if (structType.getBody().size() != 2 ||
-        structType.getBody()[0] != structType.getBody()[1]) {
-      return emitError() << "expected struct type with two elements of the "
-                            "same type, the type of a complex constant";
+  if (auto structType = dyn_cast<LLVMStructType>(getType())) {
+    auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
+    if (!arrayAttr) {
+      return emitOpError() << "expected array attribute for a struct constant";
     }
 
-    auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      return emitOpError() << "expected array attribute with two elements, "
-                              "representing a complex constant";
+    ArrayRef<Type> elementTypes = structType.getBody();
+    if (arrayAttr.size() != elementTypes.size()) {
+      return emitOpError() << "expected array attribute of size "
+                           << elementTypes.size();
     }
-    auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
-    auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
-    if (!re || !im || re.getType() != im.getType()) {
-      return emitOpError()
-             << "expected array attribute with two elements of the same type";
+    for (auto elementTy : elementTypes) {
+      if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
+        return emitOpError() << "expected struct element types to be floating "
+                                "point type or integer type";
+      }
     }
 
-    Type elementType = structType.getBody()[0];
-    if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
-            elementType)) {
-      return emitError()
-             << "expected struct element types to be floating point type or "
-                "integer type";
+    for (size_t i = 0; i < elementTypes.size(); ++i) {
+      Attribute element = arrayAttr[i];
+      if (!isa<IntegerAttr, FloatAttr>(element)) {
+        return emitOpError()
+               << "expected struct element attribute types to be floating "
+                  "point type or integer type";
+      }
+      auto elementType = cast<TypedAttr>(element).getType();
+      if (elementType != elementTypes[i]) {
+        return emitOpError()
+               << "struct element at index " << i << " is of wrong type";
+      }
     }
+
     return success();
   }
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 930300d26c4479..adf70e6aab5d14 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -557,20 +557,21 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
     auto arrayAttr = dyn_cast<ArrayAttr>(attr);
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      emitError(loc, "expected struct type to be a complex number");
+    if (!arrayAttr) {
+      emitError(loc, "expected an array attribute for a struct constant");
       return nullptr;
     }
-    llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real =
-        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
-    if (!real)
-      return nullptr;
-    llvm::Constant *imag =
-        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
-    if (!imag)
-      return nullptr;
-    return llvm::ConstantStruct::get(structType, {real, imag});
+    SmallVector<llvm::Constant *> structElements;
+    structElements.reserve(structType->getNumElements());
+    for (auto [elemType, elemAttr] :
+         zip_equal(structType->elements(), arrayAttr)) {
+      llvm::Constant *element =
+          getLLVMConstant(elemType, elemAttr, loc, moduleTranslation);
+      if (!element)
+        return nullptr;
+      structElements.push_back(element);
+    }
+    return llvm::ConstantStruct::get(structType, structElements);
   }
   // For integer types, we allow a mismatch in sizes as the index type in
   // MLIR might have a different size than the index type in the LLVM module.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 62346ce0d2c4b1..6670e4b186c397 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -367,7 +367,7 @@ func.func @constant_wrong_type_string() {
 // -----
 
 llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 2}}
   %0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -375,7 +375,7 @@ llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements of the same type}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -383,7 +383,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
+  // expected-error @+1 {{expected array attribute}}
   %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
@@ -391,7 +391,7 @@ llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
 // -----
 
 llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{expected array attribute of size 1}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
   llvm.return %0 : !llvm.struct<(f64)>
 }
@@ -399,7 +399,7 @@ llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
 // -----
 
 llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
-  // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
+  // expected-error @+1 {{struct element at index 1 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
   llvm.return %0 : !llvm.struct<(f64, f32)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 9cf922ad490a92..0e2afe6fb004d8 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -15,24 +15,40 @@ llvm.func @vector_with_non_vector_type() -> f32 {
 
 // -----
 
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
 }
 
 // -----
 
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
+llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+  // expected-error @below{{expected an array attribute for a struct constant}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
 }
 
 // -----
 
+llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
+  // expected-error @below{{expected struct element types to be floating point type or integer type}}
+  %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
+  llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
+}
+
+// -----
+
+llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
+  // expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
+  %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
+  llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
+// -----
+
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
-  // expected-error @below{{FloatAttr does not match expected type of the constant}}
+  // expected-error @below{{struct element at index 0 is of wrong type}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 8453983aa07c33..df61fef605fde0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1312,6 +1312,12 @@ llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.st
   llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
 }
 
+llvm.func @structconstant() -> !llvm.struct<(i32, f32)> {
+  %1 = llvm.mlir.constant([1 : i32, 2.000000e+00 : f32]) : !llvm.struct<(i32, f32)>
+  // CHECK: ret { i32, float } { i32 1, float 2.000000e+00 }
+  llvm.return %1 : !llvm.struct<(i32, f32)>
+}
+
 // CHECK-LABEL: @indexconstantsplat
 llvm.func @indexconstantsplat() -> vector<3xi32> {
   %1 = llvm.mlir.constant(dense<42> : vector<3xindex>) : vector<3xi32>



More information about the Mlir-commits mailing list