[llvm] 4d7cea3 - [ORC] Add optional RunPolicy to ExecutorProcessControl::callWrapperAsync.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 10 20:52:37 PDT 2021


Author: Lang Hames
Date: 2021-10-10T20:41:59-07:00
New Revision: 4d7cea3d2e833209d06e201a273f97342035c196

URL: https://github.com/llvm/llvm-project/commit/4d7cea3d2e833209d06e201a273f97342035c196
DIFF: https://github.com/llvm/llvm-project/commit/4d7cea3d2e833209d06e201a273f97342035c196.diff

LOG: [ORC] Add optional RunPolicy to ExecutorProcessControl::callWrapperAsync.

The callWrapperAsync and callSPSWrapperAsync methods take a handler object
that is run on the return value of the call when it is ready. The new RunPolicy
parameters allow clients to control how these handlers are run. If no policy is
specified then the handler will be packaged as a GenericNamedTask and dispatched
using the ExecutorProcessControl's TaskDispatch member. Callers can use the
ExecutorProcessControl::RunInPlace policy to cause the handler to be run
directly instead, which may be preferrable for simple handlers, or they can
write their own policy object (e.g. to dispatch as some other kind of Task,
rather than GenericNamedTask).

Added: 
    

Modified: 
    llvm/include/llvm/ExecutionEngine/Orc/Core.h
    llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
    llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
    llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
    llvm/lib/ExecutionEngine/Orc/Core.cpp
    llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
    llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
index d2761d645861..5cac65b49a05 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h
@@ -1284,13 +1284,16 @@ class ExecutionSession {
   /// For reporting errors.
   using ErrorReporter = std::function<void(Error)>;
 
+  /// Send a result to the remote.
+  using SendResultFunction = unique_function<void(shared::WrapperFunctionResult)>;
+
   /// For dispatching ORC tasks (typically materialization tasks).
   using DispatchTaskFunction = unique_function<void(std::unique_ptr<Task> T)>;
 
   /// An asynchronous wrapper-function callable from the executor via
   /// jit-dispatch.
   using JITDispatchHandlerFunction = unique_function<void(
-      ExecutorProcessControl::SendResultFunction SendResult,
+      SendResultFunction SendResult,
       const char *ArgData, size_t ArgSize)>;
 
   /// A map associating tag names with asynchronous wrapper function
@@ -1467,10 +1470,9 @@ class ExecutionSession {
   /// \endcode{.cpp}
   ///
   /// The given OnComplete function will be called to return the result.
-  void callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                        ExecutorProcessControl::SendResultFunction OnComplete,
-                        ArrayRef<char> ArgBuffer) {
-    EPC->callWrapperAsync(WrapperFnAddr, std::move(OnComplete), ArgBuffer);
+  template <typename... ArgTs>
+  void callWrapperAsync(ArgTs &&... Args) {
+    EPC->callWrapperAsync(std::forward<ArgTs>(Args)...);
   }
 
   /// Run a wrapper function in the executor. The wrapper function should be
@@ -1515,7 +1517,7 @@ class ExecutionSession {
   template <typename SPSSignature, typename HandlerT>
   static JITDispatchHandlerFunction wrapAsyncWithSPS(HandlerT &&H) {
     return [H = std::forward<HandlerT>(H)](
-               ExecutorProcessControl::SendResultFunction SendResult,
+               SendResultFunction SendResult,
                const char *ArgData, size_t ArgSize) mutable {
       shared::WrapperFunction<SPSSignature>::handleAsync(ArgData, ArgSize, H,
                                                          std::move(SendResult));
@@ -1554,7 +1556,7 @@ class ExecutionSession {
   /// This should be called by the ExecutorProcessControl instance in response
   /// to incoming jit-dispatch requests from the executor.
   void
-  runJITDispatchHandler(ExecutorProcessControl::SendResultFunction SendResult,
+  runJITDispatchHandler(SendResultFunction SendResult,
                         JITTargetAddress HandlerFnTagAddr,
                         ArrayRef<char> ArgBuffer);
 

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
index 147d1d3dca1b..2786f76c26a7 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h
@@ -37,11 +37,64 @@ class SymbolLookupSet;
 /// ExecutorProcessControl supports interaction with a JIT target process.
 class ExecutorProcessControl {
   friend class ExecutionSession;
-
 public:
-  /// Sender to return the result of a WrapperFunction executed in the JIT.
-  using SendResultFunction =
-      unique_function<void(shared::WrapperFunctionResult)>;
+
+  /// A handler or incoming WrapperFunctionResults -- either return values from
+  /// callWrapper* calls, or incoming JIT-dispatch requests.
+  ///
+  /// IncomingWFRHandlers are constructible from
+  /// unique_function<void(shared::WrapperFunctionResult)>s using the
+  /// runInPlace function or a RunWithDispatch object.
+  class IncomingWFRHandler {
+    friend class ExecutorProcessControl;
+  public:
+    IncomingWFRHandler() = default;
+    void operator()(shared::WrapperFunctionResult WFR) { H(std::move(WFR)); }
+  private:
+    template <typename FnT> IncomingWFRHandler(FnT &&Fn)
+      : H(std::forward<FnT>(Fn)) {}
+
+    unique_function<void(shared::WrapperFunctionResult)> H;
+  };
+
+  /// Constructs an IncomingWFRHandler from a function object that is callable
+  /// as void(shared::WrapperFunctionResult). The function object will be called
+  /// directly. This should be used with care as it may block listener threads
+  /// in remote EPCs. It is only suitable for simple tasks (e.g. setting a
+  /// future), or for performing some quick analysis before dispatching "real"
+  /// work as a Task.
+  class RunInPlace {
+  public:
+    template <typename FnT>
+    IncomingWFRHandler operator()(FnT &&Fn) {
+      return IncomingWFRHandler(std::forward<FnT>(Fn));
+    }
+  };
+
+  /// Constructs an IncomingWFRHandler from a function object by creating a new
+  /// function object that dispatches the original using a TaskDispatcher,
+  /// wrapping the original as a GenericNamedTask.
+  ///
+  /// This is the default approach for running WFR handlers.
+  class RunAsTask {
+  public:
+    RunAsTask(TaskDispatcher &D) : D(D) {}
+
+    template <typename FnT>
+    IncomingWFRHandler operator()(FnT &&Fn) {
+      return IncomingWFRHandler(
+          [&D = this->D, Fn = std::move(Fn)]
+          (shared::WrapperFunctionResult WFR) mutable {
+              D.dispatch(
+                makeGenericNamedTask(
+                    [Fn = std::move(Fn), WFR = std::move(WFR)]() mutable {
+                      Fn(std::move(WFR));
+                    }, "WFR handler task"));
+          });
+    }
+  private:
+    TaskDispatcher &D;
+  };
 
   /// APIs for manipulating memory in the target process.
   class MemoryAccess {
@@ -205,19 +258,36 @@ class ExecutorProcessControl {
   virtual Expected<int32_t> runAsMain(ExecutorAddr MainFnAddr,
                                       ArrayRef<std::string> Args) = 0;
 
-  /// Run a wrapper function in the executor.
+  /// Run a wrapper function in the executor. The given WFRHandler will be
+  /// called on the result when it is returned.
   ///
   /// The wrapper function should be callable as:
   ///
   /// \code{.cpp}
   ///   CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
   /// \endcode{.cpp}
-  ///
-  /// The given OnComplete function will be called to return the result.
   virtual void callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                                SendResultFunction OnComplete,
+                                IncomingWFRHandler OnComplete,
                                 ArrayRef<char> ArgBuffer) = 0;
 
+  /// Run a wrapper function in the executor using the given Runner to dispatch
+  /// OnComplete when the result is ready.
+  template <typename RunPolicyT, typename FnT>
+  void callWrapperAsync(RunPolicyT &&Runner, ExecutorAddr WrapperFnAddr,
+                        FnT &&OnComplete, ArrayRef<char> ArgBuffer) {
+    callWrapperAsync(
+        WrapperFnAddr, Runner(std::forward<FnT>(OnComplete)), ArgBuffer);
+  }
+
+  /// Run a wrapper function in the executor. OnComplete will be dispatched
+  /// as a GenericNamedTask using this instance's TaskDispatch object.
+  template <typename FnT>
+  void callWrapperAsync(ExecutorAddr WrapperFnAddr, FnT &&OnComplete,
+                        ArrayRef<char> ArgBuffer) {
+    callWrapperAsync(RunAsTask(*D), WrapperFnAddr,
+                     std::forward<FnT>(OnComplete), ArgBuffer);
+  }
+
   /// Run a wrapper function in the executor. The wrapper function should be
   /// callable as:
   ///
@@ -229,25 +299,37 @@ class ExecutorProcessControl {
     std::promise<shared::WrapperFunctionResult> RP;
     auto RF = RP.get_future();
     callWrapperAsync(
-        WrapperFnAddr,
-        [&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); },
-        ArgBuffer);
+        RunInPlace(), WrapperFnAddr,
+        [&](shared::WrapperFunctionResult R) {
+          RP.set_value(std::move(R));
+        }, ArgBuffer);
     return RF.get();
   }
 
+  /// Run a wrapper function using SPS to serialize the arguments and
+  /// deserialize the results.
+  template <typename SPSSignature, typename RunPolicyT, typename SendResultT,
+            typename... ArgTs>
+  void callSPSWrapperAsync(RunPolicyT &&Runner, ExecutorAddr WrapperFnAddr,
+                           SendResultT &&SendResult, const ArgTs &...Args) {
+    shared::WrapperFunction<SPSSignature>::callAsync(
+        [this, WrapperFnAddr, Runner = std::move(Runner)]
+        (auto &&SendResult, const char *ArgData, size_t ArgSize) mutable {
+          this->callWrapperAsync(std::move(Runner), WrapperFnAddr,
+                                 std::move(SendResult),
+                                 ArrayRef<char>(ArgData, ArgSize));
+        },
+        std::forward<SendResultT>(SendResult), Args...);
+  }
+
   /// Run a wrapper function using SPS to serialize the arguments and
   /// deserialize the results.
   template <typename SPSSignature, typename SendResultT, typename... ArgTs>
   void callSPSWrapperAsync(ExecutorAddr WrapperFnAddr, SendResultT &&SendResult,
                            const ArgTs &...Args) {
-    shared::WrapperFunction<SPSSignature>::callAsync(
-        [this,
-         WrapperFnAddr](ExecutorProcessControl::SendResultFunction SendResult,
-                        const char *ArgData, size_t ArgSize) {
-          callWrapperAsync(WrapperFnAddr, std::move(SendResult),
-                           ArrayRef<char>(ArgData, ArgSize));
-        },
-        std::move(SendResult), Args...);
+    callSPSWrapperAsync<SPSSignature>(RunAsTask(*D), WrapperFnAddr,
+                                      std::forward<SendResultT>(SendResult),
+                                      Args...);
   }
 
   /// Run a wrapper function using SPS to serialize the arguments and
@@ -315,7 +397,7 @@ class UnsupportedExecutorProcessControl : public ExecutorProcessControl {
   }
 
   void callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                        SendResultFunction OnComplete,
+                        IncomingWFRHandler OnComplete,
                         ArrayRef<char> ArgBuffer) override {
     llvm_unreachable("Unsupported");
   }
@@ -352,7 +434,7 @@ class SelfExecutorProcessControl
                               ArrayRef<std::string> Args) override;
 
   void callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                        SendResultFunction OnComplete,
+                        IncomingWFRHandler OnComplete,
                         ArrayRef<char> ArgBuffer) override;
 
   Error disconnect() override;

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
index 37a2792b72e0..0dd581b45d8d 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h
@@ -555,7 +555,7 @@ class WrapperFunction<void(SPSTagTs...)>
                         SendDeserializedResultFn &&SendDeserializedResult,
                         const ArgTs &...Args) {
     WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync(
-        Caller,
+        std::forward<AsyncCallerFn>(Caller),
         [SDR = std::move(SendDeserializedResult)](Error SerializeErr,
                                                   SPSEmpty E) mutable {
           SDR(std::move(SerializeErr));

diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h b/llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
index 55449f9be621..46f33813be51 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h
@@ -64,7 +64,7 @@ class SimpleRemoteEPC : public ExecutorProcessControl,
                               ArrayRef<std::string> Args) override;
 
   void callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                        SendResultFunction OnComplete,
+                        IncomingWFRHandler OnComplete,
                         ArrayRef<char> ArgBuffer) override;
 
   Error disconnect() override;
@@ -100,7 +100,8 @@ class SimpleRemoteEPC : public ExecutorProcessControl,
   uint64_t getNextSeqNo() { return NextSeqNo++; }
   void releaseSeqNo(uint64_t SeqNo) {}
 
-  using PendingCallWrapperResultsMap = DenseMap<uint64_t, SendResultFunction>;
+  using PendingCallWrapperResultsMap =
+    DenseMap<uint64_t, IncomingWFRHandler>;
 
   std::mutex SimpleRemoteEPCMutex;
   std::condition_variable DisconnectCV;

diff  --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp
index c29593bfe0e9..64e5090e4c53 100644
--- a/llvm/lib/ExecutionEngine/Orc/Core.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp
@@ -2089,8 +2089,8 @@ Error ExecutionSession::registerJITDispatchHandlers(
 }
 
 void ExecutionSession::runJITDispatchHandler(
-    ExecutorProcessControl::SendResultFunction SendResult,
-    JITTargetAddress HandlerFnTagAddr, ArrayRef<char> ArgBuffer) {
+    SendResultFunction SendResult, JITTargetAddress HandlerFnTagAddr,
+    ArrayRef<char> ArgBuffer) {
 
   std::shared_ptr<JITDispatchHandlerFunction> F;
   {

diff  --git a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
index 1485789e287b..6fb8b52e581f 100644
--- a/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp
@@ -121,7 +121,7 @@ SelfExecutorProcessControl::runAsMain(ExecutorAddr MainFnAddr,
 }
 
 void SelfExecutorProcessControl::callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                                                  SendResultFunction SendResult,
+                                                  IncomingWFRHandler SendResult,
                                                   ArrayRef<char> ArgBuffer) {
   using WrapperFnTy =
       shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size);

diff  --git a/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp b/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp
index 31886294d3ab..dfb3832cb03b 100644
--- a/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp
@@ -55,7 +55,7 @@ Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
 }
 
 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
-                                       SendResultFunction OnComplete,
+                                       IncomingWFRHandler OnComplete,
                                        ArrayRef<char> ArgBuffer) {
   uint64_t SeqNo;
   {
@@ -246,6 +246,7 @@ Error SimpleRemoteEPC::setup() {
 
   // Prepare a handler for the setup packet.
   PendingCallWrapperResults[0] =
+    RunInPlace()(
       [&](shared::WrapperFunctionResult SetupMsgBytes) {
         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
           EIP.set_value(
@@ -261,7 +262,7 @@ Error SimpleRemoteEPC::setup() {
         else
           EIP.set_value(make_error<StringError>(
               "Could not deserialize setup message", inconvertibleErrorCode()));
-      };
+      });
 
   // Start the transport.
   if (auto Err = T->start())
@@ -316,7 +317,7 @@ Error SimpleRemoteEPC::setup() {
 
 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
-  SendResultFunction SendResult;
+  IncomingWFRHandler SendResult;
 
   if (TagAddr)
     return make_error<StringError>("Unexpected TagAddr in result message",


        


More information about the llvm-commits mailing list