[Mlir-commits] [mlir] 7d436d5 - [mlir][transform] TrackingListener: Allow existing ops as replacements

Matthias Springer llvmlistbot at llvm.org
Fri May 12 06:11:55 PDT 2023


Author: Matthias Springer
Date: 2023-05-12T15:07:20+02:00
New Revision: 7d436d56b60b36508b94e39d08761f1405a9c770

URL: https://github.com/llvm/llvm-project/commit/7d436d56b60b36508b94e39d08761f1405a9c770
DIFF: https://github.com/llvm/llvm-project/commit/7d436d56b60b36508b94e39d08761f1405a9c770.diff

LOG: [mlir][transform] TrackingListener: Allow existing ops as replacements

The TrackingListener was unnecessarily strict. Existing ops are now allowed when updating payload ops mappings due to `replaceOp` in the TrackingListener.

Differential Revision: https://reviews.llvm.org/D150429

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index b6bc094d8ba55..7a0f80200cc47 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -48,8 +48,8 @@ class TrackingListener : public RewriterBase::Listener,
 protected:
   /// Return a replacement payload op for the given op, which is going to be
   /// replaced with the given values. By default, if all values are defined by
-  /// the same newly-created op, which also has the same type as the given op,
-  /// that defining op is used as a replacement.
+  /// the same op, which also has the same type as the given op, that defining
+  /// op is used as a replacement.
   virtual Operation *findReplacementOp(Operation *op,
                                        ValueRange newValues) const;
 
@@ -66,22 +66,14 @@ class TrackingListener : public RewriterBase::Listener,
   virtual void notifyPayloadReplacementNotFound(Operation *op,
                                                 ValueRange values) {}
 
-  /// Return "true" if the given op is a new op.
-  bool isNewOp(Operation *op) const;
-
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
 
 private:
-  void notifyOperationInserted(Operation *op) override;
-
   void notifyOperationRemoved(Operation *op) override;
 
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
 
-  /// Ops that were newly created during the transform.
-  DenseMap<OperationName, DenseSet<Operation *>> newOps;
-
   /// The transform op in which this TrackingListener is used.
   TransformOpInterface transformOp;
 };

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7451193d51fd0..ecd5d2a915ab6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -180,20 +180,9 @@ transform::TrackingListener::findReplacementOp(Operation *op,
   if (op->getName() != defOp->getName())
     return nullptr;
 
-  // If the replacement op is not a new op, drop the mapping.
-  if (!isNewOp(defOp))
-    return nullptr;
-
   return defOp;
 }
 
-bool transform::TrackingListener::isNewOp(Operation *op) const {
-  auto it = newOps.find(op->getName());
-  if (it == newOps.end())
-    return false;
-  return it->second.contains(op);
-}
-
 LogicalResult transform::TrackingListener::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
@@ -204,17 +193,9 @@ LogicalResult transform::TrackingListener::notifyMatchFailure(
   return failure();
 }
 
-void transform::TrackingListener::notifyOperationInserted(Operation *op) {
-  newOps[op->getName()].insert(op);
-}
-
 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   // TODO: Walk can be removed when D144193 has landed.
   op->walk([&](Operation *op) {
-    // Keep set of new ops up-to-date.
-    auto it = newOps.find(op->getName());
-    if (it != newOps.end())
-      it->second.erase(op);
     // Remove mappings for result values.
     for (OpResult value : op->getResults())
       (void)replacePayloadValue(value, nullptr);


        


More information about the Mlir-commits mailing list