[Mlir-commits] [mlir] MLIR->LLVMIR: Support creating constant arrays (and structs) from arbitary ArrayAttr for constants (PR #94143)
Ryan Thomas Lynch
llvmlistbot at llvm.org
Sat Jun 1 23:06:19 PDT 2024
https://github.com/emosy created https://github.com/llvm/llvm-project/pull/94143
Currently not working _exactly_ as expected. But the intent is something like this in MLIR:
`llvm.mlir.global external constant @array([1,2,3]) : !llvm.array<3 x i64>`
can be converted to LLVMIR:
`@array = constant [3 x i64] [i64 0, i64 1, i64 2]`
While I was at it, I decided to try supporting structs too. That might be causing the problems...
Plus I added a `getNumElements` to the `LLVMStructType` type in the LLVMIR dialect to match the same method available on a real LLVM struct type and available for an LLVM array type (both real and in MLIR)
>From 76ab905a0c308f569e71ab622ed3baff94722650 Mon Sep 17 00:00:00 2001
From: Ryan Thomas Lynch <rlynch34 at gatech.edu>
Date: Sun, 2 Jun 2024 05:30:44 +0000
Subject: [PATCH] BROKEN+WIP first try: LLVM dialect to LLVMIR constant arrays
+ structs conversion
---
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 3 +
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 30 ++++++
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 4 +
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 97 ++++++++++++++++----
mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 21 -----
mlir/test/Target/LLVMIR/llvmir.mlir | 24 +++++
6 files changed, 139 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 93733ccd4929a..89d7403e57bc0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,6 +180,9 @@ class LLVMStructType
/// Returns the list of element types contained in a non-opaque struct.
ArrayRef<Type> getBody() const;
+ /// Returns the number of elements in a struct.
+ size_t getNumElements() const;
+
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
StringRef, bool);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 60b911948d4a0..ba61665a3cf1e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2076,6 +2076,19 @@ LogicalResult GlobalOp::verify() {
return emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
+ } else if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
+ arrayType && getValueOrNull()) {
+ // Currently, creating arrays is only supported from ArrayAttrs
+ // or from StringAttr (but that was verified above)
+ auto arrayAttr = llvm::dyn_cast_or_null<ArrayAttr>(getValueOrNull());
+ if (arrayAttr && arrayAttr.size() != arrayType.getNumElements())
+ return emitOpError()
+ << "array type requires an array attribute of the same size, but "
+ "was provided with array attribute of size "
+ << arrayAttr.size() << " for an array type of size "
+ << arrayType.getNumElements();
+ // Checking that the element types match would be nice,
+ // but is not as trivial
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
@@ -2088,6 +2101,23 @@ LogicalResult GlobalOp::verify() {
"initialized with zero-initializer";
}
+ if (llvm::isa<LLVMStructType>(getType()) && getValueOrNull()) {
+ // Currently, creating structs is only supported from ArrayAttrs
+ auto structType = llvm::dyn_cast<LLVMStructType>(getType());
+ auto arrayAttr = llvm::dyn_cast_or_null<ArrayAttr>(getValueOrNull());
+ if (!arrayAttr)
+ return emitOpError()
+ << "struct type requires an array attribute of the same size";
+ if (arrayAttr.size() != structType.getNumElements())
+ return emitOpError()
+ << "struct type requires an array attribute of the same size, but "
+ "was provided with array attribute of size "
+ << arrayAttr.size() << " for an struct type of size "
+ << structType.getNumElements();
+ // Checking that the element types match would be nice,
+ // but is not as trivial
+ }
+
if (getLinkage() == Linkage::Common) {
if (Attribute value = getValueOrNull()) {
if (!isZeroAttribute(value)) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index cf3f38b710130..2ba852c81f5f9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -492,6 +492,10 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}
+size_t LLVMStructType::getNumElements() const {
+ return getBody().size();
+}
+
LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
StringRef, bool) {
return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 176821f82434d..09ca5e8bb6cd0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -352,6 +352,74 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
} while (true);
}
+/// Convert an array attribute to an LLVM IR constant of array type by creating
+/// constants for each attribute in the array. Reports errors at `loc`.
+static llvm::Constant *
+convertArrayAttrToArrayType(Location loc, ArrayAttr arrayAttr,
+ llvm::ArrayType *arrayType,
+ const ModuleTranslation &moduleTranslation) {
+ if (!arrayAttr || !arrayType)
+ return nullptr;
+
+ const auto numElementsInAttr = arrayAttr.size();
+ const auto numElementsInLLVMType = arrayType->getNumElements();
+ if (numElementsInLLVMType != numElementsInAttr) {
+ emitError(loc, "Number of elements in provided MLIR array attribute and "
+ "desired LLVM array type do not match. ArrayAttr size: ")
+ << numElementsInAttr << " != ArrayType size: " << numElementsInLLVMType;
+ return nullptr;
+ }
+
+ // Create constants for array elements
+ llvm::Type *llvmElemType = arrayType->getElementType();
+ SmallVector<llvm::Constant *> constants;
+ constants.reserve(numElementsInAttr);
+ for (auto attr : arrayAttr) {
+ llvm::Constant *constant = LLVM::detail::getLLVMConstant(
+ llvmElemType, attr, loc, moduleTranslation);
+ if (!constant)
+ return nullptr;
+ constants.push_back(constant);
+ }
+ ArrayRef<llvm::Constant *> constantsRef = constants;
+ return llvm::ConstantArray::get(arrayType, constantsRef);
+}
+
+/// Convert an array attribute to an LLVM IR constant of struct type by creating
+/// constants for each attribute in the array. Reports errors at `loc`.
+static llvm::Constant *
+convertArrayAttrToStructType(Location loc, ArrayAttr arrayAttr,
+ llvm::StructType *structType,
+ const ModuleTranslation &moduleTranslation) {
+ if (!arrayAttr || !structType)
+ return nullptr;
+
+ const auto numElementsInAttr = arrayAttr.size();
+ const auto numElementsInLLVMType = structType->getNumElements();
+ if (numElementsInLLVMType != numElementsInAttr) {
+ emitError(loc, "Number of elements in provided MLIR array attribute and "
+ "desired LLVM struct type do not match. ArrayAttr size: ")
+ << numElementsInAttr
+ << " != StructType size: " << numElementsInLLVMType;
+ return nullptr;
+ }
+
+ // Create constants for array elements
+ SmallVector<llvm::Constant *> constants;
+ constants.reserve(numElementsInAttr);
+ for (auto [index, attr] : llvm::enumerate(arrayAttr)) {
+ llvm::Type *llvmElemType = structType->getElementType(index);
+ llvm::Constant *constant = LLVM::detail::getLLVMConstant(
+ llvmElemType, attr, loc, moduleTranslation);
+ if (!constant)
+ return nullptr;
+ constants.push_back(constant);
+ }
+
+ ArrayRef<llvm::Constant *> constantsRef = constants;
+ return llvm::ConstantStruct::get(structType, constantsRef);
+}
+
/// Convert a dense elements attribute to an LLVM IR constant using its raw data
/// storage if possible. This supports elements attributes of tensor or vector
/// type and avoids constructing separate objects for individual values of the
@@ -545,31 +613,17 @@ static llvm::Constant *convertDenseResourceElementsAttr(
}
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
-/// This currently supports integer, floating point, splat and dense element
-/// attributes and combinations thereof. Also, an array attribute with two
-/// elements is supported to represent a complex constant. In case of error,
-/// report it to `loc` and return nullptr.
+/// This currently supports integer, floating point, array, struct, splat
+/// and dense element attributes and combinations thereof.
+/// In case of error, report it to `loc` and return nullptr.
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *llvmType, Attribute attr, Location loc,
const ModuleTranslation &moduleTranslation) {
if (!attr)
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
- auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2) {
- emitError(loc, "expected struct type to be a complex number");
- return nullptr;
- }
- llvm::Type *elementType = structType->getElementType(0);
- llvm::Constant *real =
- getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
- if (!real)
- return nullptr;
- llvm::Constant *imag =
- getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
- if (!imag)
- return nullptr;
- return llvm::ConstantStruct::get(structType, {real, imag});
+ return convertArrayAttrToStructType(loc, dyn_cast<::mlir::ArrayAttr>(attr),
+ structType, moduleTranslation);
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
@@ -712,6 +766,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
ArrayRef<char>{stringAttr.getValue().data(),
stringAttr.getValue().size()});
}
+ if (llvm::Constant *result = convertArrayAttrToArrayType(
+ loc, dyn_cast<::mlir::ArrayAttr>(attr),
+ dyn_cast<::llvm::ArrayType>(llvmType), moduleTranslation)) {
+ return result;
+ }
emitError(loc, "unsupported constant value");
return nullptr;
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 1b685d3783002..d0d739a8d52c6 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -15,22 +15,6 @@ llvm.func @vector_with_non_vector_type() -> f32 {
// -----
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
- %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
- llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
-}
-
-// -----
-
-llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
- // expected-error @below{{expected struct type to be a complex number}}
- %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
- llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
-}
-
-// -----
-
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @below{{FloatAttr does not match expected type of the constant}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
@@ -63,11 +47,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
// -----
-// expected-error @below{{unsupported constant value}}
-llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
-
-// -----
-
// expected-error @below{{LLVM attribute 'noinline' does not expect a value}}
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["noinline", "42"]]}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 41a7eec1d8dfc..6eafc8bf7f485 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1314,6 +1314,30 @@ llvm.func @indexconstantarray() -> vector<3xi32> {
llvm.return %1 : vector<3xi32>
}
+// FIXME: WIP tests by @emosy
+
+llvm.mlir.global external constant @test_array_zero([0]) {addr_space = 0 : i32} : !llvm.array<1 x i64>
+llvm.mlir.global external constant @test_array_ints([0, 1, 2]) {addr_space = 0 : i32} : !llvm.array<3 x i64>
+
+// FIXME: this one becomes a return void? i'm not sure why!
+llvm.func @non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+ %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+ llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
+}
+
+llvm.func @non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+ %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+ llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+llvm.func @nondenseintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
+ %1 = llvm.mlir.constant(<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
+ // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
+ llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
+}
+
+
+
llvm.func @noreach() {
// CHECK: unreachable
llvm.unreachable
More information about the Mlir-commits
mailing list