[Mlir-commits] [mlir] 7123463 - [MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 9 08:05:43 PDT 2025
Author: Twice
Date: 2025-09-09T08:05:39-07:00
New Revision: 7123463ef98bd478ff068dc94bfeee1b13c551c3
URL: https://github.com/llvm/llvm-project/commit/7123463ef98bd478ff068dc94bfeee1b13c551c3
DIFF: https://github.com/llvm/llvm-project/commit/7123463ef98bd478ff068dc94bfeee1b13c551c3.diff
LOG: [MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)
This is a follow-up PR for #156000.
In this PR we add the ability to signal pass failures
(`signal_pass_failure()`) in python-defined passes.
To achieve this, we expose `MlirExternalPass` via `nb::class_` with a
method `signal_pass_failure()`, and the callable passed to `pm.add(..)`
now accepts two arguments (`op: MlirOperation, pass_:
MlirExternalPass`).
For example:
```python
def custom_pass_that_fails(op, pass_):
if some_condition:
pass_.signal_pass_failure()
# do something
```
Added:
Modified:
mlir/lib/Bindings/Python/Pass.cpp
mlir/test/python/python_pass.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..47ef5d8e9dd3b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,13 @@ class PyPassManager {
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of MlirExternalPass
+ //----------------------------------------------------------------------------
+ nb::class_<MlirExternalPass>(m, "ExternalPass")
+ .def("signal_pass_failure",
+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ nb::handle(static_cast<PyObject *>(userData))(op, pass);
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..50c42102f66d3 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
"""
)
- def custom_pass_1(op):
+ def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
- def __call__(self, m):
- apply_patterns_and_fold_greedily(m, frozen)
+ def __call__(self, op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
custom_pass_2 = CustomPass2()
@@ -86,3 +86,17 @@ def __call__(self, m):
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ def custom_pass_that_fails(op, pass_):
+ print("hello from pass that fails")
+ pass_.signal_pass_failure()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")
More information about the Mlir-commits
mailing list