[Mlir-commits] [mlir] bc55364 - [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 07:09:34 PDT 2024


Author: tomnatan30
Date: 2024-04-18T16:09:31+02:00
New Revision: bc5536469d7854a043dbfe4c018e5b5dfc069d4f

URL: https://github.com/llvm/llvm-project/commit/bc5536469d7854a043dbfe4c018e5b5dfc069d4f
DIFF: https://github.com/llvm/llvm-project/commit/bc5536469d7854a043dbfe4c018e5b5dfc069d4f.diff

LOG: [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)

If the python callback throws an error, the c++ code will throw a
py::error_already_set that needs to be caught and handled in the c++
code .

This change is inspired by the similar solution in
PySymbolTable::walkSymbolTables.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d875f4eba2b139..01678a9719f90f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
     MlirWalkOrder walkOrder) {
   PyOperation &operation = getOperation();
   operation.checkValid();
+  struct UserData {
+    std::function<MlirWalkResult(MlirOperation)> callback;
+    bool gotException;
+    std::string exceptionWhat;
+    py::object exceptionType;
+  };
+  UserData userData{callback, false, {}, {}};
   MlirOperationWalkCallback walkCallback = [](MlirOperation op,
                                               void *userData) {
-    auto *fn =
-        static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
-    return (*fn)(op);
+    UserData *calleeUserData = static_cast<UserData *>(userData);
+    try {
+      return (calleeUserData->callback)(op);
+    } catch (py::error_already_set &e) {
+      calleeUserData->gotException = true;
+      calleeUserData->exceptionWhat = e.what();
+      calleeUserData->exceptionType = e.type();
+      return MlirWalkResult::MlirWalkResultInterrupt;
+    }
   };
-
-  mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
+  mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
+  if (userData.gotException) {
+    std::string message("Exception raised in callback: ");
+    message.append(userData.exceptionWhat);
+    throw std::runtime_error(message);
+  }
 }
 
 py::object PyOperationBase::getAsm(bool binary,

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9666e63bda1e0e..3a5d850b86e3a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1088,5 +1088,5 @@ def callback(op):
 
     try:
         module.operation.walk(callback)
-    except ValueError:
+    except RuntimeError:
         print("Exception raised")


        


More information about the Mlir-commits mailing list