[Mlir-commits] [mlir] 8a316b0 - [mlir] Convert async dialect passes from function passes to op agnostic passes

Eugene Zhulenev llvmlistbot at llvm.org
Tue Apr 13 11:46:09 PDT 2021


Author: Eugene Zhulenev
Date: 2021-04-13T11:46:00-07:00
New Revision: 8a316b00d63d5370e4f75d482e12d62b54c64308

URL: https://github.com/llvm/llvm-project/commit/8a316b00d63d5370e4f75d482e12d62b54c64308
DIFF: https://github.com/llvm/llvm-project/commit/8a316b00d63d5370e4f75d482e12d62b54c64308.diff

LOG: [mlir] Convert async dialect passes from function passes to op agnostic passes

Differential Revision: https://reviews.llvm.org/D100401

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
    mlir/test/Integration/GPU/CUDA/async.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index ddcfc8bdaeddf..d790835c76125 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -17,16 +17,15 @@
 
 namespace mlir {
 
-std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+std::unique_ptr<Pass> createAsyncParallelForPass();
 
-std::unique_ptr<OperationPass<FuncOp>>
-createAsyncParallelForPass(int numWorkerThreads);
+std::unique_ptr<Pass> createAsyncParallelForPass(int numWorkerThreads);
 
 std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
 
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
+std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
 
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
+std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
 
 //===----------------------------------------------------------------------===//
 // Registration

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 155e23572bf80..d5640f3ae65a6 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -11,7 +11,7 @@
 
 include "mlir/Pass/PassBase.td"
 
-def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
+def AsyncParallelFor : Pass<"async-parallel-for"> {
   let summary = "Convert scf.parallel operations to multiple async regions "
                 "executed concurrently for non-overlapping iteration ranges";
   let constructor = "mlir::createAsyncParallelForPass()";
@@ -31,7 +31,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
   let dependentDialects = ["async::AsyncDialect"];
 }
 
-def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
+def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
   let summary = "Automatic reference counting for Async runtime operations";
   let description = [{
     This pass works at the async runtime abtraction level, after all
@@ -48,8 +48,7 @@ def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
   let dependentDialects = ["async::AsyncDialect"];
 }
 
-def AsyncRuntimeRefCountingOpt :
-    FunctionPass<"async-runtime-ref-counting-opt"> {
+def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
   let summary = "Optimize automatic reference counting operations for the"
                 "Async runtime by removing redundant operations";
   let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 3627635ed0606..ce2bc7081faf1 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -100,7 +100,7 @@ struct AsyncParallelForPass
     assert(numWorkerThreads >= 1);
     numConcurrentAsyncExecute = numWorkerThreads;
   }
-  void runOnFunction() override;
+  void runOnOperation() override;
 };
 
 } // namespace
@@ -267,21 +267,20 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   return success();
 }
 
-void AsyncParallelForPass::runOnFunction() {
+void AsyncParallelForPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
 
   RewritePatternSet patterns(ctx);
   patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
 
-  if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
 
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncParallelForPass(int numWorkerThreads) {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) {
   return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
 }

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index af443918df970..6516e163c94ce 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -32,7 +32,7 @@ class AsyncRuntimeRefCountingPass
     : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
 public:
   AsyncRuntimeRefCountingPass() = default;
-  void runOnFunction() override;
+  void runOnOperation() override;
 
 private:
   /// Adds an automatic reference counting to the `value`.
@@ -323,13 +323,13 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
   return success();
 }
 
-void AsyncRuntimeRefCountingPass::runOnFunction() {
-  FuncOp func = getFunction();
+void AsyncRuntimeRefCountingPass::runOnOperation() {
+  Operation *op = getOperation();
 
   // Check that we do not have high level async operations in the IR because
   // otherwise automatic reference counting will produce incorrect results after
   // execute operations will be lowered to `async.runtime`
-  WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
+  WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
       return WalkResult::advance();
 
@@ -343,7 +343,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
   }
 
   // Add reference counting to block arguments.
-  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
     for (BlockArgument arg : block->getArguments())
       if (isRefCounted(arg.getType()))
         if (failed(addAutomaticRefCounting(arg)))
@@ -358,7 +358,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
   }
 
   // Add reference counting to operation results.
-  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
     for (unsigned i = 0; i < op->getNumResults(); ++i)
       if (isRefCounted(op->getResultTypes()[i]))
         if (failed(addAutomaticRefCounting(op->getResult(i))))
@@ -371,7 +371,6 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
     signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRuntimeRefCountingPass() {
+std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
   return std::make_unique<AsyncRuntimeRefCountingPass>();
 }

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
index cb00d706ce0c8..358cbbb602aee 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
@@ -26,7 +26,7 @@ class AsyncRuntimeRefCountingOptPass
     : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
 public:
   AsyncRuntimeRefCountingOptPass() = default;
-  void runOnFunction() override;
+  void runOnOperation() override;
 
 private:
   LogicalResult optimizeReferenceCounting(
@@ -124,8 +124,8 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
   return success();
 }
 
-void AsyncRuntimeRefCountingOptPass::runOnFunction() {
-  FuncOp func = getFunction();
+void AsyncRuntimeRefCountingOptPass::runOnOperation() {
+  Operation *op = getOperation();
 
   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
   //
@@ -134,7 +134,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
 
   // Optimize reference counting for values defined by block arguments.
-  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
     for (BlockArgument arg : block->getArguments())
       if (isRefCounted(arg.getType()))
         if (failed(optimizeReferenceCounting(arg, cancellable)))
@@ -147,7 +147,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
     signalPassFailure();
 
   // Optimize reference counting for values defined by operation results.
-  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
     for (unsigned i = 0; i < op->getNumResults(); ++i)
       if (isRefCounted(op->getResultTypes()[i]))
         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
@@ -171,7 +171,6 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
   }
 }
 
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRuntimeRefCountingOptPass() {
+std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
 }

diff  --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir
index fd9bc4749dd09..69256af2428ae 100644
--- a/mlir/test/Integration/GPU/CUDA/async.mlir
+++ b/mlir/test/Integration/GPU/CUDA/async.mlir
@@ -1,8 +1,9 @@
 // RUN: mlir-opt %s \
 // RUN:   -gpu-kernel-outlining \
 // RUN:   -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin)' \
-// RUN:   -gpu-async-region -async-ref-counting -gpu-to-llvm \
-// RUN:   -async-to-async-runtime -convert-async-to-llvm -convert-std-to-llvm \
+// RUN:   -gpu-async-region -gpu-to-llvm \
+// RUN:   -async-to-async-runtime -async-runtime-ref-counting \
+// RUN:   -convert-async-to-llvm -convert-std-to-llvm \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \


        


More information about the Mlir-commits mailing list