[Mlir-commits] [mlir] a86a9b5 - [mlir] Automatic reference counting for Async values + runtime support for ref counted objects
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Nov 20 03:08:54 PST 2020
Author: Eugene Zhulenev
Date: 2020-11-20T03:08:44-08:00
New Revision: a86a9b5ef777552d1683e2b9031e2045d39de2f0
URL: https://github.com/llvm/llvm-project/commit/a86a9b5ef777552d1683e2b9031e2045d39de2f0
DIFF: https://github.com/llvm/llvm-project/commit/a86a9b5ef777552d1683e2b9031e2045d39de2f0.diff
LOG: [mlir] Automatic reference counting for Async values + runtime support for ref counted objects
Depends On D89963
**Automatic reference counting algorithm outline:**
1. `ReturnLike` operations forward the reference counted values without
modifying the reference count.
2. Use liveness analysis to find blocks in the CFG where the lifetime of
reference counted values ends, and insert `drop_ref` operations after
the last use of the value.
3. Insert `add_ref` before the `async.execute` operation capturing the
value, and pairing `drop_ref` before the async body region terminator,
to release the captured reference counted value when execution
completes.
4. If the reference counted value is passed only to some of the block
successors, insert `drop_ref` operations in the beginning of the blocks
that do not have reference coutned value uses.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D90716
Added:
mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
mlir/test/Dialect/Async/async-ref-counting.mlir
Modified:
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/ops.mlir
mlir/test/mlir-cpu-runner/async-group.mlir
mlir/test/mlir-cpu-runner/async.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index ad5a8aa03098..d0664b08c0fb 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -53,6 +53,16 @@ class GroupType : public Type::TypeBase<GroupType, Type, TypeStorage> {
using Base::Base;
};
+// -------------------------------------------------------------------------- //
+// Helper functions of Async dialect transformations.
+// -------------------------------------------------------------------------- //
+
+/// Returns true if the type is reference counted. All async dialect types are
+/// reference counted at runtime.
+inline bool isRefCounted(Type type) {
+ return type.isa<TokenType, ValueType, GroupType>();
+}
+
} // namespace async
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
index e7a5e90298da..e33a9e286b7f 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
@@ -73,4 +73,8 @@ def Async_AnyValueType : DialectType<AsyncDialect,
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
Async_TokenType]>;
+def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
+ Async_TokenType,
+ Async_GroupType]>;
+
#endif // ASYNC_BASE_TD
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index cc987856a28e..80aeabf5f904 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -227,4 +227,62 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
let assemblyFormat = "$operand attr-dict";
}
+//===----------------------------------------------------------------------===//
+// Async Dialect Automatic Reference Counting Operations.
+//===----------------------------------------------------------------------===//
+
+// All async values (values, tokens, groups) are reference counted at runtime
+// and automatically destructed when reference count drops to 0.
+//
+// All values are semantically created with a reference count of +1 and it is
+// the responsibility of the last async value user to drop reference count.
+//
+// Async values created when:
+// 1. Operation returns async result (e.g. the result of an `async.execute`).
+// 2. Async value passed in as a block argument.
+//
+// It is the responsiblity of the async value user to extend the lifetime by
+// adding a +1 reference, if the reference counted value captured by the
+// asynchronously executed region (`async.execute` operation), and drop it after
+// the last nested use.
+//
+// Reference counting operations can be added to the IR using automatic
+// reference count pass, that relies on liveness analysis to find the last uses
+// of all reference counted values and automatically inserts
+// `drop_ref` operations.
+//
+// See `AsyncRefCountingPass` documentation for the implementation details.
+
+def Async_AddRefOp : Async_Op<"add_ref"> {
+ let summary = "adds a reference to async value";
+ let description = [{
+ The `async.add_ref` operation adds a reference(s) to async value (token,
+ value or group).
+ }];
+
+ let arguments = (ins Async_AnyAsyncType:$operand,
+ Confined<I32Attr, [IntPositive]>:$count);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand)
+ }];
+}
+
+def Async_DropRefOp : Async_Op<"drop_ref"> {
+ let summary = "drops a reference to async value";
+ let description = [{
+ The `async.drop_ref` operation drops a reference(s) to async value (token,
+ value or group).
+ }];
+
+ let arguments = (ins Async_AnyAsyncType:$operand,
+ Confined<I32Attr, [IntPositive]>:$count);
+ let results = (outs );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand)
+ }];
+}
+
#endif // ASYNC_OPS
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index d5a8a82dab49..9716bde76593 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,6 +19,10 @@ namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
+
+std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 51fd4e32c78e..140a3b41162a 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -24,4 +24,18 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> {
let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
}
+def AsyncRefCounting : FunctionPass<"async-ref-counting"> {
+ let summary = "Automatic reference counting for Async dialect data types";
+ let constructor = "mlir::createAsyncRefCountingPass()";
+ let dependentDialects = ["async::AsyncDialect"];
+}
+
+def AsyncRefCountingOptimization :
+ FunctionPass<"async-ref-counting-optimization"> {
+ let summary = "Optimize automatic reference counting operations for the"
+ "Async dialect by removing redundant operations";
+ let constructor = "mlir::createAsyncRefCountingOptimizationPass()";
+ let dependentDialects = ["async::AsyncDialect"];
+}
+
#endif // MLIR_DIALECT_ASYNC_PASSES
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index 12beffe9dd1c..26b0a236f0d3 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -48,6 +48,18 @@ typedef struct AsyncGroup MLIR_AsyncGroup;
using CoroHandle = void *; // coroutine handle
using CoroResume = void (*)(void *); // coroutine resume function
+// Async runtime uses reference counting to manage the lifetime of async values
+// (values of async types like tokens, values and groups).
+using RefCountedObjPtr = void *;
+
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+ mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t);
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+ mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t);
+
// Create a new `async.token` in not-ready state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 061877105283..74c0556c4bd0 100644
--- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-ref-counting \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index 79fa4c2e2c3c..196ab89b59e0 100644
--- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-ref-counting \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-std-to-llvm \
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 0cbf3debd894..b08f7e4c45b7 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -33,6 +33,8 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
// Async Runtime C API declaration.
//===----------------------------------------------------------------------===//
+static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
+static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
@@ -49,6 +51,12 @@ static constexpr const char *kAwaitAllAndExecute =
namespace {
// Async Runtime API function types.
struct AsyncAPI {
+ static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
+ auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
+ auto count = IntegerType::get(32, ctx);
+ return FunctionType::get({ref, count}, {}, ctx);
+ }
+
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
}
@@ -113,6 +121,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
};
MLIRContext *ctx = module.getContext();
+ addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+ addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
@@ -121,7 +131,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
- addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitAllAndExecute,
+ AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
}
//===----------------------------------------------------------------------===//
@@ -588,6 +599,55 @@ class CallOpOpConversion : public ConversionPattern {
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
+// to the corresponding API calls).
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename RefCountingOp>
+class RefCountingOpLowering : public ConversionPattern {
+public:
+ explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
+ : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
+ apiFunctionName(apiFunctionName) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ RefCountingOp refCountingOp = cast<RefCountingOp>(op);
+
+ auto count = rewriter.create<ConstantOp>(
+ op->getLoc(), rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(refCountingOp.count()));
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
+ ValueRange({operands[0], count}));
+
+ return success();
+ }
+
+private:
+ StringRef apiFunctionName;
+};
+
+// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
+class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
+public:
+ explicit AddRefOpLowering(MLIRContext *ctx)
+ : RefCountingOpLowering(ctx, kAddRef) {}
+};
+
+// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
+class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
+public:
+ explicit DropRefOpLowering(MLIRContext *ctx)
+ : RefCountingOpLowering(ctx, kDropRef) {}
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
//===----------------------------------------------------------------------===//
@@ -794,10 +854,12 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
patterns.insert<CallOpOpConversion>(ctx);
+ patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
ConversionTarget target(*ctx);
+ target.addLegalOp<ConstantOp>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<AsyncDialect>();
target.addDynamicallyLegalOp<FuncOp>(
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
new file mode 100644
index 000000000000..ea1da590aeea
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
@@ -0,0 +1,324 @@
+//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements automatic reference counting for Async dialect data
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
+public:
+ AsyncRefCountingPass() = default;
+ void runOnFunction() override;
+
+private:
+ /// Adds an automatic reference counting to the `value`.
+ ///
+ /// All values are semantically created with a reference count of +1 and it is
+ /// the responsibility of the last async value user to drop reference count.
+ ///
+ /// Async values created when:
+ /// 1. Operation returns async result (e.g. the result of an
+ /// `async.execute`).
+ /// 2. Async value passed in as a block argument.
+ ///
+ /// To implement automatic reference counting, we must insert a +1 reference
+ /// before each `async.execute` operation using the value, and drop it after
+ /// the last use inside the async body region (we currently drop the reference
+ /// before the `async.yield` terminator).
+ ///
+ /// Automatic reference counting algorithm outline:
+ ///
+ /// 1. `ReturnLike` operations forward the reference counted values without
+ /// modifying the reference count.
+ ///
+ /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
+ /// reference counted values ends, and insert `drop_ref` operations after
+ /// the last use of the value.
+ ///
+ /// 3. Insert `add_ref` before the `async.execute` operation capturing the
+ /// value, and pairing `drop_ref` before the async body region terminator,
+ /// to release the captured reference counted value when execution
+ /// completes.
+ ///
+ /// 4. If the reference counted value is passed only to some of the block
+ /// successors, insert `drop_ref` operations in the beginning of the blocks
+ /// that do not have reference counted value uses.
+ ///
+ ///
+ /// Example:
+ ///
+ /// %token = ...
+ /// async.execute {
+ /// async.await %token : !async.token // await #1
+ /// async.yield
+ /// }
+ /// async.await %token : !async.token // await #2
+ ///
+ /// Based on the liveness analysis await #2 is the last use of the %token,
+ /// however the execution of the async region can be delayed, and to guarantee
+ /// that the %token is still alive when await #1 executes we need to
+ /// explicitly extend its lifetime using `add_ref` operation.
+ ///
+ /// After automatic reference counting:
+ ///
+ /// %token = ...
+ ///
+ /// // Make sure that %token is alive inside async.execute.
+ /// async.add_ref %token {count = 1 : i32} : !async.token
+ ///
+ /// async.execute {
+ /// async.await %token : !async.token // await #1
+ ///
+ /// // Drop the extra reference added to keep %token alive.
+ /// async.drop_ref %token {count = 1 : i32} : !async.token
+ ///
+ /// async.yied
+ /// }
+ /// async.await %token : !async.token // await #2
+ ///
+ /// // Drop the reference after the last use of %token.
+ /// async.drop_ref %token {count = 1 : i32} : !async.token
+ ///
+ LogicalResult addAutomaticRefCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
+ MLIRContext *ctx = value.getContext();
+ OpBuilder builder(ctx);
+
+ // Set inserton point after the operation producing a value, or at the
+ // beginning of the block if the value defined by the block argument.
+ if (Operation *op = value.getDefiningOp())
+ builder.setInsertionPointAfter(op);
+ else
+ builder.setInsertionPointToStart(value.getParentBlock());
+
+ Location loc = value.getLoc();
+ auto i32 = IntegerType::get(32, ctx);
+
+ // Drop the reference count immediately if the value has no uses.
+ if (value.getUses().empty()) {
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ return success();
+ }
+
+ // Use liveness analysis to find the placement of `drop_ref`operation.
+ auto liveness = getAnalysis<Liveness>();
+
+ // We analyse only the blocks of the region that defines the `value`, and do
+ // not check nested blocks attached to operations.
+ //
+ // By analyzing only the `definingRegion` CFG we potentially loose an
+ // opportunity to drop the reference count earlier and can extend the lifetime
+ // of reference counted value longer then it is really required.
+ //
+ // We also assume that all nested regions finish their execution before the
+ // completion of the owner operation. The only exception to this rule is
+ // `async.execute` operation, which is handled explicitly below.
+ Region *definingRegion = value.getParentRegion();
+
+ // ------------------------------------------------------------------------ //
+ // Find blocks where the `value` dies: the value is in `liveIn` set and not
+ // in the `liveOut` set. We place `drop_ref` immediately after the last use
+ // of the `value` in such regions.
+ // ------------------------------------------------------------------------ //
+
+ // Last users of the `value` inside all blocks where the value dies.
+ llvm::SmallSet<Operation *, 4> lastUsers;
+
+ for (Block &block : definingRegion->getBlocks()) {
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+ // Value in live input set or was defined in the block.
+ bool liveIn = blockLiveness->isLiveIn(value) ||
+ blockLiveness->getBlock() == value.getParentBlock();
+ if (!liveIn)
+ continue;
+
+ // Value is in the live out set.
+ bool liveOut = blockLiveness->isLiveOut(value);
+ if (liveOut)
+ continue;
+
+ // We proved that `value` dies in the `block`. Now find the last use of the
+ // `value` inside the `block`.
+
+ // Find any user of the `value` inside the block (including uses in nested
+ // regions attached to the operations in the block).
+ Operation *userInTheBlock = nullptr;
+ for (Operation *user : value.getUsers()) {
+ userInTheBlock = block.findAncestorOpInBlock(*user);
+ if (userInTheBlock)
+ break;
+ }
+
+ // Values with zero users handled explicitly in the beginning, if the value
+ // is in live out set it must have at least one use in the block.
+ assert(userInTheBlock && "value must have a user in the block");
+
+ // Find the last user of the `value` in the block;
+ Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
+ assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
+ lastUsers.insert(lastUser);
+ }
+
+ // Process all the last users of the `value` inside each block where the value
+ // dies.
+ for (Operation *lastUser : lastUsers) {
+ // Return like operations forward reference count.
+ if (lastUser->hasTrait<OpTrait::ReturnLike>())
+ continue;
+
+ // We can't currently handle other types of terminators.
+ if (lastUser->hasTrait<OpTrait::IsTerminator>())
+ return lastUser->emitError() << "async reference counting can't handle "
+ "terminators that are not ReturnLike";
+
+ // Add a drop_ref immediately after the last user.
+ builder.setInsertionPointAfter(lastUser);
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ // ------------------------------------------------------------------------ //
+ // Find blocks where the `value` is in `liveOut` set, however it is not in
+ // the `liveIn` set of all successors. If the `value` is not in the successor
+ // `liveIn` set, we add a `drop_ref` to the beginning of it.
+ // ------------------------------------------------------------------------ //
+
+ // Successors that we'll need a `drop_ref` for the `value`.
+ llvm::SmallSet<Block *, 4> dropRefSuccessors;
+
+ for (Block &block : definingRegion->getBlocks()) {
+ const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
+
+ // Skip the block if value is not in the `liveOut` set.
+ if (!blockLiveness->isLiveOut(value))
+ continue;
+
+ // Find successors that do not have `value` in the `liveIn` set.
+ for (Block *successor : block.getSuccessors()) {
+ const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
+
+ if (!succLiveness->isLiveIn(value))
+ dropRefSuccessors.insert(successor);
+ }
+ }
+
+ // Drop reference in all successor blocks that do not have the `value` in
+ // their `liveIn` set.
+ for (Block *dropRefSuccessor : dropRefSuccessors) {
+ builder.setInsertionPointToStart(dropRefSuccessor);
+ builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ // ------------------------------------------------------------------------ //
+ // Find all `async.execute` operation that take `value` as an operand
+ // (dependency token or async value), or capture implicitly by the nested
+ // region. Each `async.execute` operation will require `add_ref` operation
+ // to keep all captured values alive until it will finish its execution.
+ // ------------------------------------------------------------------------ //
+
+ llvm::SmallSet<ExecuteOp, 4> executeOperations;
+
+ auto trackAsyncExecute = [&](Operation *op) {
+ if (auto execute = dyn_cast<ExecuteOp>(op))
+ executeOperations.insert(execute);
+ };
+
+ for (Operation *user : value.getUsers()) {
+ // Follow parent operations up until the operation in the `definingRegion`.
+ while (user->getParentRegion() != definingRegion) {
+ trackAsyncExecute(user);
+ user = user->getParentOp();
+ assert(user != nullptr && "value user lies outside of the value region");
+ }
+
+ // Don't forget to process the parent in the `definingRegion` (can be the
+ // original user operation itself).
+ trackAsyncExecute(user);
+ }
+
+ // Process all `async.execute` operations capturing `value`.
+ for (ExecuteOp execute : executeOperations) {
+ // Add a reference before the execute operation to keep the reference
+ // counted alive before the async region completes execution.
+ builder.setInsertionPoint(execute.getOperation());
+ builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1));
+
+ // Drop the reference inside the async region before completion.
+ OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
+ executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
+ }
+
+ return success();
+}
+
+void AsyncRefCountingPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
+ // because otherwise automatic reference counting will produce incorrect
+ // results.
+ WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
+ if (isa<AddRefOp, DropRefOp>(op))
+ return op->emitError() << "explicit reference counting is not supported";
+ return WalkResult::advance();
+ });
+
+ if (refCountingWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Add reference counting to block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(addAutomaticRefCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Add reference counting to operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(addAutomaticRefCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
+ return std::make_unique<AsyncRefCountingPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
new file mode 100644
index 000000000000..cbcb30c5276a
--- /dev/null
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp
@@ -0,0 +1,218 @@
+//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Optimize Async dialect reference counting operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+using namespace mlir::async;
+
+#define DEBUG_TYPE "async-ref-counting"
+
+namespace {
+
+class AsyncRefCountingOptimizationPass
+ : public AsyncRefCountingOptimizationBase<
+ AsyncRefCountingOptimizationPass> {
+public:
+ AsyncRefCountingOptimizationPass() = default;
+ void runOnFunction() override;
+
+private:
+ LogicalResult optimizeReferenceCounting(Value value);
+};
+
+} // namespace
+
+LogicalResult
+AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
+ Region *definingRegion = value.getParentRegion();
+
+ // Find all users of the `value` inside each block, including operations that
+ // do not use `value` directly, but have a direct use inside nested region(s).
+ //
+ // Example:
+ //
+ // ^bb1:
+ // %token = ...
+ // scf.if %cond {
+ // ^bb2:
+ // async.await %token : !async.token
+ // }
+ //
+ // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
+ //
+ // In addition to the operation that uses the `value` we also keep track if
+ // this user is an `async.execute` operation itself, or has `async.execute`
+ // operations in the nested regions that do use the `value`.
+
+ struct UserInfo {
+ Operation *operation;
+ bool hasExecuteUser;
+ };
+
+ struct BlockUsersInfo {
+ llvm::SmallVector<AddRefOp, 4> addRefs;
+ llvm::SmallVector<DropRefOp, 4> dropRefs;
+ llvm::SmallVector<UserInfo, 4> users;
+ };
+
+ llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
+
+ auto updateBlockUsersInfo = [&](UserInfo user) {
+ BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
+ info.users.push_back(user);
+
+ if (auto addRef = dyn_cast<AddRefOp>(user.operation))
+ info.addRefs.push_back(addRef);
+ if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
+ info.dropRefs.push_back(dropRef);
+ };
+
+ for (Operation *user : value.getUsers()) {
+ bool isAsyncUser = isa<ExecuteOp>(user);
+
+ while (user->getParentRegion() != definingRegion) {
+ updateBlockUsersInfo({user, isAsyncUser});
+ user = user->getParentOp();
+ isAsyncUser |= isa<ExecuteOp>(user);
+ assert(user != nullptr && "value user lies outside of the value region");
+ }
+
+ updateBlockUsersInfo({user, isAsyncUser});
+ }
+
+ // Sort all operations found in the block.
+ auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
+ auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
+ return a->isBeforeInBlock(b);
+ };
+ llvm::sort(info.addRefs, isBeforeInBlock);
+ llvm::sort(info.dropRefs, isBeforeInBlock);
+ llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
+ return isBeforeInBlock(a.operation, b.operation);
+ });
+
+ return info;
+ };
+
+ // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
+ // blocks that modify the reference count of the `value`.
+ for (auto &kv : blockUsers) {
+ BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
+
+ // Find all cancellable pairs first and erase them later to keep all
+ // pointers in the `info` valid until the end.
+ //
+ // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
+ llvm::SmallDenseMap<Operation *, Operation *> cancellable;
+
+ for (AddRefOp addRef : info.addRefs) {
+ for (DropRefOp dropRef : info.dropRefs) {
+ // `drop_ref` operation after the `add_ref` with matching count.
+ if (dropRef.count() != addRef.count() ||
+ dropRef.getOperation()->isBeforeInBlock(addRef.getOperation()))
+ continue;
+
+ // `drop_ref` was already marked for removal.
+ if (cancellable.find(dropRef.getOperation()) != cancellable.end())
+ continue;
+
+ // Check `value` users between `addRef` and `dropRef` in the `block`.
+ Operation *addRefOp = addRef.getOperation();
+ Operation *dropRefOp = dropRef.getOperation();
+
+ // If there is a "regular" user after the `async.execute` user it is
+ // unsafe to erase cancellable reference counting operations pair,
+ // because async region can complete before the "regular" user and
+ // destroy the reference counted value.
+ bool hasExecuteUser = false;
+ bool unsafeToCancel = false;
+
+ for (UserInfo &user : info.users) {
+ Operation *op = user.operation;
+
+ // `user` operation lies after `addRef` ...
+ if (op == addRefOp || op->isBeforeInBlock(addRefOp))
+ continue;
+ // ... and before `dropRef`.
+ if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
+ break;
+
+ bool isRegularUser = !user.hasExecuteUser;
+ bool isExecuteUser = user.hasExecuteUser;
+
+ // It is unsafe to cancel `addRef` / `dropRef` pair.
+ if (isRegularUser && hasExecuteUser) {
+ unsafeToCancel = true;
+ break;
+ }
+
+ hasExecuteUser |= isExecuteUser;
+ }
+
+ // Mark the pair of reference counting operations for removal.
+ if (!unsafeToCancel)
+ cancellable[dropRef.getOperation()] = addRef.getOperation();
+
+ // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
+ // all the following pairs will be also unsafe.
+ break;
+ }
+ }
+
+ // Erase all cancellable `addRef <-> dropRef` operation pairs.
+ for (auto &kv : cancellable) {
+ kv.first->erase();
+ kv.second->erase();
+ }
+ }
+
+ return success();
+}
+
+void AsyncRefCountingOptimizationPass::runOnFunction() {
+ FuncOp func = getFunction();
+
+ // Optimize reference counting for values defined by block arguments.
+ WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(optimizeReferenceCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ signalPassFailure();
+
+ // Optimize reference counting for values defined by operation results.
+ WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(optimizeReferenceCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncRefCountingOptimizationPass() {
+ return std::make_unique<AsyncRefCountingOptimizationPass>();
+}
diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
index 9de43873039d..dccae73d9bee 100644
--- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRAsyncTransforms
AsyncParallelFor.cpp
+ AsyncRefCounting.cpp
+ AsyncRefCountingOptimization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 332c7ff1e2b9..f769965b26ec 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -16,6 +16,7 @@
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
#include <atomic>
+#include <cassert>
#include <condition_variable>
#include <functional>
#include <iostream>
@@ -27,30 +28,141 @@
// Async runtime API.
//===----------------------------------------------------------------------===//
-struct AsyncToken {
- bool ready = false;
+namespace {
+
+// Forward declare class defined below.
+class RefCounted;
+
+// -------------------------------------------------------------------------- //
+// AsyncRuntime orchestrates all async operations and Async runtime API is built
+// on top of the default runtime instance.
+// -------------------------------------------------------------------------- //
+
+class AsyncRuntime {
+public:
+ AsyncRuntime() : numRefCountedObjects(0) {}
+
+ ~AsyncRuntime() {
+ assert(getNumRefCountedObjects() == 0 &&
+ "all ref counted objects must be destroyed");
+ }
+
+ int32_t getNumRefCountedObjects() {
+ return numRefCountedObjects.load(std::memory_order_relaxed);
+ }
+
+private:
+ friend class RefCounted;
+
+ // Count the total number of reference counted objects in this instance
+ // of an AsyncRuntime. For debugging purposes only.
+ void addNumRefCountedObjects() {
+ numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
+ }
+ void dropNumRefCountedObjects() {
+ numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
+ }
+
+ std::atomic<int32_t> numRefCountedObjects;
+};
+
+// Returns the default per-process instance of an async runtime.
+AsyncRuntime *getDefaultAsyncRuntimeInstance() {
+ static auto runtime = std::make_unique<AsyncRuntime>();
+ return runtime.get();
+}
+
+// -------------------------------------------------------------------------- //
+// A base class for all reference counted objects created by the async runtime.
+// -------------------------------------------------------------------------- //
+
+class RefCounted {
+public:
+ RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
+ : runtime(runtime), refCount(refCount) {
+ runtime->addNumRefCountedObjects();
+ }
+
+ virtual ~RefCounted() {
+ assert(refCount.load() == 0 && "reference count must be zero");
+ runtime->dropNumRefCountedObjects();
+ }
+
+ RefCounted(const RefCounted &) = delete;
+ RefCounted &operator=(const RefCounted &) = delete;
+
+ void addRef(int32_t count = 1) { refCount.fetch_add(count); }
+
+ void dropRef(int32_t count = 1) {
+ int32_t previous = refCount.fetch_sub(count);
+ assert(previous >= count && "reference count should not go below zero");
+ if (previous == count)
+ destroy();
+ }
+
+protected:
+ virtual void destroy() { delete this; }
+
+private:
+ AsyncRuntime *runtime;
+ std::atomic<int32_t> refCount;
+};
+
+} // namespace
+
+struct AsyncToken : public RefCounted {
+ // AsyncToken created with a reference count of 2 because it will be returned
+ // to the `async.execute` caller and also will be later on emplaced by the
+ // asynchronously executed task. If the caller immediately will drop its
+ // reference we must ensure that the token will be alive until the
+ // asynchronous operation is completed.
+ AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
+
+ // Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
+
+ bool ready = false;
std::vector<std::function<void()>> awaiters;
};
-struct AsyncGroup {
- std::atomic<int> pendingTokens{0};
- std::atomic<int> rank{0};
+struct AsyncGroup : public RefCounted {
+ AsyncGroup(AsyncRuntime *runtime)
+ : RefCounted(runtime), pendingTokens(0), rank(0) {}
+
+ std::atomic<int> pendingTokens;
+ std::atomic<int> rank;
+
+ // Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
+
std::vector<std::function<void()>> awaiters;
};
+// Adds references to reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
+ RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+ refCounted->addRef(count);
+}
+
+// Drops references from reference counted runtime object.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
+ RefCounted *refCounted = static_cast<RefCounted *>(ptr);
+ refCounted->dropRef(count);
+}
+
// Create a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
- AsyncToken *token = new AsyncToken;
+ AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
return token;
}
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
- AsyncGroup *group = new AsyncGroup;
+ AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
return group;
}
@@ -59,23 +171,34 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
+ // Get the rank of the token inside the group before we drop the reference.
+ int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
- auto onTokenReady = [group]() {
+ auto onTokenReady = [group, token](bool dropRef) {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
+
+ // We no longer need the token or the group, drop references on them.
+ if (dropRef) {
+ group->dropRef();
+ token->dropRef();
+ }
};
- if (token->ready)
- onTokenReady();
- else
- token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
+ if (token->ready) {
+ onTokenReady(false);
+ } else {
+ group->addRef();
+ token->addRef();
+ token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
+ }
- return group->rank.fetch_add(1);
+ return rank;
}
// Switches `async.token` to ready state and runs all awaiters.
@@ -85,6 +208,10 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
+
+ // Async tokens created with a ref count `2` to keep token alive until the
+ // async task completes. Drop this reference explicitly when token emplaced.
+ token->dropRef();
}
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
@@ -114,14 +241,18 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
- auto execute = [handle, resume]() {
+ auto execute = [handle, resume, token](bool dropRef) {
+ if (dropRef)
+ token->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
- if (token->ready)
- execute();
- else
- token->awaiters.push_back([execute]() { execute(); });
+ if (token->ready) {
+ execute(false);
+ } else {
+ token->addRef();
+ token->awaiters.push_back([execute]() { execute(true); });
+ }
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
@@ -129,14 +260,18 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
- auto execute = [handle, resume]() {
+ auto execute = [handle, resume, group](bool dropRef) {
+ if (dropRef)
+ group->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
- if (group->pendingTokens == 0)
- execute();
- else
- group->awaiters.push_back([execute]() { execute(); });
+ if (group->pendingTokens == 0) {
+ execute(false);
+ } else {
+ group->addRef();
+ group->awaiters.push_back([execute]() { execute(true); });
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index 1fd71a65379e..dadb28dbc082 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -1,5 +1,20 @@
// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s
+// CHECK-LABEL: reference_counting
+func @reference_counting(%arg0: !async.token) {
+ // CHECK: %[[C2:.*]] = constant 2 : i32
+ // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]])
+ async.add_ref %arg0 {count = 2 : i32} : !async.token
+
+ // CHECK: %[[C1:.*]] = constant 1 : i32
+ // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]])
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: execute_no_async_args
func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
new file mode 100644
index 000000000000..6500fa0b1d8a
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s
+
+// CHECK-LABEL: @cancellable_operations_0
+func @cancellable_operations_0(%arg0: !async.token) {
+ // CHECK-NOT: async.add_ref
+ // CHECK-NOT: async.drop_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_1
+func @cancellable_operations_1(%arg0: !async.token) {
+ // CHECK-NOT: async.add_ref
+ // CHECK: async.execute
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.execute [%arg0] {
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK-NEXT: async.yield
+ async.yield
+ }
+ // CHECK-NOT: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_2
+func @cancellable_operations_2(%arg0: !async.token) {
+ // CHECK: async.await
+ // CHECK-NEXT: async.await
+ // CHECK-NEXT: async.await
+ // CHECK-NEXT: return
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ async.await %arg0 : !async.token
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @cancellable_operations_3
+func @cancellable_operations_3(%arg0: !async.token) {
+ // CHECK-NOT: add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ %token = async.execute {
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.yield
+ }
+ // CHECK-NOT: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_0
+func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) {
+ // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible
+ // that the body of the `async.execute` operation will run before the await
+ // operation in the function body, and will destroy the `%arg0` token.
+ // CHECK: add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ %token = async.execute {
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: async.yield
+ async.yield
+ }
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @not_cancellable_operations_1
+func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) {
+ // Same reason as above, although `async.execute` is inside the nested
+ // region or "regular" opeation.
+ //
+ // NOTE: This test is not correct w.r.t. reference counting, and at runtime
+ // would leak %arg0 value if %arg1 is false. IR like this will not be
+ // constructed by automatic reference counting pass, because it would
+ // place `async.add_ref` right before the `async.execute` inside `scf.if`.
+
+ // CHECK: async.add_ref
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ scf.if %arg1 {
+ %token = async.execute {
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ async.yield
+ }
+ }
+ // CHECK: async.await
+ async.await %arg0 : !async.token
+ // CHECK: async.drop_ref
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ // CHECK: return
+ return
+}
diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir
new file mode 100644
index 000000000000..504a18fba990
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-ref-counting.mlir
@@ -0,0 +1,253 @@
+// RUN: mlir-opt %s -async-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @cond
+func private @cond() -> i1
+
+// CHECK-LABEL: @token_arg_no_uses
+func @token_arg_no_uses(%arg0: !async.token) {
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_arg_conditional_await
+func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) {
+ cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ return
+^bb2:
+ // CHECK: async.await %arg0
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ async.await %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_no_uses
+func @token_no_uses() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ %token = async.execute {
+ async.yield
+ }
+ return
+}
+
+// CHECK-LABEL: @token_return
+func @token_return() -> !async.token {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: return %[[TOKEN]]
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await
+func @token_await() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_await_and_return
+func @token_await_and_return() -> !async.token {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK-NOT: async.drop_ref
+ async.await %token : !async.token
+ // CHECK: return %[[TOKEN]]
+ return %token : !async.token
+}
+
+// CHECK-LABEL: @token_await_inside_scf_if
+func @token_await_inside_scf_if(%arg0: i1) {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: scf.if %arg0 {
+ scf.if %arg0 {
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ }
+ // CHECK: }
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_conditional_await
+func @token_conditional_await(%arg0: i1) {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ cond_br %arg0, ^bb1, ^bb2
+^bb1:
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+^bb2:
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ return
+}
+
+// CHECK-LABEL: @token_await_in_the_loop
+func @token_await_in_the_loop() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ br ^bb1
+^bb1:
+ // CHECK: async.await %[[TOKEN]]
+ async.await %token : !async.token
+ %0 = call @cond(): () -> (i1)
+ cond_br %0, ^bb1, ^bb2
+^bb2:
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ return
+}
+
+// CHECK-LABEL: @token_defined_in_the_loop
+func @token_defined_in_the_loop() {
+ br ^bb1
+^bb1:
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ %0 = call @cond(): () -> (i1)
+ cond_br %0, ^bb1, ^bb2
+^bb2:
+ return
+}
+
+// CHECK-LABEL: @token_capture
+func @token_capture() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK-NEXT: async.yield
+ async.await %token : !async.token
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_nested_capture
+func @token_nested_capture() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute {
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_1:.*]] = async.execute
+ %token_1 = async.execute {
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_2:.*]] = async.execute
+ %token_2 = async.execute {
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.yield
+ }
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @token_dependency
+func @token_dependency() {
+ // CHECK: %[[TOKEN:.*]] = async.execute
+ %token = async.execute {
+ async.yield
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute[%token] {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK-NEXT: async.yield
+ async.yield
+ }
+
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+ // CHECK: async.await %[[TOKEN_0]]
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ async.await %token_0 : !async.token
+
+ // CHECK: return
+ return
+}
+
+// CHECK-LABEL: @value_operand
+func @value_operand() -> f32 {
+ // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute
+ %token, %results = async.execute -> !async.value<f32> {
+ %0 = constant 0.0 : f32
+ async.yield %0 : f32
+ }
+
+ // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32}
+ // CHECK: %[[TOKEN_0:.*]] = async.execute
+ %token_0 = async.execute[%token](%results as %arg0 : !async.value<f32>) {
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+ // CHECK: async.yield
+ async.yield
+ }
+
+ // CHECK: async.await %[[TOKEN]]
+ // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.await %token : !async.token
+
+ // CHECK: async.await %[[TOKEN_0]]
+ // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32}
+ async.await %token_0 : !async.token
+
+ // CHECK: async.await %[[RESULTS]]
+ // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32}
+ %0 = async.await %results : !async.value<f32>
+
+ // CHECK: return
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index a95be650eff7..54dc6736b4dd 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -134,3 +134,17 @@ func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>)
%3 = addi %1, %2 : index
return %3 : index
}
+
+// CHECK-LABEL: @add_ref
+func @add_ref(%arg0: !async.token) {
+ // CHECK: async.add_ref %arg0 {count = 1 : i32}
+ async.add_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
+
+// CHECK-LABEL: @drop_ref
+func @drop_ref(%arg0: !async.token) {
+ // CHECK: async.drop_ref %arg0 {count = 1 : i32}
+ async.drop_ref %arg0 {count = 1 : i32} : !async.token
+ return
+}
diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
index 87004ff7b381..50f85ff54609 100644
--- a/mlir/test/mlir-cpu-runner/async-group.mlir
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-async-to-llvm \
+// RUN: mlir-opt %s -async-ref-counting \
+// RUN: -convert-async-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir
index fd0268e7ac56..5f06dd17ed61 100644
--- a/mlir/test/mlir-cpu-runner/async.mlir
+++ b/mlir/test/mlir-cpu-runner/async.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-async-to-llvm \
+// RUN: mlir-opt %s -async-ref-counting \
+// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-linalg-to-llvm \
// RUN: -convert-std-to-llvm \
More information about the Mlir-commits
mailing list