[Mlir-commits] [mlir] 9a5bc83 - Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.

Eugene Zhulenev llvmlistbot at llvm.org
Thu Jul 29 08:52:35 PDT 2021


Author: bakhtiyar
Date: 2021-07-29T08:52:28-07:00
New Revision: 9a5bc83660ed6978521dcfa4faac140cf5b2e895

URL: https://github.com/llvm/llvm-project/commit/9a5bc83660ed6978521dcfa4faac140cf5b2e895
DIFF: https://github.com/llvm/llvm-project/commit/9a5bc83660ed6978521dcfa4faac140cf5b2e895.diff

LOG: Add an escape-hatch for conversion of funcs with blocking awaits to coroutines.

Currently TFRT does not support top-level coroutines, so this functionality will allow to have a single blocking await at the top level until TFRT implements the necessary functionality.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D106730

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
index 6a9575c87fc2c..4bf56586fa950 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td
@@ -28,6 +28,15 @@ def AsyncDialect : Dialect {
   }];
 
   let cppNamespace = "::mlir::async";
+
+  let extraClassDeclaration = [{
+    // The name of a unit attribute on funcs that are allowed to have a blocking
+    // async.runtime.await ops. Only useful in combination with
+    // 'eliminate-blocking-await-ops' option, which in absence of this attribute
+    // might convert a func to a coroutine.
+    static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block";
+  }];
+
 }
 
 #endif // ASYNC_DIALECT_TD

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index d975eb1b39a17..f9f9804f244b4 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -44,7 +44,7 @@ def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
     Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool",
             /*default=*/"false",
            "Rewrite functions with blocking async.runtime.await as coroutines "
-           "with async.runtime.await_and_resume.">
+           "with async.runtime.await_and_resume.">,
   ];
   let dependentDialects = ["async::AsyncDialect"];
 }

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 3e325e21b1c56..67b4096122745 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -16,6 +16,8 @@ using namespace mlir::async;
 
 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
 
+constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
+
 void AsyncDialect::initialize() {
   addOperations<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 1ffcda4cb9dbb..5ca0d632b67e2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -614,6 +614,10 @@ static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
   oldCall.erase();
 }
 
+static bool isAllowedToBlock(FuncOp func) {
+  return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
+}
+
 static LogicalResult
 funcsToCoroutines(ModuleOp module,
                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
@@ -628,12 +632,15 @@ funcsToCoroutines(ModuleOp module,
   // Careful, it's okay to add a func to the worklist multiple times if and only
   // if the loop processing the worklist will skip the functions that have
   // already been converted to coroutines.
-  auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) {
+  auto addToWorklist = [&](FuncOp func) {
+    if (isAllowedToBlock(func))
+      return;
     // N.B. To refactor this code into a separate pass the lookup in
     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
     // func and recognizing if it has a coroutine structure is messy. Passing
     // this dict between the passes is ugly.
-    if (outlinedFunctions.find(func) == outlinedFunctions.end()) {
+    if (isAllowedToBlock(func) ||
+        outlinedFunctions.find(func) == outlinedFunctions.end()) {
       for (Operation &op : func.body().getOps()) {
         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
           funcWorklist.push_back(func);
@@ -759,7 +766,10 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   });
 
   if (eliminateBlockingAwaitOps)
-    runtimeTarget.addIllegalOp<RuntimeAwaitOp>();
+    runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
+        [&](RuntimeAwaitOp op) -> bool {
+          return isAllowedToBlock(op->getParentOfType<FuncOp>());
+        });
 
   if (failed(applyPartialConversion(module, runtimeTarget,
                                     std::move(asyncPatterns)))) {

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 1013966b15946..7718d06888db4 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -302,3 +302,18 @@ return
 // CHECK:   async.coro.end %[[HDL]]
 // CHECK:   return %[[TOKEN]] : !async.token
 }
+
+// CHECK-LABEL: func @caller_allowed_to_block
+// CHECK-SAME: () -> f32
+func @caller_allowed_to_block() -> f32 attributes { async.allowed_to_block } {
+// CHECK: %[[CONSTANT:.*]] = constant
+  %c = constant 1.0 : f32
+// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0
+// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1
+// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1
+  %r = call @simple_callee(%c): (f32) -> f32
+
+// CHECK:   return %[[RETURNED]] : f32
+  return %r: f32
+}


        


More information about the Mlir-commits mailing list