[Mlir-commits] [mlir] [mlir][transform] Improve error message of tracking listener. (PR #66987)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 21 00:27:21 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This PR extends the error message of the tracking listener when replacement ops cannot be found. That may happen if the applied patterns replace an op by an op of a different kind or by block arguments. However, this only matters if there are alive handles to the replaced op. The new error message mentions that explicitly and reports the alive handles.

---
Full diff: https://github.com/llvm/llvm-project/pull/66987.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+8-5) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+14-11) 
- (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..e1169f45d17f699 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -774,7 +774,8 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template <typename Range> void set(OpResult value, Range &&ops) {
+  template <typename Range>
+  void set(OpResult value, Range &&ops) {
     int64_t position = value.getResultNumber();
     assert(position < static_cast<int64_t>(operations.size()) &&
            "setting results for a non-existent handle");
@@ -942,8 +943,9 @@ class TrackingListener : public RewriterBase::Listener,
   /// This function is called when a tracked payload op is dropped because no
   /// replacement op was found. Derived classes can implement this function for
   /// custom error handling.
-  virtual void notifyPayloadReplacementNotFound(Operation *op,
-                                                ValueRange values) {}
+  virtual void
+  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
+                                   ArrayRef<Operation *> aliveUsers) {}
 
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
@@ -983,8 +985,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
   bool failed() const;
 
 protected:
-  void notifyPayloadReplacementNotFound(Operation *op,
-                                        ValueRange values) override;
+  void
+  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
+                                   ArrayRef<Operation *> aliveUsers) override;
 
 private:
   /// The error state of this listener. "Success" indicates that no error
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..ab26ec66a9ac4e3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1396,14 +1396,13 @@ void transform::TrackingListener::notifyOperationReplaced(
   };
 
   // Helper function to check if the handle is alive.
-  auto hasAliveUser = [&]() {
-    for (Value v : opHandles) {
-      for (Operation *user : v.getUsers())
-        if (user != transformOp && !happensBefore(user, transformOp))
-          return true;
-    }
-    return false;
-  };
+  SmallVector<Operation *> aliveUsers;
+  for (Value v : opHandles) {
+    for (Operation *user : v.getUsers())
+      if (user != transformOp && !happensBefore(user, transformOp))
+        aliveUsers.push_back(user);
+  }
+  auto hasAliveUser = [&]() { return !aliveUsers.empty(); };
 
   if (!hasAliveUser() || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
@@ -1416,7 +1415,7 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (failed(replacement)) {
-    notifyPayloadReplacementNotFound(op, newValues);
+    notifyPayloadReplacementNotFound(op, newValues, aliveUsers);
     (void)replacePayloadOp(op, nullptr);
     return;
   }
@@ -1444,16 +1443,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
 }
 
 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values) {
+    Operation *op, ValueRange values, ArrayRef<Operation *> aliveUsers) {
   if (status.succeeded()) {
     status = emitSilenceableFailure(
-        getTransformOp(), "tracking listener failed to find replacement op");
+        getTransformOp(), "op was replaced but replacement was of different "
+                          "kind, invalidating alive handles");
   }
 
   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
   for (auto &&[index, value] : llvm::enumerate(values))
     status.attachNote(value.getLoc())
         << "[" << errorCounter << "] replacement value " << index;
+  for (auto &&[index, user] : llvm::enumerate(aliveUsers))
+    status.attachNote(user->getLoc())
+        << "[" << errorCounter << "] alive handle " << index;
 
   ++errorCounter;
 }
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index efbdab78d397faa..f6a0801204b80d7 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -37,12 +37,13 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{tracking listener failed to find replacement op}}
+  // expected-error @below {{op was replaced but replacement was of different kind, invalidating alive handles}}
   transform.apply_patterns to %0 {
     transform.apply_patterns.transform.test_patterns
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
+  // expected-note @below {{[0] alive handle 0}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/66987


More information about the Mlir-commits mailing list