[Mlir-commits] [mlir] [MLIR][Python] Add a function to register python-defined passes (PR #157850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 10 05:38:20 PDT 2025


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/157850

This PR brings a new function `register_pass` to MLIR python bindings, which can register a python-defined pass into MLIR (corresponding to `mlir::registerPass` in C++).

An example:
```python
def demo_pass(op, pass_):
    pass # do something

register_pass("my-python-demo-pass", demo_pass)

pm = PassManager('any')
pm.add("my-python-demo-pass, some-cpp-defined-pass ...")
pm.run(..)
```

>From 09870a61b110e4f3a37a513f7d0fc4e688519583 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 10 Sep 2025 20:27:17 +0800
Subject: [PATCH] [MLIR][Python] Add a function to register python-defined
 passes

---
 mlir/include/mlir-c/Pass.h        |  7 ++++
 mlir/lib/Bindings/Python/Pass.cpp | 61 +++++++++++++++++++++++--------
 mlir/lib/CAPI/IR/Pass.cpp         | 26 +++++++++++++
 mlir/test/python/python_pass.py   | 44 +++++++++++++++++++++-
 4 files changed, 121 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 0d2e19ee7fb0a..1328b4de5d4eb 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
     intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
     MlirExternalPassCallbacks callbacks, void *userData);
 
+MLIR_CAPI_EXPORTED void
+mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+                         MlirStringRef argument, MlirStringRef description,
+                         MlirStringRef opName, intptr_t nDependentDialects,
+                         MlirDialectHandle *dependentDialects,
+                         MlirExternalPassCallbacks callbacks, void *userData);
+
 /// This signals that the pass has failed. This is only valid to call during
 /// the `run` callback of `MlirExternalPassCallbacks`.
 /// See Pass::signalPassFailure().
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 47ef5d8e9dd3b..558ab6a43d87b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -52,6 +52,24 @@ class PyPassManager {
   MlirPassManager passManager;
 };
 
+MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() {
+  MlirExternalPassCallbacks callbacks;
+  callbacks.construct = [](void *obj) {
+    (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+  };
+  callbacks.destruct = [](void *obj) {
+    (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+  };
+  callbacks.initialize = nullptr;
+  callbacks.clone = [](void *) -> void * {
+    throw std::runtime_error("Cloning Python passes not supported");
+  };
+  callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
+    nb::handle(static_cast<PyObject *>(userData))(op, pass);
+  };
+  return callbacks;
+}
+
 } // namespace
 
 /// Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
       .def("signal_pass_failure",
            [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
 
+  //----------------------------------------------------------------------------
+  // Mapping of register_pass
+  //----------------------------------------------------------------------------
+  m.def(
+      "register_pass",
+      [](const std::string &argument, const nb::callable &run,
+         std::optional<std::string> &name, const std::string &description,
+         const std::string &opName) {
+        if (!name.has_value()) {
+          name =
+              nb::cast<std::string>(nb::borrow<nb::str>(run.attr("__name__")));
+        }
+        MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+        MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+        auto callbacks = createExternalPassCallbacksForPythonCallable();
+        mlirRegisterExternalPass(
+            passID, mlirStringRefCreate(name->data(), name->length()),
+            mlirStringRefCreate(argument.data(), argument.length()),
+            mlirStringRefCreate(description.data(), description.length()),
+            mlirStringRefCreate(opName.data(), opName.size()),
+            /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks,
+            /*userData*/ run.ptr());
+      },
+      "argument"_a, "run"_a, "name"_a.none() = nb::none(),
+      "description"_a.none() = "", "op_name"_a.none() = "",
+      "Register a python-defined pass.");
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
             MlirTypeID passID =
                 mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
-            MlirExternalPassCallbacks callbacks;
-            callbacks.construct = [](void *obj) {
-              (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
-            };
-            callbacks.destruct = [](void *obj) {
-              (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
-            };
-            callbacks.initialize = nullptr;
-            callbacks.clone = [](void *) -> void * {
-              throw std::runtime_error("Cloning Python passes not supported");
-            };
-            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
-                               void *userData) {
-              nb::handle(static_cast<PyObject *>(userData))(op, pass);
-            };
+            auto callbacks = createExternalPassCallbacksForPythonCallable();
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b0a6ec1ace3cc..8924f6d9ec6a9 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
       userData)));
 }
 
+void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+                              MlirStringRef argument, MlirStringRef description,
+                              MlirStringRef opName, intptr_t nDependentDialects,
+                              MlirDialectHandle *dependentDialects,
+                              MlirExternalPassCallbacks callbacks,
+                              void *userData) {
+  // here we clone these arguments as owned and pass them to
+  // the lambda as copies to avoid dangling refs,
+  // since the lambda below lives longer than the current function
+  std::string nameStr = unwrap(name).str();
+  std::string argumentStr = unwrap(argument).str();
+  std::string descriptionStr = unwrap(description).str();
+  std::string opNameStr = unwrap(opName).str();
+  std::vector<MlirDialectHandle> dependentDialectVec(
+      dependentDialects, dependentDialects + nDependentDialects);
+
+  mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr,
+                      dependentDialectVec, callbacks, userData] {
+    return std::unique_ptr<mlir::Pass>(new mlir::ExternalPass(
+        unwrap(passID), nameStr, argumentStr, descriptionStr,
+        opNameStr.length() > 0 ? std::optional<StringRef>(opNameStr)
+                               : std::nullopt,
+        dependentDialectVec, callbacks, userData));
+  });
+}
+
 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
   unwrap(pass)->signalPassFailure();
 }
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 50c42102f66d3..0fbd96ec71ddc 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -89,7 +89,7 @@ def __call__(self, op, pass_):
 
         # test signal_pass_failure
         def custom_pass_that_fails(op, pass_):
-            print("hello from pass that fails")
+            print("hello from pass that fails", file=sys.stderr)
             pass_.signal_pass_failure()
 
         pm = PassManager("any")
@@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_):
         try:
             pm.run(module)
         except Exception as e:
-            print(f"caught exception: {e}")
+            print(f"caught exception: {e}", file=sys.stderr)
+
+
+# CHECK-LABEL: TEST: testRegisterPass
+ at run
+def testRegisterPass():
+    with Context():
+        pdl_module = make_pdl_module()
+        frozen = PDLModule(pdl_module).freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+        """
+        )
+
+        def custom_pass_3(op, pass_):
+            print("hello from pass 3!!!", file=sys.stderr)
+
+        def custom_pass_4(op, pass_):
+            apply_patterns_and_fold_greedily(op, frozen)
+
+        register_pass("custom-pass-one", custom_pass_3)
+        register_pass("custom-pass-two", custom_pass_4)
+
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+
+        # CHECK: hello from pass 3!!!
+        # CHECK-LABEL: Dump After custom_pass_3
+        # CHECK-LABEL: Dump After custom_pass_4
+        # CHECK: arith.muli
+        # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+        # CHECK: llvm.mul
+        pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm")
+        pm.run(module)



More information about the Mlir-commits mailing list