[Mlir-commits] [mlir] [mlir] Adopt `ConvertToLLVMPatternInterface` GpuToLLVMConversionPass to align with `convert-to-llvm` (PR #73761)

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 29 02:19:24 PST 2023


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/73761

>From be17e0b35165ec538bf356eb7c7451532a0e5cda Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 29 Nov 2023 00:05:18 -0800
Subject: [PATCH] [mlir] Adopt `ConvertToLLVMPatternInterface`
 GpuToLLVMConversionPass to align with `convert-to-llvm`

This is a follow-up to the introduction of `convert-to-llvm`: it is supposed
to be a unifying pass through the `ConvertToLLVMPatternInterface`, but some
specific conversion (like the GPU target) aren't vanilla LLVM target. Instead
they need extra customizations that are specific to LLVM-on-GPUs and our
custom runtime wrappers.
This change make the GpuToLLVMConversionPass just as pluggable as the
`convert-to-llvm` by using the same mechanism.
---
 .../Conversion/ConvertToLLVM/ToLLVMPass.h     |  4 +++
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       |  5 +++
 .../GPUCommon/GPUToLLVMConversion.cpp         | 36 +++++++++++++------
 3 files changed, 34 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
index 2eddf52d7abc520..73deef49c4175d3 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
@@ -22,6 +22,10 @@ namespace mlir {
 /// implementing `ConvertToLLVMPatternInterface`.
 std::unique_ptr<Pass> createConvertToLLVMPass();
 
+/// Register the extension that will load dependent dialects for LLVM
+/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
+void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index a90e557b1fdbd9c..6135117348a5b86 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -124,6 +124,11 @@ class ConvertToLLVMPass
 
 } // namespace
 
+void mlir::registerConvertToLLVMDependentDialectLoading(
+    DialectRegistry &registry) {
+  registry.addExtensions<LoadDependentDialectExtension>();
+}
+
 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
   return std::make_unique<ConvertToLLVMPass>();
 }
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 2da97c20e9c984e..75dee09d2f64fd0 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -18,6 +18,8 @@
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -38,6 +40,8 @@
 #include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
 
+#define DEBUG_TYPE "gpu-to-llvm"
+
 namespace mlir {
 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -48,12 +52,14 @@ using namespace mlir;
 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
 
 namespace {
-
 class GpuToLLVMConversionPass
     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
 public:
   using Base::Base;
-
+  void getDependentDialects(DialectRegistry &registry) const final {
+    Base::getDependentDialects(registry);
+    registerConvertToLLVMDependentDialectLoading(registry);
+  }
   // Run the dialect converter on the module.
   void runOnOperation() override;
 };
@@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
-  LowerToLLVMOptions options(&getContext());
+  MLIRContext *context = &getContext();
+  SymbolTable symbolTable = SymbolTable(getOperation());
+  LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
+  RewritePatternSet patterns(context);
+  ConversionTarget target(*context);
+  target.addLegalDialect<LLVM::LLVMDialect>();
+  LLVMTypeConverter converter(context, options);
+
+  // Populate all patterns from all dialects that implement the
+  // `ConvertToLLVMPatternInterface` interface.
+  for (Dialect *dialect : context->getLoadedDialects()) {
+    auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+    if (!iface)
+      continue;
+    iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
+  }
 
-  LLVMTypeConverter converter(&getContext(), options);
-  RewritePatternSet patterns(&getContext());
-  LLVMConversionTarget target(getContext());
-
-  SymbolTable symbolTable = SymbolTable(getOperation());
   // Preserve GPU modules if they have target attributes.
   target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
       [](gpu::GPUModuleOp module) -> bool {
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
                 !module.getTargetsAttr().empty());
       });
 
-  mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
-  mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+  // These aren't covered by the ConvertToLLVMPatternInterface right now.
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
-  populateFuncToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,



More information about the Mlir-commits mailing list