[Mlir-commits] [mlir] [MLIR][Python] Add `GreedyRewriteDriverConfig` parameter to `apply_patterns_and_fold_greedily` (PR #174785)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 07:22:07 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

We already have `GreedyRewriteDriverConfig` on the Python side, but it hasn’t yet been exposed as a parameter of `apply_patterns_and_fold_greedily`. This PR does that.

Before:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None
```

After:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None
```

---
Full diff: https://github.com/llvm/llvm-project/pull/174785.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+22-6) 
- (modified) mlir/test/python/rewrite.py (+44) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9830c277ac147..faab66d5ce4e5 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -509,26 +509,42 @@ void populateRewriteSubmodule(nb::module_ &m) {
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
        "apply_patterns_and_fold_greedily",
-       [](PyModule &module, PyFrozenRewritePatternSet &set) {
+       [](PyModule &module, PyFrozenRewritePatternSet &set, nb::object config) {
+         if (config.is_none()) {
+           config = nb::cast(PyGreedyRewriteDriverConfig());
+         }
+
          auto status = mlirApplyPatternsAndFoldGreedily(
-             module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+             module.get(), set.get(),
+             nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
          if (mlirLogicalResultIsFailure(status))
            throw std::runtime_error("pattern application failed to converge");
        },
-       "module"_a, "set"_a,
+       "module"_a, "set"_a, "config"_a = nb::none(),
+       // clang-format off
+       nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
+       // clang-format on
        "Applys the given patterns to the given module greedily while folding "
        "results.")
       .def(
           "apply_patterns_and_fold_greedily",
-          [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+          [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+             nb::object config) {
+            if (config.is_none()) {
+              config = nb::cast(PyGreedyRewriteDriverConfig());
+            }
+
             auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
                 op.getOperation(), set.get(),
-                mlirGreedyRewriteDriverConfigCreate());
+                nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
           },
-          "op"_a, "set"_a,
+          "op"_a, "set"_a, "config"_a = nb::none(),
+          // clang-format off
+          nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
+          // clang-format on
           "Applys the given patterns to the given op greedily while folding "
           "results.")
       .def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
     config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
     level = config.region_simplification_level
     print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+ at run
+def testRewriteWithGreedyRewriteDriverConfig():
+    def constant_1_to_2(op, rewriter):
+        c = op.value.value
+        if c != 1:
+            return True  # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.ConstantOp, constant_1_to_2)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 1 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = False
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c2_i64_0 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+        print(module)
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = True
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64 : i64
+        print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/174785


More information about the Mlir-commits mailing list