[llvm] 39f64c4 - [ORC] Add wrapper-function support methods to ExecutorProcessControl.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 1 01:24:13 PDT 2021


Author: Lang Hames
Date: 2021-07-01T18:21:49+10:00
New Revision: 39f64c4c83754b4e436d7fffa31bd70f11d7a657

URL: https://github.com/llvm/llvm-project/commit/39f64c4c83754b4e436d7fffa31bd70f11d7a657
DIFF: https://github.com/llvm/llvm-project/commit/39f64c4c83754b4e436d7fffa31bd70f11d7a657.diff

LOG: [ORC] Add wrapper-function support methods to ExecutorProcessControl.

Adds support for both synchronous and asynchronous calls to wrapper functions
using SPS (Simple Packed Serialization). Also adds support for wrapping
functions on the JIT side in SPS-based wrappers that can be called from the
executor.

These new methods simplify calls between the JIT and Executor, and will be used
in upcoming ORC runtime patches to enable communication between ORC and the
runtime.

Added: 
    llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp

Modified: 
    llvm/include/llvm/ExecutionEngine/Orc/Core.h
    llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
    llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h
    llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
    llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
    llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
    llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
index ae826912d629f..42bcffd36b25a 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
@@ -216,6 +216,18 @@ class SymbolLookupSet {
       add(Name, Flags);
   }
 
+  /// Construct a SymbolLookupSet from DenseMap keys.
+  template <typename KeyT>
+  static SymbolLookupSet
+  fromMapKeys(const DenseMap<SymbolStringPtr, KeyT> &M,
+              SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) {
+    SymbolLookupSet Result;
+    Result.Symbols.reserve(M.size());
+    for (const auto &KV : M)
+      Result.add(KV.first, Flags);
+    return Result;
+  }
+
   /// Add an element to the set. The client is responsible for checking that
   /// duplicates are not added.
   SymbolLookupSet &

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
index 7969a8398c952..566637e104456 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
@@ -24,6 +24,7 @@
 #include "llvm/Support/MSVCErrorWorkarounds.h"
 
 #include <future>
+#include <mutex>
 #include <vector>
 
 namespace llvm {
@@ -32,6 +33,19 @@ namespace orc {
 /// ExecutorProcessControl supports interaction with a JIT target process.
 class ExecutorProcessControl {
 public:
+  /// Sender to return the result of a WrapperFunction executed in the JIT.
+  using SendResultFunction =
+      unique_function<void(shared::WrapperFunctionResult)>;
+
+  /// An asynchronous wrapper-function.
+  using AsyncWrapperFunction = unique_function<void(
+      SendResultFunction SendResult, const char *ArgData, size_t ArgSize)>;
+
+  /// A map associating tag names with asynchronous wrapper function
+  /// implementations in the JIT.
+  using WrapperFunctionAssociationMap =
+      DenseMap<SymbolStringPtr, AsyncWrapperFunction>;
+
   /// APIs for manipulating memory in the target process.
   class MemoryAccess {
   public:
@@ -138,14 +152,91 @@ class ExecutorProcessControl {
   virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
                                       ArrayRef<std::string> Args) = 0;
 
-  /// Run a wrapper function in the executor.
+  /// Run a wrapper function in the executor (async version).
+  ///
+  /// The wrapper function should be callable as:
+  ///
+  /// \code{.cpp}
+  ///   CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
+  /// \endcode{.cpp}
+  ///
+  /// The given OnComplete function will be called to return the result.
+  virtual void runWrapperAsync(SendResultFunction OnComplete,
+                               JITTargetAddress WrapperFnAddr,
+                               ArrayRef<char> ArgBuffer) = 0;
+
+  /// Run a wrapper function in the executor. The wrapper function should be
+  /// callable as:
   ///
   /// \code{.cpp}
   ///   CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
   /// \endcode{.cpp}
+  shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr,
+                                           ArrayRef<char> ArgBuffer) {
+    std::promise<shared::WrapperFunctionResult> RP;
+    auto RF = RP.get_future();
+    runWrapperAsync(
+        [&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); },
+        WrapperFnAddr, ArgBuffer);
+    return RF.get();
+  }
+
+  /// Run a wrapper function using SPS to serialize the arguments and
+  /// deserialize the results.
+  template <typename SPSSignature, typename SendResultT, typename... ArgTs>
+  void runSPSWrapperAsync(SendResultT &&SendResult,
+                          JITTargetAddress WrapperFnAddr,
+                          const ArgTs &...Args) {
+    shared::WrapperFunction<SPSSignature>::callAsync(
+        [this, WrapperFnAddr](SendResultFunction SendResult,
+                              const char *ArgData, size_t ArgSize) {
+          runWrapperAsync(std::move(SendResult), WrapperFnAddr,
+                          ArrayRef<char>(ArgData, ArgSize));
+        },
+        std::move(SendResult), Args...);
+  }
+
+  /// Run a wrapper function using SPS to serialize the arguments and
+  /// deserialize the results.
+  template <typename SPSSignature, typename RetT, typename... ArgTs>
+  Error runSPSWrapper(JITTargetAddress WrapperFnAddr, RetT &RetVal,
+                      const ArgTs &...Args) {
+    return shared::WrapperFunction<SPSSignature>::call(
+        [this, WrapperFnAddr](const char *ArgData, size_t ArgSize) {
+          return runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
+        },
+        RetVal, Args...);
+  }
+
+  /// Wrap a handler that takes concrete argument types (and a sender for a
+  /// concrete return type) to produce an AsyncWrapperFunction. Uses SPS to
+  /// unpack the arguments and pack the result.
   ///
-  virtual Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) = 0;
+  /// This function is usually used when building association maps.
+  template <typename SPSSignature, typename HandlerT>
+  static AsyncWrapperFunction wrapAsyncWithSPS(HandlerT &&H) {
+    return [H = std::forward<HandlerT>(H)](SendResultFunction SendResult,
+                                           const char *ArgData,
+                                           size_t ArgSize) mutable {
+      shared::WrapperFunction<SPSSignature>::handleAsync(ArgData, ArgSize, H,
+                                                         std::move(SendResult));
+    };
+  }
+
+  /// For each symbol name, associate the AsyncWrapperFunction implementation
+  /// value with the address of that symbol.
+  ///
+  /// Symbols will be looked up using LookupKind::Static,
+  /// JITDylibLookupFlags::MatchAllSymbols (hidden tags will be found), and
+  /// LookupFlags::WeaklyReferencedSymbol (missing tags will not cause an
+  /// error, the implementations will simply be dropped).
+  Error associateJITSideWrapperFunctions(JITDylib &JD,
+                                         WrapperFunctionAssociationMap WFs);
+
+  /// Run a registered jit-side wrapper function.
+  void runJITSideWrapperFunction(SendResultFunction SendResult,
+                                 JITTargetAddress TagAddr,
+                                 ArrayRef<char> ArgBuffer);
 
   /// Disconnect from the target process.
   ///
@@ -161,6 +252,9 @@ class ExecutorProcessControl {
   unsigned PageSize = 0;
   MemoryAccess *MemAccess = nullptr;
   jitlink::JITLinkMemoryManager *MemMgr = nullptr;
+
+  std::mutex TagToFuncMapMutex;
+  DenseMap<JITTargetAddress, std::shared_ptr<AsyncWrapperFunction>> TagToFunc;
 };
 
 /// Call a wrapper function via ExecutorProcessControl::runWrapper.
@@ -168,8 +262,8 @@ class EPCCaller {
 public:
   EPCCaller(ExecutorProcessControl &EPC, JITTargetAddress WrapperFnAddr)
       : EPC(EPC), WrapperFnAddr(WrapperFnAddr) {}
-  Expected<shared::WrapperFunctionResult> operator()(const char *ArgData,
-                                                     size_t ArgSize) const {
+  shared::WrapperFunctionResult operator()(const char *ArgData,
+                                           size_t ArgSize) const {
     return EPC.runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
   }
 
@@ -202,8 +296,9 @@ class SelfExecutorProcessControl
   Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
                               ArrayRef<std::string> Args) override;
 
-  Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) override;
+  void runWrapperAsync(SendResultFunction OnComplete,
+                       JITTargetAddress WrapperFnAddr,
+                       ArrayRef<char> ArgBuffer) override;
 
   Error disconnect() override;
 

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h b/llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h
index 0b5ee262bb706..69e37f9af9e43 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h
@@ -354,9 +354,9 @@ class OrcRPCExecutorProcessControlBase : public ExecutorProcessControl {
     return Result;
   }
 
-  Expected<shared::WrapperFunctionResult>
-  runWrapper(JITTargetAddress WrapperFnAddr,
-             ArrayRef<char> ArgBuffer) override {
+  void runWrapperAsync(SendResultFunction OnComplete,
+                       JITTargetAddress WrapperFnAddr,
+                       ArrayRef<char> ArgBuffer) override {
     DEBUG_WITH_TYPE("orc", {
       dbgs() << "Running as wrapper function "
              << formatv("{0:x16}", WrapperFnAddr) << " with "
@@ -366,7 +366,11 @@ class OrcRPCExecutorProcessControlBase : public ExecutorProcessControl {
         WrapperFnAddr,
         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()),
                           ArgBuffer.size()));
-    return Result;
+
+    if (!Result)
+      OnComplete(shared::WrapperFunctionResult::createOutOfBandError(
+          toString(Result.takeError())));
+    OnComplete(std::move(*Result));
   }
 
   Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
index 0fc8af770233c..ceaea1d2b20f2 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
@@ -172,17 +172,16 @@ class WrapperFunctionResult {
 namespace detail {
 
 template <typename SPSArgListT, typename... ArgTs>
-Expected<WrapperFunctionResult>
+WrapperFunctionResult
 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
   WrapperFunctionResult Result;
   char *DataPtr =
       WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
   SPSOutputBuffer OB(DataPtr, Result.size());
   if (!SPSArgListT::serialize(OB, Args...))
-    return make_error<StringError>(
-        "Error serializing arguments to blob in call",
-        inconvertibleErrorCode());
-  return std::move(Result);
+    return WrapperFunctionResult::createOutOfBandError(
+        "Error serializing arguments to blob in call");
+  return Result;
 }
 
 template <typename RetT> class WrapperFunctionHandlerCaller {
@@ -230,12 +229,8 @@ class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
         std::forward<HandlerT>(H), Args, ArgIndices{});
 
-    if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
-            std::move(HandlerResult)))
-      return std::move(*Result);
-    else
-      return WrapperFunctionResult::createOutOfBandError(
-          toString(Result.takeError()));
+    return ResultSerializer<decltype(HandlerResult)>::serialize(
+        std::move(HandlerResult));
   }
 
 private:
@@ -247,10 +242,10 @@ class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
   }
 };
 
-// Map function references to function types.
+// Map function pointers to function types.
 template <typename RetT, typename... ArgTs,
           template <typename> class ResultSerializer, typename... SPSTagTs>
-class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
+class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
                                    SPSTagTs...>
     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
                                           SPSTagTs...> {};
@@ -271,9 +266,87 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
                                           SPSTagTs...> {};
 
+template <typename WrapperFunctionImplT,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper
+    : public WrapperFunctionAsyncHandlerHelper<
+          decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
+          ResultSerializer, SPSTagTs...> {};
+
+template <typename RetT, typename SendResultT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
+                                        ResultSerializer, SPSTagTs...> {
+public:
+  using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
+  using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
+
+  template <typename HandlerT, typename SendWrapperFunctionResultT>
+  static void applyAsync(HandlerT &&H,
+                         SendWrapperFunctionResultT &&SendWrapperFunctionResult,
+                         const char *ArgData, size_t ArgSize) {
+    ArgTuple Args;
+    if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
+      SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
+          "Could not deserialize arguments for wrapper function call"));
+      return;
+    }
+
+    auto SendResult =
+        [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
+          using ResultT = decltype(Result);
+          SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
+        };
+
+    callAsync(std::forward<HandlerT>(H), std::move(SendResult), Args,
+              ArgIndices{});
+  }
+
+private:
+  template <std::size_t... I>
+  static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
+                          std::index_sequence<I...>) {
+    SPSInputBuffer IB(ArgData, ArgSize);
+    return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
+  }
+
+  template <typename HandlerT, typename SerializeAndSendResultT,
+            typename ArgTupleT, std::size_t... I>
+  static void callAsync(HandlerT &&H,
+                        SerializeAndSendResultT &&SerializeAndSendResult,
+                        ArgTupleT &Args, std::index_sequence<I...>) {
+    return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
+                                     std::get<I>(Args)...);
+  }
+};
+
+// Map function pointers to function types.
+template <typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
+                                        SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
+// Map non-const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
+                                        ResultSerializer, SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
+// Map const member function types to function types.
+template <typename ClassT, typename RetT, typename... ArgTs,
+          template <typename> class ResultSerializer, typename... SPSTagTs>
+class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
+                                        ResultSerializer, SPSTagTs...>
+    : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
+                                               SPSTagTs...> {};
+
 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
 public:
-  static Expected<WrapperFunctionResult> serialize(RetT Result) {
+  static WrapperFunctionResult serialize(RetT Result) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         Result);
   }
@@ -281,7 +354,7 @@ template <typename SPSRetTagT, typename RetT> class ResultSerializer {
 
 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
 public:
-  static Expected<WrapperFunctionResult> serialize(Error Err) {
+  static WrapperFunctionResult serialize(Error Err) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(Err)));
   }
@@ -290,7 +363,7 @@ template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
 template <typename SPSRetTagT, typename T>
 class ResultSerializer<SPSRetTagT, Expected<T>> {
 public:
-  static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
+  static WrapperFunctionResult serialize(Expected<T> E) {
     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(E)));
   }
@@ -298,6 +371,7 @@ class ResultSerializer<SPSRetTagT, Expected<T>> {
 
 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
 public:
+  static RetT makeValue() { return RetT(); }
   static void makeSafe(RetT &Result) {}
 
   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
@@ -312,6 +386,7 @@ template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
 
 template <> class ResultDeserializer<SPSError, Error> {
 public:
+  static Error makeValue() { return Error::success(); }
   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
 
   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
@@ -329,6 +404,7 @@ template <> class ResultDeserializer<SPSError, Error> {
 template <typename SPSTagT, typename T>
 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
 public:
+  static Expected<T> makeValue() { return T(); }
   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
 
   static Error deserialize(Expected<T> &E, const char *ArgData,
@@ -344,6 +420,10 @@ class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
   }
 };
 
+template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
+  // Did you forget to use Error / Expected in your handler?
+};
+
 } // end namespace detail
 
 template <typename SPSSignature> class WrapperFunction;
@@ -355,7 +435,7 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
 
 public:
-  /// Call a wrapper function. Callere should be callable as
+  /// Call a wrapper function. Caller should be callable as
   /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
   template <typename CallerFn, typename RetT, typename... ArgTs>
   static Error call(const CallerFn &Caller, RetT &Result,
@@ -369,18 +449,56 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
     auto ArgBuffer =
         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
             Args...);
-    if (!ArgBuffer)
-      return ArgBuffer.takeError();
-
-    Expected<WrapperFunctionResult> ResultBuffer =
-        Caller(ArgBuffer->data(), ArgBuffer->size());
-    if (!ResultBuffer)
-      return ResultBuffer.takeError();
-    if (auto ErrMsg = ResultBuffer->getOutOfBandError())
+    if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
+      return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
+
+    WrapperFunctionResult ResultBuffer =
+        Caller(ArgBuffer.data(), ArgBuffer.size());
+    if (auto ErrMsg = ResultBuffer.getOutOfBandError())
       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
 
     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
-        Result, ResultBuffer->data(), ResultBuffer->size());
+        Result, ResultBuffer.data(), ResultBuffer.size());
+  }
+
+  /// Call an async wrapper function.
+  /// Caller should be callable as
+  /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
+  ///         WrapperFunctionResult ArgBuffer);
+  template <typename AsyncCallerFn, typename SendDeserializedResultFn,
+            typename... ArgTs>
+  static void callAsync(AsyncCallerFn &&Caller,
+                        SendDeserializedResultFn &&SendDeserializedResult,
+                        const ArgTs &...Args) {
+    using RetT = typename std::tuple_element<
+        1, typename detail::WrapperFunctionHandlerHelper<
+               std::remove_reference_t<SendDeserializedResultFn>,
+               ResultSerializer, SPSRetTagT>::ArgTuple>::type;
+
+    auto ArgBuffer =
+        detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
+            Args...);
+    if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
+      SendDeserializedResult(
+          make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
+          detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
+      return;
+    }
+
+    auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
+                                    WrapperFunctionResult R) {
+      RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
+      detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
+
+      SPSInputBuffer IB(R.data(), R.size());
+      if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
+              RetVal, R.data(), R.size()))
+        SDR(std::move(Err), std::move(RetVal));
+
+      SDR(Error::success(), std::move(RetVal));
+    };
+
+    Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
   }
 
   /// Handle a call to a wrapper function.
@@ -388,11 +506,21 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
                                       HandlerT &&Handler) {
     using WFHH =
-        detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
-                                             SPSTagTs...>;
+        detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
+                                             ResultSerializer, SPSTagTs...>;
     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
   }
 
+  /// Handle a call to an async wrapper function.
+  template <typename HandlerT, typename SendResultT>
+  static void handleAsync(const char *ArgData, size_t ArgSize,
+                          HandlerT &&Handler, SendResultT &&SendResult) {
+    using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
+        std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
+    WFAHH::applyAsync(std::forward<HandlerT>(Handler),
+                      std::forward<SendResultT>(SendResult), ArgData, ArgSize);
+  }
+
 private:
   template <typename T> static const T &makeSerializable(const T &Value) {
     return Value;
@@ -411,6 +539,7 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
 template <typename... SPSTagTs>
 class WrapperFunction<void(SPSTagTs...)>
     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
+
 public:
   template <typename CallerFn, typename... ArgTs>
   static Error call(const CallerFn &Caller, const ArgTs &...Args) {
@@ -419,6 +548,7 @@ class WrapperFunction<void(SPSTagTs...)>
   }
 
   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
+  using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
 };
 
 } // end namespace shared

diff  --git a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
index f8bd74eabc9b4..12fa42ccdef6b 100644
--- a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
@@ -10,11 +10,10 @@
 
 #include "llvm/ExecutionEngine/Orc/Core.h"
 #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Host.h"
 #include "llvm/Support/Process.h"
 
-#include <mutex>
-
 namespace llvm {
 namespace orc {
 
@@ -22,6 +21,56 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {}
 
 ExecutorProcessControl::~ExecutorProcessControl() {}
 
+Error ExecutorProcessControl::associateJITSideWrapperFunctions(
+    JITDylib &JD, WrapperFunctionAssociationMap WFs) {
+
+  // Look up tag addresses.
+  auto &ES = JD.getExecutionSession();
+  auto TagAddrs =
+      ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}},
+                SymbolLookupSet::fromMapKeys(
+                    WFs, SymbolLookupFlags::WeaklyReferencedSymbol));
+  if (!TagAddrs)
+    return TagAddrs.takeError();
+
+  // Associate tag addresses with implementations.
+  std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+  for (auto &KV : *TagAddrs) {
+    auto TagAddr = KV.second.getAddress();
+    if (TagToFunc.count(TagAddr))
+      return make_error<StringError>("Tag " + formatv("{0:x16}", TagAddr) +
+                                         " (for " + *KV.first +
+                                         ") already registered",
+                                     inconvertibleErrorCode());
+    auto I = WFs.find(KV.first);
+    assert(I != WFs.end() && I->second &&
+           "AsyncWrapperFunction implementation missing");
+    TagToFunc[KV.second.getAddress()] =
+        std::make_shared<AsyncWrapperFunction>(std::move(I->second));
+  }
+  return Error::success();
+}
+
+void ExecutorProcessControl::runJITSideWrapperFunction(
+    SendResultFunction SendResult, JITTargetAddress TagAddr,
+    ArrayRef<char> ArgBuffer) {
+
+  std::shared_ptr<AsyncWrapperFunction> F;
+  {
+    std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
+    auto I = TagToFunc.find(TagAddr);
+    if (I != TagToFunc.end())
+      F = I->second;
+  }
+
+  if (F)
+    (*F)(std::move(SendResult), ArgBuffer.data(), ArgBuffer.size());
+  else
+    SendResult(shared::WrapperFunctionResult::createOutOfBandError(
+        ("No function registered for tag " + formatv("{0:x16}", TagAddr))
+            .str()));
+}
+
 SelfExecutorProcessControl::SelfExecutorProcessControl(
     std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
     unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
@@ -102,13 +151,13 @@ SelfExecutorProcessControl::runAsMain(JITTargetAddress MainFnAddr,
   return orc::runAsMain(jitTargetAddressToFunction<MainTy>(MainFnAddr), Args);
 }
 
-Expected<shared::WrapperFunctionResult>
-SelfExecutorProcessControl::runWrapper(JITTargetAddress WrapperFnAddr,
-                                       ArrayRef<char> ArgBuffer) {
-  using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)(
-      const char *Data, uint64_t Size);
+void SelfExecutorProcessControl::runWrapperAsync(SendResultFunction SendResult,
+                                                 JITTargetAddress WrapperFnAddr,
+                                                 ArrayRef<char> ArgBuffer) {
+  using WrapperFnTy =
+      shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size);
   auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr);
-  return WrapperFn(ArgBuffer.data(), ArgBuffer.size());
+  SendResult(WrapperFn(ArgBuffer.data(), ArgBuffer.size()));
 }
 
 Error SelfExecutorProcessControl::disconnect() { return Error::success(); }

diff  --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
index b1cfd18e5d4e5..b544cfa1864e8 100644
--- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
+++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt
@@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(OrcJITTests
   CoreAPIsTest.cpp
+  ExecutorProcessControlTest.cpp
   IndirectionUtilsTest.cpp
   JITTargetMachineBuilderTest.cpp
   LazyCallThroughAndReexportsTest.cpp

diff  --git a/llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp
new file mode 100644
index 0000000000000..23096c86f4d33
--- /dev/null
+++ b/llvm/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp
@@ -0,0 +1,105 @@
+//===- ExecutorProcessControlTest.cpp - Test ExecutorProcessControl utils -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
+#include "llvm/Support/MSVCErrorWorkarounds.h"
+#include "llvm/Testing/Support/Error.h"
+#include "gtest/gtest.h"
+
+#include <future>
+
+using namespace llvm;
+using namespace llvm::orc;
+using namespace llvm::orc::shared;
+
+static llvm::orc::shared::detail::CWrapperFunctionResult
+addWrapper(const char *ArgData, size_t ArgSize) {
+  return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
+             ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
+      .release();
+}
+
+static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
+                            int32_t X, int32_t Y) {
+  SendResult(X + Y);
+}
+
+TEST(ExecutorProcessControl, RunWrapperTemplate) {
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+  int32_t Result;
+  EXPECT_THAT_ERROR(EPC->runSPSWrapper<int32_t(int32_t, int32_t)>(
+                        pointerToJITTargetAddress(addWrapper), Result, 2, 3),
+                    Succeeded());
+  EXPECT_EQ(Result, 5);
+}
+
+TEST(ExecutorProcessControl, RunWrapperAsyncTemplate) {
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+
+  std::promise<MSVCPExpected<int32_t>> RP;
+  using Sig = int32_t(int32_t, int32_t);
+  EPC->runSPSWrapperAsync<Sig>(
+      [&](Error SerializationErr, int32_t R) {
+        if (SerializationErr)
+          RP.set_value(std::move(SerializationErr));
+        RP.set_value(std::move(R));
+      },
+      pointerToJITTargetAddress(addWrapper), 2, 3);
+  Expected<int32_t> Result = RP.get_future().get();
+  EXPECT_THAT_EXPECTED(Result, HasValue(5));
+}
+
+TEST(ExecutorProcessControl, RegisterAsyncHandlerAndRun) {
+
+  constexpr JITTargetAddress AddAsyncTagAddr = 0x01;
+
+  auto EPC = cantFail(
+      SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
+  ExecutionSession ES(EPC->getSymbolStringPool());
+  auto &JD = ES.createBareJITDylib("JD");
+
+  auto AddAsyncTag = ES.intern("addAsync_tag");
+  cantFail(JD.define(absoluteSymbols(
+      {{AddAsyncTag,
+        JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}})));
+
+  ExecutorProcessControl::WrapperFunctionAssociationMap Associations;
+
+  Associations[AddAsyncTag] =
+      EPC->wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
+
+  cantFail(EPC->associateJITSideWrapperFunctions(JD, std::move(Associations)));
+
+  std::promise<int32_t> RP;
+  auto RF = RP.get_future();
+
+  using ArgSerialization = SPSArgList<int32_t, int32_t>;
+  size_t ArgBufferSize = ArgSerialization::size(1, 2);
+  WrapperFunctionResult ArgBuffer;
+  char *ArgBufferData =
+      WrapperFunctionResult::allocate(ArgBuffer, ArgBufferSize);
+  SPSOutputBuffer OB(ArgBufferData, ArgBufferSize);
+  EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
+
+  EPC->runJITSideWrapperFunction(
+      [&](WrapperFunctionResult ResultBuffer) {
+        int32_t Result;
+        SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
+        EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
+        RP.set_value(Result);
+      },
+      AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
+
+  EXPECT_EQ(RF.get(), (int32_t)3);
+
+  cantFail(ES.endSession());
+}

diff  --git a/llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp
index 1f177b4c2d143..42051836506fb 100644
--- a/llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp
@@ -7,8 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
+#include "llvm/ADT/FunctionExtras.h"
 #include "gtest/gtest.h"
 
+#include <future>
+
 using namespace llvm;
 using namespace llvm::orc::shared;
 
@@ -65,13 +68,54 @@ static WrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) {
       ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; });
 }
 
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleVoid) {
   EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopWrapper));
 }
 
-TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) {
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleRet) {
   int32_t Result;
   EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
       addWrapper, Result, 1, 2));
   EXPECT_EQ(Result, (int32_t)3);
 }
+
+static void voidNoopAsync(unique_function<void(SPSEmpty)> SendResult) {
+  SendResult(SPSEmpty());
+}
+
+static WrapperFunctionResult voidNoopAsyncWrapper(const char *ArgData,
+                                                  size_t ArgSize) {
+  std::promise<WrapperFunctionResult> RP;
+  auto RF = RP.get_future();
+
+  WrapperFunction<void()>::handleAsync(
+      ArgData, ArgSize, voidNoopAsync,
+      [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+
+  return RF.get();
+}
+
+static WrapperFunctionResult addAsyncWrapper(const char *ArgData,
+                                             size_t ArgSize) {
+  std::promise<WrapperFunctionResult> RP;
+  auto RF = RP.get_future();
+
+  WrapperFunction<int32_t(int32_t, int32_t)>::handleAsync(
+      ArgData, ArgSize,
+      [](unique_function<void(int32_t)> SendResult, int32_t X, int32_t Y) {
+        SendResult(X + Y);
+      },
+      [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
+  return RF.get();
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncVoid) {
+  EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopAsyncWrapper));
+}
+
+TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncRet) {
+  int32_t Result;
+  EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
+      addAsyncWrapper, Result, 1, 2));
+  EXPECT_EQ(Result, (int32_t)3);
+}


        


More information about the llvm-commits mailing list