[Mlir-commits] [mlir] 6cca6b9 - Add async_funcs_only option to AsyncToAsyncRuntime pass
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Nov 30 10:27:05 PST 2022
Author: yijiagu
Date: 2022-11-30T10:27:00-08:00
New Revision: 6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8
URL: https://github.com/llvm/llvm-project/commit/6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8
DIFF: https://github.com/llvm/llvm-project/commit/6cca6b9ab9ff27d22cdb4b9b4aa98ec91b6b08d8.diff
LOG: Add async_funcs_only option to AsyncToAsyncRuntime pass
This change adds async_funcs_only option to AsyncToAsyncRuntimePass. The goal is to convert async functions to regular functions in early stages of compilation pipeline.
Differential Revision: https://reviews.llvm.org/D138611
Added:
Modified:
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/mlir-cpu-runner/async-func.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 78216125ab91b..090768cd0209c 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -17,6 +17,7 @@
namespace mlir {
class ModuleOp;
+class ConversionTarget;
#define GEN_PASS_DECL
#include "mlir/Dialect/Async/Passes.h.inc"
@@ -27,6 +28,11 @@ std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t minTaskSize);
+void populateAsyncFuncToAsyncRuntimeConversionPatterns(
+ RewritePatternSet &patterns, ConversionTarget &target);
+
+std::unique_ptr<OperationPass<ModuleOp>> createAsyncFuncToAsyncRuntimePass();
+
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index aed5b4ff7865a..8b579dd591604 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -41,12 +41,19 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
}
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
- let summary = "Lower high level async operations (e.g. async.execute) to the"
- "explicit async.runtime and async.coro operations";
+ let summary = "Lower all high level async operations (e.g. async.execute) to"
+ "the explicit async.runtime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
}
+def AsyncFuncToAsyncRuntime : Pass<"async-func-to-async-runtime", "ModuleOp"> {
+ let summary = "Lower async.func operations to the explicit async.runtime and"
+ "async.coro operations";
+ let constructor = "mlir::createAsyncFuncToAsyncRuntimePass()";
+ let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
+}
+
def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 66c5b731b6e76..e7859488b6b4e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -30,6 +30,7 @@
namespace mlir {
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
+#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir
@@ -51,6 +52,17 @@ class AsyncToAsyncRuntimePass
} // namespace
+namespace {
+
+class AsyncFuncToAsyncRuntimePass
+ : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
+public:
+ AsyncFuncToAsyncRuntimePass() = default;
+ void runOnOperation() override;
+};
+
+} // namespace
+
/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
@@ -84,6 +96,9 @@ struct CoroMachinery {
};
} // namespace
+using FuncCoroMapPtr =
+ std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
+
/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
@@ -399,9 +414,8 @@ namespace {
class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
public:
- AsyncFuncOpLowering(MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
- : OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
+ AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+ : OpConversionPattern<async::FuncOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
@@ -423,7 +437,7 @@ class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
newFuncOp.end());
CoroMachinery coro = setupCoroMachinery(newFuncOp);
- coros[newFuncOp] = coro;
+ (*coros_)[newFuncOp] = coro;
// no initial suspend, we should hot-start
rewriter.eraseOp(op);
@@ -431,7 +445,7 @@ class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+ FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@@ -458,16 +472,15 @@ class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
public:
- AsyncReturnOpLowering(MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
- : OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
+ AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+ : OpConversionPattern<async::ReturnOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->template getParentOfType<func::FuncOp>();
- auto funcCoro = coros.find(func);
- if (funcCoro == coros.end())
+ auto funcCoro = coros_->find(func);
+ if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@@ -494,7 +507,7 @@ class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+ FuncCoroMapPtr coros_;
};
} // namespace
@@ -509,9 +522,10 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
- AwaitOpLoweringBase(MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
- : OpConversionPattern<AwaitType>(ctx), coros(coros) {}
+ AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
+ bool should_lower_blocking_wait)
+ : OpConversionPattern<AwaitType>(ctx), coros_(coros),
+ should_lower_blocking_wait_(should_lower_blocking_wait) {}
LogicalResult
matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
@@ -521,16 +535,20 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
if (!op.getOperand().getType().template isa<AwaitableType>())
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
- // Check if await operation is inside the outlined coroutine function.
+ // Check if await operation is inside the coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto funcCoro = coros.find(func);
- const bool isInCoroutine = funcCoro != coros.end();
+ auto funcCoro = coros_->find(func);
+ const bool isInCoroutine = funcCoro != coros_->end();
Location loc = op->getLoc();
Value operand = adaptor.getOperand();
Type i1 = rewriter.getI1Type();
+ // Delay lowering to block wait in case await op is inside async.execute
+ if (!isInCoroutine && !should_lower_blocking_wait_)
+ return failure();
+
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
@@ -602,7 +620,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+ FuncCoroMapPtr coros_;
+ bool should_lower_blocking_wait_;
};
/// Lowering for `async.await` with a token operand.
@@ -645,17 +664,16 @@ class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
- YieldOpLowering(MLIRContext *ctx,
- const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
- : OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
+ YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+ : OpConversionPattern<async::YieldOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if yield operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto funcCoro = coros.find(func);
- if (funcCoro == coros.end())
+ auto funcCoro = coros_->find(func);
+ if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@@ -682,7 +700,7 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
}
private:
- const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+ FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@@ -691,17 +709,16 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public:
- AssertOpLowering(MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
- : OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
+ AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
+ : OpConversionPattern<cf::AssertOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if assert operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto funcCoro = coros.find(func);
- if (funcCoro == coros.end())
+ auto funcCoro = coros_->find(func);
+ if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@@ -721,7 +738,7 @@ class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+ FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@@ -730,22 +747,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
SymbolTable symbolTable(module);
// Functions with coroutine CFG setups, which are results of outlining
- // `async.execute` body regions and converting async.func.
- llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
+ // `async.execute` body regions
+ FuncCoroMapPtr coros =
+ std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
module.walk([&](ExecuteOp execute) {
- coros.insert(outlineExecuteOp(symbolTable, execute));
+ coros->insert(outlineExecuteOp(symbolTable, execute));
});
LLVM_DEBUG({
- llvm::dbgs() << "Outlined " << coros.size()
+ llvm::dbgs() << "Outlined " << coros->size()
<< " functions built from async.execute operations\n";
});
// Returns true if operation is inside the coroutine.
auto isInCoroutine = [&](Operation *op) -> bool {
auto parentFunc = op->getParentOfType<func::FuncOp>();
- return coros.find(parentFunc) != coros.end();
+ return coros->find(parentFunc) != coros->end();
};
// Lower async operations to async.runtime operations.
@@ -762,22 +780,18 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// types for async.runtime operations.
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
- // Lower async.func to func.func with coroutine cfg.
- asyncPatterns.add<AsyncCallOpLowering>(ctx);
- asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
-
- asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
- AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
+ asyncPatterns
+ .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+ ctx, coros, /*should_lower_blocking_wait=*/true);
// Lower assertions to conditional branches into error blocks.
- asyncPatterns.add<AssertOpLowering>(ctx, coros);
+ asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
- runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
- async::FuncOp, async::CallOp, async::ReturnOp>();
+ runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
// Decide if structured control flow has to be lowered to branch-based CFG.
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
@@ -795,7 +809,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
[&](cf::AssertOp op) -> bool {
auto func = op->getParentOfType<func::FuncOp>();
- return coros.find(func) == coros.end();
+ return coros->find(func) == coros->end();
});
if (failed(applyPartialConversion(module, runtimeTarget,
@@ -805,6 +819,59 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
}
}
+//===----------------------------------------------------------------------===//
+void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
+ RewritePatternSet &patterns, ConversionTarget &target) {
+ // Functions with coroutine CFG setups, which are results of converting
+ // async.func.
+ FuncCoroMapPtr coros =
+ std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
+ MLIRContext *ctx = patterns.getContext();
+ // Lower async.func to func.func with coroutine cfg.
+ patterns.add<AsyncCallOpLowering>(ctx);
+ patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
+
+ patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+ ctx, coros, /*should_lower_blocking_wait=*/false);
+ patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
+
+ target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
+ [coros](Operation *op) {
+ auto func = op->getParentOfType<func::FuncOp>();
+ return coros->find(func) == coros->end();
+ });
+}
+
+void AsyncFuncToAsyncRuntimePass::runOnOperation() {
+ ModuleOp module = getOperation();
+
+ // Lower async operations to async.runtime operations.
+ MLIRContext *ctx = module->getContext();
+ RewritePatternSet asyncPatterns(ctx);
+ ConversionTarget runtimeTarget(*ctx);
+
+ // Lower async.func to func.func with coroutine cfg.
+ populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
+ runtimeTarget);
+
+ runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
+ runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
+
+ runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
+ cf::BranchOp, cf::CondBranchOp>();
+
+ if (failed(applyPartialConversion(module, runtimeTarget,
+ std::move(asyncPatterns)))) {
+ signalPassFailure();
+ return;
+ }
+}
+
std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
return std::make_unique<AsyncToAsyncRuntimePass>();
}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createAsyncFuncToAsyncRuntimePass() {
+ return std::make_unique<AsyncFuncToAsyncRuntimePass>();
+}
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 1551e55c90c08..38a88cc9de5b8 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \
-// RUN: | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -split-input-file -async-func-to-async-runtime \
+// RUN: -async-to-async-runtime | FileCheck %s --dump-input=always
// CHECK-LABEL: @execute_no_async_args
func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
index 8b3d728d4667f..6f21ba906b222 100644
--- a/mlir/test/mlir-cpu-runner/async-func.mlir
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-func-to-async-runtime,async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \
More information about the Mlir-commits
mailing list