[Mlir-commits] [mlir] [MLIR][Python] Add the ability to signal pass failures in python-defined passes (PR #157613)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 8 22:09:05 PDT 2025
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/157613
In this PR we add the ability to signal pass failures (`pass.signal_pass_failure()`) in python-defined passes.
To achieve this, we store the `MlirExternalPass` C pointer into an attribute of the pass object, and add a method `signal_pass_failure` into the pass object while it is added into the pass manager via `pm.add(..)`.
Note that the `signal_pass_failure()` method should be always called from `__call__` in the pass object since the `MlirExternalPass` should be only available in this context (otherwise a friendly exception message will be raised).
>From 8d17e5b831de56237324c34aefabf608b6d639b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 12:13:16 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add the ability to signal pass failures in
python-defined passes
---
mlir/lib/Bindings/Python/Pass.cpp | 18 ++++++++++++++++--
mlir/test/python/python_pass.py | 17 +++++++++++++++++
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..fb7dc2705b3ce 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+ constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -182,10 +184,22 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ auto callable =
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+ nb::setattr(callable, mlirExternalPassAttr,
+ nb::capsule(pass.ptr));
+ callable(op);
+ // delete it to avoid that it is used after
+ // the external pass is freed by the pass manager
+ nb::delattr(callable, mlirExternalPassAttr);
};
+ nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+ nb::capsule cap = run.attr(mlirExternalPassAttr);
+ mlirExternalPassSignalFailure(
+ MlirExternalPass{cap.data()});
+ }));
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..7734d76fcba94 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,20 @@ def __call__(self, m):
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ class CustomPassThatFails:
+ def __call__(self, m):
+ print("hello from pass that fails")
+ self.signal_pass_failure()
+
+ custom_pass_that_fails = CustomPassThatFails()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")
>From 2241e278856ed01ab912da5eb5567b13f02358d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 13:03:13 +0800
Subject: [PATCH 2/2] refine the bad path
---
mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++++-
mlir/test/python/python_pass.py | 6 ++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index fb7dc2705b3ce..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -196,7 +196,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
nb::delattr(callable, mlirExternalPassAttr);
};
nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
- nb::capsule cap = run.attr(mlirExternalPassAttr);
+ nb::capsule cap;
+ try {
+ cap = run.attr(mlirExternalPassAttr);
+ } catch (nb::python_error &e) {
+ throw std::runtime_error(
+ "signal_pass_failure() should always be called "
+ "from the __call__ method");
+ }
mlirExternalPassSignalFailure(
MlirExternalPass{cap.data()});
}));
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 7734d76fcba94..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -103,3 +103,9 @@ def __call__(self, m):
pm.run(module)
except Exception as e:
print(f"caught exception: {e}")
+
+ # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+ try:
+ custom_pass_that_fails.signal_pass_failure()
+ except Exception as e:
+ print(f"caught exception: {e}")
More information about the Mlir-commits
mailing list