[Mlir-commits] [mlir] bc55364 - [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 18 07:09:34 PDT 2024
Author: tomnatan30
Date: 2024-04-18T16:09:31+02:00
New Revision: bc5536469d7854a043dbfe4c018e5b5dfc069d4f
URL: https://github.com/llvm/llvm-project/commit/bc5536469d7854a043dbfe4c018e5b5dfc069d4f
DIFF: https://github.com/llvm/llvm-project/commit/bc5536469d7854a043dbfe4c018e5b5dfc069d4f.diff
LOG: [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)
If the python callback throws an error, the c++ code will throw a
py::error_already_set that needs to be caught and handled in the c++
code .
This change is inspired by the similar solution in
PySymbolTable::walkSymbolTables.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..01678a9719f90f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
+ struct UserData {
+ std::function<MlirWalkResult(MlirOperation)> callback;
+ bool gotException;
+ std::string exceptionWhat;
+ py::object exceptionType;
+ };
+ UserData userData{callback, false, {}, {}};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
- auto *fn =
- static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
- return (*fn)(op);
+ UserData *calleeUserData = static_cast<UserData *>(userData);
+ try {
+ return (calleeUserData->callback)(op);
+ } catch (py::error_already_set &e) {
+ calleeUserData->gotException = true;
+ calleeUserData->exceptionWhat = e.what();
+ calleeUserData->exceptionType = e.type();
+ return MlirWalkResult::MlirWalkResultInterrupt;
+ }
};
-
- mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ if (userData.gotException) {
+ std::string message("Exception raised in callback: ");
+ message.append(userData.exceptionWhat);
+ throw std::runtime_error(message);
+ }
}
py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
try:
module.operation.walk(callback)
- except ValueError:
+ except RuntimeError:
print("Exception raised")
More information about the Mlir-commits
mailing list