[Mlir-commits] [mlir] [mlir][transform] Improve error message of tracking listener. (PR #66987)

Ingo Müller llvmlistbot at llvm.org
Thu Sep 21 07:25:46 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/66987

>From aa5c6e5c78ece6548deb7f56f69baa6b4cb7ef8d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 07:22:50 +0000
Subject: [PATCH 1/3] [mlir][transform] Improve error message of tracking
 listener.

This PR extends the error message of the tracking listener when
replacement ops cannot be found. That may happen if the applied patterns
replace an op by an op of a different kind or by block arguments.
However, this only matters if there are alive handles to the replaced
op. The new error message mentions that explicitly and reports the alive
handles.
---
 .../Transform/IR/TransformInterfaces.h        | 13 ++++++----
 .../Transform/IR/TransformInterfaces.cpp      | 25 +++++++++++--------
 .../Transform/test-pattern-application.mlir   |  3 ++-
 3 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..e1169f45d17f699 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -774,7 +774,8 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template <typename Range> void set(OpResult value, Range &&ops) {
+  template <typename Range>
+  void set(OpResult value, Range &&ops) {
     int64_t position = value.getResultNumber();
     assert(position < static_cast<int64_t>(operations.size()) &&
            "setting results for a non-existent handle");
@@ -942,8 +943,9 @@ class TrackingListener : public RewriterBase::Listener,
   /// This function is called when a tracked payload op is dropped because no
   /// replacement op was found. Derived classes can implement this function for
   /// custom error handling.
-  virtual void notifyPayloadReplacementNotFound(Operation *op,
-                                                ValueRange values) {}
+  virtual void
+  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
+                                   ArrayRef<Operation *> aliveUsers) {}
 
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
@@ -983,8 +985,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
   bool failed() const;
 
 protected:
-  void notifyPayloadReplacementNotFound(Operation *op,
-                                        ValueRange values) override;
+  void
+  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
+                                   ArrayRef<Operation *> aliveUsers) override;
 
 private:
   /// The error state of this listener. "Success" indicates that no error
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..ab26ec66a9ac4e3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1396,14 +1396,13 @@ void transform::TrackingListener::notifyOperationReplaced(
   };
 
   // Helper function to check if the handle is alive.
-  auto hasAliveUser = [&]() {
-    for (Value v : opHandles) {
-      for (Operation *user : v.getUsers())
-        if (user != transformOp && !happensBefore(user, transformOp))
-          return true;
-    }
-    return false;
-  };
+  SmallVector<Operation *> aliveUsers;
+  for (Value v : opHandles) {
+    for (Operation *user : v.getUsers())
+      if (user != transformOp && !happensBefore(user, transformOp))
+        aliveUsers.push_back(user);
+  }
+  auto hasAliveUser = [&]() { return !aliveUsers.empty(); };
 
   if (!hasAliveUser() || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
@@ -1416,7 +1415,7 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (failed(replacement)) {
-    notifyPayloadReplacementNotFound(op, newValues);
+    notifyPayloadReplacementNotFound(op, newValues, aliveUsers);
     (void)replacePayloadOp(op, nullptr);
     return;
   }
@@ -1444,16 +1443,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
 }
 
 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values) {
+    Operation *op, ValueRange values, ArrayRef<Operation *> aliveUsers) {
   if (status.succeeded()) {
     status = emitSilenceableFailure(
-        getTransformOp(), "tracking listener failed to find replacement op");
+        getTransformOp(), "op was replaced but replacement was of different "
+                          "kind, invalidating alive handles");
   }
 
   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
   for (auto &&[index, value] : llvm::enumerate(values))
     status.attachNote(value.getLoc())
         << "[" << errorCounter << "] replacement value " << index;
+  for (auto &&[index, user] : llvm::enumerate(aliveUsers))
+    status.attachNote(user->getLoc())
+        << "[" << errorCounter << "] alive handle " << index;
 
   ++errorCounter;
 }
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index efbdab78d397faa..f6a0801204b80d7 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -37,12 +37,13 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{tracking listener failed to find replacement op}}
+  // expected-error @below {{op was replaced but replacement was of different kind, invalidating alive handles}}
   transform.apply_patterns to %0 {
     transform.apply_patterns.transform.test_patterns
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
+  // expected-note @below {{[0] alive handle 0}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 

>From 9ed343aa9cfe04c0e4118b7846d30a93c6feb88c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 12:49:53 +0000
Subject: [PATCH 2/3] Address @ftynse's and @matthias-springer's comments.

In particular:

* Change `TrackingListener::findReplacementOp` such that it returns a
  `DiagnosedSilenceableFailure` rather than just a `FailureOr`.
* Have `TrackingListener::notifyPayloadReplacementNotFound` accept a
  `DiagnosedSilenceableFailure` with previous diagnostics.
* Create initial diagnostics in `findReplacementOp` and give information
  about why finding the replacement failed.
* Forward that result into `notifyPayloadReplacementNotFound` and
  adapt how the `ErroriCheckingTrackingListener` deals with the
  pre-existing diagnostics.
* Only report the first alive user.
* Adapt error messages to suggestions.
---
 .../Transform/IR/TransformInterfaces.h        |  9 +-
 .../Transform/IR/TransformInterfaces.cpp      | 85 ++++++++++++-------
 .../Transform/test-pattern-application.mlir   |  5 +-
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  6 +-
 4 files changed, 64 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index e1169f45d17f699..2c1775b3b462cf8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -930,8 +930,9 @@ class TrackingListener : public RewriterBase::Listener,
   ///
   /// Derived classes may override `findReplacementOp` to specify custom
   /// replacement rules.
-  virtual FailureOr<Operation *> findReplacementOp(Operation *op,
-                                                   ValueRange newValues) const;
+  virtual DiagnosedSilenceableFailure
+  findReplacementOp(Operation *&result, Operation *op,
+                    ValueRange newValues) const;
 
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the
@@ -945,7 +946,7 @@ class TrackingListener : public RewriterBase::Listener,
   /// custom error handling.
   virtual void
   notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
-                                   ArrayRef<Operation *> aliveUsers) {}
+                                   DiagnosedSilenceableFailure &&diag) {}
 
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
@@ -987,7 +988,7 @@ class ErrorCheckingTrackingListener : public TrackingListener {
 protected:
   void
   notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
-                                   ArrayRef<Operation *> aliveUsers) override;
+                                   DiagnosedSilenceableFailure &&diag) override;
 
 private:
   /// The error state of this listener. "Success" indicates that no error
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ab26ec66a9ac4e3..099408399a996fd 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1289,27 +1289,36 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
   return defOp;
 }
 
-FailureOr<Operation *>
-transform::TrackingListener::findReplacementOp(Operation *op,
-                                               ValueRange newValues) const {
+DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
+    Operation *&result, Operation *op, ValueRange newValues) const {
   assert(op->getNumResults() == newValues.size() &&
          "invalid number of replacement values");
   SmallVector<Value> values(newValues.begin(), newValues.end());
 
+  DiagnosedSilenceableFailure diag = emitSilenceableFailure(
+      getTransformOp(), "tracking listener failed to find replacement op "
+                        "during application of this transform op");
+
   do {
     // If the replacement values belong to different ops, drop the mapping.
     Operation *defOp = getCommonDefiningOp(values);
-    if (!defOp)
-      return failure();
+    if (!defOp) {
+      diag.attachNote() << "replacement values belong to different ops";
+      return diag;
+    }
 
     // If the defining op has the same type, we take it as a replacement.
-    if (op->getName() == defOp->getName())
-      return defOp;
+    if (op->getName() == defOp->getName()) {
+      result = defOp;
+      return DiagnosedSilenceableFailure::success();
+    }
 
     // Replacing an op with a constant-like equivalent is a common
     // canonicalization.
-    if (defOp->hasTrait<OpTrait::ConstantLike>())
-      return defOp;
+    if (defOp->hasTrait<OpTrait::ConstantLike>()) {
+      result = defOp;
+      return DiagnosedSilenceableFailure::success();
+    }
 
     values.clear();
 
@@ -1317,17 +1326,22 @@ transform::TrackingListener::findReplacementOp(Operation *op,
     if (auto findReplacementOpInterface =
             dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
       values.assign(findReplacementOpInterface.getNextOperands());
+      diag.attachNote(defOp->getLoc()) << "using operands provided by "
+                                          "'FindPayloadReplacementOpInterface'";
       continue;
     }
 
     // Skip through ops that implement CastOpInterface.
     if (isa<CastOpInterface>(defOp)) {
       values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+      diag.attachNote(defOp->getLoc())
+          << "using output of 'CastOpInterface' op";
       continue;
     }
   } while (!values.empty());
 
-  return failure();
+  diag.attachNote() << "ran out of suitable replacement values";
+  return diag;
 }
 
 LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1396,31 +1410,39 @@ void transform::TrackingListener::notifyOperationReplaced(
   };
 
   // Helper function to check if the handle is alive.
-  SmallVector<Operation *> aliveUsers;
-  for (Value v : opHandles) {
-    for (Operation *user : v.getUsers())
-      if (user != transformOp && !happensBefore(user, transformOp))
-        aliveUsers.push_back(user);
-  }
-  auto hasAliveUser = [&]() { return !aliveUsers.empty(); };
+  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
+    for (Value v : opHandles) {
+      for (OpOperand &use : v.getUses())
+        if (use.getOwner() != transformOp &&
+            !happensBefore(use.getOwner(), transformOp))
+          return &use;
+    }
+    return std::nullopt;
+  }();
 
-  if (!hasAliveUser() || handleWasConsumed()) {
+  if (!firstAliveUser.has_value() || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
     // consumed. Drop the op form the mapping.
     (void)replacePayloadOp(op, nullptr);
     return;
   }
 
-  FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
+  Operation *replacement;
+  DiagnosedSilenceableFailure diag =
+      findReplacementOp(replacement, op, newValues);
   // If the op is tracked but no replacement op was found, send a
   // notification.
-  if (failed(replacement)) {
-    notifyPayloadReplacementNotFound(op, newValues, aliveUsers);
+  if (!diag.succeeded()) {
+    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
+        << "replacement is required because alive handle(s) exist "
+        << "(first use in this op as operand number "
+        << (*firstAliveUser)->getOperandNumber() << ")";
+    notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
     (void)replacePayloadOp(op, nullptr);
     return;
   }
 
-  (void)replacePayloadOp(op, *replacement);
+  (void)replacePayloadOp(op, replacement);
 }
 
 transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1443,21 +1465,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
 }
 
 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values, ArrayRef<Operation *> aliveUsers) {
-  if (status.succeeded()) {
-    status = emitSilenceableFailure(
-        getTransformOp(), "op was replaced but replacement was of different "
-                          "kind, invalidating alive handles");
-  }
+    Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
+
+  // Merge potentially existing diags and store the result in the listener.
+  SmallVector<Diagnostic> diags;
+  diag.takeDiagnostics(diags);
+  if (!status.succeeded())
+    status.takeDiagnostics(diags);
+  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
 
+  // Report more details.
   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
   for (auto &&[index, value] : llvm::enumerate(values))
     status.attachNote(value.getLoc())
         << "[" << errorCounter << "] replacement value " << index;
-  for (auto &&[index, user] : llvm::enumerate(aliveUsers))
-    status.attachNote(user->getLoc())
-        << "[" << errorCounter << "] alive handle " << index;
-
   ++errorCounter;
 }
 
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index f6a0801204b80d7..2d57d4aa2547f2f 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -37,13 +37,14 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{op was replaced but replacement was of different kind, invalidating alive handles}}
+  // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
+  // expected-note @below {{ran out of suitable replacement values}}
   transform.apply_patterns to %0 {
     transform.apply_patterns.transform.test_patterns
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
-  // expected-note @below {{[0] alive handle 0}}
+  // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 3c510c18996b0c2..35619f2816f5905 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
 
   // Expose `findReplacementOp` as a public function, so that it can be tested.
   Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
-    FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
-    if (failed(replacementOp))
+    Operation *replacementOp;
+    if (!findReplacementOp(replacementOp, op, newValues).succeeded())
       return nullptr;
-    return *replacementOp;
+    return replacementOp;
   }
 };
 } // namespace

>From c25f34714ad40059d93099200d0eef3c6de2b03a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 14:23:20 +0000
Subject: [PATCH 3/3] Fix test for tensor transform ops.

That op created an empty (and invalid) transform op with
`TransformOpInterface()` for testing the tracking listeners. The
previous listeners never accessed the op created this way, so this
wasn't a problem. However, the new version does, so we need to create a
real op, which is what this commit does.
---
 .../lib/Dialect/Tensor/TestTensorTransforms.cpp | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 35619f2816f5905..366956c535cfa96 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -32,8 +32,8 @@ struct TestTensorTransforms
   TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
+    registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
+                    transform::TransformDialect>();
   }
 
   StringRef getArgument() const final {
@@ -352,8 +352,17 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
   transform::TransformState transformState =
       transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
                                                       /*payloadRoot=*/nullptr);
-  DummyTrackingListener listener(transformState,
-                                 transform::TransformOpInterface());
+  MLIRContext *context = rootOp->getContext();
+  OpBuilder builder(context);
+  auto transformOp = builder.create<transform::NamedSequenceOp>(
+      rootOp->getLoc(),
+      /*sym_name=*/"test_sequence",
+      /*function_type=*/
+      TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
+      /*sym_visibility*/ StringAttr::get(context, "public"),
+      /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
+      /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
+  DummyTrackingListener listener(transformState, transformOp);
   Operation *replacement = listener.getReplacementOp(replaced, replacements);
   if (!replacement) {
     replaced->emitError("listener could not find replacement op");



More information about the Mlir-commits mailing list