[Mlir-commits] [mlir] 27dad99 - [mlir][LLVM] Make the nested type restriction on complex constants less aggressive
Benjamin Kramer
llvmlistbot at llvm.org
Thu May 12 03:07:34 PDT 2022
Author: Benjamin Kramer
Date: 2022-05-12T11:47:01+02:00
New Revision: 27dad99622bb16fc7ba94beda26dded9023bf2cd
URL: https://github.com/llvm/llvm-project/commit/27dad99622bb16fc7ba94beda26dded9023bf2cd
DIFF: https://github.com/llvm/llvm-project/commit/27dad99622bb16fc7ba94beda26dded9023bf2cd.diff
LOG: [mlir][LLVM] Make the nested type restriction on complex constants less aggressive
Complex nested in other types is perfectly fine, just nested structs
aren't supported. Instead of checking whether there's nesting just check
whether the struct we're dealing with is a complex number.
Differential Revision: https://reviews.llvm.org/D125381
Added:
Modified:
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/llvmir-invalid.mlir
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 1750c0c9a3cfd..f2b5066925033 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region ®ion);
/// report it to `loc` and return nullptr.
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
Location loc,
- const ModuleTranslation &moduleTranslation,
- bool isTopLevel = true);
+ const ModuleTranslation &moduleTranslation);
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 127e7e15ccab9..b0e231ded6d54 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
(type.isa<VectorType>() || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
- moduleTranslation, /*isTopLevel=*/false);
+ moduleTranslation);
llvm::Constant *splatVector =
llvm::ConstantDataVector::getSplat(0, splatValue);
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
@@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
/// report it to `loc` and return nullptr.
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *llvmType, Attribute attr, Location loc,
- const ModuleTranslation &moduleTranslation, bool isTopLevel) {
+ const ModuleTranslation &moduleTranslation) {
if (!attr)
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
- if (!isTopLevel) {
- emitError(loc, "nested struct types are not supported in constants");
+ auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+ if (!arrayAttr || arrayAttr.size() != 2) {
+ emitError(loc, "expected struct type to be a complex number");
return nullptr;
}
- auto arrayAttr = attr.cast<ArrayAttr>();
llvm::Type *elementType = structType->getElementType(0);
- llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
- moduleTranslation, false);
+ llvm::Constant *real =
+ getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
if (!real)
return nullptr;
- llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
- moduleTranslation, false);
+ llvm::Constant *imag =
+ getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
if (!imag)
return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag});
@@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
elementType,
elementTypeSequential ? splatAttr
: splatAttr.getSplatValue<Attribute>(),
- loc, moduleTranslation, false);
+ loc, moduleTranslation);
if (!child)
return nullptr;
if (llvmType->isVectorTy())
@@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *innermostType = getInnermostElementType(llvmType);
for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back(
- getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
+ getLLVMConstant(innermostType, n, loc, moduleTranslation));
if (!constants.back())
return nullptr;
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index fdbbf9e8fcc98..ba23c8700c48d 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
// -----
-llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
- // expected-error @+1 {{nested struct types are not supported in constants}}
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+ // expected-error @+1 {{expected struct type to be a complex number}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}
// -----
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+ // expected-error @+1 {{expected struct type to be a complex number}}
+ %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+ llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+// -----
+
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index cd14641944c30..b4a2dbcf02d8a 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
llvm.return %1 : !llvm.struct<(i32, i32)>
}
+llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
+ %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
+ llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+}
+
+llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
+ %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
+ // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
+ llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
+}
+
llvm.func @noreach() {
// CHECK: unreachable
llvm.unreachable
More information about the Mlir-commits
mailing list