[Mlir-commits] [mlir] 1c14441 - Refactor AsyncToAsyncRuntime pass to boost understandability.

Eugene Zhulenev llvmlistbot at llvm.org
Thu Jul 29 12:01:19 PDT 2021


Author: bakhtiyar
Date: 2021-07-29T12:01:07-07:00
New Revision: 1c144410e791032dfea7dc744518a2c051ec0510

URL: https://github.com/llvm/llvm-project/commit/1c144410e791032dfea7dc744518a2c051ec0510
DIFF: https://github.com/llvm/llvm-project/commit/1c144410e791032dfea7dc744518a2c051ec0510.diff

LOG: Refactor AsyncToAsyncRuntime pass to boost understandability.

Depends On D106730

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D106731

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 5ca0d632b67e2..10dcba1f30444 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -67,6 +67,7 @@ struct CoroMachinery {
   llvm::SmallVector<Value, 4> returnValues; // returned async values
 
   Value coroHandle; // coroutine handle (!async.coro.handle value)
+  Block *entry;     // coroutine entry block
   Block *setError;  // switch completion token and all values to error state
   Block *cleanup;   // coroutine cleanup block
   Block *suspend;   // coroutine suspension block
@@ -75,16 +76,15 @@ struct CoroMachinery {
 
 /// Utility to partially update the regular function CFG to the coroutine CFG
 /// compatible with LLVM coroutines switched-resume lowering using
-/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block
-/// by prepending its ops with coroutine setup. Also inserts trailing blocks.
+/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
+/// that branches into preexisting entry block. Also inserts trailing blocks.
 ///
 /// The result types of the passed `func` must start with an `async.token`
 /// and be continued with some number of `async.value`s.
 ///
-/// It's up to the caller of this function to fix up the terminators of the
-/// preexisting blocks of the passed func op. If the passed `func` is legal,
-/// this typically means rewriting every return op as a yield op and a branch op
-/// to the suspend block.
+/// The func given to this function needs to have been preprocessed to have
+/// either branch or yield ops as terminators. Branches to the cleanup block are
+/// inserted after each yield.
 ///
 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
 ///
@@ -104,9 +104,9 @@ struct CoroMachinery {
 ///       %value = <async value> : !async.value<T> // create async value
 ///       %id = async.coro.id                      // create a coroutine id
 ///       %hdl = async.coro.begin %id              // create a coroutine handle
-///       /* other ops of the preexisting entry block */
+///       br ^preexisting_entry_block
 ///
-///     /* other preexisting blocks */
+///     /*  preexisting blocks modified to branch to the cleanup block */
 ///
 ///     ^set_error: // this block created lazily only if needed (see code below)
 ///       async.runtime.set_error %token : !async.token
@@ -127,6 +127,8 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
 
   MLIRContext *ctx = func.getContext();
   Block *entryBlock = &func.getBlocks().front();
+  Block *originalEntryBlock =
+      entryBlock->splitBlock(entryBlock->getOperations().begin());
   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
 
   // ------------------------------------------------------------------------ //
@@ -144,6 +146,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
   auto coroHdlOp =
       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
+  builder.create<BranchOp>(originalEntryBlock);
 
   Block *cleanupBlock = func.addBlock();
   Block *suspendBlock = func.addBlock();
@@ -175,11 +178,23 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // `async.await` op lowering will create resume blocks for async
   // continuations, and will conditionally branch to cleanup or suspend blocks.
 
+  for (Block &block : func.body().getBlocks()) {
+    if (&block == entryBlock || &block == cleanupBlock ||
+        &block == suspendBlock)
+      continue;
+    Operation *terminator = block.getTerminator();
+    if (auto yield = dyn_cast<YieldOp>(terminator)) {
+      builder.setInsertionPointToEnd(&block);
+      builder.create<BranchOp>(cleanupBlock);
+    }
+  }
+
   CoroMachinery machinery;
   machinery.func = func;
   machinery.asyncToken = retToken;
   machinery.returnValues = retValues;
   machinery.coroHandle = coroHdlOp.handle();
+  machinery.entry = entryBlock;
   machinery.setError = nullptr; // created lazily only if needed
   machinery.cleanup = cleanupBlock;
   machinery.suspend = suspendBlock;
@@ -241,68 +256,69 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   symbolTable.insert(func);
 
   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
+  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
+
+  // Prepare for coroutine conversion by creating the body of the function.
+  {
+    size_t numDependencies = execute.dependencies().size();
+    size_t numOperands = execute.operands().size();
+
+    // Await on all dependencies before starting to execute the body region.
+    for (size_t i = 0; i < numDependencies; ++i)
+      builder.create<AwaitOp>(func.getArgument(i));
+
+    // Await on all async value operands and unwrap the payload.
+    SmallVector<Value, 4> unwrappedOperands(numOperands);
+    for (size_t i = 0; i < numOperands; ++i) {
+      Value operand = func.getArgument(numDependencies + i);
+      unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
+    }
+
+    // Map from function inputs defined above the execute op to the function
+    // arguments.
+    BlockAndValueMapping valueMapping;
+    valueMapping.map(functionInputs, func.getArguments());
+    valueMapping.map(execute.body().getArguments(), unwrappedOperands);
+
+    // Clone all operations from the execute operation body into the outlined
+    // function body.
+    for (Operation &op : execute.body().getOps())
+      builder.clone(op, valueMapping);
+  }
 
-  // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
-  // blocks, adding async.coro operations and setting up control flow.
-  func.addEntryBlock();
+  // Adding entry/cleanup/suspend blocks.
   CoroMachinery coro = setupCoroMachinery(func);
 
   // Suspend async function at the end of an entry block, and resume it using
   // Async resume operation (execution will be resumed in a thread managed by
   // the async runtime).
-  Block *entryBlock = &func.getBlocks().front();
-  auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock);
+  {
+    BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
+    builder.setInsertionPointToEnd(coro.entry);
 
-  // Save the coroutine state: async.coro.save
-  auto coroSaveOp =
-      builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
+    // Save the coroutine state: async.coro.save
+    auto coroSaveOp =
+        builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
 
-  // Pass coroutine to the runtime to be resumed on a runtime managed thread.
-  builder.create<RuntimeResumeOp>(coro.coroHandle);
-  builder.create<BranchOp>(coro.cleanup);
+    // Pass coroutine to the runtime to be resumed on a runtime managed
+    // thread.
+    builder.create<RuntimeResumeOp>(coro.coroHandle);
 
-  // Split the entry block before the terminator (branch to suspend block).
-  auto *terminatorOp = entryBlock->getTerminator();
-  Block *suspended = terminatorOp->getBlock();
-  Block *resume = suspended->splitBlock(terminatorOp);
-
-  // Add async.coro.suspend as a suspended block terminator.
-  builder.setInsertionPointToEnd(suspended);
-  builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
-                                coro.cleanup);
-
-  size_t numDependencies = execute.dependencies().size();
-  size_t numOperands = execute.operands().size();
-
-  // Await on all dependencies before starting to execute the body region.
-  builder.setInsertionPointToStart(resume);
-  for (size_t i = 0; i < numDependencies; ++i)
-    builder.create<AwaitOp>(func.getArgument(i));
-
-  // Await on all async value operands and unwrap the payload.
-  SmallVector<Value, 4> unwrappedOperands(numOperands);
-  for (size_t i = 0; i < numOperands; ++i) {
-    Value operand = func.getArgument(numDependencies + i);
-    unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
-  }
+    // Add async.coro.suspend as a suspended block terminator.
+    builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
+                                  branch.getDest(), coro.cleanup);
 
-  // Map from function inputs defined above the execute op to the function
-  // arguments.
-  BlockAndValueMapping valueMapping;
-  valueMapping.map(functionInputs, func.getArguments());
-  valueMapping.map(execute.body().getArguments(), unwrappedOperands);
-
-  // Clone all operations from the execute operation body into the outlined
-  // function body.
-  for (Operation &op : execute.body().getOps())
-    builder.clone(op, valueMapping);
+    branch.erase();
+  }
 
   // Replace the original `async.execute` with a call to outlined function.
-  ImplicitLocOpBuilder callBuilder(loc, execute);
-  auto callOutlinedFunc = callBuilder.create<CallOp>(
-      func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
-  execute.replaceAllUsesWith(callOutlinedFunc.getResults());
-  execute.erase();
+  {
+    ImplicitLocOpBuilder callBuilder(loc, execute);
+    auto callOutlinedFunc = callBuilder.create<CallOp>(
+        func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
+    execute.replaceAllUsesWith(callOutlinedFunc.getResults());
+    execute.erase();
+  }
 
   return {func, coro};
 }
@@ -575,20 +591,15 @@ static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
                   [](Type type) { return ValueType::get(type); });
   func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
   func.insertResult(0, TokenType::get(ctx), {});
-  CoroMachinery coro = setupCoroMachinery(func);
   for (Block &block : func.getBlocks()) {
-    if (&block == coro.suspend)
-      continue;
-
     Operation *terminator = block.getTerminator();
     if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
       ImplicitLocOpBuilder builder(loc, returnOp);
       builder.create<YieldOp>(returnOp.getOperands());
-      builder.create<BranchOp>(coro.cleanup);
       returnOp.erase();
     }
   }
-  return coro;
+  return setupCoroMachinery(func);
 }
 
 /// Rewrites a call into a function that has been rewritten as a coroutine.

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 7718d06888db4..5baf362552889 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -10,19 +10,20 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
 // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
 // CHECK: %[[ID:.*]] = async.coro.id
 // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
-
-// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK  ^[[ORIGINAL_ENTRY]]:
+// CHECK:   %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
   %0 = addf %arg0, %arg0 : f32
-// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
+// CHECK:   %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
   %1 = async.runtime.create: !async.value<f32>
-// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK:   async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
   async.runtime.store %0, %1: !async.value<f32>
-// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK:   async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
   async.runtime.set_available %1: !async.value<f32>
 
-// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK:   %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK:   async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
+// CHECK:   async.coro.suspend %[[SAVED]]
 // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
   %2 = async.await %1 : !async.value<f32>
 
@@ -62,13 +63,15 @@ func @simple_caller() -> f32 {
 // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
 // CHECK: %[[ID:.*]] = async.coro.id
 // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK  ^[[ORIGINAL_ENTRY]]:
 
-// CHECK: %[[CONSTANT:.*]] = constant
+// CHECK:   %[[CONSTANT:.*]] = constant
   %c = constant 1.0 : f32
-// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
-// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK:   %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK:   %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK:   async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
+// CHECK:   async.coro.suspend %[[SAVED]]
 // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
   %r = call @simple_callee(%c): (f32) -> f32
 
@@ -109,13 +112,15 @@ func @double_caller() -> f32 {
 // CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
 // CHECK: %[[ID:.*]] = async.coro.id
 // CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK  ^[[ORIGINAL_ENTRY]]:
 
-// CHECK: %[[CONSTANT:.*]] = constant
+// CHECK:   %[[CONSTANT:.*]] = constant
   %c = constant 1.0 : f32
-// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
-// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED_1]]
+// CHECK:   %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK:   %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
+// CHECK:   async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
+// CHECK:   async.coro.suspend %[[SAVED_1]]
 // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]]
   %r = call @simple_callee(%c): (f32) -> f32
 

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 148a0b3c257d2..661d208e17662 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -328,8 +328,8 @@ func @async_value_operands() {
 
 // -----
 
-// CHECK-LABEL: @execute_asserttion
-func @execute_asserttion(%arg0: i1) {
+// CHECK-LABEL: @execute_assertion
+func @execute_assertion(%arg0: i1) {
   %token = async.execute {
     assert %arg0, "error"
     async.yield


        


More information about the Mlir-commits mailing list