[Mlir-commits] [mlir] [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (PR #89225)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 18 05:22:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (tomnatan30)
<details>
<summary>Changes</summary>
If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code .
This change is inspired by the similar solution in PySymbolTable::walkSymbolTables.
---
Full diff: https://github.com/llvm/llvm-project/pull/89225.diff
2 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+22-5)
- (modified) mlir/test/python/ir/operation.py (+1-1)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..0a12c53ac00abd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
+ struct UserData {
+ std::function<MlirWalkResult(MlirOperation)> callback;
+ bool gotException;
+ std::string exceptionWhat;
+ py::object exceptionType;
+ };
+ UserData userData{.callback = callback};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
- auto *fn =
- static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
- return (*fn)(op);
+ UserData *calleeUserData = static_cast<UserData *>(userData);
+ try {
+ return (calleeUserData->callback)(op);
+ } catch (py::error_already_set &e) {
+ calleeUserData->gotException = true;
+ calleeUserData->exceptionWhat = e.what();
+ calleeUserData->exceptionType = e.type();
+ return MlirWalkResult::MlirWalkResultInterrupt;
+ }
};
-
- mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+ mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+ if (userData.gotException) {
+ std::string message("Exception raised in callback: ");
+ message.append(userData.exceptionWhat);
+ throw std::runtime_error(message);
+ }
}
py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
try:
module.operation.walk(callback)
- except ValueError:
+ except RuntimeError:
print("Exception raised")
``````````
</details>
https://github.com/llvm/llvm-project/pull/89225
More information about the Mlir-commits
mailing list