[Mlir-commits] [mlir] a6628e5 - [mlir] Async: add automatic reference counting at async.runtime operations level

Eugene Zhulenev llvmlistbot at llvm.org
Mon Apr 12 18:55:02 PDT 2021


Author: Eugene Zhulenev
Date: 2021-04-12T18:54:55-07:00
New Revision: a6628e596e70bf5c31058dec582c8a7907928e98

URL: https://github.com/llvm/llvm-project/commit/a6628e596e70bf5c31058dec582c8a7907928e98
DIFF: https://github.com/llvm/llvm-project/commit/a6628e596e70bf5c31058dec582c8a7907928e98.diff

LOG: [mlir] Async: add automatic reference counting at async.runtime operations level

Depends On D95311

Previous automatic-ref-counting pass worked with high level async operations (e.g. async.execute), however async values reference counting is a runtime implementation detail.

New pass mostly relies on the save liveness analysis to place drop_ref operations, and does better verification of CFG with different liveIn sets in block successors.

This is almost NFC change. No new reference counting ideas, just a cleanup of the previous version.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D95390

Added: 
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
    mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
    mlir/test/Dialect/Async/async-runtime-ref-counting.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
    mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
    mlir/test/mlir-cpu-runner/async-group.mlir
    mlir/test/mlir-cpu-runner/async-value.mlir
    mlir/test/mlir-cpu-runner/async.mlir

Removed: 
    mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
    mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
    mlir/test/Dialect/Async/async-ref-counting.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 1f38136a2a30c..ddcfc8bdaeddf 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -22,11 +22,11 @@ std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
 std::unique_ptr<OperationPass<FuncOp>>
 createAsyncParallelForPass(int numWorkerThreads);
 
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
+std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
 
-std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingPass();
 
-std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRuntimeRefCountingOptPass();
 
 //===----------------------------------------------------------------------===//
 // Registration

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 755d13a7cd221..155e23572bf80 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -24,24 +24,35 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
   let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
 }
 
-def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
-  let summary = "Automatic reference counting for Async dialect data types";
-  let constructor = "mlir::createAsyncRefCountingPass()";
+def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
+  let summary = "Lower high level async operations (e.g. async.execute) to the"
+                "explicit async.runtime and async.coro operations";
+  let constructor = "mlir::createAsyncToAsyncRuntimePass()";
   let dependentDialects = ["async::AsyncDialect"];
 }
 
-def AsyncRefCountingOptimization :
-    FunctionPass<"async-ref-counting-optimization"> {
-  let summary = "Optimize automatic reference counting operations for the"
-                "Async dialect by removing redundant operations";
-  let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
+def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> {
+  let summary = "Automatic reference counting for Async runtime operations";
+  let description = [{
+    This pass works at the async runtime abtraction level, after all
+    `async.execute` and `async.await` operations are lowered to the async
+    runtime API calls, and async coroutine operations.
+
+    It relies on the LLVM coroutines switched-resume lowering semantics for
+    the correct placing of the reference counting operations.
+
+    See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering
+  }];
+
+  let constructor = "mlir::createAsyncRuntimeRefCountingPass()";
   let dependentDialects = ["async::AsyncDialect"];
 }
 
-def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
-  let summary = "Lower high level async operations (e.g. async.execute) to the"
-                "explicit async.rutime and async.coro operations";
-  let constructor = "mlir::createAsyncToAsyncRuntimePass()";
+def AsyncRuntimeRefCountingOpt :
+    FunctionPass<"async-runtime-ref-counting-opt"> {
+  let summary = "Optimize automatic reference counting operations for the"
+                "Async runtime by removing redundant operations";
+  let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";
   let dependentDialects = ["async::AsyncDialect"];
 }
 

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
deleted file mode 100644
index d28da61bb2688..0000000000000
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
+++ /dev/null
@@ -1,325 +0,0 @@
-//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements automatic reference counting for Async dialect data
-// types.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Analysis/Liveness.h"
-#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/Async/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallSet.h"
-
-using namespace mlir;
-using namespace mlir::async;
-
-#define DEBUG_TYPE "async-ref-counting"
-
-namespace {
-
-class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
-public:
-  AsyncRefCountingPass() = default;
-  void runOnFunction() override;
-
-private:
-  /// Adds an automatic reference counting to the `value`.
-  ///
-  /// All values are semantically created with a reference count of +1 and it is
-  /// the responsibility of the last async value user to drop reference count.
-  ///
-  /// Async values created when:
-  ///   1. Operation returns async result (e.g. the result of an
-  ///      `async.execute`).
-  ///   2. Async value passed in as a block argument.
-  ///
-  /// To implement automatic reference counting, we must insert a +1 reference
-  /// before each `async.execute` operation using the value, and drop it after
-  /// the last use inside the async body region (we currently drop the reference
-  /// before the `async.yield` terminator).
-  ///
-  /// Automatic reference counting algorithm outline:
-  ///
-  /// 1. `ReturnLike` operations forward the reference counted values without
-  ///     modifying the reference count.
-  ///
-  /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
-  ///    reference counted values ends, and insert `drop_ref` operations after
-  ///    the last use of the value.
-  ///
-  /// 3. Insert `add_ref` before the `async.execute` operation capturing the
-  ///    value, and pairing `drop_ref` before the async body region terminator,
-  ///    to release the captured reference counted value when execution
-  ///    completes.
-  ///
-  /// 4. If the reference counted value is passed only to some of the block
-  ///    successors, insert `drop_ref` operations in the beginning of the blocks
-  ///    that do not have reference counted value uses.
-  ///
-  ///
-  /// Example:
-  ///
-  ///   %token = ...
-  ///   async.execute {
-  ///     async.await %token : !async.token   // await #1
-  ///     async.yield
-  ///   }
-  ///   async.await %token : !async.token     // await #2
-  ///
-  /// Based on the liveness analysis await #2 is the last use of the %token,
-  /// however the execution of the async region can be delayed, and to guarantee
-  /// that the %token is still alive when await #1 executes we need to
-  /// explicitly extend its lifetime using `add_ref` operation.
-  ///
-  /// After automatic reference counting:
-  ///
-  ///   %token = ...
-  ///
-  ///   // Make sure that %token is alive inside async.execute.
-  ///   async.add_ref %token {count = 1 : i32} : !async.token
-  ///
-  ///   async.execute {
-  ///     async.await %token : !async.token   // await #1
-  ///
-  ///     // Drop the extra reference added to keep %token alive.
-  ///     async.drop_ref %token {count = 1 : i32} : !async.token
-  ///
-  ///     async.yied
-  ///   }
-  ///   async.await %token : !async.token     // await #2
-  ///
-  ///   // Drop the reference after the last use of %token.
-  ///   async.drop_ref %token {count = 1 : i32} : !async.token
-  ///
-  LogicalResult addAutomaticRefCounting(Value value);
-};
-
-} // namespace
-
-LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
-  MLIRContext *ctx = value.getContext();
-  OpBuilder builder(ctx);
-
-  // Set inserton point after the operation producing a value, or at the
-  // beginning of the block if the value defined by the block argument.
-  if (Operation *op = value.getDefiningOp())
-    builder.setInsertionPointAfter(op);
-  else
-    builder.setInsertionPointToStart(value.getParentBlock());
-
-  Location loc = value.getLoc();
-  auto i32 = IntegerType::get(ctx, 32);
-
-  // Drop the reference count immediately if the value has no uses.
-  if (value.getUses().empty()) {
-    builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
-    return success();
-  }
-
-  // Use liveness analysis to find the placement of `drop_ref`operation.
-  auto liveness = getAnalysis<Liveness>();
-
-  // We analyse only the blocks of the region that defines the `value`, and do
-  // not check nested blocks attached to operations.
-  //
-  // By analyzing only the `definingRegion` CFG we potentially loose an
-  // opportunity to drop the reference count earlier and can extend the lifetime
-  // of reference counted value longer then it is really required.
-  //
-  // We also assume that all nested regions finish their execution before the
-  // completion of the owner operation. The only exception to this rule is
-  // `async.execute` operation, which is handled explicitly below.
-  Region *definingRegion = value.getParentRegion();
-
-  // ------------------------------------------------------------------------ //
-  // Find blocks where the `value` dies: the value is in `liveIn` set and not
-  // in the `liveOut` set. We place `drop_ref` immediately after the last use
-  // of the `value` in such regions.
-  // ------------------------------------------------------------------------ //
-
-  // Last users of the `value` inside all blocks where the value dies.
-  llvm::SmallSet<Operation *, 4> lastUsers;
-
-  for (Block &block : definingRegion->getBlocks()) {
-    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
-
-    // Value in live input set or was defined in the block.
-    bool liveIn = blockLiveness->isLiveIn(value) ||
-                  blockLiveness->getBlock() == value.getParentBlock();
-    if (!liveIn)
-      continue;
-
-    // Value is in the live out set.
-    bool liveOut = blockLiveness->isLiveOut(value);
-    if (liveOut)
-      continue;
-
-    // We proved that `value` dies in the `block`. Now find the last use of the
-    // `value` inside the `block`.
-
-    // Find any user of the `value` inside the block (including uses in nested
-    // regions attached to the operations in the block).
-    Operation *userInTheBlock = nullptr;
-    for (Operation *user : value.getUsers()) {
-      userInTheBlock = block.findAncestorOpInBlock(*user);
-      if (userInTheBlock)
-        break;
-    }
-
-    // Values with zero users handled explicitly in the beginning, if the value
-    // is in live out set it must have at least one use in the block.
-    assert(userInTheBlock && "value must have a user in the block");
-
-    // Find the last user of the `value` in the block;
-    Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
-    assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
-    lastUsers.insert(lastUser);
-  }
-
-  // Process all the last users of the `value` inside each block where the value
-  // dies.
-  for (Operation *lastUser : lastUsers) {
-    // Return like operations forward reference count.
-    if (lastUser->hasTrait<OpTrait::ReturnLike>())
-      continue;
-
-    // We can't currently handle other types of terminators.
-    if (lastUser->hasTrait<OpTrait::IsTerminator>())
-      return lastUser->emitError() << "async reference counting can't handle "
-                                      "terminators that are not ReturnLike";
-
-    // Add a drop_ref immediately after the last user.
-    builder.setInsertionPointAfter(lastUser);
-    builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
-  }
-
-  // ------------------------------------------------------------------------ //
-  // Find blocks where the `value` is in `liveOut` set, however it is not in
-  // the `liveIn` set of all successors. If the `value` is not in the successor
-  // `liveIn` set, we add a `drop_ref` to the beginning of it.
-  // ------------------------------------------------------------------------ //
-
-  // Successors that we'll need a `drop_ref` for the `value`.
-  llvm::SmallSet<Block *, 4> dropRefSuccessors;
-
-  for (Block &block : definingRegion->getBlocks()) {
-    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
-
-    // Skip the block if value is not in the `liveOut` set.
-    if (!blockLiveness->isLiveOut(value))
-      continue;
-
-    // Find successors that do not have `value` in the `liveIn` set.
-    for (Block *successor : block.getSuccessors()) {
-      const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
-
-      if (!succLiveness->isLiveIn(value))
-        dropRefSuccessors.insert(successor);
-    }
-  }
-
-  // Drop reference in all successor blocks that do not have the `value` in
-  // their `liveIn` set.
-  for (Block *dropRefSuccessor : dropRefSuccessors) {
-    builder.setInsertionPointToStart(dropRefSuccessor);
-    builder.create<RuntimeDropRefOp>(loc, value, IntegerAttr::get(i32, 1));
-  }
-
-  // ------------------------------------------------------------------------ //
-  // Find all `async.execute` operation that take `value` as an operand
-  // (dependency token or async value), or capture implicitly by the nested
-  // region. Each `async.execute` operation will require `add_ref` operation
-  // to keep all captured values alive until it will finish its execution.
-  // ------------------------------------------------------------------------ //
-
-  llvm::SmallSet<ExecuteOp, 4> executeOperations;
-
-  auto trackAsyncExecute = [&](Operation *op) {
-    if (auto execute = dyn_cast<ExecuteOp>(op))
-      executeOperations.insert(execute);
-  };
-
-  for (Operation *user : value.getUsers()) {
-    // Follow parent operations up until the operation in the `definingRegion`.
-    while (user->getParentRegion() != definingRegion) {
-      trackAsyncExecute(user);
-      user = user->getParentOp();
-      assert(user != nullptr && "value user lies outside of the value region");
-    }
-
-    // Don't forget to process the parent in the `definingRegion` (can be the
-    // original user operation itself).
-    trackAsyncExecute(user);
-  }
-
-  // Process all `async.execute` operations capturing `value`.
-  for (ExecuteOp execute : executeOperations) {
-    // Add a reference before the execute operation to keep the reference
-    // counted alive before the async region completes execution.
-    builder.setInsertionPoint(execute.getOperation());
-    builder.create<RuntimeAddRefOp>(loc, value, IntegerAttr::get(i32, 1));
-
-    // Drop the reference inside the async region before completion.
-    OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
-    executeBuilder.create<RuntimeDropRefOp>(loc, value,
-                                            IntegerAttr::get(i32, 1));
-  }
-
-  return success();
-}
-
-void AsyncRefCountingPass::runOnFunction() {
-  FuncOp func = getFunction();
-
-  // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
-  // because otherwise automatic reference counting will produce incorrect
-  // results.
-  WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
-    if (isa<RuntimeAddRefOp, RuntimeDropRefOp>(op))
-      return op->emitError() << "explicit reference counting is not supported";
-    return WalkResult::advance();
-  });
-
-  if (refCountingWalk.wasInterrupted())
-    signalPassFailure();
-
-  // Add reference counting to block arguments.
-  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
-    for (BlockArgument arg : block->getArguments())
-      if (isRefCounted(arg.getType()))
-        if (failed(addAutomaticRefCounting(arg)))
-          return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  if (blockWalk.wasInterrupted())
-    signalPassFailure();
-
-  // Add reference counting to operation results.
-  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
-    for (unsigned i = 0; i < op->getNumResults(); ++i)
-      if (isRefCounted(op->getResultTypes()[i]))
-        if (failed(addAutomaticRefCounting(op->getResult(i))))
-          return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  if (opWalk.wasInterrupted())
-    signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
-  return std::make_unique<AsyncRefCountingPass>();
-}

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
deleted file mode 100644
index 6ac2fd12fa113..0000000000000
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
+++ /dev/null
@@ -1,218 +0,0 @@
-//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Optimize Async dialect reference counting operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/Async/Passes.h"
-#include "llvm/ADT/SmallSet.h"
-
-using namespace mlir;
-using namespace mlir::async;
-
-#define DEBUG_TYPE "async-ref-counting"
-
-namespace {
-
-class AsyncRefCountingOptimizationPass
-    : public AsyncRefCountingOptimizationBase<
-          AsyncRefCountingOptimizationPass> {
-public:
-  AsyncRefCountingOptimizationPass() = default;
-  void runOnFunction() override;
-
-private:
-  LogicalResult optimizeReferenceCounting(Value value);
-};
-
-} // namespace
-
-LogicalResult
-AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
-  Region *definingRegion = value.getParentRegion();
-
-  // Find all users of the `value` inside each block, including operations that
-  // do not use `value` directly, but have a direct use inside nested region(s).
-  //
-  // Example:
-  //
-  //  ^bb1:
-  //    %token = ...
-  //    scf.if %cond {
-  //      ^bb2:
-  //      async.await %token : !async.token
-  //    }
-  //
-  // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
-  //
-  // In addition to the operation that uses the `value` we also keep track if
-  // this user is an `async.execute` operation itself, or has `async.execute`
-  // operations in the nested regions that do use the `value`.
-
-  struct UserInfo {
-    Operation *operation;
-    bool hasExecuteUser;
-  };
-
-  struct BlockUsersInfo {
-    llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
-    llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
-    llvm::SmallVector<UserInfo, 4> users;
-  };
-
-  llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
-
-  auto updateBlockUsersInfo = [&](UserInfo user) {
-    BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
-    info.users.push_back(user);
-
-    if (auto addRef = dyn_cast<RuntimeAddRefOp>(user.operation))
-      info.addRefs.push_back(addRef);
-    if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user.operation))
-      info.dropRefs.push_back(dropRef);
-  };
-
-  for (Operation *user : value.getUsers()) {
-    bool isAsyncUser = isa<ExecuteOp>(user);
-
-    while (user->getParentRegion() != definingRegion) {
-      updateBlockUsersInfo({user, isAsyncUser});
-      user = user->getParentOp();
-      isAsyncUser |= isa<ExecuteOp>(user);
-      assert(user != nullptr && "value user lies outside of the value region");
-    }
-
-    updateBlockUsersInfo({user, isAsyncUser});
-  }
-
-  // Sort all operations found in the block.
-  auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
-    auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
-      return a->isBeforeInBlock(b);
-    };
-    llvm::sort(info.addRefs, isBeforeInBlock);
-    llvm::sort(info.dropRefs, isBeforeInBlock);
-    llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
-      return isBeforeInBlock(a.operation, b.operation);
-    });
-
-    return info;
-  };
-
-  // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
-  // blocks that modify the reference count of the `value`.
-  for (auto &kv : blockUsers) {
-    BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
-
-    // Find all cancellable pairs first and erase them later to keep all
-    // pointers in the `info` valid until the end.
-    //
-    // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
-    llvm::SmallDenseMap<Operation *, Operation *> cancellable;
-
-    for (RuntimeAddRefOp addRef : info.addRefs) {
-      for (RuntimeDropRefOp dropRef : info.dropRefs) {
-        // `drop_ref` operation after the `add_ref` with matching count.
-        if (dropRef.count() != addRef.count() ||
-            dropRef->isBeforeInBlock(addRef.getOperation()))
-          continue;
-
-        // `drop_ref` was already marked for removal.
-        if (cancellable.find(dropRef.getOperation()) != cancellable.end())
-          continue;
-
-        // Check `value` users between `addRef` and `dropRef` in the `block`.
-        Operation *addRefOp = addRef.getOperation();
-        Operation *dropRefOp = dropRef.getOperation();
-
-        // If there is a "regular" user after the `async.execute` user it is
-        // unsafe to erase cancellable reference counting operations pair,
-        // because async region can complete before the "regular" user and
-        // destroy the reference counted value.
-        bool hasExecuteUser = false;
-        bool unsafeToCancel = false;
-
-        for (UserInfo &user : info.users) {
-          Operation *op = user.operation;
-
-          // `user` operation lies after `addRef` ...
-          if (op == addRefOp || op->isBeforeInBlock(addRefOp))
-            continue;
-          // ... and before `dropRef`.
-          if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
-            break;
-
-          bool isRegularUser = !user.hasExecuteUser;
-          bool isExecuteUser = user.hasExecuteUser;
-
-          // It is unsafe to cancel `addRef` / `dropRef` pair.
-          if (isRegularUser && hasExecuteUser) {
-            unsafeToCancel = true;
-            break;
-          }
-
-          hasExecuteUser |= isExecuteUser;
-        }
-
-        // Mark the pair of reference counting operations for removal.
-        if (!unsafeToCancel)
-          cancellable[dropRef.getOperation()] = addRef.getOperation();
-
-        // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
-        // all the following pairs will be also unsafe.
-        break;
-      }
-    }
-
-    // Erase all cancellable `addRef <-> dropRef` operation pairs.
-    for (auto &kv : cancellable) {
-      kv.first->erase();
-      kv.second->erase();
-    }
-  }
-
-  return success();
-}
-
-void AsyncRefCountingOptimizationPass::runOnFunction() {
-  FuncOp func = getFunction();
-
-  // Optimize reference counting for values defined by block arguments.
-  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
-    for (BlockArgument arg : block->getArguments())
-      if (isRefCounted(arg.getType()))
-        if (failed(optimizeReferenceCounting(arg)))
-          return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  if (blockWalk.wasInterrupted())
-    signalPassFailure();
-
-  // Optimize reference counting for values defined by operation results.
-  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
-    for (unsigned i = 0; i < op->getNumResults(); ++i)
-      if (isRefCounted(op->getResultTypes()[i]))
-        if (failed(optimizeReferenceCounting(op->getResult(i))))
-          return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  if (opWalk.wasInterrupted())
-    signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<FuncOp>>
-mlir::createAsyncRefCountingOptimizationPass() {
-  return std::make_unique<AsyncRefCountingOptimizationPass>();
-}

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
new file mode 100644
index 0000000000000..af443918df970
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -0,0 +1,377 @@
+//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements automatic reference counting for Async runtime
+// operations and types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-runtime-ref-counting"
+
+namespace {
+
+class AsyncRuntimeRefCountingPass
+    : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
+public:
+  AsyncRuntimeRefCountingPass() = default;
+  void runOnFunction() override;
+
+private:
+  /// Adds an automatic reference counting to the `value`.
+  ///
+  /// All values (token, group or value) are semantically created with a
+  /// reference count of +1 and it is the responsibility of the async value user
+  /// to place the `add_ref` and `drop_ref` operations to ensure that the value
+  /// is destroyed after the last use.
+  ///
+  /// The function returns failure if it can't deduce the locations where
+  /// to place the reference counting operations.
+  ///
+  /// Async values "semantically created" when:
+  ///   1. Operation returns async result (e.g. `async.runtime.create`)
+  ///   2. Async value passed in as a block argument (or function argument,
+  ///      because function arguments are just entry block arguments)
+  ///
+  /// Passing async value as a function argument (or block argument) does not
+  /// really mean that a new async value is created, it only means that the
+  /// caller of a function transfered ownership of `+1` reference to the callee.
+  /// It is convenient to think that from the callee perspective async value was
+  /// "created" with `+1` reference by the block argument.
+  ///
+  /// Automatic reference counting algorithm outline:
+  ///
+  /// #1 Insert `drop_ref` operations after last use of the `value`.
+  /// #2 Insert `add_ref` operations before functions calls with reference
+  ///    counted `value` operand (newly created `+1` reference will be
+  ///    transferred to the callee).
+  /// #3 Verify that divergent control flow does not lead to leaked reference
+  ///    counted objects.
+  ///
+  /// Async runtime reference counting optimization pass will optimize away
+  /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
+  /// strategy (see `async-runtime-ref-counting-opt`).
+  LogicalResult addAutomaticRefCounting(Value value);
+
+  /// (#1) Adds the `drop_ref` operation after the last use of the `value`
+  /// relying on the liveness analysis.
+  ///
+  /// If the `value` is in the block `liveIn` set and it is not in the block
+  /// `liveOut` set, it means that it "dies" in the block. We find the last
+  /// use of the value in such block and:
+  ///
+  ///   1. If the last user is a `ReturnLike` operation we do nothing, because
+  ///      it forwards the ownership to the caller.
+  ///   2. Otherwise we add a `drop_ref` operation immediately after the last
+  ///      use.
+  LogicalResult addDropRefAfterLastUse(Value value);
+
+  /// (#2) Adds the `add_ref` operation before the function call taking `value`
+  /// operand to ensure that the value passed to the function entry block
+  /// has a `+1` reference count.
+  LogicalResult addAddRefBeforeFunctionCall(Value value);
+
+  /// (#3) Verifies that if a block has a value in the `liveOut` set, then the
+  /// value is in `liveIn` set in all successors.
+  ///
+  /// Example:
+  ///
+  ///   ^entry:
+  ///     %token = async.runtime.create : !async.token
+  ///     cond_br %cond, ^bb1, ^bb2
+  ///   ^bb1:
+  ///     async.runtime.await %token
+  ///     return
+  ///   ^bb2:
+  ///     return
+  ///
+  /// This CFG will be rejected because ^bb2 does not have `value` in the
+  /// `liveIn` set, and it will leak a reference counted object.
+  ///
+  /// An exception to this rule are blocks with `async.coro.suspend` terminator,
+  /// because in Async to LLVM lowering it is guaranteed that the control flow
+  /// will jump into the resume block, and then follow into the cleanup and
+  /// suspend blocks.
+  ///
+  /// Example:
+  ///
+  ///  ^entry(%value: !async.value<f32>):
+  ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
+  ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
+  ///   ^resume:
+  ///     %0 = async.runtime.load %value
+  ///     br ^cleanup
+  ///   ^cleanup:
+  ///     ...
+  ///   ^suspend:
+  ///     ...
+  ///
+  /// Although cleanup and suspend blocks do not have the `value` in the
+  /// `liveIn` set, it is guaranteed that execution will eventually continue in
+  /// the resume block (we never explicitly destroy coroutines).
+  LogicalResult verifySuccessors(Value value);
+};
+
+} // namespace
+
+LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
+  OpBuilder builder(value.getContext());
+  Location loc = value.getLoc();
+
+  // Use liveness analysis to find the placement of `drop_ref`operation.
+  auto &liveness = getAnalysis<Liveness>();
+
+  // We analyse only the blocks of the region that defines the `value`, and do
+  // not check nested blocks attached to operations.
+  //
+  // By analyzing only the `definingRegion` CFG we potentially loose an
+  // opportunity to drop the reference count earlier and can extend the lifetime
+  // of reference counted value longer then it is really required.
+  //
+  // We also assume that all nested regions finish their execution before the
+  // completion of the owner operation. The only exception to this rule is
+  // `async.execute` operation, and we verify that they are lowered to the
+  // `async.runtime` operations before adding automatic reference counting.
+  Region *definingRegion = value.getParentRegion();
+
+  // Last users of the `value` inside all blocks where the value dies.
+  llvm::SmallSet<Operation *, 4> lastUsers;
+
+  // Find blocks in the `definingRegion` that have users of the `value` (if
+  // there are multiple users in the block, which one will be selected is
+  // undefined). User operation might be not the actual user of the value, but
+  // the operation in the block that has a "real user" in one of the attached
+  // regions.
+  llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
+
+  for (Operation *user : value.getUsers()) {
+    Block *userBlock = user->getBlock();
+    Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
+    usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
+    assert(ancestor && "ancestor block must be not null");
+    assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
+  }
+
+  // Find blocks where the `value` dies: the value is in `liveIn` set and not
+  // in the `liveOut` set. We place `drop_ref` immediately after the last use
+  // of the `value` in such regions (after handling few special cases).
+  //
+  // We do not traverse all the blocks in the `definingRegion`, because the
+  // `value` can be in the live in set only if it has users in the block, or it
+  // is defined in the block.
+  //
+  // Values with zero users (only definition) handled explicitly above.
+  for (auto &blockAndUser : usersInTheBlocks) {
+    Block *block = blockAndUser.getFirst();
+    Operation *userInTheBlock = blockAndUser.getSecond();
+
+    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
+
+    // Value must be in the live input set or defined in the block.
+    assert(blockLiveness->isLiveIn(value) ||
+           blockLiveness->getBlock() == value.getParentBlock());
+
+    // If value is in the live out set, it means it doesn't "die" in the block.
+    if (blockLiveness->isLiveOut(value))
+      continue;
+
+    // At this point we proved that `value` dies in the `block`. Find the last
+    // use of the `value` inside the `block`, this is where it "dies".
+    Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
+    assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
+    lastUsers.insert(lastUser);
+  }
+
+  // Process all the last users of the `value` inside each block where the value
+  // dies.
+  for (Operation *lastUser : lastUsers) {
+    // Return like operations forward reference count.
+    if (lastUser->hasTrait<OpTrait::ReturnLike>())
+      continue;
+
+    // We can't currently handle other types of terminators.
+    if (lastUser->hasTrait<OpTrait::IsTerminator>())
+      return lastUser->emitError() << "async reference counting can't handle "
+                                      "terminators that are not ReturnLike";
+
+    // Add a drop_ref immediately after the last user.
+    builder.setInsertionPointAfter(lastUser);
+    builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+  }
+
+  return success();
+}
+
+LogicalResult
+AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
+  OpBuilder builder(value.getContext());
+  Location loc = value.getLoc();
+
+  for (Operation *user : value.getUsers()) {
+    if (!isa<CallOp>(user))
+      continue;
+
+    // Add a reference before the function call to pass the value at `+1`
+    // reference to the function entry block.
+    builder.setInsertionPoint(user);
+    builder.create<RuntimeAddRefOp>(loc, value, builder.getI32IntegerAttr(1));
+  }
+
+  return success();
+}
+
+LogicalResult AsyncRuntimeRefCountingPass::verifySuccessors(Value value) {
+  OpBuilder builder(value.getContext());
+
+  // Blocks with successfors with 
diff erent `liveIn` properties of the `value`.
+  llvm::SmallSet<Block *, 4> divergentLivenessBlocks;
+
+  // Use liveness analysis to find the placement of `drop_ref`operation.
+  auto &liveness = getAnalysis<Liveness>();
+
+  // Because we only add `drop_ref` operations to the region that defines the
+  // `value` we can only process CFG for the same region.
+  Region *definingRegion = value.getParentRegion();
+
+  // Collect blocks with successors with mismatching `liveIn` sets.
+  for (Block &block : definingRegion->getBlocks()) {
+    const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+    // Skip the block if value is not in the `liveOut` set.
+    if (!blockLiveness->isLiveOut(value))
+      continue;
+
+    // Sucessors with value in `liveIn` set and not value in `liveIn` set.
+    llvm::SmallSet<Block *, 4> liveInSuccessors;
+    llvm::SmallSet<Block *, 4> noLiveInSuccessors;
+
+    // Collect successors that do not have `value` in the `liveIn` set.
+    for (Block *successor : block.getSuccessors()) {
+      const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
+      if (succLiveness->isLiveIn(value))
+        liveInSuccessors.insert(successor);
+      else
+        noLiveInSuccessors.insert(successor);
+    }
+
+    // Block has successors with 
diff erent `liveIn` property of the `value`.
+    if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
+      divergentLivenessBlocks.insert(&block);
+  }
+
+  // Verify that divergent `liveIn` property only present in blocks with
+  // async.coro.suspend terminator.
+  for (Block *block : divergentLivenessBlocks) {
+    Operation *terminator = block->getTerminator();
+    if (isa<CoroSuspendOp>(terminator))
+      continue;
+
+    return terminator->emitOpError("successor have 
diff erent `liveIn` property "
+                                   "of the reference counted value: ");
+  }
+
+  return success();
+}
+
+LogicalResult
+AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
+  OpBuilder builder(value.getContext());
+  Location loc = value.getLoc();
+
+  // Set inserton point after the operation producing a value, or at the
+  // beginning of the block if the value defined by the block argument.
+  if (Operation *op = value.getDefiningOp())
+    builder.setInsertionPointAfter(op);
+  else
+    builder.setInsertionPointToStart(value.getParentBlock());
+
+  // Drop the reference count immediately if the value has no uses.
+  if (value.getUses().empty()) {
+    builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+    return success();
+  }
+
+  // Add `drop_ref` operations based on the liveness analysis.
+  if (failed(addDropRefAfterLastUse(value)))
+    return failure();
+
+  // Add `add_ref` operations before function calls.
+  if (failed(addAddRefBeforeFunctionCall(value)))
+    return failure();
+
+  // Verify that the `value` is in `liveIn` set of all successors.
+  if (failed(verifySuccessors(value)))
+    return failure();
+
+  return success();
+}
+
+void AsyncRuntimeRefCountingPass::runOnFunction() {
+  FuncOp func = getFunction();
+
+  // Check that we do not have high level async operations in the IR because
+  // otherwise automatic reference counting will produce incorrect results after
+  // execute operations will be lowered to `async.runtime`
+  WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult {
+    if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
+      return WalkResult::advance();
+
+    return op->emitError()
+           << "async operations must be lowered to async runtime operations";
+  });
+
+  if (executeOpWalk.wasInterrupted()) {
+    signalPassFailure();
+    return;
+  }
+
+  // Add reference counting to block arguments.
+  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+    for (BlockArgument arg : block->getArguments())
+      if (isRefCounted(arg.getType()))
+        if (failed(addAutomaticRefCounting(arg)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (blockWalk.wasInterrupted()) {
+    signalPassFailure();
+    return;
+  }
+
+  // Add reference counting to operation results.
+  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+    for (unsigned i = 0; i < op->getNumResults(); ++i)
+      if (isRefCounted(op->getResultTypes()[i]))
+        if (failed(addAutomaticRefCounting(op->getResult(i))))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (opWalk.wasInterrupted())
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRuntimeRefCountingPass() {
+  return std::make_unique<AsyncRuntimeRefCountingPass>();
+}

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
new file mode 100644
index 0000000000000..cb00d706ce0c8
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
@@ -0,0 +1,177 @@
+//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimize Async dialect reference counting operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRuntimeRefCountingOptPass
+    : public AsyncRuntimeRefCountingOptBase<AsyncRuntimeRefCountingOptPass> {
+public:
+  AsyncRuntimeRefCountingOptPass() = default;
+  void runOnFunction() override;
+
+private:
+  LogicalResult optimizeReferenceCounting(
+      Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
+};
+
+} // namespace
+
+LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
+    Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
+  Region *definingRegion = value.getParentRegion();
+
+  // Find all users of the `value` inside each block, including operations that
+  // do not use `value` directly, but have a direct use inside nested region(s).
+  //
+  // Example:
+  //
+  //  ^bb1:
+  //    %token = ...
+  //    scf.if %cond {
+  //      ^bb2:
+  //      async.runtime.await %token : !async.token
+  //    }
+  //
+  // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
+  // (`scf.if`).
+
+  struct BlockUsersInfo {
+    llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
+    llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
+    llvm::SmallVector<Operation *, 4> users;
+  };
+
+  llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
+
+  auto updateBlockUsersInfo = [&](Operation *user) {
+    BlockUsersInfo &info = blockUsers[user->getBlock()];
+    info.users.push_back(user);
+
+    if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
+      info.addRefs.push_back(addRef);
+    if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
+      info.dropRefs.push_back(dropRef);
+  };
+
+  for (Operation *user : value.getUsers()) {
+    while (user->getParentRegion() != definingRegion) {
+      updateBlockUsersInfo(user);
+      user = user->getParentOp();
+      assert(user != nullptr && "value user lies outside of the value region");
+    }
+
+    updateBlockUsersInfo(user);
+  }
+
+  // Sort all operations found in the block.
+  auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
+    auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
+      return a->isBeforeInBlock(b);
+    };
+    llvm::sort(info.addRefs, isBeforeInBlock);
+    llvm::sort(info.dropRefs, isBeforeInBlock);
+    llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
+      return isBeforeInBlock(a, b);
+    });
+
+    return info;
+  };
+
+  // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
+  // blocks that modify the reference count of the `value`.
+  for (auto &kv : blockUsers) {
+    BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
+
+    for (RuntimeAddRefOp addRef : info.addRefs) {
+      for (RuntimeDropRefOp dropRef : info.dropRefs) {
+        // `drop_ref` operation after the `add_ref` with matching count.
+        if (dropRef.count() != addRef.count() ||
+            dropRef->isBeforeInBlock(addRef.getOperation()))
+          continue;
+
+        // Try to cancel the pair of `add_ref` and `drop_ref` operations.
+        auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
+                                                addRef.getOperation());
+
+        if (!emplaced.second) // `drop_ref` was already marked for removal
+          continue;           // go to the next `drop_ref`
+
+        if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
+          break;             // go to the next `add_ref`
+      }
+    }
+  }
+
+  return success();
+}
+
+void AsyncRuntimeRefCountingOptPass::runOnFunction() {
+  FuncOp func = getFunction();
+
+  // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
+  //
+  // Find all cancellable pairs of operation and erase them in the end to keep
+  // all iterators valid while we are walking the function operations.
+  llvm::SmallDenseMap<Operation *, Operation *> cancellable;
+
+  // Optimize reference counting for values defined by block arguments.
+  WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+    for (BlockArgument arg : block->getArguments())
+      if (isRefCounted(arg.getType()))
+        if (failed(optimizeReferenceCounting(arg, cancellable)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (blockWalk.wasInterrupted())
+    signalPassFailure();
+
+  // Optimize reference counting for values defined by operation results.
+  WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+    for (unsigned i = 0; i < op->getNumResults(); ++i)
+      if (isRefCounted(op->getResultTypes()[i]))
+        if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (opWalk.wasInterrupted())
+    signalPassFailure();
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Found " << cancellable.size()
+                 << " cancellable reference counting operations\n";
+  });
+
+  // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
+  for (auto &kv : cancellable) {
+    kv.first->erase();
+    kv.second->erase();
+  }
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRuntimeRefCountingOptPass() {
+  return std::make_unique<AsyncRuntimeRefCountingOptPass>();
+}

diff  --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
index 8056758ac308a..45fb77f443a00 100644
--- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRAsyncTransforms
   AsyncParallelFor.cpp
-  AsyncRefCounting.cpp
-  AsyncRefCountingOptimization.cpp
+  AsyncRuntimeRefCounting.cpp
+  AsyncRuntimeRefCountingOpt.cpp
   AsyncToAsyncRuntime.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
deleted file mode 100644
index fdf326ef50340..0000000000000
--- a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
+++ /dev/null
@@ -1,114 +0,0 @@
-// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
-
-// CHECK-LABEL: @cancellable_operations_0
-func @cancellable_operations_0(%arg0: !async.token) {
-  // CHECK-NOT: async.runtime.add_ref
-  // CHECK-NOT: async.runtime.drop_ref
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @cancellable_operations_1
-func @cancellable_operations_1(%arg0: !async.token) {
-  // CHECK-NOT: async.runtime.add_ref
-  // CHECK: async.execute
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  async.execute [%arg0] {
-    // CHECK: async.runtime.drop_ref
-    async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-    // CHECK-NEXT: async.yield
-    async.yield
-  }
-  // CHECK-NOT: async.runtime.drop_ref
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @cancellable_operations_2
-func @cancellable_operations_2(%arg0: !async.token) {
-  // CHECK: async.await
-  // CHECK-NEXT: async.await
-  // CHECK-NEXT: async.await
-  // CHECK-NEXT: return
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  async.await %arg0 : !async.token
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  async.await %arg0 : !async.token
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  async.await %arg0 : !async.token
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  return
-}
-
-// CHECK-LABEL: @cancellable_operations_3
-func @cancellable_operations_3(%arg0: !async.token) {
-  // CHECK-NOT: add_ref
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  %token = async.execute {
-    async.await %arg0 : !async.token
-    // CHECK: async.runtime.drop_ref
-    async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-    async.yield
-  }
-  // CHECK-NOT: async.runtime.drop_ref
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  // CHECK: async.await
-  async.await %arg0 : !async.token
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @not_cancellable_operations_0
-func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
-  // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
-  // that the body of the `async.execute` operation will run before the await
-  // operation in the function body, and will destroy the `%arg0` token.
-  // CHECK: add_ref
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  %token = async.execute {
-    // CHECK: async.await
-    async.await %arg0 : !async.token
-    // CHECK: async.runtime.drop_ref
-    async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-    // CHECK: async.yield
-    async.yield
-  }
-  // CHECK: async.await
-  async.await %arg0 : !async.token
-  // CHECK: drop_ref
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @not_cancellable_operations_1
-func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
-  // Same reason as above, although `async.execute` is inside the nested
-  // region or "regular" operation.
-  //
-  // NOTE: This test is not correct w.r.t. reference counting, and at runtime
-  // would leak %arg0 value if %arg1 is false. IR like this will not be
-  // constructed by automatic reference counting pass, because it would
-  // place `async.runtime.add_ref` right before the `async.execute`
-  // inside `scf.if`.
-  
-  // CHECK: async.runtime.add_ref
-  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
-  scf.if %arg1 {
-    %token = async.execute {
-      async.await %arg0 : !async.token
-      // CHECK: async.runtime.drop_ref
-      async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-      async.yield
-    }
-  }
-  // CHECK: async.await
-  async.await %arg0 : !async.token
-  // CHECK: async.runtime.drop_ref
-  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
-  // CHECK: return
-  return
-}

diff  --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir
deleted file mode 100644
index 403747c8725b8..0000000000000
--- a/mlir/test/Dialect/Async/async-ref-counting.mlir
+++ /dev/null
@@ -1,253 +0,0 @@
-// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
-
-// CHECK-LABEL: @cond
-func private @cond() -> i1
-
-// CHECK-LABEL: @token_arg_no_uses
-func @token_arg_no_uses(%arg0: !async.token) {
-  // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
-  return
-}
-
-// CHECK-LABEL: @token_arg_conditional_await
-func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
-  cond_br %arg1, ^bb1, ^bb2
-^bb1:
-  // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
-  return
-^bb2:
-  // CHECK: async.await %arg0
-  // CHECK: async.runtime.drop_ref %arg0 {count = 1 : i32}
-  async.await %arg0 : !async.token
-  return
-}
-
-// CHECK-LABEL: @token_no_uses
-func @token_no_uses() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  %token = async.execute {
-    async.yield
-  }
-  return
-}
-
-// CHECK-LABEL: @token_return
-func @token_return() -> !async.token {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  // CHECK: return %[[TOKEN]]
-  return %token : !async.token
-}
-
-// CHECK-LABEL: @token_await
-func @token_await() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  // CHECK: async.await %[[TOKEN]]
-  async.await %token : !async.token
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @token_await_and_return
-func @token_await_and_return() -> !async.token {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  // CHECK: async.await %[[TOKEN]]
-  // CHECK-NOT: async.runtime.drop_ref
-  async.await %token : !async.token
-  // CHECK: return %[[TOKEN]]
-  return %token : !async.token
-}
-
-// CHECK-LABEL: @token_await_inside_scf_if
-func @token_await_inside_scf_if(%arg0: i1) {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  // CHECK: scf.if %arg0 {
-  scf.if %arg0 {
-    // CHECK: async.await %[[TOKEN]]
-    async.await %token : !async.token
-  }
-  // CHECK: }
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @token_conditional_await
-func @token_conditional_await(%arg0: i1) {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  cond_br %arg0, ^bb1, ^bb2
-^bb1:
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  return
-^bb2:
-  // CHECK: async.await %[[TOKEN]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  async.await %token : !async.token
-  return
-}
-
-// CHECK-LABEL: @token_await_in_the_loop
-func @token_await_in_the_loop() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  br ^bb1
-^bb1:
-  // CHECK: async.await %[[TOKEN]]
-  async.await %token : !async.token
-  %0 = call @cond(): () -> (i1)
-  cond_br %0, ^bb1, ^bb2
-^bb2:
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  return
-}
-
-// CHECK-LABEL: @token_defined_in_the_loop
-func @token_defined_in_the_loop() {
-  br ^bb1
-^bb1:
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-  // CHECK: async.await %[[TOKEN]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  async.await %token : !async.token
-  %0 = call @cond(): () -> (i1)
-  cond_br %0, ^bb1, ^bb2
-^bb2:
-  return
-}
-
-// CHECK-LABEL: @token_capture
-func @token_capture() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-
-  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: %[[TOKEN_0:.*]] = async.execute
-  %token_0 = async.execute {
-    // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-    // CHECK-NEXT: async.yield
-    async.await %token : !async.token
-    async.yield
-  }
-  // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @token_nested_capture
-func @token_nested_capture() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-
-  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: %[[TOKEN_0:.*]] = async.execute
-  %token_0 = async.execute {
-    // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-    // CHECK: %[[TOKEN_1:.*]] = async.execute
-    %token_1 = async.execute {
-      // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-      // CHECK: %[[TOKEN_2:.*]] = async.execute
-      %token_2 = async.execute {
-        // CHECK: async.await %[[TOKEN]]
-        // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-        async.await %token : !async.token
-        async.yield
-      }
-      // CHECK: async.runtime.drop_ref %[[TOKEN_2]] {count = 1 : i32}
-      // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-      async.yield
-    }
-    // CHECK: async.runtime.drop_ref %[[TOKEN_1]] {count = 1 : i32}
-    // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-    async.yield
-  }
-  // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @token_dependency
-func @token_dependency() {
-  // CHECK: %[[TOKEN:.*]] = async.execute
-  %token = async.execute {
-    async.yield
-  }
-
-  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: %[[TOKEN_0:.*]] = async.execute
-  %token_0 = async.execute[%token] {
-    // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-    // CHECK-NEXT: async.yield
-    async.yield
-  }
-
-  // CHECK: async.await %[[TOKEN]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  async.await %token : !async.token
-  // CHECK: async.await %[[TOKEN_0]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
-  async.await %token_0 : !async.token
-
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: @value_operand
-func @value_operand() -> f32 {
-  // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
-  %token, %results = async.execute -> !async.value<f32> {
-    %0 = constant 0.0 : f32
-    async.yield %0 : f32
-  }
-
-  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32}
-  // CHECK: async.runtime.add_ref %[[RESULTS]] {count = 1 : i32}
-  // CHECK: %[[TOKEN_0:.*]] = async.execute
-  %token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>)  {
-    // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-    // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
-    // CHECK: async.yield
-    async.yield
-  }
-
-  // CHECK: async.await %[[TOKEN]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
-  async.await %token : !async.token
-
-  // CHECK: async.await %[[TOKEN_0]]
-  // CHECK: async.runtime.drop_ref %[[TOKEN_0]] {count = 1 : i32}
-  async.await %token_0 : !async.token
-
-  // CHECK: async.await %[[RESULTS]]
-  // CHECK: async.runtime.drop_ref %[[RESULTS]] {count = 1 : i32}
-  %0 = async.await %results : !async.value<f32>
-
-  // CHECK: return
-  return %0 : f32
-}

diff  --git a/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
new file mode 100644
index 0000000000000..9b6bb1a5e7515
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-ref-counting-opt.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s -async-runtime-ref-counting-opt | FileCheck %s
+
+func private @consume_token(%arg0: !async.token)
+
+// CHECK-LABEL: @cancellable_operations_0
+func @cancellable_operations_0(%arg0: !async.token) {
+  // CHECK-NOT: async.runtime.add_ref
+  // CHECK-NOT: async.runtime.drop_ref
+  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_1
+func @cancellable_operations_1(%arg0: !async.token) {
+  // CHECK-NOT: async.runtime.add_ref
+  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: call @consume_toke
+  call @consume_token(%arg0): (!async.token) -> ()
+  // CHECK-NOT: async.runtime.drop_ref
+  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_2
+func @cancellable_operations_2(%arg0: !async.token) {
+  // CHECK: async.runtime.await
+  // CHECK-NEXT: async.runtime.await
+  // CHECK-NEXT: async.runtime.await
+  // CHECK-NEXT: return
+  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.runtime.await %arg0 : !async.token
+  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+  async.runtime.await %arg0 : !async.token
+  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+  async.runtime.await %arg0 : !async.token
+  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+  return
+}
+
+// CHECK-LABEL: @cancellable_operations_3
+func @cancellable_operations_3(%arg0: !async.token) {
+  // CHECK-NOT: add_ref
+  async.runtime.add_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: call @consume_toke
+  call @consume_token(%arg0): (!async.token) -> ()
+  // CHECK-NOT: async.runtime.drop_ref
+  async.runtime.drop_ref %arg0 {count = 1 : i32} : !async.token
+  // CHECK: async.runtime.await
+  async.runtime.await %arg0 : !async.token
+  // CHECK: return
+  return
+}

diff  --git a/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir
new file mode 100644
index 0000000000000..40ac40a930097
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-ref-counting.mlir
@@ -0,0 +1,215 @@
+// RUN: mlir-opt %s -async-runtime-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @token
+func private @token() -> !async.token
+
+// CHECK-LABEL: @cond
+func private @cond() -> i1
+
+// CHECK-LABEL: @take_token
+func private @take_token(%arg0: !async.token)
+
+// CHECK-LABEL: @token_arg_no_uses
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_no_uses(%arg0: !async.token) {
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  return
+}
+
+// CHECK-LABEL: @token_value_no_uses
+func @token_value_no_uses() {
+  // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  %0 = async.runtime.create : !async.token
+  return
+}
+
+// CHECK-LABEL: @token_returned_no_uses
+func @token_returned_no_uses() {
+  // CHECK: %[[TOKEN:.*]] = call @token
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  %0 = call @token() : () -> !async.token
+  return
+}
+
+// CHECK-LABEL: @token_arg_to_func
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_to_func(%arg0: !async.token) {
+  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+  call @take_token(%arg0): (!async.token) -> ()
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+  return
+}
+
+// CHECK-LABEL: @token_value_to_func
+func @token_value_to_func() {
+  // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+  %0 = async.runtime.create : !async.token
+  // CHECK: async.runtime.add_ref %[[TOKEN]] {count = 1 : i32} : !async.token
+  call @take_token(%0): (!async.token) -> ()
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  return
+}
+
+// CHECK-LABEL: @token_arg_cond_br_await_with_fallthough
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_arg_cond_br_await_with_fallthough(%arg0: !async.token, %arg1: i1) {
+  // CHECK: cond_br
+  // CHECK-SAME: ^[[BB1:.*]], ^[[BB2:.*]]
+  cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  // CHECK: ^[[BB1]]:
+  // CHECK:   br ^[[BB2]]
+  br ^bb2
+^bb2:
+  // CHECK: ^[[BB2]]:
+  // CHECK:   async.runtime.await %[[TOKEN]]
+  // CHECK:   async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.runtime.await %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @token_simple_return
+func @token_simple_return() -> !async.token {
+  // CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+  %token = async.runtime.create : !async.token
+  // CHECK: return %[[TOKEN]]
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @token_coro_return
+// CHECK-NOT: async.runtime.drop_ref
+// CHECK-NOT: async.runtime.add_ref
+func @token_coro_return() -> !async.token {
+  %token = async.runtime.create : !async.token
+  %id = async.coro.id
+  %hdl = async.coro.begin %id
+  %saved = async.coro.save %hdl
+  async.runtime.resume %hdl
+  async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+  br ^cleanup
+^cleanup:
+  async.coro.free %id, %hdl
+  br ^suspend
+^suspend:
+  async.coro.end %hdl
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @token_coro_await_and_resume
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_coro_await_and_resume(%arg0: !async.token) -> !async.token {
+  %token = async.runtime.create : !async.token
+  %id = async.coro.id
+  %hdl = async.coro.begin %id
+  %saved = async.coro.save %hdl
+  // CHECK: async.runtime.await_and_resume %[[TOKEN]]
+  async.runtime.await_and_resume %arg0, %hdl : !async.token
+  // CHECK-NEXT: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+  br ^cleanup
+^cleanup:
+  async.coro.free %id, %hdl
+  br ^suspend
+^suspend:
+  async.coro.end %hdl
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @value_coro_await_and_resume
+// CHECK: %[[VALUE:.*]]: !async.value<f32>
+func @value_coro_await_and_resume(%arg0: !async.value<f32>) -> !async.token {
+  %token = async.runtime.create : !async.token
+  %id = async.coro.id
+  %hdl = async.coro.begin %id
+  %saved = async.coro.save %hdl
+  // CHECK: async.runtime.await_and_resume %[[VALUE]]
+  async.runtime.await_and_resume %arg0, %hdl : !async.value<f32>
+  // CHECK: async.coro.suspend
+  // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+  async.coro.suspend %saved, ^suspend, ^resume, ^cleanup
+^resume:
+  // CHECK: ^[[RESUME]]:
+  // CHECK:   %[[LOADED:.*]] = async.runtime.load %[[VALUE]]
+  // CHECK:   async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
+  %0 = async.runtime.load %arg0 : !async.value<f32>
+  // CHECK:  addf %[[LOADED]], %[[LOADED]]
+  %1 = addf %0, %0 : f32
+  br ^cleanup
+^cleanup:
+  async.coro.free %id, %hdl
+  br ^suspend
+^suspend:
+  async.coro.end %hdl
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @outlined_async_execute
+// CHECK: %[[TOKEN:.*]]: !async.token
+func private @outlined_async_execute(%arg0: !async.token) -> !async.token {
+  %0 = async.runtime.create : !async.token
+  %1 = async.coro.id
+  %2 = async.coro.begin %1
+  %3 = async.coro.save %2
+  async.runtime.resume %2
+  // CHECK: async.coro.suspend
+  async.coro.suspend %3, ^suspend, ^resume, ^cleanup
+^resume:
+  // CHECK: ^[[RESUME:.*]]:
+  %4 = async.coro.save %2
+  async.runtime.await_and_resume %arg0, %2 : !async.token
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: async.coro.suspend
+  async.coro.suspend %4, ^suspend, ^resume_1, ^cleanup
+^resume_1:
+  // CHECK: ^[[RESUME_1:.*]]:
+  // CHECK:   async.runtime.set_available
+  async.runtime.set_available %0 : !async.token
+  br ^cleanup
+^cleanup:
+  // CHECK: ^[[CLEANUP:.*]]:
+  // CHECK:   async.coro.free
+  async.coro.free %1, %2
+  br ^suspend
+^suspend:
+  // CHECK: ^[[SUSPEND:.*]]:
+  // CHECK:   async.coro.end
+  async.coro.end %2
+  return %0 : !async.token
+}
+
+// CHECK-LABEL: @token_await_inside_nested_region
+// CHECK: %[[ARG:.*]]: i1
+func @token_await_inside_nested_region(%arg0: i1) {
+  // CHECK: %[[TOKEN:.*]] = call @token()
+  %token = call @token() : () -> !async.token
+  // CHECK: scf.if %[[ARG]] {
+  scf.if %arg0 {
+    // CHECK: async.runtime.await %[[TOKEN]]
+    async.runtime.await %token : !async.token
+  }
+  // CHECK: }
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: @token_defined_in_the_loop
+func @token_defined_in_the_loop() {
+  br ^bb1
+^bb1:
+  // CHECK: ^[[BB1:.*]]:
+  // CHECK:   %[[TOKEN:.*]] = call @token()
+  %token = call @token() : () -> !async.token
+  // CHECK:   async.runtime.await %[[TOKEN]]
+  // CHECK:   async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.runtime.await %token : !async.token
+  %0 = call @cond(): () -> (i1)
+  cond_br %0, ^bb1, ^bb2
+^bb2:
+  // CHECK: ^[[BB2:.*]]:
+  // CHECK:   return
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
index 0e29209c4fab7..dee9b1cd62eac 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
@@ -1,8 +1,9 @@
 // RUN:   mlir-opt %s                                                          \
 // RUN:               -linalg-tile-to-parallel-loops="linalg-tile-sizes=256"   \
 // RUN:               -async-parallel-for="num-concurrent-async-execute=4"     \
-// RUN:               -async-ref-counting                                      \
 // RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -lower-affine                                            \
 // RUN:               -convert-linalg-to-loops                                 \

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index e4d19bf5c2f9f..9f05ec8065dc5 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -1,6 +1,7 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
-// RUN:               -async-ref-counting                                      \
 // RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index d0f688c61226b..883a0bc4fab7b 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -1,6 +1,7 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
 // RUN:               -async-to-async-runtime                                  \
-// RUN:               -async-ref-counting                                      \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \

diff  --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
index a9149aef67d26..8216d1558c639 100644
--- a/mlir/test/mlir-cpu-runner/async-group.mlir
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -1,5 +1,6 @@
-// RUN:   mlir-opt %s -async-ref-counting                                      \
-// RUN:               -async-to-async-runtime                                  \
+// RUN:   mlir-opt %s -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-std-to-llvm                                     \
 // RUN: | mlir-cpu-runner                                                      \

diff  --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir
index e58f9e06c1aee..878256ea3ad47 100644
--- a/mlir/test/mlir-cpu-runner/async-value.mlir
+++ b/mlir/test/mlir-cpu-runner/async-value.mlir
@@ -1,5 +1,6 @@
-// RUN:   mlir-opt %s -async-ref-counting                                      \
-// RUN:               -async-to-async-runtime                                  \
+// RUN:   mlir-opt %s -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-vector-to-llvm                                  \
 // RUN:               -convert-std-to-llvm                                     \

diff  --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir
index 30967928fc62b..417713e383577 100644
--- a/mlir/test/mlir-cpu-runner/async.mlir
+++ b/mlir/test/mlir-cpu-runner/async.mlir
@@ -1,5 +1,6 @@
-// RUN:   mlir-opt %s -async-ref-counting                                      \
-// RUN:               -async-to-async-runtime                                  \
+// RUN:   mlir-opt %s -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-linalg-to-loops                                 \
 // RUN:               -convert-scf-to-std                                      \


        


More information about the Mlir-commits mailing list