[Mlir-commits] [mlir] MLIR->LLVMIR: Support creating constant arrays (and structs) from arbitary ArrayAttr for constants (PR #94143)

Ryan Thomas Lynch llvmlistbot at llvm.org
Sat Jun 1 23:06:19 PDT 2024


https://github.com/emosy created https://github.com/llvm/llvm-project/pull/94143

Currently not working _exactly_ as expected. But the intent is something like this in MLIR:
`llvm.mlir.global external constant @array([1,2,3]) : !llvm.array<3 x i64>`
can be converted to LLVMIR:
`@array = constant [3 x i64] [i64 0, i64 1, i64 2]`

While I was at it, I decided to try supporting structs too. That might be causing the problems...

Plus I added a `getNumElements` to the `LLVMStructType` type in the LLVMIR dialect to match the same method available on a real LLVM struct type and available for an LLVM array type (both real and in MLIR)

>From 76ab905a0c308f569e71ab622ed3baff94722650 Mon Sep 17 00:00:00 2001
From: Ryan Thomas Lynch <rlynch34 at gatech.edu>
Date: Sun, 2 Jun 2024 05:30:44 +0000
Subject: [PATCH] BROKEN+WIP first try: LLVM dialect to LLVMIR constant arrays
 + structs conversion

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h |  3 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp   | 30 ++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp     |  4 +
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 97 ++++++++++++++++----
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir  | 21 -----
 mlir/test/Target/LLVMIR/llvmir.mlir          | 24 +++++
 6 files changed, 139 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 93733ccd4929a..89d7403e57bc0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,6 +180,9 @@ class LLVMStructType
   /// Returns the list of element types contained in a non-opaque struct.
   ArrayRef<Type> getBody() const;
 
+  /// Returns the number of elements in a struct.
+  size_t getNumElements() const;
+
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
                               StringRef, bool);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 60b911948d4a0..ba61665a3cf1e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2076,6 +2076,19 @@ LogicalResult GlobalOp::verify() {
       return emitOpError(
           "requires an i8 array type of the length equal to that of the string "
           "attribute");
+  } else if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
+             arrayType && getValueOrNull()) {
+    // Currently, creating arrays is only supported from ArrayAttrs
+    // or from StringAttr (but that was verified above)
+    auto arrayAttr = llvm::dyn_cast_or_null<ArrayAttr>(getValueOrNull());
+    if (arrayAttr && arrayAttr.size() != arrayType.getNumElements())
+      return emitOpError()
+             << "array type requires an array attribute of the same size, but "
+                "was provided with array attribute of size "
+             << arrayAttr.size() << " for an array type of size "
+             << arrayType.getNumElements();
+    // Checking that the element types match would be nice,
+    // but is not as trivial
   }
 
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
@@ -2088,6 +2101,23 @@ LogicalResult GlobalOp::verify() {
                               "initialized with zero-initializer";
   }
 
+  if (llvm::isa<LLVMStructType>(getType()) && getValueOrNull()) {
+    // Currently, creating structs is only supported from ArrayAttrs
+    auto structType = llvm::dyn_cast<LLVMStructType>(getType());
+    auto arrayAttr = llvm::dyn_cast_or_null<ArrayAttr>(getValueOrNull());
+    if (!arrayAttr)
+      return emitOpError()
+             << "struct type requires an array attribute of the same size";
+    if (arrayAttr.size() != structType.getNumElements())
+      return emitOpError()
+             << "struct type requires an array attribute of the same size, but "
+                "was provided with array attribute of size "
+             << arrayAttr.size() << " for an struct type of size "
+             << structType.getNumElements();
+    // Checking that the element types match would be nice,
+    // but is not as trivial
+  }
+
   if (getLinkage() == Linkage::Common) {
     if (Attribute value = getValueOrNull()) {
       if (!isZeroAttribute(value)) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index cf3f38b710130..2ba852c81f5f9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -492,6 +492,10 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
+size_t LLVMStructType::getNumElements() const {
+  return getBody().size();
+}
+
 LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
                                      StringRef, bool) {
   return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 176821f82434d..09ca5e8bb6cd0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -352,6 +352,74 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
   } while (true);
 }
 
+/// Convert an array attribute to an LLVM IR constant of array type by creating
+/// constants for each attribute in the array. Reports errors at `loc`.
+static llvm::Constant *
+convertArrayAttrToArrayType(Location loc, ArrayAttr arrayAttr,
+                            llvm::ArrayType *arrayType,
+                            const ModuleTranslation &moduleTranslation) {
+  if (!arrayAttr || !arrayType)
+    return nullptr;
+
+  const auto numElementsInAttr = arrayAttr.size();
+  const auto numElementsInLLVMType = arrayType->getNumElements();
+  if (numElementsInLLVMType != numElementsInAttr) {
+    emitError(loc, "Number of elements in provided MLIR array attribute and "
+                   "desired LLVM array type do not match. ArrayAttr size: ")
+        << numElementsInAttr << " != ArrayType size: " << numElementsInLLVMType;
+    return nullptr;
+  }
+
+  // Create constants for array elements
+  llvm::Type *llvmElemType = arrayType->getElementType();
+  SmallVector<llvm::Constant *> constants;
+  constants.reserve(numElementsInAttr);
+  for (auto attr : arrayAttr) {
+    llvm::Constant *constant = LLVM::detail::getLLVMConstant(
+        llvmElemType, attr, loc, moduleTranslation);
+    if (!constant)
+      return nullptr;
+    constants.push_back(constant);
+  }
+  ArrayRef<llvm::Constant *> constantsRef = constants;
+  return llvm::ConstantArray::get(arrayType, constantsRef);
+}
+
+/// Convert an array attribute to an LLVM IR constant of struct type by creating
+/// constants for each attribute in the array. Reports errors at `loc`.
+static llvm::Constant *
+convertArrayAttrToStructType(Location loc, ArrayAttr arrayAttr,
+                             llvm::StructType *structType,
+                             const ModuleTranslation &moduleTranslation) {
+  if (!arrayAttr || !structType)
+    return nullptr;
+
+  const auto numElementsInAttr = arrayAttr.size();
+  const auto numElementsInLLVMType = structType->getNumElements();
+  if (numElementsInLLVMType != numElementsInAttr) {
+    emitError(loc, "Number of elements in provided MLIR array attribute and "
+                   "desired LLVM struct type do not match. ArrayAttr size: ")
+        << numElementsInAttr
+        << " != StructType size: " << numElementsInLLVMType;
+    return nullptr;
+  }
+
+  // Create constants for array elements
+  SmallVector<llvm::Constant *> constants;
+  constants.reserve(numElementsInAttr);
+  for (auto [index, attr] : llvm::enumerate(arrayAttr)) {
+    llvm::Type *llvmElemType = structType->getElementType(index);
+    llvm::Constant *constant = LLVM::detail::getLLVMConstant(
+        llvmElemType, attr, loc, moduleTranslation);
+    if (!constant)
+      return nullptr;
+    constants.push_back(constant);
+  }
+
+  ArrayRef<llvm::Constant *> constantsRef = constants;
+  return llvm::ConstantStruct::get(structType, constantsRef);
+}
+
 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
 /// storage if possible. This supports elements attributes of tensor or vector
 /// type and avoids constructing separate objects for individual values of the
@@ -545,31 +613,17 @@ static llvm::Constant *convertDenseResourceElementsAttr(
 }
 
 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
-/// This currently supports integer, floating point, splat and dense element
-/// attributes and combinations thereof. Also, an array attribute with two
-/// elements is supported to represent a complex constant.  In case of error,
-/// report it to `loc` and return nullptr.
+/// This currently supports integer, floating point, array, struct, splat
+/// and dense element attributes and combinations thereof.
+/// In case of error, report it to `loc` and return nullptr.
 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *llvmType, Attribute attr, Location loc,
     const ModuleTranslation &moduleTranslation) {
   if (!attr)
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
-    auto arrayAttr = dyn_cast<ArrayAttr>(attr);
-    if (!arrayAttr || arrayAttr.size() != 2) {
-      emitError(loc, "expected struct type to be a complex number");
-      return nullptr;
-    }
-    llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real =
-        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
-    if (!real)
-      return nullptr;
-    llvm::Constant *imag =
-        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
-    if (!imag)
-      return nullptr;
-    return llvm::ConstantStruct::get(structType, {real, imag});
+    return convertArrayAttrToStructType(loc, dyn_cast<::mlir::ArrayAttr>(attr),
+                                        structType, moduleTranslation);
   }
   // For integer types, we allow a mismatch in sizes as the index type in
   // MLIR might have a different size than the index type in the LLVM module.
@@ -712,6 +766,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         ArrayRef<char>{stringAttr.getValue().data(),
                        stringAttr.getValue().size()});
   }
+  if (llvm::Constant *result = convertArrayAttrToArrayType(
+          loc, dyn_cast<::mlir::ArrayAttr>(attr),
+          dyn_cast<::llvm::ArrayType>(llvmType), moduleTranslation)) {
+    return result;
+  }
   emitError(loc, "unsupported constant value");
   return nullptr;
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 1b685d3783002..d0d739a8d52c6 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -15,22 +15,6 @@ llvm.func @vector_with_non_vector_type() -> f32 {
 
 // -----
 
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
-  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-}
-
-// -----
-
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
-  // expected-error @below{{expected struct type to be a complex number}}
-  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-}
-
-// -----
-
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
   // expected-error @below{{FloatAttr does not match expected type of the constant}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
@@ -63,11 +47,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
 
 // -----
 
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
 // expected-error @below{{LLVM attribute 'noinline' does not expect a value}}
 llvm.func @passthrough_unexpected_value() attributes {passthrough = [["noinline", "42"]]}
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 41a7eec1d8dfc..6eafc8bf7f485 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1314,6 +1314,30 @@ llvm.func @indexconstantarray() -> vector<3xi32> {
   llvm.return %1 : vector<3xi32>
 }
 
+// FIXME: WIP tests by @emosy
+
+llvm.mlir.global external constant @test_array_zero([0]) {addr_space = 0 : i32} : !llvm.array<1 x i64>
+llvm.mlir.global external constant @test_array_ints([0, 1, 2]) {addr_space = 0 : i32} : !llvm.array<3 x i64>
+
+// FIXME: this one becomes a return void? i'm not sure why!
+llvm.func @non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+}
+
+llvm.func @non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+llvm.func @nondenseintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
+  %1 = llvm.mlir.constant(<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
+  // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
+  llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
+}
+
+
+
 llvm.func @noreach() {
 // CHECK:    unreachable
   llvm.unreachable



More information about the Mlir-commits mailing list