[Mlir-commits] [mlir] f81f880 - [mlir] Lower async.func with async.coro and async.runtime operations
Eugene Zhulenev
llvmlistbot at llvm.org
Mon Nov 7 09:54:03 PST 2022
Author: yijiagu
Date: 2022-11-07T09:53:58-08:00
New Revision: f81f880871e04ef0284af14a141a58905e81cdd9
URL: https://github.com/llvm/llvm-project/commit/f81f880871e04ef0284af14a141a58905e81cdd9
DIFF: https://github.com/llvm/llvm-project/commit/f81f880871e04ef0284af14a141a58905e81cdd9.diff
LOG: [mlir] Lower async.func with async.coro and async.runtime operations
Lower async.func with async.coro and async.runtime operations
- This patch modifies AsyncToAsyncRuntime pass to add lowering async.func ops with coroutine cfg.
Example:
```
async.func @foo() -> !async.value<f32> {
%cst = arith.constant 42.0 : f32
return %cst: f32
}
```
After lowering:
```
func.func @foo() -> !async.value<f32> attributes {passthrough = ["presplitcoroutine"]} {
%0 = async.runtime.create : !async.value<f32>
%1 = async.coro.id
%2 = async.coro.begin %1
cf.br ^bb1
^bb1: // pred: ^bb0
%cst = arith.constant 4.200000e+01 : f32
async.runtime.store %cst, %0 : <f32>
async.runtime.set_available %0 : !async.value<f32>
cf.br ^bb2
^bb2: // pred: ^bb1
async.coro.free %1, %2
cf.br ^bb3
^bb3: // pred: ^bb2
async.coro.end %2
return %0 : !async.value<f32>
}
```
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D137462
Added:
mlir/test/mlir-cpu-runner/async-func.mlir
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 38f3717c70f9b..66c5b731b6e76 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -51,10 +51,6 @@ class AsyncToAsyncRuntimePass
} // namespace
-//===----------------------------------------------------------------------===//
-// async.execute op outlining to the coroutine functions.
-//===----------------------------------------------------------------------===//
-
/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
@@ -64,6 +60,12 @@ namespace {
struct CoroMachinery {
func::FuncOp func;
+ // Async function returns an optional token, followed by some async values
+ //
+ // async.func @foo() -> !async.value<T> {
+ // %cst = arith.constant 42.0 : T
+ // return %cst: T
+ // }
// Async execute region returns a completion token, and an async value for
// each yielded value.
//
@@ -71,12 +73,12 @@ struct CoroMachinery {
// %0 = arith.constant ... : T
// async.yield %0 : T
// }
- Value asyncToken; // token representing completion of the async region
+ Optional<Value> asyncToken; // returned completion token
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.getHandle value)
Block *entry; // coroutine entry block
- Block *setError; // switch completion token and all values to error state
+ Optional<Block *> setError; // set returned values to error state
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
};
@@ -87,13 +89,9 @@ struct CoroMachinery {
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
/// that branches into preexisting entry block. Also inserts trailing blocks.
///
-/// The result types of the passed `func` must start with an `async.token`
+/// The result types of the passed `func` start with an optional `async.token`
/// and be continued with some number of `async.value`s.
///
-/// The func given to this function needs to have been preprocessed to have
-/// either branch or yield ops as terminators. Branches to the cleanup block are
-/// inserted after each yield.
-///
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
/// - `entry` block sets up the coroutine.
@@ -110,7 +108,7 @@ struct CoroMachinery {
/// ^entry(<function-arguments>):
/// %token = <async token> : !async.token // create async runtime token
/// %value = <async value> : !async.value<T> // create async value
-/// %id = async.coro.getId // create a coroutine id
+/// %id = async.coro.getId // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// cf.br ^preexisting_entry_block
///
@@ -142,11 +140,20 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// ------------------------------------------------------------------------ //
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
- auto retToken =
- builder.create<RuntimeCreateOp>(TokenType::get(ctx)).getResult();
+
+ // We treat TokenType as state update marker to represent side-effects of
+ // async computations
+ bool isStateful = func.getCallableResults().front().isa<TokenType>();
+
+ Optional<Value> retToken;
+ if (isStateful)
+ retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
llvm::SmallVector<Value, 4> retValues;
- for (auto resType : func.getCallableResults().drop_front())
+ ArrayRef<Type> resValueTypes = isStateful
+ ? func.getCallableResults().drop_front()
+ : func.getCallableResults();
+ for (auto resType : resValueTypes)
retValues.emplace_back(
builder.create<RuntimeCreateOp>(resType).getResult());
@@ -179,26 +186,17 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// Mark the end of a coroutine: async.coro.end
builder.create<CoroEndOp>(coroHdlOp.getHandle());
- // Return created `async.token` and `async.values` from the suspend block.
- // This will be the return value of a coroutine ramp function.
- SmallVector<Value, 4> ret{retToken};
+ // Return created optional `async.token` and `async.values` from the suspend
+ // block. This will be the return value of a coroutine ramp function.
+ SmallVector<Value, 4> ret;
+ if (retToken)
+ ret.push_back(*retToken);
ret.insert(ret.end(), retValues.begin(), retValues.end());
builder.create<func::ReturnOp>(ret);
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
- for (Block &block : func.getBody().getBlocks()) {
- if (&block == entryBlock || &block == cleanupBlock ||
- &block == suspendBlock)
- continue;
- Operation *terminator = block.getTerminator();
- if (auto yield = dyn_cast<YieldOp>(terminator)) {
- builder.setInsertionPointToEnd(&block);
- builder.create<cf::BranchOp>(cleanupBlock);
- }
- }
-
// The switch-resumed API based coroutine should be marked with
// coroutine.presplit attribute to mark the function as a coroutine.
func->setAttr("passthrough", builder.getArrayAttr(
@@ -210,7 +208,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.getHandle();
machinery.entry = entryBlock;
- machinery.setError = nullptr; // created lazily only if needed
+ machinery.setError = None; // created lazily only if needed
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
return machinery;
@@ -220,25 +218,31 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// runtime operations (see for example lowering of assert operation).
static Block *setupSetErrorBlock(CoroMachinery &coro) {
if (coro.setError)
- return coro.setError;
+ return *coro.setError;
coro.setError = coro.func.addBlock();
- coro.setError->moveBefore(coro.cleanup);
+ (*coro.setError)->moveBefore(coro.cleanup);
auto builder =
- ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
+ ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
// Coroutine set_error block: set error on token and all returned values.
- builder.create<RuntimeSetErrorOp>(coro.asyncToken);
+ if (coro.asyncToken)
+ builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
+
for (Value retValue : coro.returnValues)
builder.create<RuntimeSetErrorOp>(retValue);
// Branch into the cleanup block.
builder.create<cf::BranchOp>(coro.cleanup);
- return coro.setError;
+ return *coro.setError;
}
+//===----------------------------------------------------------------------===//
+// async.execute op outlining to the coroutine functions.
+//===----------------------------------------------------------------------===//
+
/// Outline the body region attached to the `async.execute` op into a standalone
/// function.
///
@@ -382,6 +386,118 @@ class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Convert async.func, async.return and async.call operations to non-blocking
+// operations based on llvm coroutine
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Convert async.func operation to func.func
+//===----------------------------------------------------------------------===//
+
+class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
+public:
+ AsyncFuncOpLowering(MLIRContext *ctx,
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+ : OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
+
+ LogicalResult
+ matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+
+ auto newFuncOp =
+ rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
+
+ SymbolTable::setSymbolVisibility(newFuncOp,
+ SymbolTable::getSymbolVisibility(op));
+ // Copy over all attributes other than the name.
+ for (const auto &namedAttr : op->getAttrs()) {
+ if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
+ newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+ }
+
+ rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ CoroMachinery coro = setupCoroMachinery(newFuncOp);
+ coros[newFuncOp] = coro;
+ // no initial suspend, we should hot-start
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+};
+
+//===----------------------------------------------------------------------===//
+// Convert async.call operation to func.call
+//===----------------------------------------------------------------------===//
+
+class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
+public:
+ AsyncCallOpLowering(MLIRContext *ctx)
+ : OpConversionPattern<async::CallOp>(ctx) {}
+
+ LogicalResult
+ matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ op, op.getCallee(), op.getResultTypes(), op.getOperands());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Convert async.return operation to async.runtime operations.
+//===----------------------------------------------------------------------===//
+
+class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
+public:
+ AsyncReturnOpLowering(MLIRContext *ctx,
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+ : OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
+
+ LogicalResult
+ matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto func = op->template getParentOfType<func::FuncOp>();
+ auto funcCoro = coros.find(func);
+ if (funcCoro == coros.end())
+ return rewriter.notifyMatchFailure(
+ op, "operation is not inside the async coroutine function");
+
+ Location loc = op->getLoc();
+ const CoroMachinery &coro = funcCoro->getSecond();
+ rewriter.setInsertionPointAfter(op);
+
+ // Store return values into the async values storage and switch async
+ // values state to available.
+ for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
+ Value returnValue = std::get<0>(tuple);
+ Value asyncValue = std::get<1>(tuple);
+ rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
+ rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
+ }
+
+ if (coro.asyncToken)
+ // Switch the coroutine completion token to available state.
+ rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+
+ rewriter.eraseOp(op);
+ rewriter.create<cf::BranchOp>(loc, coro.cleanup);
+ return success();
+ }
+
+private:
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// Convert async.await and async.await_all operations to the async.runtime.await
// or async.runtime.await_and_resume operations.
@@ -393,11 +509,9 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
- AwaitOpLoweringBase(
- MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
- : OpConversionPattern<AwaitType>(ctx),
- outlinedFunctions(outlinedFunctions) {}
+ AwaitOpLoweringBase(MLIRContext *ctx,
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+ : OpConversionPattern<AwaitType>(ctx), coros(coros) {}
LogicalResult
matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
@@ -409,8 +523,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Check if await operation is inside the outlined coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto outlined = outlinedFunctions.find(func);
- const bool isInCoroutine = outlined != outlinedFunctions.end();
+ auto funcCoro = coros.find(func);
+ const bool isInCoroutine = funcCoro != coros.end();
Location loc = op->getLoc();
Value operand = adaptor.getOperand();
@@ -436,7 +550,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
if (isInCoroutine) {
- CoroMachinery &coro = outlined->getSecond();
+ CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
@@ -488,7 +602,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
};
/// Lowering for `async.await` with a token operand.
@@ -531,24 +645,22 @@ class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
- YieldOpLowering(
- MLIRContext *ctx,
- const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
- : OpConversionPattern<async::YieldOp>(ctx),
- outlinedFunctions(outlinedFunctions) {}
+ YieldOpLowering(MLIRContext *ctx,
+ const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+ : OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
LogicalResult
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if yield operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto outlined = outlinedFunctions.find(func);
- if (outlined == outlinedFunctions.end())
+ auto funcCoro = coros.find(func);
+ if (funcCoro == coros.end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
- const CoroMachinery &coro = outlined->getSecond();
+ const CoroMachinery &coro = funcCoro->getSecond();
// Store yielded values into the async values storage and switch async
// values state to available.
@@ -559,14 +671,18 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
}
- // Switch the coroutine completion token to available state.
- rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
+ if (coro.asyncToken)
+ // Switch the coroutine completion token to available state.
+ rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+
+ rewriter.eraseOp(op);
+ rewriter.create<cf::BranchOp>(loc, coro.cleanup);
return success();
}
private:
- const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+ const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
};
//===----------------------------------------------------------------------===//
@@ -575,24 +691,22 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public:
- AssertOpLowering(
- MLIRContext *ctx,
- llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
- : OpConversionPattern<cf::AssertOp>(ctx),
- outlinedFunctions(outlinedFunctions) {}
+ AssertOpLowering(MLIRContext *ctx,
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+ : OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if assert operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
- auto outlined = outlinedFunctions.find(func);
- if (outlined == outlinedFunctions.end())
+ auto funcCoro = coros.find(func);
+ if (funcCoro == coros.end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
- CoroMachinery &coro = outlined->getSecond();
+ CoroMachinery &coro = funcCoro->getSecond();
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
rewriter.setInsertionPointToEnd(cont->getPrevNode());
@@ -607,7 +721,7 @@ class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
}
private:
- llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+ llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
};
//===----------------------------------------------------------------------===//
@@ -615,22 +729,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
- // Outline all `async.execute` body regions into async functions (coroutines).
- llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions;
+ // Functions with coroutine CFG setups, which are results of outlining
+ // `async.execute` body regions and converting async.func.
+ llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
module.walk([&](ExecuteOp execute) {
- outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
+ coros.insert(outlineExecuteOp(symbolTable, execute));
});
LLVM_DEBUG({
- llvm::dbgs() << "Outlined " << outlinedFunctions.size()
+ llvm::dbgs() << "Outlined " << coros.size()
<< " functions built from async.execute operations\n";
});
// Returns true if operation is inside the coroutine.
auto isInCoroutine = [&](Operation *op) -> bool {
auto parentFunc = op->getParentOfType<func::FuncOp>();
- return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
+ return coros.find(parentFunc) != coros.end();
};
// Lower async operations to async.runtime operations.
@@ -646,18 +761,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
+
+ // Lower async.func to func.func with coroutine cfg.
+ asyncPatterns.add<AsyncCallOpLowering>(ctx);
+ asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
+
asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
- AwaitAllOpLowering, YieldOpLowering>(ctx,
- outlinedFunctions);
+ AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
// Lower assertions to conditional branches into error blocks.
- asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
+ asyncPatterns.add<AssertOpLowering>(ctx, coros);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
- runtimeTarget.addLegalDialect<AsyncDialect>();
+ runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
- runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
+ runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
+ async::FuncOp, async::CallOp, async::ReturnOp>();
// Decide if structured control flow has to be lowered to branch-based CFG.
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
@@ -675,7 +795,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
[&](cf::AssertOp op) -> bool {
auto func = op->getParentOfType<func::FuncOp>();
- return outlinedFunctions.find(func) == outlinedFunctions.end();
+ return coros.find(func) == coros.end();
});
if (failed(applyPartialConversion(module, runtimeTarget,
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index d7ebfb9e77926..1551e55c90c08 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -433,3 +433,25 @@ func.func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK-SAME: ) -> !async.token
// CHECK: %[[CST:.*]] = arith.constant 0 : index
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
+
+// -----
+// Async Functions should be none blocking
+
+// CHECK-LABEL: @async_func_await
+async.func @async_func_await(%arg0: f32, %arg1: !async.value<f32>)
+ -> !async.token {
+ %0 = async.await %arg1 : !async.value<f32>
+ return
+}
+// Create token for return op, and mark a function as a coroutine.
+// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+// CHECK: %[[ID:.*]] = async.coro.id
+// CHECK: %[[HDL:.*]] = async.coro.begin
+// CHECK: cf.br ^[[ORIGIN_ENTRY:.*]]
+
+// CHECK: ^[[ORIGIN_ENTRY]]:
+// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK: async.runtime.await_and_resume %[[arg1:.*]], %[[HDL]] :
+// CHECK-SAME: !async.value<f32>
+// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
new file mode 100644
index 0000000000000..8b3d728d4667f
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+// FIXME: https://github.com/llvm/llvm-project/issues/57231
+// UNSUPPORTED: hwasan
+
+async.func @async_func_empty() -> !async.token {
+ return
+}
+
+async.func @async_func_assert() -> !async.token {
+ %false = arith.constant 0 : i1
+ cf.assert %false, "error"
+ return
+}
+
+async.func @async_func_nested_assert() -> !async.token {
+ %token0 = async.call @async_func_assert() : () -> !async.token
+ async.await %token0 : !async.token
+ return
+}
+
+async.func @async_func_value_assert() -> !async.value<f32> {
+ %false = arith.constant 0 : i1
+ cf.assert %false, "error"
+ %0 = arith.constant 123.45 : f32
+ return %0 : f32
+}
+
+async.func @async_func_value_nested_assert() -> !async.value<f32> {
+ %value0 = async.call @async_func_value_assert() : () -> !async.value<f32>
+ %ret = async.await %value0 : !async.value<f32>
+ return %ret : f32
+}
+
+async.func @async_func_return_value() -> !async.value<f32> {
+ %0 = arith.constant 456.789 : f32
+ return %0 : f32
+}
+
+async.func @async_func_non_blocking_await() -> !async.value<f32> {
+ %value0 = async.call @async_func_return_value() : () -> !async.value<f32>
+ %1 = async.await %value0 : !async.value<f32>
+ return %1 : f32
+}
+
+async.func @async_func_inside_memref() -> !async.value<memref<f32>> {
+ %0 = memref.alloc() : memref<f32>
+ %c0 = arith.constant 0.25 : f32
+ memref.store %c0, %0[] : memref<f32>
+ return %0 : memref<f32>
+}
+
+async.func @async_func_passed_memref(%arg0 : !async.value<memref<f32>>) -> !async.token {
+ %unwrapped = async.await %arg0 : !async.value<memref<f32>>
+ %0 = memref.load %unwrapped[] : memref<f32>
+ %1 = arith.addf %0, %0 : f32
+ memref.store %1, %unwrapped[] : memref<f32>
+ return
+}
+
+
+func.func @main() {
+ %false = arith.constant 0 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check that simple async.func completes without errors.
+ // ------------------------------------------------------------------------ //
+ %token0 = async.call @async_func_empty() : () -> !async.token
+ async.runtime.await %token0 : !async.token
+
+ // CHECK: 0
+ %err0 = async.runtime.is_error %token0 : !async.token
+ vector.print %err0 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check that assertion in the async.func converted to async error.
+ // ------------------------------------------------------------------------ //
+ %token1 = async.call @async_func_assert() : () -> !async.token
+ async.runtime.await %token1 : !async.token
+
+ // CHECK: 1
+ %err1 = async.runtime.is_error %token1 : !async.token
+ vector.print %err1 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check error propagation from the nested async.func.
+ // ------------------------------------------------------------------------ //
+ %token2 = async.call @async_func_nested_assert() : () -> !async.token
+ async.runtime.await %token2 : !async.token
+
+ // CHECK: 1
+ %err2 = async.runtime.is_error %token2 : !async.token
+ vector.print %err2 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check error propagation from the nested async.func with async values.
+ // ------------------------------------------------------------------------ //
+ %value3 = async.call @async_func_value_nested_assert() : () -> !async.value<f32>
+ async.runtime.await %value3 : !async.value<f32>
+
+ // CHECK: 1
+ %err3_0 = async.runtime.is_error %value3 : !async.value<f32>
+ vector.print %err3_0 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Non-blocking async.await inside the async.func
+ // ------------------------------------------------------------------------ //
+ %result0 = async.call @async_func_non_blocking_await() : () -> !async.value<f32>
+ %4 = async.await %result0 : !async.value<f32>
+
+ // CHECK: 456.789
+ vector.print %4 : f32
+
+ // ------------------------------------------------------------------------ //
+ // Memref allocated inside async.func.
+ // ------------------------------------------------------------------------ //
+ %result1 = async.call @async_func_inside_memref() : () -> !async.value<memref<f32>>
+ %5 = async.await %result1 : !async.value<memref<f32>>
+ %6 = memref.cast %5 : memref<f32> to memref<*xf32>
+
+ // CHECK: Unranked Memref
+ // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+ // CHECK-NEXT: [0.25]
+ call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
+ // ------------------------------------------------------------------------ //
+ // Memref passed as async.func parameter
+ // ------------------------------------------------------------------------ //
+ %token3 = async.call @async_func_passed_memref(%result1) : (!async.value<memref<f32>>) -> !async.token
+ async.await %token3 : !async.token
+
+ // CHECK: Unranked Memref
+ // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+ // CHECK-NEXT: [0.5]
+ call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
+ memref.dealloc %5 : memref<f32>
+
+ return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+ attributes { llvm.emit_c_interface }
More information about the Mlir-commits
mailing list