[Mlir-commits] [mlir] e64c766 - [mlir] recursively convert builtin types to LLVM when possible

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 10 09:11:06 PST 2021


Author: Alex Zinenko
Date: 2021-11-10T18:11:00+01:00
New Revision: e64c76672f5c50d28298222047b3678ee5b2f251

URL: https://github.com/llvm/llvm-project/commit/e64c76672f5c50d28298222047b3678ee5b2f251
DIFF: https://github.com/llvm/llvm-project/commit/e64c76672f5c50d28298222047b3678ee5b2f251.diff

LOG: [mlir] recursively convert builtin types to LLVM when possible

Given that LLVM dialect types may now optionally contain types from other
dialects, which itself is motivated by dialect interoperability and progressive
lowering, the conversion should no longer assume that the outermost LLVM
dialect type can be left as is. Instead, it should inspect the types it
contains and attempt to convert them to the LLVM dialect. Introduce this
capability for LLVM array, pointer and structure types. Only literal structures
are currently supported as handling identified structures requires the
converison infrastructure to have a mechanism for avoiding infite recursion in
case of recursive types.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D112550

Added: 
    mlir/test/Conversion/StandardToLLVM/convert-types.mlir

Modified: 
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ca1882899edff..28be3300fb382 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -38,12 +38,53 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
   addConversion([&](VectorType type) { return convertVectorType(type); });
 
-  // LLVM-compatible types are legal, so add a pass-through conversion.
+  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
+  // before the conversions below since conversions are attempted in reverse
+  // order and those should take priority.
   addConversion([](Type type) {
     return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
                                         : llvm::None;
   });
 
+  // LLVM container types may (recursively) contain other types that must be
+  // converted even when the outer type is compatible.
+  addConversion([&](LLVM::LLVMPointerType type) -> llvm::Optional<Type> {
+    if (auto pointee = convertType(type.getElementType()))
+      return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
+    return llvm::None;
+  });
+  addConversion([&](LLVM::LLVMStructType type) -> llvm::Optional<Type> {
+    // TODO: handle conversion of identified structs, which may be recursive.
+    if (type.isIdentified())
+      return type;
+
+    SmallVector<Type> convertedSubtypes;
+    convertedSubtypes.reserve(type.getBody().size());
+    if (failed(convertTypes(type.getBody(), convertedSubtypes)))
+      return llvm::None;
+
+    return LLVM::LLVMStructType::getLiteral(type.getContext(),
+                                            convertedSubtypes, type.isPacked());
+  });
+  addConversion([&](LLVM::LLVMArrayType type) -> llvm::Optional<Type> {
+    if (auto element = convertType(type.getElementType()))
+      return LLVM::LLVMArrayType::get(element, type.getNumElements());
+    return llvm::None;
+  });
+  addConversion([&](LLVM::LLVMFunctionType type) -> llvm::Optional<Type> {
+    Type convertedResType = convertType(type.getReturnType());
+    if (!convertedResType)
+      return llvm::None;
+
+    SmallVector<Type> convertedArgTypes;
+    convertedArgTypes.reserve(type.getNumParams());
+    if (failed(convertTypes(type.getParams(), convertedArgTypes)))
+      return llvm::None;
+
+    return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
+                                       type.isVarArg());
+  });
+
   // Materialization for memrefs creates descriptor structs from individual
   // values constituting them, when descriptors are used, i.e. more than one
   // value represents a memref.

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir
new file mode 100644
index 0000000000000..8a15625b2298a
--- /dev/null
+++ b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -test-convert-call-op %s | FileCheck %s
+
+// CHECK-LABEL: @ptr
+// CHECK: !llvm.ptr<i42>
+func private @ptr() -> !llvm.ptr<!test.smpla>
+
+// CHECK-LABEL: @ptr_ptr()
+// CHECK: !llvm.ptr<ptr<i42>> 
+func private @ptr_ptr() -> !llvm.ptr<!llvm.ptr<!test.smpla>>
+
+// CHECK-LABEL: @struct_ptr()
+// CHECK: !llvm.struct<(ptr<i42>)> 
+func private @struct_ptr() -> !llvm.struct<(ptr<!test.smpla>)>
+
+// CHECK-LABEL: @named_struct_ptr()
+// CHECK: !llvm.struct<"named", (ptr<!test.smpla>)> 
+func private @named_struct_ptr() -> !llvm.struct<"named", (ptr<!test.smpla>)>
+
+// CHECK-LABEL: @array_ptr()
+// CHECK: !llvm.array<10 x ptr<i42>> 
+func private @array_ptr() -> !llvm.array<10 x ptr<!test.smpla>>
+
+// CHECK-LABEL: @func()
+// CHECK: !llvm.ptr<func<i42 (i42)>>
+func private @func() -> !llvm.ptr<!llvm.func<!test.smpla (!test.smpla)>>
+
+// TODO: support conversion of recursive types in the conversion infra.
+// CHECK-LABEL: @named_recursive()
+// CHECK: !llvm.struct<"recursive", (ptr<!test.smpla>, ptr<struct<"recursive">>)> 
+func private @named_recursive() -> !llvm.struct<"recursive", (ptr<!test.smpla>, ptr<struct<"recursive">>)>
+

diff  --git a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
index 43f57fecd8e9e..dbe2a74cd55ac 100644
--- a/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Conversion/StandardToLLVM/TestConvertCallOp.cpp
@@ -52,6 +52,9 @@ class TestConvertCallOp
     typeConverter.addConversion([&](test::TestType type) {
       return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
     });
+    typeConverter.addConversion([&](test::SimpleAType type) {
+      return IntegerType::get(type.getContext(), 42);
+    });
 
     // Populate patterns.
     RewritePatternSet patterns(m.getContext());


        


More information about the Mlir-commits mailing list