[Mlir-commits] [mlir] [MLIR][Python] Add the ability to signal pass failures in python-defined passes (PR #157613)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Sep  8 22:09:05 PDT 2025
    
    
  
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/157613
In this PR we add the ability to signal pass failures (`pass.signal_pass_failure()`) in python-defined passes.
To achieve this, we store the `MlirExternalPass` C pointer into an attribute of the pass object, and add a method `signal_pass_failure` into the pass object while it is added into the pass manager via `pm.add(..)`.
Note that the `signal_pass_failure()` method should be always called from `__call__` in the pass object since the `MlirExternalPass` should be only available in this context (otherwise a friendly exception message will be raised).
>From 8d17e5b831de56237324c34aefabf608b6d639b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 12:13:16 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add the ability to signal pass failures in
 python-defined passes
---
 mlir/lib/Bindings/Python/Pass.cpp | 18 ++++++++++++++++--
 mlir/test/python/python_pass.py   | 17 +++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..fb7dc2705b3ce 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
 
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+  constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -182,10 +184,22 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             callbacks.clone = [](void *) -> void * {
               throw std::runtime_error("Cloning Python passes not supported");
             };
-            callbacks.run = [](MlirOperation op, MlirExternalPass,
+            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
                                void *userData) {
-              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+              auto callable =
+                  nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+              nb::setattr(callable, mlirExternalPassAttr,
+                          nb::capsule(pass.ptr));
+              callable(op);
+              // delete it to avoid that it is used after
+              // the external pass is freed by the pass manager
+              nb::delattr(callable, mlirExternalPassAttr);
             };
+            nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+                          nb::capsule cap = run.attr(mlirExternalPassAttr);
+                          mlirExternalPassSignalFailure(
+                              MlirExternalPass{cap.data()});
+                        }));
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..7734d76fcba94 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,20 @@ def __call__(self, m):
         # CHECK: llvm.mul
         pm.add("convert-arith-to-llvm")
         pm.run(module)
+
+        # test signal_pass_failure
+        class CustomPassThatFails:
+            def __call__(self, m):
+                print("hello from pass that fails")
+                self.signal_pass_failure()
+
+        custom_pass_that_fails = CustomPassThatFails()
+
+        pm = PassManager("any")
+        pm.add(custom_pass_that_fails, "CustomPassThatFails")
+        # CHECK: hello from pass that fails
+        # CHECK: caught exception: Failure while executing pass pipeline
+        try:
+            pm.run(module)
+        except Exception as e:
+            print(f"caught exception: {e}")
>From 2241e278856ed01ab912da5eb5567b13f02358d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 13:03:13 +0800
Subject: [PATCH 2/2] refine the bad path
---
 mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++++-
 mlir/test/python/python_pass.py   | 6 ++++++
 2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index fb7dc2705b3ce..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -196,7 +196,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
               nb::delattr(callable, mlirExternalPassAttr);
             };
             nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
-                          nb::capsule cap = run.attr(mlirExternalPassAttr);
+                          nb::capsule cap;
+                          try {
+                            cap = run.attr(mlirExternalPassAttr);
+                          } catch (nb::python_error &e) {
+                            throw std::runtime_error(
+                                "signal_pass_failure() should always be called "
+                                "from the __call__ method");
+                          }
                           mlirExternalPassSignalFailure(
                               MlirExternalPass{cap.data()});
                         }));
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 7734d76fcba94..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -103,3 +103,9 @@ def __call__(self, m):
             pm.run(module)
         except Exception as e:
             print(f"caught exception: {e}")
+
+        # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+        try:
+            custom_pass_that_fails.signal_pass_failure()
+        except Exception as e:
+            print(f"caught exception: {e}")
    
    
More information about the Mlir-commits
mailing list