[Mlir-commits] [mlir] [MLIR][Python] Add a python function to apply patterns with MlirOperation (PR #157487)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 8 08:14:59 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/157487
>From 483b71cfbbfa7510add4e68b37af2f458ff0f7f8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 22:49:39 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add a python function to apply patterns
with MlirOperation
---
mlir/include/mlir-c/Rewrite.h | 4 ++++
mlir/lib/Bindings/Python/Rewrite.cpp | 30 ++++++++++++++++++----------
mlir/lib/CAPI/Transforms/Rewrite.cpp | 7 +++++++
3 files changed, 31 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..feb22485c5609 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,24 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily_with_op",
+ [](MlirOperation op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
>From 27ca16036a081d66527cac4fe277067defcbe6e5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 8 Sep 2025 23:14:46 +0800
Subject: [PATCH 2/2] add test case
---
mlir/test/python/integration/dialects/pdl.py | 49 ++++++++++++++------
1 file changed, 34 insertions(+), 15 deletions(-)
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..42d3707017e17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
- at construct_and_print_in_module
-def test_add_to_mul(module_):
- index_type = IndexType.get()
-
- # Create a test case.
- @module(sym_name="ir")
- def ir():
- @func.func(index_type, index_type)
- def add_func(a, b):
- return arith.addi(a, b)
-
+def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
# - operands are index types.
@@ -61,7 +48,39 @@ def rew():
# not yet captured Python side/has sharp edges. So best to construct the
# module and PDL module in same scope.
# FIXME: This should be made more robust.
- frozen = PDLModule(m).freeze()
+ return PDLModule(m).freeze()
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
# Could apply frozen pattern set multiple times.
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
+ apply_patterns_and_fold_greedily_with_op(module_.operation, frozen)
+ return module_
More information about the Mlir-commits
mailing list