[Mlir-commits] [mlir] [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (PR #89225)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 18 05:46:41 PDT 2024
https://github.com/tomnatan30 updated https://github.com/llvm/llvm-project/pull/89225
>From 40b2e7b51513f0dad94da8ec1e4a2bcc088600e5 Mon Sep 17 00:00:00 2001
From: tomnatan <tomnatan at google.com>
Date: Thu, 18 Apr 2024 12:19:18 +0000
Subject: [PATCH] Fix PyOperationBase::walk not catching exception in python
callback (py::error_already_set)
---
mlir/lib/Bindings/Python/IRCore.cpp | 27 ++++++++++++++++++++++-----
mlir/test/python/ir/operation.py | 2 +-
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..0a12c53ac00abd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
+ struct UserData {
+ std::function<MlirWalkResult(MlirOperation)> callback;
+ bool gotException;
+ std::string exceptionWhat;
+ py::object exceptionType;
+ };
+ UserData userData{.callback = callback};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
- auto *fn =
- static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
- return (*fn)(op);
+ UserData *calleeUserData = static_cast<UserData *>(userData);
+ try {
+ return (calleeUserData->callback)(op);
+ } catch (py::error_already_set &e) {
+ calleeUserData->gotException = true;
+ calleeUserData->exceptionWhat = e.what();
+ calleeUserData->exceptionType = e.type();
+ return MlirWalkResult::MlirWalkResultInterrupt;
+ }
};
-
- mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ if (userData.gotException) {
+ std::string message("Exception raised in callback: ");
+ message.append(userData.exceptionWhat);
+ throw std::runtime_error(message);
+ }
}
py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
try:
module.operation.walk(callback)
- except ValueError:
+ except RuntimeError:
print("Exception raised")
More information about the Mlir-commits
mailing list