[Mlir-commits] [mlir] [MLIR][Python] Support Python-defined passes in MLIR (PR #156000)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 29 02:57:19 PDT 2025


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/156000

It tries to close #155996.

This PR exports a class `mlir.passmanager.Pass` for Python-side to use for defining new MLIR passes.

This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import Pass, PassManager

class DemoPass(Pass):
  def run(op):
    # do something with op
    pass

pm = PassManager(ctx)
pm.add(DemoPass())
pm.run(..)
```

TODO list:
- [ ] Python interface stub files
- [ ] tests for this change
- [ ] interop with PDL rewriting
- [ ] support to clone passes? (not sure)

>From 8386c87c431585c4412a52d07287b7423e34f602 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 29 Aug 2025 17:43:04 +0800
Subject: [PATCH] [MLIR][Python] Support Python-defined passes in MLIR

---
 mlir/lib/Bindings/Python/MainModule.cpp |  1 +
 mlir/lib/Bindings/Python/Pass.cpp       | 80 ++++++++++++++++++++++++-
 mlir/lib/Bindings/Python/Pass.h         |  1 +
 3 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..590e862a8d358 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+  populatePassSubmodule(passModule);
 }
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..4aa93df938295 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,9 +9,11 @@
 #include "Pass.h"
 
 #include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/Pass.h"
 #include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+#include "nanobind/trampoline.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace nb = nanobind;
 using namespace nb::literals;
@@ -20,6 +22,63 @@ using namespace mlir::python;
 
 namespace {
 
+// A base class for defining passes in Python
+// Users are expected to subclass this and implement the `run` method, e.g.
+// ```
+// class MyPass(mlir.passmanager.Pass):
+//   def run(self, operation):
+//     # do something with operation
+//     pass
+// ```
+class PyPassBase {
+public:
+  PyPassBase() : callbacks{} {
+    callbacks.construct = [](void *) {};
+    callbacks.destruct = [](void *) {};
+    callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
+      static_cast<PyPassBase *>(obj)->run(op);
+    };
+    // TODO: currently we don't support pass cloning in python
+    // due to lifetime management issues.
+    callbacks.clone = [](void *obj) -> void * {
+      // since the caller here should be MLIR C++ code,
+      // we need to avoid using exceptions like throw py::value_error(...).
+      llvm_unreachable("cloning of python-defined passes is not supported");
+    };
+  }
+
+  // this method should be overridden by subclasses in Python.
+  virtual void run(MlirOperation op) = 0;
+
+  virtual ~PyPassBase() = default;
+
+  // Make an MlirPass instance on-the-fly that wraps this object.
+  // Note that passmanager will take the ownership of the returned
+  // object and release it when appropriate.
+  // Also, `*this` must remain alive as long as the returned object is alive.
+  MlirPass make() {
+    return mlirCreateExternalPass(
+        mlirTypeIDCreate(this),
+        mlirStringRefCreateFromCString("python-example-pass"),
+        mlirStringRefCreateFromCString(""),
+        mlirStringRefCreateFromCString("Python Example Pass"),
+        mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
+  }
+
+private:
+  MlirExternalPassCallbacks callbacks;
+};
+
+// A trampoline class upon PyPassBase.
+// Refer to
+// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
+class PyPass : PyPassBase {
+public:
+  NB_TRAMPOLINE(PyPassBase, 1);
+
+  void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
+};
+
 /// Owning Wrapper around a PassManager.
 class PyPassManager {
 public:
@@ -52,6 +111,16 @@ class PyPassManager {
 
 } // namespace
 
+void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the Python-defined Pass interface
+  //----------------------------------------------------------------------------
+  nb::class_<PyPassBase, PyPass>(m, "Pass")
+      .def(nb::init<>(), "Create a new Pass.")
+      .def("run", &PyPassBase::run, "operation"_a,
+           "Run the pass on the provided operation.");
+}
+
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
   //----------------------------------------------------------------------------
@@ -157,6 +226,15 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
+      .def(
+          "add",
+          [](PyPassManager &passManager, PyPassBase &pass) {
+            mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
+          },
+          "pass"_a, "Add a python-defined pass to the pass manager.",
+          // NOTE that we should keep the pass object alive as long as the
+          // passManager to prevent dangling objects.
+          nb::keep_alive<1, 2>())
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op,
diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
index bc40943521829..ba3fbb707fed7 100644
--- a/mlir/lib/Bindings/Python/Pass.h
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -15,6 +15,7 @@ namespace mlir {
 namespace python {
 
 void populatePassManagerSubmodule(nanobind::module_ &m);
+void populatePassSubmodule(nanobind::module_ &m);
 
 } // namespace python
 } // namespace mlir



More information about the Mlir-commits mailing list