[Mlir-commits] [mlir] 11a9c05 - [mlir] GreedyPatternRewriteDriver: Fix termination criteria in OpPatternRewriteDriver

Matthias Springer llvmlistbot at llvm.org
Wed Jan 18 06:15:41 PST 2023


Author: Matthias Springer
Date: 2023-01-18T15:11:06+01:00
New Revision: 11a9c05bcbbdc42e182aea0c502a74f4bc626b79

URL: https://github.com/llvm/llvm-project/commit/11a9c05bcbbdc42e182aea0c502a74f4bc626b79
DIFF: https://github.com/llvm/llvm-project/commit/11a9c05bcbbdc42e182aea0c502a74f4bc626b79.diff

LOG: [mlir] GreedyPatternRewriteDriver: Fix termination criteria in OpPatternRewriteDriver

This driver should iterate until convergence or until the specified op was erased. However, it used to stop when any op was erased.

Differential Revision: https://reviews.llvm.org/D141921

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 6bd3994c43137..b7ea592bfcc7d 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -481,7 +481,8 @@ class OpPatternRewriteDriver : public PatternRewriter {
   /// If an operation is about to be removed, mark it so that we can let clients
   /// know.
   void notifyOperationRemoved(Operation *op) override {
-    opErasedViaPatternRewrites = true;
+    if (this->op == op)
+      opErasedViaPatternRewrites = true;
   }
 
   // When a root is going to be replaced, its removal will be notified as well.
@@ -495,6 +496,9 @@ class OpPatternRewriteDriver : public PatternRewriter {
   /// Non-pattern based folder for operations.
   OperationFolder folder;
 
+  /// Op that is being processed.
+  Operation *op = nullptr;
+
   /// Set to true if the operation has been erased via pattern rewrites.
   bool opErasedViaPatternRewrites = false;
 };
@@ -509,6 +513,7 @@ class OpPatternRewriteDriver : public PatternRewriter {
 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
                                                       int64_t maxNumRewrites,
                                                       bool &erased) {
+  this->op = op;
   bool changed = false;
   erased = false;
   opErasedViaPatternRewrites = false;


        


More information about the Mlir-commits mailing list