[Mlir-commits] [mlir] [MLIR][Python] Call `notifyOperationInserted` while constructing new op in rewrite patterns (PR #163694)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 16 10:15:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
In MLIR, both `PatternRewriter` and `OpBuilder` typically have corresponding *Listeners* that monitor changes to the IR. For example, when a new operation is created, the listener’s `notifyOperationInserted` method is invoked to inform the rewriter of this change. (This is also why, in C++ rewrite patterns, one should use `PatternRewriter::create` instead of `OpTy::create`, since the latter does not trigger `notifyOperationInserted`.)
However, in Python, the listener methods are currently not invoked when executing rewrite patterns. While this may not always affect the outcome, it violates the semantics defined in MLIR’s C++ implementation. At present, rewrite patterns in Python directly construct new operations using the TableGen-generated op constructors, such as `arith.addi(lhs, rhs)`.
Although we could introduce an API like `rewriter.create("arith.addi", operands=[lhs, rhs], results=...)`, it would be less intuitive, and users might still bypass it and instantiate ops directly, which would again skip listener notifications.
In this PR, we adopt an approach similar to how *context*, *location*, and *insertion point* are managed: we maintain a stack of listeners. When `arith.addi` or `Operation.create` is called, the top listener on the stack is retrieved, and `notifyOperationInserted` is invoked automatically. This allows users to construct operations in the usual way, while ensuring that listeners are properly notified.
---
Full diff: https://github.com/llvm/llvm-project/pull/163694.diff
6 Files Affected:
- (modified) mlir/include/mlir-c/Rewrite.h (+10)
- (modified) mlir/include/mlir/CAPI/Rewrite.h (+1)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+69-6)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+39-3)
- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+16-1)
- (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+17)
``````````diff
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84cd1d89..f42aaff4d3c19 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -39,6 +39,7 @@ DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
+DEFINE_C_API_STRUCT(MlirRewriterBaseListener, void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -48,6 +49,15 @@ DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
MLIR_CAPI_EXPORTED MlirContext
mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
+/// Get the listener of the rewriter.
+MLIR_CAPI_EXPORTED MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter);
+
+/// Notify the listener that the specified operation was inserted.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseListenerNotifyOperationInserted(
+ MlirRewriterBaseListener listener, MlirOperation op,
+ MlirOperation insertionPoint);
+
//===----------------------------------------------------------------------===//
/// Insertion points methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 9c96d354d4fc9..357f2ae5a4418 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -26,6 +26,7 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
mlir::FrozenRewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirRewriterBaseListener, mlir::RewriterBase::Listener)
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b1710656243a..66d008445227b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -810,11 +810,11 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
}
void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
- nb::object insertionPoint,
- nb::object location) {
+ nb::object insertionPoint, nb::object location,
+ nb::object listener) {
auto &stack = getStack();
stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
- std::move(location));
+ std::move(location), std::move(listener));
// If the new stack has more than one entry and the context of the new top
// entry matches the previous, copy the insertionPoint and location from the
// previous entry if missing from the new top entry.
@@ -827,6 +827,8 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
current.insertionPoint = prev.insertionPoint;
if (!current.location)
current.location = prev.location;
+ if (!current.listener)
+ current.listener = prev.listener;
}
}
}
@@ -849,6 +851,12 @@ PyLocation *PyThreadContextEntry::getLocation() {
return nb::cast<PyLocation *>(location);
}
+PyRewriterBaseListener *PyThreadContextEntry::getListener() {
+ if (!listener)
+ return nullptr;
+ return nb::cast<PyRewriterBaseListener *>(listener);
+}
+
PyMlirContext *PyThreadContextEntry::getDefaultContext() {
auto *tos = getTopOfStack();
return tos ? tos->getContext() : nullptr;
@@ -864,10 +872,16 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() {
return tos ? tos->getLocation() : nullptr;
}
+PyRewriterBaseListener *PyThreadContextEntry::getDefaultListener() {
+ auto *tos = getTopOfStack();
+ return tos ? tos->getListener() : nullptr;
+}
+
nb::object PyThreadContextEntry::pushContext(nb::object context) {
push(FrameKind::Context, /*context=*/context,
/*insertionPoint=*/nb::object(),
- /*location=*/nb::object());
+ /*location=*/nb::object(),
+ /*listener=*/nb::object());
return context;
}
@@ -890,7 +904,8 @@ PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
push(FrameKind::InsertionPoint,
/*context=*/contextObj,
/*insertionPoint=*/insertionPointObj,
- /*location=*/nb::object());
+ /*location=*/nb::object(),
+ /*listener=*/nb::object());
return insertionPointObj;
}
@@ -910,7 +925,8 @@ nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
nb::object contextObj = location.getContext().getObject();
push(FrameKind::Location, /*context=*/contextObj,
/*insertionPoint=*/nb::object(),
- /*location=*/locationObj);
+ /*location=*/locationObj,
+ /*listener=*/nb::object());
return locationObj;
}
@@ -924,6 +940,27 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
stack.pop_back();
}
+nb::object PyThreadContextEntry::pushListener(nb::object listenerObj) {
+ PyRewriterBaseListener &listener =
+ nb::cast<PyRewriterBaseListener &>(listenerObj);
+ nb::object contextObj = listener.getContext().getObject();
+ push(FrameKind::Location, /*context=*/contextObj,
+ /*insertionPoint=*/nb::object(),
+ /*location=*/nb::object(),
+ /*listener=*/listenerObj);
+ return listenerObj;
+}
+
+void PyThreadContextEntry::popListener(PyRewriterBaseListener &listener) {
+ auto &stack = getStack();
+ if (stack.empty())
+ throw std::runtime_error("Unbalanced Listener enter/exit");
+ auto &tos = stack.back();
+ if (tos.frameKind != FrameKind::Listener && tos.getListener() != &listener)
+ throw std::runtime_error("Unbalanced Listener enter/exit");
+ stack.pop_back();
+}
+
//------------------------------------------------------------------------------
// PyDiagnostic*
//------------------------------------------------------------------------------
@@ -1417,6 +1454,10 @@ static void maybeInsertOperation(PyOperationRef &op,
if (ip)
ip->insert(*op.get());
}
+ if (PyRewriterBaseListener *listener =
+ PyThreadContextEntry::getDefaultListener()) {
+ listener->notifyOperationInserted(*op.get());
+ }
}
nb::object PyOperation::create(std::string_view name,
@@ -2036,6 +2077,19 @@ PyOpView::PyOpView(const nb::object &operationObject)
: operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}
+//------------------------------------------------------------------------------
+// PyRewriterBaseListener.
+//------------------------------------------------------------------------------
+
+nb::object PyRewriterBaseListener::contextEnter(nb::object listener) {
+ return PyThreadContextEntry::pushListener(std::move(listener));
+}
+
+void PyRewriterBaseListener::contextExit(nb::handle excType, nb::handle excVal,
+ nb::handle excTb) {
+ PyThreadContextEntry::popListener(*this);
+}
+
//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
@@ -3961,6 +4015,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Returns the list of Block predecessors.");
+ //----------------------------------------------------------------------------
+ // Mapping of PyRewriterBaseListener.
+ //----------------------------------------------------------------------------
+ nb::class_<PyRewriterBaseListener>(m, "RewriterBaseListener")
+ .def("__enter__", &PyRewriterBaseListener::contextEnter)
+ .def("__exit__", &PyRewriterBaseListener::contextExit,
+ nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+ nb::arg("traceback").none());
+
//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
//----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index e706be3b4d32a..9985af0a448a6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -46,6 +46,7 @@ class PyOperationBase;
class PyType;
class PySymbolTable;
class PyValue;
+class PyRewriterBaseListener;
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
@@ -115,13 +116,15 @@ class PyThreadContextEntry {
Context,
InsertionPoint,
Location,
+ Listener,
};
PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
nanobind::object insertionPoint,
- nanobind::object location)
+ nanobind::object location, nanobind::object listener)
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
- location(std::move(location)), frameKind(frameKind) {}
+ location(std::move(location)), listener(std::move(listener)),
+ frameKind(frameKind) {}
/// Gets the top of stack context and return nullptr if not defined.
static PyMlirContext *getDefaultContext();
@@ -132,9 +135,12 @@ class PyThreadContextEntry {
/// Gets the top of stack location and returns nullptr if not defined.
static PyLocation *getDefaultLocation();
+ static PyRewriterBaseListener *getDefaultListener();
+
PyMlirContext *getContext();
PyInsertionPoint *getInsertionPoint();
PyLocation *getLocation();
+ PyRewriterBaseListener *getListener();
FrameKind getFrameKind() { return frameKind; }
/// Stack management.
@@ -145,13 +151,16 @@ class PyThreadContextEntry {
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
static nanobind::object pushLocation(nanobind::object location);
static void popLocation(PyLocation &location);
+ static nanobind::object pushListener(nanobind::object listener);
+ static void popListener(PyRewriterBaseListener &listener);
/// Gets the thread local stack.
static std::vector<PyThreadContextEntry> &getStack();
private:
static void push(FrameKind frameKind, nanobind::object context,
- nanobind::object insertionPoint, nanobind::object location);
+ nanobind::object insertionPoint, nanobind::object location,
+ nanobind::object listener);
/// An object reference to the PyContext.
nanobind::object context;
@@ -159,6 +168,8 @@ class PyThreadContextEntry {
nanobind::object insertionPoint;
/// An object reference to the current location.
nanobind::object location;
+ /// An object reference to the current listener.
+ nanobind::object listener;
// The kind of push that was performed.
FrameKind frameKind;
};
@@ -830,6 +841,31 @@ class PyBlock {
MlirBlock block;
};
+/// Wrapper around a MlirRewriterBaseListener.
+class PyRewriterBaseListener {
+public:
+ PyRewriterBaseListener(MlirRewriterBaseListener listener,
+ PyMlirContextRef ctx)
+ : listener(listener), ctx(std::move(ctx)) {}
+
+ MlirRewriterBaseListener get() { return listener; }
+
+ void notifyOperationInserted(PyOperationBase &op) {
+ mlirRewriterBaseListenerNotifyOperationInserted(get(), op.getOperation(),
+ MlirOperation{nullptr});
+ }
+
+ PyMlirContextRef getContext() { return ctx; }
+
+ static nanobind::object contextEnter(nanobind::object listener);
+ void contextExit(nanobind::handle excType, nanobind::handle excVal,
+ nanobind::handle excTb);
+
+private:
+ MlirRewriterBaseListener listener;
+ PyMlirContextRef ctx;
+};
+
/// An insertion point maintains a pointer to a Block and a reference operation.
/// Calls to insert() will insert a new operation before the
/// reference operation. If the reference operation is null, then appends to
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5ddb3fbbb1317..5512fb2377d60 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -18,6 +18,7 @@
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
+#include "llvm/ADT/ScopeExit.h"
namespace nb = nanobind;
using namespace mlir;
@@ -45,6 +46,10 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
+ PyRewriterBaseListener getListener() {
+ return PyRewriterBaseListener(mlirRewriterBaseGetListener(base), ctx);
+ }
+
void replaceOp(MlirOperation op, MlirOperation newOp) {
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
}
@@ -202,7 +207,15 @@ class PyRewritePatternSet {
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
- nb::object res = f(opView, PyPatternRewriter(rewriter));
+ PyPatternRewriter pyRewriter(rewriter);
+ nb::object listener = nb::cast(pyRewriter.getListener());
+
+ listener.attr("__enter__")();
+ auto exit = llvm::make_scope_exit([listener] {
+ listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
+ });
+ nb::object res = f(opView, pyRewriter);
+
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
@@ -234,6 +247,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
+ .def_prop_ro("listener", &PyPatternRewriter::getListener,
+ "The rewrite listener of the PatternRewriter.")
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 46c329d8433b4..41b48bdfe5d6b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -29,6 +29,23 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getContext());
}
+MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter) {
+ return wrap(
+ dyn_cast<RewriterBase::Listener>(unwrap(rewriter)->getListener()));
+}
+
+void mlirRewriterBaseListenerNotifyOperationInserted(
+ MlirRewriterBaseListener listener, MlirOperation op,
+ MlirOperation insertionPoint) {
+ OpBuilder::InsertPoint ip;
+ if (!mlirOperationIsNull(insertionPoint)) {
+ ip = OpBuilder::InsertPoint(unwrap(insertionPoint)->getBlock(),
+ Block::iterator(unwrap(insertionPoint)));
+ }
+ return unwrap(listener)->notifyOperationInserted(unwrap(op), ip);
+}
+
//===----------------------------------------------------------------------===//
/// Insertion points methods
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/163694
More information about the Mlir-commits
mailing list