[Mlir-commits] [mlir] [MLIR][Python] Add the ability to signal pass failures in python-defined passes (PR #157613)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 9 06:20:28 PDT 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/157613

>From 8d17e5b831de56237324c34aefabf608b6d639b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 12:13:16 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add the ability to signal pass failures in
 python-defined passes

---
 mlir/lib/Bindings/Python/Pass.cpp | 18 ++++++++++++++++--
 mlir/test/python/python_pass.py   | 17 +++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..fb7dc2705b3ce 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
 
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+  constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -182,10 +184,22 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             callbacks.clone = [](void *) -> void * {
               throw std::runtime_error("Cloning Python passes not supported");
             };
-            callbacks.run = [](MlirOperation op, MlirExternalPass,
+            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
                                void *userData) {
-              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+              auto callable =
+                  nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+              nb::setattr(callable, mlirExternalPassAttr,
+                          nb::capsule(pass.ptr));
+              callable(op);
+              // delete it to avoid that it is used after
+              // the external pass is freed by the pass manager
+              nb::delattr(callable, mlirExternalPassAttr);
             };
+            nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+                          nb::capsule cap = run.attr(mlirExternalPassAttr);
+                          mlirExternalPassSignalFailure(
+                              MlirExternalPass{cap.data()});
+                        }));
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..7734d76fcba94 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,20 @@ def __call__(self, m):
         # CHECK: llvm.mul
         pm.add("convert-arith-to-llvm")
         pm.run(module)
+
+        # test signal_pass_failure
+        class CustomPassThatFails:
+            def __call__(self, m):
+                print("hello from pass that fails")
+                self.signal_pass_failure()
+
+        custom_pass_that_fails = CustomPassThatFails()
+
+        pm = PassManager("any")
+        pm.add(custom_pass_that_fails, "CustomPassThatFails")
+        # CHECK: hello from pass that fails
+        # CHECK: caught exception: Failure while executing pass pipeline
+        try:
+            pm.run(module)
+        except Exception as e:
+            print(f"caught exception: {e}")

>From 2241e278856ed01ab912da5eb5567b13f02358d8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 13:03:13 +0800
Subject: [PATCH 2/3] refine the bad path

---
 mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++++-
 mlir/test/python/python_pass.py   | 6 ++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index fb7dc2705b3ce..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -196,7 +196,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
               nb::delattr(callable, mlirExternalPassAttr);
             };
             nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
-                          nb::capsule cap = run.attr(mlirExternalPassAttr);
+                          nb::capsule cap;
+                          try {
+                            cap = run.attr(mlirExternalPassAttr);
+                          } catch (nb::python_error &e) {
+                            throw std::runtime_error(
+                                "signal_pass_failure() should always be called "
+                                "from the __call__ method");
+                          }
                           mlirExternalPassSignalFailure(
                               MlirExternalPass{cap.data()});
                         }));
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 7734d76fcba94..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -103,3 +103,9 @@ def __call__(self, m):
             pm.run(module)
         except Exception as e:
             print(f"caught exception: {e}")
+
+        # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+        try:
+            custom_pass_that_fails.signal_pass_failure()
+        except Exception as e:
+            print(f"caught exception: {e}")

>From 28638ab495cd0cde1487d5540819aea64f57d014 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Sep 2025 21:19:11 +0800
Subject: [PATCH 3/3] drop the setattr design

---
 mlir/lib/Bindings/Python/Pass.cpp | 28 +++++++---------------------
 mlir/test/python/python_pass.py   | 16 +++++-----------
 2 files changed, 12 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index c5fe7bda4a680..ef606431fbd5e 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,7 +56,12 @@ class PyPassManager {
 
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
-  constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+  //----------------------------------------------------------------------------
+  // Mapping of MlirExternalPass
+  //----------------------------------------------------------------------------
+  nb::class_<MlirExternalPass>(m, "ExternalPass")
+      .def("signal_failure",
+           [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
 
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
@@ -186,27 +191,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             };
             callbacks.run = [](MlirOperation op, MlirExternalPass pass,
                                void *userData) {
-              auto callable =
-                  nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
-              nb::setattr(callable, mlirExternalPassAttr,
-                          nb::capsule(pass.ptr));
-              callable(op);
-              // delete it to avoid that it is used after
-              // the external pass is freed by the pass manager
-              nb::delattr(callable, mlirExternalPassAttr);
+              nb::handle(static_cast<PyObject *>(userData))(op, pass);
             };
-            nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
-                          nb::capsule cap;
-                          try {
-                            cap = run.attr(mlirExternalPassAttr);
-                          } catch (nb::python_error &e) {
-                            throw std::runtime_error(
-                                "signal_pass_failure() should always be called "
-                                "from the __call__ method");
-                          }
-                          mlirExternalPassSignalFailure(
-                              MlirExternalPass{cap.data()});
-                        }));
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 4784e073fef0a..10b449f9b1ef8 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
         """
         )
 
-        def custom_pass_1(op):
+        def custom_pass_1(op, pass_):
             print("hello from pass 1!!!", file=sys.stderr)
 
         class CustomPass2:
-            def __call__(self, m):
-                apply_patterns_and_fold_greedily(m, frozen)
+            def __call__(self, op, pass_):
+                apply_patterns_and_fold_greedily(op, frozen)
 
         custom_pass_2 = CustomPass2()
 
@@ -89,9 +89,9 @@ def __call__(self, m):
 
         # test signal_pass_failure
         class CustomPassThatFails:
-            def __call__(self, m):
+            def __call__(self, op, pass_):
                 print("hello from pass that fails")
-                self.signal_pass_failure()
+                pass_.signal_failure()
 
         custom_pass_that_fails = CustomPassThatFails()
 
@@ -103,9 +103,3 @@ def __call__(self, m):
             pm.run(module)
         except Exception as e:
             print(f"caught exception: {e}")
-
-        # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
-        try:
-            custom_pass_that_fails.signal_pass_failure()
-        except Exception as e:
-            print(f"caught exception: {e}")



More information about the Mlir-commits mailing list