[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)

Jakub Kuderski llvmlistbot at llvm.org
Sun Oct 27 10:57:00 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/113825

This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream).

The new driver is inspired by the discussion in https://github.com/llvm/llvm-project/pull/112454 and the LLVM Dev presentation from @matthias-springer earlier this week.

This limitation comes with some limitations:
* It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order).
* `matchAndRewrite` can only erase the matched op. This is verified under expensive checks.
* It does not perform folding / DCE.
 
 We could probably relax some of these in the future without sacrificing too much performance.

>From 883902df7f742e7e6f5f6f32e44743f58793bdfc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 13:45:09 -0400
Subject: [PATCH] [mlir] Add walk pattern rewrite driver

---
 mlir/docs/PatternRewriter.md                  |  32 ++++-
 .../Transforms/WalkPatternRewriteDriver.h     |  37 +++++
 .../Transforms/UnsignedWhenEquivalent.cpp     |  12 +-
 mlir/lib/Transforms/Utils/CMakeLists.txt      |   1 +
 .../Utils/WalkPatternRewriteDriver.cpp        |  86 ++++++++++++
 mlir/test/IR/enum-attr-roundtrip.mlir         |   2 +-
 ...eedy-pattern-rewrite-driver-bottom-up.mlir |   2 +-
 ...reedy-pattern-rewrite-driver-top-down.mlir |   2 +-
 .../IR/test-walk-pattern-rewrite-driver.mlir  | 107 ++++++++++++++
 .../test-operation-folder-commutative.mlir    |   2 +-
 .../Transforms/test-operation-folder.mlir     |   4 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 132 +++++++++++++-----
 mlir/test/mlir-tblgen/pattern.mlir            |   2 +-
 13 files changed, 366 insertions(+), 55 deletions(-)
 create mode 100644 mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
 create mode 100644 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
 create mode 100644 mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 0ba76199874cc3..1d036dc8ce0f6a 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -320,15 +320,33 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
 framework also provides support for type conversions. More information on this
 driver can be found [here](DialectConversion.md).
 
+### Walk Pattern Rewrite Driver
+
+This is a fast and simple driver that walks the given op and applies patterns
+that locally have the most benefit. The benefit of a pattern is decided solely
+by the benefit specified on the pattern, and the relative order of the pattern
+within the pattern list (when two patterns have the same local benefit).
+
+This driver does not (re)visit modified or newly replaced ops, and does not
+allow for progressive rewrites of the same op. Op erasure is only supported for
+the currently matched op. If your pattern-set requires these, consider using the
+Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+
+This driver is exposed using the `walkAndApplyPatterns` function.
+
+#### Debugging
+
+You can debug the Walk Pattern Rewrite Driver by passing the
+`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
+ops.
+
 ### Greedy Pattern Rewrite Driver
 
 This driver processes ops in a worklist-driven fashion and greedily applies the
-patterns that locally have the most benefit. The benefit of a pattern is decided
-solely by the benefit specified on the pattern, and the relative order of the
-pattern within the pattern list (when two patterns have the same local benefit).
-Patterns are iteratively applied to operations until a fixed point is reached or
-until the configurable maximum number of iterations exhausted, at which point
-the driver finishes.
+patterns that locally have the most benefit (same as the Walk Pattern Rewrite
+Driver). Patterns are iteratively applied to operations until a fixed point is
+reached or until the configurable maximum number of iterations exhausted, at
+which point the driver finishes.
 
 This driver comes in two fashions:
 
@@ -368,7 +386,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
 Note: This driver is the one used by the [canonicalization](Canonicalization.md)
 [pass](Passes.md/#-canonicalize) in MLIR.
 
-### Debugging
+#### Debugging
 
 To debug the execution of the greedy pattern rewrite driver,
 `-debug-only=greedy-rewriter` may be used. This command line flag activates
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
new file mode 100644
index 00000000000000..6d62ae3dd43dc1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -0,0 +1,37 @@
+//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declares a helper function to walk the given op and apply rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+
+namespace mlir {
+
+/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
+/// given operation by walking it and applying the highest benefit patterns.
+/// This rewriter *does not* wait until a fixpoint is reached and *does not*
+/// visit modified or newly replaced ops. Also *does not* perform folding or
+/// dead-code elimination.
+///
+/// This is intended as the simplest and most lightweight pattern rewriter in
+/// cases when a simple walk gets the job done.
+///
+/// Note: Does not apply patterns to the given operation itself.
+void walkAndApplyPatterns(Operation *op,
+                          const FrozenRewritePatternSet &patterns,
+                          RewriterBase::Listener *listener = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index bebe0b5a7c0b61..ad455aaf987fc6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -14,7 +14,11 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace arith {
@@ -157,11 +161,7 @@ struct ArithUnsignedWhenEquivalentPass
     RewritePatternSet patterns(ctx);
     populateUnsignedWhenEquivalentPatterns(patterns, solver);
 
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-      signalPassFailure();
+    walkAndApplyPatterns(op, std::move(patterns), &listener);
   }
 };
 } // end anonymous namespace
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index eb588640dbf83a..72eb34f36cf5f6 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
   LoopInvariantCodeMotionUtils.cpp
   OneToNTypeConversion.cpp
   RegionUtils.cpp
+  WalkPatternRewriteDriver.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
new file mode 100644
index 00000000000000..2d3aa5fc7d15c7
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+  using RewriterBase::ForwardingListener::ForwardingListener;
+
+  void notifyOperationErased(Operation *op) override {
+    if (op != visitedOp)
+      llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+                               "erasure is only supported for matched ops");
+
+    ForwardingListener::notifyOperationErased(op);
+  }
+
+  Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+                          const FrozenRewritePatternSet &patterns,
+                          RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (failed(verify(op)))
+    llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+  PatternRewriter rewriter(op->getContext());
+  ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  rewriter.setListener(&erasedListener);
+#else
+  (void)erasedListener;
+  rewriter.setListener(listener);
+#endif
+
+  PatternApplicator applicator(patterns);
+  applicator.applyDefaultCostModel();
+
+  op->walk([&](Operation *visitedOp) {
+    if (visitedOp == op)
+      return;
+
+    LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+               visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+               llvm::dbgs() << "\n";);
+    erasedListener.visitedOp = visitedOp;
+    if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+    }
+  });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (failed(verify(op)))
+    llvm::report_fatal_error(
+        "walk pattern rewriter result IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+}
+
+} // namespace mlir
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 0b4d379cfb7d5f..36e605bdbff4dc 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
 
 // CHECK-LABEL: @test_enum_attr_roundtrip
 func.func @test_enum_attr_roundtrip() -> () {
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
index f3da9a147fcb95..d619eefd721023 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
 // RUN:     -allow-unregistered-dialect --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @add_to_worklist_after_inplace_update()
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
index a362d6f99b9478..9f4a7924b725a2 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
 // RUN:     --split-input-file | FileCheck %s
 
 // Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
new file mode 100644
index 00000000000000..f7536ad3315870
--- /dev/null
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
+// RUN:   --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// The following op is updated in-place and will not be added back to the worklist.
+// CHECK-LABEL: func.func @inplace_update()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+func.func @inplace_update() {
+  "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  "test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
+  return
+}
+
+// Check that the driver does not fold visited ops.
+// CHECK-LABEL: func.func @add_no_fold()
+// CHECK: arith.constant
+// CHECK: arith.constant
+// CHECK: %[[RES:.+]] = arith.addi
+// CHECK: return %[[RES]]
+func.func @add_no_fold() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %res = arith.addi %c0, %c1 : i32
+  return %res : i32
+}
+
+// Check that the driver handles rewriter.moveBefore.
+// CHECK-LABEL: func.func @move_before(
+// CHECK: "test.move_before_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: scf.if
+// CHECK: return
+func.func @move_before(%cond : i1) {
+  scf.if %cond {
+    "test.move_before_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) : () -> ()
+  }
+  return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// CHECK-LABEL: func.func @move_after(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: return
+func.func @move_after(%cond : i1) {
+  scf.if %cond {
+    "test.move_after_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) : () -> ()
+  }
+  return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited twice since we advance its position to the next
+// node after the parent.
+// CHECK-LABEL: func.func @move_forward_and_revisit(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: arith.addi
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+// CHECK: arith.addi
+// CHECK: return
+func.func @move_forward_and_revisit(%cond : i1) {
+  scf.if %cond {
+    "test.move_after_parent_op"() ({
+      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+    }) {advance = 1 : i32} : () -> ()
+  }
+  %a = arith.addi %cond, %cond : i1
+  %b = arith.addi %a, %cond : i1
+  return
+}
+
+// Operation inserted just after the currently visited one won't be visited.
+// CHECK-LABEL: func.func @insert_just_after
+// CHECK: "test.clone_me"() ({
+// CHECK:   "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) {was_cloned} : () -> ()
+// CHECK: "test.clone_me"() ({
+// CHECK:   "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) : () -> ()
+// CHECK: return
+func.func @insert_just_after(%cond : i1) {
+  "test.clone_me"() ({
+    "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  }) : () -> ()
+  return
+}
+
+// Check that we can replace the current operation with a new one.
+// Note that the new op won't be visited.
+// CHECK-LABEL: func.func @replace_with_new_op
+// CHECK: %[[NEW:.+]] = "test.new_op"
+// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
+// CHECK: return %[[RES]]
+func.func @replace_with_new_op() -> i32 {
+  %a = "test.replace_with_new_op"() : () -> (i32)
+  %res = arith.addi %a, %a : i32
+  return %res : i32
+}
diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir
index 8ffdeb54f399dc..55556c1ec58443 100644
--- a/mlir/test/Transforms/test-operation-folder-commutative.mlir
+++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
 
 // CHECK-LABEL: func @test_reorder_constants_and_match
 func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 46ee07af993cc7..3c0cd15dc6c510 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
-// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
 
 func.func @foo() -> i32 {
   %c42 = arith.constant 42 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3eade0369f7654..c54e35b3e07be3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,12 +13,16 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/ScopeExit.h"
+#include <cstdint>
 
 using namespace mlir;
 using namespace test;
@@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
   }
 };
 
+/// This pattern moves "test.move_after_parent_op" after the parent op.
+struct MoveAfterParentOp : public RewritePattern {
+  MoveAfterParentOp(MLIRContext *context)
+      : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not hoist past functions.
+    if (isa<FunctionOpInterface>(op->getParentOp()))
+      return failure();
+
+    int64_t moveForwardBy = 0;
+    if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
+      moveForwardBy = advanceBy.getInt();
+
+    Operation *moveAfter = op->getParentOp();
+    for (int64_t i = 0; i < moveForwardBy; ++i)
+      moveAfter = moveAfter->getNextNode();
+
+    rewriter.moveOpAfter(op, moveAfter);
+    return success();
+  }
+};
+
 /// This pattern inlines blocks that are nested in
 /// "test.inline_blocks_into_parent" into the parent block.
 struct InlineBlocksIntoParent : public RewritePattern {
@@ -286,14 +314,43 @@ struct CloneRegionBeforeOp : public RewritePattern {
   }
 };
 
-struct TestPatternDriver
-    : public PassWrapper<TestPatternDriver, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
+/// Replace an operation may introduce the re-visiting of its users.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+  ReplaceWithNewOp(MLIRContext *context)
+      : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Operation *newOp;
+    if (op->hasAttr("create_erase_op")) {
+      newOp = rewriter.create(
+          op->getLoc(),
+          OperationName("test.erase_op", op->getContext()).getIdentifier(),
+          ValueRange(), TypeRange());
+    } else {
+      newOp = rewriter.create(
+          op->getLoc(),
+          OperationName("test.new_op", op->getContext()).getIdentifier(),
+          op->getOperands(), op->getResultTypes());
+    }
+    // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+    // A "notifyOperationReplaced" callback is triggered in either case.
+    rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct TestGreedyPatternDriver
+    : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
 
-  TestPatternDriver() = default;
-  TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+  TestGreedyPatternDriver() = default;
+  TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
+      : PassWrapper(other) {}
 
-  StringRef getArgument() const final { return "test-patterns"; }
+  StringRef getArgument() const final { return "test-greedy-patterns"; }
   StringRef getDescription() const final { return "Run test dialect patterns"; }
   void runOnOperation() override {
     mlir::RewritePatternSet patterns(&getContext());
@@ -470,34 +527,6 @@ struct TestStrictPatternDriver
     }
   };
 
-  // Replace an operation may introduce the re-visiting of its users.
-  class ReplaceWithNewOp : public RewritePattern {
-  public:
-    ReplaceWithNewOp(MLIRContext *context)
-        : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
-
-    LogicalResult matchAndRewrite(Operation *op,
-                                  PatternRewriter &rewriter) const override {
-      Operation *newOp;
-      if (op->hasAttr("create_erase_op")) {
-        newOp = rewriter.create(
-            op->getLoc(),
-            OperationName("test.erase_op", op->getContext()).getIdentifier(),
-            ValueRange(), TypeRange());
-      } else {
-        newOp = rewriter.create(
-            op->getLoc(),
-            OperationName("test.new_op", op->getContext()).getIdentifier(),
-            op->getOperands(), op->getResultTypes());
-      }
-      // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
-      // A "notifyOperationReplaced" callback is triggered in either case.
-      rewriter.replaceAllOpUsesWith(op, newOp->getResults());
-      rewriter.eraseOp(op);
-      return success();
-    }
-  };
-
   // Remove an operation may introduce the re-visiting of its operands.
   class EraseOp : public RewritePattern {
   public:
@@ -560,6 +589,38 @@ struct TestStrictPatternDriver
   };
 };
 
+struct TestWalkPatternDriver final
+    : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
+
+  TestWalkPatternDriver() = default;
+  TestWalkPatternDriver(const TestWalkPatternDriver &other)
+      : PassWrapper(other) {}
+
+  StringRef getArgument() const override {
+    return "test-walk-pattern-rewrite-driver";
+  }
+  StringRef getDescription() const override {
+    return "Run test greedy pattern rewrite driver";
+  }
+  void runOnOperation() override {
+    mlir::RewritePatternSet patterns(&getContext());
+
+    // Patterns for testing the WalkPatternRewriteDriver.
+    patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
+                 MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+
+    DumpNotifications dumpListener;
+    walkAndApplyPatterns(getOperation(), std::move(patterns),
+                         dumpNotifications ? &dumpListener : nullptr);
+  }
+
+  Option<bool> dumpNotifications{
+      *this, "dump-notifications",
+      llvm::cl::desc("Print rewrite listener notifications"),
+      llvm::cl::init(false)};
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1978,8 +2039,9 @@ void registerPatternsTestPass() {
 
   PassRegistration<TestDerivedAttributeDriver>();
 
-  PassRegistration<TestPatternDriver>();
+  PassRegistration<TestGreedyPatternDriver>();
   PassRegistration<TestStrictPatternDriver>();
+  PassRegistration<TestWalkPatternDriver>();
 
   PassRegistration<TestLegalizePatternDriver>([] {
     return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5ff8710b937701..60d46e676d2a33 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
 
 // CHECK-LABEL: verifyFusedLocs
 func.func @verifyFusedLocs(%arg0 : i32) -> i32 {



More information about the Mlir-commits mailing list