[Mlir-commits] [mlir] 3ed1e9c - [MLIR][Python] Add support of the walk pattern rewrite driver (#173562)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 26 00:11:11 PST 2025


Author: Twice
Date: 2025-12-26T16:11:06+08:00
New Revision: 3ed1e9c85dc19e81f8ade20503e917580233900d

URL: https://github.com/llvm/llvm-project/commit/3ed1e9c85dc19e81f8ade20503e917580233900d
DIFF: https://github.com/llvm/llvm-project/commit/3ed1e9c85dc19e81f8ade20503e917580233900d.diff

LOG: [MLIR][Python] Add support of the walk pattern rewrite driver (#173562)

MLIR currently has three main pattern rewrite drivers (see
[https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers](https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers)):

* Dialect Conversion Driver
* Walk Pattern Rewrite Driver
* Greedy Pattern Rewrite Driver

Right now, we already support the greedy pattern rewrite driver in the C
API and Python bindings. This PR adds support for the walk pattern
rewrite driver. This lightweight driver, unlike the greedy driver, does
not repeatedly apply patterns; instead, it walks the IR once. API-wise,
the main change is adding the `walk_and_apply_patterns` function.

Note that the listener parameter is not supported now.

Added: 
    

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp
    mlir/test/python/rewrite.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index fe42a20e73482..9e2685719fe4d 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -321,6 +321,12 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
 
+/// Applies the given patterns to the given op by a fast walk-based pattern
+/// rewrite driver.
+MLIR_CAPI_EXPORTED void
+mlirWalkAndApplyPatterns(MlirOperation op,
+                         MlirFrozenRewritePatternSet patterns);
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..0df9d0cbc7ffc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -437,5 +437,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
           // clang-format on
           "Applys the given patterns to the given op greedily while folding "
-          "results.");
+          "results.")
+      .def(
+          "walk_and_apply_patterns",
+          [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+            mlirWalkAndApplyPatterns(op.getOperation(), set.get());
+          },
+          "op"_a, "set"_a,
+          // clang-format off
+          nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
+          // clang-format on
+          "Applies the given patterns to the given op by a fast walk-based "
+          "driver.");
 }

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 41ceb1580a4e8..fd4ae6ffc72be 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -296,6 +297,11 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+void mlirWalkAndApplyPatterns(MlirOperation op,
+                              MlirFrozenRewritePatternSet patterns) {
+  mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
+}
+
 //===----------------------------------------------------------------------===//
 /// PatternRewriter API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 821e47085a5bd..e40d5eb92b86f 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -68,3 +68,19 @@ def constant_1_to_2(op, rewriter):
         # CHECK: %c3_i64 = arith.constant 3 : i64
         # CHECK: return %c2_i64, %c3_i64 : i64, i64
         print(module)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        walk_and_apply_patterns(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)


        


More information about the Mlir-commits mailing list