[Mlir-commits] [mlir] [mlir] Make C/Python ExecutionEngine constructible with an Operation. (PR #86329)

Stella Laurenzo llvmlistbot at llvm.org
Fri Mar 22 12:24:02 PDT 2024


https://github.com/stellaraccident created https://github.com/llvm/llvm-project/pull/86329

This continues the long deprivileging of mlir.ir.Module as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated.

The `ExecutionEngine()` constructor was extended to accept either a `Module` or an `Operation`, so there should be no user-level API breakage. Test was added to verify.

Python ExecutionEngine tests were modernized to use `Operation.parse` and explicit outer modules.

>From 9b3a14e425150a42b08e1f046b4646e6e5939920 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Fri, 22 Mar 2024 11:41:40 -0700
Subject: [PATCH] [mlir] Make C/Python ExecutionEngine constructible with an
 Operation.

This continues the long deprivileging of mlir.ir.Module as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated.

The `ExecutionEngine()` constructor was extended to accept either a `Module` or an `Operation`, so there should be no user-level API breakage. Test was added to verify.

Python ExecutionEngine tests were modernized to use `Operation.parse` and explicit outer modules.
---
 mlir/include/mlir-c/ExecutionEngine.h         |   9 +-
 .../Bindings/Python/ExecutionEngineModule.cpp |  25 ++++-
 .../CAPI/ExecutionEngine/ExecutionEngine.cpp  |  16 ++-
 .../mlir/_mlir_libs/_mlirExecutionEngine.pyi  |   4 +-
 mlir/test/python/execution_engine.py          | 105 +++++++++++++-----
 5 files changed, 119 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h
index 99cddc5c2598d4..311451a029181b 100644
--- a/mlir/include/mlir-c/ExecutionEngine.h
+++ b/mlir/include/mlir-c/ExecutionEngine.h
@@ -42,8 +42,15 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void);
 /// that will be loaded are specified via `numPaths` and `sharedLibPaths`
 /// respectively.
 /// TODO: figure out other options.
+MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreateFromOp(
+    MlirOperation op, int optLevel, int numPaths,
+    const MlirStringRef *sharedLibPaths, bool enableObjectDump);
+
+// Deprecated variant which takes an MlirModule instead of an operation.
+// This is being preserved as of 2024-Mar for short term consistency and should
+// be dropped soon.
 MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
-    MlirModule op, int optLevel, int numPaths,
+    MlirModule module, int optLevel, int numPaths,
     const MlirStringRef *sharedLibPaths, bool enableObjectDump);
 
 /// Destroy an ExecutionEngine instance.
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..9ed5ee80f97f8b 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -71,15 +71,34 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
   py::class_<PyExecutionEngine>(m, "ExecutionEngine", py::module_local())
-      .def(py::init<>([](MlirModule module, int optLevel,
+      .def(py::init<>([](py::object operation_or_module, int optLevel,
                          const std::vector<std::string> &sharedLibPaths,
                          bool enableObjectDump) {
+             // Manually type cast from either a Module or Operation. The
+             // automatic type casters do not handle such cascades well,
+             // so be explicit.
+             py::object capsule = mlirApiObjectToCapsule(operation_or_module);
+             MlirOperation module_op =
+                 mlirPythonCapsuleToOperation(capsule.ptr());
+             if (mlirOperationIsNull(module_op)) {
+               // If null, then a PyErr_Set has set an exception, which we must
+               // clear.
+               PyErr_Clear();
+               MlirModule mod = mlirPythonCapsuleToModule(capsule.ptr());
+               if (mlirModuleIsNull(mod)) {
+                 throw py::type_error(
+                     "ExecutionEngine expects a Module or Operation");
+               }
+               module_op = mlirModuleGetOperation(mod);
+             }
+
              llvm::SmallVector<MlirStringRef, 4> libPaths;
              for (const std::string &path : sharedLibPaths)
                libPaths.push_back({path.c_str(), path.length()});
              MlirExecutionEngine executionEngine =
-                 mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
-                                           libPaths.data(), enableObjectDump);
+                 mlirExecutionEngineCreateFromOp(
+                     module_op, optLevel, libPaths.size(), libPaths.data(),
+                     enableObjectDump);
              if (mlirExecutionEngineIsNull(executionEngine))
                throw std::runtime_error(
                    "Failure while creating the ExecutionEngine.");
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 507be9171d328d..8bd7e8b354f341 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -20,9 +20,18 @@
 using namespace mlir;
 
 extern "C" MlirExecutionEngine
-mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
+mlirExecutionEngineCreate(MlirModule module, int optLevel, int numPaths,
                           const MlirStringRef *sharedLibPaths,
                           bool enableObjectDump) {
+  return mlirExecutionEngineCreateFromOp(mlirModuleGetOperation(module),
+                                         optLevel, numPaths, sharedLibPaths,
+                                         enableObjectDump);
+}
+
+extern "C" MlirExecutionEngine
+mlirExecutionEngineCreateFromOp(MlirOperation op, int optLevel, int numPaths,
+                                const MlirStringRef *sharedLibPaths,
+                                bool enableObjectDump) {
   static bool initOnce = [] {
     llvm::InitializeNativeTarget();
     llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm
@@ -104,9 +113,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
                                                   void *sym) {
   unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
     llvm::orc::SymbolMap symbolMap;
-    symbolMap[interner(unwrap(name))] =
-        { llvm::orc::ExecutorAddr::fromPtr(sym),
-          llvm::JITSymbolFlags::Exported };
+    symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+                                         llvm::JITSymbolFlags::Exported};
     return symbolMap;
   });
 }
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
index 893dab8a431fd1..c32b5db13241c0 100644
--- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
@@ -4,7 +4,7 @@
 #   * Relative imports for cross-module references.
 #   * Add __all__
 
-from typing import List, Sequence
+from typing import List, Sequence,Union
 
 from ._mlir import ir as _ir
 
@@ -13,7 +13,7 @@ __all__ = [
 ]
 
 class ExecutionEngine:
-    def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
+    def __init__(self, module: Union[_ir.Operation, _ir.Module], opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
     def _CAPICreate(self) -> object: ...
     def _testing_release(self) -> None: ...
     def dump_to_object_file(self, file_name: str) -> None: ...
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907d..647e6667b69a34 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -21,17 +21,46 @@ def run(f):
     assert Context._get_live_count() == 0
 
 
-# Verify capsule interop.
-# CHECK-LABEL: TEST: testCapsule
-def testCapsule():
+# Verify capsule interop for passing an Operation.
+# CHECK-LABEL: TEST: testAcceptsOperation
+def testAcceptsOperation():
+    with Context():
+        module = Operation.parse(
+            r"""
+builtin.module {
+llvm.func @none() {
+llvm.return
+}
+}
+    """
+        )
+        execution_engine = ExecutionEngine(module)
+        execution_engine_capsule = execution_engine._CAPIPtr
+        # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
+        log(repr(execution_engine_capsule))
+        execution_engine._testing_release()
+        execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
+        # CHECK: _mlirExecutionEngine.ExecutionEngine
+        log(repr(execution_engine1))
+
+
+run(testAcceptsOperation)
+
+
+# Verify capsule interop for passing a Module.
+# CHECK-LABEL: TEST: testAcceptsModule
+def testAcceptsModule():
     with Context():
         module = Module.parse(
             r"""
+builtin.module {
 llvm.func @none() {
-  llvm.return
+llvm.return
+}
 }
     """
         )
+        print("MODULE:", type(module))
         execution_engine = ExecutionEngine(module)
         execution_engine_capsule = execution_engine._CAPIPtr
         # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
@@ -42,7 +71,7 @@ def testCapsule():
         log(repr(execution_engine1))
 
 
-run(testCapsule)
+run(testAcceptsModule)
 
 
 # Test invalid ExecutionEngine creation
@@ -50,9 +79,11 @@ def testCapsule():
 def testInvalidModule():
     with Context():
         # Builtin function
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+    builtin.module {
     func.func @foo() { return }
+    }
     """
         )
         # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
@@ -69,7 +100,7 @@ def lowerToLLVM(module):
     pm = PassManager.parse(
         "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)"
     )
-    pm.run(module.operation)
+    pm.run(module)
     return module
 
 
@@ -77,10 +108,12 @@ def lowerToLLVM(module):
 # CHECK-LABEL: TEST: testInvokeVoid
 def testInvokeVoid():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @void() attributes { llvm.emit_c_interface } {
   return
+}
 }
     """
         )
@@ -96,11 +129,13 @@ def testInvokeVoid():
 # CHECK-LABEL: TEST: testInvokeFloatAdd
 def testInvokeFloatAdd():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
   %add = arith.addf %arg0, %arg1 : f32
   return %add : f32
+}
 }
     """
         )
@@ -129,13 +164,15 @@ def callback(a, b):
 
     with Context():
         # The module just forwards to a runtime function known as "some_callback_into_python".
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
   %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
   return %resf : f32
 }
 func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
+}
     """
         )
         execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -168,13 +205,15 @@ def callback(a):
 
     with Context():
         # The module just forwards to a runtime function known as "some_callback_into_python".
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
   call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
   return
 }
 func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
+}
 """
         )
         execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -221,13 +260,15 @@ def callback(a):
 
     with Context():
         # The module just forwards to a runtime function known as "some_callback_into_python".
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
   call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
   return
 }
 func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
+}
 """
         )
         execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -262,8 +303,9 @@ def callback(a):
 
     with Context():
         # The module takes a subview of the argument memref and calls the callback with it
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
   %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
   %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
@@ -272,6 +314,7 @@ def callback(a):
   return
 }
 func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
+}
 """
         )
         execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -301,8 +344,9 @@ def callback(a):
     with Context():
         # The module takes a subview of the argument memref, casts it to an unranked memref and 
         # calls the callback with it.
-        module = Module.parse(
+        module = Operation.parse(
             r"""
+builtin.module {
 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
     %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
     %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
@@ -311,6 +355,7 @@ def callback(a):
     return
 }
 func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
 """
         )
         execution_engine = ExecutionEngine(lowerToLLVM(module))
@@ -330,9 +375,9 @@ def callback(a):
 # CHECK-LABEL: TEST: testMemrefAdd
 def testMemrefAdd():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-    module  {
+    builtin.module  {
       func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
         %0 = arith.constant 0 : index
         %1 = memref.load %arg0[%0] : memref<1xf32>
@@ -372,9 +417,9 @@ def testMemrefAdd():
 # CHECK-LABEL: TEST: testF16MemrefAdd
 def testF16MemrefAdd():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-    module  {
+    builtin.module  {
       func.func @main(%arg0: memref<1xf16>,
                       %arg1: memref<1xf16>,
                       %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
@@ -422,9 +467,9 @@ def testF16MemrefAdd():
 # CHECK-LABEL: TEST: testComplexMemrefAdd
 def testComplexMemrefAdd():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-    module  {
+    builtin.module  {
       func.func @main(%arg0: memref<1xcomplex<f64>>,
                       %arg1: memref<1xcomplex<f64>>,
                       %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
@@ -472,9 +517,9 @@ def testComplexMemrefAdd():
 # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
 def testComplexUnrankedMemrefAdd():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-    module  {
+    builtin.module  {
       func.func @main(%arg0: memref<*xcomplex<f32>>,
                       %arg1: memref<*xcomplex<f32>>,
                       %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
@@ -525,9 +570,9 @@ def testComplexUnrankedMemrefAdd():
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-      module  {
+      builtin.module  {
         func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
           %c0 = arith.constant 0 : index
           %c2 = arith.constant 2 : index
@@ -589,9 +634,9 @@ def testDynamicMemrefAdd2D():
 # CHECK-LABEL: TEST: testSharedLibLoad
 def testSharedLibLoad():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-      module  {
+      builtin.module  {
       func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
         %c0 = arith.constant 0 : index
         %cst42 = arith.constant 42.0 : f32
@@ -640,9 +685,9 @@ def testSharedLibLoad():
 # CHECK-LABEL: TEST: testNanoTime
 def testNanoTime():
     with Context():
-        module = Module.parse(
+        module = Operation.parse(
             """
-      module {
+      builtin.module {
       func.func @main() attributes { llvm.emit_c_interface } {
         %now = call @nanoTime() : () -> i64
         %memref = memref.alloca() : memref<1xi64>
@@ -686,9 +731,9 @@ def testDumpToObjectFile():
 
     try:
         with Context():
-            module = Module.parse(
+            module = Operation.parse(
                 """
-        module {
+        builtin.module {
         func.func @main() attributes { llvm.emit_c_interface } {
           return
         }



More information about the Mlir-commits mailing list