[Mlir-commits] [mlir] ebf4ab1 - [mlir][transform][NFC] Move TrackingListener to TransformInterfaces.h

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 02:52:58 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T11:52:49+02:00
New Revision: ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c

URL: https://github.com/llvm/llvm-project/commit/ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c
DIFF: https://github.com/llvm/llvm-project/commit/ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c.diff

LOG: [mlir][transform][NFC] Move TrackingListener to TransformInterfaces.h

A TransformRewriter (with attached TrackingListener) will be added to an interface method in a subsequent revision.

Differential Revision: https://reviews.llvm.org/D152426

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5605c4804c680..28972f1b59fe5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
 #include "mlir/Dialect/Transform/Utils/RaggedArray.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -69,6 +70,13 @@ TransformState makeTransformStateForTesting(Region *region,
 SmallVector<OpOperand *>
 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
 } // namespace detail
+} // namespace transform
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
+
+namespace mlir {
+namespace transform {
 
 /// Options controlling the application of transform operations by the
 /// TransformState.
@@ -839,6 +847,125 @@ TransformState::make_isolated_region_scope(Region &region) {
   return RegionScope(*this, region, RegionScope::Isolated());
 }
 
+/// A listener that updates a TransformState based on IR modifications. This
+/// listener can be used during a greedy pattern rewrite to keep the transform
+/// state up-to-date.
+class TrackingListener : public RewriterBase::Listener,
+                         public TransformState::Extension {
+public:
+  /// Create a new TrackingListener for usage in the specified transform op.
+  explicit TrackingListener(TransformState &state, TransformOpInterface op)
+      : TransformState::Extension(state), transformOp(op) {}
+
+protected:
+  /// Return a replacement payload op for the given op, which is going to be
+  /// replaced with the given values. By default, if all values are defined by
+  /// the same op, which also has the same type as the given op, that defining
+  /// op is used as a replacement.
+  ///
+  /// A "failure" return value indicates that no replacement operation could be
+  /// found. A "nullptr" return value indicates that no replacement op is needed
+  /// (e.g., handle is dead or was consumed) and that the payload op should
+  /// be dropped from the mapping.
+  ///
+  /// Example: A tracked "linalg.generic" with two results is replaced with two
+  /// values defined by (another) "linalg.generic". It is reasonable to assume
+  /// that the replacement "linalg.generic" represents the same "computation".
+  /// Therefore, the payload op mapping is updated to the defining op of the
+  /// replacement values.
+  ///
+  /// Counter Example: A "linalg.generic" is replaced with values defined by an
+  /// "scf.for". Without further investigation, the relationship between the
+  /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
+  /// same computation; e.g., there may be tiled "linalg.generic" inside the
+  /// loop body that represents the original computation. Therefore, the
+  /// TrackingListener is conservative by default: it drops the mapping and
+  /// triggers the "payload replacement not found" notification.
+  ///
+  /// If no replacement op could be found according to the rules mentioned
+  /// above, this function tries to skip over cast-like ops that implement
+  /// `CastOpInterface`.
+  ///
+  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
+  /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
+  /// reasonable to assume that the wrapped "linalg.generic" represents the same
+  /// computation as the original "linalg.generic". The mapping is updated
+  /// accordingly.
+  ///
+  /// Certain ops (typically also metadata-only ops) are not considered casts,
+  /// but should be skipped nonetheless. Such ops should implement
+  /// `FindPayloadReplacementOpInterface` to specify with which operands the
+  /// lookup should continue.
+  ///
+  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
+  /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
+  /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
+  /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
+  /// implementation, the replacement op lookup continues with the wrapped
+  /// "linalg.generic" and the mapping is updated accordingly.
+  ///
+  /// Derived classes may override `findReplacementOp` to specify custom
+  /// replacement rules.
+  virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+                                                   ValueRange newValues) const;
+
+  /// Notify the listener that the pattern failed to match the given operation,
+  /// and provide a callback to populate a diagnostic with the reason why the
+  /// failure occurred.
+  LogicalResult
+  notifyMatchFailure(Location loc,
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
+
+  /// This function is called when a tracked payload op is dropped because no
+  /// replacement op was found. Derived classes can implement this function for
+  /// custom error handling.
+  virtual void notifyPayloadReplacementNotFound(Operation *op,
+                                                ValueRange values) {}
+
+  /// Return the single op that defines all given values (if any).
+  static Operation *getCommonDefiningOp(ValueRange values);
+
+  /// Return the transform op in which this TrackingListener is used.
+  TransformOpInterface getTransformOp() const { return transformOp; }
+
+private:
+  void notifyOperationRemoved(Operation *op) override;
+
+  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+
+  /// The transform op in which this TrackingListener is used.
+  TransformOpInterface transformOp;
+};
+
+/// A specialized listener that keeps track of cases in which no replacement
+/// payload could be found. The error state of this listener must be checked
+/// before the end of its lifetime.
+class ErrorCheckingTrackingListener : public TrackingListener {
+public:
+  using transform::TrackingListener::TrackingListener;
+
+  ~ErrorCheckingTrackingListener() override;
+
+  /// Check and return the current error state of this listener. Afterwards,
+  /// resets the error state to "success".
+  DiagnosedSilenceableFailure checkAndResetError();
+
+  /// Return "true" if this tracking listener had a failure.
+  bool failed() const;
+
+protected:
+  void notifyPayloadReplacementNotFound(Operation *op,
+                                        ValueRange values) override;
+
+private:
+  /// The error state of this listener. "Success" indicates that no error
+  /// happened so far.
+  DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
+
+  /// The number of errors that have been encountered.
+  int64_t errorCounter = 0;
+};
+
 /// This trait is supposed to be attached to Transform dialect operations that
 /// can be standalone top-level transforms. Such operations typically contain
 /// other Transform dialect operations that can be executed following some
@@ -1084,14 +1211,6 @@ class ParamProducerTransformOpTrait
   }
 };
 
-} // namespace transform
-} // namespace mlir
-
-#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
-
-namespace mlir {
-namespace transform {
-
 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
 /// single payload operation.
 using ApplyToEachResult = MappedValue;

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index d5fcdb31ab0c8..b8ac86bc43db2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -37,125 +37,6 @@ using SequenceBodyBuilderArgsFn =
     ::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
                               ::mlir::BlockArgument, ::mlir::ValueRange)>;
 
-/// A listener that updates a TransformState based on IR modifications. This
-/// listener can be used during a greedy pattern rewrite to keep the transform
-/// state up-to-date.
-class TrackingListener : public RewriterBase::Listener,
-                         public TransformState::Extension {
-public:
-  /// Create a new TrackingListener for usage in the specified transform op.
-  explicit TrackingListener(TransformState &state, TransformOpInterface op)
-      : TransformState::Extension(state), transformOp(op) {}
-
-protected:
-  /// Return a replacement payload op for the given op, which is going to be
-  /// replaced with the given values. By default, if all values are defined by
-  /// the same op, which also has the same type as the given op, that defining
-  /// op is used as a replacement.
-  ///
-  /// A "failure" return value indicates that no replacement operation could be
-  /// found. A "nullptr" return value indicates that no replacement op is needed
-  /// (e.g., handle is dead or was consumed) and that the payload op should
-  /// be dropped from the mapping.
-  ///
-  /// Example: A tracked "linalg.generic" with two results is replaced with two
-  /// values defined by (another) "linalg.generic". It is reasonable to assume
-  /// that the replacement "linalg.generic" represents the same "computation".
-  /// Therefore, the payload op mapping is updated to the defining op of the
-  /// replacement values.
-  ///
-  /// Counter Example: A "linalg.generic" is replaced with values defined by an
-  /// "scf.for". Without further investigation, the relationship between the
-  /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
-  /// same computation; e.g., there may be tiled "linalg.generic" inside the
-  /// loop body that represents the original computation. Therefore, the
-  /// TrackingListener is conservative by default: it drops the mapping and
-  /// triggers the "payload replacement not found" notification.
-  ///
-  /// If no replacement op could be found according to the rules mentioned
-  /// above, this function tries to skip over cast-like ops that implement
-  /// `CastOpInterface`.
-  ///
-  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
-  /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
-  /// reasonable to assume that the wrapped "linalg.generic" represents the same
-  /// computation as the original "linalg.generic". The mapping is updated
-  /// accordingly.
-  ///
-  /// Certain ops (typically also metadata-only ops) are not considered casts,
-  /// but should be skipped nonetheless. Such ops should implement
-  /// `FindPayloadReplacementOpInterface` to specify with which operands the
-  /// lookup should continue.
-  ///
-  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
-  /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
-  /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
-  /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
-  /// implementation, the replacement op lookup continues with the wrapped
-  /// "linalg.generic" and the mapping is updated accordingly.
-  ///
-  /// Derived classes may override `findReplacementOp` to specify custom
-  /// replacement rules.
-  virtual FailureOr<Operation *> findReplacementOp(Operation *op,
-                                                   ValueRange newValues) const;
-
-  /// Notify the listener that the pattern failed to match the given operation,
-  /// and provide a callback to populate a diagnostic with the reason why the
-  /// failure occurred.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-
-  /// This function is called when a tracked payload op is dropped because no
-  /// replacement op was found. Derived classes can implement this function for
-  /// custom error handling.
-  virtual void notifyPayloadReplacementNotFound(Operation *op,
-                                                ValueRange values) {}
-
-  /// Return the single op that defines all given values (if any).
-  static Operation *getCommonDefiningOp(ValueRange values);
-
-  /// Return the transform op in which this TrackingListener is used.
-  TransformOpInterface getTransformOp() const { return transformOp; }
-
-private:
-  void notifyOperationRemoved(Operation *op) override;
-
-  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
-
-  /// The transform op in which this TrackingListener is used.
-  TransformOpInterface transformOp;
-};
-
-/// A specialized listener that keeps track of cases in which no replacement
-/// payload could be found. The error state of this listener must be checked
-/// before the end of its lifetime.
-class ErrorCheckingTrackingListener : public TrackingListener {
-public:
-  using transform::TrackingListener::TrackingListener;
-
-  ~ErrorCheckingTrackingListener() override;
-
-  /// Check and return the current error state of this listener. Afterwards,
-  /// resets the error state to "success".
-  DiagnosedSilenceableFailure checkAndResetError();
-
-  /// Return "true" if this tracking listener had a failure.
-  bool failed() const;
-
-protected:
-  void notifyPayloadReplacementNotFound(Operation *op,
-                                        ValueRange values) override;
-
-private:
-  /// The error state of this listener. "Success" indicates that no error
-  /// happened so far.
-  DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
-
-  /// The number of errors that have been encountered.
-  int64_t errorCounter = 0;
-};
-
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index e69e56436322c..068fa0bec328e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -7,10 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
@@ -1155,6 +1157,175 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
          values[resultNumber].data() != nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// TrackingListener
+//===----------------------------------------------------------------------===//
+
+Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
+  Operation *defOp = nullptr;
+  for (Value v : values) {
+    // Skip empty values.
+    if (!v)
+      continue;
+    if (!defOp) {
+      defOp = v.getDefiningOp();
+      continue;
+    }
+    if (defOp != v.getDefiningOp())
+      return nullptr;
+  }
+  return defOp;
+}
+
+FailureOr<Operation *>
+transform::TrackingListener::findReplacementOp(Operation *op,
+                                               ValueRange newValues) const {
+  assert(op->getNumResults() == newValues.size() &&
+         "invalid number of replacement values");
+  SmallVector<Value> values(newValues.begin(), newValues.end());
+
+  do {
+    // If the replacement values belong to 
diff erent ops, drop the mapping.
+    Operation *defOp = getCommonDefiningOp(values);
+    if (!defOp)
+      return failure();
+
+    // If the defining op has the same type, we take it as a replacement.
+    if (op->getName() == defOp->getName())
+      return defOp;
+
+    // Replacing an op with a constant-like equivalent is a common
+    // canonicalization.
+    if (defOp->hasTrait<OpTrait::ConstantLike>())
+      return defOp;
+
+    values.clear();
+
+    // Skip through ops that implement FindPayloadReplacementOpInterface.
+    if (auto findReplacementOpInterface =
+            dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
+      values.assign(findReplacementOpInterface.getNextOperands());
+      continue;
+    }
+
+    // Skip through ops that implement CastOpInterface.
+    if (isa<CastOpInterface>(defOp)) {
+      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+      continue;
+    }
+  } while (!values.empty());
+
+  return failure();
+}
+
+LogicalResult transform::TrackingListener::notifyMatchFailure(
+    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
+  LLVM_DEBUG({
+    Diagnostic diag(loc, DiagnosticSeverity::Remark);
+    reasonCallback(diag);
+    DBGS() << "Match Failure : " << diag.str() << "\n";
+  });
+  return failure();
+}
+
+void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
+  // TODO: Walk can be removed when D144193 has landed.
+  op->walk([&](Operation *op) {
+    // Remove mappings for result values.
+    for (OpResult value : op->getResults())
+      (void)replacePayloadValue(value, nullptr);
+    // Remove mapping for op.
+    (void)replacePayloadOp(op, nullptr);
+  });
+}
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
+void transform::TrackingListener::notifyOperationReplaced(
+    Operation *op, ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size() &&
+         "invalid number of replacement values");
+
+  // Replace value handles.
+  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
+    (void)replacePayloadValue(oldValue, newValue);
+
+  // Replace op handle.
+  SmallVector<Value> opHandles;
+  if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+    // Op is not tracked.
+    return;
+  }
+  auto hasAliveUser = [&]() {
+    for (Value v : opHandles)
+      for (Operation *user : v.getUsers())
+        if (!happensBefore(user, transformOp))
+          return true;
+    return false;
+  };
+  if (!hasAliveUser()) {
+    // The op is tracked but the corresponding handles are dead.
+    (void)replacePayloadOp(op, nullptr);
+    return;
+  }
+
+  FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
+  // If the op is tracked but no replacement op was found, send a
+  // notification.
+  if (failed(replacement)) {
+    notifyPayloadReplacementNotFound(op, newValues);
+    (void)replacePayloadOp(op, nullptr);
+    return;
+  }
+
+  (void)replacePayloadOp(op, *replacement);
+}
+
+transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
+  // The state of the ErrorCheckingTrackingListener must be checked and reset
+  // if there was an error. This is to prevent errors from accidentally being
+  // missed.
+  assert(status.succeeded() && "listener state was not checked");
+}
+
+DiagnosedSilenceableFailure
+transform::ErrorCheckingTrackingListener::checkAndResetError() {
+  DiagnosedSilenceableFailure s = std::move(status);
+  status = DiagnosedSilenceableFailure::success();
+  errorCounter = 0;
+  return s;
+}
+
+bool transform::ErrorCheckingTrackingListener::failed() const {
+  return !status.succeeded();
+}
+
+void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
+    Operation *op, ValueRange values) {
+  if (status.succeeded()) {
+    status = emitSilenceableFailure(
+        getTransformOp(), "tracking listener failed to find replacement op");
+  }
+
+  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+  for (auto &&[index, value] : llvm::enumerate(values))
+    status.attachNote(value.getLoc())
+        << "[" << errorCounter << "] replacement value " << index;
+
+  ++errorCounter;
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for TransformEachOpTrait.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a64b9d7aa365d..fe07fb3160726 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -50,175 +50,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
-//===----------------------------------------------------------------------===//
-// TrackingListener
-//===----------------------------------------------------------------------===//
-
-Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
-  Operation *defOp = nullptr;
-  for (Value v : values) {
-    // Skip empty values.
-    if (!v)
-      continue;
-    if (!defOp) {
-      defOp = v.getDefiningOp();
-      continue;
-    }
-    if (defOp != v.getDefiningOp())
-      return nullptr;
-  }
-  return defOp;
-}
-
-FailureOr<Operation *>
-transform::TrackingListener::findReplacementOp(Operation *op,
-                                               ValueRange newValues) const {
-  assert(op->getNumResults() == newValues.size() &&
-         "invalid number of replacement values");
-  SmallVector<Value> values(newValues.begin(), newValues.end());
-
-  do {
-    // If the replacement values belong to 
diff erent ops, drop the mapping.
-    Operation *defOp = getCommonDefiningOp(values);
-    if (!defOp)
-      return failure();
-
-    // If the defining op has the same type, we take it as a replacement.
-    if (op->getName() == defOp->getName())
-      return defOp;
-
-    // Replacing an op with a constant-like equivalent is a common
-    // canonicalization.
-    if (defOp->hasTrait<OpTrait::ConstantLike>())
-      return defOp;
-
-    values.clear();
-
-    // Skip through ops that implement FindPayloadReplacementOpInterface.
-    if (auto findReplacementOpInterface =
-            dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
-      values.assign(findReplacementOpInterface.getNextOperands());
-      continue;
-    }
-
-    // Skip through ops that implement CastOpInterface.
-    if (isa<CastOpInterface>(defOp)) {
-      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
-      continue;
-    }
-  } while (!values.empty());
-
-  return failure();
-}
-
-LogicalResult transform::TrackingListener::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  LLVM_DEBUG({
-    Diagnostic diag(loc, DiagnosticSeverity::Remark);
-    reasonCallback(diag);
-    DBGS() << "Match Failure : " << diag.str() << "\n";
-  });
-  return failure();
-}
-
-void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
-  // TODO: Walk can be removed when D144193 has landed.
-  op->walk([&](Operation *op) {
-    // Remove mappings for result values.
-    for (OpResult value : op->getResults())
-      (void)replacePayloadValue(value, nullptr);
-    // Remove mapping for op.
-    (void)replacePayloadOp(op, nullptr);
-  });
-}
-
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
-  do {
-    if (a->isProperAncestor(b))
-      return false;
-    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
-      return a->isBeforeInBlock(bAncestor);
-    }
-  } while ((a = a->getParentOp()));
-  return false;
-}
-
-void transform::TrackingListener::notifyOperationReplaced(
-    Operation *op, ValueRange newValues) {
-  assert(op->getNumResults() == newValues.size() &&
-         "invalid number of replacement values");
-
-  // Replace value handles.
-  for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
-    (void)replacePayloadValue(oldValue, newValue);
-
-  // Replace op handle.
-  SmallVector<Value> opHandles;
-  if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
-    // Op is not tracked.
-    return;
-  }
-  auto hasAliveUser = [&]() {
-    for (Value v : opHandles)
-      for (Operation *user : v.getUsers())
-        if (!happensBefore(user, transformOp))
-          return true;
-    return false;
-  };
-  if (!hasAliveUser()) {
-    // The op is tracked but the corresponding handles are dead.
-    (void)replacePayloadOp(op, nullptr);
-    return;
-  }
-
-  FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
-  // If the op is tracked but no replacement op was found, send a
-  // notification.
-  if (failed(replacement)) {
-    notifyPayloadReplacementNotFound(op, newValues);
-    (void)replacePayloadOp(op, nullptr);
-    return;
-  }
-
-  (void)replacePayloadOp(op, *replacement);
-}
-
-transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
-  // The state of the ErrorCheckingTrackingListener must be checked and reset
-  // if there was an error. This is to prevent errors from accidentally being
-  // missed.
-  assert(status.succeeded() && "listener state was not checked");
-}
-
-DiagnosedSilenceableFailure
-transform::ErrorCheckingTrackingListener::checkAndResetError() {
-  DiagnosedSilenceableFailure s = std::move(status);
-  status = DiagnosedSilenceableFailure::success();
-  errorCounter = 0;
-  return s;
-}
-
-bool transform::ErrorCheckingTrackingListener::failed() const {
-  return !status.succeeded();
-}
-
-void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values) {
-  if (status.succeeded()) {
-    status = emitSilenceableFailure(
-        getTransformOp(), "tracking listener failed to find replacement op");
-  }
-
-  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
-  for (auto &&[index, value] : llvm::enumerate(values))
-    status.attachNote(value.getLoc())
-        << "[" << errorCounter << "] replacement value " << index;
-
-  ++errorCounter;
-}
-
 //===----------------------------------------------------------------------===//
 // AlternativesOp
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list