[Mlir-commits] [mlir] [mlir][transform] Do not maintain mappings for dead handles (PR #73558)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 27 11:16:13 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Do not maintain transform IR <-> payload IR mappings for dead handles, i.e., handles that do not have any further uses.
This change reduces the memory overhead of the transform dialect interpreter.
---
Full diff: https://github.com/llvm/llvm-project/pull/73558.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+63-34)
- (modified) mlir/test/Dialect/Transform/transform-state-extension.mlir (+3)
``````````diff
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d0cd879d560c887..ab7280d7e2b27b6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -30,6 +30,33 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+ do {
+ if (a->isProperAncestor(b))
+ return false;
+ if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+ return a->isBeforeInBlock(bAncestor);
+ }
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
+/// Return nullptr if `v` is dead (has no further uses) after `op`. Otherwise,
+/// return an arbitrary alive use. This return value is typically used for in
+/// error messages or for debugging purposes.
+static OpOperand *getAliveUse(Value v, Operation *op) {
+ for (OpOperand &use : v.getUses())
+ if (use.getOwner() != op && !happensBefore(use.getOwner(), op))
+ return &use;
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
@@ -216,6 +243,10 @@ transform::TransformState::setPayloadOps(Value value,
if (failed(result.checkAndReport()))
return failure();
+ // Do not maintain mappings for dead handles.
+ if (value.getUses().empty())
+ return success();
+
// Setting new payload for the value without cleaning it first is a misuse of
// the API, assert here.
SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
@@ -252,6 +283,10 @@ transform::TransformState::setPayloadValues(Value handle,
if (failed(result.checkAndReport()))
return failure();
+ // Do not maintain mappings for dead handles.
+ if (handle.getUses().empty())
+ return success();
+
Mappings &mappings = getMapping(handle);
bool inserted =
mappings.values.insert({handle, std::move(payloadValueVector)}).second;
@@ -285,6 +320,10 @@ LogicalResult transform::TransformState::setParams(Value value,
if (failed(result.checkAndReport()))
return failure();
+ // Do not maintain mappings for dead handles.
+ if (value.getUses().empty())
+ return success();
+
Mappings &mappings = getMapping(value);
bool inserted =
mappings.params.insert({value, llvm::to_vector(params)}).second;
@@ -494,10 +533,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
unsigned operandNo = consumingHandle.getOperandNumber();
for (Operation *ancestor : potentialAncestors) {
// clang-format off
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- { (DBGS() << "----of payload with name: "
+ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+ { (DBGS() << "----of payload with name: "
<< payloadOp->getName().getIdentifier() << "\n"); });
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -994,13 +1033,18 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// Remove the mapping for the operand if it is consumed by the operation. This
// allows us to catch use-after-free with assertions later on.
- for (OpOperand *opOperand : consumedOperands) {
- Value operand = opOperand->get();
+ for (OpOperand &opOperand : transform->getOpOperands()) {
+ Value operand = opOperand.get();
+ if (getAliveUse(operand, transform) != nullptr)
+ continue;
+ bool wasConsumed = llvm::is_contained(consumedOperands, &opOperand);
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
- forgetMapping(operand, origOpFlatResults);
+ forgetMapping(operand,
+ wasConsumed ? ValueRange(origOpFlatResults) : ValueRange());
} else if (llvm::isa<TransformValueHandleTypeInterface>(
operand.getType())) {
- forgetValueMapping(operand, origAssociatedOps);
+ forgetValueMapping(operand, wasConsumed ? ArrayRef(origAssociatedOps)
+ : ArrayRef<Operation *>());
}
}
@@ -1369,19 +1413,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
});
}
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
- do {
- if (a->isProperAncestor(b))
- return false;
- if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
- return a->isBeforeInBlock(bAncestor);
- }
- } while ((a = a->getParentOp()));
- return false;
-}
-
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
@@ -1413,20 +1444,18 @@ void transform::TrackingListener::notifyOperationReplaced(
[&](Value h) { return consumedHandles.contains(h); });
};
- // Helper function to check if the handle is alive.
- auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
- for (Value v : opHandles) {
- for (OpOperand &use : v.getUses())
- if (use.getOwner() != transformOp &&
- !happensBefore(use.getOwner(), transformOp))
- return &use;
+ // Check if there are any live handles.
+ OpOperand *aliveUse = nullptr;
+ for (Value v : opHandles) {
+ if (OpOperand *use = getAliveUse(v, transformOp)) {
+ aliveUse = use;
+ break;
}
- return std::nullopt;
- }();
+ }
- if (!firstAliveUser.has_value() || handleWasConsumed()) {
- // The op is tracked but the corresponding handles are dead or were
- // consumed. Drop the op form the mapping.
+ if (!aliveUse || handleWasConsumed()) {
+ // The op is tracked but the corresponding handles are dead. Drop the op
+ // from the mapping.
(void)replacePayloadOp(op, nullptr);
return;
}
@@ -1437,10 +1466,10 @@ void transform::TrackingListener::notifyOperationReplaced(
// If the op is tracked but no replacement op was found, send a
// notification.
if (!diag.succeeded()) {
- diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
+ diag.attachNote(aliveUse->getOwner()->getLoc())
<< "replacement is required because alive handle(s) exist "
<< "(first use in this op as operand number "
- << (*firstAliveUser)->getOperandNumber() << ")";
+ << aliveUse->getOperandNumber() << ")";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;
diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index a26293fbe51ca61..cd115027d0f0002 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -76,6 +76,9 @@ transform.sequence failures(propagate) {
%dummy = test_remap_operand_to_self %arg0 : (!transform.any_op) -> !transform.any_op
%valuehandle = transform.get_result %dummy[0] : (!transform.any_op) -> !transform.any_value
test_remap_operand_to_self %dummy : (!transform.any_op) -> ()
+ // Use %valuehandle so that the SSA value is not dead. This prevents the
+ // transform dialect interpreter from discarding the handle.
+ test_print_number_of_associated_payload_ir_values %valuehandle : !transform.any_value
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/73558
More information about the Mlir-commits
mailing list