[Mlir-commits] [mlir] 876a480 - [mlir][Conversion] Add type converter parameter to ConvertToLLVMPatternInterface

Matthias Springer llvmlistbot at llvm.org
Wed Aug 9 00:11:41 PDT 2023


Author: Matthias Springer
Date: 2023-08-09T09:00:46+02:00
New Revision: 876a480cacf908feaa88acdbfb9aa3d40013e08e

URL: https://github.com/llvm/llvm-project/commit/876a480cacf908feaa88acdbfb9aa3d40013e08e
DIFF: https://github.com/llvm/llvm-project/commit/876a480cacf908feaa88acdbfb9aa3d40013e08e.diff

LOG: [mlir][Conversion] Add type converter parameter to ConvertToLLVMPatternInterface

Most `*-to-llvm` conversion patterns require a type converter. This
revision adds a type converter to the
`populateConvertToLLVMConversionPatterns` function and implements the
interface for the MemRef dialect.

Differential Revision: https://reviews.llvm.org/D157387

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
    mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
    mlir/include/mlir/InitAllExtensions.h
    mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
    mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
    mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
index 8841c38deafecb..424ab38a13b24d 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h
@@ -40,13 +40,15 @@ class ConvertToLLVMPatternInterface
   /// Hook for derived dialect interface to provide conversion patterns
   /// and mark dialect legal for the conversion target.
   virtual void populateConvertToLLVMConversionPatterns(
-      ConversionTarget &target, RewritePatternSet &patterns) const = 0;
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const = 0;
 };
 
 /// Recursively walk the IR and collect all dialects implementing the interface,
 /// and populate the conversion patterns.
 void populateConversionTargetFromOperation(Operation *op,
                                            ConversionTarget &target,
+                                           LLVMTypeConverter &typeConverter,
                                            RewritePatternSet &patterns);
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 943e2108eb97f2..db1c222643d571 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 namespace mlir {
+class DialectRegistry;
 class Pass;
 class LLVMTypeConverter;
 class RewritePatternSet;
@@ -23,6 +24,9 @@ class RewritePatternSet;
 /// MemRef dialect to the LLVM dialect.
 void populateFinalizeMemRefToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns);
+
+void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H

diff  --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 14fc94fc86cd9f..ee51bee9ced1d8 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INITALLEXTENSIONS_H_
 #define MLIR_INITALLEXTENSIONS_H_
 
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Target/LLVM/NVVM/Target.h"
@@ -29,6 +30,7 @@ namespace mlir {
 /// pipelines and transformations you are using.
 inline void registerAllExtensions(DialectRegistry &registry) {
   func::registerAllExtensions(registry);
+  registerConvertMemRefToLLVMInterface(registry);
   registerConvertNVVMToLLVMInterface(registry);
   registerNVVMTarget(registry);
 }

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
index c1f7165d526c4d..df7e3f995303c9 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass
   LINK_LIBS PUBLIC
   MLIRConvertToLLVMInterface
   MLIRIR
+  MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRPass
   MLIRRewrite

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index 9766e847c7f3b5..f838068dc0d555 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -62,6 +63,7 @@ class ConvertToLLVMPass
     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
   std::shared_ptr<const FrozenRewritePatternSet> patterns;
   std::shared_ptr<const ConversionTarget> target;
+  std::shared_ptr<const LLVMTypeConverter> typeConverter;
 
 public:
   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -72,23 +74,26 @@ class ConvertToLLVMPass
 
   ConvertToLLVMPass(const ConvertToLLVMPass &other)
       : ConvertToLLVMPassBase(other), patterns(other.patterns),
-        target(other.target) {}
+        target(other.target), typeConverter(other.typeConverter) {}
 
   LogicalResult initialize(MLIRContext *context) final {
     RewritePatternSet tempPatterns(context);
     auto target = std::make_shared<ConversionTarget>(*context);
     target->addLegalDialect<LLVM::LLVMDialect>();
+    auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
     for (Dialect *dialect : context->getLoadedDialects()) {
       // First time we encounter this dialect: if it implements the interface,
       // let's populate patterns !
       auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
       if (!iface)
         continue;
-      iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns);
+      iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
+                                                     tempPatterns);
     }
     patterns =
         std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
     this->target = target;
+    this->typeConverter = typeConverter;
     return success();
   }
 

diff  --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 6be3defd8781ee..0f2cdbb549ee66 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -13,9 +13,9 @@
 
 using namespace mlir;
 
-void mlir::populateConversionTargetFromOperation(Operation *root,
-                                                 ConversionTarget &target,
-                                                 RewritePatternSet &patterns) {
+void mlir::populateConversionTargetFromOperation(
+    Operation *root, ConversionTarget &target, LLVMTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
   DenseSet<Dialect *> dialects;
   root->walk([&](Operation *op) {
     Dialect *dialect = op->getDialect();
@@ -26,6 +26,7 @@ void mlir::populateConversionTargetFromOperation(Operation *root,
     auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
     if (!iface)
       return;
-    iface->populateConvertToLLVMConversionPatterns(target, patterns);
+    iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
+                                                   patterns);
   });
 }

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d69ee3ff82220e..80ae59e16e5a95 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 
 #include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -1935,4 +1936,27 @@ struct FinalizeMemRefToLLVMConversionPass
       signalPassFailure();
   }
 };
+
+/// Implement the interface to convert MemRef to LLVM.
+struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<LLVM::LLVMDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
+  }
+};
+
 } // namespace
+
+void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    dialect->addInterfaces<MemRefToLLVMDialectInterface>();
+  });
+}

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 6d2726f949d9c9..52abbe998872ab 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -201,7 +201,8 @@ struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   /// Hook for derived dialect interface to provide conversion patterns
   /// and mark dialect legal for the conversion target.
   void populateConvertToLLVMConversionPatterns(
-      ConversionTarget &target, RewritePatternSet &patterns) const final {
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
     populateNVVMToLLVMConversionPatterns(patterns);
   }
 };

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index a5513dcc8d7cb4..597e76fade3595 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1,6 +1,12 @@
 // RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s -split-input-file | FileCheck %s
 // RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
 
+// Same below, but using the `ConvertToLLVMPatternInterface` entry point
+// and the generic `convert-to-llvm` pass. This produces slightly 
diff erent IR
+// because the conversion target is set up 
diff erently. Only one test case is
+// checked.
+// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s
+
 // CHECK-LABEL: func @view(
 // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
 func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
@@ -88,6 +94,10 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
 // CHECK-LABEL: func @view_empty_memref(
 // CHECK:        %[[ARG0:.*]]: index,
 // CHECK:        %[[ARG1:.*]]: memref<0xi8>)
+
+// CHECK-INTERFACE-LABEL: func @view_empty_memref(
+// CHECK-INTERFACE:        %[[ARG0:.*]]: index,
+// CHECK-INTERFACE:        %[[ARG1:.*]]: memref<0xi8>)
 func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
 
   // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
@@ -101,6 +111,18 @@ func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
   // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: llvm.mlir.constant(4 : index) : i64
   // CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+
+  // CHECK-INTERFACE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
+  // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
+  // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-INTERFACE: llvm.mlir.constant(1 : index) : i64
+  // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
+  // CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
+  // CHECK-INTERFACE: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   %0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
 
   return


        


More information about the Mlir-commits mailing list