[Mlir-commits] [mlir] 78c8ab5 - Revert "[mlir][transform] Improve error message of tracking listener. (#66987)"

Vitaly Buka llvmlistbot at llvm.org
Mon Sep 25 09:07:31 PDT 2023


Author: Vitaly Buka
Date: 2023-09-25T09:07:17-07:00
New Revision: 78c8ab5844e618162c4cf3982d05102d4da10d23

URL: https://github.com/llvm/llvm-project/commit/78c8ab5844e618162c4cf3982d05102d4da10d23
DIFF: https://github.com/llvm/llvm-project/commit/78c8ab5844e618162c4cf3982d05102d4da10d23.diff

LOG: Revert "[mlir][transform] Improve error message of tracking listener. (#66987)"

Breaks https://lab.llvm.org/buildbot/#/builders/5/builds/36953

This reverts commit a7530452fd163c84e83e662b549ade7b0fae9edf.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-pattern-application.mlir
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b45861e6190c18a..f87c0981bc4a9a1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -797,8 +797,7 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template <typename Range>
-  void set(OpResult value, Range &&ops) {
+  template <typename Range> void set(OpResult value, Range &&ops) {
     int64_t position = value.getResultNumber();
     assert(position < static_cast<int64_t>(operations.size()) &&
            "setting results for a non-existent handle");
@@ -973,9 +972,8 @@ class TrackingListener : public RewriterBase::Listener,
   ///
   /// Derived classes may override `findReplacementOp` to specify custom
   /// replacement rules.
-  virtual DiagnosedSilenceableFailure
-  findReplacementOp(Operation *&result, Operation *op,
-                    ValueRange newValues) const;
+  virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+                                                   ValueRange newValues) const;
 
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the
@@ -987,9 +985,8 @@ class TrackingListener : public RewriterBase::Listener,
   /// This function is called when a tracked payload op is dropped because no
   /// replacement op was found. Derived classes can implement this function for
   /// custom error handling.
-  virtual void
-  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
-                                   DiagnosedSilenceableFailure &&diag) {}
+  virtual void notifyPayloadReplacementNotFound(Operation *op,
+                                                ValueRange values) {}
 
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
@@ -1029,9 +1026,8 @@ class ErrorCheckingTrackingListener : public TrackingListener {
   bool failed() const;
 
 protected:
-  void
-  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
-                                   DiagnosedSilenceableFailure &&diag) override;
+  void notifyPayloadReplacementNotFound(Operation *op,
+                                        ValueRange values) override;
 
 private:
   /// The error state of this listener. "Success" indicates that no error

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 483b0e7f7a4f998..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1291,36 +1291,27 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
   return defOp;
 }
 
-DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
-    Operation *&result, Operation *op, ValueRange newValues) const {
+FailureOr<Operation *>
+transform::TrackingListener::findReplacementOp(Operation *op,
+                                               ValueRange newValues) const {
   assert(op->getNumResults() == newValues.size() &&
          "invalid number of replacement values");
   SmallVector<Value> values(newValues.begin(), newValues.end());
 
-  DiagnosedSilenceableFailure diag = emitSilenceableFailure(
-      getTransformOp(), "tracking listener failed to find replacement op "
-                        "during application of this transform op");
-
   do {
     // If the replacement values belong to 
diff erent ops, drop the mapping.
     Operation *defOp = getCommonDefiningOp(values);
-    if (!defOp) {
-      diag.attachNote() << "replacement values belong to 
diff erent ops";
-      return diag;
-    }
+    if (!defOp)
+      return failure();
 
     // If the defining op has the same type, we take it as a replacement.
-    if (op->getName() == defOp->getName()) {
-      result = defOp;
-      return DiagnosedSilenceableFailure::success();
-    }
+    if (op->getName() == defOp->getName())
+      return defOp;
 
     // Replacing an op with a constant-like equivalent is a common
     // canonicalization.
-    if (defOp->hasTrait<OpTrait::ConstantLike>()) {
-      result = defOp;
-      return DiagnosedSilenceableFailure::success();
-    }
+    if (defOp->hasTrait<OpTrait::ConstantLike>())
+      return defOp;
 
     values.clear();
 
@@ -1328,22 +1319,17 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
     if (auto findReplacementOpInterface =
             dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
       values.assign(findReplacementOpInterface.getNextOperands());
-      diag.attachNote(defOp->getLoc()) << "using operands provided by "
-                                          "'FindPayloadReplacementOpInterface'";
       continue;
     }
 
     // Skip through ops that implement CastOpInterface.
     if (isa<CastOpInterface>(defOp)) {
       values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
-      diag.attachNote(defOp->getLoc())
-          << "using output of 'CastOpInterface' op";
       continue;
     }
   } while (!values.empty());
 
-  diag.attachNote() << "ran out of suitable replacement values";
-  return diag;
+  return failure();
 }
 
 LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1412,39 +1398,32 @@ void transform::TrackingListener::notifyOperationReplaced(
   };
 
   // Helper function to check if the handle is alive.
-  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
+  auto hasAliveUser = [&]() {
     for (Value v : opHandles) {
-      for (OpOperand &use : v.getUses())
-        if (use.getOwner() != transformOp &&
-            !happensBefore(use.getOwner(), transformOp))
-          return &use;
+      for (Operation *user : v.getUsers())
+        if (user != transformOp && !happensBefore(user, transformOp))
+          return true;
     }
-    return std::nullopt;
-  }();
+    return false;
+  };
 
-  if (!firstAliveUser.has_value() || handleWasConsumed()) {
+  if (!hasAliveUser() || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
     // consumed. Drop the op form the mapping.
     (void)replacePayloadOp(op, nullptr);
     return;
   }
 
-  Operation *replacement;
-  DiagnosedSilenceableFailure diag =
-      findReplacementOp(replacement, op, newValues);
+  FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
   // If the op is tracked but no replacement op was found, send a
   // notification.
-  if (!diag.succeeded()) {
-    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
-        << "replacement is required because alive handle(s) exist "
-        << "(first use in this op as operand number "
-        << (*firstAliveUser)->getOperandNumber() << ")";
-    notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
+  if (failed(replacement)) {
+    notifyPayloadReplacementNotFound(op, newValues);
     (void)replacePayloadOp(op, nullptr);
     return;
   }
 
-  (void)replacePayloadOp(op, replacement);
+  (void)replacePayloadOp(op, *replacement);
 }
 
 transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1467,20 +1446,17 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
 }
 
 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
-
-  // Merge potentially existing diags and store the result in the listener.
-  SmallVector<Diagnostic> diags;
-  diag.takeDiagnostics(diags);
-  if (!status.succeeded())
-    status.takeDiagnostics(diags);
-  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
+    Operation *op, ValueRange values) {
+  if (status.succeeded()) {
+    status = emitSilenceableFailure(
+        getTransformOp(), "tracking listener failed to find replacement op");
+  }
 
-  // Report more details.
   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
   for (auto &&[index, value] : llvm::enumerate(values))
     status.attachNote(value.getLoc())
         << "[" << errorCounter << "] replacement value " << index;
+
   ++errorCounter;
 }
 

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 2d57d4aa2547f2f..efbdab78d397faa 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -37,14 +37,12 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
-  // expected-note @below {{ran out of suitable replacement values}}
+  // expected-error @below {{tracking listener failed to find replacement op}}
   transform.apply_patterns to %0 {
     transform.apply_patterns.transform.test_patterns
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
-  // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 366956c535cfa96..3c510c18996b0c2 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -32,8 +32,8 @@ struct TestTensorTransforms
   TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
-                    transform::TransformDialect>();
+    registry
+        .insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
   }
 
   StringRef getArgument() const final {
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
 
   // Expose `findReplacementOp` as a public function, so that it can be tested.
   Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
-    Operation *replacementOp;
-    if (!findReplacementOp(replacementOp, op, newValues).succeeded())
+    FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
+    if (failed(replacementOp))
       return nullptr;
-    return replacementOp;
+    return *replacementOp;
   }
 };
 } // namespace
@@ -352,17 +352,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
   transform::TransformState transformState =
       transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
                                                       /*payloadRoot=*/nullptr);
-  MLIRContext *context = rootOp->getContext();
-  OpBuilder builder(context);
-  auto transformOp = builder.create<transform::NamedSequenceOp>(
-      rootOp->getLoc(),
-      /*sym_name=*/"test_sequence",
-      /*function_type=*/
-      TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
-      /*sym_visibility*/ StringAttr::get(context, "public"),
-      /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
-      /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
-  DummyTrackingListener listener(transformState, transformOp);
+  DummyTrackingListener listener(transformState,
+                                 transform::TransformOpInterface());
   Operation *replacement = listener.getReplacementOp(replaced, replacements);
   if (!replacement) {
     replaced->emitError("listener could not find replacement op");


        


More information about the Mlir-commits mailing list