[Mlir-commits] [mlir] 244a4e7 - [mlir][transform] Add optional error checking to TrackingListener
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 30 02:18:27 PDT 2023
Author: Matthias Springer
Date: 2023-03-30T11:10:38+02:00
New Revision: 244a4e75aee42126209a6950fa12b345af5edf3d
URL: https://github.com/llvm/llvm-project/commit/244a4e75aee42126209a6950fa12b345af5edf3d
DIFF: https://github.com/llvm/llvm-project/commit/244a4e75aee42126209a6950fa12b345af5edf3d.diff
LOG: [mlir][transform] Add optional error checking to TrackingListener
Derived classes can implement `notifyPayloadReplacementNotFound` for custom error checking.
Differential Revision: https://reviews.llvm.org/D147206
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index f0f209f071a1b..eb55f7f133e91 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -50,6 +50,12 @@ class TrackingListener : public RewriterBase::Listener,
virtual Operation *findReplacementOp(Operation *op,
ValueRange newValues) const;
+ /// This function is called when a tracked payload op is dropped because no
+ /// replacement op was found. Derived classes can implement this function for
+ /// custom error handling.
+ virtual void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) const {}
+
/// Return "true" if the given op is a new op.
bool isNewOp(Operation *op) const;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 63df020f2eb79..a37822d9d0998 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -204,15 +204,19 @@ void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
- if (op->getNumResults() == 0)
- return;
// Replace value handles.
for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
(void)replacePayloadValue(oldValue, newValue);
// Replace op handle.
- (void)replacePayloadOp(op, findReplacementOp(op, newValues));
+ Operation *replacement = findReplacementOp(op, newValues);
+ if (succeeded(replacePayloadOp(op, replacement))) {
+ // If the op is tracked but no replacement op was found, send a
+ // notification.
+ if (!replacement)
+ notifyPayloadReplacementNotFound(op, newValues);
+ }
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list