[Mlir-commits] [mlir] f57b242 - [mlir:Async] Add an async reference counting pass based on the user defined policy
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Jun 29 12:53:17 PDT 2021
Author: Eugene Zhulenev
Date: 2021-06-29T12:53:09-07:00
New Revision: f57b2420b2235eca00d5c085a7ef084433140452
URL: https://github.com/llvm/llvm-project/commit/f57b2420b2235eca00d5c085a7ef084433140452
DIFF: https://github.com/llvm/llvm-project/commit/f57b2420b2235eca00d5c085a7ef084433140452.diff
LOG: [mlir:Async] Add an async reference counting pass based on the user defined policy
Depends On D104999
Automatic reference counting based on the liveness analysis can add a lot of reference counting overhead at runtime. If the IR is known to be constrained to few particular "shapes", it's much more efficient to provide a custom reference counting policy that will specify where it is required to update the async value reference count.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D105037
Added:
mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
Modified:
mlir/include/mlir/Dialect/Async/Passes.h
mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 5d0a7f66cf774..ce85ffa296472 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -29,6 +29,8 @@ std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
+std::unique_ptr<Pass> createAsyncRuntimePolicyBasedRefCountingPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index e321747d4ec66..913ecee43097c 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -66,4 +66,36 @@ def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
let dependentDialects = ["async::AsyncDialect"];
}
+def AsyncRuntimePolicyBasedRefCounting
+ : Pass<"async-runtime-policy-based-ref-counting"> {
+ let summary = "Policy based reference counting for Async runtime operations";
+ let description = [{
+ This pass works at the async runtime abtraction level, after all
+ `async.execute` and `async.await` operations are lowered to the async
+ runtime API calls, and async coroutine operations.
+
+ This pass doesn't rely on reference counted values liveness analysis, and
+ instead uses simple policy to create reference counting operations. If the
+ program violates any of the assumptions, then this pass might lead to
+ memory leaks or runtime errors.
+
+ The default reference counting policy assumptions:
+ 1. Async token can be awaited or added to the group only once.
+ 2. Async value or group can be awaited only once.
+
+ Under these assumptions reference counting only needs to drop reference:
+ 1. After `async.runtime.await` operation for async tokens and groups
+ (until error handling is not implemented for the sync await).
+ 2. After `async.runtime.is_error` operation for async tokens and groups
+ (this is the last operation in the coroutine resume function).
+ 3. After `async.runtime.load` operation for async values.
+
+ This pass introduces significanly less runtime overhead compared to the
+ automatic reference counting.
+ }];
+
+ let constructor = "mlir::createAsyncRuntimePolicyBasedRefCountingPass()";
+ let dependentDialects = ["async::AsyncDialect"];
+}
+
#endif // MLIR_DIALECT_ASYNC_PASSES
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 15fd4f3f87650..17e768cee74ba 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -26,6 +26,79 @@ using namespace mlir::async;
#define DEBUG_TYPE "async-runtime-ref-counting"
+//===----------------------------------------------------------------------===//
+// Utility functions shared by reference counting passes.
+//===----------------------------------------------------------------------===//
+
+// Drop the reference count immediately if the value has no uses.
+static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
+ if (!value.getUses().empty())
+ return failure();
+
+ OpBuilder b(value.getContext());
+
+ // Set insertion point after the operation producing a value, or at the
+ // beginning of the block if the value defined by the block argument.
+ if (Operation *op = value.getDefiningOp())
+ b.setInsertionPointAfter(op);
+ else
+ b.setInsertionPointToStart(value.getParentBlock());
+
+ b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI32IntegerAttr(1));
+ return success();
+}
+
+// Calls `addRefCounting` for every reference counted value defined by the
+// operation `op` (block arguments and values defined in nested regions).
+static LogicalResult walkReferenceCountedValues(
+ Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
+ // Check that we do not have high level async operations in the IR because
+ // otherwise reference counting will produce incorrect results after high
+ // level async operations will be lowered to `async.runtime`
+ WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
+ if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
+ return WalkResult::advance();
+
+ return op->emitError()
+ << "async operations must be lowered to async runtime operations";
+ });
+
+ if (checkNoAsyncWalk.wasInterrupted())
+ return failure();
+
+ // Add reference counting to block arguments.
+ WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
+ for (BlockArgument arg : block->getArguments())
+ if (isRefCounted(arg.getType()))
+ if (failed(addRefCounting(arg)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (blockWalk.wasInterrupted())
+ return failure();
+
+ // Add reference counting to operation results.
+ WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
+ for (unsigned i = 0; i < op->getNumResults(); ++i)
+ if (isRefCounted(op->getResultTypes()[i]))
+ if (failed(addRefCounting(op->getResult(i))))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ if (opWalk.wasInterrupted())
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Automatic reference counting based on the liveness analysis.
+//===----------------------------------------------------------------------===//
+
namespace {
class AsyncRuntimeRefCountingPass
@@ -356,21 +429,9 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
LogicalResult
AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
- OpBuilder builder(value.getContext());
- Location loc = value.getLoc();
-
- // Set inserton point after the operation producing a value, or at the
- // beginning of the block if the value defined by the block argument.
- if (Operation *op = value.getDefiningOp())
- builder.setInsertionPointAfter(op);
- else
- builder.setInsertionPointToStart(value.getParentBlock());
-
- // Drop the reference count immediately if the value has no uses.
- if (value.getUses().empty()) {
- builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+ // Short-circuit reference counting for values without uses.
+ if (succeeded(dropRefIfNoUses(value)))
return success();
- }
// Add `drop_ref` operations based on the liveness analysis.
if (failed(addDropRefAfterLastUse(value)))
@@ -388,53 +449,114 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
}
void AsyncRuntimeRefCountingPass::runOnOperation() {
- Operation *op = getOperation();
+ auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
+ if (failed(walkReferenceCountedValues(getOperation(), functor)))
+ signalPassFailure();
+}
- // Check that we do not have high level async operations in the IR because
- // otherwise automatic reference counting will produce incorrect results after
- // execute operations will be lowered to `async.runtime`
- WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
- if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
- return WalkResult::advance();
+//===----------------------------------------------------------------------===//
+// Reference counting based on the user defined policy.
+//===----------------------------------------------------------------------===//
- return op->emitError()
- << "async operations must be lowered to async runtime operations";
- });
+namespace {
- if (executeOpWalk.wasInterrupted()) {
- signalPassFailure();
- return;
- }
+class AsyncRuntimePolicyBasedRefCountingPass
+ : public AsyncRuntimePolicyBasedRefCountingBase<
+ AsyncRuntimePolicyBasedRefCountingPass> {
+public:
+ AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
- // Add reference counting to block arguments.
- WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
- for (BlockArgument arg : block->getArguments())
- if (isRefCounted(arg.getType()))
- if (failed(addAutomaticRefCounting(arg)))
- return WalkResult::interrupt();
+ void runOnOperation() override;
- return WalkResult::advance();
- });
+private:
+ // Adds a reference counting operations for all uses of the `value` according
+ // to the reference counting policy.
+ LogicalResult addRefCounting(Value value);
- if (blockWalk.wasInterrupted()) {
- signalPassFailure();
- return;
+ void initializeDefaultPolicy();
+
+ llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
+};
+
+} // namespace
+
+LogicalResult
+AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
+ // Short-circuit reference counting for values without uses.
+ if (succeeded(dropRefIfNoUses(value)))
+ return success();
+
+ OpBuilder b(value.getContext());
+
+ // Consult the user defined policy for every value use.
+ for (OpOperand &operand : value.getUses()) {
+ Location loc = operand.getOwner()->getLoc();
+
+ for (auto &func : policy) {
+ FailureOr<int> refCount = func(operand);
+ if (failed(refCount))
+ return failure();
+
+ int cnt = refCount.getValue();
+
+ // Create `add_ref` operation before the operand owner.
+ if (cnt > 0) {
+ b.setInsertionPoint(operand.getOwner());
+ b.create<RuntimeAddRefOp>(loc, value, b.getI32IntegerAttr(cnt));
+ }
+
+ // Create `drop_ref` operation after the operand owner.
+ if (cnt < 0) {
+ b.setInsertionPointAfter(operand.getOwner());
+ b.create<RuntimeDropRefOp>(loc, value, b.getI32IntegerAttr(-cnt));
+ }
+ }
}
- // Add reference counting to operation results.
- WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
- for (unsigned i = 0; i < op->getNumResults(); ++i)
- if (isRefCounted(op->getResultTypes()[i]))
- if (failed(addAutomaticRefCounting(op->getResult(i))))
- return WalkResult::interrupt();
+ return success();
+}
- return WalkResult::advance();
+void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
+ policy.push_back([](OpOperand &operand) -> FailureOr<int> {
+ Operation *op = operand.getOwner();
+ Type type = operand.get().getType();
+
+ bool isToken = type.isa<TokenType>();
+ bool isGroup = type.isa<GroupType>();
+ bool isValue = type.isa<ValueType>();
+
+ // Drop reference after async token or group await (sync await)
+ if (auto await = dyn_cast<RuntimeAwaitOp>(op))
+ return (isToken || isGroup) ? -1 : 0;
+
+ // Drop reference after async token or group error check (coro await).
+ if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
+ return (isToken || isGroup) ? -1 : 0;
+
+ // Drop reference after async value load.
+ if (auto load = dyn_cast<RuntimeLoadOp>(op))
+ return isValue ? -1 : 0;
+
+ // Drop reference after async token added to the group.
+ if (auto add = dyn_cast<RuntimeAddToGroupOp>(op))
+ return isToken ? -1 : 0;
+
+ return 0;
});
+}
- if (opWalk.wasInterrupted())
+void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
+ auto functor = [&](Value value) { return addRefCounting(value); };
+ if (failed(walkReferenceCountedValues(getOperation(), functor)))
signalPassFailure();
}
+//----------------------------------------------------------------------------//
+
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
return std::make_unique<AsyncRuntimeRefCountingPass>();
}
+
+std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
+ return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
+}
diff --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
new file mode 100644
index 0000000000000..54640f552798d
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s -async-runtime-policy-based-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @token_await
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @token_await(%arg0: !async.token) {
+ // CHECK: async.runtime.await %[[TOKEN]]
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.runtime.await %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @group_await
+// CHECK: %[[GROUP:.*]]: !async.group
+func @group_await(%arg0: !async.group) {
+ // CHECK: async.runtime.await %[[GROUP]]
+ // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
+ async.runtime.await %arg0 : !async.group
+ return
+}
+
+// CHECK-LABEL: @add_token_to_group
+// CHECK: %[[GROUP:.*]]: !async.group
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @add_token_to_group(%arg0: !async.group, %arg1: !async.token) {
+ // CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]]
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ async.runtime.add_to_group %arg1, %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @value_load
+// CHECK: %[[VALUE:.*]]: !async.value<f32>
+func @value_load(%arg0: !async.value<f32>) {
+ // CHECK: async.runtime.load %[[VALUE]]
+ // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
+ %0 = async.runtime.load %arg0 : !async.value<f32>
+ return
+}
+
+// CHECK-LABEL: @error_check
+// CHECK: %[[TOKEN:.*]]: !async.token
+func @error_check(%arg0: !async.token) {
+ // CHECK: async.runtime.is_error %[[TOKEN]]
+ // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+ %0 = async.runtime.is_error %arg0 : !async.token
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 12b2be2627131..6c2758c484f79 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -11,6 +11,18 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
// RUN: | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-to-async-runtime \
+// RUN: -async-runtime-policy-based-ref-counting \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-scf-to-std \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
// RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \
// RUN: num-workers=20 \
// RUN: target-block-size=1" \
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index b294b9ce4d26e..d8f99d061b7d4 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -11,6 +11,18 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
// RUN: | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -async-parallel-for \
+// RUN: -async-to-async-runtime \
+// RUN: -async-runtime-policy-based-ref-counting \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-scf-to-std \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
// RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \
// RUN: num-workers=20 \
// RUN: target-block-size=1" \
More information about the Mlir-commits
mailing list