[Mlir-commits] [mlir] 7aec3f2 - [MLIR][Python] Support Python-defined rewrite patterns (#162699)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 10 20:28:49 PDT 2025


Author: Twice
Date: 2025-10-11T11:28:45+08:00
New Revision: 7aec3f2864e8ea968e7d164e784f9d7038807a5d

URL: https://github.com/llvm/llvm-project/commit/7aec3f2864e8ea968e7d164e784f9d7038807a5d
DIFF: https://github.com/llvm/llvm-project/commit/7aec3f2864e8ea968e7d164e784f9d7038807a5d.diff

LOG: [MLIR][Python] Support Python-defined rewrite patterns (#162699)

This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.

Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).

With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.

### Example

```python
def to_muli(op, rewriter):
    with rewriter.ip:
        new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
    rewriter.replace_op(op, new_op.owner)

with Context():
    patterns = RewritePatternSet()
    patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
    frozen = patterns.freeze()

    module = ...
    apply_patterns_and_fold_greedily(module, frozen)
```

---------

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>

Added: 
    mlir/test/python/rewrite.py

Modified: 
    mlir/include/mlir-c/Rewrite.h
    mlir/include/mlir/CAPI/Rewrite.h
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/lib/CAPI/Transforms/Rewrite.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 5dd285ee076c4..2db1d84cd1d89 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
+DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter);
 /// FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
+/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet.
+/// Note that the ownership of the input set is transferred into the frozen set
+/// after this call.
 MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
-mlirFreezeRewritePattern(MlirRewritePatternSet op);
+mlirFreezeRewritePattern(MlirRewritePatternSet set);
 
+/// Destroy the given MlirFrozenRewritePatternSet.
 MLIR_CAPI_EXPORTED void
-mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set);
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
     MlirOperation op, MlirFrozenRewritePatternSet patterns,
@@ -324,6 +329,51 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
 MLIR_CAPI_EXPORTED MlirRewriterBase
 mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+/// Callbacks to construct a rewrite pattern.
+typedef struct {
+  /// Optional constructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*construct)(void *userData);
+  /// Optional destructor for the user data.
+  /// Set to nullptr to disable it.
+  void (*destruct)(void *userData);
+  /// The callback function to match against code rooted at the specified
+  /// operation, and perform the rewrite if the match is successful,
+  /// corresponding to RewritePattern::matchAndRewrite.
+  MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
+                                       MlirOperation op,
+                                       MlirPatternRewriter rewriter,
+                                       void *userData);
+} MlirRewritePatternCallbacks;
+
+/// Create a rewrite pattern that matches the operation
+/// with the given rootName, corresponding to mlir::OpRewritePattern.
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, unsigned benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames);
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+/// Create an empty MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetCreate(MlirContext context);
+
+/// Destruct the given MlirRewritePatternSet.
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
+
+/// Add the given MlirRewritePattern into a MlirRewritePatternSet.
+/// Note that the ownership of the pattern is transferred to the set after this
+/// call.
+MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                                                 MlirRewritePattern pattern);
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 1038c0a575cf2..8cd51edf0837b 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -20,5 +20,7 @@
 #include "mlir/IR/PatternMatch.h"
 
 DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase)
+DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern)
+DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
 
 #endif // MLIR_CAPIREWRITER_H

diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9e3d9703c82e8..d506b7fc9bc7b 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -45,6 +45,16 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  void replaceOp(MlirOperation op, MlirOperation newOp) {
+    mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
+  }
+
+  void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
+    mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
+  }
+
+  void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
+
 private:
   MlirRewriterBase base;
   PyMlirContextRef ctx;
@@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() {
+    if (set.ptr)
+      mlirRewritePatternSetDestroy(set);
+  }
+
+  void add(MlirStringRef rootName, unsigned benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter));
+      return logicalResultFromObject(res);
+    };
+    MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set, pattern);
+  }
+
+  PyFrozenRewritePatternSet freeze() {
+    MlirRewritePatternSet s = set;
+    set.ptr = nullptr;
+    return mlirFreezeRewritePattern(s);
+  }
+
+private:
+  MlirRewritePatternSet set;
+  MlirContext ctx;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
-      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.");
+  //----------------------------------------------------------------------------
+  // Mapping of the PatternRewriter
+  //----------------------------------------------------------------------------
+  nb::
+      class_<PyPatternRewriter>(m, "PatternRewriter")
+          .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                       "The current insertion point of the PatternRewriter.")
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 MlirOperation newOp) { self.replaceOp(op, newOp); },
+              "Replace an operation with a new operation.", nb::arg("op"),
+              nb::arg("new_op"),
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+              // clang-format on
+              )
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 const std::vector<MlirValue> &values) {
+                self.replaceOp(op, values);
+              },
+              "Replace an operation with a list of values.", nb::arg("op"),
+              nb::arg("values"),
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+              // clang-format on
+              )
+          .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
+               nb::arg("op"),
+               // clang-format off
+                nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+               // clang-format on
+          );
+
+  //----------------------------------------------------------------------------
+  // Mapping of the RewritePatternSet
+  //----------------------------------------------------------------------------
+  nb::class_<MlirRewritePattern>(m, "RewritePattern");
+  nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
+      .def(
+          "__init__",
+          [](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
+            new (&self) PyRewritePatternSet(context.get()->get());
+          },
+          "context"_a = nb::none())
+      .def(
+          "add",
+          [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
+             unsigned benefit) {
+            std::string opName =
+                nb::cast<std::string>(root.attr("OPERATION_NAME"));
+            self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
+                     fn);
+          },
+          "root"_a, "fn"_a, "benefit"_a = 1,
+          "Add a new rewrite pattern on the given root operation with the "
+          "callable as the matching and rewriting function and the given "
+          "benefit.")
+      .def("freeze", &PyRewritePatternSet::freeze,
+           "Freeze the pattern set into a frozen one.");
+
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
   //----------------------------------------------------------------------------
@@ -237,7 +350,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "freeze",
           [](PyPDLPatternModule &self) {
-            return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
                 mlirRewritePatternSetFromPDLPatternModule(self.get())));
           },
           nb::keep_alive<0, 1>())

diff  --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index c15a73b991f5d..70dee598c9535 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,15 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
 /// RewritePatternSet and FrozenRewritePatternSet API
 //===----------------------------------------------------------------------===//
 
-static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
-  assert(module.ptr && "unexpected null module");
-  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
-}
-
-static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
-  return {module};
-}
-
 static inline mlir::FrozenRewritePatternSet *
 unwrap(MlirFrozenRewritePatternSet module) {
   assert(module.ptr && "unexpected null module");
@@ -290,15 +281,16 @@ wrap(mlir::FrozenRewritePatternSet *module) {
   return {module};
 }
 
-MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
-  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
-  op.ptr = nullptr;
+MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet set) {
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
+  set.ptr = nullptr;
   return wrap(m);
 }
 
-void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
-  delete unwrap(op);
-  op.ptr = nullptr;
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
+  delete unwrap(set);
+  set.ptr = nullptr;
 }
 
 MlirLogicalResult
@@ -332,6 +324,77 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
   return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
 }
 
+//===----------------------------------------------------------------------===//
+/// RewritePattern API
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+
+class ExternalRewritePattern : public mlir::RewritePattern {
+public:
+  ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
+                         StringRef rootName, PatternBenefit benefit,
+                         MLIRContext *context,
+                         ArrayRef<StringRef> generatedNames)
+      : RewritePattern(rootName, benefit, context, generatedNames),
+        callbacks(callbacks), userData(userData) {
+    if (callbacks.construct)
+      callbacks.construct(userData);
+  }
+
+  ~ExternalRewritePattern() {
+    if (callbacks.destruct)
+      callbacks.destruct(userData);
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    return unwrap(callbacks.matchAndRewrite(
+        wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
+        wrap(&rewriter), userData));
+  }
+
+private:
+  MlirRewritePatternCallbacks callbacks;
+  void *userData;
+};
+
+} // namespace mlir
+
+MlirRewritePattern mlirOpRewritePattenCreate(
+    MlirStringRef rootName, unsigned benefit, MlirContext context,
+    MlirRewritePatternCallbacks callbacks, void *userData,
+    size_t nGeneratedNames, MlirStringRef *generatedNames) {
+  std::vector<mlir::StringRef> generatedNamesVec;
+  generatedNamesVec.reserve(nGeneratedNames);
+  for (size_t i = 0; i < nGeneratedNames; ++i) {
+    generatedNamesVec.push_back(unwrap(generatedNames[i]));
+  }
+  return wrap(new mlir::ExternalRewritePattern(
+      callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
+      unwrap(context), generatedNamesVec));
+}
+
+//===----------------------------------------------------------------------===//
+/// RewritePatternSet API
+//===----------------------------------------------------------------------===//
+
+MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
+  return wrap(new mlir::RewritePatternSet(unwrap(context)));
+}
+
+void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
+  delete unwrap(set);
+}
+
+void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
+                              MlirRewritePattern pattern) {
+  std::unique_ptr<mlir::RewritePattern> patternPtr(
+      const_cast<mlir::RewritePattern *>(unwrap(pattern)));
+  pattern.ptr = nullptr;
+  unwrap(set)->add(std::move(patternPtr));
+}
+
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..acf7db23db914
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,69 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    f()
+
+
+# CHECK-LABEL: TEST: testRewritePattern
+ at run
+def testRewritePattern():
+    def to_muli(op, rewriter):
+        with rewriter.ip:
+            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+        rewriter.replace_op(op, new_op.owner)
+
+    def constant_1_to_2(op, rewriter):
+        c = op.attributes["value"].value
+        if c != 1:
+            return True  # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.result.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.AddIOp, to_muli)
+        patterns.add(arith.ConstantOp, constant_1_to_2)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+        # CHECK: return %0 : i64
+        print(module)
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 3 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        apply_patterns_and_fold_greedily(module, frozen)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c3_i64 = arith.constant 3 : i64
+        # CHECK: return %c2_i64, %c3_i64 : i64, i64
+        print(module)


        


More information about the Mlir-commits mailing list