[Mlir-commits] [mlir] [MLIR][Python] Add support of the walk pattern rewriter driver (PR #173562)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 25 05:08:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

MLIR currently has three main pattern rewrite drivers (see [https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers](https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers)):

* Dialect Conversion Driver
* Walk Pattern Rewrite Driver
* Greedy Pattern Rewrite Driver

Right now, we already support the greedy pattern rewrite driver in the C API and Python bindings. This PR adds support for the walk pattern rewrite driver. This lightweight driver, unlike the greedy driver, does not repeatedly apply patterns; instead, it walks the IR once. API-wise, the main change is adding the `walk_and_apply_patterns` function.

Note that the listener argument is not supported now.

---
Full diff: https://github.com/llvm/llvm-project/pull/173562.diff


4 Files Affected:

- (modified) mlir/include/mlir-c/Rewrite.h (+4) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+12-1) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+8) 
- (modified) mlir/test/python/rewrite.py (+16) 


``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index fe42a20e73482..d035110c63d5c 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -321,6 +321,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+MLIR_CAPI_EXPORTED void
+mlirWalkAndApplyPatterns(MlirOperation op,
+                         MlirFrozenRewritePatternSet patterns);
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..0df9d0cbc7ffc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -437,5 +437,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
           // clang-format on
           "Applys the given patterns to the given op greedily while folding "
-          "results.");
+          "results.")
+      .def(
+          "walk_and_apply_patterns",
+          [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+            mlirWalkAndApplyPatterns(op.getOperation(), set.get());
+          },
+          "op"_a, "set"_a,
+          // clang-format off
+          nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
+          // clang-format on
+          "Applies the given patterns to the given op by a fast walk-based "
+          "driver.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 41ceb1580a4e8..7413c9791da12 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -296,6 +297,13 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+/// Applies the given patterns to the given op by a fast walk-based pattern
+/// rewrite driver.
+void mlirWalkAndApplyPatterns(MlirOperation op,
+                              MlirFrozenRewritePatternSet patterns) {
+  mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
+}
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 821e47085a5bd..e40d5eb92b86f 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -68,3 +68,19 @@ def constant_1_to_2(op, rewriter):
         # CHECK: %c3_i64 = arith.constant 3 : i64
         # CHECK: return %c2_i64, %c3_i64 : i64, i64
         print(module)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        walk_and_apply_patterns(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/173562


More information about the Mlir-commits mailing list