[Mlir-commits] [mlir] [MLIR][Python] Refine the support of `RewritePatternSet.add` (PR #173874)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 29 08:33:04 PST 2025


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/173874

This patch includes the following changes:
- `RewritePatternSet.add` now accepts op name (e.g. `.add("arith.addi", fn)`) besides op class (e.g. `.add(arith.AddIOp, fn)`)
- add a concrete signature and a more complete docstring to `RewritePatternSet.add`.

>From f5183866210a4e78296e27f73c9ca4db7625a86f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 29 Dec 2025 23:47:07 +0800
Subject: [PATCH 1/4] [MLIR][Python] Refine the support of
 RewritePatternSet.add

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 28 +++++++++++++++++++++++-----
 mlir/test/python/rewrite.py          |  2 +-
 2 files changed, 24 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 0f0ed22c50fa9..1b640a6f9e10f 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -277,15 +277,33 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           "add",
           [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
              unsigned benefit) {
-            std::string opName =
-                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            std::string opName;
+            if (root.is_type()) {
+              opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            } else if (nb::isinstance<nb::str>(root)) {
+              opName = nb::cast<std::string>(root);
+            } else {
+              throw nb::type_error(
+                  "the root argument must be a type or a string");
+            }
             self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
                      fn);
           },
           "root"_a, "fn"_a, "benefit"_a = 1,
-          "Add a new rewrite pattern on the given root operation with the "
-          "callable as the matching and rewriting function and the given "
-          "benefit.")
+          // clang-format off
+          nb::sig("def add(self, root: type | str, fn: Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], Any], benefit: int = 1) -> None"),
+          // clang-format on
+          R"(
+            Add a new rewrite pattern on the given root operation with the
+            callable as the matching and rewriting function and the given benefit.
+
+            Args:
+              root: The root operation to apply the pattern on,
+                    which can be an OpView class (type) or an operation name (str).
+              fn: The callable to use for matching and rewriting,
+                  which takes an operation and a pattern rewriter as arguments.
+                  The matching succeeds iff the callable returns a value castable to False (e.g. None).
+              benefit: The benefit of the pattern, defaults to 1.)")
       .def("freeze", &PyRewritePatternSet::freeze,
            "Freeze the pattern set into a frozen one.");
 
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 821e47085a5bd..a6027161f29db 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -32,7 +32,7 @@ def constant_1_to_2(op, rewriter):
     with Context():
         patterns = RewritePatternSet()
         patterns.add(arith.AddIOp, to_muli)
-        patterns.add(arith.ConstantOp, constant_1_to_2)
+        patterns.add("arith.constant", constant_1_to_2)
         frozen = patterns.freeze()
 
         module = ModuleOp.parse(

>From 7ce1975477cd51775389ee345535c56a735a2847 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 29 Dec 2025 23:52:50 +0800
Subject: [PATCH 2/4] fix

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 1b640a6f9e10f..ba341b76579f9 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -291,7 +291,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           },
           "root"_a, "fn"_a, "benefit"_a = 1,
           // clang-format off
-          nb::sig("def add(self, root: type | str, fn: Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], Any], benefit: int = 1) -> None"),
+          nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
           // clang-format on
           R"(
             Add a new rewrite pattern on the given root operation with the

>From 7ebba860bd1b6b134c79e9f8f222badbb548faeb Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 30 Dec 2025 00:11:17 +0800
Subject: [PATCH 3/4] fix

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index ba341b76579f9..6cc8434c8020f 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -299,10 +299,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
 
             Args:
               root: The root operation to apply the pattern on,
-                    which can be an OpView class (type) or an operation name (str).
+                    which can be an OpView subclass (e.g. arith.AddIOp)
+                    or an operation name string (e.g. "arith.addi").
               fn: The callable to use for matching and rewriting,
                   which takes an operation and a pattern rewriter as arguments.
-                  The matching succeeds iff the callable returns a value castable to False (e.g. None).
+                  The matching succeeds iff the callable returns
+                  a value castable to False (e.g. None).
               benefit: The benefit of the pattern, defaults to 1.)")
       .def("freeze", &PyRewritePatternSet::freeze,
            "Freeze the pattern set into a frozen one.");

>From 5acdb5bfc93015175697f39445c5f22db82a76c4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 30 Dec 2025 00:22:40 +0800
Subject: [PATCH 4/4] fix

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 6cc8434c8020f..f7557c3f7f768 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -294,18 +294,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
           // clang-format on
           R"(
-            Add a new rewrite pattern on the given root operation with the
-            callable as the matching and rewriting function and the given benefit.
+            Add a new rewrite pattern on the specified root operation, using the provided callable
+            for matching and rewriting, and assign it the given benefit.
 
             Args:
-              root: The root operation to apply the pattern on,
-                    which can be an OpView subclass (e.g. arith.AddIOp)
-                    or an operation name string (e.g. "arith.addi").
+              root: The root operation to which this pattern applies.
+                    This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
+                    an operation name string (e.g., ``"arith.addi"``).
               fn: The callable to use for matching and rewriting,
                   which takes an operation and a pattern rewriter as arguments.
-                  The matching succeeds iff the callable returns
-                  a value castable to False (e.g. None).
-              benefit: The benefit of the pattern, defaults to 1.)")
+                  The match is considered successful iff the callable returns
+                  a value where ``bool(value)`` is ``False`` (e.g. ``None``).
+                  If possible, the operation is cast to its corresponding OpView subclass
+                  before being passed to the callable.
+              benefit: The benefit of the pattern, defaulting to 1.)")
       .def("freeze", &PyRewritePatternSet::freeze,
            "Freeze the pattern set into a frozen one.");
 



More information about the Mlir-commits mailing list