[Mlir-commits] [mlir] [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (PR #89225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 05:46:41 PDT 2024


https://github.com/tomnatan30 updated https://github.com/llvm/llvm-project/pull/89225

>From 40b2e7b51513f0dad94da8ec1e4a2bcc088600e5 Mon Sep 17 00:00:00 2001
From: tomnatan <tomnatan at google.com>
Date: Thu, 18 Apr 2024 12:19:18 +0000
Subject: [PATCH] Fix PyOperationBase::walk not catching exception in python
 callback (py::error_already_set)

---
 mlir/lib/Bindings/Python/IRCore.cpp | 27 ++++++++++++++++++++++-----
 mlir/test/python/ir/operation.py    |  2 +-
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..0a12c53ac00abd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
     MlirWalkOrder walkOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
+  struct UserData {
+    std::function<MlirWalkResult(MlirOperation)> callback;
+    bool gotException;
+    std::string exceptionWhat;
+    py::object exceptionType;
+  };
+  UserData userData{.callback = callback};
   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
-    auto *fn =
-        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
-    return (*fn)(op);
+    UserData *calleeUserData = static_cast<UserData *>(userData);
+    try {
+      return (calleeUserData->callback)(op);
+    } catch (py::error_already_set &e) {
+      calleeUserData->gotException = true;
+      calleeUserData->exceptionWhat = e.what();
+      calleeUserData->exceptionType = e.type();
+      return MlirWalkResult::MlirWalkResultInterrupt;
+    }
   };
-
-  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+  if (userData.gotException) {
+    std::string message("Exception raised in callback: ");
+    message.append(userData.exceptionWhat);
+    throw std::runtime_error(message);
+  }
 }
 
 py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
 
     try:
         module.operation.walk(callback)
-    except ValueError:
+    except RuntimeError:
         print("Exception raised")



More information about the Mlir-commits mailing list