[Mlir-commits] [mlir] 572b171 - [mlir][transform] TrackingListener: Distinguish between failure and "should be dropped"
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 02:40:40 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T11:40:32+02:00
New Revision: 572b171fb58b58f4c3f95dfde495dd81059e1451
URL: https://github.com/llvm/llvm-project/commit/572b171fb58b58f4c3f95dfde495dd81059e1451
DIFF: https://github.com/llvm/llvm-project/commit/572b171fb58b58f4c3f95dfde495dd81059e1451.diff
LOG: [mlir][transform] TrackingListener: Distinguish between failure and "should be dropped"
When looking for replacement ops (`findReplacementOp`) distinguish between "no replacement could be found" and "this op should be dropped from the mapping". The latter case will be utilized in a subsequent revision when a payload op is mapped to a consumed handle.
Differential Revision: https://reviews.llvm.org/D152375
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 5625d193b1508..d5fcdb31ab0c8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -53,6 +53,11 @@ class TrackingListener : public RewriterBase::Listener,
/// the same op, which also has the same type as the given op, that defining
/// op is used as a replacement.
///
+ /// A "failure" return value indicates that no replacement operation could be
+ /// found. A "nullptr" return value indicates that no replacement op is needed
+ /// (e.g., handle is dead or was consumed) and that the payload op should
+ /// be dropped from the mapping.
+ ///
/// Example: A tracked "linalg.generic" with two results is replaced with two
/// values defined by (another) "linalg.generic". It is reasonable to assume
/// that the replacement "linalg.generic" represents the same "computation".
@@ -91,8 +96,8 @@ class TrackingListener : public RewriterBase::Listener,
///
/// Derived classes may override `findReplacementOp` to specify custom
/// replacement rules.
- virtual Operation *findReplacementOp(Operation *op,
- ValueRange newValues) const;
+ virtual FailureOr<Operation *> findReplacementOp(Operation *op,
+ ValueRange newValues) const;
/// Notify the listener that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 69d0b2b20307b..a64b9d7aa365d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -70,7 +70,7 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
return defOp;
}
-Operation *
+FailureOr<Operation *>
transform::TrackingListener::findReplacementOp(Operation *op,
ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
@@ -81,7 +81,7 @@ transform::TrackingListener::findReplacementOp(Operation *op,
// If the replacement values belong to
diff erent ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(values);
if (!defOp)
- return nullptr;
+ return failure();
// If the defining op has the same type, we take it as a replacement.
if (op->getName() == defOp->getName())
@@ -108,7 +108,7 @@ transform::TrackingListener::findReplacementOp(Operation *op,
}
} while (!values.empty());
- return nullptr;
+ return failure();
}
LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -173,12 +173,16 @@ void transform::TrackingListener::notifyOperationReplaced(
return;
}
- Operation *replacement = findReplacementOp(op, newValues);
+ FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
// If the op is tracked but no replacement op was found, send a
// notification.
- if (!replacement)
+ if (failed(replacement)) {
notifyPayloadReplacementNotFound(op, newValues);
- (void)replacePayloadOp(op, replacement);
+ (void)replacePayloadOp(op, nullptr);
+ return;
+ }
+
+ (void)replacePayloadOp(op, *replacement);
}
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index c0a1348c8010a..c6aad2c8fdb35 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -302,7 +302,10 @@ class DummyTrackingListener : public transform::TrackingListener {
// Expose `findReplacementOp` as a public function, so that it can be tested.
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
- return findReplacementOp(op, newValues);
+ FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
+ if (failed(replacementOp))
+ return nullptr;
+ return *replacementOp;
}
};
} // namespace
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index fc6f323a3b557..0c3697d1171ff 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -696,15 +696,15 @@ class TestTrackingListener : public transform::TrackingListener {
using transform::TrackingListener::TrackingListener;
protected:
- Operation *findReplacementOp(Operation *op,
- ValueRange newValues) const override {
+ FailureOr<Operation *>
+ findReplacementOp(Operation *op, ValueRange newValues) const override {
if (newValues.size() != 1)
- return nullptr;
+ return failure();
Operation *replacement = newValues[0].getDefiningOp();
if (!replacement)
- return nullptr;
+ return failure();
if (replacement->getName().getStringRef() != "test.update_mapping")
- return nullptr;
+ return failure();
return replacement;
}
};
More information about the Mlir-commits
mailing list