[Mlir-commits] [mlir] [MLIR][Python] Add a python function to apply patterns with MlirOperation (PR #157487)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 8 08:14:59 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/157487

>From 483b71cfbbfa7510add4e68b37af2f458ff0f7f8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 22:49:39 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add a python function to apply patterns
 with MlirOperation

---
 mlir/include/mlir-c/Rewrite.h        |  4 ++++
 mlir/lib/Bindings/Python/Rewrite.cpp | 30 ++++++++++++++++++----------
 mlir/lib/CAPI/Transforms/Rewrite.cpp |  7 +++++++
 3 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..feb22485c5609 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,24 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](MlirModule module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+         if (mlirLogicalResultIsFailure(status))
+           throw std::runtime_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily_with_op",
+          [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+            if (mlirLogicalResultIsFailure(status))
+              throw std::runtime_error(
+                  "pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+                                       MlirFrozenRewritePatternSet patterns,
+                                       MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//

>From 27ca16036a081d66527cac4fe277067defcbe6e5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 23:14:46 +0800
Subject: [PATCH 2/2] add test case

---
 mlir/test/python/integration/dialects/pdl.py | 49 ++++++++++++++------
 1 file changed, 34 insertions(+), 15 deletions(-)

diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..42d3707017e17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
             print(module)
     return f
 
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
- at construct_and_print_in_module
-def test_add_to_mul(module_):
-    index_type = IndexType.get()
-
-    # Create a test case.
-    @module(sym_name="ir")
-    def ir():
-        @func.func(index_type, index_type)
-        def add_func(a, b):
-            return arith.addi(a, b)
-
+def get_pdl_patterns():
     # Create a rewrite from add to mul. This will match
     # - operation name is arith.addi
     # - operands are index types.
@@ -61,7 +48,39 @@ def rew():
     # not yet captured Python side/has sharp edges. So best to construct the
     # module and PDL module in same scope.
     # FIXME: This should be made more robust.
-    frozen = PDLModule(m).freeze()
+    return PDLModule(m).freeze()
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
     # Could apply frozen pattern set multiple times.
     apply_patterns_and_fold_greedily(module_, frozen)
     return module_
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
+    apply_patterns_and_fold_greedily_with_op(module_.operation, frozen)
+    return module_



More information about the Mlir-commits mailing list