[Mlir-commits] [mlir] [MLIR][Python] Call `notifyOperationInserted` while constructing new op in rewrite patterns (PR #163694)
Maksim Levental
llvmlistbot at llvm.org
Thu Oct 16 20:03:59 PDT 2025
================
@@ -202,7 +207,15 @@ class PyRewritePatternSet {
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
- nb::object res = f(opView, PyPatternRewriter(rewriter));
+ PyPatternRewriter pyRewriter(rewriter);
+ nb::object listener = nb::cast(pyRewriter.getListener());
+
+ listener.attr("__enter__")();
+ auto exit = llvm::make_scope_exit([listener] {
+ listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
+ });
----------------
makslevental wrote:
here's a sketch of what i'm talking about (on top of your current commit):
```diff
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5512fb2377d6..e41a99be1473 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -189,37 +189,44 @@ public:
mlirRewritePatternSetDestroy(set);
}
+ struct UserData {
+ const nb::callable &matchAndRewriteCb;
+ nb::object listener;
+ };
+
void add(MlirStringRef rootName, unsigned benefit,
- const nb::callable &matchAndRewrite) {
+ const nb::callable &matchAndRewriteCb, nb::object listener) {
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+ UserData *userData_ = static_cast<UserData *>(userData);
+ userData_->matchAndRewriteCb.inc_ref();
+ userData_->listener.inc_ref();
};
callbacks.destruct = [](void *userData) {
- nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+ UserData *userData_ = static_cast<UserData *>(userData);
+ userData_->matchAndRewriteCb.dec_ref();
+ userData_->listener.dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
- nb::handle f(static_cast<PyObject *>(userData));
+ UserData *userData_ = static_cast<UserData *>(userData);
+ nb::handle f(userData_->matchAndRewriteCb);
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
PyPatternRewriter pyRewriter(rewriter);
- nb::object listener = nb::cast(pyRewriter.getListener());
-
- listener.attr("__enter__")();
- auto exit = llvm::make_scope_exit([listener] {
- listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
- });
- nb::object res = f(opView, pyRewriter);
+ nb::object listener = userData_->listener;
+ nb::object res = f(opView, pyRewriter, listener);
return logicalResultFromObject(res);
};
+
+ UserData *userData_ = new UserData{matchAndRewriteCb, listener};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
- rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+ rootName, benefit, ctx, callbacks, static_cast<void *>(userData_),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(set, pattern);
@@ -291,13 +298,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(
"add",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
- unsigned benefit) {
+ const nb::object &listener, unsigned benefit) {
std::string opName =
nb::cast<std::string>(root.attr("OPERATION_NAME"));
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
- fn);
+ fn, listener);
},
- "root"_a, "fn"_a, "benefit"_a = 1,
+ "root"_a, "fn"_a, "listener"_a, "benefit"_a = 1,
"Add a new rewrite pattern on the given root operation with the "
"callable as the matching and rewriting function and the given "
"benefit.")
```
https://github.com/llvm/llvm-project/pull/163694
More information about the Mlir-commits
mailing list