[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Oct 27 16:07:13 PDT 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/113825
>From 883902df7f742e7e6f5f6f32e44743f58793bdfc Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 13:45:09 -0400
Subject: [PATCH 1/4] [mlir] Add walk pattern rewrite driver
---
mlir/docs/PatternRewriter.md | 32 ++++-
.../Transforms/WalkPatternRewriteDriver.h | 37 +++++
.../Transforms/UnsignedWhenEquivalent.cpp | 12 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 +
.../Utils/WalkPatternRewriteDriver.cpp | 86 ++++++++++++
mlir/test/IR/enum-attr-roundtrip.mlir | 2 +-
...eedy-pattern-rewrite-driver-bottom-up.mlir | 2 +-
...reedy-pattern-rewrite-driver-top-down.mlir | 2 +-
.../IR/test-walk-pattern-rewrite-driver.mlir | 107 ++++++++++++++
.../test-operation-folder-commutative.mlir | 2 +-
.../Transforms/test-operation-folder.mlir | 4 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 132 +++++++++++++-----
mlir/test/mlir-tblgen/pattern.mlir | 2 +-
13 files changed, 366 insertions(+), 55 deletions(-)
create mode 100644 mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
create mode 100644 mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
create mode 100644 mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 0ba76199874cc3..1d036dc8ce0f6a 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -320,15 +320,33 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
framework also provides support for type conversions. More information on this
driver can be found [here](DialectConversion.md).
+### Walk Pattern Rewrite Driver
+
+This is a fast and simple driver that walks the given op and applies patterns
+that locally have the most benefit. The benefit of a pattern is decided solely
+by the benefit specified on the pattern, and the relative order of the pattern
+within the pattern list (when two patterns have the same local benefit).
+
+This driver does not (re)visit modified or newly replaced ops, and does not
+allow for progressive rewrites of the same op. Op erasure is only supported for
+the currently matched op. If your pattern-set requires these, consider using the
+Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
+
+This driver is exposed using the `walkAndApplyPatterns` function.
+
+#### Debugging
+
+You can debug the Walk Pattern Rewrite Driver by passing the
+`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
+ops.
+
### Greedy Pattern Rewrite Driver
This driver processes ops in a worklist-driven fashion and greedily applies the
-patterns that locally have the most benefit. The benefit of a pattern is decided
-solely by the benefit specified on the pattern, and the relative order of the
-pattern within the pattern list (when two patterns have the same local benefit).
-Patterns are iteratively applied to operations until a fixed point is reached or
-until the configurable maximum number of iterations exhausted, at which point
-the driver finishes.
+patterns that locally have the most benefit (same as the Walk Pattern Rewrite
+Driver). Patterns are iteratively applied to operations until a fixed point is
+reached or until the configurable maximum number of iterations exhausted, at
+which point the driver finishes.
This driver comes in two fashions:
@@ -368,7 +386,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize) in MLIR.
-### Debugging
+#### Debugging
To debug the execution of the greedy pattern rewrite driver,
`-debug-only=greedy-rewriter` may be used. This command line flag activates
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
new file mode 100644
index 00000000000000..6d62ae3dd43dc1
--- /dev/null
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -0,0 +1,37 @@
+//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declares a helper function to walk the given op and apply rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+
+namespace mlir {
+
+/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
+/// given operation by walking it and applying the highest benefit patterns.
+/// This rewriter *does not* wait until a fixpoint is reached and *does not*
+/// visit modified or newly replaced ops. Also *does not* perform folding or
+/// dead-code elimination.
+///
+/// This is intended as the simplest and most lightweight pattern rewriter in
+/// cases when a simple walk gets the job done.
+///
+/// Note: Does not apply patterns to the given operation itself.
+void walkAndApplyPatterns(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ RewriterBase::Listener *listener = nullptr);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index bebe0b5a7c0b61..ad455aaf987fc6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -14,7 +14,11 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace arith {
@@ -157,11 +161,7 @@ struct ArithUnsignedWhenEquivalentPass
RewritePatternSet patterns(ctx);
populateUnsignedWhenEquivalentPatterns(patterns, solver);
- GreedyRewriteConfig config;
- config.listener = &listener;
-
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- signalPassFailure();
+ walkAndApplyPatterns(op, std::move(patterns), &listener);
}
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index eb588640dbf83a..72eb34f36cf5f6 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
+ WalkPatternRewriteDriver.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
new file mode 100644
index 00000000000000..2d3aa5fc7d15c7
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+ using RewriterBase::ForwardingListener::ForwardingListener;
+
+ void notifyOperationErased(Operation *op) override {
+ if (op != visitedOp)
+ llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+ "erasure is only supported for matched ops");
+
+ ForwardingListener::notifyOperationErased(op);
+ }
+
+ Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (failed(verify(op)))
+ llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+ PatternRewriter rewriter(op->getContext());
+ ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ rewriter.setListener(&erasedListener);
+#else
+ (void)erasedListener;
+ rewriter.setListener(listener);
+#endif
+
+ PatternApplicator applicator(patterns);
+ applicator.applyDefaultCostModel();
+
+ op->walk([&](Operation *visitedOp) {
+ if (visitedOp == op)
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+ visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (failed(verify(op)))
+ llvm::report_fatal_error(
+ "walk pattern rewriter result IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+}
+
+} // namespace mlir
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 0b4d379cfb7d5f..36e605bdbff4dc 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
// CHECK-LABEL: @test_enum_attr_roundtrip
func.func @test_enum_attr_roundtrip() -> () {
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
index f3da9a147fcb95..d619eefd721023 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
diff --git a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
index a362d6f99b9478..9f4a7924b725a2 100644
--- a/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
+++ b/mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
+// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
// RUN: --split-input-file | FileCheck %s
// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
new file mode 100644
index 00000000000000..f7536ad3315870
--- /dev/null
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
+// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// The following op is updated in-place and will not be added back to the worklist.
+// CHECK-LABEL: func.func @inplace_update()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+func.func @inplace_update() {
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ "test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
+ return
+}
+
+// Check that the driver does not fold visited ops.
+// CHECK-LABEL: func.func @add_no_fold()
+// CHECK: arith.constant
+// CHECK: arith.constant
+// CHECK: %[[RES:.+]] = arith.addi
+// CHECK: return %[[RES]]
+func.func @add_no_fold() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %res = arith.addi %c0, %c1 : i32
+ return %res : i32
+}
+
+// Check that the driver handles rewriter.moveBefore.
+// CHECK-LABEL: func.func @move_before(
+// CHECK: "test.move_before_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: scf.if
+// CHECK: return
+func.func @move_before(%cond : i1) {
+ scf.if %cond {
+ "test.move_before_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ }
+ return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// CHECK-LABEL: func.func @move_after(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: return
+func.func @move_after(%cond : i1) {
+ scf.if %cond {
+ "test.move_after_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ }
+ return
+}
+
+// Check that the driver handles rewriter.moveAfter. In this case, we expect
+// the moved op to be visited twice since we advance its position to the next
+// node after the parent.
+// CHECK-LABEL: func.func @move_forward_and_revisit(
+// CHECK: scf.if
+// CHECK: }
+// CHECK: arith.addi
+// CHECK: "test.move_after_parent_op"
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
+// CHECK: arith.addi
+// CHECK: return
+func.func @move_forward_and_revisit(%cond : i1) {
+ scf.if %cond {
+ "test.move_after_parent_op"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) {advance = 1 : i32} : () -> ()
+ }
+ %a = arith.addi %cond, %cond : i1
+ %b = arith.addi %a, %cond : i1
+ return
+}
+
+// Operation inserted just after the currently visited one won't be visited.
+// CHECK-LABEL: func.func @insert_just_after
+// CHECK: "test.clone_me"() ({
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) {was_cloned} : () -> ()
+// CHECK: "test.clone_me"() ({
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: }) : () -> ()
+// CHECK: return
+func.func @insert_just_after(%cond : i1) {
+ "test.clone_me"() ({
+ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+ }) : () -> ()
+ return
+}
+
+// Check that we can replace the current operation with a new one.
+// Note that the new op won't be visited.
+// CHECK-LABEL: func.func @replace_with_new_op
+// CHECK: %[[NEW:.+]] = "test.new_op"
+// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
+// CHECK: return %[[RES]]
+func.func @replace_with_new_op() -> i32 {
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ %res = arith.addi %a, %a : i32
+ return %res : i32
+}
diff --git a/mlir/test/Transforms/test-operation-folder-commutative.mlir b/mlir/test/Transforms/test-operation-folder-commutative.mlir
index 8ffdeb54f399dc..55556c1ec58443 100644
--- a/mlir/test/Transforms/test-operation-folder-commutative.mlir
+++ b/mlir/test/Transforms/test-operation-folder-commutative.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
// CHECK-LABEL: func @test_reorder_constants_and_match
func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 46ee07af993cc7..3c0cd15dc6c510 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
-// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3eade0369f7654..c54e35b3e07be3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,12 +13,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
+#include <cstdint>
using namespace mlir;
using namespace test;
@@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
}
};
+/// This pattern moves "test.move_after_parent_op" after the parent op.
+struct MoveAfterParentOp : public RewritePattern {
+ MoveAfterParentOp(MLIRContext *context)
+ : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+
+ int64_t moveForwardBy = 0;
+ if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
+ moveForwardBy = advanceBy.getInt();
+
+ Operation *moveAfter = op->getParentOp();
+ for (int64_t i = 0; i < moveForwardBy; ++i)
+ moveAfter = moveAfter->getNextNode();
+
+ rewriter.moveOpAfter(op, moveAfter);
+ return success();
+ }
+};
+
/// This pattern inlines blocks that are nested in
/// "test.inline_blocks_into_parent" into the parent block.
struct InlineBlocksIntoParent : public RewritePattern {
@@ -286,14 +314,43 @@ struct CloneRegionBeforeOp : public RewritePattern {
}
};
-struct TestPatternDriver
- : public PassWrapper<TestPatternDriver, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
+/// Replace an operation may introduce the re-visiting of its users.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+ ReplaceWithNewOp(MLIRContext *context)
+ : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ Operation *newOp;
+ if (op->hasAttr("create_erase_op")) {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.erase_op", op->getContext()).getIdentifier(),
+ ValueRange(), TypeRange());
+ } else {
+ newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ }
+ // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+ // A "notifyOperationReplaced" callback is triggered in either case.
+ rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct TestGreedyPatternDriver
+ : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
- TestPatternDriver() = default;
- TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+ TestGreedyPatternDriver() = default;
+ TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
+ : PassWrapper(other) {}
- StringRef getArgument() const final { return "test-patterns"; }
+ StringRef getArgument() const final { return "test-greedy-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
@@ -470,34 +527,6 @@ struct TestStrictPatternDriver
}
};
- // Replace an operation may introduce the re-visiting of its users.
- class ReplaceWithNewOp : public RewritePattern {
- public:
- ReplaceWithNewOp(MLIRContext *context)
- : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- Operation *newOp;
- if (op->hasAttr("create_erase_op")) {
- newOp = rewriter.create(
- op->getLoc(),
- OperationName("test.erase_op", op->getContext()).getIdentifier(),
- ValueRange(), TypeRange());
- } else {
- newOp = rewriter.create(
- op->getLoc(),
- OperationName("test.new_op", op->getContext()).getIdentifier(),
- op->getOperands(), op->getResultTypes());
- }
- // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
- // A "notifyOperationReplaced" callback is triggered in either case.
- rewriter.replaceAllOpUsesWith(op, newOp->getResults());
- rewriter.eraseOp(op);
- return success();
- }
- };
-
// Remove an operation may introduce the re-visiting of its operands.
class EraseOp : public RewritePattern {
public:
@@ -560,6 +589,38 @@ struct TestStrictPatternDriver
};
};
+struct TestWalkPatternDriver final
+ : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
+
+ TestWalkPatternDriver() = default;
+ TestWalkPatternDriver(const TestWalkPatternDriver &other)
+ : PassWrapper(other) {}
+
+ StringRef getArgument() const override {
+ return "test-walk-pattern-rewrite-driver";
+ }
+ StringRef getDescription() const override {
+ return "Run test greedy pattern rewrite driver";
+ }
+ void runOnOperation() override {
+ mlir::RewritePatternSet patterns(&getContext());
+
+ // Patterns for testing the WalkPatternRewriteDriver.
+ patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
+ MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
+
+ DumpNotifications dumpListener;
+ walkAndApplyPatterns(getOperation(), std::move(patterns),
+ dumpNotifications ? &dumpListener : nullptr);
+ }
+
+ Option<bool> dumpNotifications{
+ *this, "dump-notifications",
+ llvm::cl::desc("Print rewrite listener notifications"),
+ llvm::cl::init(false)};
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1978,8 +2039,9 @@ void registerPatternsTestPass() {
PassRegistration<TestDerivedAttributeDriver>();
- PassRegistration<TestPatternDriver>();
+ PassRegistration<TestGreedyPatternDriver>();
PassRegistration<TestStrictPatternDriver>();
+ PassRegistration<TestWalkPatternDriver>();
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5ff8710b937701..60d46e676d2a33 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
+// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
// CHECK-LABEL: verifyFusedLocs
func.func @verifyFusedLocs(%arg0 : i32) -> i32 {
>From 34d440ab2192895d64306938d4cd3929b7f83363 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 18:38:41 -0400
Subject: [PATCH 2/4] Iterate over regions
---
.../Utils/WalkPatternRewriteDriver.cpp | 23 +++++++++----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 2d3aa5fc7d15c7..864e86193a0d6c 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -63,18 +63,17 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- op->walk([&](Operation *visitedOp) {
- if (visitedOp == op)
- return;
-
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
- visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
- erasedListener.visitedOp = visitedOp;
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ for (Region ®ion : op->getRegions()) {
+ region.walk([&](Operation *visitedOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+ llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+ }
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
>From ccc4e08f84136180304e396af0c86deb99079662 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:03:10 -0400
Subject: [PATCH 3/4] Add action
---
mlir/docs/ActionTracing.md | 2 +-
.../Utils/WalkPatternRewriteDriver.cpp | 40 +++++++++++++------
2 files changed, 29 insertions(+), 13 deletions(-)
diff --git a/mlir/docs/ActionTracing.md b/mlir/docs/ActionTracing.md
index 978fdbbe54d81c..984516d5c5e7e2 100644
--- a/mlir/docs/ActionTracing.md
+++ b/mlir/docs/ActionTracing.md
@@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
```c++
/// A custom Action can be defined minimally by deriving from
-/// `tracing::ActionImpl`. It can has any members!
+/// `tracing::ActionImpl`. It can have any members!
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
public:
using Base = tracing::ActionImpl<MyCustomAction>;
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 864e86193a0d6c..eb450c4dacdd45 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -12,11 +12,13 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -25,9 +27,18 @@
namespace mlir {
namespace {
+struct WalkAndApplyPatternsAction final
+ : tracing::ActionImpl<WalkAndApplyPatternsAction> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
+ using ActionImpl::ActionImpl;
+ static constexpr StringLiteral tag = "walk-and-apply-patterns";
+ void print(raw_ostream &os) const override { os << tag; }
+};
+
// Forwarding listener to guard against unsupported erasures. Because we use
// walk-based pattern application, erasing the op from the *next* iteration
// (e.g., a user of the visited op) is not valid.
+// Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;
@@ -51,7 +62,8 @@ void walkAndApplyPatterns(Operation *op,
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- PatternRewriter rewriter(op->getContext());
+ MLIRContext *ctx = op->getContext();
+ PatternRewriter rewriter(ctx);
ErasedOpsListener erasedListener(listener);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriter.setListener(&erasedListener);
@@ -63,17 +75,21 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- for (Region ®ion : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
- erasedListener.visitedOp = visitedOp;
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
- }
+ ctx->executeAction<WalkAndApplyPatternsAction>(
+ [&] {
+ for (Region ®ion : op->getRegions()) {
+ region.walk([&](Operation *visitedOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
+ llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
+ }
+ });
+ }
+ },
+ {op});
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
>From 62cb821e2365a6b33bab21134f0cedfb6575465b Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 27 Oct 2024 19:07:03 -0400
Subject: [PATCH 4/4] Coding style
---
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index eb450c4dacdd45..0258064a3436b7 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -70,7 +70,7 @@ void walkAndApplyPatterns(Operation *op,
#else
(void)erasedListener;
rewriter.setListener(listener);
-#endif
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
More information about the Mlir-commits
mailing list