[Mlir-commits] [mlir] [MLIR][Python] Call `notifyOperationInserted` while constructing new op in rewrite patterns (PR #163694)

Maksim Levental llvmlistbot at llvm.org
Thu Oct 16 20:03:59 PDT 2025


================
@@ -202,7 +207,15 @@ class PyRewritePatternSet {
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
+      PyPatternRewriter pyRewriter(rewriter);
+      nb::object listener = nb::cast(pyRewriter.getListener());
+
+      listener.attr("__enter__")();
+      auto exit = llvm::make_scope_exit([listener] {
+        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
+      });
----------------
makslevental wrote:

here's a sketch of what i'm talking about (on top of your current commit):

```diff
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5512fb2377d6..e41a99be1473 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -189,37 +189,44 @@ public:
       mlirRewritePatternSetDestroy(set);
   }
 
+  struct UserData {
+    const nb::callable &matchAndRewriteCb;
+    nb::object listener;
+  };
+
   void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
+           const nb::callable &matchAndRewriteCb, nb::object listener) {
     MlirRewritePatternCallbacks callbacks;
     callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.inc_ref();
+      userData_->listener.inc_ref();
     };
     callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.dec_ref();
+      userData_->listener.dec_ref();
     };
     callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
                                    MlirPatternRewriter rewriter,
                                    void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
+      UserData *userData_ = static_cast<UserData *>(userData);
+      nb::handle f(userData_->matchAndRewriteCb);
 
       PyMlirContextRef ctx =
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
       PyPatternRewriter pyRewriter(rewriter);
-      nb::object listener = nb::cast(pyRewriter.getListener());
-
-      listener.attr("__enter__")();
-      auto exit = llvm::make_scope_exit([listener] {
-        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
-      });
-      nb::object res = f(opView, pyRewriter);
+      nb::object listener = userData_->listener;
+      nb::object res = f(opView, pyRewriter, listener);
 
       return logicalResultFromObject(res);
     };
+
+    UserData *userData_ = new UserData{matchAndRewriteCb, listener};
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        rootName, benefit, ctx, callbacks, static_cast<void *>(userData_),
         /* nGeneratedNames */ 0,
         /* generatedNames */ nullptr);
     mlirRewritePatternSetAdd(set, pattern);
@@ -291,13 +298,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "add",
           [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
-             unsigned benefit) {
+             const nb::object &listener, unsigned benefit) {
             std::string opName =
                 nb::cast<std::string>(root.attr("OPERATION_NAME"));
             self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
-                     fn);
+                     fn, listener);
           },
-          "root"_a, "fn"_a, "benefit"_a = 1,
+          "root"_a, "fn"_a, "listener"_a, "benefit"_a = 1,
           "Add a new rewrite pattern on the given root operation with the "
           "callable as the matching and rewriting function and the given "
           "benefit.")
```

https://github.com/llvm/llvm-project/pull/163694


More information about the Mlir-commits mailing list