[Mlir-commits] [mlir] [mlir][transform] TrackingListener: Improve dead handles detection (PR #74290)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 00:50:37 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The tracking listener should not report op replacement errors for payload ops that are not mapped to any live handles. The handle liveless analysis did not work properly with transform IR that has named sequences.

A handle is live if it has a user after the transform op that is currently being applied. With named sequences, we need to maintain a stack of currently applied transform ops. That stack already exists (`regionStack`), the only thing that's missing is the current transform op for each stack frame.

This commit fixes #<!-- -->72931.

---
Full diff: https://github.com/llvm/llvm-project/pull/74290.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+19-11) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+61-43) 
- (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+29-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2fdc15db9ad85..35de8a2e1fa5f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -310,10 +310,8 @@ class TransformState {
   /// with the type of the handle value.
   LogicalResult mapBlockArguments(BlockArgument argument,
                                   ArrayRef<Operation *> operations) {
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-    assert(argument.getParentRegion() == regionStack.back() &&
+    assert(argument.getParentRegion() == regionStack.back()->region &&
            "mapping block arguments from a region other than the active one");
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     return setPayloadOps(argument, operations);
   }
   LogicalResult mapBlockArgument(BlockArgument argument,
@@ -350,9 +348,7 @@ class TransformState {
           std::make_pair(&region, std::make_unique<Mappings>()));
       assert(res.second && "the region scope is already present");
       (void)res;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-      state.regionStack.push_back(&region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+      state.regionStack.push_back(this);
     }
 
     /// Back-reference to the transform state.
@@ -361,7 +357,9 @@ class TransformState {
     /// The region this scope is associated with.
     Region *region;
 
-    friend RegionScope TransformState::make_region_scope(Region &);
+    TransformOpInterface currentTransform;
+
+    friend class transform::TransformState;
   };
   friend class RegionScope;
 
@@ -784,12 +782,12 @@ class TransformState {
   /// location.
   InvalidatedHandleMap invalidatedHandles;
 
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// A stack of nested regions that are being processed in the transform IR.
   /// Each region must be an ancestor of the following regions in this list.
   /// These are also the keys for "mappings".
-  SmallVector<Region *> regionStack;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  SmallVector<RegionScope *> regionStack;
+
+  std::unique_ptr<RegionScope> topLevelRegionScope;
 };
 
 /// Local mapping between values defined by a specific op implementing the
@@ -926,8 +924,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
 class TrackingListener : public RewriterBase::Listener,
                          public TransformState::Extension {
 public:
+  /// A function that returns "true" for handles that do not have to be updated.
+  using SkipHandleFn = std::function<bool(Value)>;
+
   /// Create a new TrackingListener for usage in the specified transform op.
-  TrackingListener(TransformState &state, TransformOpInterface op);
+  /// Optionally, a function can be specified to identify handles that should
+  /// do not have to be updated.
+  TrackingListener(TransformState &state, TransformOpInterface op,
+                   SkipHandleFn skipHandleFn = nullptr);
 
 protected:
   /// Return a replacement payload op for the given op, which is going to be
@@ -1015,6 +1019,10 @@ class TrackingListener : public RewriterBase::Listener,
 
   /// The handles that are consumed by the transform op.
   DenseSet<Value> consumedHandles;
+
+  /// Handles for which this function evaluates to "true" do not have to be
+  /// updated. These are typically dead or consumed handles.
+  SkipHandleFn skipHandleFn;
 };
 
 /// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index de5b7a81286bc..cd66a0e566f6c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -30,6 +30,23 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // TransformState
 //===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
   topLevelMappedValues.reserve(extraMappings.size());
   for (ArrayRef<MappedValue> mapping : extraMappings)
     topLevelMappedValues.push_back(mapping);
-
-  auto result =
-      mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
-  assert(result.second && "the region scope is already present");
-  (void)result;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  regionStack.push_back(region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  if (region) {
+    RegionScope *scope = new RegionScope(*this, *region);
+    topLevelRegionScope.reset(scope);
+  }
 }
 
 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
         llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
   });
+
+  // Set current transform op.
+  regionStack.back()->currentTransform = transform;
+
+  // Expensive checks to detect invalid transform IR.
   if (options.getExpensiveChecksEnabled()) {
     FULL_LDBG("ExpensiveChecksEnabled\n");
     if (failed(checkAndRecordHandleInvalidation(transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   // Prepare rewriter and listener.
-  transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
+  TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+    // Skip handle if it is dead.
+    auto scopeIt =
+        llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
+          return handle.getParentRegion() == scope->region;
+        });
+    assert(scopeIt != regionStack.rend() &&
+           "could not find region scope for handle");
+    RegionScope *scope = *scopeIt;
+    for (Operation *user : handle.getUsers()) {
+      if (user != scope->currentTransform &&
+          !happensBefore(user, scope->currentTransform))
+        return false;
+    }
+    return true;
+  };
+  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
+                                                            skipHandleFn);
   transform::TransformRewriter rewriter(transform->getContext(),
                                         &trackingListener);
 
@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   state.mappings.erase(region);
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   state.regionStack.pop_back();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }
 
 //===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
 //===----------------------------------------------------------------------===//
 
 transform::TrackingListener::TrackingListener(TransformState &state,
-                                              TransformOpInterface op)
-    : TransformState::Extension(state), transformOp(op) {
+                                              TransformOpInterface op,
+                                              SkipHandleFn skipHandleFn)
+    : TransformState::Extension(state), transformOp(op),
+      skipHandleFn(skipHandleFn) {
   if (op) {
     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
       consumedHandles.insert(opOperand->get());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   });
 }
 
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
-  do {
-    if (a->isProperAncestor(b))
-      return false;
-    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
-      return a->isBeforeInBlock(bAncestor);
-    }
-  } while ((a = a->getParentOp()));
-  return false;
-}
-
 void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
                         [&](Value h) { return consumedHandles.contains(h); });
   };
 
-  // Helper function to check if the handle is alive.
-  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
-    for (Value v : opHandles) {
-      for (OpOperand &use : v.getUses())
-        if (use.getOwner() != transformOp &&
-            !happensBefore(use.getOwner(), transformOp))
-          return &use;
-    }
-    return std::nullopt;
-  }();
-
-  if (!firstAliveUser.has_value() || handleWasConsumed()) {
+  // Check if there are any handles that must be updated.
+  Value aliveHandle;
+  if (skipHandleFn) {
+    auto it =
+        llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+    if (it != opHandles.end())
+      aliveHandle = *it;
+  } else if (!opHandles.empty()) {
+    aliveHandle = opHandles.front();
+  }
+  if (!aliveHandle || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
     // consumed. Drop the op form the mapping.
     (void)replacePayloadOp(op, nullptr);
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (!diag.succeeded()) {
-    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
-        << "replacement is required because alive handle(s) exist "
-        << "(first use in this op as operand number "
-        << (*firstAliveUser)->getOperandNumber() << ")";
+    diag.attachNote(aliveHandle.getLoc())
+        << "replacement is required because this handle must be updated";
     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
     (void)replacePayloadOp(op, nullptr);
     return;
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 2d57d4aa2547f..2fd47c6bae396 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{replacement is required because this handle must be updated}}
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
   // expected-note @below {{ran out of suitable replacement values}}
@@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
-  // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 
@@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
      legal_ops = ["func.func", "func.return", "test.new_op"]}
       : !transform.any_op
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+func.func @replacement_op_not_found() {
+  // No op replacement can be found, but there are no handles that must be
+  // updated. No error should be reported.
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
+  transform.apply_patterns to %container {
+    transform.apply_patterns.transform.test_patterns
+  } : !transform.any_op
+  transform.yield
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.annotate %1 "annotated" : !transform.any_op
+  transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
+}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/74290


More information about the Mlir-commits mailing list