[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined passes in MLIR (PR #156000)

Maksim Levental llvmlistbot at llvm.org
Mon Sep 8 06:18:10 PDT 2025


================
@@ -20,6 +21,79 @@ using namespace mlir::python;
 
 namespace {
 
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+//   def run(self, operation):
+//     # do something with operation
+//     pass
+// ```
+class PyPassBase {
+public:
+  PyPassBase(std::string name, std::string argument, std::string description,
+             std::string opName)
+      : name(std::move(name)), argument(std::move(argument)),
+        description(std::move(description)), opName(std::move(opName)) {
+    callbacks.construct = [](void *obj) {};
+    callbacks.destruct = [](void *obj) {
+      nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+    };
+    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+      auto handle = nb::handle(static_cast<PyObject *>(obj));
+      nb::cast<PyPassBase *>(handle)->run(op);
+    };
+    callbacks.clone = [](void *obj) -> void * {
+      nb::object copy = nb::module_::import_("copy");
+      nb::object deepcopy = copy.attr("deepcopy");
+      return deepcopy(obj).release().ptr();
+    };
+    callbacks.initialize = nullptr;
+  }
+
+  // this method should be overridden by subclasses in Python.
+  virtual void run(MlirOperation op) = 0;
+
+  virtual ~PyPassBase() = default;
+
+  // Make an MlirPass instance on-the-fly that wraps this object.
+  // Note that passmanager will take the ownership of the returned
+  // object and release it when appropriate.
+  // Also, `*this` must remain alive as long as the returned object is alive.
+  MlirPass make() {
+    auto *obj = nb::find(this).release().ptr();
+    return mlirCreateExternalPass(
+        mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
+        mlirStringRefCreate(argument.data(), argument.length()),
+        mlirStringRefCreate(description.data(), description.length()),
+        mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
----------------
makslevental wrote:

> then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass

Yes that's exactly the follow-up I had in mind.

> The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

It's the lifetime management that is the issue - the C++ APIs expect [ownership of the `Pass` object](https://github.com/llvm/llvm-project/blob/968b50b347f028c02b55657e64c27f720d0f9a20/mlir/lib/CAPI/IR/Pass.cpp#L93). But there's simply no way to express "unique ownership" in Python. That's why I rewrote @PragmaTwice's original PR (which isn't very different from what you propose) to only manage the lifetime of a single Python object - the `run` callback. 

https://github.com/llvm/llvm-project/pull/156000


More information about the Mlir-commits mailing list