[Mlir-commits] [mlir] 774416b - [mlir] GreedyPatternRewriteDriver: Keep track of surviving ops

Matthias Springer llvmlistbot at llvm.org
Thu Jan 26 00:22:00 PST 2023


Author: Matthias Springer
Date: 2023-01-26T09:21:51+01:00
New Revision: 774416bdb396e79e19d9c7ed49a49203c841c7c3

URL: https://github.com/llvm/llvm-project/commit/774416bdb396e79e19d9c7ed49a49203c841c7c3
DIFF: https://github.com/llvm/llvm-project/commit/774416bdb396e79e19d9c7ed49a49203c841c7c3.diff

LOG: [mlir] GreedyPatternRewriteDriver: Keep track of surviving ops

This change adds `allErased` to the `applyOpPatternsAndFold(ArrayRef<Operation *>, ...)` overload. This overload now supports all functionality that is also supported by `applyOpPatternsAndFold(Operation *, ...)` and can be used as a replacement.

This change has no performance implications when `allErased = nullptr`.

The single-operation overload is removed in a subsequent NFC change.

Differential Revision: https://reviews.llvm.org/D141920

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Transforms/test-strict-pattern-driver.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 72b24754fac27..f72dbb7ff2986 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -118,10 +118,12 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
 ///
 /// Returns success if the iterative process converged and no more patterns can
 /// be matched. `changed` is set to true if the IR was modified at all.
+/// `allOpsErased` is set to true if all ops in `ops` were erased.
 LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                                      const FrozenRewritePatternSet &patterns,
                                      GreedyRewriteStrictness strictMode,
-                                     bool *changed = nullptr);
+                                     bool *changed = nullptr,
+                                     bool *allErased = nullptr);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 56a94663e43e9..2b3a796dee93f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ScopedPrinter.h"
@@ -584,8 +585,11 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   ///
   /// Note that ops in `ops` could be erased as a result of folding, becoming
   /// dead, or via pattern rewrites. The return value indicates convergence.
-  LogicalResult simplifyLocally(ArrayRef<Operation *> op,
-                                bool *changed = nullptr);
+  ///
+  /// All `ops` that survived the rewrite are stored in `surviving`.
+  LogicalResult
+  simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
+                  llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
 
   void addToWorklist(Operation *op) override {
     if (strictMode == GreedyRewriteStrictness::AnyOp ||
@@ -602,6 +606,8 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 
   void notifyOperationRemoved(Operation *op) override {
     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+    if (survivingOps)
+      survivingOps->erase(op);
     if (strictMode != GreedyRewriteStrictness::AnyOp)
       strictModeFilteredOps.erase(op);
   }
@@ -615,13 +621,25 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   /// depending on `strictMode`. This set is not maintained when `strictMode`
   /// is GreedyRewriteStrictness::AnyOp.
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
+  /// An optional set of ops that survived the rewrite. This set is populated
+  /// at the beginning of `simplifyLocally` with the inititally provided list
+  /// of ops.
+  llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr;
 };
 
 } // namespace
 
-LogicalResult
-MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
-                                             bool *changed) {
+LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
+    ArrayRef<Operation *> ops, bool *changed,
+    llvm::SmallDenseSet<Operation *, 4> *surviving) {
+  auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
+  if (surviving) {
+    survivingOps = surviving;
+    survivingOps->clear();
+    survivingOps->insert(ops.begin(), ops.end());
+  }
+
   if (strictMode != GreedyRewriteStrictness::AnyOp) {
     strictModeFilteredOps.clear();
     strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -726,15 +744,22 @@ LogicalResult mlir::applyOpPatternsAndFold(
 
 LogicalResult mlir::applyOpPatternsAndFold(
     ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
-    GreedyRewriteStrictness strictMode, bool *changed) {
+    GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
   if (ops.empty()) {
     if (changed)
       *changed = false;
+    if (allErased)
+      *allErased = true;
     return success();
   }
 
   // Start the pattern driver.
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
                                      strictMode);
-  return driver.simplifyLocally(ops, changed);
+  llvm::SmallDenseSet<Operation *, 4> surviving;
+  LogicalResult converged =
+      driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+  if (allErased)
+    *allErased = surviving.empty();
+  return converged;
 }

diff  --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index ad6f6a5f01e20..9dbaea18967f8 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -7,6 +7,7 @@
 // RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX
 
 // CHECK-EN-LABEL: func @test_erase
+//  CHECK-EN-SAME:     pattern_driver_all_erased = true, pattern_driver_changed = true}
 //       CHECK-EN:   test.arg0
 //       CHECK-EN:   test.arg1
 //   CHECK-EN-NOT:   test.erase_op
@@ -20,6 +21,7 @@ func.func @test_erase() {
 // -----
 
 // CHECK-EN-LABEL: func @test_insert_same_op
+//  CHECK-EN-SAME:     {pattern_driver_all_erased = false, pattern_driver_changed = true}
 //       CHECK-EN:   "test.insert_same_op"() {skip = true}
 //       CHECK-EN:   "test.insert_same_op"() {skip = true}
 func.func @test_insert_same_op() {
@@ -30,6 +32,7 @@ func.func @test_insert_same_op() {
 // -----
 
 // CHECK-EN-LABEL: func @test_replace_with_new_op
+//  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //       CHECK-EN:   %[[n:.*]] = "test.new_op"
 //       CHECK-EN:   "test.dummy_user"(%[[n]])
 //       CHECK-EN:   "test.dummy_user"(%[[n]])
@@ -43,10 +46,12 @@ func.func @test_replace_with_new_op() {
 // -----
 
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
+//  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EN-NOT:   test.replace_with_new_op
 //   CHECK-EN-NOT:   test.erase_op
 
 // CHECK-EX-LABEL: func @test_replace_with_erase_op
+//  CHECK-EX-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EX-NOT:   test.replace_with_new_op
 //       CHECK-EX:   test.erase_op
 func.func @test_replace_with_erase_op() {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 726d8100c4c65..98896c736a3cb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -254,8 +254,9 @@ struct TestStrictPatternDriver
   }
 
   void runOnOperation() override {
-    mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
+    MLIRContext *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
@@ -279,7 +280,14 @@ struct TestStrictPatternDriver
     // Check if these transformations introduce visiting of operations that
     // are not in the `ops` set (The new created ops are valid). An invalid
     // operation will trigger the assertion while processing.
-    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode);
+    bool changed = false;
+    bool allErased = false;
+    (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
+                                 &changed, &allErased);
+    Builder b(ctx);
+    getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
+    getOperation()->setAttr("pattern_driver_all_erased",
+                            b.getBoolAttr(allErased));
   }
 
   Option<std::string> strictMode{


        


More information about the Mlir-commits mailing list