[Mlir-commits] [mlir] 9a5bc83 - Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.
Eugene Zhulenev
llvmlistbot at llvm.org
Thu Jul 29 08:52:35 PDT 2021
Author: bakhtiyar
Date: 2021-07-29T08:52:28-07:00
New Revision: 9a5bc83660ed6978521dcfa4faac140cf5b2e895
URL: https://github.com/llvm/llvm-project/commit/9a5bc83660ed6978521dcfa4faac140cf5b2e895
DIFF: https://github.com/llvm/llvm-project/commit/9a5bc83660ed6978521dcfa4faac140cf5b2e895.diff
LOG: Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.
Currently TFRT does not support top-level coroutines, so this functionality will allow to have a single blocking await at the top level until TFRT implements the necessary functionality.
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D106730
Added:
Modified:
mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
index 6a9575c87fc2c..4bf56586fa950 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
@@ -28,6 +28,15 @@ def AsyncDialect : Dialect {
}];
let cppNamespace = "::mlir::async";
+
+ let extraClassDeclaration = [{
+ // The name of a unit attribute on funcs that are allowed to have a blocking
+ // async.runtime.await ops. Only useful in combination with
+ // 'eliminate-blocking-await-ops' option, which in absence of this attribute
+ // might convert a func to a coroutine.
+ static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block";
+ }];
+
}
#endif // ASYNC_DIALECT_TD
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index d975eb1b39a17..f9f9804f244b4 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -44,7 +44,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool",
/*default=*/"false",
"Rewrite functions with blocking async.runtime.await as coroutines "
- "with async.runtime.await_and_resume.">
+ "with async.runtime.await_and_resume.">,
];
let dependentDialects = ["async::AsyncDialect"];
}
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 3e325e21b1c56..67b4096122745 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -16,6 +16,8 @@ using namespace mlir::async;
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
+constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
+
void AsyncDialect::initialize() {
addOperations<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 1ffcda4cb9dbb..5ca0d632b67e2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -614,6 +614,10 @@ static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
oldCall.erase();
}
+static bool isAllowedToBlock(FuncOp func) {
+ return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
+}
+
static LogicalResult
funcsToCoroutines(ModuleOp module,
llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
@@ -628,12 +632,15 @@ funcsToCoroutines(ModuleOp module,
// Careful, it's okay to add a func to the worklist multiple times if and only
// if the loop processing the worklist will skip the functions that have
// already been converted to coroutines.
- auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
+ auto addToWorklist = [&](FuncOp func) {
+ if (isAllowedToBlock(func))
+ return;
// N.B. To refactor this code into a separate pass the lookup in
// outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
// func and recognizing if it has a coroutine structure is messy. Passing
// this dict between the passes is ugly.
- if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
+ if (isAllowedToBlock(func) ||
+ outlinedFunctions.find(func) == outlinedFunctions.end()) {
for (Operation &op : func.body().getOps()) {
if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
funcWorklist.push_back(func);
@@ -759,7 +766,10 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
});
if (eliminateBlockingAwaitOps)
- runtimeTarget.addIllegalOp<RuntimeAwaitOp>();
+ runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
+ [&](RuntimeAwaitOp op) -> bool {
+ return isAllowedToBlock(op->getParentOfType<FuncOp>());
+ });
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 1013966b15946..7718d06888db4 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -302,3 +302,18 @@ return
// CHECK: async.coro.end %[[HDL]]
// CHECK: return %[[TOKEN]] : !async.token
}
+
+// CHECK-LABEL: func @caller_allowed_to_block
+// CHECK-SAME: () -> f32
+func @caller_allowed_to_block() -> f32 attributes { async.allowed_to_block } {
+// CHECK: %[[CONSTANT:.*]] = constant
+ %c = constant 1.0 : f32
+// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0
+// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1
+// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1
+ %r = call @simple_callee(%c): (f32) -> f32
+
+// CHECK: return %[[RETURNED]] : f32
+ return %r: f32
+}
More information about the Mlir-commits
mailing list