[Mlir-commits] [mlir] 7123463 - [MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 9 08:05:43 PDT 2025


Author: Twice
Date: 2025-09-09T08:05:39-07:00
New Revision: 7123463ef98bd478ff068dc94bfeee1b13c551c3

URL: https://github.com/llvm/llvm-project/commit/7123463ef98bd478ff068dc94bfeee1b13c551c3
DIFF: https://github.com/llvm/llvm-project/commit/7123463ef98bd478ff068dc94bfeee1b13c551c3.diff

LOG: [MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)

This is a follow-up PR for #156000.

In this PR we add the ability to signal pass failures
(`signal_pass_failure()`) in python-defined passes.

To achieve this, we expose `MlirExternalPass` via `nb::class_` with a
method `signal_pass_failure()`, and the callable passed to `pm.add(..)`
now accepts two arguments (`op: MlirOperation, pass_:
MlirExternalPass`).

For example:
```python
def custom_pass_that_fails(op, pass_):
    if some_condition:
        pass_.signal_pass_failure()
    # do something
```

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/test/python/python_pass.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..47ef5d8e9dd3b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,13 @@ class PyPassManager {
 
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of MlirExternalPass
+  //----------------------------------------------------------------------------
+  nb::class_<MlirExternalPass>(m, "ExternalPass")
+      .def("signal_pass_failure",
+           [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             callbacks.clone = [](void *) -> void * {
               throw std::runtime_error("Cloning Python passes not supported");
             };
-            callbacks.run = [](MlirOperation op, MlirExternalPass,
+            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
                                void *userData) {
-              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+              nb::handle(static_cast<PyObject *>(userData))(op, pass);
             };
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),

diff  --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..50c42102f66d3 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
         """
         )
 
-        def custom_pass_1(op):
+        def custom_pass_1(op, pass_):
             print("hello from pass 1!!!", file=sys.stderr)
 
         class CustomPass2:
-            def __call__(self, m):
-                apply_patterns_and_fold_greedily(m, frozen)
+            def __call__(self, op, pass_):
+                apply_patterns_and_fold_greedily(op, frozen)
 
         custom_pass_2 = CustomPass2()
 
@@ -86,3 +86,17 @@ def __call__(self, m):
         # CHECK: llvm.mul
         pm.add("convert-arith-to-llvm")
         pm.run(module)
+
+        # test signal_pass_failure
+        def custom_pass_that_fails(op, pass_):
+            print("hello from pass that fails")
+            pass_.signal_pass_failure()
+
+        pm = PassManager("any")
+        pm.add(custom_pass_that_fails, "CustomPassThatFails")
+        # CHECK: hello from pass that fails
+        # CHECK: caught exception: Failure while executing pass pipeline
+        try:
+            pm.run(module)
+        except Exception as e:
+            print(f"caught exception: {e}")


        


More information about the Mlir-commits mailing list