[Mlir-commits] [mlir] 3ed1e9c - [MLIR][Python] Add support of the walk pattern rewrite driver (#173562)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 26 00:11:11 PST 2025
Author: Twice
Date: 2025-12-26T16:11:06+08:00
New Revision: 3ed1e9c85dc19e81f8ade20503e917580233900d
URL: https://github.com/llvm/llvm-project/commit/3ed1e9c85dc19e81f8ade20503e917580233900d
DIFF: https://github.com/llvm/llvm-project/commit/3ed1e9c85dc19e81f8ade20503e917580233900d.diff
LOG: [MLIR][Python] Add support of the walk pattern rewrite driver (#173562)
MLIR currently has three main pattern rewrite drivers (see
[https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers](https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers)):
* Dialect Conversion Driver
* Walk Pattern Rewrite Driver
* Greedy Pattern Rewrite Driver
Right now, we already support the greedy pattern rewrite driver in the C
API and Python bindings. This PR adds support for the walk pattern
rewrite driver. This lightweight driver, unlike the greedy driver, does
not repeatedly apply patterns; instead, it walks the IR once. API-wise,
the main change is adding the `walk_and_apply_patterns` function.
Note that the listener parameter is not supported now.
Added:
Modified:
mlir/include/mlir-c/Rewrite.h
mlir/lib/Bindings/Python/Rewrite.cpp
mlir/lib/CAPI/Transforms/Rewrite.cpp
mlir/test/python/rewrite.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index fe42a20e73482..9e2685719fe4d 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -321,6 +321,12 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
+/// Applies the given patterns to the given op by a fast walk-based pattern
+/// rewrite driver.
+MLIR_CAPI_EXPORTED void
+mlirWalkAndApplyPatterns(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns);
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..0df9d0cbc7ffc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -437,5 +437,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
- "results.");
+ "results.")
+ .def(
+ "walk_and_apply_patterns",
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+ mlirWalkAndApplyPatterns(op.getOperation(), set.get());
+ },
+ "op"_a, "set"_a,
+ // clang-format off
+ nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
+ // clang-format on
+ "Applies the given patterns to the given op by a fast walk-based "
+ "driver.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 41ceb1580a4e8..fd4ae6ffc72be 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
using namespace mlir;
@@ -296,6 +297,11 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+void mlirWalkAndApplyPatterns(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns) {
+ mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
+}
+
//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 821e47085a5bd..e40d5eb92b86f 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -68,3 +68,19 @@ def constant_1_to_2(op, rewriter):
# CHECK: %c3_i64 = arith.constant 3 : i64
# CHECK: return %c2_i64, %c3_i64 : i64, i64
print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ walk_and_apply_patterns(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
More information about the Mlir-commits
mailing list