[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)

Jakub Kuderski llvmlistbot at llvm.org
Thu Oct 31 08:09:03 PDT 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/113825

>From 883902df7f742e7e6f5f6f32e44743f58793bdfc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 13:45:09 -0400
Subject: [PATCH 01/12] [mlir] Add walk pattern rewrite driver

---
 mlir/docs/PatternRewriter.md                  |  32 ++++-
 .../Transforms/WalkPatternRewriteDriver.h     |  37 +++++
 .../Transforms/UnsignedWhenEquivalent.cpp     |  12 +-
 mlir/lib/Transforms/Utils/CMakeLists.txt      |   1 +
 .../Utils/WalkPatternRewriteDriver.cpp        |  86 ++++++++++++
 mlir/test/IR/enum-attr-roundtrip.mlir         |   2 +-
 ...eedy-pattern-rewrite-driver-bottom-up.mlir |   2 +-
 ...reedy-pattern-rewrite-driver-top-down.mlir |   2 +-
 .../IR/test-walk-pattern-rewrite-driver.mlir  | 107 ++++++++++++++
 .../test-operation-folder-commutative.mlir    |   2 +-
 .../Transforms/test-operation-folder.mlir     |   4 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 132 +++++++++++++-----
 mlir/test/mlir-tblgen/pattern.mlir            |   2 +-
 13 files changed, 366 insertions(+), 55 deletions(-)
 create mode 100644 mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
 create mode 100644 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
 create mode 100644 mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 0ba76199874cc3..1d036dc8ce0f6a 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -320,15 +320,33 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
 framework also provides support for type conversions. More information on this
 driver can be found [here](DialectConversion.md).
 
+### Walk Pattern Rewrite Driver
+
+This is a fast and simple driver that walks the given op and applies patterns
+that locally have the most benefit. The benefit of a pattern is decided solely
+by the benefit specified on the pattern, and the relative order of the pattern
+within the pattern list (when two patterns have the same local benefit).
+
+This driver does not (re)visit modified or newly replaced ops, and does not
+allow for progressive rewrites of the same op. Op erasure is only supported for
+the currently matched op. If your pattern-set requires these, consider using the
+Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+
+This driver is exposed using the `walkAndApplyPatterns` function.
+
+#### Debugging
+
+You can debug the Walk Pattern Rewrite Driver by passing the
+`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
+ops.
+
 ### Greedy Pattern Rewrite Driver
 
 This driver processes ops in a worklist-driven fashion and greedily applies the
-patterns that locally have the most benefit. The benefit of a pattern is decided
-solely by the benefit specified on the pattern, and the relative order of the
-pattern within the pattern list (when two patterns have the same local benefit).
-Patterns are iteratively applied to operations until a fixed point is reached or
-until the configurable maximum number of iterations exhausted, at which point
-the driver finishes.
+patterns that locally have the most benefit (same as the Walk Pattern Rewrite
+Driver). Patterns are iteratively applied to operations until a fixed point is
+reached or until the configurable maximum number of iterations exhausted, at
+which point the driver finishes.
 
 This driver comes in two fashions:
 
@@ -368,7 +386,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
 Note: This driver is the one used by the [canonicalization](Canonicalization.md)
 [pass](Passes.md/#-canonicalize) in MLIR.
 
-### Debugging
+#### Debugging
 
 To debug the execution of the greedy pattern rewrite driver,
 `-debug-only=greedy-rewriter` may be used. This command line flag activates
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
new file mode 100644
index 00000000000000..6d62ae3dd43dc1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -0,0 +1,37 @@
+//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declares a helper function to walk the given op and apply rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+
+namespace mlir {
+
+/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
+/// given operation by walking it and applying the highest benefit patterns.
+/// This rewriter *does not* wait until a fixpoint is reached and *does not*
+/// visit modified or newly replaced ops. Also *does not* perform folding or
+/// dead-code elimination.
+///
+/// This is intended as the simplest and most lightweight pattern rewriter in
+/// cases when a simple walk gets the job done.
+///
+/// Note: Does not apply patterns to the given operation itself.
+void walkAndApplyPatterns(Operation *op,
+                          const FrozenRewritePatternSet &patterns,
+                          RewriterBase::Listener *listener = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index bebe0b5a7c0b61..ad455aaf987fc6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -14,7 +14,11 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace arith {
@@ -157,11 +161,7 @@ struct ArithUnsignedWhenEquivalentPass
     RewritePatternSet patterns(ctx);
     populateUnsignedWhenEquivalentPatterns(patterns, solver);
 
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-      signalPassFailure();
+    walkAndApplyPatterns(op, std::move(patterns), &listener);
   }
 };
 } // end anonymous namespace
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index eb588640dbf83a..72eb34f36cf5f6 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
   LoopInvariantCodeMotionUtils.cpp
   OneToNTypeConversion.cpp
   RegionUtils.cpp
+  WalkPatternRewriteDriver.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
new file mode 100644
index 00000000000000..2d3aa5fc7d15c7
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+  using RewriterBase::ForwardingListener::ForwardingListener;
+
+  void notifyOperationErased(Operation *op) override {
+    if (op != visitedOp)
+      llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+                               "erasure is only supported for matched ops");
+
+    ForwardingListener::notifyOperationErased(op);
+  }
+
+  Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+                          const FrozenRewritePatternSet &patterns,
+                          RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (failed(verify(op)))
+    llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+  PatternRewriter rewriter(op->getContext());
+  ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  rewriter.setListener(&erasedListener);
+#else
+  (void)erasedListener;
+  rewriter.setListener(listener);
+#endif
+
+  PatternApplicator applicator(patterns);
+  applicator.applyDefaultCostModel();
+
+  op->walk([&](Operation *visitedOp) {
+    if (visitedOp == op)
+      return;
+
+    LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+               visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+               llvm::dbgs() << "\n";);
+    erasedListener.visitedOp = visitedOp;
+    if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+    }
+  });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (failed(verify(op)))
+    llvm::report_fatal_error(
+        "walk pattern rewriter result IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+}
+
+} // namespace mlir
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 0b4d379cfb7d5f..36e605bdbff4dc 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
 
 // CHECK-LABEL: @test_enum_attr_roundtrip
 func.func @test_enum_attr_roundtrip() -> () {
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
index f3da9a147fcb95..d619eefd721023 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
 // RUN:     -allow-unregistered-dialect --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @add_to_worklist_after_inplace_update()
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
index a362d6f99b9478..9f4a7924b725a2 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
 // RUN:     --split-input-file | FileCheck %s
 
 // Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
new file mode 100644
index 00000000000000..f7536ad3315870
--- /dev/null
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
+// RUN:   --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// The following op is updated in-place and will not be added back to the worklist.
+// CHECK-LABEL: func.func @inplace_update()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+func.func @inplace_update() {
+  "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  "test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
+  return
+}
+
+// Check that the driver does not fold visited ops.
+// CHECK-LABEL: func.func @add_no_fold()
+// CHECK: arith.constant
+// CHECK: arith.constant
+// CHECK: %[[RES:.+]] = arith.addi
+// CHECK: return %[[RES]]
+func.func @add_no_fold() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %res = arith.addi %c0, %c1 : i32
+  return %res : i32
+}
+
+// Check that the driver handles rewriter.moveBefore.
+// CHECK-LABEL: func.func @move_before(
+// CHECK: "test.move_before_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: scf.if
+// CHECK: return
+func.func @move_before(%cond : i1) {
+  scf.if %cond {
+    "test.move_before_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) : () -> ()
+  }
+  return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// CHECK-LABEL: func.func @move_after(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: return
+func.func @move_after(%cond : i1) {
+  scf.if %cond {
+    "test.move_after_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) : () -> ()
+  }
+  return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited twice since we advance its position to the next
+// node after the parent.
+// CHECK-LABEL: func.func @move_forward_and_revisit(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: arith.addi
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+// CHECK: arith.addi
+// CHECK: return
+func.func @move_forward_and_revisit(%cond : i1) {
+  scf.if %cond {
+    "test.move_after_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) {advance = 1 : i32} : () -> ()
+  }
+  %a = arith.addi %cond, %cond : i1
+  %b = arith.addi %a, %cond : i1
+  return
+}
+
+// Operation inserted just after the currently visited one won't be visited.
+// CHECK-LABEL: func.func @insert_just_after
+// CHECK: "test.clone_me"() ({
+// CHECK:   "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) {was_cloned} : () -> ()
+// CHECK: "test.clone_me"() ({
+// CHECK:   "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) : () -> ()
+// CHECK: return
+func.func @insert_just_after(%cond : i1) {
+  "test.clone_me"() ({
+    "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  }) : () -> ()
+  return
+}
+
+// Check that we can replace the current operation with a new one.
+// Note that the new op won't be visited.
+// CHECK-LABEL: func.func @replace_with_new_op
+// CHECK: %[[NEW:.+]] = "test.new_op"
+// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
+// CHECK: return %[[RES]]
+func.func @replace_with_new_op() -> i32 {
+  %a = "test.replace_with_new_op"() : () -> (i32)
+  %res = arith.addi %a, %a : i32
+  return %res : i32
+}
diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir
index 8ffdeb54f399dc..55556c1ec58443 100644
--- a/mlir/test/Transforms/test-operation-folder-commutative.mlir
+++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
 
 // CHECK-LABEL: func @test_reorder_constants_and_match
 func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 46ee07af993cc7..3c0cd15dc6c510 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
-// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
 
 func.func @foo() -> i32 {
   %c42 = arith.constant 42 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3eade0369f7654..c54e35b3e07be3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,12 +13,16 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/ScopeExit.h"
+#include <cstdint>
 
 using namespace mlir;
 using namespace test;
@@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
   }
 };
 
+/// This pattern moves "test.move_after_parent_op" after the parent op.
+struct MoveAfterParentOp : public RewritePattern {
+  MoveAfterParentOp(MLIRContext *context)
+      : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not hoist past functions.
+    if (isa<FunctionOpInterface>(op->getParentOp()))
+      return failure();
+
+    int64_t moveForwardBy = 0;
+    if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
+      moveForwardBy = advanceBy.getInt();
+
+    Operation *moveAfter = op->getParentOp();
+    for (int64_t i = 0; i < moveForwardBy; ++i)
+      moveAfter = moveAfter->getNextNode();
+
+    rewriter.moveOpAfter(op, moveAfter);
+    return success();
+  }
+};
+
 /// This pattern inlines blocks that are nested in
 /// "test.inline_blocks_into_parent" into the parent block.
 struct InlineBlocksIntoParent : public RewritePattern {
@@ -286,14 +314,43 @@ struct CloneRegionBeforeOp : public RewritePattern {
   }
 };
 
-struct TestPatternDriver
-    : public PassWrapper<TestPatternDriver, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
+/// Replace an operation may introduce the re-visiting of its users.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+  ReplaceWithNewOp(MLIRContext *context)
+      : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Operation *newOp;
+    if (op->hasAttr("create_erase_op")) {
+      newOp = rewriter.create(
+          op->getLoc(),
+          OperationName("test.erase_op", op->getContext()).getIdentifier(),
+          ValueRange(), TypeRange());
+    } else {
+      newOp = rewriter.create(
+          op->getLoc(),
+          OperationName("test.new_op", op->getContext()).getIdentifier(),
+          op->getOperands(), op->getResultTypes());
+    }
+    // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+    // A "notifyOperationReplaced" callback is triggered in either case.
+    rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct TestGreedyPatternDriver
+    : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
 
-  TestPatternDriver() = default;
-  TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+  TestGreedyPatternDriver() = default;
+  TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
+      : PassWrapper(other) {}
 
-  StringRef getArgument() const final { return "test-patterns"; }
+  StringRef getArgument() const final { return "test-greedy-patterns"; }
   StringRef getDescription() const final { return "Run test dialect patterns"; }
   void runOnOperation() override {
     mlir::RewritePatternSet patterns(&getContext());
@@ -470,34 +527,6 @@ struct TestStrictPatternDriver
     }
   };
 
-  // Replace an operation may introduce the re-visiting of its users.
-  class ReplaceWithNewOp : public RewritePattern {
-  public:
-    ReplaceWithNewOp(MLIRContext *context)
-        : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
-
-    LogicalResult matchAndRewrite(Operation *op,
-                                  PatternRewriter &rewriter) const override {
-      Operation *newOp;
-      if (op->hasAttr("create_erase_op")) {
-        newOp = rewriter.create(
-            op->getLoc(),
-            OperationName("test.erase_op", op->getContext()).getIdentifier(),
-            ValueRange(), TypeRange());
-      } else {
-        newOp = rewriter.create(
-            op->getLoc(),
-            OperationName("test.new_op", op->getContext()).getIdentifier(),
-            op->getOperands(), op->getResultTypes());
-      }
-      // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
-      // A "notifyOperationReplaced" callback is triggered in either case.
-      rewriter.replaceAllOpUsesWith(op, newOp->getResults());
-      rewriter.eraseOp(op);
-      return success();
-    }
-  };
-
   // Remove an operation may introduce the re-visiting of its operands.
   class EraseOp : public RewritePattern {
   public:
@@ -560,6 +589,38 @@ struct TestStrictPatternDriver
   };
 };
 
+struct TestWalkPatternDriver final
+    : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
+
+  TestWalkPatternDriver() = default;
+  TestWalkPatternDriver(const TestWalkPatternDriver &other)
+      : PassWrapper(other) {}
+
+  StringRef getArgument() const override {
+    return "test-walk-pattern-rewrite-driver";
+  }
+  StringRef getDescription() const override {
+    return "Run test greedy pattern rewrite driver";
+  }
+  void runOnOperation() override {
+    mlir::RewritePatternSet patterns(&getContext());
+
+    // Patterns for testing the WalkPatternRewriteDriver.
+    patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
+                 MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+
+    DumpNotifications dumpListener;
+    walkAndApplyPatterns(getOperation(), std::move(patterns),
+                         dumpNotifications ? &dumpListener : nullptr);
+  }
+
+  Option<bool> dumpNotifications{
+      *this, "dump-notifications",
+      llvm::cl::desc("Print rewrite listener notifications"),
+      llvm::cl::init(false)};
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1978,8 +2039,9 @@ void registerPatternsTestPass() {
 
   PassRegistration<TestDerivedAttributeDriver>();
 
-  PassRegistration<TestPatternDriver>();
+  PassRegistration<TestGreedyPatternDriver>();
   PassRegistration<TestStrictPatternDriver>();
+  PassRegistration<TestWalkPatternDriver>();
 
   PassRegistration<TestLegalizePatternDriver>([] {
     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5ff8710b937701..60d46e676d2a33 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
 
 // CHECK-LABEL: verifyFusedLocs
 func.func @verifyFusedLocs(%arg0 : i32) -> i32 {

>From 34d440ab2192895d64306938d4cd3929b7f83363 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 18:38:41 -0400
Subject: [PATCH 02/12] Iterate over regions

---
 .../Utils/WalkPatternRewriteDriver.cpp        | 23 +++++++++----------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 2d3aa5fc7d15c7..864e86193a0d6c 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -63,18 +63,17 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
-  op->walk([&](Operation *visitedOp) {
-    if (visitedOp == op)
-      return;
-
-    LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
-               visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
-               llvm::dbgs() << "\n";);
-    erasedListener.visitedOp = visitedOp;
-    if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
-      LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
-    }
-  });
+  for (Region &region : op->getRegions()) {
+    region.walk([&](Operation *visitedOp) {
+      LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+          llvm::dbgs(), OpPrintingFlags().skipRegions());
+                 llvm::dbgs() << "\n";);
+      erasedListener.visitedOp = visitedOp;
+      if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+        LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+      }
+    });
+  }
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   if (failed(verify(op)))

>From ccc4e08f84136180304e396af0c86deb99079662 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:03:10 -0400
Subject: [PATCH 03/12] Add action

---
 mlir/docs/ActionTracing.md                    |  2 +-
 .../Utils/WalkPatternRewriteDriver.cpp        | 40 +++++++++++++------
 2 files changed, 29 insertions(+), 13 deletions(-)

diff --git a/mlir/docs/ActionTracing.md b/mlir/docs/ActionTracing.md
index 978fdbbe54d81c..984516d5c5e7e2 100644
--- a/mlir/docs/ActionTracing.md
+++ b/mlir/docs/ActionTracing.md
@@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
 
 ```c++
 /// A custom Action can be defined minimally by deriving from
-/// `tracing::ActionImpl`. It can has any members!
+/// `tracing::ActionImpl`. It can have any members!
 class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
 public:
   using Base = tracing::ActionImpl<MyCustomAction>;
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 864e86193a0d6c..eb450c4dacdd45 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -12,11 +12,13 @@
 
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 
@@ -25,9 +27,18 @@
 namespace mlir {
 
 namespace {
+struct WalkAndApplyPatternsAction final
+    : tracing::ActionImpl<WalkAndApplyPatternsAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
+  using ActionImpl::ActionImpl;
+  static constexpr StringLiteral tag = "walk-and-apply-patterns";
+  void print(raw_ostream &os) const override { os << tag; }
+};
+
 // Forwarding listener to guard against unsupported erasures. Because we use
 // walk-based pattern application, erasing the op from the *next* iteration
 // (e.g., a user of the visited op) is not valid.
+// Note that this is only used with expensive pattern API checks.
 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
   using RewriterBase::ForwardingListener::ForwardingListener;
 
@@ -51,7 +62,8 @@ void walkAndApplyPatterns(Operation *op,
     llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
-  PatternRewriter rewriter(op->getContext());
+  MLIRContext *ctx = op->getContext();
+  PatternRewriter rewriter(ctx);
   ErasedOpsListener erasedListener(listener);
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   rewriter.setListener(&erasedListener);
@@ -63,17 +75,21 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
-  for (Region &region : op->getRegions()) {
-    region.walk([&](Operation *visitedOp) {
-      LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
-          llvm::dbgs(), OpPrintingFlags().skipRegions());
-                 llvm::dbgs() << "\n";);
-      erasedListener.visitedOp = visitedOp;
-      if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
-        LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
-      }
-    });
-  }
+  ctx->executeAction<WalkAndApplyPatternsAction>(
+      [&] {
+        for (Region &region : op->getRegions()) {
+          region.walk([&](Operation *visitedOp) {
+            LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+                llvm::dbgs(), OpPrintingFlags().skipRegions());
+                       llvm::dbgs() << "\n";);
+            erasedListener.visitedOp = visitedOp;
+            if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+              LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+            }
+          });
+        }
+      },
+      {op});
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   if (failed(verify(op)))

>From 62cb821e2365a6b33bab21134f0cedfb6575465b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:07:03 -0400
Subject: [PATCH 04/12] Coding style

---
 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index eb450c4dacdd45..0258064a3436b7 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -70,7 +70,7 @@ void walkAndApplyPatterns(Operation *op,
 #else
   (void)erasedListener;
   rewriter.setListener(listener);
-#endif
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();

>From aaa9f00802663f7a35ff7bd0bd7eef9161018533 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:40:22 -0400
Subject: [PATCH 05/12] Cleanup

---
 mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index ad455aaf987fc6..986450fb128aed 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,10 +13,6 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 

>From 2b1205f6ce7c9f70c8ec7a58ccaf2ed7aa5a2874 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:41:11 -0400
Subject: [PATCH 06/12] Cleanup

---
 mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 986450fb128aed..8922e93e399f9f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,8 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace arith {

>From d70b289e7acd88b12a502d4c053c1e2ef775f5f9 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:42:38 -0400
Subject: [PATCH 07/12] Typo

---
 mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c54e35b3e07be3..dd801513adbe66 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -601,7 +601,7 @@ struct TestWalkPatternDriver final
     return "test-walk-pattern-rewrite-driver";
   }
   StringRef getDescription() const override {
-    return "Run test greedy pattern rewrite driver";
+    return "Run test walk pattern rewrite driver";
   }
   void runOnOperation() override {
     mlir::RewritePatternSet patterns(&getContext());

>From 9d295d840b8010977ad42f46653c20cd888a7b6b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 20:51:31 -0400
Subject: [PATCH 08/12] Clarify docs

---
 mlir/docs/PatternRewriter.md | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 1d036dc8ce0f6a..315a8dc0d1620d 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -327,9 +327,12 @@ that locally have the most benefit. The benefit of a pattern is decided solely
 by the benefit specified on the pattern, and the relative order of the pattern
 within the pattern list (when two patterns have the same local benefit).
 
+The driver performs a post-order traversal. Note that it walks regions of the
+given op but does not visit the op.
+
 This driver does not (re)visit modified or newly replaced ops, and does not
 allow for progressive rewrites of the same op. Op erasure is only supported for
-the currently matched op. If your pattern-set requires these, consider using the
+the currently matched op. If your pattern set requires these, consider using the
 Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
 
 This driver is exposed using the `walkAndApplyPatterns` function.

>From 63d0a9b72ff1b9fa31e274acbc04738181979a17 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 19:55:05 -0400
Subject: [PATCH 09/12] Guard forwarding listener

---
 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 0258064a3436b7..b1ce47805bf435 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -18,7 +18,6 @@
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 
@@ -35,6 +34,7 @@ struct WalkAndApplyPatternsAction final
   void print(raw_ostream &os) const override { os << tag; }
 };
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 // Forwarding listener to guard against unsupported erasures. Because we use
 // walk-based pattern application, erasing the op from the *next* iteration
 // (e.g., a user of the visited op) is not valid.
@@ -52,6 +52,7 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
 
   Operation *visitedOp = nullptr;
 };
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 } // namespace
 
 void walkAndApplyPatterns(Operation *op,
@@ -64,11 +65,10 @@ void walkAndApplyPatterns(Operation *op,
 
   MLIRContext *ctx = op->getContext();
   PatternRewriter rewriter(ctx);
-  ErasedOpsListener erasedListener(listener);
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  ErasedOpsListener erasedListener(listener);
   rewriter.setListener(&erasedListener);
 #else
-  (void)erasedListener;
   rewriter.setListener(listener);
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
@@ -82,7 +82,9 @@ void walkAndApplyPatterns(Operation *op,
             LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
                 llvm::dbgs(), OpPrintingFlags().skipRegions());
                        llvm::dbgs() << "\n";);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             erasedListener.visitedOp = visitedOp;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
               LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
             }

>From 4a524aec9e1af1b086dd4e797153fc9d351401c2 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 22:15:19 -0400
Subject: [PATCH 10/12] Guard against invalid block erasure. Support forwarding
 to null listeners.

---
 mlir/docs/PatternRewriter.md                  | 11 +++++--
 mlir/include/mlir/IR/PatternMatch.h           | 28 +++++++++++-------
 .../Utils/WalkPatternRewriteDriver.cpp        | 29 ++++++++++++++-----
 .../IR/test-walk-pattern-rewrite-driver.mlir  | 14 +++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 25 +++++++++++++++-
 5 files changed, 84 insertions(+), 23 deletions(-)

diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 315a8dc0d1620d..c61ceaf81681e2 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -331,12 +331,17 @@ The driver performs a post-order traversal. Note that it walks regions of the
 given op but does not visit the op.
 
 This driver does not (re)visit modified or newly replaced ops, and does not
-allow for progressive rewrites of the same op. Op erasure is only supported for
-the currently matched op. If your pattern set requires these, consider using the
-Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+allow for progressive rewrites of the same op. Op and block erasure is only
+supported for the currently matched op and its descendant. If your pattern
+set requires these, consider using the Greedy Pattern Rewrite Driver instead,
+at the expense of extra overhead.
 
 This driver is exposed using the `walkAndApplyPatterns` function.
 
+Note: This driver listens for IR changes via the callbacks provided by
+`RewriterBase`. It is important that patterns announce all IR changes to the
+rewriter and do not bypass the rewriter API by modifying ops directly.
+
 #### Debugging
 
 You can debug the Walk Pattern Rewrite Driver by passing the
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 896fdf1c899e3d..2ab0405043a546 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
   /// struct can be used as a base to create listener chains, so that multiple
   /// listeners can be notified of IR changes.
   struct ForwardingListener : public RewriterBase::Listener {
-    ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
+    ForwardingListener(OpBuilder::Listener *listener)
+        : listener(listener),
+          rewriteListener(
+              dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
 
     void notifyOperationInserted(Operation *op, InsertPoint previous) override {
-      listener->notifyOperationInserted(op, previous);
+      if (listener)
+        listener->notifyOperationInserted(op, previous);
     }
     void notifyBlockInserted(Block *block, Region *previous,
                              Region::iterator previousIt) override {
-      listener->notifyBlockInserted(block, previous, previousIt);
+      if (listener)
+        listener->notifyBlockInserted(block, previous, previousIt);
     }
     void notifyBlockErased(Block *block) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyBlockErased(block);
     }
     void notifyOperationModified(Operation *op) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyOperationModified(op);
     }
     void notifyOperationReplaced(Operation *op, Operation *newOp) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyOperationReplaced(op, newOp);
     }
     void notifyOperationReplaced(Operation *op,
                                  ValueRange replacement) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyOperationReplaced(op, replacement);
     }
     void notifyOperationErased(Operation *op) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyOperationErased(op);
     }
     void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyPatternBegin(pattern, op);
     }
     void notifyPatternEnd(const Pattern &pattern,
                           LogicalResult status) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyPatternEnd(pattern, status);
     }
     void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
-      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+      if (rewriteListener)
         rewriteListener->notifyMatchFailure(loc, reasonCallback);
     }
 
   private:
     OpBuilder::Listener *listener;
+    RewriterBase::Listener *rewriteListener;
   };
 
   /// Move the blocks that belong to "region" before the given position in
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index b1ce47805bf435..efbc646f2ef276 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -35,21 +35,34 @@ struct WalkAndApplyPatternsAction final
 };
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-// Forwarding listener to guard against unsupported erasures. Because we use
-// walk-based pattern application, erasing the op from the *next* iteration
-// (e.g., a user of the visited op) is not valid.
-// Note that this is only used with expensive pattern API checks.
+// Forwarding listener to guard against unsupported erasures of non-descendant
+// ops/blocks. Because we use walk-based pattern application, erasing the
+// op/block from the *next* iteration (e.g., a user of the visited op) is not
+// valid. Note that this is only used with expensive pattern API checks.
 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
   using RewriterBase::ForwardingListener::ForwardingListener;
 
   void notifyOperationErased(Operation *op) override {
-    if (op != visitedOp)
-      llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
-                               "erasure is only supported for matched ops");
-
+    checkErasure(op);
     ForwardingListener::notifyOperationErased(op);
   }
 
+  void notifyBlockErased(Block *block) override {
+    checkErasure(block->getParentOp());
+    ForwardingListener::notifyBlockErased(block);
+  }
+
+  void checkErasure(Operation *op) const {
+    Operation *ancestorOp = op;
+    while (ancestorOp && ancestorOp != visitedOp)
+      ancestorOp = ancestorOp->getParentOp();
+
+    if (ancestorOp != visitedOp)
+      llvm::report_fatal_error(
+          "unsupported erased in WalkPatternRewriter; "
+          "erasure is only supported for matched ops and their descendants");
+  }
+
   Operation *visitedOp = nullptr;
 };
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index f7536ad3315870..5423552a8ef1dd 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -105,3 +105,17 @@ func.func @replace_with_new_op() -> i32 {
   %res = arith.addi %a, %a : i32
   return %res : i32
 }
+
+// Check that we can erase nested blocks.
+// CHECK-LABEL: func.func @erase_nested_block
+// CHECK:         %[[RES:.+]] = "test.erase_first_block"
+// CHECK-NEXT:    foo.bar
+// CHECK:         return %[[RES]]
+func.func @erase_nested_block() -> i32 {
+  %a = "test.erase_first_block"() ({
+    "foo.foo"() : () -> ()
+    ^bb1:
+    "foo.bar"() : () -> ()
+  }): () -> (i32)
+  return %a : i32
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index dd801513adbe66..d97f3b41f2ef29 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -342,6 +342,28 @@ class ReplaceWithNewOp : public RewritePattern {
   }
 };
 
+/// Erases the first child block of the matched "test.erase_first_block"
+/// operation.
+class EraseFirstBlock : public RewritePattern {
+public:
+  EraseFirstBlock(MLIRContext *context)
+      : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
+    for (Region &r : op->getRegions()) {
+      for (Block &b : r.getBlocks()) {
+        rewriter.eraseBlock(&b);
+        llvm::errs() << "Erasing block: " << b << "\n";
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
 struct TestGreedyPatternDriver
     : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
@@ -608,7 +630,8 @@ struct TestWalkPatternDriver final
 
     // Patterns for testing the WalkPatternRewriteDriver.
     patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
-                 MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+                 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
+        &getContext());
 
     DumpNotifications dumpListener;
     walkAndApplyPatterns(getOperation(), std::move(patterns),

>From 5e85a8fadcd67915d5409f7cbecc9145290f3a37 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 22:25:15 -0400
Subject: [PATCH 11/12] Add missing newline

---
 mlir/test/IR/test-walk-pattern-rewrite-driver.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 5423552a8ef1dd..02f7e60671c9b3 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -118,4 +118,4 @@ func.func @erase_nested_block() -> i32 {
     "foo.bar"() : () -> ()
   }): () -> (i32)
   return %a : i32
-}
\ No newline at end of file
+}

>From dd4ed14676d274c4672ee1f72cd09870b52148d7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 31 Oct 2024 11:08:19 -0400
Subject: [PATCH 12/12] Fix typo

---
 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index efbc646f2ef276..ee5c642c943c45 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -59,7 +59,7 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
 
     if (ancestorOp != visitedOp)
       llvm::report_fatal_error(
-          "unsupported erased in WalkPatternRewriter; "
+          "unsupported erasure in WalkPatternRewriter; "
           "erasure is only supported for matched ops and their descendants");
   }
 



More information about the Mlir-commits mailing list