[Mlir-commits] [mlir] a6628e5 - [mlir] Async: add automatic reference counting at async.runtime operations level
Eugene Zhulenev
llvmlistbot at llvm.org
Mon Apr 12 18:55:02 PDT 2021
Author: Eugene Zhulenev
Date: 2021-04-12T18:54:55-07:00
New Revision: a6628e596e70bf5c31058dec582c8a7907928e98
URL: https://github.com/llvm/llvm-project/commit/a6628e596e70bf5c31058dec582c8a7907928e98
DIFF: https://github.com/llvm/llvm-project/commit/a6628e596e70bf5c31058dec582c8a7907928e98.diff
LOG: [mlir] Async: add automatic reference counting at async.runtime operations level
Depends On D95311
Previous automatic-ref-counting pass worked with high level async operations (e.g. async.execute), however async values reference counting is a runtime implementation detail.
New pass mostly relies on the save liveness analysis to place drop_ref operations, and does better verification of CFG with different liveIn sets in block successors.
This is almost NFC change. No new reference counting ideas, just a cleanup of the previous version.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D95390
Added:
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
mlir/test/Dialect/Async/async-runtime-ref-counting.mlir
Modified:
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
mlir/test/mlir-cpu-runner/async-group.mlir
mlir/test/mlir-cpu-runner/async-value.mlir
mlir/test/mlir-cpu-runner/async.mlir
Removed:
mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
mlir/test/Dialect/Async/async-ref-counting.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 1f38136a2a30c..ddcfc8bdaeddf 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -22,11 +22,11 @@ std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
std::unique_ptr<OperationPass<FuncOp>>
createAsyncParallelForPass(int numWorkerThreads);
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
+std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
-std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 755d13a7cd221..155e23572bf80 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -24,24 +24,35 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
}
-def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
- let summary = "Automatic reference counting for Async dialect data types";
- let constructor = "mlir::createAsyncRefCountingPass()";
+def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
+ let summary = "Lower high level async operations (e.g. async.execute) to the"
+ "explicit async.runtime and async.coro operations";
+ let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect"];
}
-def AsyncRefCountingOptimization :
- FunctionPass<"async-ref-counting-optimization"> {
- let summary = "Optimize automatic reference counting operations for the"
- "Async dialect by removing redundant operations";
- let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
+def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
+ let summary = "Automatic reference counting for Async runtime operations";
+ let description = [{
+ This pass works at the async runtime abtraction level, after all
+ `async.execute` and `async.await` operations are lowered to the async
+ runtime API calls, and async coroutine operations.
+
+ It relies on the LLVM coroutines switched-resume lowering semantics for
+ the correct placing of the reference counting operations.
+
+ See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering
+ }];
+
+ let constructor = "mlir::createAsyncRuntimeRefCountingPass()";
let dependentDialects = ["async::AsyncDialect"];
}
-def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
- let summary = "Lower high level async operations (e.g. async.execute) to the"
- "explicit async.rutime and async.coro operations";
- let constructor = "mlir::createAsyncToAsyncRuntimePass()";
+def AsyncRuntimeRefCountingOpt :
+ FunctionPass<"async-runtime-ref-counting-opt"> {
+ let summary = "Optimize automatic reference counting operations for the"
+ "Async runtime by removing redundant operations";
+ let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
let dependentDialects = ["async::AsyncDialect"];
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
deleted file mode 100644
index d28da61bb2688..0000000000000
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
+++ /dev/null
@@ -1,325 +0,0 @@
-//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements automatic reference counting for Async dialect data
-// types.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Analysis/Liveness.h"
-#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/Async/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallSet.h"
-
-using namespace mlir;
-using namespace mlir::async;
-
-#define DEBUG_TYPE "async-ref-counting"
-
-namespace {
-
-class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
-public:
- AsyncRefCountingPass() = default;
- void runOnFunction() override;
-
-private:
- /// Adds an automatic reference counting to the `value`.
- ///
- /// All values are semantically created with a reference count of +1 and it is
- /// the responsibility of the last async value user to drop reference count.
- ///
- /// Async values created when:
- /// 1. Operation returns async result (e.g. the result of an
- /// `async.execute`).
- /// 2. Async value passed in as a block argument.
- ///
- /// To implement automatic reference counting, we must insert a +1 reference
- /// before each `async.execute` operation using the value, and drop it after
- /// the last use inside the async body region (we currently drop the reference
- /// before the `async.yield` terminator).
- ///
- /// Automatic reference counting algorithm outline:
- ///
- /// 1. `ReturnLike` operations forward the reference counted values without
- /// modifying the reference count.
- ///
- /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
- /// reference counted values ends, and insert `drop_ref` operations after
- /// the last use of the value.
- ///
- /// 3. Insert `add_ref` before the `async.execute` operation capturing the
- /// value, and pairing `drop_ref` before the async body region terminator,
- /// to release the captured reference counted value when execution
- /// completes.
- ///
- /// 4. If the reference counted value is passed only to some of the block
- /// successors, insert `drop_ref` operations in the beginning of the blocks
- /// that do not have reference counted value uses.
- ///
- ///
- /// Example:
- ///
- /// %token = ...
- /// async.execute {
- /// async.await %token : !async.token // await #1
- /// async.yield
- /// }
- /// async.await %token : !async.token // await #2
- ///
- /// Based on the liveness analysis await #2 is the last use of the %token,
- /// however the execution of the async region can be delayed, and to guarantee
- /// that the %token is still alive when await #1 executes we need to
- /// explicitly extend its lifetime using `add_ref` operation.
- ///
- /// After automatic reference counting:
- ///
- /// %token = ...
- ///
- /// // Make sure that %token is alive inside async.execute.
- /// async.add_ref %token {count = 1 : i32} : !async.token
- ///
- /// async.execute {
- /// async.await %token : !async.token // await #1
- ///
- /// // Drop the extra reference added to keep %token alive.
- /// async.drop_ref %token {count = 1 : i32} : !async.token
- ///
- /// async.yied
- /// }
- /// async.await %token : !async.token // await #2
- ///
- /// // Drop the reference after the last use of %token.
- /// async.drop_ref %token {count = 1 : i32} : !async.token
- ///
- LogicalResult addAutomaticRefCounting(Value value);
-};
-
-} // namespace
-
-LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
- MLIRContext *ctx = value.getContext();
- OpBuilder builder(ctx);
-
- // Set inserton point after the operation producing a value, or at the
- // beginning of the block if the value defined by the block argument.
- if (Operation *op = value.getDefiningOp())
- builder.setInsertionPointAfter(op);
- else
- builder.setInsertionPointToStart(value.getParentBlock());
-
- Location loc = value.getLoc();
- auto i32 = IntegerType::get(ctx, 32);
-
- // Drop the reference count immediately if the value has no uses.
- if (value.getUses().empty()) {
- builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
- return success();
- }
-
- // Use liveness analysis to find the placement of `drop_ref`operation.
- auto liveness = getAnalysis<Liveness>();
-
- // We analyse only the blocks of the region that defines the `value`, and do
- // not check nested blocks attached to operations.
- //
- // By analyzing only the `definingRegion` CFG we potentially loose an
- // opportunity to drop the reference count earlier and can extend the lifetime
- // of reference counted value longer then it is really required.
- //
- // We also assume that all nested regions finish their execution before the
- // completion of the owner operation. The only exception to this rule is
- // `async.execute` operation, which is handled explicitly below.
- Region *definingRegion = value.getParentRegion();
-
- // ------------------------------------------------------------------------ //
- // Find blocks where the `value` dies: the value is in `liveIn` set and not
- // in the `liveOut` set. We place `drop_ref` immediately after the last use
- // of the `value` in such regions.
- // ------------------------------------------------------------------------ //
-
- // Last users of the `value` inside all blocks where the value dies.
- llvm::SmallSet<Operation *, 4> lastUsers;
-
- for (Block &block : definingRegion->getBlocks()) {
- const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
-
- // Value in live input set or was defined in the block.
- bool liveIn = blockLiveness->isLiveIn(value) ||
- blockLiveness->getBlock() == value.getParentBlock();
- if (!liveIn)
- continue;
-
- // Value is in the live out set.
- bool liveOut = blockLiveness->isLiveOut(value);
- if (liveOut)
- continue;
-
- // We proved that `value` dies in the `block`. Now find the last use of the
- // `value` inside the `block`.
-
- // Find any user of the `value` inside the block (including uses in nested
- // regions attached to the operations in the block).
- Operation *userInTheBlock = nullptr;
- for (Operation *user : value.getUsers()) {
- userInTheBlock = block.findAncestorOpInBlock(*user);
- if (userInTheBlock)
- break;
- }
-
- // Values with zero users handled explicitly in the beginning, if the value
- // is in live out set it must have at least one use in the block.
- assert(userInTheBlock && "value must have a user in the block");
-
- // Find the last user of the `value` in the block;
- Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
- assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
- lastUsers.insert(lastUser);
- }
-
- // Process all the last users of the `value` inside each block where the value
- // dies.
- for (Operation *lastUser : lastUsers) {
- // Return like operations forward reference count.
- if (lastUser->hasTrait<OpTrait::ReturnLike>())
- continue;
-
- // We can't currently handle other types of terminators.
- if (lastUser->hasTrait<OpTrait::IsTerminator>())
- return lastUser->emitError() << "async reference counting can't handle "
- "terminators that are not ReturnLike";
-
- // Add a drop_ref immediately after the last user.
- builder.setInsertionPointAfter(lastUser);
- builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
- }
-
- // ------------------------------------------------------------------------ //
- // Find blocks where the `value` is in `liveOut` set, however it is not in
- // the `liveIn` set of all successors. If the `value` is not in the successor
- // `liveIn` set, we add a `drop_ref` to the beginning of it.
- // ------------------------------------------------------------------------ //
-
- // Successors that we'll need a `drop_ref` for the `value`.
- llvm::SmallSet<Block *, 4> dropRefSuccessors;
-
- for (Block &block : definingRegion->getBlocks()) {
- const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
-
- // Skip the block if value is not in the `liveOut` set.
- if (!blockLiveness->isLiveOut(value))
- continue;
-
- // Find successors that do not have `value` in the `liveIn` set.
- for (Block *successor : block.getSuccessors()) {
- const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
-
- if (!succLiveness->isLiveIn(value))
- dropRefSuccessors.insert(successor);
- }
- }
-
- // Drop reference in all successor blocks that do not have the `value` in
- // their `liveIn` set.
- for (Block *dropRefSuccessor : dropRefSuccessors) {
- builder.setInsertionPointToStart(dropRefSuccessor);
- builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
- }
-
- // ------------------------------------------------------------------------ //
- // Find all `async.execute` operation that take `value` as an operand
- // (dependency token or async value), or capture implicitly by the nested
- // region. Each `async.execute` operation will require `add_ref` operation
- // to keep all captured values alive until it will finish its execution.
- // ------------------------------------------------------------------------ //
-
- llvm::SmallSet<ExecuteOp, 4> executeOperations;
-
- auto trackAsyncExecute = [&](Operation *op) {
- if (auto execute = dyn_cast<ExecuteOp>(op))
- executeOperations.insert(execute);
- };
-
- for (Operation *user : value.getUsers()) {
- // Follow parent operations up until the operation in the `definingRegion`.
- while (user->getParentRegion() != definingRegion) {
- trackAsyncExecute(user);
- user = user->getParentOp();
- assert(user != nullptr && "value user lies outside of the value region");
- }
-
- // Don't forget to process the parent in the `definingRegion` (can be the
- // original user operation itself).
- trackAsyncExecute(user);
- }
-
- // Process all `async.execute` operations capturing `value`.
- for (ExecuteOp execute : executeOperations) {
- // Add a reference before the execute operation to keep the reference
- // counted alive before the async region completes execution.
- builder.setInsertionPoint(execute.getOperation());
- builder.create<RuntimeAddRefOp>(loc, value, IntegerAttr::get(i32, 1));
-
- // Drop the reference inside the async region before completion.
- OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
- executeBuilder.create<RuntimeDropRefOp>(loc, value,
- IntegerAttr::get(i32, 1));
- }
-
- return success();
-}
-
-void AsyncRefCountingPass::runOnFunction() {
- FuncOp func = getFunction();
-
- // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
- // because otherwise automatic reference counting will produce incorrect
- // results.
- WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
- if (isa<RuntimeAddRefOp, RuntimeDropRefOp>(op))
- return op->emitError() << "explicit reference counting is not supported";
- return WalkResult::advance();
- });
-
- if (refCountingWalk.wasInterrupted())
- signalPassFailure();
-
- // Add reference counting to block arguments.
- WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
- for (BlockArgument arg : block->getArguments())
- if (isRefCounted(arg.getType()))
- if (failed(addAutomaticRefCounting(arg)))
- return WalkResult::interrupt();
-
- return WalkResult::advance();
- });
-
- if (blockWalk.wasInterrupted())
- signalPassFailure();
-
- // Add reference counting to operation results.
- WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
- for (unsigned i = 0; i < op->getNumResults(); ++i)
- if (isRefCounted(op->getResultTypes()[i]))
- if (failed(addAutomaticRefCounting(op->getResult(i))))
- return WalkResult::interrupt();
-
- return WalkResult::advance();
- });
-
- if (opWalk.wasInterrupted())
- signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
- return std::make_unique<AsyncRefCountingPass>();
-}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
deleted file mode 100644
index 6ac2fd12fa113..0000000000000
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
+++ /dev/null
@@ -1,218 +0,0 @@
-//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Optimize Async dialect reference counting operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/Async/Passes.h"
-#include "llvm/ADT/SmallSet.h"
-
-using namespace mlir;
-using namespace mlir::async;
-
-#define DEBUG_TYPE "async-ref-counting"
-
-namespace {
-
-class AsyncRefCountingOptimizationPass
- : public AsyncRefCountingOptimizationBase<
- AsyncRefCountingOptimizationPass> {
-public:
- AsyncRefCountingOptimizationPass() = default;
- void runOnFunction() override;
-
-private:
- LogicalResult optimizeReferenceCounting(Value value);
-};
-
-} // namespace
-
-LogicalResult
-AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
- Region *definingRegion = value.getParentRegion();
-
- // Find all users of the `value` inside each block, including operations that
- // do not use `value` directly, but have a direct use inside nested region(s).
- //
- // Example:
- //
- // ^bb1:
- // %token = ...
- // scf.if %cond {
- // ^bb2:
- // async.await %token : !async.token
- // }
- //
- // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
- //
- // In addition to the operation that uses the `value` we also keep track if
- // this user is an `async.execute` operation itself, or has `async.execute`
- // operations in the nested regions that do use the `value`.
-
- struct UserInfo {
- Operation *operation;
- bool hasExecuteUser;
- };
-
- struct BlockUsersInfo {
- llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
- llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
- llvm::SmallVector<UserInfo, 4> users;
- };
-
- llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
-
- auto updateBlockUsersInfo = [&](UserInfo user) {
- BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
- info.users.push_back(user);
-
- if (auto addRef = dyn_cast<RuntimeAddRefOp>(user.operation))
- info.addRefs.push_back(addRef);
- if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user.operation))
- info.dropRefs.push_back(dropRef);
- };
-
- for (Operation *user : value.getUsers()) {
- bool isAsyncUser = isa<ExecuteOp>(user);
-
- while (user->getParentRegion() != definingRegion) {
- updateBlockUsersInfo({user, isAsyncUser});
- user = user->getParentOp();
- isAsyncUser |= isa<ExecuteOp>(user);
- assert(user != nullptr && "value user lies outside of the value region");
- }
-
- updateBlockUsersInfo({user, isAsyncUser});
- }
-
- // Sort all operations found in the block.
- auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
- auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
- return a->isBeforeInBlock(b);
- };
- llvm::sort(info.addRefs, isBeforeInBlock);
- llvm::sort(info.dropRefs, isBeforeInBlock);
- llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
- return isBeforeInBlock(a.operation, b.operation);
- });
-
- return info;
- };
-
- // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
- // blocks that modify the reference count of the `value`.
- for (auto &kv : blockUsers) {
- BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
-
- // Find all cancellable pairs first and erase them later to keep all
- // pointers in the `info` valid until the end.
- //
- // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
- llvm::SmallDenseMap<Operation *, Operation *> cancellable;
-
- for (RuntimeAddRefOp addRef : info.addRefs) {
- for (RuntimeDropRefOp dropRef : info.dropRefs) {
- // `drop_ref` operation after the `add_ref` with matching count.
- if (dropRef.count() != addRef.count() ||
- dropRef->isBeforeInBlock(addRef.getOperation()))
- continue;
-
- // `drop_ref` was already marked for removal.
- if (cancellable.find(dropRef.getOperation()) != cancellable.end())
- continue;
-
- // Check `value` users between `addRef` and `dropRef` in the `block`.
- Operation *addRefOp = addRef.getOperation();
- Operation *dropRefOp = dropRef.getOperation();
-
- // If there is a "regular" user after the `async.execute` user it is
- // unsafe to erase cancellable reference counting operations pair,
- // because async region can complete before the "regular" user and
- // destroy the reference counted value.
- bool hasExecuteUser = false;
- bool unsafeToCancel = false;
-
- for (UserInfo &user : info.users) {
- Operation *op = user.operation;
-
- // `user` operation lies after `addRef` ...
- if (op == addRefOp || op->isBeforeInBlock(addRefOp))
- continue;
- // ... and before `dropRef`.
- if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
- break;
-
- bool isRegularUser = !user.hasExecuteUser;
- bool isExecuteUser = user.hasExecuteUser;
-
- // It is unsafe to cancel `addRef` / `dropRef` pair.
- if (isRegularUser && hasExecuteUser) {
- unsafeToCancel = true;
- break;
- }
-
- hasExecuteUser |= isExecuteUser;
- }
-
- // Mark the pair of reference counting operations for removal.
- if (!unsafeToCancel)
- cancellable[dropRef.getOperation()] = addRef.getOperation();
-
- // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
- // all the following pairs will be also unsafe.
- break;
- }
- }
-
- // Erase all cancellable `addRef <-> dropRef` operation pairs.
- for (auto &kv : cancellable) {
- kv.first->erase();
- kv.second->erase();
- }
- }
-
- return success();
-}
-
-void AsyncRefCountingOptimizationPass::runOnFunction() {
- FuncOp func = getFunction();
-
- // Optimize reference counting for values defined by block arguments.
- WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
- for (BlockArgument arg : block->getArguments())
- if (isRefCounted(arg.getType()))
- if (failed(optimizeReferenceCounting(arg)))
- return WalkResult::interrupt();
-
- return WalkResult::advance();
- });
-
- if (blockWalk.wasInterrupted())
- signalPassFailure();
-
- // Optimize reference counting for values defined by operation results.
- WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
- for (unsigned i = 0; i < op->getNumResults(); ++i)
- if (isRefCounted(op->getResultTypes()[i]))
- if (failed(optimizeReferenceCounting(op->getResult(i))))
- return WalkResult::interrupt();
-
- return WalkResult::advance();
- });
-
- if (opWalk.wasInterrupted())
- signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRefCountingOptimizationPass() {
- return std::make_unique<AsyncRefCountingOptimizationPass>();
-}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
new file mode 100644
index 0000000000000..af443918df970
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -0,0 +1,377 @@
+//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements automatic reference counting for Async runtime
+// operations and types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-runtime-ref-counting"
+
+namespace {
+
+class AsyncRuntimeRefCountingPass
+ : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
+public:
+ AsyncRuntimeRefCountingPass() = default;
+ void runOnFunction() override;
+
+private:
+ /// Adds an automatic reference counting to the `value`.
+ ///
+ /// All values (token, group or value) are semantically created with a
+ /// reference count of +1 and it is the responsibility of the async value user
+ /// to place the `add_ref` and `drop_ref` operations to ensure that the value
+ /// is destroyed after the last use.
+ ///
+ /// The function returns failure if it can't deduce the locations where
+ /// to place the reference counting operations.
+ ///
+ /// Async values "semantically created" when:
+ /// 1. Operation returns async result (e.g. `async.runtime.create`)
+ /// 2. Async value passed in as a block argument (or function argument,
+ /// because function arguments are just entry block arguments)
+ ///
+ /// Passing async value as a function argument (or block argument) does not
+ /// really mean that a new async value is created, it only means that the
+ /// caller of a function transfered ownership of `+1` reference to the callee.
+ /// It is convenient to think that from the callee perspective async value was
+ /// "created" with `+1` reference by the block argument.
+ ///
+ /// Automatic reference counting algorithm outline:
+ ///
+ /// #1 Insert `drop_ref` operations after last use of the `value`.
+ /// #2 Insert `add_ref` operations before functions calls with reference
+ /// counted `value` operand (newly created `+1` reference will be
+ /// transferred to the callee).
+ /// #3 Verify that divergent control flow does not lead to leaked reference
+ /// counted objects.
+ ///
+ /// Async runtime reference counting optimization pass will optimize away
+ /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
+ /// strategy (see `async-runtime-ref-counting-opt`).
+ LogicalResult addAutomaticRefCounting(Value value);
+
+ /// (#1) Adds the `drop_ref` operation after the last use of the `value`
+ /// relying on the liveness analysis.
+ ///
+ /// If the `value` is in the block `liveIn` set and it is not in the block
+ /// `liveOut` set, it means that it "dies" in the block. We find the last
+ /// use of the value in such block and:
+ ///
+ /// 1. If the last user is a `ReturnLike` operation we do nothing, because
+ /// it forwards the ownership to the caller.
+ /// 2. Otherwise we add a `drop_ref` operation immediately after the last
+ /// use.
+ LogicalResult addDropRefAfterLastUse(Value value);
+
+ /// (#2) Adds the `add_ref` operation before the function call taking `value`
+ /// operand to ensure that the value passed to the function entry block
+ /// has a `+1` reference count.
+ LogicalResult addAddRefBeforeFunctionCall(Value value);
+
+ /// (#3) Verifies that if a block has a value in the `liveOut` set, then the
+ /// value is in `liveIn` set in all successors.
+ ///
+ /// Example:
+ ///
+ /// ^entry:
+ /// %token = async.runtime.create : !async.token
+ /// cond_br %cond, ^bb1, ^bb2
+ /// ^bb1:
+ /// async.runtime.await %token
+ /// return
+ /// ^bb2:
+ /// return
+ ///
+ /// This CFG will be rejected because ^bb2 does not have `value` in the
+ /// `liveIn` set, and it will leak a reference counted object.
+ ///
+ /// An exception to this rule are blocks with `async.coro.suspend` terminator,
+ /// because in Async to LLVM lowering it is guaranteed that the control flow
+ /// will jump into the resume block, and then follow into the cleanup and
+ /// suspend blocks.
+ ///
+ /// Example:
+ ///
+ /// ^entry(%value: !async.value<f32>):
+ /// async.runtime.await_and_resume %value, %hdl : !async.value<f32>
+ /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
+ /// ^resume:
+ /// %0 = async.runtime.load %value
+ /// br ^cleanup
+ /// ^cleanup:
+ /// ...
+ /// ^suspend:
+ /// ...
+ ///
+ /// Although cleanup and suspend blocks do not have the `value` in the
+ /// `liveIn` set, it is guaranteed that execution will eventually continue in
+ /// the resume block (we never explicitly destroy coroutines).
+ LogicalResult verifySuccessors(Value value);
+};
+
+} // namespace
+
+LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
+ OpBuilder builder(value.getContext());
+ Location loc = value.getLoc();
+
+ // Use liveness analysis to find the placement of `drop_ref`operation.
+ auto &liveness = getAnalysis<Liveness>();
+
+ // We analyse only the blocks of the region that defines the `value`, and do
+ // not check nested blocks attached to operations.
+ //
+ // By analyzing only the `definingRegion` CFG we potentially loose an
+ // opportunity to drop the reference count earlier and can extend the lifetime
+ // of reference counted value longer then it is really required.
+ //
+ // We also assume that all nested regions finish their execution before the
+ // completion of the owner operation. The only exception to this rule is
+ // `async.execute` operation, and we verify that they are lowered to the
+ // `async.runtime` operations before adding automatic reference counting.
+ Region *definingRegion = value.getParentRegion();
+
+ // Last users of the `value` inside all blocks where the value dies.
+ llvm::SmallSet<Operation *, 4> lastUsers;
+
+ // Find blocks in the `definingRegion` that have users of the `value` (if
+ // there are multiple users in the block, which one will be selected is
+ // undefined). User operation might be not the actual user of the value, but
+ // the operation in the block that has a "real user" in one of the attached
+ // regions.
+ llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
+
+ for (Operation *user : value.getUsers()) {
+ Block *userBlock = user->getBlock();
+ Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
+ usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
+ assert(ancestor && "ancestor block must be not null");
+ assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
+ }
+
+ // Find blocks where the `value` dies: the value is in `liveIn` set and not
+ // in the `liveOut` set. We place `drop_ref` immediately after the last use
+ // of the `value` in such regions (after handling few special cases).
+ //
+ // We do not traverse all the blocks in the `definingRegion`, because the
+ // `value` can be in the live in set only if it has users in the block, or it
+ // is defined in the block.
+ //
+ // Values with zero users (only definition) handled explicitly above.
+ for (auto &blockAndUser : usersInTheBlocks) {
+ Block *block = blockAndUser.getFirst();
+ Operation *userInTheBlock = blockAndUser.getSecond();
+
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
+
+ // Value must be in the live input set or defined in the block.
+ assert(blockLiveness->isLiveIn(value) ||
+ blockLiveness->getBlock() == value.getParentBlock());
+
+ // If value is in the live out set, it means it doesn't "die" in the block.
+ if (blockLiveness->isLiveOut(value))
+ continue;
+
+ // At this point we proved that `value` dies in the `block`. Find the last
+ // use of the `value` inside the `block`, this is where it "dies".
+ Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
+ assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
+ lastUsers.insert(lastUser);
+ }
+
+ // Process all the last users of the `value` inside each block where the value
+ // dies.
+ for (Operation *lastUser : lastUsers) {
+ // Return like operations forward reference count.
+ if (lastUser->hasTrait<OpTrait::ReturnLike>())
+ continue;
+
+ // We can't currently handle other types of terminators.
+ if (lastUser->hasTrait<OpTrait::IsTerminator>())
+ return lastUser->emitError() << "async reference counting can't handle "
+ "terminators that are not ReturnLike";
+
+ // Add a drop_ref immediately after the last user.
+ builder.setInsertionPointAfter(lastUser);
+ builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+ }
+
+ return success();
+}
+
+LogicalResult
+AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
+ OpBuilder builder(value.getContext());
+ Location loc = value.getLoc();
+
+ for (Operation *user : value.getUsers()) {
+ if (!isa<CallOp>(user))
+ continue;
+
+ // Add a reference before the function call to pass the value at `+1`
+ // reference to the function entry block.
+ builder.setInsertionPoint(user);
+ builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
+ }
+
+ return success();
+}
+
+LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) {
+ OpBuilder builder(value.getContext());
+
+ // Blocks with successfors with
diff erent `liveIn` properties of the `value`.
+ llvm::SmallSet<Block *, 4> divergentLivenessBlocks;
+
+ // Use liveness analysis to find the placement of `drop_ref`operation.
+ auto &liveness = getAnalysis<Liveness>();
+
+ // Because we only add `drop_ref` operations to the region that defines the
+ // `value` we can only process CFG for the same region.
+ Region *definingRegion = value.getParentRegion();
+
+ // Collect blocks with successors with mismatching `liveIn` sets.
+ for (Block &block : definingRegion->getBlocks()) {
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+ // Skip the block if value is not in the `liveOut` set.
+ if (!blockLiveness->isLiveOut(value))
+ continue;
+
+ // Sucessors with value in `liveIn` set and not value in `liveIn` set.
+ llvm::SmallSet<Block *, 4> liveInSuccessors;
+ llvm::SmallSet<Block *, 4> noLiveInSuccessors;
+
+ // Collect successors that do not have `value` in the `liveIn` set.
+ for (Block *successor : block.getSuccessors()) {
+ const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
+ if (succLiveness->isLiveIn(value))
+ liveInSuccessors.insert(successor);
+ else
+ noLiveInSuccessors.insert(successor);
+ }
+
+ // Block has successors with
diff erent `liveIn` property of the `value`.
+ if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
+ divergentLivenessBlocks.insert(&block);
+ }
+
+ // Verify that divergent `liveIn` property only present in blocks with
+ // async.coro.suspend terminator.
+ for (Block *block : divergentLivenessBlocks) {
+ Operation *terminator = block->getTerminator();
+ if (isa<CoroSuspendOp>(terminator))
+ continue;
+
+ return terminator->emitOpError("successor have
diff erent `liveIn` property "
+ "of the reference counted value: ");
+ }
+
+ return success();
+}
+
+LogicalResult
+AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
+ OpBuilder builder(value.getContext());
+ Location loc = value.getLoc();
+
+ // Set inserton point after the operation producing a value, or at the
+ // beginning of the block if the value defined by the block argument.
+ if (Operation *op = value.getDefiningOp())
+ builder.setInsertionPointAfter(op);
+ else
+ builder.setInsertionPointToStart(value.getParentBlock());
+
+ // Drop the reference count immediately if the value has no uses.
+ if (value.getUses().empty()) {
+ builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+ return success();
+ }
+
+ // Add `drop_ref` operations based on the liveness analysis.
+ if (failed(addDropRefAfterLastUse(value)))
+ return failure();
+
+ // Add `add_ref` operations before function calls.
+ if (failed(addAddRefBeforeFunctionCall(value)))
+ return failure();
+
+ // Verify that the `value` is in `liveIn` set of all successors.
+ if (failed(verifySuccessors(value)))
+ return failure();
+
+ return success();
+}
+
+void AsyncRuntimeRefCountingPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Check that we do not have high level async operations in the IR because
+ // otherwise automatic reference counting will produce incorrect results after
+ // execute operations will be lowered to `async.runtime`
+ WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
+ if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
+ return WalkResult::advance();
+
+ return op->emitError()
+ << "async operations must be lowered to async runtime operations";
+ });
+
+ if (executeOpWalk.wasInterrupted()) {
+ signalPassFailure();
+ return;
+ }
+
+ // Add reference counting to block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(addAutomaticRefCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted()) {
+ signalPassFailure();
+ return;
+ }
+
+ // Add reference counting to operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(addAutomaticRefCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRuntimeRefCountingPass() {
+ return std::make_unique<AsyncRuntimeRefCountingPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
new file mode 100644
index 0000000000000..cb00d706ce0c8
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
@@ -0,0 +1,177 @@
+//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimize Async dialect reference counting operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRuntimeRefCountingOptPass
+ : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
+public:
+ AsyncRuntimeRefCountingOptPass() = default;
+ void runOnFunction() override;
+
+private:
+ LogicalResult optimizeReferenceCounting(
+ Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
+};
+
+} // namespace
+
+LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
+ Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
+ Region *definingRegion = value.getParentRegion();
+
+ // Find all users of the `value` inside each block, including operations that
+ // do not use `value` directly, but have a direct use inside nested region(s).
+ //
+ // Example:
+ //
+ // ^bb1:
+ // %token = ...
+ // scf.if %cond {
+ // ^bb2:
+ // async.runtime.await %token : !async.token
+ // }
+ //
+ // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
+ // (`scf.if`).
+
+ struct BlockUsersInfo {
+ llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
+ llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
+ llvm::SmallVector<Operation *, 4> users;
+ };
+
+ llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
+
+ auto updateBlockUsersInfo = [&](Operation *user) {
+ BlockUsersInfo &info = blockUsers[user->getBlock()];
+ info.users.push_back(user);
+
+ if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
+ info.addRefs.push_back(addRef);
+ if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
+ info.dropRefs.push_back(dropRef);
+ };
+
+ for (Operation *user : value.getUsers()) {
+ while (user->getParentRegion() != definingRegion) {
+ updateBlockUsersInfo(user);
+ user = user->getParentOp();
+ assert(user != nullptr && "value user lies outside of the value region");
+ }
+
+ updateBlockUsersInfo(user);
+ }
+
+ // Sort all operations found in the block.
+ auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
+ auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
+ return a->isBeforeInBlock(b);
+ };
+ llvm::sort(info.addRefs, isBeforeInBlock);
+ llvm::sort(info.dropRefs, isBeforeInBlock);
+ llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
+ return isBeforeInBlock(a, b);
+ });
+
+ return info;
+ };
+
+ // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
+ // blocks that modify the reference count of the `value`.
+ for (auto &kv : blockUsers) {
+ BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
+
+ for (RuntimeAddRefOp addRef : info.addRefs) {
+ for (RuntimeDropRefOp dropRef : info.dropRefs) {
+ // `drop_ref` operation after the `add_ref` with matching count.
+ if (dropRef.count() != addRef.count() ||
+ dropRef->isBeforeInBlock(addRef.getOperation()))
+ continue;
+
+ // Try to cancel the pair of `add_ref` and `drop_ref` operations.
+ auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
+ addRef.getOperation());
+
+ if (!emplaced.second) // `drop_ref` was already marked for removal
+ continue; // go to the next `drop_ref`
+
+ if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
+ break; // go to the next `add_ref`
+ }
+ }
+ }
+
+ return success();
+}
+
+void AsyncRuntimeRefCountingOptPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
+ //
+ // Find all cancellable pairs of operation and erase them in the end to keep
+ // all iterators valid while we are walking the function operations.
+ llvm::SmallDenseMap<Operation *, Operation *> cancellable;
+
+ // Optimize reference counting for values defined by block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(optimizeReferenceCounting(arg, cancellable)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Optimize reference counting for values defined by operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << cancellable.size()
+ << " cancellable reference counting operations\n";
+ });
+
+ // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
+ for (auto &kv : cancellable) {
+ kv.first->erase();
+ kv.second->erase();
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRuntimeRefCountingOptPass() {
+ return std::make_unique<AsyncRuntimeRefCountingOptPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
index 8056758ac308a..45fb77f443a00 100644
--- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAsyncTransforms
AsyncParallelFor.cpp
- AsyncRefCounting.cpp
- AsyncRefCountingOptimization.cpp
+ AsyncRuntimeRefCounting.cpp
+ AsyncRuntimeRefCountingOpt.cpp
AsyncToAsyncRuntime.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
deleted file mode 100644
index fdf326ef50340..0000000000000
--- a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
+++ /dev/null
@@ -1,114 +0,0 @@
-// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
-
-// CHECK-LABEL: @cancellable_operations_0
-func @cancellable_operations_0(%arg0: !async.token) {
- // CHECK-NOT: async.runtime.add_ref
- // CHECK-NOT: async.runtime.drop_ref
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @cancellable_operations_1
-func @cancellable_operations_1(%arg0: !async.token) {
- // CHECK-NOT: async.runtime.add_ref
- // CHECK: async.execute
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- async.execute [%arg0] {
- // CHECK: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK-NEXT: async.yield
- async.yield
- }
- // CHECK-NOT: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @cancellable_operations_2
-func @cancellable_operations_2(%arg0: !async.token) {
- // CHECK: async.await
- // CHECK-NEXT: async.await
- // CHECK-NEXT: async.await
- // CHECK-NEXT: return
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- async.await %arg0 : !async.token
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- async.await %arg0 : !async.token
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- async.await %arg0 : !async.token
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- return
-}
-
-// CHECK-LABEL: @cancellable_operations_3
-func @cancellable_operations_3(%arg0: !async.token) {
- // CHECK-NOT: add_ref
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- %token = async.execute {
- async.await %arg0 : !async.token
- // CHECK: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- async.yield
- }
- // CHECK-NOT: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: async.await
- async.await %arg0 : !async.token
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @not_cancellable_operations_0
-func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
- // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
- // that the body of the `async.execute` operation will run before the await
- // operation in the function body, and will destroy the `%arg0` token.
- // CHECK: add_ref
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- %token = async.execute {
- // CHECK: async.await
- async.await %arg0 : !async.token
- // CHECK: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: async.yield
- async.yield
- }
- // CHECK: async.await
- async.await %arg0 : !async.token
- // CHECK: drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @not_cancellable_operations_1
-func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
- // Same reason as above, although `async.execute` is inside the nested
- // region or "regular" operation.
- //
- // NOTE: This test is not correct w.r.t. reference counting, and at runtime
- // would leak %arg0 value if %arg1 is false. IR like this will not be
- // constructed by automatic reference counting pass, because it would
- // place `async.runtime.add_ref` right before the `async.execute`
- // inside `scf.if`.
-
- // CHECK: async.runtime.add_ref
- async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
- scf.if %arg1 {
- %token = async.execute {
- async.await %arg0 : !async.token
- // CHECK: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- async.yield
- }
- }
- // CHECK: async.await
- async.await %arg0 : !async.token
- // CHECK: async.runtime.drop_ref
- async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
- // CHECK: return
- return
-}
diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir
deleted file mode 100644
index 403747c8725b8..0000000000000
--- a/mlir/test/Dialect/Async/async-ref-counting.mlir
+++ /dev/null
@@ -1,253 +0,0 @@
-// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
-
-// CHECK-LABEL: @cond
-func private @cond() -> i1
-
-// CHECK-LABEL: @token_arg_no_uses
-func @token_arg_no_uses(%arg0: !async.token) {
- // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
- return
-}
-
-// CHECK-LABEL: @token_arg_conditional_await
-func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
- cond_br %arg1, ^bb1, ^bb2
-^bb1:
- // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
- return
-^bb2:
- // CHECK: async.await %arg0
- // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
- async.await %arg0 : !async.token
- return
-}
-
-// CHECK-LABEL: @token_no_uses
-func @token_no_uses() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- %token = async.execute {
- async.yield
- }
- return
-}
-
-// CHECK-LABEL: @token_return
-func @token_return() -> !async.token {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- // CHECK: return %[[TOKEN]]
- return %token : !async.token
-}
-
-// CHECK-LABEL: @token_await
-func @token_await() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- // CHECK: async.await %[[TOKEN]]
- async.await %token : !async.token
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @token_await_and_return
-func @token_await_and_return() -> !async.token {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- // CHECK: async.await %[[TOKEN]]
- // CHECK-NOT: async.runtime.drop_ref
- async.await %token : !async.token
- // CHECK: return %[[TOKEN]]
- return %token : !async.token
-}
-
-// CHECK-LABEL: @token_await_inside_scf_if
-func @token_await_inside_scf_if(%arg0: i1) {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- // CHECK: scf.if %arg0 {
- scf.if %arg0 {
- // CHECK: async.await %[[TOKEN]]
- async.await %token : !async.token
- }
- // CHECK: }
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @token_conditional_await
-func @token_conditional_await(%arg0: i1) {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- cond_br %arg0, ^bb1, ^bb2
-^bb1:
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- return
-^bb2:
- // CHECK: async.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.await %token : !async.token
- return
-}
-
-// CHECK-LABEL: @token_await_in_the_loop
-func @token_await_in_the_loop() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- br ^bb1
-^bb1:
- // CHECK: async.await %[[TOKEN]]
- async.await %token : !async.token
- %0 = call @cond(): () -> (i1)
- cond_br %0, ^bb1, ^bb2
-^bb2:
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- return
-}
-
-// CHECK-LABEL: @token_defined_in_the_loop
-func @token_defined_in_the_loop() {
- br ^bb1
-^bb1:
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
- // CHECK: async.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.await %token : !async.token
- %0 = call @cond(): () -> (i1)
- cond_br %0, ^bb1, ^bb2
-^bb2:
- return
-}
-
-// CHECK-LABEL: @token_capture
-func @token_capture() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
-
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: %[[TOKEN_0:.*]] = async.execute
- %token_0 = async.execute {
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK-NEXT: async.yield
- async.await %token : !async.token
- async.yield
- }
- // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @token_nested_capture
-func @token_nested_capture() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
-
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: %[[TOKEN_0:.*]] = async.execute
- %token_0 = async.execute {
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: %[[TOKEN_1:.*]] = async.execute
- %token_1 = async.execute {
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: %[[TOKEN_2:.*]] = async.execute
- %token_2 = async.execute {
- // CHECK: async.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.await %token : !async.token
- async.yield
- }
- // CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32}
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.yield
- }
- // CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32}
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.yield
- }
- // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @token_dependency
-func @token_dependency() {
- // CHECK: %[[TOKEN:.*]] = async.execute
- %token = async.execute {
- async.yield
- }
-
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: %[[TOKEN_0:.*]] = async.execute
- %token_0 = async.execute[%token] {
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK-NEXT: async.yield
- async.yield
- }
-
- // CHECK: async.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.await %token : !async.token
- // CHECK: async.await %[[TOKEN_0]]
- // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
- async.await %token_0 : !async.token
-
- // CHECK: return
- return
-}
-
-// CHECK-LABEL: @value_operand
-func @value_operand() -> f32 {
- // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
- %token, %results = async.execute -> !async.value<f32> {
- %0 = constant 0.0 : f32
- async.yield %0 : f32
- }
-
- // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32}
- // CHECK: %[[TOKEN_0:.*]] = async.execute
- %token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>) {
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
- // CHECK: async.yield
- async.yield
- }
-
- // CHECK: async.await %[[TOKEN]]
- // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
- async.await %token : !async.token
-
- // CHECK: async.await %[[TOKEN_0]]
- // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
- async.await %token_0 : !async.token
-
- // CHECK: async.await %[[RESULTS]]
- // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
- %0 = async.await %results : !async.value<f32>
-
- // CHECK: return
- return %0 : f32
-}
diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
new file mode 100644
index 0000000000000..9b6bb1a5e7515
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s -async-runtime-ref-counting-opt | FileCheck %s
+
+func private @consume_token(%arg0: !async.token)
+
+// CHECK-LABEL: @cancellable_operations_0
+func @cancellable_operations_0(%arg0: !async.token) {
+ // CHECK-NOT: async.runtime.add_ref
+ // CHECK-NOT: async.runtime.drop_ref
+ async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_1
+func @cancellable_operations_1(%arg0: !async.token) {
+ // CHECK-NOT: async.runtime.add_ref
+ async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: call @consume_toke
+ call @consume_token(%arg0): (!async.token) -> ()
+ // CHECK-NOT: async.runtime.drop_ref
+ async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_2
+func @cancellable_operations_2(%arg0: !async.token) {
+ // CHECK: async.runtime.await
+ // CHECK-NEXT: async.runtime.await
+ // CHECK-NEXT: async.runtime.await
+ // CHECK-NEXT: return
+ async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.runtime.await %arg0 : !async.token
+ async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.runtime.await %arg0 : !async.token
+ async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.runtime.await %arg0 : !async.token
+ async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_3
+func @cancellable_operations_3(%arg0: !async.token) {
+ // CHECK-NOT: add_ref
+ async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: call @consume_toke
+ call @consume_token(%arg0): (!async.token) -> ()
+ // CHECK-NOT: async.runtime.drop_ref
+ async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: async.runtime.await
+ async.runtime.await %arg0 : !async.token
+ // CHECK: return
+ return
+}
diff --git a/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir
new file mode 100644
index 0000000000000..40ac40a930097
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir
@@ -0,0 +1,215 @@
+// RUN: mlir-opt %s -async-runtime-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @token
+func private @token() -> !async.token
+
+// CHECK-LABEL: @cond
+func private @cond() -> i1
+
+// CHECK-LABEL: @take_token
+func private @take_token(%arg0: !async.token)
+
+// CHECK-LABEL: @token_arg_no_uses
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_no_uses(%arg0: !async.token) {
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_value_no_uses
+func @token_value_no_uses() {
+ // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ %0 = async.runtime.create : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_returned_no_uses
+func @token_returned_no_uses() {
+ // CHECK: %[[TOKEN:.*]] = call @token
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ %0 = call @token() : () -> !async.token
+ return
+}
+
+// CHECK-LABEL: @token_arg_to_func
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_to_func(%arg0: !async.token) {
+ // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+ call @take_token(%arg0): (!async.token) -> ()
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_value_to_func
+func @token_value_to_func() {
+ // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+ %0 = async.runtime.create : !async.token
+ // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+ call @take_token(%0): (!async.token) -> ()
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
+ // CHECK: cond_br
+ // CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
+ cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ // CHECK: ^[[BB1]]:
+ // CHECK: br ^[[BB2]]
+ br ^bb2
+^bb2:
+ // CHECK: ^[[BB2]]:
+ // CHECK: async.runtime.await %[[TOKEN]]
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.runtime.await %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_simple_return
+func @token_simple_return() -> !async.token {
+ // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+ %token = async.runtime.create : !async.token
+ // CHECK: return %[[TOKEN]]
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_coro_return
+// CHECK-NOT: async.runtime.drop_ref
+// CHECK-NOT: async.runtime.add_ref
+func @token_coro_return() -> !async.token {
+ %token = async.runtime.create : !async.token
+ %id = async.coro.id
+ %hdl = async.coro.begin %id
+ %saved = async.coro.save %hdl
+ async.runtime.resume %hdl
+ async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+ br ^cleanup
+^cleanup:
+ async.coro.free %id, %hdl
+ br ^suspend
+^suspend:
+ async.coro.end %hdl
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_coro_await_and_resume
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
+ %token = async.runtime.create : !async.token
+ %id = async.coro.id
+ %hdl = async.coro.begin %id
+ %saved = async.coro.save %hdl
+ // CHECK: async.runtime.await_and_resume %[[TOKEN]]
+ async.runtime.await_and_resume %arg0, %hdl : !async.token
+ // CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+ br ^cleanup
+^cleanup:
+ async.coro.free %id, %hdl
+ br ^suspend
+^suspend:
+ async.coro.end %hdl
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @value_coro_await_and_resume
+// CHECK: %[[VALUE:.*]]: !async.value<f32>
+func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
+ %token = async.runtime.create : !async.token
+ %id = async.coro.id
+ %hdl = async.coro.begin %id
+ %saved = async.coro.save %hdl
+ // CHECK: async.runtime.await_and_resume %[[VALUE]]
+ async.runtime.await_and_resume %arg0, %hdl : !async.value<f32>
+ // CHECK: async.coro.suspend
+ // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+ async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+ // CHECK: ^[[RESUME]]:
+ // CHECK: %[[LOADED:.*]] = async.runtime.load %[[VALUE]]
+ // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
+ %0 = async.runtime.load %arg0 : !async.value<f32>
+ // CHECK: addf %[[LOADED]], %[[LOADED]]
+ %1 = addf %0, %0 : f32
+ br ^cleanup
+^cleanup:
+ async.coro.free %id, %hdl
+ br ^suspend
+^suspend:
+ async.coro.end %hdl
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @outlined_async_execute
+// CHECK: %[[TOKEN:.*]]: !async.token
+func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
+ %0 = async.runtime.create : !async.token
+ %1 = async.coro.id
+ %2 = async.coro.begin %1
+ %3 = async.coro.save %2
+ async.runtime.resume %2
+ // CHECK: async.coro.suspend
+ async.coro.suspend %3, ^suspend, ^resume, ^cleanup
+^resume:
+ // CHECK: ^[[RESUME:.*]]:
+ %4 = async.coro.save %2
+ async.runtime.await_and_resume %arg0, %2 : !async.token
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: async.coro.suspend
+ async.coro.suspend %4, ^suspend, ^resume_1, ^cleanup
+^resume_1:
+ // CHECK: ^[[RESUME_1:.*]]:
+ // CHECK: async.runtime.set_available
+ async.runtime.set_available %0 : !async.token
+ br ^cleanup
+^cleanup:
+ // CHECK: ^[[CLEANUP:.*]]:
+ // CHECK: async.coro.free
+ async.coro.free %1, %2
+ br ^suspend
+^suspend:
+ // CHECK: ^[[SUSPEND:.*]]:
+ // CHECK: async.coro.end
+ async.coro.end %2
+ return %0 : !async.token
+}
+
+// CHECK-LABEL: @token_await_inside_nested_region
+// CHECK: %[[ARG:.*]]: i1
+func @token_await_inside_nested_region(%arg0: i1) {
+ // CHECK: %[[TOKEN:.*]] = call @token()
+ %token = call @token() : () -> !async.token
+ // CHECK: scf.if %[[ARG]] {
+ scf.if %arg0 {
+ // CHECK: async.runtime.await %[[TOKEN]]
+ async.runtime.await %token : !async.token
+ }
+ // CHECK: }
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_defined_in_the_loop
+func @token_defined_in_the_loop() {
+ br ^bb1
+^bb1:
+ // CHECK: ^[[BB1:.*]]:
+ // CHECK: %[[TOKEN:.*]] = call @token()
+ %token = call @token() : () -> !async.token
+ // CHECK: async.runtime.await %[[TOKEN]]
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.runtime.await %token : !async.token
+ %0 = call @cond(): () -> (i1)
+ cond_br %0, ^bb1, ^bb2
+^bb2:
+ // CHECK: ^[[BB2:.*]]:
+ // CHECK: return
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
index 0e29209c4fab7..dee9b1cd62eac 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
@@ -1,8 +1,9 @@
// RUN: mlir-opt %s \
// RUN: -linalg-tile-to-parallel-loops="linalg-tile-sizes=256" \
// RUN: -async-parallel-for="num-concurrent-async-execute=4" \
-// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -lower-affine \
// RUN: -convert-linalg-to-loops \
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index e4d19bf5c2f9f..9f05ec8065dc5 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
-// RUN: -async-ref-counting \
// RUN: -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index d0f688c61226b..883a0bc4fab7b 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
-// RUN: -async-ref-counting \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
index a9149aef67d26..8216d1558c639 100644
--- a/mlir/test/mlir-cpu-runner/async-group.mlir
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -async-ref-counting \
-// RUN: -async-to-async-runtime \
+// RUN: mlir-opt %s -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
diff --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir
index e58f9e06c1aee..878256ea3ad47 100644
--- a/mlir/test/mlir-cpu-runner/async-value.mlir
+++ b/mlir/test/mlir-cpu-runner/async-value.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -async-ref-counting \
-// RUN: -async-to-async-runtime \
+// RUN: mlir-opt %s -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-std-to-llvm \
diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir
index 30967928fc62b..417713e383577 100644
--- a/mlir/test/mlir-cpu-runner/async.mlir
+++ b/mlir/test/mlir-cpu-runner/async.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -async-ref-counting \
-// RUN: -async-to-async-runtime \
+// RUN: mlir-opt %s -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \
More information about the Mlir-commits
mailing list