[Mlir-commits] [mlir] b688360 - [MLIR][Python] Refine the support of `RewritePatternSet.add` (#173874)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 29 18:16:32 PST 2025
Author: Twice
Date: 2025-12-30T10:16:27+08:00
New Revision: b6883607c036e4a539555f3849b63da4eda7c956
URL: https://github.com/llvm/llvm-project/commit/b6883607c036e4a539555f3849b63da4eda7c956
DIFF: https://github.com/llvm/llvm-project/commit/b6883607c036e4a539555f3849b63da4eda7c956.diff
LOG: [MLIR][Python] Refine the support of `RewritePatternSet.add` (#173874)
This patch includes the following changes:
- `RewritePatternSet.add` now accepts op name (e.g. `.add("arith.addi",
fn)`) besides op class (e.g. `.add(arith.AddIOp, fn)`)
- add a concrete signature and a more complete docstring to
`RewritePatternSet.add`.
Added:
Modified:
mlir/lib/Bindings/Python/Rewrite.cpp
mlir/test/python/rewrite.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0df9d0cbc7ffc..2a5129b7f4ab1 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -277,15 +277,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
"add",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
unsigned benefit) {
- std::string opName =
- nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ std::string opName;
+ if (root.is_type()) {
+ opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ } else if (nb::isinstance<nb::str>(root)) {
+ opName = nb::cast<std::string>(root);
+ } else {
+ throw nb::type_error(
+ "the root argument must be a type or a string");
+ }
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
fn);
},
"root"_a, "fn"_a, "benefit"_a = 1,
- "Add a new rewrite pattern on the given root operation with the "
- "callable as the matching and rewriting function and the given "
- "benefit.")
+ // clang-format off
+ nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
+ // clang-format on
+ R"(
+ Add a new rewrite pattern on the specified root operation, using the provided callable
+ for matching and rewriting, and assign it the given benefit.
+
+ Args:
+ root: The root operation to which this pattern applies.
+ This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
+ an operation name string (e.g., ``"arith.addi"``).
+ fn: The callable to use for matching and rewriting,
+ which takes an operation and a pattern rewriter as arguments.
+ The match is considered successful iff the callable returns
+ a value where ``bool(value)`` is ``False`` (e.g. ``None``).
+ If possible, the operation is cast to its corresponding OpView subclass
+ before being passed to the callable.
+ benefit: The benefit of the pattern, defaulting to 1.)")
.def("freeze", &PyRewritePatternSet::freeze,
"Freeze the pattern set into a frozen one.");
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index e40d5eb92b86f..80f076f36a416 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -32,7 +32,7 @@ def constant_1_to_2(op, rewriter):
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli)
- patterns.add(arith.ConstantOp, constant_1_to_2)
+ patterns.add("arith.constant", constant_1_to_2)
frozen = patterns.freeze()
module = ModuleOp.parse(
More information about the Mlir-commits
mailing list