[Mlir-commits] [mlir] [MLIR][Python] Call `notifyOperationInserted` while constructing new op in rewrite patterns (PR #163694)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 16 10:15:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

In MLIR, both `PatternRewriter` and `OpBuilder` typically have corresponding *Listeners* that monitor changes to the IR. For example, when a new operation is created, the listener’s `notifyOperationInserted` method is invoked to inform the rewriter of this change. (This is also why, in C++ rewrite patterns, one should use `PatternRewriter::create` instead of `OpTy::create`, since the latter does not trigger `notifyOperationInserted`.)

However, in Python, the listener methods are currently not invoked when executing rewrite patterns. While this may not always affect the outcome, it violates the semantics defined in MLIR’s C++ implementation. At present, rewrite patterns in Python directly construct new operations using the TableGen-generated op constructors, such as `arith.addi(lhs, rhs)`.
Although we could introduce an API like `rewriter.create("arith.addi", operands=[lhs, rhs], results=...)`, it would be less intuitive, and users might still bypass it and instantiate ops directly, which would again skip listener notifications.

In this PR, we adopt an approach similar to how *context*, *location*, and *insertion point* are managed: we maintain a stack of listeners. When `arith.addi` or `Operation.create` is called, the top listener on the stack is retrieved, and `notifyOperationInserted` is invoked automatically. This allows users to construct operations in the usual way, while ensuring that listeners are properly notified.


---
Full diff: https://github.com/llvm/llvm-project/pull/163694.diff


6 Files Affected:

- (modified) mlir/include/mlir-c/Rewrite.h (+10) 
- (modified) mlir/include/mlir/CAPI/Rewrite.h (+1) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+69-6) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+39-3) 
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+16-1) 
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+17) 


``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84cd1d89..f42aaff4d3c19 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -39,6 +39,7 @@ DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
+DEFINE_C_API_STRUCT(MlirRewriterBaseListener, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -48,6 +49,15 @@ DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 MLIR_CAPI_EXPORTED MlirContext
 mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
 
+/// Get the listener of the rewriter.
+MLIR_CAPI_EXPORTED MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter);
+
+/// Notify the listener that the specified operation was inserted.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseListenerNotifyOperationInserted(
+    MlirRewriterBaseListener listener, MlirOperation op,
+    MlirOperation insertionPoint);
+
 //===----------------------------------------------------------------------===//
 /// Insertion points methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 9c96d354d4fc9..357f2ae5a4418 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -26,6 +26,7 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
                          mlir::FrozenRewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirRewriterBaseListener, mlir::RewriterBase::Listener)
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b1710656243a..66d008445227b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -810,11 +810,11 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
 }
 
 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
-                                nb::object insertionPoint,
-                                nb::object location) {
+                                nb::object insertionPoint, nb::object location,
+                                nb::object listener) {
   auto &stack = getStack();
   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
-                     std::move(location));
+                     std::move(location), std::move(listener));
   // If the new stack has more than one entry and the context of the new top
   // entry matches the previous, copy the insertionPoint and location from the
   // previous entry if missing from the new top entry.
@@ -827,6 +827,8 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
         current.insertionPoint = prev.insertionPoint;
       if (!current.location)
         current.location = prev.location;
+      if (!current.listener)
+        current.listener = prev.listener;
     }
   }
 }
@@ -849,6 +851,12 @@ PyLocation *PyThreadContextEntry::getLocation() {
   return nb::cast<PyLocation *>(location);
 }
 
+PyRewriterBaseListener *PyThreadContextEntry::getListener() {
+  if (!listener)
+    return nullptr;
+  return nb::cast<PyRewriterBaseListener *>(listener);
+}
+
 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
   auto *tos = getTopOfStack();
   return tos ? tos->getContext() : nullptr;
@@ -864,10 +872,16 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() {
   return tos ? tos->getLocation() : nullptr;
 }
 
+PyRewriterBaseListener *PyThreadContextEntry::getDefaultListener() {
+  auto *tos = getTopOfStack();
+  return tos ? tos->getListener() : nullptr;
+}
+
 nb::object PyThreadContextEntry::pushContext(nb::object context) {
   push(FrameKind::Context, /*context=*/context,
        /*insertionPoint=*/nb::object(),
-       /*location=*/nb::object());
+       /*location=*/nb::object(),
+       /*listener=*/nb::object());
   return context;
 }
 
@@ -890,7 +904,8 @@ PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
   push(FrameKind::InsertionPoint,
        /*context=*/contextObj,
        /*insertionPoint=*/insertionPointObj,
-       /*location=*/nb::object());
+       /*location=*/nb::object(),
+       /*listener=*/nb::object());
   return insertionPointObj;
 }
 
@@ -910,7 +925,8 @@ nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
   nb::object contextObj = location.getContext().getObject();
   push(FrameKind::Location, /*context=*/contextObj,
        /*insertionPoint=*/nb::object(),
-       /*location=*/locationObj);
+       /*location=*/locationObj,
+       /*listener=*/nb::object());
   return locationObj;
 }
 
@@ -924,6 +940,27 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
   stack.pop_back();
 }
 
+nb::object PyThreadContextEntry::pushListener(nb::object listenerObj) {
+  PyRewriterBaseListener &listener =
+      nb::cast<PyRewriterBaseListener &>(listenerObj);
+  nb::object contextObj = listener.getContext().getObject();
+  push(FrameKind::Location, /*context=*/contextObj,
+       /*insertionPoint=*/nb::object(),
+       /*location=*/nb::object(),
+       /*listener=*/listenerObj);
+  return listenerObj;
+}
+
+void PyThreadContextEntry::popListener(PyRewriterBaseListener &listener) {
+  auto &stack = getStack();
+  if (stack.empty())
+    throw std::runtime_error("Unbalanced Listener enter/exit");
+  auto &tos = stack.back();
+  if (tos.frameKind != FrameKind::Listener && tos.getListener() != &listener)
+    throw std::runtime_error("Unbalanced Listener enter/exit");
+  stack.pop_back();
+}
+
 //------------------------------------------------------------------------------
 // PyDiagnostic*
 //------------------------------------------------------------------------------
@@ -1417,6 +1454,10 @@ static void maybeInsertOperation(PyOperationRef &op,
     if (ip)
       ip->insert(*op.get());
   }
+  if (PyRewriterBaseListener *listener =
+          PyThreadContextEntry::getDefaultListener()) {
+    listener->notifyOperationInserted(*op.get());
+  }
 }
 
 nb::object PyOperation::create(std::string_view name,
@@ -2036,6 +2077,19 @@ PyOpView::PyOpView(const nb::object &operationObject)
     : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
       operationObject(operation.getRef().getObject()) {}
 
+//------------------------------------------------------------------------------
+// PyRewriterBaseListener.
+//------------------------------------------------------------------------------
+
+nb::object PyRewriterBaseListener::contextEnter(nb::object listener) {
+  return PyThreadContextEntry::pushListener(std::move(listener));
+}
+
+void PyRewriterBaseListener::contextExit(nb::handle excType, nb::handle excVal,
+                                         nb::handle excTb) {
+  PyThreadContextEntry::popListener(*this);
+}
+
 //------------------------------------------------------------------------------
 // PyInsertionPoint.
 //------------------------------------------------------------------------------
@@ -3961,6 +4015,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           "Returns the list of Block predecessors.");
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyRewriterBaseListener.
+  //----------------------------------------------------------------------------
+  nb::class_<PyRewriterBaseListener>(m, "RewriterBaseListener")
+      .def("__enter__", &PyRewriterBaseListener::contextEnter)
+      .def("__exit__", &PyRewriterBaseListener::contextExit,
+           nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+           nb::arg("traceback").none());
+
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index e706be3b4d32a..9985af0a448a6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -46,6 +46,7 @@ class PyOperationBase;
 class PyType;
 class PySymbolTable;
 class PyValue;
+class PyRewriterBaseListener;
 
 /// Template for a reference to a concrete type which captures a python
 /// reference to its underlying python object.
@@ -115,13 +116,15 @@ class PyThreadContextEntry {
     Context,
     InsertionPoint,
     Location,
+    Listener,
   };
 
   PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
                        nanobind::object insertionPoint,
-                       nanobind::object location)
+                       nanobind::object location, nanobind::object listener)
       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
-        location(std::move(location)), frameKind(frameKind) {}
+        location(std::move(location)), listener(std::move(listener)),
+        frameKind(frameKind) {}
 
   /// Gets the top of stack context and return nullptr if not defined.
   static PyMlirContext *getDefaultContext();
@@ -132,9 +135,12 @@ class PyThreadContextEntry {
   /// Gets the top of stack location and returns nullptr if not defined.
   static PyLocation *getDefaultLocation();
 
+  static PyRewriterBaseListener *getDefaultListener();
+
   PyMlirContext *getContext();
   PyInsertionPoint *getInsertionPoint();
   PyLocation *getLocation();
+  PyRewriterBaseListener *getListener();
   FrameKind getFrameKind() { return frameKind; }
 
   /// Stack management.
@@ -145,13 +151,16 @@ class PyThreadContextEntry {
   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
   static nanobind::object pushLocation(nanobind::object location);
   static void popLocation(PyLocation &location);
+  static nanobind::object pushListener(nanobind::object listener);
+  static void popListener(PyRewriterBaseListener &listener);
 
   /// Gets the thread local stack.
   static std::vector<PyThreadContextEntry> &getStack();
 
 private:
   static void push(FrameKind frameKind, nanobind::object context,
-                   nanobind::object insertionPoint, nanobind::object location);
+                   nanobind::object insertionPoint, nanobind::object location,
+                   nanobind::object listener);
 
   /// An object reference to the PyContext.
   nanobind::object context;
@@ -159,6 +168,8 @@ class PyThreadContextEntry {
   nanobind::object insertionPoint;
   /// An object reference to the current location.
   nanobind::object location;
+  /// An object reference to the current listener.
+  nanobind::object listener;
   // The kind of push that was performed.
   FrameKind frameKind;
 };
@@ -830,6 +841,31 @@ class PyBlock {
   MlirBlock block;
 };
 
+/// Wrapper around a MlirRewriterBaseListener.
+class PyRewriterBaseListener {
+public:
+  PyRewriterBaseListener(MlirRewriterBaseListener listener,
+                         PyMlirContextRef ctx)
+      : listener(listener), ctx(std::move(ctx)) {}
+
+  MlirRewriterBaseListener get() { return listener; }
+
+  void notifyOperationInserted(PyOperationBase &op) {
+    mlirRewriterBaseListenerNotifyOperationInserted(get(), op.getOperation(),
+                                                    MlirOperation{nullptr});
+  }
+
+  PyMlirContextRef getContext() { return ctx; }
+
+  static nanobind::object contextEnter(nanobind::object listener);
+  void contextExit(nanobind::handle excType, nanobind::handle excVal,
+                   nanobind::handle excTb);
+
+private:
+  MlirRewriterBaseListener listener;
+  PyMlirContextRef ctx;
+};
+
 /// An insertion point maintains a pointer to a Block and a reference operation.
 /// Calls to insert() will insert a new operation before the
 /// reference operation. If the reference operation is null, then appends to
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5ddb3fbbb1317..5512fb2377d60 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -18,6 +18,7 @@
 // clang-format on
 #include "mlir/Config/mlir-config.h"
 #include "nanobind/nanobind.h"
+#include "llvm/ADT/ScopeExit.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -45,6 +46,10 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  PyRewriterBaseListener getListener() {
+    return PyRewriterBaseListener(mlirRewriterBaseGetListener(base), ctx);
+  }
+
   void replaceOp(MlirOperation op, MlirOperation newOp) {
     mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
   }
@@ -202,7 +207,15 @@ class PyRewritePatternSet {
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
+      PyPatternRewriter pyRewriter(rewriter);
+      nb::object listener = nb::cast(pyRewriter.getListener());
+
+      listener.attr("__enter__")();
+      auto exit = llvm::make_scope_exit([listener] {
+        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
+      });
+      nb::object res = f(opView, pyRewriter);
+
       return logicalResultFromObject(res);
     };
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
@@ -234,6 +247,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       class_<PyPatternRewriter>(m, "PatternRewriter")
           .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
                        "The current insertion point of the PatternRewriter.")
+          .def_prop_ro("listener", &PyPatternRewriter::getListener,
+                       "The rewrite listener of the PatternRewriter.")
           .def(
               "replace_op",
               [](PyPatternRewriter &self, MlirOperation op,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 46c329d8433b4..41b48bdfe5d6b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -29,6 +29,23 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
   return wrap(unwrap(rewriter)->getContext());
 }
 
+MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter) {
+  return wrap(
+      dyn_cast<RewriterBase::Listener>(unwrap(rewriter)->getListener()));
+}
+
+void mlirRewriterBaseListenerNotifyOperationInserted(
+    MlirRewriterBaseListener listener, MlirOperation op,
+    MlirOperation insertionPoint) {
+  OpBuilder::InsertPoint ip;
+  if (!mlirOperationIsNull(insertionPoint)) {
+    ip = OpBuilder::InsertPoint(unwrap(insertionPoint)->getBlock(),
+                                Block::iterator(unwrap(insertionPoint)));
+  }
+  return unwrap(listener)->notifyOperationInserted(unwrap(op), ip);
+}
+
 //===----------------------------------------------------------------------===//
 /// Insertion points methods
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/163694


More information about the Mlir-commits mailing list