[Mlir-commits] [mlir] [MLIR][Python] Add support of the walk pattern rewriter driver (PR #173562)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 25 05:08:06 PST 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/173562
MLIR currently has three main pattern rewrite drivers (see [https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers](https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers)):
* Dialect Conversion Driver
* Walk Pattern Rewrite Driver
* Greedy Pattern Rewrite Driver
Right now, we already support the greedy pattern rewrite driver in the C API and Python bindings. This PR adds support for the walk pattern rewrite driver. This lightweight driver, unlike the greedy driver, does not repeatedly apply patterns; instead, it walks the IR once. API-wise, the main change is adding the `walk_and_apply_patterns` function.
Note that the listener argument is not supported now.
>From faa5f79a26b667fa8e13104a32a4cac0694b23c3 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 25 Dec 2025 20:57:53 +0800
Subject: [PATCH] [MLIR][Python] Add support of the walk pattern rewriter
driver
---
mlir/include/mlir-c/Rewrite.h | 4 ++++
mlir/lib/Bindings/Python/Rewrite.cpp | 13 ++++++++++++-
mlir/lib/CAPI/Transforms/Rewrite.cpp | 8 ++++++++
mlir/test/python/rewrite.py | 16 ++++++++++++++++
4 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index fe42a20e73482..d035110c63d5c 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -321,6 +321,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+MLIR_CAPI_EXPORTED void
+mlirWalkAndApplyPatterns(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns);
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..0df9d0cbc7ffc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -437,5 +437,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
- "results.");
+ "results.")
+ .def(
+ "walk_and_apply_patterns",
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+ mlirWalkAndApplyPatterns(op.getOperation(), set.get());
+ },
+ "op"_a, "set"_a,
+ // clang-format off
+ nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
+ // clang-format on
+ "Applies the given patterns to the given op by a fast walk-based "
+ "driver.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 41ceb1580a4e8..7413c9791da12 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
using namespace mlir;
@@ -296,6 +297,13 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+/// Applies the given patterns to the given op by a fast walk-based pattern
+/// rewrite driver.
+void mlirWalkAndApplyPatterns(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns) {
+ mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
+}
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 821e47085a5bd..e40d5eb92b86f 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -68,3 +68,19 @@ def constant_1_to_2(op, rewriter):
# CHECK: %c3_i64 = arith.constant 3 : i64
# CHECK: return %c2_i64, %c3_i64 : i64, i64
print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ walk_and_apply_patterns(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
More information about the Mlir-commits
mailing list