[Mlir-commits] [mlir] 905e932 - [mlir][transform] TrackingListener: Drop mappings of tracked ops when all handles are dead

Matthias Springer llvmlistbot at llvm.org
Tue Apr 11 23:56:30 PDT 2023


Author: Matthias Springer
Date: 2023-04-12T15:56:22+09:00
New Revision: 905e93244187a88614fc866b6089ecc3f7b16105

URL: https://github.com/llvm/llvm-project/commit/905e93244187a88614fc866b6089ecc3f7b16105
DIFF: https://github.com/llvm/llvm-project/commit/905e93244187a88614fc866b6089ecc3f7b16105.diff

LOG: [mlir][transform] TrackingListener: Drop mappings of tracked ops when all handles are dead

No replacement ops are needed for tracked ops who's handles are all dead.

Differential Revision: https://reviews.llvm.org/D147510

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
index 69fcbd8ae3ea2..d1b14d206cd77 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h
@@ -22,8 +22,7 @@ namespace tensor {
 /// replacements.
 class TrackingListener : public transform::TrackingListener {
 public:
-  explicit TrackingListener(transform::TransformState &state)
-      : transform::TrackingListener(state) {}
+  using transform::TrackingListener::TrackingListener;
 
 protected:
   Operation *findReplacementOp(Operation *op,

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 58e9d8506808a..af812a735b77f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -39,8 +39,9 @@ using SequenceBodyBuilderArgsFn =
 class TrackingListener : public RewriterBase::Listener,
                          public TransformState::Extension {
 public:
-  explicit TrackingListener(TransformState &state)
-      : TransformState::Extension(state) {}
+  /// Create a new TrackingListener for usage in the specified transform op.
+  explicit TrackingListener(TransformState &state, TransformOpInterface op)
+      : TransformState::Extension(state), transformOp(op) {}
 
 protected:
   /// Return a replacement payload op for the given op, which is going to be
@@ -78,6 +79,9 @@ class TrackingListener : public RewriterBase::Listener,
 
   /// Ops that were newly created during the transform.
   DenseMap<OperationName, DenseSet<Operation *>> newOps;
+
+  /// The transform op in which this TrackingListener is used.
+  TransformOpInterface transformOp;
 };
 
 } // namespace transform

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1f4ac08f1259e..a92ffa196a40f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1818,7 +1818,7 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
                                   transform::TransformState &state) {
   tensor::PadOp hoistedPadOp;
   SmallVector<GenericOp> transposeOps;
-  TrackingListener listener(state);
+  TrackingListener listener(state, *this);
   IRRewriter rewriter(target->getContext(), &listener);
   FailureOr<Value> result =
       hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
@@ -3068,7 +3068,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
   if (getVectorizePadding())
     linalg::populatePadOpVectorizationPatterns(patterns);
 
-  TrackingListener listener(state);
+  TrackingListener listener(state, *this);
   GreedyRewriteConfig config;
   config.listener = &listener;
   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 15583315cc9b3..27b00cf6c6c9c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -210,6 +210,19 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   });
 }
 
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
 void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
@@ -220,13 +233,30 @@ void transform::TrackingListener::notifyOperationReplaced(
     (void)replacePayloadValue(oldValue, newValue);
 
   // Replace op handle.
-  Operation *replacement = findReplacementOp(op, newValues);
-  if (succeeded(replacePayloadOp(op, replacement))) {
-    // If the op is tracked but no replacement op was found, send a
-    // notification.
-    if (!replacement)
-      notifyPayloadReplacementNotFound(op, newValues);
+  SmallVector<Value> opHandles;
+  if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+    // Op is not tracked.
+    return;
+  }
+  auto hasAliveUser = [&]() {
+    for (Value v : opHandles)
+      for (Operation *user : v.getUsers())
+        if (!happensBefore(user, transformOp))
+          return true;
+    return false;
+  };
+  if (!hasAliveUser()) {
+    // The op is tracked but the corresponding handles are dead.
+    (void)replacePayloadOp(op, nullptr);
+    return;
   }
+
+  Operation *replacement = findReplacementOp(op, newValues);
+  // If the op is tracked but no replacement op was found, send a
+  // notification.
+  if (!replacement)
+    notifyPayloadReplacementNotFound(op, newValues);
+  (void)replacePayloadOp(op, replacement);
 }
 
 //===----------------------------------------------------------------------===//
@@ -349,7 +379,7 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
     if (!failed) {
       // We will be using the clones, so cancel their scheduled deletion.
       deleteClones.release();
-      TrackingListener listener(state);
+      TrackingListener listener(state, *this);
       IRRewriter rewriter(getContext(), &listener);
       for (const auto &kvp : llvm::zip(originals, clones)) {
         Operation *original = std::get<0>(kvp);


        


More information about the Mlir-commits mailing list