[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined rewrite patterns (PR #162699)

Maksim Levental llvmlistbot at llvm.org
Fri Oct 10 10:58:39 PDT 2025


================
@@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+class PyRewritePatternSet {
+public:
+  PyRewritePatternSet(MlirContext ctx)
+      : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
+  ~PyRewritePatternSet() {
+    if (set.ptr)
+      mlirRewritePatternSetDestroy(set);
+  }
+
+  void add(MlirStringRef rootName, unsigned benefit,
+           const nb::callable &matchAndRewrite) {
+    MlirRewritePatternCallbacks callbacks;
+    callbacks.construct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+    };
+    callbacks.destruct = [](void *userData) {
+      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+    };
+    callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
+                                   MlirPatternRewriter rewriter,
+                                   void *userData) -> MlirLogicalResult {
+      nb::handle f(static_cast<PyObject *>(userData));
+      nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
----------------
makslevental wrote:

so is there value in passing `pattern` to the callback? even in cpp i've rarely seen any of the `Pattern` members called (like `getGeneratedOps` and etc).

https://github.com/llvm/llvm-project/pull/162699


More information about the Mlir-commits mailing list