[Mlir-commits] [mlir] 7a4d630 - Add a "register_runtime" method to the mlir.execution_engine and show calling back from MLIR into Python

Mehdi Amini llvmlistbot at llvm.org
Mon Apr 12 02:55:45 PDT 2021


Author: Mehdi Amini
Date: 2021-03-30T17:04:38Z
New Revision: 7a4d630764829ad40738ae4e5944a411529728ef

URL: https://github.com/llvm/llvm-project/commit/7a4d630764829ad40738ae4e5944a411529728ef
DIFF: https://github.com/llvm/llvm-project/commit/7a4d630764829ad40738ae4e5944a411529728ef.diff

LOG: Add a "register_runtime" method to the mlir.execution_engine and show calling back from MLIR into Python

This exposes the ability to register Python functions with the JIT and
exposes them to the MLIR jitted code. The provided test case illustrates
the mechanism.

Differential Revision: https://reviews.llvm.org/D99562

Added: 
    

Modified: 
    mlir/include/mlir-c/ExecutionEngine.h
    mlir/lib/Bindings/Python/ExecutionEngine.cpp
    mlir/lib/Bindings/Python/mlir/execution_engine.py
    mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
    mlir/test/Bindings/Python/execution_engine.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h
index c2563577160a4..5210f108ee6b0 100644
--- a/mlir/include/mlir-c/ExecutionEngine.h
+++ b/mlir/include/mlir-c/ExecutionEngine.h
@@ -61,6 +61,12 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked(
 MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
                                                    MlirStringRef name);
 
+/// Register a symbol with the jit: this symbol will be accessible to the jitted
+/// code.
+MLIR_CAPI_EXPORTED void
+mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name,
+                                  void *sym);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
index 5ca9b1f681286..0e8ae8b38b3cb 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
@@ -81,7 +81,17 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) {
             auto *res = mlirExecutionEngineLookup(
                 executionEngine.get(),
                 mlirStringRefCreate(func.c_str(), func.size()));
-            return (int64_t)res;
+            return reinterpret_cast<uintptr_t>(res);
+          },
+          "Lookup function `func` in the ExecutionEngine.")
+      .def(
+          "raw_register_runtime",
+          [](PyExecutionEngine &executionEngine, const std::string &name,
+             uintptr_t sym) {
+            mlirExecutionEngineRegisterSymbol(
+                executionEngine.get(),
+                mlirStringRefCreate(name.c_str(), name.size()),
+                reinterpret_cast<void *>(sym));
           },
           "Lookup function `func` in the ExecutionEngine.");
 }

diff  --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/lib/Bindings/Python/mlir/execution_engine.py
index 89bd4aad56582..39d9501d9c8bc 100644
--- a/mlir/lib/Bindings/Python/mlir/execution_engine.py
+++ b/mlir/lib/Bindings/Python/mlir/execution_engine.py
@@ -29,3 +29,11 @@ def invoke(self, name, *ctypes_args):
     for argNum in range(len(ctypes_args)):
       packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
     func(packed_args)
+
+  def register_runtime(self, name, ctypes_callback):
+    """Register a runtime function available to the jitted code
+    under the provided `name`. The `ctypes_callback` must be a
+    `CFuncType` that outlives the execution engine.
+    """
+    callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value
+    self.raw_register_runtime("_mlir_ciface_" + name, callback)

diff  --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 68137c0674c53..345eac2193d8f 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -11,6 +11,7 @@
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "llvm/ExecutionEngine/Orc/Mangling.h"
 #include "llvm/Support/TargetSelect.h"
 
 using namespace mlir;
@@ -54,3 +55,14 @@ extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
     return nullptr;
   return reinterpret_cast<void *>(*expectedFPtr);
 }
+
+extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
+                                                  MlirStringRef name,
+                                                  void *sym) {
+  unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
+    llvm::orc::SymbolMap symbolMap;
+    symbolMap[interner(unwrap(name))] =
+        llvm::JITEvaluatedSymbol::fromPointer(sym);
+    return symbolMap;
+  });
+}

diff  --git a/mlir/test/Bindings/Python/execution_engine.py b/mlir/test/Bindings/Python/execution_engine.py
index 0706ea4e2e443..9ef4dceea2a81 100644
--- a/mlir/test/Bindings/Python/execution_engine.py
+++ b/mlir/test/Bindings/Python/execution_engine.py
@@ -97,3 +97,37 @@ def testInvokeFloatAdd():
     log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
 
 run(testInvokeFloatAdd)
+
+
+# Test callback
+# CHECK-LABEL: TEST: testBasicCallback
+def testBasicCallback():
+  # Define a callback function that takes a float and an integer and returns a float.
+  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
+  def callback(a, b):
+    return a/2 + b/2
+
+  with Context():
+    # The module just forwards to a runtime function known as "some_callback_into_python".
+    module = Module.parse(r"""
+func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
+  %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
+  return %resf : f32
+}
+func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
+    """)
+    execution_engine = ExecutionEngine(lowerToLLVM(module))
+    execution_engine.register_runtime("some_callback_into_python", callback)
+
+    # Prepare arguments: two input floats and one result.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    c_int_p = ctypes.c_int * 1
+    arg0 = c_float_p(42.)
+    arg1 = c_int_p(2)
+    res = c_float_p(-1.)
+    execution_engine.invoke("add", arg0, arg1, res)
+    # CHECK: 42.0 + 2 = 44.0
+    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))
+
+run(testBasicCallback)


        


More information about the Mlir-commits mailing list