[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 22:09:33 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/162699

>From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:03:10 +0800
Subject: [PATCH 01/10] [MLIR][Python] Support Python-defined rewrite patterns

---
 mlir/include/mlir-c/Rewrite.h        | 33 +++++++++++
 mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++-
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++-
 mlir/test/python/rewrite.py          | 49 ++++++++++++++++
 4 files changed, 245 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/python/rewrite.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..68bb112404170 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 MLIR_CAPI_EXPORTED MlirRewriterBase
 mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+typedef unsigned short MlirPatternBenefit;
+
+typedef struct {
+  void (*construct)(void *userData);
+  void (*destruct)(void *userData);
+  MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+                                       MlirOperation op,
+                                       MlirPatternRewriter rewriter,
+                                       void *userData);
+} MlirRewritePatternCallbacks;
+
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                                                 MlirRewritePattern pattern);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..3740c59e62001 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,14 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  void replaceOp(MlirOperation op, MlirOperation newOp) {
+    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+  }
+
+  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+  }
+
 private:
   MlirRewriterBase base;
   PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+
+  void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+      return logicalResultFromObject(res);
+    };
+    MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set, pattern);
+  }
+
+  PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+
+private:
+  MlirRewritePatternSet set;
+  MlirContext ctx;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the PatternRewriter
+  //----------------------------------------------------------------------------
   nb::class_<PyPatternRewriter>(m, "PatternRewriter")
       .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.");
+                   "The current insertion point of the PatternRewriter.")
+      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+                            MlirOperation newOp) { self.replaceOp(op, newOp); })
+      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+                            const std::vector<MlirValue> &values) {
+        self.replaceOp(op, values);
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of the RewritePatternSet
+  //----------------------------------------------------------------------------
+  nb::class_<MlirRewritePattern>(m, "RewritePattern");
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             unsigned benefit) {
+            std::string opName =
+                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+                     fn);
+          },
+          "root"_a, "fn"_a, "benefit"_a = 1)
+      .def("freeze", &PyRewritePatternSet::freeze);
+
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "freeze",
           [](PyPDLPatternModule &self) {
-            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
                 mlirRewritePatternSetFromPDLPatternModule(self.get())));
           },
           nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..f3430e2e78978 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
 /// RewritePatternSet and FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
   assert(module.ptr && "unexpected null module");
-  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+  return static_cast<mlir::RewritePatternSet *>(module.ptr);
 }
 
 static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
 }
 
 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
-  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
   op.ptr = nullptr;
   return wrap(m);
 }
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+  assert(pattern.ptr && "unexpected null pattern");
+  return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+  return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+  ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+                         StringRef rootName, PatternBenefit benefit,
+                         MLIRContext *context,
+                         ArrayRef<StringRef> generatedNames)
+      : RewritePattern(rootName, benefit, context, generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalRewritePattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+        wrap(&rewriter), userData));
+  }
+
+private:
+  MlirRewritePatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i) {
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  }
+  return wrap(new mlir::ExternalRewritePattern(
+      callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+      unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+  return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+  delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                              MlirRewritePattern pattern) {
+  std::unique_ptr<mlir::RewritePattern> patternPtr(
+      const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+  pattern.ptr = nullptr;
+  unwrap(set)->add(std::move(patternPtr));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6aed936f94d87
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+    def to_muli(op, rewriter, pattern):
+        with rewriter.ip:
+            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+        rewriter.replace_op(op, new_op.owner)
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.AddIOp, to_muli)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)

>From 61b87af618652e26f71400d7b238f1597a2ca364 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:56:46 +0800
Subject: [PATCH 02/10] format

---
 mlir/test/python/rewrite.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 6aed936f94d87..c7b6c1f19991e 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -19,6 +19,7 @@ def run(f):
     gc.collect()
     assert Context._get_live_count() == 0
 
+
 # CHECK-LABEL: TEST: testRewritePattern
 @run
 def testRewritePattern():

>From 395627f8987187ca8f45dcefe6c9167b69a3f7d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:35:15 +0800
Subject: [PATCH 03/10] add docs for C API

---
 mlir/include/mlir-c/Rewrite.h | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 68bb112404170..cc021bcfba889 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -329,17 +329,28 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 /// RewritePattern API
 //===----------------------------------------------------------------------===//
 
+/// PatternBenefit represents the benefit of a pattern match.
 typedef unsigned short MlirPatternBenefit;
 
+/// Callbacks to construct a rewrite pattern.
 typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
   void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
   void (*destruct)(void *userData);
+  /// The callback function to match against code rooted at the specified
+  /// operation, and perform the rewrite if the match is successful,
+  /// corresponding to RewritePattern::matchAndRewrite.
   MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
                                        MlirOperation op,
                                        MlirPatternRewriter rewriter,
                                        void *userData);
 } MlirRewritePatternCallbacks;
 
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
 MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
     MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
     MlirRewritePatternCallbacks callbacks, void *userData,
@@ -349,11 +360,14 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
 /// RewritePatternSet API
 //===----------------------------------------------------------------------===//
 
+/// Create an empty MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetCreate(MlirContext context);
 
+/// Destruct the given MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
 
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
                                                  MlirRewritePattern pattern);
 

>From 0ddd081a3eb27348c7b87058edcf8eb437c796a0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:45:09 +0800
Subject: [PATCH 04/10] add more docs and fix some name

---
 mlir/include/mlir-c/Rewrite.h        | 10 ++++++++--
 mlir/lib/Bindings/Python/Rewrite.cpp |  6 +++++-
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 13 +++++++------
 3 files changed, 20 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index cc021bcfba889..66a9a5de1669d 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -303,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
 /// FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
+/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
+/// Note that the ownership of the input set is transferred into the frozen set
+/// after this call.
 MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
-mlirFreezeRewritePattern(MlirRewritePatternSet op);
+mlirFreezeRewritePattern(MlirRewritePatternSet set);
 
+/// Destroy the given MlirFrozenRewritePatternSet.
 MLIR_CAPI_EXPORTED void
-mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
     MlirOperation op, MlirFrozenRewritePatternSet patterns,
@@ -368,6 +372,8 @@ mlirRewritePatternSetCreate(MlirContext context);
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
 
 /// Add the given MlirRewritePattern into a MlirRewritePatternSet.
+/// Note that the ownership of the pattern is transferred to the set after this
+/// call.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
                                                  MlirRewritePattern pattern);
 
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 3740c59e62001..9c99c6a4366b5 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -202,7 +202,11 @@ class PyRewritePatternSet {
     mlirRewritePatternSetAdd(set, pattern);
   }
 
-  PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+  PyFrozenRewritePatternSet freeze() {
+    MlirRewritePatternSet s = set;
+    set.ptr = nullptr;
+    return mlirFreezeRewritePattern(s);
+  }
 
 private:
   MlirRewritePatternSet set;
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index f3430e2e78978..7e7a4f7715bb4 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -291,15 +291,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
   return {module};
 }
 
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
-  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
-  op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+  set.ptr = nullptr;
   return wrap(m);
 }
 
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
-  delete unwrap(op);
-  op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+  delete unwrap(set);
+  set.ptr = nullptr;
 }
 
 MlirLogicalResult

>From da4bb8b560b3bc49d5064281ef407c618d24787c Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:28:45 +0800
Subject: [PATCH 05/10] add nb::sigs and python api docs

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 55 ++++++++++++++++++++++------
 1 file changed, 43 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9c99c6a4366b5..07559457f2f2f 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -53,6 +53,8 @@ class PyPatternRewriter {
     mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
   }
 
+  void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
 private:
   MlirRewriterBase base;
   PyMlirContextRef ctx;
@@ -177,7 +179,10 @@ class PyRewritePatternSet {
 public:
   PyRewritePatternSet(MlirContext ctx)
       : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
-  ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+  ~PyRewritePatternSet() {
+    if (set.ptr)
+      mlirRewritePatternSetDestroy(set);
+  }
 
   void add(MlirStringRef rootName, MlirPatternBenefit benefit,
            const nb::callable &matchAndRewrite) {
@@ -220,15 +225,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
   // Mapping of the PatternRewriter
   //----------------------------------------------------------------------------
-  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
-      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.")
-      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
-                            MlirOperation newOp) { self.replaceOp(op, newOp); })
-      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
-                            const std::vector<MlirValue> &values) {
-        self.replaceOp(op, values);
-      });
+  nb::
+      class_<PyPatternRewriter>(m, "PatternRewriter")
+          .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                       "The current insertion point of the PatternRewriter.")
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 MlirOperation newOp) { self.replaceOp(op, newOp); },
+              "Replace an operation with a new operation.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+              // clang-format on
+              )
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 const std::vector<MlirValue> &values) {
+                self.replaceOp(op, values);
+              },
+              "Replace an operation with a list of values.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+              // clang-format on
+              )
+          .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+               // clang-format off
+                nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+               // clang-format on
+          );
 
   //----------------------------------------------------------------------------
   // Mapping of the RewritePatternSet
@@ -250,8 +277,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
             self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
                      fn);
           },
-          "root"_a, "fn"_a, "benefit"_a = 1)
-      .def("freeze", &PyRewritePatternSet::freeze);
+          "root"_a, "fn"_a, "benefit"_a = 1,
+          "Add a new rewrite pattern on the given root operation with the "
+          "callable as the matching and rewriting function and the given "
+          "benefit.")
+      .def("freeze", &PyRewritePatternSet::freeze,
+           "Freeze the pattern set into a frozen one.");
 
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule

>From 5333a6ef08a3286b83494089b568f0f4087f77a7 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:42:21 +0800
Subject: [PATCH 06/10] add more examples

---
 mlir/lib/CAPI/Transforms/Rewrite.cpp |  1 -
 mlir/test/python/rewrite.py          | 27 +++++++++++++++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 7e7a4f7715bb4..d7c8e53f2bba6 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,7 +17,6 @@
 #include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index c7b6c1f19991e..cbc3a4043f96c 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -28,9 +28,18 @@ def to_muli(op, rewriter, pattern):
             new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
         rewriter.replace_op(op, new_op.owner)
 
+    def constant_1_to_2(op, rewriter, pattern):
+        c = op.attributes["value"].value
+        if c != 1:
+            return True # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.result.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
     with Context():
         patterns = RewritePatternSet()
         patterns.add(arith.AddIOp, to_muli)
+        patterns.add(arith.ConstantOp, constant_1_to_2)
         frozen = patterns.freeze()
 
         module = ModuleOp.parse(
@@ -48,3 +57,21 @@ def to_muli(op, rewriter, pattern):
         # CHECK: %0 = arith.muli %arg0, %arg1 : i64
         # CHECK: return %0 : i64
         print(module)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 3 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c3_i64 = arith.constant 3 : i64
+        # CHECK: return %c2_i64, %c3_i64 : i64, i64
+        print(module)

>From a57961fc66c12529e957086869e008e835b70a54 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 11:47:17 +0800
Subject: [PATCH 07/10] fix format

---
 mlir/test/python/rewrite.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index cbc3a4043f96c..4537068a5b9d5 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -31,7 +31,7 @@ def to_muli(op, rewriter, pattern):
     def constant_1_to_2(op, rewriter, pattern):
         c = op.attributes["value"].value
         if c != 1:
-            return True # failed to match
+            return True  # failed to match
         with rewriter.ip:
             new_op = arith.constant(op.result.type, 2, loc=op.location)
         rewriter.replace_op(op, [new_op])

>From 43da9a2cbe6d074fa863e6d500b18fc1d0a62894 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 10 Oct 2025 12:59:58 +0800
Subject: [PATCH 08/10] Update mlir/lib/Bindings/Python/Rewrite.cpp

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
 mlir/lib/Bindings/Python/Rewrite.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 07559457f2f2f..c938360756f03 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -250,7 +250,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
               nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
                 ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
               // clang-format on
-              )
+              nb::arg("op"), nb::arg("values"))
           .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
                // clang-format off
                 nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")

>From 75c2dd90ae36c92ef184dda1a27150f5ace66aaf Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 13:06:55 +0800
Subject: [PATCH 09/10] reformat nb::sigs and add nb::args

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c938360756f03..078593955bf9c 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -233,10 +233,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
               "replace_op",
               [](PyPatternRewriter &self, MlirOperation op,
                  MlirOperation newOp) { self.replaceOp(op, newOp); },
-              "Replace an operation with a new operation.",
+              "Replace an operation with a new operation.", nb::arg("op"),
+              nb::arg("new_op"),
               // clang-format off
-              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
-                ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
               // clang-format on
               )
           .def(
@@ -245,13 +245,14 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
                  const std::vector<MlirValue> &values) {
                 self.replaceOp(op, values);
               },
-              "Replace an operation with a list of values.",
+              "Replace an operation with a list of values.", nb::arg("op"),
+              nb::arg("values"),
               // clang-format off
-              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
-                ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
               // clang-format on
-              nb::arg("op"), nb::arg("values"))
+              )
           .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+               nb::arg("op"),
                // clang-format off
                 nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
                // clang-format on

>From 64d98e42960545330ee4842f8d81d12664f12784 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 13:09:16 +0800
Subject: [PATCH 10/10] remove log()

---
 mlir/test/python/rewrite.py | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 4537068a5b9d5..546a4fb720a98 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -8,13 +8,8 @@
 from mlir.rewrite import *
 
 
-def log(*args):
-    print(*args, file=sys.stderr)
-    sys.stderr.flush()
-
-
 def run(f):
-    log("\nTEST:", f.__name__)
+    print("\nTEST:", f.__name__)
     f()
     gc.collect()
     assert Context._get_live_count() == 0



More information about the Mlir-commits mailing list