[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Oct 31 08:09:03 PDT 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/113825
>From 883902df7f742e7e6f5f6f32e44743f58793bdfc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 13:45:09 -0400
Subject: [PATCH 01/12] [mlir] Add walk pattern rewrite driver
---
mlir/docs/PatternRewriter.md | 32 ++++-
.../Transforms/WalkPatternRewriteDriver.h | 37 +++++
.../Transforms/UnsignedWhenEquivalent.cpp | 12 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 +
.../Utils/WalkPatternRewriteDriver.cpp | 86 ++++++++++++
mlir/test/IR/enum-attr-roundtrip.mlir | 2 +-
...eedy-pattern-rewrite-driver-bottom-up.mlir | 2 +-
...reedy-pattern-rewrite-driver-top-down.mlir | 2 +-
.../IR/test-walk-pattern-rewrite-driver.mlir | 107 ++++++++++++++
.../test-operation-folder-commutative.mlir | 2 +-
.../Transforms/test-operation-folder.mlir | 4 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 132 +++++++++++++-----
mlir/test/mlir-tblgen/pattern.mlir | 2 +-
13 files changed, 366 insertions(+), 55 deletions(-)
create mode 100644 mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
create mode 100644 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
create mode 100644 mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 0ba76199874cc3..1d036dc8ce0f6a 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -320,15 +320,33 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
framework also provides support for type conversions. More information on this
driver can be found [here](DialectConversion.md).
+### Walk Pattern Rewrite Driver
+
+This is a fast and simple driver that walks the given op and applies patterns
+that locally have the most benefit. The benefit of a pattern is decided solely
+by the benefit specified on the pattern, and the relative order of the pattern
+within the pattern list (when two patterns have the same local benefit).
+
+This driver does not (re)visit modified or newly replaced ops, and does not
+allow for progressive rewrites of the same op. Op erasure is only supported for
+the currently matched op. If your pattern-set requires these, consider using the
+Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+
+This driver is exposed using the `walkAndApplyPatterns` function.
+
+#### Debugging
+
+You can debug the Walk Pattern Rewrite Driver by passing the
+`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
+ops.
+
### Greedy Pattern Rewrite Driver
This driver processes ops in a worklist-driven fashion and greedily applies the
-patterns that locally have the most benefit. The benefit of a pattern is decided
-solely by the benefit specified on the pattern, and the relative order of the
-pattern within the pattern list (when two patterns have the same local benefit).
-Patterns are iteratively applied to operations until a fixed point is reached or
-until the configurable maximum number of iterations exhausted, at which point
-the driver finishes.
+patterns that locally have the most benefit (same as the Walk Pattern Rewrite
+Driver). Patterns are iteratively applied to operations until a fixed point is
+reached or until the configurable maximum number of iterations exhausted, at
+which point the driver finishes.
This driver comes in two fashions:
@@ -368,7 +386,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize) in MLIR.
-### Debugging
+#### Debugging
To debug the execution of the greedy pattern rewrite driver,
`-debug-only=greedy-rewriter` may be used. This command line flag activates
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
new file mode 100644
index 00000000000000..6d62ae3dd43dc1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -0,0 +1,37 @@
+//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declares a helper function to walk the given op and apply rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+
+namespace mlir {
+
+/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
+/// given operation by walking it and applying the highest benefit patterns.
+/// This rewriter *does not* wait until a fixpoint is reached and *does not*
+/// visit modified or newly replaced ops. Also *does not* perform folding or
+/// dead-code elimination.
+///
+/// This is intended as the simplest and most lightweight pattern rewriter in
+/// cases when a simple walk gets the job done.
+///
+/// Note: Does not apply patterns to the given operation itself.
+void walkAndApplyPatterns(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ RewriterBase::Listener *listener = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index bebe0b5a7c0b61..ad455aaf987fc6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -14,7 +14,11 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace arith {
@@ -157,11 +161,7 @@ struct ArithUnsignedWhenEquivalentPass
RewritePatternSet patterns(ctx);
populateUnsignedWhenEquivalentPatterns(patterns, solver);
- GreedyRewriteConfig config;
- config.listener = &listener;
-
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- signalPassFailure();
+ walkAndApplyPatterns(op, std::move(patterns), &listener);
}
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index eb588640dbf83a..72eb34f36cf5f6 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
+ WalkPatternRewriteDriver.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
new file mode 100644
index 00000000000000..2d3aa5fc7d15c7
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+ using RewriterBase::ForwardingListener::ForwardingListener;
+
+ void notifyOperationErased(Operation *op) override {
+ if (op != visitedOp)
+ llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+ "erasure is only supported for matched ops");
+
+ ForwardingListener::notifyOperationErased(op);
+ }
+
+ Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (failed(verify(op)))
+ llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+ PatternRewriter rewriter(op->getContext());
+ ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ rewriter.setListener(&erasedListener);
+#else
+ (void)erasedListener;
+ rewriter.setListener(listener);
+#endif
+
+ PatternApplicator applicator(patterns);
+ applicator.applyDefaultCostModel();
+
+ op->walk([&](Operation *visitedOp) {
+ if (visitedOp == op)
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+ visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (failed(verify(op)))
+ llvm::report_fatal_error(
+ "walk pattern rewriter result IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+}
+
+} // namespace mlir
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 0b4d379cfb7d5f..36e605bdbff4dc 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
// CHECK-LABEL: @test_enum_attr_roundtrip
func.func @test_enum_attr_roundtrip() -> () {
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
index f3da9a147fcb95..d619eefd721023 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
index a362d6f99b9478..9f4a7924b725a2 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
// RUN: --split-input-file | FileCheck %s
// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
new file mode 100644
index 00000000000000..f7536ad3315870
--- /dev/null
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
+// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// The following op is updated in-place and will not be added back to the worklist.
+// CHECK-LABEL: func.func @inplace_update()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+func.func @inplace_update() {
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ "test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
+ return
+}
+
+// Check that the driver does not fold visited ops.
+// CHECK-LABEL: func.func @add_no_fold()
+// CHECK: arith.constant
+// CHECK: arith.constant
+// CHECK: %[[RES:.+]] = arith.addi
+// CHECK: return %[[RES]]
+func.func @add_no_fold() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %res = arith.addi %c0, %c1 : i32
+ return %res : i32
+}
+
+// Check that the driver handles rewriter.moveBefore.
+// CHECK-LABEL: func.func @move_before(
+// CHECK: "test.move_before_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: scf.if
+// CHECK: return
+func.func @move_before(%cond : i1) {
+ scf.if %cond {
+ "test.move_before_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ }
+ return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// CHECK-LABEL: func.func @move_after(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: return
+func.func @move_after(%cond : i1) {
+ scf.if %cond {
+ "test.move_after_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ }
+ return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited twice since we advance its position to the next
+// node after the parent.
+// CHECK-LABEL: func.func @move_forward_and_revisit(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: arith.addi
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+// CHECK: arith.addi
+// CHECK: return
+func.func @move_forward_and_revisit(%cond : i1) {
+ scf.if %cond {
+ "test.move_after_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) {advance = 1 : i32} : () -> ()
+ }
+ %a = arith.addi %cond, %cond : i1
+ %b = arith.addi %a, %cond : i1
+ return
+}
+
+// Operation inserted just after the currently visited one won't be visited.
+// CHECK-LABEL: func.func @insert_just_after
+// CHECK: "test.clone_me"() ({
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) {was_cloned} : () -> ()
+// CHECK: "test.clone_me"() ({
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) : () -> ()
+// CHECK: return
+func.func @insert_just_after(%cond : i1) {
+ "test.clone_me"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ return
+}
+
+// Check that we can replace the current operation with a new one.
+// Note that the new op won't be visited.
+// CHECK-LABEL: func.func @replace_with_new_op
+// CHECK: %[[NEW:.+]] = "test.new_op"
+// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
+// CHECK: return %[[RES]]
+func.func @replace_with_new_op() -> i32 {
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ %res = arith.addi %a, %a : i32
+ return %res : i32
+}
diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir
index 8ffdeb54f399dc..55556c1ec58443 100644
--- a/mlir/test/Transforms/test-operation-folder-commutative.mlir
+++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
// CHECK-LABEL: func @test_reorder_constants_and_match
func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 46ee07af993cc7..3c0cd15dc6c510 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
-// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3eade0369f7654..c54e35b3e07be3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,12 +13,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
+#include <cstdint>
using namespace mlir;
using namespace test;
@@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
}
};
+/// This pattern moves "test.move_after_parent_op" after the parent op.
+struct MoveAfterParentOp : public RewritePattern {
+ MoveAfterParentOp(MLIRContext *context)
+ : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+
+ int64_t moveForwardBy = 0;
+ if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
+ moveForwardBy = advanceBy.getInt();
+
+ Operation *moveAfter = op->getParentOp();
+ for (int64_t i = 0; i < moveForwardBy; ++i)
+ moveAfter = moveAfter->getNextNode();
+
+ rewriter.moveOpAfter(op, moveAfter);
+ return success();
+ }
+};
+
/// This pattern inlines blocks that are nested in
/// "test.inline_blocks_into_parent" into the parent block.
struct InlineBlocksIntoParent : public RewritePattern {
@@ -286,14 +314,43 @@ struct CloneRegionBeforeOp : public RewritePattern {
}
};
-struct TestPatternDriver
- : public PassWrapper<TestPatternDriver, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
+/// Replace an operation may introduce the re-visiting of its users.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+ ReplaceWithNewOp(MLIRContext *context)
+ : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ Operation *newOp;
+ if (op->hasAttr("create_erase_op")) {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.erase_op", op->getContext()).getIdentifier(),
+ ValueRange(), TypeRange());
+ } else {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ }
+ // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+ // A "notifyOperationReplaced" callback is triggered in either case.
+ rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct TestGreedyPatternDriver
+ : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
- TestPatternDriver() = default;
- TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+ TestGreedyPatternDriver() = default;
+ TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
+ : PassWrapper(other) {}
- StringRef getArgument() const final { return "test-patterns"; }
+ StringRef getArgument() const final { return "test-greedy-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
@@ -470,34 +527,6 @@ struct TestStrictPatternDriver
}
};
- // Replace an operation may introduce the re-visiting of its users.
- class ReplaceWithNewOp : public RewritePattern {
- public:
- ReplaceWithNewOp(MLIRContext *context)
- : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- Operation *newOp;
- if (op->hasAttr("create_erase_op")) {
- newOp = rewriter.create(
- op->getLoc(),
- OperationName("test.erase_op", op->getContext()).getIdentifier(),
- ValueRange(), TypeRange());
- } else {
- newOp = rewriter.create(
- op->getLoc(),
- OperationName("test.new_op", op->getContext()).getIdentifier(),
- op->getOperands(), op->getResultTypes());
- }
- // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
- // A "notifyOperationReplaced" callback is triggered in either case.
- rewriter.replaceAllOpUsesWith(op, newOp->getResults());
- rewriter.eraseOp(op);
- return success();
- }
- };
-
// Remove an operation may introduce the re-visiting of its operands.
class EraseOp : public RewritePattern {
public:
@@ -560,6 +589,38 @@ struct TestStrictPatternDriver
};
};
+struct TestWalkPatternDriver final
+ : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
+
+ TestWalkPatternDriver() = default;
+ TestWalkPatternDriver(const TestWalkPatternDriver &other)
+ : PassWrapper(other) {}
+
+ StringRef getArgument() const override {
+ return "test-walk-pattern-rewrite-driver";
+ }
+ StringRef getDescription() const override {
+ return "Run test greedy pattern rewrite driver";
+ }
+ void runOnOperation() override {
+ mlir::RewritePatternSet patterns(&getContext());
+
+ // Patterns for testing the WalkPatternRewriteDriver.
+ patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
+ MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+
+ DumpNotifications dumpListener;
+ walkAndApplyPatterns(getOperation(), std::move(patterns),
+ dumpNotifications ? &dumpListener : nullptr);
+ }
+
+ Option<bool> dumpNotifications{
+ *this, "dump-notifications",
+ llvm::cl::desc("Print rewrite listener notifications"),
+ llvm::cl::init(false)};
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1978,8 +2039,9 @@ void registerPatternsTestPass() {
PassRegistration<TestDerivedAttributeDriver>();
- PassRegistration<TestPatternDriver>();
+ PassRegistration<TestGreedyPatternDriver>();
PassRegistration<TestStrictPatternDriver>();
+ PassRegistration<TestWalkPatternDriver>();
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5ff8710b937701..60d46e676d2a33 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
// CHECK-LABEL: verifyFusedLocs
func.func @verifyFusedLocs(%arg0 : i32) -> i32 {
>From 34d440ab2192895d64306938d4cd3929b7f83363 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 18:38:41 -0400
Subject: [PATCH 02/12] Iterate over regions
---
.../Utils/WalkPatternRewriteDriver.cpp | 23 +++++++++----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 2d3aa5fc7d15c7..864e86193a0d6c 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -63,18 +63,17 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- op->walk([&](Operation *visitedOp) {
- if (visitedOp == op)
- return;
-
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
- visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
- erasedListener.visitedOp = visitedOp;
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ for (Region ®ion : op->getRegions()) {
+ region.walk([&](Operation *visitedOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+ llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+ }
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
>From ccc4e08f84136180304e396af0c86deb99079662 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:03:10 -0400
Subject: [PATCH 03/12] Add action
---
mlir/docs/ActionTracing.md | 2 +-
.../Utils/WalkPatternRewriteDriver.cpp | 40 +++++++++++++------
2 files changed, 29 insertions(+), 13 deletions(-)
diff --git a/mlir/docs/ActionTracing.md b/mlir/docs/ActionTracing.md
index 978fdbbe54d81c..984516d5c5e7e2 100644
--- a/mlir/docs/ActionTracing.md
+++ b/mlir/docs/ActionTracing.md
@@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
```c++
/// A custom Action can be defined minimally by deriving from
-/// `tracing::ActionImpl`. It can has any members!
+/// `tracing::ActionImpl`. It can have any members!
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
public:
using Base = tracing::ActionImpl<MyCustomAction>;
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 864e86193a0d6c..eb450c4dacdd45 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -12,11 +12,13 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -25,9 +27,18 @@
namespace mlir {
namespace {
+struct WalkAndApplyPatternsAction final
+ : tracing::ActionImpl<WalkAndApplyPatternsAction> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
+ using ActionImpl::ActionImpl;
+ static constexpr StringLiteral tag = "walk-and-apply-patterns";
+ void print(raw_ostream &os) const override { os << tag; }
+};
+
// Forwarding listener to guard against unsupported erasures. Because we use
// walk-based pattern application, erasing the op from the *next* iteration
// (e.g., a user of the visited op) is not valid.
+// Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;
@@ -51,7 +62,8 @@ void walkAndApplyPatterns(Operation *op,
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- PatternRewriter rewriter(op->getContext());
+ MLIRContext *ctx = op->getContext();
+ PatternRewriter rewriter(ctx);
ErasedOpsListener erasedListener(listener);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriter.setListener(&erasedListener);
@@ -63,17 +75,21 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- for (Region ®ion : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
- erasedListener.visitedOp = visitedOp;
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
- }
+ ctx->executeAction<WalkAndApplyPatternsAction>(
+ [&] {
+ for (Region ®ion : op->getRegions()) {
+ region.walk([&](Operation *visitedOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+ llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+ }
+ },
+ {op});
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
>From 62cb821e2365a6b33bab21134f0cedfb6575465b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:07:03 -0400
Subject: [PATCH 04/12] Coding style
---
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index eb450c4dacdd45..0258064a3436b7 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -70,7 +70,7 @@ void walkAndApplyPatterns(Operation *op,
#else
(void)erasedListener;
rewriter.setListener(listener);
-#endif
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
>From aaa9f00802663f7a35ff7bd0bd7eef9161018533 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:40:22 -0400
Subject: [PATCH 05/12] Cleanup
---
mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp | 4 ----
1 file changed, 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index ad455aaf987fc6..986450fb128aed 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,10 +13,6 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Visitors.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
>From 2b1205f6ce7c9f70c8ec7a58ccaf2ed7aa5a2874 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:41:11 -0400
Subject: [PATCH 06/12] Cleanup
---
mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 986450fb128aed..8922e93e399f9f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,8 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace arith {
>From d70b289e7acd88b12a502d4c053c1e2ef775f5f9 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:42:38 -0400
Subject: [PATCH 07/12] Typo
---
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c54e35b3e07be3..dd801513adbe66 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -601,7 +601,7 @@ struct TestWalkPatternDriver final
return "test-walk-pattern-rewrite-driver";
}
StringRef getDescription() const override {
- return "Run test greedy pattern rewrite driver";
+ return "Run test walk pattern rewrite driver";
}
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
>From 9d295d840b8010977ad42f46653c20cd888a7b6b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 20:51:31 -0400
Subject: [PATCH 08/12] Clarify docs
---
mlir/docs/PatternRewriter.md | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 1d036dc8ce0f6a..315a8dc0d1620d 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -327,9 +327,12 @@ that locally have the most benefit. The benefit of a pattern is decided solely
by the benefit specified on the pattern, and the relative order of the pattern
within the pattern list (when two patterns have the same local benefit).
+The driver performs a post-order traversal. Note that it walks regions of the
+given op but does not visit the op.
+
This driver does not (re)visit modified or newly replaced ops, and does not
allow for progressive rewrites of the same op. Op erasure is only supported for
-the currently matched op. If your pattern-set requires these, consider using the
+the currently matched op. If your pattern set requires these, consider using the
Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
This driver is exposed using the `walkAndApplyPatterns` function.
>From 63d0a9b72ff1b9fa31e274acbc04738181979a17 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 19:55:05 -0400
Subject: [PATCH 09/12] Guard forwarding listener
---
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 0258064a3436b7..b1ce47805bf435 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -35,6 +34,7 @@ struct WalkAndApplyPatternsAction final
void print(raw_ostream &os) const override { os << tag; }
};
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Forwarding listener to guard against unsupported erasures. Because we use
// walk-based pattern application, erasing the op from the *next* iteration
// (e.g., a user of the visited op) is not valid.
@@ -52,6 +52,7 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
Operation *visitedOp = nullptr;
};
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
} // namespace
void walkAndApplyPatterns(Operation *op,
@@ -64,11 +65,10 @@ void walkAndApplyPatterns(Operation *op,
MLIRContext *ctx = op->getContext();
PatternRewriter rewriter(ctx);
- ErasedOpsListener erasedListener(listener);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ ErasedOpsListener erasedListener(listener);
rewriter.setListener(&erasedListener);
#else
- (void)erasedListener;
rewriter.setListener(listener);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -82,7 +82,9 @@ void walkAndApplyPatterns(Operation *op,
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = visitedOp;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
}
>From 4a524aec9e1af1b086dd4e797153fc9d351401c2 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 22:15:19 -0400
Subject: [PATCH 10/12] Guard against invalid block erasure. Support forwarding
to null listeners.
---
mlir/docs/PatternRewriter.md | 11 +++++--
mlir/include/mlir/IR/PatternMatch.h | 28 +++++++++++-------
.../Utils/WalkPatternRewriteDriver.cpp | 29 ++++++++++++++-----
.../IR/test-walk-pattern-rewrite-driver.mlir | 14 +++++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 25 +++++++++++++++-
5 files changed, 84 insertions(+), 23 deletions(-)
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 315a8dc0d1620d..c61ceaf81681e2 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -331,12 +331,17 @@ The driver performs a post-order traversal. Note that it walks regions of the
given op but does not visit the op.
This driver does not (re)visit modified or newly replaced ops, and does not
-allow for progressive rewrites of the same op. Op erasure is only supported for
-the currently matched op. If your pattern set requires these, consider using the
-Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+allow for progressive rewrites of the same op. Op and block erasure is only
+supported for the currently matched op and its descendant. If your pattern
+set requires these, consider using the Greedy Pattern Rewrite Driver instead,
+at the expense of extra overhead.
This driver is exposed using the `walkAndApplyPatterns` function.
+Note: This driver listens for IR changes via the callbacks provided by
+`RewriterBase`. It is important that patterns announce all IR changes to the
+rewriter and do not bypass the rewriter API by modifying ops directly.
+
#### Debugging
You can debug the Walk Pattern Rewrite Driver by passing the
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 896fdf1c899e3d..2ab0405043a546 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
/// struct can be used as a base to create listener chains, so that multiple
/// listeners can be notified of IR changes.
struct ForwardingListener : public RewriterBase::Listener {
- ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
+ ForwardingListener(OpBuilder::Listener *listener)
+ : listener(listener),
+ rewriteListener(
+ dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
- listener->notifyOperationInserted(op, previous);
+ if (listener)
+ listener->notifyOperationInserted(op, previous);
}
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
- listener->notifyBlockInserted(block, previous, previousIt);
+ if (listener)
+ listener->notifyBlockInserted(block, previous, previousIt);
}
void notifyBlockErased(Block *block) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyBlockErased(block);
}
void notifyOperationModified(Operation *op) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyOperationModified(op);
}
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, newOp);
}
void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, replacement);
}
void notifyOperationErased(Operation *op) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyOperationErased(op);
}
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyPatternBegin(pattern, op);
}
void notifyPatternEnd(const Pattern &pattern,
LogicalResult status) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyPatternEnd(pattern, status);
}
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
- if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ if (rewriteListener)
rewriteListener->notifyMatchFailure(loc, reasonCallback);
}
private:
OpBuilder::Listener *listener;
+ RewriterBase::Listener *rewriteListener;
};
/// Move the blocks that belong to "region" before the given position in
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index b1ce47805bf435..efbc646f2ef276 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -35,21 +35,34 @@ struct WalkAndApplyPatternsAction final
};
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-// Forwarding listener to guard against unsupported erasures. Because we use
-// walk-based pattern application, erasing the op from the *next* iteration
-// (e.g., a user of the visited op) is not valid.
-// Note that this is only used with expensive pattern API checks.
+// Forwarding listener to guard against unsupported erasures of non-descendant
+// ops/blocks. Because we use walk-based pattern application, erasing the
+// op/block from the *next* iteration (e.g., a user of the visited op) is not
+// valid. Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;
void notifyOperationErased(Operation *op) override {
- if (op != visitedOp)
- llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
- "erasure is only supported for matched ops");
-
+ checkErasure(op);
ForwardingListener::notifyOperationErased(op);
}
+ void notifyBlockErased(Block *block) override {
+ checkErasure(block->getParentOp());
+ ForwardingListener::notifyBlockErased(block);
+ }
+
+ void checkErasure(Operation *op) const {
+ Operation *ancestorOp = op;
+ while (ancestorOp && ancestorOp != visitedOp)
+ ancestorOp = ancestorOp->getParentOp();
+
+ if (ancestorOp != visitedOp)
+ llvm::report_fatal_error(
+ "unsupported erased in WalkPatternRewriter; "
+ "erasure is only supported for matched ops and their descendants");
+ }
+
Operation *visitedOp = nullptr;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index f7536ad3315870..5423552a8ef1dd 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -105,3 +105,17 @@ func.func @replace_with_new_op() -> i32 {
%res = arith.addi %a, %a : i32
return %res : i32
}
+
+// Check that we can erase nested blocks.
+// CHECK-LABEL: func.func @erase_nested_block
+// CHECK: %[[RES:.+]] = "test.erase_first_block"
+// CHECK-NEXT: foo.bar
+// CHECK: return %[[RES]]
+func.func @erase_nested_block() -> i32 {
+ %a = "test.erase_first_block"() ({
+ "foo.foo"() : () -> ()
+ ^bb1:
+ "foo.bar"() : () -> ()
+ }): () -> (i32)
+ return %a : i32
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index dd801513adbe66..d97f3b41f2ef29 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -342,6 +342,28 @@ class ReplaceWithNewOp : public RewritePattern {
}
};
+/// Erases the first child block of the matched "test.erase_first_block"
+/// operation.
+class EraseFirstBlock : public RewritePattern {
+public:
+ EraseFirstBlock(MLIRContext *context)
+ : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
+ for (Region &r : op->getRegions()) {
+ for (Block &b : r.getBlocks()) {
+ rewriter.eraseBlock(&b);
+ llvm::errs() << "Erasing block: " << b << "\n";
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
struct TestGreedyPatternDriver
: public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
@@ -608,7 +630,8 @@ struct TestWalkPatternDriver final
// Patterns for testing the WalkPatternRewriteDriver.
patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
- MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+ MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
+ &getContext());
DumpNotifications dumpListener;
walkAndApplyPatterns(getOperation(), std::move(patterns),
>From 5e85a8fadcd67915d5409f7cbecc9145290f3a37 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 30 Oct 2024 22:25:15 -0400
Subject: [PATCH 11/12] Add missing newline
---
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 5423552a8ef1dd..02f7e60671c9b3 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -118,4 +118,4 @@ func.func @erase_nested_block() -> i32 {
"foo.bar"() : () -> ()
}): () -> (i32)
return %a : i32
-}
\ No newline at end of file
+}
>From dd4ed14676d274c4672ee1f72cd09870b52148d7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 31 Oct 2024 11:08:19 -0400
Subject: [PATCH 12/12] Fix typo
---
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index efbc646f2ef276..ee5c642c943c45 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -59,7 +59,7 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
if (ancestorOp != visitedOp)
llvm::report_fatal_error(
- "unsupported erased in WalkPatternRewriter; "
+ "unsupported erasure in WalkPatternRewriter; "
"erasure is only supported for matched ops and their descendants");
}
More information about the Mlir-commits
mailing list