[Mlir-commits] [mlir] f57b242 - [mlir:Async] Add an async reference counting pass based on the user defined policy

Eugene Zhulenev llvmlistbot at llvm.org
Tue Jun 29 12:53:17 PDT 2021


Author: Eugene Zhulenev
Date: 2021-06-29T12:53:09-07:00
New Revision: f57b2420b2235eca00d5c085a7ef084433140452

URL: https://github.com/llvm/llvm-project/commit/f57b2420b2235eca00d5c085a7ef084433140452
DIFF: https://github.com/llvm/llvm-project/commit/f57b2420b2235eca00d5c085a7ef084433140452.diff

LOG: [mlir:Async] Add an async reference counting pass based on the user defined policy

Depends On D104999

Automatic reference counting based on the liveness analysis can add a lot of reference counting overhead at runtime. If the IR is known to be constrained to few particular "shapes", it's much more efficient to provide a custom reference counting policy that will specify where it is required to update the async value reference count.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D105037

Added: 
    mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 5d0a7f66cf774..ce85ffa296472 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -29,6 +29,8 @@ std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
 
 std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();
 
+std::unique_ptr<Pass> createAsyncRuntimePolicyBasedRefCountingPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index e321747d4ec66..913ecee43097c 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -66,4 +66,36 @@ def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
   let dependentDialects = ["async::AsyncDialect"];
 }
 
+def AsyncRuntimePolicyBasedRefCounting
+    : Pass<"async-runtime-policy-based-ref-counting"> {
+  let summary = "Policy based reference counting for Async runtime operations";
+  let description = [{
+    This pass works at the async runtime abtraction level, after all
+    `async.execute` and `async.await` operations are lowered to the async
+    runtime API calls, and async coroutine operations.
+
+    This pass doesn't rely on reference counted values liveness analysis, and
+    instead uses simple policy to create reference counting operations. If the
+    program violates any of the assumptions, then this pass might lead to
+    memory leaks or runtime errors.
+
+    The default reference counting policy assumptions:
+      1. Async token can be awaited or added to the group only once.
+      2. Async value or group can be awaited only once.
+
+    Under these assumptions reference counting only needs to drop reference:
+      1. After `async.runtime.await` operation for async tokens and groups
+         (until error handling is not implemented for the sync await).
+      2. After `async.runtime.is_error` operation for async tokens and groups
+         (this is the last operation in the coroutine resume function).
+      3. After `async.runtime.load` operation for async values.
+
+    This pass introduces significanly less runtime overhead compared to the
+    automatic reference counting.
+  }];
+
+  let constructor = "mlir::createAsyncRuntimePolicyBasedRefCountingPass()";
+  let dependentDialects = ["async::AsyncDialect"];
+}
+
 #endif // MLIR_DIALECT_ASYNC_PASSES

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 15fd4f3f87650..17e768cee74ba 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -26,6 +26,79 @@ using namespace mlir::async;
 
 #define DEBUG_TYPE "async-runtime-ref-counting"
 
+//===----------------------------------------------------------------------===//
+// Utility functions shared by reference counting passes.
+//===----------------------------------------------------------------------===//
+
+// Drop the reference count immediately if the value has no uses.
+static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
+  if (!value.getUses().empty())
+    return failure();
+
+  OpBuilder b(value.getContext());
+
+  // Set insertion point after the operation producing a value, or at the
+  // beginning of the block if the value defined by the block argument.
+  if (Operation *op = value.getDefiningOp())
+    b.setInsertionPointAfter(op);
+  else
+    b.setInsertionPointToStart(value.getParentBlock());
+
+  b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI32IntegerAttr(1));
+  return success();
+}
+
+// Calls `addRefCounting` for every reference counted value defined by the
+// operation `op` (block arguments and values defined in nested regions).
+static LogicalResult walkReferenceCountedValues(
+    Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
+  // Check that we do not have high level async operations in the IR because
+  // otherwise reference counting will produce incorrect results after high
+  // level async operations will be lowered to `async.runtime`
+  WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
+    if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
+      return WalkResult::advance();
+
+    return op->emitError()
+           << "async operations must be lowered to async runtime operations";
+  });
+
+  if (checkNoAsyncWalk.wasInterrupted())
+    return failure();
+
+  // Add reference counting to block arguments.
+  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
+    for (BlockArgument arg : block->getArguments())
+      if (isRefCounted(arg.getType()))
+        if (failed(addRefCounting(arg)))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (blockWalk.wasInterrupted())
+    return failure();
+
+  // Add reference counting to operation results.
+  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
+    for (unsigned i = 0; i < op->getNumResults(); ++i)
+      if (isRefCounted(op->getResultTypes()[i]))
+        if (failed(addRefCounting(op->getResult(i))))
+          return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  if (opWalk.wasInterrupted())
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Automatic reference counting based on the liveness analysis.
+//===----------------------------------------------------------------------===//
+
 namespace {
 
 class AsyncRuntimeRefCountingPass
@@ -356,21 +429,9 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
 
 LogicalResult
 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
-  OpBuilder builder(value.getContext());
-  Location loc = value.getLoc();
-
-  // Set inserton point after the operation producing a value, or at the
-  // beginning of the block if the value defined by the block argument.
-  if (Operation *op = value.getDefiningOp())
-    builder.setInsertionPointAfter(op);
-  else
-    builder.setInsertionPointToStart(value.getParentBlock());
-
-  // Drop the reference count immediately if the value has no uses.
-  if (value.getUses().empty()) {
-    builder.create<RuntimeDropRefOp>(loc, value, builder.getI32IntegerAttr(1));
+  // Short-circuit reference counting for values without uses.
+  if (succeeded(dropRefIfNoUses(value)))
     return success();
-  }
 
   // Add `drop_ref` operations based on the liveness analysis.
   if (failed(addDropRefAfterLastUse(value)))
@@ -388,53 +449,114 @@ AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
 }
 
 void AsyncRuntimeRefCountingPass::runOnOperation() {
-  Operation *op = getOperation();
+  auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
+  if (failed(walkReferenceCountedValues(getOperation(), functor)))
+    signalPassFailure();
+}
 
-  // Check that we do not have high level async operations in the IR because
-  // otherwise automatic reference counting will produce incorrect results after
-  // execute operations will be lowered to `async.runtime`
-  WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult {
-    if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
-      return WalkResult::advance();
+//===----------------------------------------------------------------------===//
+// Reference counting based on the user defined policy.
+//===----------------------------------------------------------------------===//
 
-    return op->emitError()
-           << "async operations must be lowered to async runtime operations";
-  });
+namespace {
 
-  if (executeOpWalk.wasInterrupted()) {
-    signalPassFailure();
-    return;
-  }
+class AsyncRuntimePolicyBasedRefCountingPass
+    : public AsyncRuntimePolicyBasedRefCountingBase<
+          AsyncRuntimePolicyBasedRefCountingPass> {
+public:
+  AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
 
-  // Add reference counting to block arguments.
-  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
-    for (BlockArgument arg : block->getArguments())
-      if (isRefCounted(arg.getType()))
-        if (failed(addAutomaticRefCounting(arg)))
-          return WalkResult::interrupt();
+  void runOnOperation() override;
 
-    return WalkResult::advance();
-  });
+private:
+  // Adds a reference counting operations for all uses of the `value` according
+  // to the reference counting policy.
+  LogicalResult addRefCounting(Value value);
 
-  if (blockWalk.wasInterrupted()) {
-    signalPassFailure();
-    return;
+  void initializeDefaultPolicy();
+
+  llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
+};
+
+} // namespace
+
+LogicalResult
+AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
+  // Short-circuit reference counting for values without uses.
+  if (succeeded(dropRefIfNoUses(value)))
+    return success();
+
+  OpBuilder b(value.getContext());
+
+  // Consult the user defined policy for every value use.
+  for (OpOperand &operand : value.getUses()) {
+    Location loc = operand.getOwner()->getLoc();
+
+    for (auto &func : policy) {
+      FailureOr<int> refCount = func(operand);
+      if (failed(refCount))
+        return failure();
+
+      int cnt = refCount.getValue();
+
+      // Create `add_ref` operation before the operand owner.
+      if (cnt > 0) {
+        b.setInsertionPoint(operand.getOwner());
+        b.create<RuntimeAddRefOp>(loc, value, b.getI32IntegerAttr(cnt));
+      }
+
+      // Create `drop_ref` operation after the operand owner.
+      if (cnt < 0) {
+        b.setInsertionPointAfter(operand.getOwner());
+        b.create<RuntimeDropRefOp>(loc, value, b.getI32IntegerAttr(-cnt));
+      }
+    }
   }
 
-  // Add reference counting to operation results.
-  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
-    for (unsigned i = 0; i < op->getNumResults(); ++i)
-      if (isRefCounted(op->getResultTypes()[i]))
-        if (failed(addAutomaticRefCounting(op->getResult(i))))
-          return WalkResult::interrupt();
+  return success();
+}
 
-    return WalkResult::advance();
+void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
+  policy.push_back([](OpOperand &operand) -> FailureOr<int> {
+    Operation *op = operand.getOwner();
+    Type type = operand.get().getType();
+
+    bool isToken = type.isa<TokenType>();
+    bool isGroup = type.isa<GroupType>();
+    bool isValue = type.isa<ValueType>();
+
+    // Drop reference after async token or group await (sync await)
+    if (auto await = dyn_cast<RuntimeAwaitOp>(op))
+      return (isToken || isGroup) ? -1 : 0;
+
+    // Drop reference after async token or group error check (coro await).
+    if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
+      return (isToken || isGroup) ? -1 : 0;
+
+    // Drop reference after async value load.
+    if (auto load = dyn_cast<RuntimeLoadOp>(op))
+      return isValue ? -1 : 0;
+
+    // Drop reference after async token added to the group.
+    if (auto add = dyn_cast<RuntimeAddToGroupOp>(op))
+      return isToken ? -1 : 0;
+
+    return 0;
   });
+}
 
-  if (opWalk.wasInterrupted())
+void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
+  auto functor = [&](Value value) { return addRefCounting(value); };
+  if (failed(walkReferenceCountedValues(getOperation(), functor)))
     signalPassFailure();
 }
 
+//----------------------------------------------------------------------------//
+
 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
   return std::make_unique<AsyncRuntimeRefCountingPass>();
 }
+
+std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
+  return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
+}

diff  --git a/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
new file mode 100644
index 0000000000000..54640f552798d
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-runtime-policy-based-ref-counting.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s -async-runtime-policy-based-ref-counting | FileCheck %s
+
+// CHECK-LABEL: @token_await
+// CHECK:         %[[TOKEN:.*]]: !async.token
+func @token_await(%arg0: !async.token) {
+  // CHECK: async.runtime.await %[[TOKEN]]
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.runtime.await %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @group_await
+// CHECK:         %[[GROUP:.*]]: !async.group
+func @group_await(%arg0: !async.group) {
+  // CHECK: async.runtime.await %[[GROUP]]
+  // CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
+  async.runtime.await %arg0 : !async.group
+  return
+}
+
+// CHECK-LABEL: @add_token_to_group
+// CHECK:         %[[GROUP:.*]]: !async.group
+// CHECK:         %[[TOKEN:.*]]: !async.token
+func @add_token_to_group(%arg0: !async.group, %arg1: !async.token) {
+  // CHECK: async.runtime.add_to_group %[[TOKEN]], %[[GROUP]]
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  async.runtime.add_to_group %arg1, %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @value_load
+// CHECK:         %[[VALUE:.*]]: !async.value<f32>
+func @value_load(%arg0: !async.value<f32>) {
+  // CHECK: async.runtime.load %[[VALUE]]
+  // CHECK: async.runtime.drop_ref %[[VALUE]] {count = 1 : i32}
+  %0 = async.runtime.load %arg0 : !async.value<f32>
+  return
+}
+
+// CHECK-LABEL: @error_check
+// CHECK:         %[[TOKEN:.*]]: !async.token
+func @error_check(%arg0: !async.token) {
+  // CHECK: async.runtime.is_error %[[TOKEN]]
+  // CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
+  %0 = async.runtime.is_error %arg0 : !async.token
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 12b2be2627131..6c2758c484f79 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -11,6 +11,18 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
 // RUN: | FileCheck %s --dump-input=always
 
+// RUN:   mlir-opt %s -async-parallel-for                                      \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-policy-based-ref-counting                 \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:  -e entry -entry-point-result=void -O0                                 \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
 // RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
 // RUN:                                    num-workers=20                      \
 // RUN:                                    target-block-size=1"                \

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index b294b9ce4d26e..d8f99d061b7d4 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -11,6 +11,18 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
 // RUN: | FileCheck %s --dump-input=always
 
+// RUN:   mlir-opt %s -async-parallel-for                                      \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-policy-based-ref-counting                 \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:  -e entry -entry-point-result=void -O0                                 \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
 // RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
 // RUN:                                    num-workers=20                      \
 // RUN:                                    target-block-size=1"                \


        


More information about the Mlir-commits mailing list