[Mlir-commits] [mlir] [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (PR #89225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 05:22:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (tomnatan30)

<details>
<summary>Changes</summary>

If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code . 

This change is inspired by the similar solution in PySymbolTable::walkSymbolTables.

---
Full diff: https://github.com/llvm/llvm-project/pull/89225.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+22-5) 
- (modified) mlir/test/python/ir/operation.py (+1-1) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..0a12c53ac00abd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
     MlirWalkOrder walkOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
+  struct UserData {
+    std::function<MlirWalkResult(MlirOperation)> callback;
+    bool gotException;
+    std::string exceptionWhat;
+    py::object exceptionType;
+  };
+  UserData userData{.callback = callback};
   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
-    auto *fn =
-        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
-    return (*fn)(op);
+    UserData *calleeUserData = static_cast<UserData *>(userData);
+    try {
+      return (calleeUserData->callback)(op);
+    } catch (py::error_already_set &e) {
+      calleeUserData->gotException = true;
+      calleeUserData->exceptionWhat = e.what();
+      calleeUserData->exceptionType = e.type();
+      return MlirWalkResult::MlirWalkResultInterrupt;
+    }
   };
-
-  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+  if (userData.gotException) {
+    std::string message("Exception raised in callback: ");
+    message.append(userData.exceptionWhat);
+    throw std::runtime_error(message);
+  }
 }
 
 py::object PyOperationBase::getAsm(bool binary,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
 
     try:
         module.operation.walk(callback)
-    except ValueError:
+    except RuntimeError:
         print("Exception raised")

``````````

</details>


https://github.com/llvm/llvm-project/pull/89225


More information about the Mlir-commits mailing list