[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 19:36:01 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/162699

>From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:03:10 +0800
Subject: [PATCH 1/3] [MLIR][Python] Support Python-defined rewrite patterns

---
 mlir/include/mlir-c/Rewrite.h        | 33 +++++++++++
 mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++-
 mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++-
 mlir/test/python/rewrite.py          | 49 ++++++++++++++++
 4 files changed, 245 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/python/rewrite.py

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..68bb112404170 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 MLIR_CAPI_EXPORTED MlirRewriterBase
 mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+typedef unsigned short MlirPatternBenefit;
+
+typedef struct {
+  void (*construct)(void *userData);
+  void (*destruct)(void *userData);
+  MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+                                       MlirOperation op,
+                                       MlirPatternRewriter rewriter,
+                                       void *userData);
+} MlirRewritePatternCallbacks;
+
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                                                 MlirRewritePattern pattern);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..3740c59e62001 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,14 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  void replaceOp(MlirOperation op, MlirOperation newOp) {
+    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+  }
+
+  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+  }
+
 private:
   MlirRewriterBase base;
   PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
+
+  void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+      return logicalResultFromObject(res);
+    };
+    MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set, pattern);
+  }
+
+  PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
+
+private:
+  MlirRewritePatternSet set;
+  MlirContext ctx;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the PatternRewriter
+  //----------------------------------------------------------------------------
   nb::class_<PyPatternRewriter>(m, "PatternRewriter")
       .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.");
+                   "The current insertion point of the PatternRewriter.")
+      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+                            MlirOperation newOp) { self.replaceOp(op, newOp); })
+      .def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
+                            const std::vector<MlirValue> &values) {
+        self.replaceOp(op, values);
+      });
+
+  //----------------------------------------------------------------------------
+  // Mapping of the RewritePatternSet
+  //----------------------------------------------------------------------------
+  nb::class_<MlirRewritePattern>(m, "RewritePattern");
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             unsigned benefit) {
+            std::string opName =
+                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+                     fn);
+          },
+          "root"_a, "fn"_a, "benefit"_a = 1)
+      .def("freeze", &PyRewritePatternSet::freeze);
+
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "freeze",
           [](PyPDLPatternModule &self) {
-            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
                 mlirRewritePatternSetFromPDLPatternModule(self.get())));
           },
           nb::keep_alive<0, 1>())
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..f3430e2e78978 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/PDLPatternMatch.h.inc"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
 /// RewritePatternSet and FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
   assert(module.ptr && "unexpected null module");
-  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+  return static_cast<mlir::RewritePatternSet *>(module.ptr);
 }
 
 static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
 }
 
 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
-  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
   op.ptr = nullptr;
   return wrap(m);
 }
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
+  assert(pattern.ptr && "unexpected null pattern");
+  return static_cast<const mlir::RewritePattern *>(pattern.ptr);
+}
+
+inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
+  return {pattern};
+}
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+  ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+                         StringRef rootName, PatternBenefit benefit,
+                         MLIRContext *context,
+                         ArrayRef<StringRef> generatedNames)
+      : RewritePattern(rootName, benefit, context, generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalRewritePattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+        wrap(&rewriter), userData));
+  }
+
+private:
+  MlirRewritePatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i) {
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  }
+  return wrap(new mlir::ExternalRewritePattern(
+      callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+      unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+  return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+  delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                              MlirRewritePattern pattern) {
+  std::unique_ptr<mlir::RewritePattern> patternPtr(
+      const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+  pattern.ptr = nullptr;
+  unwrap(set)->add(std::move(patternPtr));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6aed936f94d87
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def log(*args):
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
+
+def run(f):
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+    def to_muli(op, rewriter, pattern):
+        with rewriter.ip:
+            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+        rewriter.replace_op(op, new_op.owner)
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.AddIOp, to_muli)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)

>From 61b87af618652e26f71400d7b238f1597a2ca364 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 00:56:46 +0800
Subject: [PATCH 2/3] format

---
 mlir/test/python/rewrite.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 6aed936f94d87..c7b6c1f19991e 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -19,6 +19,7 @@ def run(f):
     gc.collect()
     assert Context._get_live_count() == 0
 
+
 # CHECK-LABEL: TEST: testRewritePattern
 @run
 def testRewritePattern():

>From 395627f8987187ca8f45dcefe6c9167b69a3f7d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 10 Oct 2025 10:35:15 +0800
Subject: [PATCH 3/3] add docs for C API

---
 mlir/include/mlir-c/Rewrite.h | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 68bb112404170..cc021bcfba889 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -329,17 +329,28 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 /// RewritePattern API
 //===----------------------------------------------------------------------===//
 
+/// PatternBenefit represents the benefit of a pattern match.
 typedef unsigned short MlirPatternBenefit;
 
+/// Callbacks to construct a rewrite pattern.
 typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
   void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
   void (*destruct)(void *userData);
+  /// The callback function to match against code rooted at the specified
+  /// operation, and perform the rewrite if the match is successful,
+  /// corresponding to RewritePattern::matchAndRewrite.
   MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
                                        MlirOperation op,
                                        MlirPatternRewriter rewriter,
                                        void *userData);
 } MlirRewritePatternCallbacks;
 
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
 MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
     MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
     MlirRewritePatternCallbacks callbacks, void *userData,
@@ -349,11 +360,14 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
 /// RewritePatternSet API
 //===----------------------------------------------------------------------===//
 
+/// Create an empty MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED MlirRewritePatternSet
 mlirRewritePatternSetCreate(MlirContext context);
 
+/// Destruct the given MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
 
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
 MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
                                                  MlirRewritePattern pattern);
 



More information about the Mlir-commits mailing list