[Mlir-commits] [mlir] 27dad99 - [mlir][LLVM] Make the nested type restriction on complex constants less aggressive

Benjamin Kramer llvmlistbot at llvm.org
Thu May 12 03:07:34 PDT 2022


Author: Benjamin Kramer
Date: 2022-05-12T11:47:01+02:00
New Revision: 27dad99622bb16fc7ba94beda26dded9023bf2cd

URL: https://github.com/llvm/llvm-project/commit/27dad99622bb16fc7ba94beda26dded9023bf2cd
DIFF: https://github.com/llvm/llvm-project/commit/27dad99622bb16fc7ba94beda26dded9023bf2cd.diff

LOG: [mlir][LLVM] Make the nested type restriction on complex constants less aggressive

Complex nested in other types is perfectly fine, just nested structs
aren't supported. Instead of checking whether there's nesting just check
whether the struct we're dealing with is a complex number.

Differential Revision: https://reviews.llvm.org/D125381

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 1750c0c9a3cfd..f2b5066925033 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region &region);
 /// report it to `loc` and return nullptr.
 llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
                                 Location loc,
-                                const ModuleTranslation &moduleTranslation,
-                                bool isTopLevel = true);
+                                const ModuleTranslation &moduleTranslation);
 
 /// Creates a call to an LLVM IR intrinsic function with the given arguments.
 llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 127e7e15ccab9..b0e231ded6d54 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
       (type.isa<VectorType>() || hasVectorElementType)) {
     llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
         innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
-        moduleTranslation, /*isTopLevel=*/false);
+        moduleTranslation);
     llvm::Constant *splatVector =
         llvm::ConstantDataVector::getSplat(0, splatValue);
     SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
@@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
 /// report it to `loc` and return nullptr.
 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *llvmType, Attribute attr, Location loc,
-    const ModuleTranslation &moduleTranslation, bool isTopLevel) {
+    const ModuleTranslation &moduleTranslation) {
   if (!attr)
     return llvm::UndefValue::get(llvmType);
   if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
-    if (!isTopLevel) {
-      emitError(loc, "nested struct types are not supported in constants");
+    auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+    if (!arrayAttr || arrayAttr.size() != 2) {
+      emitError(loc, "expected struct type to be a complex number");
       return nullptr;
     }
-    auto arrayAttr = attr.cast<ArrayAttr>();
     llvm::Type *elementType = structType->getElementType(0);
-    llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
-                                           moduleTranslation, false);
+    llvm::Constant *real =
+        getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
     if (!real)
       return nullptr;
-    llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
-                                           moduleTranslation, false);
+    llvm::Constant *imag =
+        getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
     if (!imag)
       return nullptr;
     return llvm::ConstantStruct::get(structType, {real, imag});
@@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         elementType,
         elementTypeSequential ? splatAttr
                               : splatAttr.getSplatValue<Attribute>(),
-        loc, moduleTranslation, false);
+        loc, moduleTranslation);
     if (!child)
       return nullptr;
     if (llvmType->isVectorTy())
@@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     llvm::Type *innermostType = getInnermostElementType(llvmType);
     for (auto n : elementsAttr.getValues<Attribute>()) {
       constants.push_back(
-          getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
+          getLLVMConstant(innermostType, n, loc, moduleTranslation));
       if (!constants.back())
         return nullptr;
     }

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index fdbbf9e8fcc98..ba23c8700c48d 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
 
 // -----
 
-llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
-  // expected-error @+1 {{nested struct types are not supported in constants}}
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
+  // expected-error @+1 {{expected struct type to be a complex number}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
   llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
 }
 
 // -----
 
+llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
+  // expected-error @+1 {{expected struct type to be a complex number}}
+  %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+  llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
+}
+
+// -----
+
 llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
   // expected-error @+1 {{FloatAttr does not match expected type of the constant}}
   %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index cd14641944c30..b4a2dbcf02d8a 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
   llvm.return %1 : !llvm.struct<(i32, i32)>
 }
 
+llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
+  llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+}
+
+llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
+  %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
+  // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
+  llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
+}
+
 llvm.func @noreach() {
 // CHECK:    unreachable
   llvm.unreachable


        


More information about the Mlir-commits mailing list