[Mlir-commits] [mlir] ebf4ab1 - [mlir][transform][NFC] Move TrackingListener to TransformInterfaces.h
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 02:52:58 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T11:52:49+02:00
New Revision: ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c
URL: https://github.com/llvm/llvm-project/commit/ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c
DIFF: https://github.com/llvm/llvm-project/commit/ebf4ab1dc3cafcecb3d30683f5cbb0fc82b08d0c.diff
LOG: [mlir][transform][NFC] Move TrackingListener to TransformInterfaces.h
A TransformRewriter (with attached TrackingListener) will be added to an interface method in a subsequent revision.
Differential Revision: https://reviews.llvm.org/D152426
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5605c4804c680..28972f1b59fe5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
#include "mlir/Dialect/Transform/Utils/RaggedArray.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
@@ -69,6 +70,13 @@ TransformState makeTransformStateForTesting(Region *region,
SmallVector<OpOperand *>
getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
} // namespace detail
+} // namespace transform
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
+
+namespace mlir {
+namespace transform {
/// Options controlling the application of transform operations by the
/// TransformState.
@@ -839,6 +847,125 @@ TransformState::make_isolated_region_scope(Region ®ion) {
return RegionScope(*this, region, RegionScope::Isolated());
}
+/// A listener that updates a TransformState based on IR modifications. This
+/// listener can be used during a greedy pattern rewrite to keep the transform
+/// state up-to-date.
+class TrackingListener : public RewriterBase::Listener,
+ public TransformState::Extension {
+public:
+ /// Create a new TrackingListener for usage in the specified transform op.
+ explicit TrackingListener(TransformState &state, TransformOpInterface op)
+ : TransformState::Extension(state), transformOp(op) {}
+
+protected:
+ /// Return a replacement payload op for the given op, which is going to be
+ /// replaced with the given values. By default, if all values are defined by
+ /// the same op, which also has the same type as the given op, that defining
+ /// op is used as a replacement.
+ ///
+ /// A "failure" return value indicates that no replacement operation could be
+ /// found. A "nullptr" return value indicates that no replacement op is needed
+ /// (e.g., handle is dead or was consumed) and that the payload op should
+ /// be dropped from the mapping.
+ ///
+ /// Example: A tracked "linalg.generic" with two results is replaced with two
+ /// values defined by (another) "linalg.generic". It is reasonable to assume
+ /// that the replacement "linalg.generic" represents the same "computation".
+ /// Therefore, the payload op mapping is updated to the defining op of the
+ /// replacement values.
+ ///
+ /// Counter Example: A "linalg.generic" is replaced with values defined by an
+ /// "scf.for". Without further investigation, the relationship between the
+ /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
+ /// same computation; e.g., there may be tiled "linalg.generic" inside the
+ /// loop body that represents the original computation. Therefore, the
+ /// TrackingListener is conservative by default: it drops the mapping and
+ /// triggers the "payload replacement not found" notification.
+ ///
+ /// If no replacement op could be found according to the rules mentioned
+ /// above, this function tries to skip over cast-like ops that implement
+ /// `CastOpInterface`.
+ ///
+ /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
+ /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
+ /// reasonable to assume that the wrapped "linalg.generic" represents the same
+ /// computation as the original "linalg.generic". The mapping is updated
+ /// accordingly.
+ ///
+ /// Certain ops (typically also metadata-only ops) are not considered casts,
+ /// but should be skipped nonetheless. Such ops should implement
+ /// `FindPayloadReplacementOpInterface` to specify with which operands the
+ /// lookup should continue.
+ ///
+ /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
+ /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
+ /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
+ /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
+ /// implementation, the replacement op lookup continues with the wrapped
+ /// "linalg.generic" and the mapping is updated accordingly.
+ ///
+ /// Derived classes may override `findReplacementOp` to specify custom
+ /// replacement rules.
+ virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+ ValueRange newValues) const;
+
+ /// Notify the listener that the pattern failed to match the given operation,
+ /// and provide a callback to populate a diagnostic with the reason why the
+ /// failure occurred.
+ LogicalResult
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
+
+ /// This function is called when a tracked payload op is dropped because no
+ /// replacement op was found. Derived classes can implement this function for
+ /// custom error handling.
+ virtual void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) {}
+
+ /// Return the single op that defines all given values (if any).
+ static Operation *getCommonDefiningOp(ValueRange values);
+
+ /// Return the transform op in which this TrackingListener is used.
+ TransformOpInterface getTransformOp() const { return transformOp; }
+
+private:
+ void notifyOperationRemoved(Operation *op) override;
+
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+
+ /// The transform op in which this TrackingListener is used.
+ TransformOpInterface transformOp;
+};
+
+/// A specialized listener that keeps track of cases in which no replacement
+/// payload could be found. The error state of this listener must be checked
+/// before the end of its lifetime.
+class ErrorCheckingTrackingListener : public TrackingListener {
+public:
+ using transform::TrackingListener::TrackingListener;
+
+ ~ErrorCheckingTrackingListener() override;
+
+ /// Check and return the current error state of this listener. Afterwards,
+ /// resets the error state to "success".
+ DiagnosedSilenceableFailure checkAndResetError();
+
+ /// Return "true" if this tracking listener had a failure.
+ bool failed() const;
+
+protected:
+ void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) override;
+
+private:
+ /// The error state of this listener. "Success" indicates that no error
+ /// happened so far.
+ DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
+
+ /// The number of errors that have been encountered.
+ int64_t errorCounter = 0;
+};
+
/// This trait is supposed to be attached to Transform dialect operations that
/// can be standalone top-level transforms. Such operations typically contain
/// other Transform dialect operations that can be executed following some
@@ -1084,14 +1211,6 @@ class ParamProducerTransformOpTrait
}
};
-} // namespace transform
-} // namespace mlir
-
-#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
-
-namespace mlir {
-namespace transform {
-
/// A single result of applying a transform op with `ApplyEachOpTrait` to a
/// single payload operation.
using ApplyToEachResult = MappedValue;
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index d5fcdb31ab0c8..b8ac86bc43db2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -37,125 +37,6 @@ using SequenceBodyBuilderArgsFn =
::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location,
::mlir::BlockArgument, ::mlir::ValueRange)>;
-/// A listener that updates a TransformState based on IR modifications. This
-/// listener can be used during a greedy pattern rewrite to keep the transform
-/// state up-to-date.
-class TrackingListener : public RewriterBase::Listener,
- public TransformState::Extension {
-public:
- /// Create a new TrackingListener for usage in the specified transform op.
- explicit TrackingListener(TransformState &state, TransformOpInterface op)
- : TransformState::Extension(state), transformOp(op) {}
-
-protected:
- /// Return a replacement payload op for the given op, which is going to be
- /// replaced with the given values. By default, if all values are defined by
- /// the same op, which also has the same type as the given op, that defining
- /// op is used as a replacement.
- ///
- /// A "failure" return value indicates that no replacement operation could be
- /// found. A "nullptr" return value indicates that no replacement op is needed
- /// (e.g., handle is dead or was consumed) and that the payload op should
- /// be dropped from the mapping.
- ///
- /// Example: A tracked "linalg.generic" with two results is replaced with two
- /// values defined by (another) "linalg.generic". It is reasonable to assume
- /// that the replacement "linalg.generic" represents the same "computation".
- /// Therefore, the payload op mapping is updated to the defining op of the
- /// replacement values.
- ///
- /// Counter Example: A "linalg.generic" is replaced with values defined by an
- /// "scf.for". Without further investigation, the relationship between the
- /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
- /// same computation; e.g., there may be tiled "linalg.generic" inside the
- /// loop body that represents the original computation. Therefore, the
- /// TrackingListener is conservative by default: it drops the mapping and
- /// triggers the "payload replacement not found" notification.
- ///
- /// If no replacement op could be found according to the rules mentioned
- /// above, this function tries to skip over cast-like ops that implement
- /// `CastOpInterface`.
- ///
- /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
- /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
- /// reasonable to assume that the wrapped "linalg.generic" represents the same
- /// computation as the original "linalg.generic". The mapping is updated
- /// accordingly.
- ///
- /// Certain ops (typically also metadata-only ops) are not considered casts,
- /// but should be skipped nonetheless. Such ops should implement
- /// `FindPayloadReplacementOpInterface` to specify with which operands the
- /// lookup should continue.
- ///
- /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
- /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
- /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
- /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
- /// implementation, the replacement op lookup continues with the wrapped
- /// "linalg.generic" and the mapping is updated accordingly.
- ///
- /// Derived classes may override `findReplacementOp` to specify custom
- /// replacement rules.
- virtual FailureOr<Operation *> findReplacementOp(Operation *op,
- ValueRange newValues) const;
-
- /// Notify the listener that the pattern failed to match the given operation,
- /// and provide a callback to populate a diagnostic with the reason why the
- /// failure occurred.
- LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
-
- /// This function is called when a tracked payload op is dropped because no
- /// replacement op was found. Derived classes can implement this function for
- /// custom error handling.
- virtual void notifyPayloadReplacementNotFound(Operation *op,
- ValueRange values) {}
-
- /// Return the single op that defines all given values (if any).
- static Operation *getCommonDefiningOp(ValueRange values);
-
- /// Return the transform op in which this TrackingListener is used.
- TransformOpInterface getTransformOp() const { return transformOp; }
-
-private:
- void notifyOperationRemoved(Operation *op) override;
-
- void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
-
- /// The transform op in which this TrackingListener is used.
- TransformOpInterface transformOp;
-};
-
-/// A specialized listener that keeps track of cases in which no replacement
-/// payload could be found. The error state of this listener must be checked
-/// before the end of its lifetime.
-class ErrorCheckingTrackingListener : public TrackingListener {
-public:
- using transform::TrackingListener::TrackingListener;
-
- ~ErrorCheckingTrackingListener() override;
-
- /// Check and return the current error state of this listener. Afterwards,
- /// resets the error state to "success".
- DiagnosedSilenceableFailure checkAndResetError();
-
- /// Return "true" if this tracking listener had a failure.
- bool failed() const;
-
-protected:
- void notifyPayloadReplacementNotFound(Operation *op,
- ValueRange values) override;
-
-private:
- /// The error state of this listener. "Success" indicates that no error
- /// happened so far.
- DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
-
- /// The number of errors that have been encountered.
- int64_t errorCounter = 0;
-};
-
} // namespace transform
} // namespace mlir
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index e69e56436322c..068fa0bec328e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -7,10 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
@@ -1155,6 +1157,175 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
values[resultNumber].data() != nullptr;
}
+//===----------------------------------------------------------------------===//
+// TrackingListener
+//===----------------------------------------------------------------------===//
+
+Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
+ Operation *defOp = nullptr;
+ for (Value v : values) {
+ // Skip empty values.
+ if (!v)
+ continue;
+ if (!defOp) {
+ defOp = v.getDefiningOp();
+ continue;
+ }
+ if (defOp != v.getDefiningOp())
+ return nullptr;
+ }
+ return defOp;
+}
+
+FailureOr<Operation *>
+transform::TrackingListener::findReplacementOp(Operation *op,
+ ValueRange newValues) const {
+ assert(op->getNumResults() == newValues.size() &&
+ "invalid number of replacement values");
+ SmallVector<Value> values(newValues.begin(), newValues.end());
+
+ do {
+ // If the replacement values belong to
diff erent ops, drop the mapping.
+ Operation *defOp = getCommonDefiningOp(values);
+ if (!defOp)
+ return failure();
+
+ // If the defining op has the same type, we take it as a replacement.
+ if (op->getName() == defOp->getName())
+ return defOp;
+
+ // Replacing an op with a constant-like equivalent is a common
+ // canonicalization.
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
+ return defOp;
+
+ values.clear();
+
+ // Skip through ops that implement FindPayloadReplacementOpInterface.
+ if (auto findReplacementOpInterface =
+ dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
+ values.assign(findReplacementOpInterface.getNextOperands());
+ continue;
+ }
+
+ // Skip through ops that implement CastOpInterface.
+ if (isa<CastOpInterface>(defOp)) {
+ values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+ continue;
+ }
+ } while (!values.empty());
+
+ return failure();
+}
+
+LogicalResult transform::TrackingListener::notifyMatchFailure(
+ Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
+ LLVM_DEBUG({
+ Diagnostic diag(loc, DiagnosticSeverity::Remark);
+ reasonCallback(diag);
+ DBGS() << "Match Failure : " << diag.str() << "\n";
+ });
+ return failure();
+}
+
+void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
+ // TODO: Walk can be removed when D144193 has landed.
+ op->walk([&](Operation *op) {
+ // Remove mappings for result values.
+ for (OpResult value : op->getResults())
+ (void)replacePayloadValue(value, nullptr);
+ // Remove mapping for op.
+ (void)replacePayloadOp(op, nullptr);
+ });
+}
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+ do {
+ if (a->isProperAncestor(b))
+ return false;
+ if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+ return a->isBeforeInBlock(bAncestor);
+ }
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
+void transform::TrackingListener::notifyOperationReplaced(
+ Operation *op, ValueRange newValues) {
+ assert(op->getNumResults() == newValues.size() &&
+ "invalid number of replacement values");
+
+ // Replace value handles.
+ for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
+ (void)replacePayloadValue(oldValue, newValue);
+
+ // Replace op handle.
+ SmallVector<Value> opHandles;
+ if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+ // Op is not tracked.
+ return;
+ }
+ auto hasAliveUser = [&]() {
+ for (Value v : opHandles)
+ for (Operation *user : v.getUsers())
+ if (!happensBefore(user, transformOp))
+ return true;
+ return false;
+ };
+ if (!hasAliveUser()) {
+ // The op is tracked but the corresponding handles are dead.
+ (void)replacePayloadOp(op, nullptr);
+ return;
+ }
+
+ FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
+ // If the op is tracked but no replacement op was found, send a
+ // notification.
+ if (failed(replacement)) {
+ notifyPayloadReplacementNotFound(op, newValues);
+ (void)replacePayloadOp(op, nullptr);
+ return;
+ }
+
+ (void)replacePayloadOp(op, *replacement);
+}
+
+transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
+ // The state of the ErrorCheckingTrackingListener must be checked and reset
+ // if there was an error. This is to prevent errors from accidentally being
+ // missed.
+ assert(status.succeeded() && "listener state was not checked");
+}
+
+DiagnosedSilenceableFailure
+transform::ErrorCheckingTrackingListener::checkAndResetError() {
+ DiagnosedSilenceableFailure s = std::move(status);
+ status = DiagnosedSilenceableFailure::success();
+ errorCounter = 0;
+ return s;
+}
+
+bool transform::ErrorCheckingTrackingListener::failed() const {
+ return !status.succeeded();
+}
+
+void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
+ Operation *op, ValueRange values) {
+ if (status.succeeded()) {
+ status = emitSilenceableFailure(
+ getTransformOp(), "tracking listener failed to find replacement op");
+ }
+
+ status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+ for (auto &&[index, value] : llvm::enumerate(values))
+ status.attachNote(value.getLoc())
+ << "[" << errorCounter << "] replacement value " << index;
+
+ ++errorCounter;
+}
+
//===----------------------------------------------------------------------===//
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a64b9d7aa365d..fe07fb3160726 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -50,175 +50,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
-//===----------------------------------------------------------------------===//
-// TrackingListener
-//===----------------------------------------------------------------------===//
-
-Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
- Operation *defOp = nullptr;
- for (Value v : values) {
- // Skip empty values.
- if (!v)
- continue;
- if (!defOp) {
- defOp = v.getDefiningOp();
- continue;
- }
- if (defOp != v.getDefiningOp())
- return nullptr;
- }
- return defOp;
-}
-
-FailureOr<Operation *>
-transform::TrackingListener::findReplacementOp(Operation *op,
- ValueRange newValues) const {
- assert(op->getNumResults() == newValues.size() &&
- "invalid number of replacement values");
- SmallVector<Value> values(newValues.begin(), newValues.end());
-
- do {
- // If the replacement values belong to
diff erent ops, drop the mapping.
- Operation *defOp = getCommonDefiningOp(values);
- if (!defOp)
- return failure();
-
- // If the defining op has the same type, we take it as a replacement.
- if (op->getName() == defOp->getName())
- return defOp;
-
- // Replacing an op with a constant-like equivalent is a common
- // canonicalization.
- if (defOp->hasTrait<OpTrait::ConstantLike>())
- return defOp;
-
- values.clear();
-
- // Skip through ops that implement FindPayloadReplacementOpInterface.
- if (auto findReplacementOpInterface =
- dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
- values.assign(findReplacementOpInterface.getNextOperands());
- continue;
- }
-
- // Skip through ops that implement CastOpInterface.
- if (isa<CastOpInterface>(defOp)) {
- values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
- continue;
- }
- } while (!values.empty());
-
- return failure();
-}
-
-LogicalResult transform::TrackingListener::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- LLVM_DEBUG({
- Diagnostic diag(loc, DiagnosticSeverity::Remark);
- reasonCallback(diag);
- DBGS() << "Match Failure : " << diag.str() << "\n";
- });
- return failure();
-}
-
-void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
- // TODO: Walk can be removed when D144193 has landed.
- op->walk([&](Operation *op) {
- // Remove mappings for result values.
- for (OpResult value : op->getResults())
- (void)replacePayloadValue(value, nullptr);
- // Remove mapping for op.
- (void)replacePayloadOp(op, nullptr);
- });
-}
-
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
- do {
- if (a->isProperAncestor(b))
- return false;
- if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
- return a->isBeforeInBlock(bAncestor);
- }
- } while ((a = a->getParentOp()));
- return false;
-}
-
-void transform::TrackingListener::notifyOperationReplaced(
- Operation *op, ValueRange newValues) {
- assert(op->getNumResults() == newValues.size() &&
- "invalid number of replacement values");
-
- // Replace value handles.
- for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
- (void)replacePayloadValue(oldValue, newValue);
-
- // Replace op handle.
- SmallVector<Value> opHandles;
- if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
- // Op is not tracked.
- return;
- }
- auto hasAliveUser = [&]() {
- for (Value v : opHandles)
- for (Operation *user : v.getUsers())
- if (!happensBefore(user, transformOp))
- return true;
- return false;
- };
- if (!hasAliveUser()) {
- // The op is tracked but the corresponding handles are dead.
- (void)replacePayloadOp(op, nullptr);
- return;
- }
-
- FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
- // If the op is tracked but no replacement op was found, send a
- // notification.
- if (failed(replacement)) {
- notifyPayloadReplacementNotFound(op, newValues);
- (void)replacePayloadOp(op, nullptr);
- return;
- }
-
- (void)replacePayloadOp(op, *replacement);
-}
-
-transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
- // The state of the ErrorCheckingTrackingListener must be checked and reset
- // if there was an error. This is to prevent errors from accidentally being
- // missed.
- assert(status.succeeded() && "listener state was not checked");
-}
-
-DiagnosedSilenceableFailure
-transform::ErrorCheckingTrackingListener::checkAndResetError() {
- DiagnosedSilenceableFailure s = std::move(status);
- status = DiagnosedSilenceableFailure::success();
- errorCounter = 0;
- return s;
-}
-
-bool transform::ErrorCheckingTrackingListener::failed() const {
- return !status.succeeded();
-}
-
-void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
- Operation *op, ValueRange values) {
- if (status.succeeded()) {
- status = emitSilenceableFailure(
- getTransformOp(), "tracking listener failed to find replacement op");
- }
-
- status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
- for (auto &&[index, value] : llvm::enumerate(values))
- status.attachNote(value.getLoc())
- << "[" << errorCounter << "] replacement value " << index;
-
- ++errorCounter;
-}
-
//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list