[Mlir-commits] [mlir] 1c14441 - Refactor AsyncToAsyncRuntime pass to boost understandability.
Eugene Zhulenev
llvmlistbot at llvm.org
Thu Jul 29 12:01:19 PDT 2021
Author: bakhtiyar
Date: 2021-07-29T12:01:07-07:00
New Revision: 1c144410e791032dfea7dc744518a2c051ec0510
URL: https://github.com/llvm/llvm-project/commit/1c144410e791032dfea7dc744518a2c051ec0510
DIFF: https://github.com/llvm/llvm-project/commit/1c144410e791032dfea7dc744518a2c051ec0510.diff
LOG: Refactor AsyncToAsyncRuntime pass to boost understandability.
Depends On D106730
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D106731
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 5ca0d632b67e2..10dcba1f30444 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -67,6 +67,7 @@ struct CoroMachinery {
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.handle value)
+ Block *entry; // coroutine entry block
Block *setError; // switch completion token and all values to error state
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
@@ -75,16 +76,15 @@ struct CoroMachinery {
/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
-/// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block
-/// by prepending its ops with coroutine setup. Also inserts trailing blocks.
+/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
+/// that branches into preexisting entry block. Also inserts trailing blocks.
///
/// The result types of the passed `func` must start with an `async.token`
/// and be continued with some number of `async.value`s.
///
-/// It's up to the caller of this function to fix up the terminators of the
-/// preexisting blocks of the passed func op. If the passed `func` is legal,
-/// this typically means rewriting every return op as a yield op and a branch op
-/// to the suspend block.
+/// The func given to this function needs to have been preprocessed to have
+/// either branch or yield ops as terminators. Branches to the cleanup block are
+/// inserted after each yield.
///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
@@ -104,9 +104,9 @@ struct CoroMachinery {
/// %value = <async value> : !async.value<T> // create async value
/// %id = async.coro.id // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
-/// /* other ops of the preexisting entry block */
+/// br ^preexisting_entry_block
///
-/// /* other preexisting blocks */
+/// /* preexisting blocks modified to branch to the cleanup block */
///
/// ^set_error: // this block created lazily only if needed (see code below)
/// async.runtime.set_error %token : !async.token
@@ -127,6 +127,8 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
MLIRContext *ctx = func.getContext();
Block *entryBlock = &func.getBlocks().front();
+ Block *originalEntryBlock =
+ entryBlock->splitBlock(entryBlock->getOperations().begin());
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
// ------------------------------------------------------------------------ //
@@ -144,6 +146,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
+ builder.create<BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock();
Block *suspendBlock = func.addBlock();
@@ -175,11 +178,23 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
+ for (Block &block : func.body().getBlocks()) {
+ if (&block == entryBlock || &block == cleanupBlock ||
+ &block == suspendBlock)
+ continue;
+ Operation *terminator = block.getTerminator();
+ if (auto yield = dyn_cast<YieldOp>(terminator)) {
+ builder.setInsertionPointToEnd(&block);
+ builder.create<BranchOp>(cleanupBlock);
+ }
+ }
+
CoroMachinery machinery;
machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.handle();
+ machinery.entry = entryBlock;
machinery.setError = nullptr; // created lazily only if needed
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
@@ -241,68 +256,69 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
symbolTable.insert(func);
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
+ auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
+
+ // Prepare for coroutine conversion by creating the body of the function.
+ {
+ size_t numDependencies = execute.dependencies().size();
+ size_t numOperands = execute.operands().size();
+
+ // Await on all dependencies before starting to execute the body region.
+ for (size_t i = 0; i < numDependencies; ++i)
+ builder.create<AwaitOp>(func.getArgument(i));
+
+ // Await on all async value operands and unwrap the payload.
+ SmallVector<Value, 4> unwrappedOperands(numOperands);
+ for (size_t i = 0; i < numOperands; ++i) {
+ Value operand = func.getArgument(numDependencies + i);
+ unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
+ }
+
+ // Map from function inputs defined above the execute op to the function
+ // arguments.
+ BlockAndValueMapping valueMapping;
+ valueMapping.map(functionInputs, func.getArguments());
+ valueMapping.map(execute.body().getArguments(), unwrappedOperands);
+
+ // Clone all operations from the execute operation body into the outlined
+ // function body.
+ for (Operation &op : execute.body().getOps())
+ builder.clone(op, valueMapping);
+ }
- // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
- // blocks, adding async.coro operations and setting up control flow.
- func.addEntryBlock();
+ // Adding entry/cleanup/suspend blocks.
CoroMachinery coro = setupCoroMachinery(func);
// Suspend async function at the end of an entry block, and resume it using
// Async resume operation (execution will be resumed in a thread managed by
// the async runtime).
- Block *entryBlock = &func.getBlocks().front();
- auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock);
+ {
+ BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
+ builder.setInsertionPointToEnd(coro.entry);
- // Save the coroutine state: async.coro.save
- auto coroSaveOp =
- builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
+ // Save the coroutine state: async.coro.save
+ auto coroSaveOp =
+ builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
- // Pass coroutine to the runtime to be resumed on a runtime managed thread.
- builder.create<RuntimeResumeOp>(coro.coroHandle);
- builder.create<BranchOp>(coro.cleanup);
+ // Pass coroutine to the runtime to be resumed on a runtime managed
+ // thread.
+ builder.create<RuntimeResumeOp>(coro.coroHandle);
- // Split the entry block before the terminator (branch to suspend block).
- auto *terminatorOp = entryBlock->getTerminator();
- Block *suspended = terminatorOp->getBlock();
- Block *resume = suspended->splitBlock(terminatorOp);
-
- // Add async.coro.suspend as a suspended block terminator.
- builder.setInsertionPointToEnd(suspended);
- builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
- coro.cleanup);
-
- size_t numDependencies = execute.dependencies().size();
- size_t numOperands = execute.operands().size();
-
- // Await on all dependencies before starting to execute the body region.
- builder.setInsertionPointToStart(resume);
- for (size_t i = 0; i < numDependencies; ++i)
- builder.create<AwaitOp>(func.getArgument(i));
-
- // Await on all async value operands and unwrap the payload.
- SmallVector<Value, 4> unwrappedOperands(numOperands);
- for (size_t i = 0; i < numOperands; ++i) {
- Value operand = func.getArgument(numDependencies + i);
- unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
- }
+ // Add async.coro.suspend as a suspended block terminator.
+ builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
+ branch.getDest(), coro.cleanup);
- // Map from function inputs defined above the execute op to the function
- // arguments.
- BlockAndValueMapping valueMapping;
- valueMapping.map(functionInputs, func.getArguments());
- valueMapping.map(execute.body().getArguments(), unwrappedOperands);
-
- // Clone all operations from the execute operation body into the outlined
- // function body.
- for (Operation &op : execute.body().getOps())
- builder.clone(op, valueMapping);
+ branch.erase();
+ }
// Replace the original `async.execute` with a call to outlined function.
- ImplicitLocOpBuilder callBuilder(loc, execute);
- auto callOutlinedFunc = callBuilder.create<CallOp>(
- func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
- execute.replaceAllUsesWith(callOutlinedFunc.getResults());
- execute.erase();
+ {
+ ImplicitLocOpBuilder callBuilder(loc, execute);
+ auto callOutlinedFunc = callBuilder.create<CallOp>(
+ func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
+ execute.replaceAllUsesWith(callOutlinedFunc.getResults());
+ execute.erase();
+ }
return {func, coro};
}
@@ -575,20 +591,15 @@ static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
[](Type type) { return ValueType::get(type); });
func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
func.insertResult(0, TokenType::get(ctx), {});
- CoroMachinery coro = setupCoroMachinery(func);
for (Block &block : func.getBlocks()) {
- if (&block == coro.suspend)
- continue;
-
Operation *terminator = block.getTerminator();
if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
ImplicitLocOpBuilder builder(loc, returnOp);
builder.create<YieldOp>(returnOp.getOperands());
- builder.create<BranchOp>(coro.cleanup);
returnOp.erase();
}
}
- return coro;
+ return setupCoroMachinery(func);
}
/// Rewrites a call into a function that has been rewritten as a coroutine.
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
index 7718d06888db4..5baf362552889 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir
@@ -10,19 +10,20 @@ func @simple_callee(%arg0: f32) -> (f32 {builtin.foo = "bar"}) {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
-
-// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK ^[[ORIGINAL_ENTRY]]:
+// CHECK: %[[VAL:.*]] = addf %[[ARG]], %[[ARG]] : f32
%0 = addf %arg0, %arg0 : f32
-// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
+// CHECK: %[[VAL_STORAGE:.*]] = async.runtime.create : !async.value<f32>
%1 = async.runtime.create: !async.value<f32>
-// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK: async.runtime.store %[[VAL]], %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.store %0, %1: !async.value<f32>
-// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
+// CHECK: async.runtime.set_available %[[VAL_STORAGE]] : !async.value<f32>
async.runtime.set_available %1: !async.value<f32>
-// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK: async.runtime.await_and_resume %[[VAL_STORAGE]], %[[HDL]]
+// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
%2 = async.await %1 : !async.value<f32>
@@ -62,13 +63,15 @@ func @simple_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK ^[[ORIGINAL_ENTRY]]:
-// CHECK: %[[CONSTANT:.*]] = constant
+// CHECK: %[[CONSTANT:.*]] = constant
%c = constant 1.0 : f32
-// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
-// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER]]#0, %[[HDL]]
+// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
%r = call @simple_callee(%c): (f32) -> f32
@@ -109,13 +112,15 @@ func @double_caller() -> f32 {
// CHECK: %[[RETURNED_STORAGE:.*]] = async.runtime.create : !async.value<f32>
// CHECK: %[[ID:.*]] = async.coro.id
// CHECK: %[[HDL:.*]] = async.coro.begin %[[ID]]
+// CHECK: br ^[[ORIGINAL_ENTRY:.*]]
+// CHECK ^[[ORIGINAL_ENTRY]]:
-// CHECK: %[[CONSTANT:.*]] = constant
+// CHECK: %[[CONSTANT:.*]] = constant
%c = constant 1.0 : f32
-// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
-// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
-// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
-// CHECK: async.coro.suspend %[[SAVED_1]]
+// CHECK: %[[RETURNED_TO_CALLER_1:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value<f32>)
+// CHECK: %[[SAVED_1:.*]] = async.coro.save %[[HDL]]
+// CHECK: async.runtime.await_and_resume %[[RETURNED_TO_CALLER_1]]#0, %[[HDL]]
+// CHECK: async.coro.suspend %[[SAVED_1]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME_1:.*]], ^[[CLEANUP:.*]]
%r = call @simple_callee(%c): (f32) -> f32
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 148a0b3c257d2..661d208e17662 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -328,8 +328,8 @@ func @async_value_operands() {
// -----
-// CHECK-LABEL: @execute_asserttion
-func @execute_asserttion(%arg0: i1) {
+// CHECK-LABEL: @execute_assertion
+func @execute_assertion(%arg0: i1) {
%token = async.execute {
assert %arg0, "error"
async.yield
More information about the Mlir-commits
mailing list