[Mlir-commits] [mlir] e8aaf75 - [mlir] specify the values when notifying about op replacement

Alex Zinenko llvmlistbot at llvm.org
Tue Sep 27 09:22:44 PDT 2022


Author: Alex Zinenko
Date: 2022-09-27T16:22:35Z
New Revision: e8aaf75810575e389f5191e69ae3ab387d57f61a

URL: https://github.com/llvm/llvm-project/commit/e8aaf75810575e389f5191e69ae3ab387d57f61a
DIFF: https://github.com/llvm/llvm-project/commit/e8aaf75810575e389f5191e69ae3ab387d57f61a.diff

LOG: [mlir] specify the values when notifying about op replacement

It is useful for PatternRewriter listeners to know the values that are
replacing the op in addition to only the fact of the op being replaced
for being able to keep track of changes or for debugging.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134748

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8d62d3109991b..fa21d0c696f36 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -546,9 +546,9 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
   /// they would like to be notified about certain types of mutations.
 
   /// Notify the rewriter that the specified operation is about to be replaced
-  /// with another set of operations. This is called before the uses of the
-  /// operation have been changed.
-  virtual void notifyRootReplaced(Operation *op) {}
+  /// with the set of values potentially produced by new operations. This is
+  /// called before the uses of the operation have been changed.
+  virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}
 
   /// This is called on an operation that a rewrite is removing, right before
   /// the operation is deleted. At this point, the operation has zero uses.

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 56063d05e0e14..494d90f304bdd 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -216,7 +216,7 @@ void RewriterBase::replaceOpWithIf(
          "incorrect number of values to replace operation");
 
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op);
+  notifyRootReplaced(op, newValues);
 
   // Replace each use of the results when the functor is true.
   bool replacedAllUses = true;
@@ -244,7 +244,7 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
 /// the operation.
 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op);
+  notifyRootReplaced(op, newValues);
 
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7305a376449d4..9c62d61fe5291 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -69,7 +69,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   // When the root of a pattern is about to be replaced, it can trigger
   // simplifications to its users - make sure to add them to the worklist
   // before the root is changed.
-  void notifyRootReplaced(Operation *op) override;
+  void notifyRootReplaced(Operation *op, ValueRange replacement) override;
 
   /// PatternRewriter hook for erasing a dead operation.
   void eraseOp(Operation *op) override;
@@ -348,7 +348,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
   });
 }
 
-void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
+void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
+                                                    ValueRange replacement) {
   LLVM_DEBUG({
     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
                        << ")\n";
@@ -437,7 +438,7 @@ class OpPatternRewriteDriver : public PatternRewriter {
 
   // When a root is going to be replaced, its removal will be notified as well.
   // So there is nothing to do here.
-  void notifyRootReplaced(Operation *op) override {}
+  void notifyRootReplaced(Operation *op, ValueRange replacement) override {}
 
 private:
   /// The low-level pattern applicator.


        


More information about the Mlir-commits mailing list