[Mlir-commits] [mlir] [mlir][transform] Do not maintain mappings for dead handles (PR #73558)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 27 11:15:45 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/73558

Do not maintain transform IR <-> payload IR mappings for dead handles, i.e., handles that do not have any further uses.

This change reduces the memory overhead of the transform dialect interpreter.

>From db7720612710cc066eadf2227ccdac9ac8828475 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 27 Nov 2023 20:10:52 +0100
Subject: [PATCH] [mlir][transform] Do not maintain mappings for dead handles

Do not maintain transform IR <-> payload IR mappings for dead handles,
i.e., handles that do not have any further uses.

This change reduces the memory overhead of the transform dialect
interpreter.
---
 .../Transform/IR/TransformInterfaces.cpp      | 97 ++++++++++++-------
 .../Transform/transform-state-extension.mlir  |  3 +
 2 files changed, 66 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d0cd879d560c887..ab7280d7e2b27b6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -30,6 +30,33 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
+/// Return nullptr if `v` is dead (has no further uses) after `op`. Otherwise,
+/// return an arbitrary alive use. This return value is typically used for in
+/// error messages or for debugging purposes.
+static OpOperand *getAliveUse(Value v, Operation *op) {
+  for (OpOperand &use : v.getUses())
+    if (use.getOwner() != op && !happensBefore(use.getOwner(), op))
+      return &use;
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // TransformState
 //===----------------------------------------------------------------------===//
@@ -216,6 +243,10 @@ transform::TransformState::setPayloadOps(Value value,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (value.getUses().empty())
+    return success();
+
   // Setting new payload for the value without cleaning it first is a misuse of
   // the API, assert here.
   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
@@ -252,6 +283,10 @@ transform::TransformState::setPayloadValues(Value handle,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (handle.getUses().empty())
+    return success();
+
   Mappings &mappings = getMapping(handle);
   bool inserted =
       mappings.values.insert({handle, std::move(payloadValueVector)}).second;
@@ -285,6 +320,10 @@ LogicalResult transform::TransformState::setParams(Value value,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (value.getUses().empty())
+    return success();
+
   Mappings &mappings = getMapping(value);
   bool inserted =
       mappings.params.insert({value, llvm::to_vector(params)}).second;
@@ -494,10 +533,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
   unsigned operandNo = consumingHandle.getOperandNumber();
   for (Operation *ancestor : potentialAncestors) {
     // clang-format off
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
-      { (DBGS() << "----of payload with name: " 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+      { (DBGS() << "----of payload with name: "
                 << payloadOp->getName().getIdentifier() << "\n"); });
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -994,13 +1033,18 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
 
   // Remove the mapping for the operand if it is consumed by the operation. This
   // allows us to catch use-after-free with assertions later on.
-  for (OpOperand *opOperand : consumedOperands) {
-    Value operand = opOperand->get();
+  for (OpOperand &opOperand : transform->getOpOperands()) {
+    Value operand = opOperand.get();
+    if (getAliveUse(operand, transform) != nullptr)
+      continue;
+    bool wasConsumed = llvm::is_contained(consumedOperands, &opOperand);
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
-      forgetMapping(operand, origOpFlatResults);
+      forgetMapping(operand,
+                    wasConsumed ? ValueRange(origOpFlatResults) : ValueRange());
     } else if (llvm::isa<TransformValueHandleTypeInterface>(
                    operand.getType())) {
-      forgetValueMapping(operand, origAssociatedOps);
+      forgetValueMapping(operand, wasConsumed ? ArrayRef(origAssociatedOps)
+                                              : ArrayRef<Operation *>());
     }
   }
 
@@ -1369,19 +1413,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   });
 }
 
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
-  do {
-    if (a->isProperAncestor(b))
-      return false;
-    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
-      return a->isBeforeInBlock(bAncestor);
-    }
-  } while ((a = a->getParentOp()));
-  return false;
-}
-
 void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
@@ -1413,20 +1444,18 @@ void transform::TrackingListener::notifyOperationReplaced(
                         [&](Value h) { return consumedHandles.contains(h); });
   };
 
-  // Helper function to check if the handle is alive.
-  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
-    for (Value v : opHandles) {
-      for (OpOperand &use : v.getUses())
-        if (use.getOwner() != transformOp &&
-            !happensBefore(use.getOwner(), transformOp))
-          return &use;
+  // Check if there are any live handles.
+  OpOperand *aliveUse = nullptr;
+  for (Value v : opHandles) {
+    if (OpOperand *use = getAliveUse(v, transformOp)) {
+      aliveUse = use;
+      break;
     }
-    return std::nullopt;
-  }();
+  }
 
-  if (!firstAliveUser.has_value() || handleWasConsumed()) {
-    // The op is tracked but the corresponding handles are dead or were
-    // consumed. Drop the op form the mapping.
+  if (!aliveUse || handleWasConsumed()) {
+    // The op is tracked but the corresponding handles are dead. Drop the op
+    // from the mapping.
     (void)replacePayloadOp(op, nullptr);
     return;
   }
@@ -1437,10 +1466,10 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (!diag.succeeded()) {
-    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
+    diag.attachNote(aliveUse->getOwner()->getLoc())
         << "replacement is required because alive handle(s) exist "
         << "(first use in this op as operand number "
-        << (*firstAliveUser)->getOperandNumber() << ")";
+        << aliveUse->getOperandNumber() << ")";
     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
     (void)replacePayloadOp(op, nullptr);
     return;
diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index a26293fbe51ca61..cd115027d0f0002 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -76,6 +76,9 @@ transform.sequence failures(propagate) {
   %dummy = test_remap_operand_to_self %arg0 : (!transform.any_op) -> !transform.any_op
   %valuehandle = transform.get_result %dummy[0] : (!transform.any_op) -> !transform.any_value
   test_remap_operand_to_self %dummy : (!transform.any_op) -> ()
+  // Use %valuehandle so that the SSA value is not dead. This prevents the
+  // transform dialect interpreter from discarding the handle.
+  test_print_number_of_associated_payload_ir_values %valuehandle : !transform.any_value
 }
 
 // -----



More information about the Mlir-commits mailing list