[Mlir-commits] [mlir] aac4eb5 - [MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 8 09:05:50 PDT 2025
Author: Twice
Date: 2025-09-08T16:05:45Z
New Revision: aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c
URL: https://github.com/llvm/llvm-project/commit/aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c
DIFF: https://github.com/llvm/llvm-project/commit/aac4eb5c3c393efb63ec1a24fe3dba2afd3e092c.diff
LOG: [MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)
In https://github.com/llvm/llvm-project/pull/94714, we add a python
function `apply_patterns_and_fold_greedily` which accepts an
`MlirModule` as the argument type. However, sometimes we want to apply
patterns with an `MlirOperation` argument, and there is currently no
python API to convert an `MlirOperation` to `MlirModule`.
So here we overload this function `apply_patterns_and_fold_greedily` to
do this (also a corresponding new C API
`mlirApplyPatternsAndFoldGreedilyWithOp`)
Added:
Modified:
mlir/include/mlir-c/Rewrite.h
mlir/lib/Bindings/Python/Rewrite.cpp
mlir/lib/CAPI/Transforms/Rewrite.cpp
mlir/test/python/integration/dialects/pdl.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 61d3446317550..374d2fb78de88 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
+ MlirOperation op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0373f9c7affe9..5b7de50f02e6a 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
- "apply_patterns_and_fold_greedily",
- [](MlirModule module, MlirFrozenRewritePatternSet set) {
- auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
- if (mlirLogicalResultIsFailure(status))
- // FIXME: Not sure this is the right error to throw here.
- throw nb::value_error("pattern application failed to converge");
- },
- "module"_a, "set"_a,
- "Applys the given patterns to the given module greedily while folding "
- "results.");
+ "apply_patterns_and_fold_greedily",
+ [](PyModule &module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily",
+ [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+ op.getOperation(), set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ "Applys the given patterns to the given op greedily while folding "
+ "results.");
}
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index a4df97f7beace..6f85357a14a18 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index 923af29a71ad7..dd6c74ce622c8 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f
-
-# CHECK-LABEL: TEST: test_add_to_mul
-# CHECK: arith.muli
- at construct_and_print_in_module
-def test_add_to_mul(module_):
- index_type = IndexType.get()
-
- # Create a test case.
- @module(sym_name="ir")
- def ir():
- @func.func(index_type, index_type)
- def add_func(a, b):
- return arith.addi(a, b)
-
+def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
# - operands are index types.
@@ -61,7 +48,41 @@ def rew():
# not yet captured Python side/has sharp edges. So best to construct the
# module and PDL module in same scope.
# FIXME: This should be made more robust.
- frozen = PDLModule(m).freeze()
+ return PDLModule(m).freeze()
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
# Could apply frozen pattern set multiple times.
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# CHECK-LABEL: TEST: test_add_to_mul_with_op
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul_with_op(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ frozen = get_pdl_patterns()
+ apply_patterns_and_fold_greedily(module_.operation, frozen)
+ return module_
More information about the Mlir-commits
mailing list