[Mlir-commits] [mlir] [MLIR][Python] Add the ability to signal pass failures in python-defined passes (PR #157613)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 9 06:20:28 PDT 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/157613
>From 8d17e5b831de56237324c34aefabf608b6d639b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 12:13:16 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add the ability to signal pass failures in
python-defined passes
---
mlir/lib/Bindings/Python/Pass.cpp | 18 ++++++++++++++++--
mlir/test/python/python_pass.py | 17 +++++++++++++++++
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..fb7dc2705b3ce 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+ constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -182,10 +184,22 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ auto callable =
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+ nb::setattr(callable, mlirExternalPassAttr,
+ nb::capsule(pass.ptr));
+ callable(op);
+ // delete it to avoid that it is used after
+ // the external pass is freed by the pass manager
+ nb::delattr(callable, mlirExternalPassAttr);
};
+ nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+ nb::capsule cap = run.attr(mlirExternalPassAttr);
+ mlirExternalPassSignalFailure(
+ MlirExternalPass{cap.data()});
+ }));
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..7734d76fcba94 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,20 @@ def __call__(self, m):
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ class CustomPassThatFails:
+ def __call__(self, m):
+ print("hello from pass that fails")
+ self.signal_pass_failure()
+
+ custom_pass_that_fails = CustomPassThatFails()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")
>From 2241e278856ed01ab912da5eb5567b13f02358d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 13:03:13 +0800
Subject: [PATCH 2/3] refine the bad path
---
mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++++-
mlir/test/python/python_pass.py | 6 ++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index fb7dc2705b3ce..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -196,7 +196,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
nb::delattr(callable, mlirExternalPassAttr);
};
nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
- nb::capsule cap = run.attr(mlirExternalPassAttr);
+ nb::capsule cap;
+ try {
+ cap = run.attr(mlirExternalPassAttr);
+ } catch (nb::python_error &e) {
+ throw std::runtime_error(
+ "signal_pass_failure() should always be called "
+ "from the __call__ method");
+ }
mlirExternalPassSignalFailure(
MlirExternalPass{cap.data()});
}));
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 7734d76fcba94..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -103,3 +103,9 @@ def __call__(self, m):
pm.run(module)
except Exception as e:
print(f"caught exception: {e}")
+
+ # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+ try:
+ custom_pass_that_fails.signal_pass_failure()
+ except Exception as e:
+ print(f"caught exception: {e}")
>From 28638ab495cd0cde1487d5540819aea64f57d014 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 21:19:11 +0800
Subject: [PATCH 3/3] drop the setattr design
---
mlir/lib/Bindings/Python/Pass.cpp | 28 +++++++---------------------
mlir/test/python/python_pass.py | 16 +++++-----------
2 files changed, 12 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index c5fe7bda4a680..ef606431fbd5e 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,7 +56,12 @@ class PyPassManager {
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
- constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+ //----------------------------------------------------------------------------
+ // Mapping of MlirExternalPass
+ //----------------------------------------------------------------------------
+ nb::class_<MlirExternalPass>(m, "ExternalPass")
+ .def("signal_failure",
+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
@@ -186,27 +191,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
};
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- auto callable =
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
- nb::setattr(callable, mlirExternalPassAttr,
- nb::capsule(pass.ptr));
- callable(op);
- // delete it to avoid that it is used after
- // the external pass is freed by the pass manager
- nb::delattr(callable, mlirExternalPassAttr);
+ nb::handle(static_cast<PyObject *>(userData))(op, pass);
};
- nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
- nb::capsule cap;
- try {
- cap = run.attr(mlirExternalPassAttr);
- } catch (nb::python_error &e) {
- throw std::runtime_error(
- "signal_pass_failure() should always be called "
- "from the __call__ method");
- }
- mlirExternalPassSignalFailure(
- MlirExternalPass{cap.data()});
- }));
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 4784e073fef0a..10b449f9b1ef8 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
"""
)
- def custom_pass_1(op):
+ def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
- def __call__(self, m):
- apply_patterns_and_fold_greedily(m, frozen)
+ def __call__(self, op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
custom_pass_2 = CustomPass2()
@@ -89,9 +89,9 @@ def __call__(self, m):
# test signal_pass_failure
class CustomPassThatFails:
- def __call__(self, m):
+ def __call__(self, op, pass_):
print("hello from pass that fails")
- self.signal_pass_failure()
+ pass_.signal_failure()
custom_pass_that_fails = CustomPassThatFails()
@@ -103,9 +103,3 @@ def __call__(self, m):
pm.run(module)
except Exception as e:
print(f"caught exception: {e}")
-
- # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
- try:
- custom_pass_that_fails.signal_pass_failure()
- except Exception as e:
- print(f"caught exception: {e}")
More information about the Mlir-commits
mailing list