[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)

Maksim Levental llvmlistbot at llvm.org
Thu Oct 9 21:15:28 PDT 2025


================
@@ -165,13 +175,115 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() {
+    if (set.ptr)
+      mlirRewritePatternSetDestroy(set);
+  }
+
+  void add(MlirStringRef rootName, MlirPatternBenefit benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
+      return logicalResultFromObject(res);
+    };
+    MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        /* nGeneratedNames */ 0,
+        /* generatedNames */ nullptr);
+    mlirRewritePatternSetAdd(set, pattern);
+  }
+
+  PyFrozenRewritePatternSet freeze() {
+    MlirRewritePatternSet s = set;
+    set.ptr = nullptr;
+    return mlirFreezeRewritePattern(s);
+  }
+
+private:
+  MlirRewritePatternSet set;
+  MlirContext ctx;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
-  nb::class_<PyPatternRewriter>(m, "PatternRewriter")
-      .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
-                   "The current insertion point of the PatternRewriter.");
+  //----------------------------------------------------------------------------
+  // Mapping of the PatternRewriter
+  //----------------------------------------------------------------------------
+  nb::
+      class_<PyPatternRewriter>(m, "PatternRewriter")
+          .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
+                       "The current insertion point of the PatternRewriter.")
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 MlirOperation newOp) { self.replaceOp(op, newOp); },
+              "Replace an operation with a new operation.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
+              // clang-format on
+              )
+          .def(
+              "replace_op",
+              [](PyPatternRewriter &self, MlirOperation op,
+                 const std::vector<MlirValue> &values) {
+                self.replaceOp(op, values);
+              },
+              "Replace an operation with a list of values.",
+              // clang-format off
+              nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
+                ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
+              // clang-format on
+              )
----------------
makslevental wrote:

```suggestion
              nb::arg("op"), nb::arg("values"))
```

https://github.com/llvm/llvm-project/pull/162699


More information about the Mlir-commits mailing list