[Mlir-commits] [mlir] aac4eb5 - [MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 8 09:05:50 PDT 2025


Author: Twice
Date: 2025-09-08T16:05:45Z
New Revision: aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c

URL: https://github.com/llvm/llvm-project/commit/aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c
DIFF: https://github.com/llvm/llvm-project/commit/aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c.diff

LOG: [MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)

In https://github.com/llvm/llvm-project/pull/94714, we add a python
function `apply_patterns_and_fold_greedily` which accepts an
`MlirModule` as the argument type. However, sometimes we want to apply
patterns with an `MlirOperation` argument, and there is currently no
python API to convert an `MlirOperation` to `MlirModule`.

So here we overload this function `apply_patterns_and_fold_greedily` to
do this (also a corresponding new C API
`mlirApplyPatternsAndFoldGreedilyWithOp`)

Added: 
    

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp
    mlir/test/python/integration/dialects/pdl.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
 MLIR_CAPI_EXPORTED void
 mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
 
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+    MlirOperation op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
     MlirGreedyRewriteDriverConfig);

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..5b7de50f02e6a 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
-      "apply_patterns_and_fold_greedily",
-      [](MlirModule module, MlirFrozenRewritePatternSet set) {
-        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
-        if (mlirLogicalResultIsFailure(status))
-          // FIXME: Not sure this is the right error to throw here.
-          throw nb::value_error("pattern application failed to converge");
-      },
-      "module"_a, "set"_a,
-      "Applys the given patterns to the given module greedily while folding "
-      "results.");
+       "apply_patterns_and_fold_greedily",
+       [](PyModule &module, MlirFrozenRewritePatternSet set) {
+         auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+         if (mlirLogicalResultIsFailure(status))
+           throw std::runtime_error("pattern application failed to converge");
+       },
+       "module"_a, "set"_a,
+       "Applys the given patterns to the given module greedily while folding "
+       "results.")
+      .def(
+          "apply_patterns_and_fold_greedily",
+          [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
+            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+                op.getOperation(), set, {});
+            if (mlirLogicalResultIsFailure(status))
+              throw std::runtime_error(
+                  "pattern application failed to converge");
+          },
+          "op"_a, "set"_a,
+          "Applys the given patterns to the given op greedily while folding "
+          "results.");
 }

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
 }
 
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+                                       MlirFrozenRewritePatternSet patterns,
+                                       MlirGreedyRewriteDriverConfig) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..dd6c74ce622c8 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
             print(module)
     return f
 
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
- at construct_and_print_in_module
-def test_add_to_mul(module_):
-    index_type = IndexType.get()
-
-    # Create a test case.
-    @module(sym_name="ir")
-    def ir():
-        @func.func(index_type, index_type)
-        def add_func(a, b):
-            return arith.addi(a, b)
-
+def get_pdl_patterns():
     # Create a rewrite from add to mul. This will match
     # - operation name is arith.addi
     # - operands are index types.
@@ -61,7 +48,41 @@ def rew():
     # not yet captured Python side/has sharp edges. So best to construct the
     # module and PDL module in same scope.
     # FIXME: This should be made more robust.
-    frozen = PDLModule(m).freeze()
+    return PDLModule(m).freeze()
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
     # Could apply frozen pattern set multiple times.
     apply_patterns_and_fold_greedily(module_, frozen)
     return module_
+
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    frozen = get_pdl_patterns()
+    apply_patterns_and_fold_greedily(module_.operation, frozen)
+    return module_


        


More information about the Mlir-commits mailing list