[Mlir-commits] [mlir] 8a316b0 - [mlir] Convert async dialect passes from function passes to op agnostic passes
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Apr 13 11:46:09 PDT 2021
Author: Eugene Zhulenev
Date: 2021-04-13T11:46:00-07:00
New Revision: 8a316b00d63d5370e4f75d482e12d62b54c64308
URL: https://github.com/llvm/llvm-project/commit/8a316b00d63d5370e4f75d482e12d62b54c64308
DIFF: https://github.com/llvm/llvm-project/commit/8a316b00d63d5370e4f75d482e12d62b54c64308.diff
LOG: [mlir] Convert async dialect passes from function passes to op agnostic passes
Differential Revision: https://reviews.llvm.org/D100401
Added:
Modified:
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
mlir/test/Integration/GPU/CUDA/async.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index ddcfc8bdaeddf..d790835c76125 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -17,16 +17,15 @@
namespace mlir {
-std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+std::unique_ptr<Pass> createAsyncParallelForPass();
-std::unique_ptr<OperationPass<FuncOp>>
-createAsyncParallelForPass(int numWorkerThreads);
+std::unique_ptr<Pass> createAsyncParallelForPass(int numWorkerThreads);
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
+std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
+std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 155e23572bf80..d5640f3ae65a6 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -11,7 +11,7 @@
include "mlir/Pass/PassBase.td"
-def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
+def AsyncParallelFor : Pass<"async-parallel-for"> {
let summary = "Convert scf.parallel operations to multiple async regions "
"executed concurrently for non-overlapping iteration ranges";
let constructor = "mlir::createAsyncParallelForPass()";
@@ -31,7 +31,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let dependentDialects = ["async::AsyncDialect"];
}
-def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
+def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{
This pass works at the async runtime abtraction level, after all
@@ -48,8 +48,7 @@ def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
let dependentDialects = ["async::AsyncDialect"];
}
-def AsyncRuntimeRefCountingOpt :
- FunctionPass<"async-runtime-ref-counting-opt"> {
+def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
let summary = "Optimize automatic reference counting operations for the"
"Async runtime by removing redundant operations";
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 3627635ed0606..ce2bc7081faf1 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -100,7 +100,7 @@ struct AsyncParallelForPass
assert(numWorkerThreads >= 1);
numConcurrentAsyncExecute = numWorkerThreads;
}
- void runOnFunction() override;
+ void runOnOperation() override;
};
} // namespace
@@ -267,21 +267,20 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
return success();
}
-void AsyncParallelForPass::runOnFunction() {
+void AsyncParallelForPass::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
- if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
-std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncParallelForPass(int numWorkerThreads) {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) {
return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index af443918df970..6516e163c94ce 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -32,7 +32,7 @@ class AsyncRuntimeRefCountingPass
: public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
public:
AsyncRuntimeRefCountingPass() = default;
- void runOnFunction() override;
+ void runOnOperation() override;
private:
/// Adds an automatic reference counting to the `value`.
@@ -323,13 +323,13 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
return success();
}
-void AsyncRuntimeRefCountingPass::runOnFunction() {
- FuncOp func = getFunction();
+void AsyncRuntimeRefCountingPass::runOnOperation() {
+ Operation *op = getOperation();
// Check that we do not have high level async operations in the IR because
// otherwise automatic reference counting will produce incorrect results after
// execute operations will be lowered to `async.runtime`
- WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
+ WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
return WalkResult::advance();
@@ -343,7 +343,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
}
// Add reference counting to block arguments.
- WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(addAutomaticRefCounting(arg)))
@@ -358,7 +358,7 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
}
// Add reference counting to operation results.
- WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(addAutomaticRefCounting(op->getResult(i))))
@@ -371,7 +371,6 @@ void AsyncRuntimeRefCountingPass::runOnFunction() {
signalPassFailure();
}
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRuntimeRefCountingPass() {
+std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
return std::make_unique<AsyncRuntimeRefCountingPass>();
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
index cb00d706ce0c8..358cbbb602aee 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
@@ -26,7 +26,7 @@ class AsyncRuntimeRefCountingOptPass
: public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
public:
AsyncRuntimeRefCountingOptPass() = default;
- void runOnFunction() override;
+ void runOnOperation() override;
private:
LogicalResult optimizeReferenceCounting(
@@ -124,8 +124,8 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
return success();
}
-void AsyncRuntimeRefCountingOptPass::runOnFunction() {
- FuncOp func = getFunction();
+void AsyncRuntimeRefCountingOptPass::runOnOperation() {
+ Operation *op = getOperation();
// Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
//
@@ -134,7 +134,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
// Optimize reference counting for values defined by block arguments.
- WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg, cancellable)))
@@ -147,7 +147,7 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
signalPassFailure();
// Optimize reference counting for values defined by operation results.
- WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
@@ -171,7 +171,6 @@ void AsyncRuntimeRefCountingOptPass::runOnFunction() {
}
}
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRuntimeRefCountingOptPass() {
+std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
}
diff --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir
index fd9bc4749dd09..69256af2428ae 100644
--- a/mlir/test/Integration/GPU/CUDA/async.mlir
+++ b/mlir/test/Integration/GPU/CUDA/async.mlir
@@ -1,8 +1,9 @@
// RUN: mlir-opt %s \
// RUN: -gpu-kernel-outlining \
// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin)' \
-// RUN: -gpu-async-region -async-ref-counting -gpu-to-llvm \
-// RUN: -async-to-async-runtime -convert-async-to-llvm -convert-std-to-llvm \
+// RUN: -gpu-async-region -gpu-to-llvm \
+// RUN: -async-to-async-runtime -async-runtime-ref-counting \
+// RUN: -convert-async-to-llvm -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
More information about the Mlir-commits
mailing list