[Mlir-commits] [mlir] [MLIR][Python] Add a function to register python-defined passes (PR #157850)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 10 05:38:20 PDT 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/157850
This PR brings a new function `register_pass` to MLIR python bindings, which can register a python-defined pass into MLIR (corresponding to `mlir::registerPass` in C++).
An example:
```python
def demo_pass(op, pass_):
pass # do something
register_pass("my-python-demo-pass", demo_pass)
pm = PassManager('any')
pm.add("my-python-demo-pass, some-cpp-defined-pass ...")
pm.run(..)
```
>From 09870a61b110e4f3a37a513f7d0fc4e688519583 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 10 Sep 2025 20:27:17 +0800
Subject: [PATCH] [MLIR][Python] Add a function to register python-defined
passes
---
mlir/include/mlir-c/Pass.h | 7 ++++
mlir/lib/Bindings/Python/Pass.cpp | 61 +++++++++++++++++++++++--------
mlir/lib/CAPI/IR/Pass.cpp | 26 +++++++++++++
mlir/test/python/python_pass.py | 44 +++++++++++++++++++++-
4 files changed, 121 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 0d2e19ee7fb0a..1328b4de5d4eb 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData);
+MLIR_CAPI_EXPORTED void
+mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument, MlirStringRef description,
+ MlirStringRef opName, intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks, void *userData);
+
/// This signals that the pass has failed. This is only valid to call during
/// the `run` callback of `MlirExternalPassCallbacks`.
/// See Pass::signalPassFailure().
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 47ef5d8e9dd3b..558ab6a43d87b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -52,6 +52,24 @@ class PyPassManager {
MlirPassManager passManager;
};
+MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() {
+ MlirExternalPassCallbacks callbacks;
+ callbacks.construct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+ };
+ callbacks.destruct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
+ callbacks.initialize = nullptr;
+ callbacks.clone = [](void *) -> void * {
+ throw std::runtime_error("Cloning Python passes not supported");
+ };
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
+ nb::handle(static_cast<PyObject *>(userData))(op, pass);
+ };
+ return callbacks;
+}
+
} // namespace
/// Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
.def("signal_pass_failure",
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+ //----------------------------------------------------------------------------
+ // Mapping of register_pass
+ //----------------------------------------------------------------------------
+ m.def(
+ "register_pass",
+ [](const std::string &argument, const nb::callable &run,
+ std::optional<std::string> &name, const std::string &description,
+ const std::string &opName) {
+ if (!name.has_value()) {
+ name =
+ nb::cast<std::string>(nb::borrow<nb::str>(run.attr("__name__")));
+ }
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ auto callbacks = createExternalPassCallbacksForPythonCallable();
+ mlirRegisterExternalPass(
+ passID, mlirStringRefCreate(name->data(), name->length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()),
+ /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks,
+ /*userData*/ run.ptr());
+ },
+ "argument"_a, "run"_a, "name"_a.none() = nb::none(),
+ "description"_a.none() = "", "op_name"_a.none() = "",
+ "Register a python-defined pass.");
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
MlirTypeID passID =
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
- MlirExternalPassCallbacks callbacks;
- callbacks.construct = [](void *obj) {
- (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
- };
- callbacks.destruct = [](void *obj) {
- (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
- };
- callbacks.initialize = nullptr;
- callbacks.clone = [](void *) -> void * {
- throw std::runtime_error("Cloning Python passes not supported");
- };
- callbacks.run = [](MlirOperation op, MlirExternalPass pass,
- void *userData) {
- nb::handle(static_cast<PyObject *>(userData))(op, pass);
- };
+ auto callbacks = createExternalPassCallbacksForPythonCallable();
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b0a6ec1ace3cc..8924f6d9ec6a9 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
userData)));
}
+void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument, MlirStringRef description,
+ MlirStringRef opName, intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks,
+ void *userData) {
+ // here we clone these arguments as owned and pass them to
+ // the lambda as copies to avoid dangling refs,
+ // since the lambda below lives longer than the current function
+ std::string nameStr = unwrap(name).str();
+ std::string argumentStr = unwrap(argument).str();
+ std::string descriptionStr = unwrap(description).str();
+ std::string opNameStr = unwrap(opName).str();
+ std::vector<MlirDialectHandle> dependentDialectVec(
+ dependentDialects, dependentDialects + nDependentDialects);
+
+ mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr,
+ dependentDialectVec, callbacks, userData] {
+ return std::unique_ptr<mlir::Pass>(new mlir::ExternalPass(
+ unwrap(passID), nameStr, argumentStr, descriptionStr,
+ opNameStr.length() > 0 ? std::optional<StringRef>(opNameStr)
+ : std::nullopt,
+ dependentDialectVec, callbacks, userData));
+ });
+}
+
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
unwrap(pass)->signalPassFailure();
}
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 50c42102f66d3..0fbd96ec71ddc 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -89,7 +89,7 @@ def __call__(self, op, pass_):
# test signal_pass_failure
def custom_pass_that_fails(op, pass_):
- print("hello from pass that fails")
+ print("hello from pass that fails", file=sys.stderr)
pass_.signal_pass_failure()
pm = PassManager("any")
@@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_):
try:
pm.run(module)
except Exception as e:
- print(f"caught exception: {e}")
+ print(f"caught exception: {e}", file=sys.stderr)
+
+
+# CHECK-LABEL: TEST: testRegisterPass
+ at run
+def testRegisterPass():
+ with Context():
+ pdl_module = make_pdl_module()
+ frozen = PDLModule(pdl_module).freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ def custom_pass_3(op, pass_):
+ print("hello from pass 3!!!", file=sys.stderr)
+
+ def custom_pass_4(op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
+
+ register_pass("custom-pass-one", custom_pass_3)
+ register_pass("custom-pass-two", custom_pass_4)
+
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+
+ # CHECK: hello from pass 3!!!
+ # CHECK-LABEL: Dump After custom_pass_3
+ # CHECK-LABEL: Dump After custom_pass_4
+ # CHECK: arith.muli
+ # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+ # CHECK: llvm.mul
+ pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm")
+ pm.run(module)
More information about the Mlir-commits
mailing list