[Mlir-commits] [mlir] 905e932 - [mlir][transform] TrackingListener: Drop mappings of tracked ops when all handles are dead
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 11 23:56:30 PDT 2023
Author: Matthias Springer
Date: 2023-04-12T15:56:22+09:00
New Revision: 905e93244187a88614fc866b6089ecc3f7b16105
URL: https://github.com/llvm/llvm-project/commit/905e93244187a88614fc866b6089ecc3f7b16105
DIFF: https://github.com/llvm/llvm-project/commit/905e93244187a88614fc866b6089ecc3f7b16105.diff
LOG: [mlir][transform] TrackingListener: Drop mappings of tracked ops when all handles are dead
No replacement ops are needed for tracked ops who's handles are all dead.
Differential Revision: https://reviews.llvm.org/D147510
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
index 69fcbd8ae3ea2..d1b14d206cd77 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
@@ -22,8 +22,7 @@ namespace tensor {
/// replacements.
class TrackingListener : public transform::TrackingListener {
public:
- explicit TrackingListener(transform::TransformState &state)
- : transform::TrackingListener(state) {}
+ using transform::TrackingListener::TrackingListener;
protected:
Operation *findReplacementOp(Operation *op,
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 58e9d8506808a..af812a735b77f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -39,8 +39,9 @@ using SequenceBodyBuilderArgsFn =
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
- explicit TrackingListener(TransformState &state)
- : TransformState::Extension(state) {}
+ /// Create a new TrackingListener for usage in the specified transform op.
+ explicit TrackingListener(TransformState &state, TransformOpInterface op)
+ : TransformState::Extension(state), transformOp(op) {}
protected:
/// Return a replacement payload op for the given op, which is going to be
@@ -78,6 +79,9 @@ class TrackingListener : public RewriterBase::Listener,
/// Ops that were newly created during the transform.
DenseMap<OperationName, DenseSet<Operation *>> newOps;
+
+ /// The transform op in which this TrackingListener is used.
+ TransformOpInterface transformOp;
};
} // namespace transform
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1f4ac08f1259e..a92ffa196a40f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1818,7 +1818,7 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
transform::TransformState &state) {
tensor::PadOp hoistedPadOp;
SmallVector<GenericOp> transposeOps;
- TrackingListener listener(state);
+ TrackingListener listener(state, *this);
IRRewriter rewriter(target->getContext(), &listener);
FailureOr<Value> result =
hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
@@ -3068,7 +3068,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);
- TrackingListener listener(state);
+ TrackingListener listener(state, *this);
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 15583315cc9b3..27b00cf6c6c9c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -210,6 +210,19 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
});
}
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+ do {
+ if (a->isProperAncestor(b))
+ return false;
+ if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+ return a->isBeforeInBlock(bAncestor);
+ }
+ } while ((a = a->getParentOp()));
+ return false;
+}
+
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
@@ -220,13 +233,30 @@ void transform::TrackingListener::notifyOperationReplaced(
(void)replacePayloadValue(oldValue, newValue);
// Replace op handle.
- Operation *replacement = findReplacementOp(op, newValues);
- if (succeeded(replacePayloadOp(op, replacement))) {
- // If the op is tracked but no replacement op was found, send a
- // notification.
- if (!replacement)
- notifyPayloadReplacementNotFound(op, newValues);
+ SmallVector<Value> opHandles;
+ if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+ // Op is not tracked.
+ return;
+ }
+ auto hasAliveUser = [&]() {
+ for (Value v : opHandles)
+ for (Operation *user : v.getUsers())
+ if (!happensBefore(user, transformOp))
+ return true;
+ return false;
+ };
+ if (!hasAliveUser()) {
+ // The op is tracked but the corresponding handles are dead.
+ (void)replacePayloadOp(op, nullptr);
+ return;
}
+
+ Operation *replacement = findReplacementOp(op, newValues);
+ // If the op is tracked but no replacement op was found, send a
+ // notification.
+ if (!replacement)
+ notifyPayloadReplacementNotFound(op, newValues);
+ (void)replacePayloadOp(op, replacement);
}
//===----------------------------------------------------------------------===//
@@ -349,7 +379,7 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
if (!failed) {
// We will be using the clones, so cancel their scheduled deletion.
deleteClones.release();
- TrackingListener listener(state);
+ TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
for (const auto &kvp : llvm::zip(originals, clones)) {
Operation *original = std::get<0>(kvp);
More information about the Mlir-commits
mailing list