[llvm] r266581 - [ORC] Generalize the ORC RPC utils to support RPC function return values and

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 17 18:06:49 PDT 2016


Author: lhames
Date: Sun Apr 17 20:06:49 2016
New Revision: 266581

URL: http://llvm.org/viewvc/llvm-project?rev=266581&view=rev
Log:
[ORC] Generalize the ORC RPC utils to support RPC function return values and
asynchronous call/handle. Also updates the ORC remote JIT API to use the new
scheme.

The previous version of the RPC tools only supported void functions, and
required the user to manually call a paired function to return results. This
patch replaces the Procedure typedef (which only supported void functions) with
the Function typedef which supports return values, e.g.:

  Function<FooId, int32_t(std::string)> Foo;

The RPC primitives and channel operations are also expanded. RPC channels must
support four new operations: startSendMessage, endSendMessage,
startRecieveMessage and endRecieveMessage, to handle channel locking. In
addition, serialization support for tuples to RPCChannels is added to enable
multiple return values.

The RPC primitives are expanded from callAppend, call, expect and handle, to:

appendCallAsync - Make an asynchronous call to the given function.

callAsync - The same as appendCallAsync, but calls send on the channel when
            done.

callSTHandling - Blocking call for single-threaded code. Wraps a call to
                 callAsync then waits on the result, using a user-supplied
                 handler to handle any callbacks from the remote.

callST - The same as callSTHandling, except that it doesn't handle
         callbacks - it expects the result to be the first return.

expect and handle - as before.

handleResponse - Handle a response from the remote.

waitForResult - Wait for the response with the given sequence number to arrive.


Modified:
    llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
    llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
    llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp
    llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp
    llvm/trunk/tools/lli/RemoteJITUtils.h
    llvm/trunk/tools/lli/lli.cpp
    llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h Sun Apr 17 20:06:49 2016
@@ -26,7 +26,8 @@ enum class OrcErrorCode : int {
   RemoteMProtectAddrUnrecognized,
   RemoteIndirectStubsOwnerDoesNotExist,
   RemoteIndirectStubsOwnerIdAlreadyInUse,
-  UnexpectedRPCCall
+  UnexpectedRPCCall,
+  UnexpectedRPCResponse,
 };
 
 std::error_code orcError(OrcErrorCode ErrCode);

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h Sun Apr 17 20:06:49 2016
@@ -36,6 +36,7 @@ namespace remote {
 template <typename ChannelT>
 class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI {
 public:
+
   /// Remote memory manager.
   class RCMemoryManager : public RuntimeDyld::MemoryManager {
   public:
@@ -105,11 +106,13 @@ public:
       DEBUG(dbgs() << "Allocator " << Id << " reserved:\n");
 
       if (CodeSize != 0) {
-        std::error_code EC = Client.reserveMem(Unmapped.back().RemoteCodeAddr,
-                                               Id, CodeSize, CodeAlign);
-        // FIXME; Add error to poll.
-        assert(!EC && "Failed reserving remote memory.");
-        (void)EC;
+	if (auto AddrOrErr = Client.reserveMem(Id, CodeSize, CodeAlign))
+	  Unmapped.back().RemoteCodeAddr = *AddrOrErr;
+	else {
+	  // FIXME; Add error to poll.
+	  assert(!AddrOrErr.getError() && "Failed reserving remote memory.");
+	}
+
         DEBUG(dbgs() << "  code: "
                      << format("0x%016x", Unmapped.back().RemoteCodeAddr)
                      << " (" << CodeSize << " bytes, alignment " << CodeAlign
@@ -117,11 +120,13 @@ public:
       }
 
       if (RODataSize != 0) {
-        std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRODataAddr,
-                                               Id, RODataSize, RODataAlign);
-        // FIXME; Add error to poll.
-        assert(!EC && "Failed reserving remote memory.");
-        (void)EC;
+        if (auto AddrOrErr = Client.reserveMem(Id, RODataSize, RODataAlign))
+	  Unmapped.back().RemoteRODataAddr = *AddrOrErr;
+	else {
+	  // FIXME; Add error to poll.
+	  assert(!AddrOrErr.getError() && "Failed reserving remote memory.");
+	}
+                                               
         DEBUG(dbgs() << "  ro-data: "
                      << format("0x%016x", Unmapped.back().RemoteRODataAddr)
                      << " (" << RODataSize << " bytes, alignment "
@@ -129,11 +134,13 @@ public:
       }
 
       if (RWDataSize != 0) {
-        std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRWDataAddr,
-                                               Id, RWDataSize, RWDataAlign);
-        // FIXME; Add error to poll.
-        assert(!EC && "Failed reserving remote memory.");
-        (void)EC;
+        if (auto AddrOrErr = Client.reserveMem(Id, RWDataSize, RWDataAlign))
+	  Unmapped.back().RemoteRWDataAddr = *AddrOrErr;
+	else {
+	  // FIXME; Add error to poll.
+	  assert(!AddrOrErr.getError() && "Failed reserving remote memory.");
+        }
+
         DEBUG(dbgs() << "  rw-data: "
                      << format("0x%016x", Unmapped.back().RemoteRWDataAddr)
                      << " (" << RWDataSize << " bytes, alignment "
@@ -431,8 +438,10 @@ public:
       TargetAddress PtrBase;
       unsigned NumStubsEmitted;
 
-      Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id,
-                               NewStubsRequired);
+      if (auto StubInfoOrErr = Remote.emitIndirectStubs(Id, NewStubsRequired))
+	std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr;
+      else
+	return StubInfoOrErr.getError();
 
       unsigned NewBlockId = RemoteIndirectStubsInfos.size();
       RemoteIndirectStubsInfos.push_back({StubBase, PtrBase, NumStubsEmitted});
@@ -484,8 +493,12 @@ public:
     void grow() override {
       TargetAddress BlockAddr = 0;
       uint32_t NumTrampolines = 0;
-      auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines);
-      assert(!EC && "Failed to create trampolines");
+      if (auto TrampolineInfoOrErr = Remote.emitTrampolineBlock())
+	std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr;
+      else {
+	// FIXME: Return error.
+	llvm_unreachable("Failed to create trampolines");
+      }
 
       uint32_t TrampolineSize = Remote.getTrampolineSize();
       for (unsigned I = 0; I < NumTrampolines; ++I)
@@ -503,53 +516,33 @@ public:
     OrcRemoteTargetClient H(Channel, EC);
     if (EC)
       return EC;
-    return H;
+    return ErrorOr<OrcRemoteTargetClient>(std::move(H));
   }
 
   /// Call the int(void) function at the given address in the target and return
   /// its result.
-  std::error_code callIntVoid(int &Result, TargetAddress Addr) {
+  ErrorOr<int> callIntVoid(TargetAddress Addr) {
     DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n");
 
-    if (auto EC = call<CallIntVoid>(Channel, Addr))
-      return EC;
-
-    unsigned NextProcId;
-    if (auto EC = listenForCompileRequests(NextProcId))
-      return EC;
-
-    if (NextProcId != CallIntVoidResponseId)
-      return orcError(OrcErrorCode::UnexpectedRPCCall);
-
-    return handle<CallIntVoidResponse>(Channel, [&](int R) {
-      Result = R;
-      DEBUG(dbgs() << "Result: " << R << "\n");
-      return std::error_code();
-    });
+    auto Listen =
+      [&](RPCChannel &C, uint32_t Id) {
+        return listenForCompileRequests(C, Id);
+      };
+    return callSTHandling<CallIntVoid>(Channel, Listen, Addr);
   }
 
   /// Call the int(int, char*[]) function at the given address in the target and
   /// return its result.
-  std::error_code callMain(int &Result, TargetAddress Addr,
-                           const std::vector<std::string> &Args) {
+  ErrorOr<int> callMain(TargetAddress Addr,
+			const std::vector<std::string> &Args) {
     DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr)
                  << "\n");
 
-    if (auto EC = call<CallMain>(Channel, Addr, Args))
-      return EC;
-
-    unsigned NextProcId;
-    if (auto EC = listenForCompileRequests(NextProcId))
-      return EC;
-
-    if (NextProcId != CallMainResponseId)
-      return orcError(OrcErrorCode::UnexpectedRPCCall);
-
-    return handle<CallMainResponse>(Channel, [&](int R) {
-      Result = R;
-      DEBUG(dbgs() << "Result: " << R << "\n");
-      return std::error_code();
-    });
+    auto Listen =
+      [&](RPCChannel &C, uint32_t Id) {
+        return listenForCompileRequests(C, Id);
+      };
+    return callSTHandling<CallMain>(Channel, Listen, Addr, Args);
   }
 
   /// Call the void() function at the given address in the target and wait for
@@ -558,17 +551,11 @@ public:
     DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr)
                  << "\n");
 
-    if (auto EC = call<CallVoidVoid>(Channel, Addr))
-      return EC;
-
-    unsigned NextProcId;
-    if (auto EC = listenForCompileRequests(NextProcId))
-      return EC;
-
-    if (NextProcId != CallVoidVoidResponseId)
-      return orcError(OrcErrorCode::UnexpectedRPCCall);
-
-    return handle<CallVoidVoidResponse>(Channel, doNothing);
+    auto Listen =
+      [&](RPCChannel &C, JITFuncId Id) {
+        return listenForCompileRequests(C, Id);
+      };
+    return callSTHandling<CallVoidVoid>(Channel, Listen, Addr);
   }
 
   /// Create an RCMemoryManager which will allocate its memory on the remote
@@ -578,7 +565,7 @@ public:
     assert(!MM && "MemoryManager should be null before creation.");
 
     auto Id = AllocatorIds.getNext();
-    if (auto EC = call<CreateRemoteAllocator>(Channel, Id))
+    if (auto EC = callST<CreateRemoteAllocator>(Channel, Id))
       return EC;
     MM = llvm::make_unique<RCMemoryManager>(*this, Id);
     return std::error_code();
@@ -590,7 +577,7 @@ public:
   createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) {
     assert(!I && "Indirect stubs manager should be null before creation.");
     auto Id = IndirectStubOwnerIds.getNext();
-    if (auto EC = call<CreateIndirectStubsOwner>(Channel, Id))
+    if (auto EC = callST<CreateIndirectStubsOwner>(Channel, Id))
       return EC;
     I = llvm::make_unique<RCIndirectStubsManager>(*this, Id);
     return std::error_code();
@@ -599,45 +586,39 @@ public:
   /// Search for symbols in the remote process. Note: This should be used by
   /// symbol resolvers *after* they've searched the local symbol table in the
   /// JIT stack.
-  std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) {
+  ErrorOr<TargetAddress> getSymbolAddress(StringRef Name) {
     // Check for an 'out-of-band' error, e.g. from an MM destructor.
     if (ExistingError)
       return ExistingError;
 
-    // Request remote symbol address.
-    if (auto EC = call<GetSymbolAddress>(Channel, Name))
-      return EC;
-
-    return expect<GetSymbolAddressResponse>(Channel, [&](TargetAddress &A) {
-      Addr = A;
-      DEBUG(dbgs() << "Remote address lookup " << Name << " = "
-                   << format("0x%016x", Addr) << "\n");
-      return std::error_code();
-    });
+    return callST<GetSymbolAddress>(Channel, Name);
   }
 
   /// Get the triple for the remote target.
   const std::string &getTargetTriple() const { return RemoteTargetTriple; }
 
-  std::error_code terminateSession() { return call<TerminateSession>(Channel); }
+  std::error_code terminateSession() {
+    return callST<TerminateSession>(Channel);
+  }
 
 private:
   OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC)
       : Channel(Channel) {
-    if ((EC = call<GetRemoteInfo>(Channel)))
-      return;
-
-    EC = expect<GetRemoteInfoResponse>(
-        Channel, readArgs(RemoteTargetTriple, RemotePointerSize, RemotePageSize,
-                          RemoteTrampolineSize, RemoteIndirectStubSize));
+    if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) {
+      std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize,
+	       RemoteTrampolineSize, RemoteIndirectStubSize) =
+	*RIOrErr;
+      EC = std::error_code();
+    } else
+      EC = RIOrErr.getError();
   }
 
   std::error_code deregisterEHFrames(TargetAddress Addr, uint32_t Size) {
-    return call<RegisterEHFrames>(Channel, Addr, Size);
+    return callST<RegisterEHFrames>(Channel, Addr, Size);
   }
 
   void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
-    if (auto EC = call<DestroyRemoteAllocator>(Channel, Id)) {
+    if (auto EC = callST<DestroyRemoteAllocator>(Channel, Id)) {
       // FIXME: This will be triggered by a removeModuleSet call: Propagate
       //        error return up through that.
       llvm_unreachable("Failed to destroy remote allocator.");
@@ -647,19 +628,13 @@ private:
 
   std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) {
     IndirectStubOwnerIds.release(Id);
-    return call<DestroyIndirectStubsOwner>(Channel, Id);
+    return callST<DestroyIndirectStubsOwner>(Channel, Id);
   }
 
-  std::error_code emitIndirectStubs(TargetAddress &StubBase,
-                                    TargetAddress &PtrBase,
-                                    uint32_t &NumStubsEmitted,
-                                    ResourceIdMgr::ResourceId Id,
-                                    uint32_t NumStubsRequired) {
-    if (auto EC = call<EmitIndirectStubs>(Channel, Id, NumStubsRequired))
-      return EC;
-
-    return expect<EmitIndirectStubsResponse>(
-        Channel, readArgs(StubBase, PtrBase, NumStubsEmitted));
+  ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>>
+  emitIndirectStubs(ResourceIdMgr::ResourceId Id,
+		    uint32_t NumStubsRequired) {
+    return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired);
   }
 
   std::error_code emitResolverBlock() {
@@ -667,24 +642,16 @@ private:
     if (ExistingError)
       return ExistingError;
 
-    return call<EmitResolverBlock>(Channel);
+    return callST<EmitResolverBlock>(Channel);
   }
 
-  std::error_code emitTrampolineBlock(TargetAddress &BlockAddr,
-                                      uint32_t &NumTrampolines) {
+  ErrorOr<std::tuple<TargetAddress, uint32_t>>
+  emitTrampolineBlock() {
     // Check for an 'out-of-band' error, e.g. from an MM destructor.
     if (ExistingError)
       return ExistingError;
 
-    if (auto EC = call<EmitTrampolineBlock>(Channel))
-      return EC;
-
-    return expect<EmitTrampolineBlockResponse>(
-        Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) {
-          BlockAddr = BAddr;
-          NumTrampolines = NTrampolines;
-          return std::error_code();
-        });
+    return callST<EmitTrampolineBlock>(Channel);
   }
 
   uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; }
@@ -693,67 +660,46 @@ private:
 
   uint32_t getTrampolineSize() const { return RemoteTrampolineSize; }
 
-  std::error_code listenForCompileRequests(uint32_t &NextId) {
+  std::error_code listenForCompileRequests(RPCChannel &C, uint32_t &Id) {
     // Check for an 'out-of-band' error, e.g. from an MM destructor.
     if (ExistingError)
       return ExistingError;
 
-    if (auto EC = getNextProcId(Channel, NextId))
-      return EC;
-
-    while (NextId == RequestCompileId) {
-      TargetAddress TrampolineAddr = 0;
-      if (auto EC = handle<RequestCompile>(Channel, readArgs(TrampolineAddr)))
-        return EC;
-
-      TargetAddress ImplAddr = CompileCallback(TrampolineAddr);
-      if (auto EC = call<RequestCompileResponse>(Channel, ImplAddr))
-        return EC;
-
-      if (auto EC = getNextProcId(Channel, NextId))
+    if (Id == RequestCompileId) {
+      if (auto EC = handle<RequestCompile>(C, CompileCallback))
         return EC;
+      return std::error_code();
     }
-
-    return std::error_code();
+    // else
+    return orcError(OrcErrorCode::UnexpectedRPCCall);
   }
 
-  std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) {
+  ErrorOr<std::vector<char>> readMem(char *Dst, TargetAddress Src, uint64_t Size) {
     // Check for an 'out-of-band' error, e.g. from an MM destructor.
     if (ExistingError)
       return ExistingError;
 
-    if (auto EC = call<ReadMem>(Channel, Src, Size))
-      return EC;
-
-    if (auto EC = expect<ReadMemResponse>(
-            Channel, [&]() { return Channel.readBytes(Dst, Size); }))
-      return EC;
-
-    return std::error_code();
+    return callST<ReadMem>(Channel, Src, Size);
   }
 
   std::error_code registerEHFrames(TargetAddress &RAddr, uint32_t Size) {
-    return call<RegisterEHFrames>(Channel, RAddr, Size);
+    return callST<RegisterEHFrames>(Channel, RAddr, Size);
   }
 
-  std::error_code reserveMem(TargetAddress &RemoteAddr,
-                             ResourceIdMgr::ResourceId Id, uint64_t Size,
-                             uint32_t Align) {
+  ErrorOr<TargetAddress> reserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
+				    uint32_t Align) {
 
     // Check for an 'out-of-band' error, e.g. from an MM destructor.
     if (ExistingError)
       return ExistingError;
 
-    if (std::error_code EC = call<ReserveMem>(Channel, Id, Size, Align))
-      return EC;
-
-    return expect<ReserveMemResponse>(Channel, readArgs(RemoteAddr));
+    return callST<ReserveMem>(Channel, Id, Size, Align);
   }
 
   std::error_code setProtections(ResourceIdMgr::ResourceId Id,
                                  TargetAddress RemoteSegAddr,
                                  unsigned ProtFlags) {
-    return call<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags);
+    return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags);
   }
 
   std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t Size) {
@@ -761,15 +707,7 @@ private:
     if (ExistingError)
       return ExistingError;
 
-    // Make the send call.
-    if (auto EC = call<WriteMem>(Channel, Addr, Size))
-      return EC;
-
-    // Follow this up with the section contents.
-    if (auto EC = Channel.appendBytes(Src, Size))
-      return EC;
-
-    return Channel.send();
+    return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size));
   }
 
   std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) {
@@ -777,7 +715,7 @@ private:
     if (ExistingError)
       return ExistingError;
 
-    return call<WritePtr>(Channel, Addr, PtrVal);
+    return callST<WritePtr>(Channel, Addr, PtrVal);
   }
 
   static std::error_code doNothing() { return std::error_code(); }

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h Sun Apr 17 20:06:49 2016
@@ -24,8 +24,48 @@ namespace llvm {
 namespace orc {
 namespace remote {
 
+class DirectBufferWriter {
+public:
+  DirectBufferWriter() = default;
+  DirectBufferWriter(const char *Src, TargetAddress Dst, uint64_t Size)
+    : Src(Src), Dst(Dst), Size(Size) {}
+  
+  const char *getSrc() const { return Src; }
+  TargetAddress getDst() const { return Dst; }
+  uint64_t getSize() const { return Size; }
+private:
+  const char *Src;
+  TargetAddress Dst;
+  uint64_t Size;
+};
+
+inline std::error_code serialize(RPCChannel &C,
+				 const DirectBufferWriter &DBW) {
+  if (auto EC = serialize(C, DBW.getDst()))
+    return EC;
+  if (auto EC = serialize(C, DBW.getSize()))
+    return EC;
+  return C.appendBytes(DBW.getSrc(), DBW.getSize());
+}
+  
+inline std::error_code deserialize(RPCChannel &C,
+				   DirectBufferWriter &DBW) {
+  TargetAddress Dst;
+  if (auto EC = deserialize(C, Dst))
+    return EC;
+  uint64_t Size;
+  if (auto EC = deserialize(C, Size))
+    return EC;
+  char *Addr = reinterpret_cast<char*>(static_cast<uintptr_t>(Dst));
+  
+  DBW = DirectBufferWriter(0, Dst, Size);
+  
+  return C.readBytes(Addr, Size);
+}
+
 class OrcRemoteTargetRPCAPI : public RPC<RPCChannel> {
 protected:
+
   class ResourceIdMgr {
   public:
     typedef uint64_t ResourceId;
@@ -45,146 +85,111 @@ protected:
   };
 
 public:
-  enum JITProcId : uint32_t {
-    InvalidId = 0,
-    CallIntVoidId,
-    CallIntVoidResponseId,
+  enum JITFuncId : uint32_t {
+    InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId,
+    CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId,
     CallMainId,
-    CallMainResponseId,
     CallVoidVoidId,
-    CallVoidVoidResponseId,
     CreateRemoteAllocatorId,
     CreateIndirectStubsOwnerId,
     DeregisterEHFramesId,
     DestroyRemoteAllocatorId,
     DestroyIndirectStubsOwnerId,
     EmitIndirectStubsId,
-    EmitIndirectStubsResponseId,
     EmitResolverBlockId,
     EmitTrampolineBlockId,
-    EmitTrampolineBlockResponseId,
     GetSymbolAddressId,
-    GetSymbolAddressResponseId,
     GetRemoteInfoId,
-    GetRemoteInfoResponseId,
     ReadMemId,
-    ReadMemResponseId,
     RegisterEHFramesId,
     ReserveMemId,
-    ReserveMemResponseId,
     RequestCompileId,
-    RequestCompileResponseId,
     SetProtectionsId,
     TerminateSessionId,
     WriteMemId,
     WritePtrId
   };
 
-  static const char *getJITProcIdName(JITProcId Id);
-
-  typedef Procedure<CallIntVoidId, void(TargetAddress Addr)> CallIntVoid;
+  static const char *getJITFuncIdName(JITFuncId Id);
 
-  typedef Procedure<CallIntVoidResponseId, void(int Result)>
-    CallIntVoidResponse;
+  typedef Function<CallIntVoidId, int32_t(TargetAddress Addr)> CallIntVoid;
 
-  typedef Procedure<CallMainId, void(TargetAddress Addr,
-                                     std::vector<std::string> Args)>
+  typedef Function<CallMainId, int32_t(TargetAddress Addr,
+				       std::vector<std::string> Args)>
       CallMain;
 
-  typedef Procedure<CallMainResponseId, void(int Result)> CallMainResponse;
-
-  typedef Procedure<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid;
-
-  typedef Procedure<CallVoidVoidResponseId, void()> CallVoidVoidResponse;
+  typedef Function<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid;
 
-  typedef Procedure<CreateRemoteAllocatorId,
-                    void(ResourceIdMgr::ResourceId AllocatorID)>
+  typedef Function<CreateRemoteAllocatorId,
+		   void(ResourceIdMgr::ResourceId AllocatorID)>
       CreateRemoteAllocator;
 
-  typedef Procedure<CreateIndirectStubsOwnerId,
-                    void(ResourceIdMgr::ResourceId StubOwnerID)>
+  typedef Function<CreateIndirectStubsOwnerId,
+		   void(ResourceIdMgr::ResourceId StubOwnerID)>
     CreateIndirectStubsOwner;
 
-  typedef Procedure<DeregisterEHFramesId,
-                    void(TargetAddress Addr, uint32_t Size)>
+  typedef Function<DeregisterEHFramesId,
+		   void(TargetAddress Addr, uint32_t Size)>
       DeregisterEHFrames;
 
-  typedef Procedure<DestroyRemoteAllocatorId,
-                    void(ResourceIdMgr::ResourceId AllocatorID)>
+  typedef Function<DestroyRemoteAllocatorId,
+		   void(ResourceIdMgr::ResourceId AllocatorID)>
       DestroyRemoteAllocator;
 
-  typedef Procedure<DestroyIndirectStubsOwnerId,
-                    void(ResourceIdMgr::ResourceId StubsOwnerID)>
+  typedef Function<DestroyIndirectStubsOwnerId,
+		   void(ResourceIdMgr::ResourceId StubsOwnerID)>
       DestroyIndirectStubsOwner;
 
-  typedef Procedure<EmitIndirectStubsId,
-                    void(ResourceIdMgr::ResourceId StubsOwnerID,
-                         uint32_t NumStubsRequired)>
+  /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted).
+  typedef Function<EmitIndirectStubsId,
+		   std::tuple<TargetAddress, TargetAddress, uint32_t>(
+                        ResourceIdMgr::ResourceId StubsOwnerID,
+			uint32_t NumStubsRequired)>
       EmitIndirectStubs;
 
-  typedef Procedure<EmitIndirectStubsResponseId,
-                    void(TargetAddress StubsBaseAddr,
-                         TargetAddress PtrsBaseAddr,
-                         uint32_t NumStubsEmitted)>
-      EmitIndirectStubsResponse;
+  typedef Function<EmitResolverBlockId, void()> EmitResolverBlock;
 
-  typedef Procedure<EmitResolverBlockId, void()> EmitResolverBlock;
+  /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines).
+  typedef Function<EmitTrampolineBlockId,
+		   std::tuple<TargetAddress, uint32_t>()> EmitTrampolineBlock;
 
-  typedef Procedure<EmitTrampolineBlockId, void()> EmitTrampolineBlock;
-
-  typedef Procedure<EmitTrampolineBlockResponseId,
-                    void(TargetAddress BlockAddr, uint32_t NumTrampolines)>
-      EmitTrampolineBlockResponse;
-
-  typedef Procedure<GetSymbolAddressId, void(std::string SymbolName)>
+  typedef Function<GetSymbolAddressId, TargetAddress(std::string SymbolName)>
       GetSymbolAddress;
 
-  typedef Procedure<GetSymbolAddressResponseId, void(uint64_t SymbolAddr)>
-      GetSymbolAddressResponse;
-
-  typedef Procedure<GetRemoteInfoId, void()> GetRemoteInfo;
-
-  typedef Procedure<GetRemoteInfoResponseId,
-                    void(std::string Triple, uint32_t PointerSize,
-                         uint32_t PageSize, uint32_t TrampolineSize,
-                         uint32_t IndirectStubSize)>
-      GetRemoteInfoResponse;
+  /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize,
+  ///                          IndirectStubsSize).
+  typedef Function<GetRemoteInfoId,
+		   std::tuple<std::string, uint32_t, uint32_t, uint32_t,
+			      uint32_t>()> GetRemoteInfo;
 
-  typedef Procedure<ReadMemId, void(TargetAddress Src, uint64_t Size)>
+  typedef Function<ReadMemId,
+		   std::vector<char>(TargetAddress Src, uint64_t Size)>
       ReadMem;
 
-  typedef Procedure<ReadMemResponseId, void()> ReadMemResponse;
-
-  typedef Procedure<RegisterEHFramesId,
-                    void(TargetAddress Addr, uint32_t Size)>
+  typedef Function<RegisterEHFramesId,
+		   void(TargetAddress Addr, uint32_t Size)>
       RegisterEHFrames;
 
-  typedef Procedure<ReserveMemId,
-                    void(ResourceIdMgr::ResourceId AllocID, uint64_t Size,
-                         uint32_t Align)>
+  typedef Function<ReserveMemId,
+		   TargetAddress(ResourceIdMgr::ResourceId AllocID,
+				 uint64_t Size, uint32_t Align)>
       ReserveMem;
 
-  typedef Procedure<ReserveMemResponseId, void(TargetAddress Addr)>
-      ReserveMemResponse;
-
-  typedef Procedure<RequestCompileId, void(TargetAddress TrampolineAddr)>
+  typedef Function<RequestCompileId,
+		   TargetAddress(TargetAddress TrampolineAddr)>
       RequestCompile;
 
-  typedef Procedure<RequestCompileResponseId, void(TargetAddress ImplAddr)>
-      RequestCompileResponse;
-
-  typedef Procedure<SetProtectionsId,
-                    void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst,
-                         uint32_t ProtFlags)>
+  typedef Function<SetProtectionsId,
+		   void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst,
+			uint32_t ProtFlags)>
       SetProtections;
 
-  typedef Procedure<TerminateSessionId, void()> TerminateSession;
+  typedef Function<TerminateSessionId, void()> TerminateSession;
 
-  typedef Procedure<WriteMemId,
-                    void(TargetAddress Dst, uint64_t Size /* Data to follow */)>
+  typedef Function<WriteMemId, void(DirectBufferWriter DB)>
       WriteMem;
 
-  typedef Procedure<WritePtrId, void(TargetAddress Dst, TargetAddress Val)>
+  typedef Function<WritePtrId, void(TargetAddress Dst, TargetAddress Val)>
       WritePtr;
 };
 

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h Sun Apr 17 20:06:49 2016
@@ -45,14 +45,14 @@ public:
         EHFramesRegister(std::move(EHFramesRegister)),
         EHFramesDeregister(std::move(EHFramesDeregister)) {}
 
-  std::error_code getNextProcId(JITProcId &Id) {
+  std::error_code getNextFuncId(JITFuncId &Id) {
     return deserialize(Channel, Id);
   }
 
-  std::error_code handleKnownProcedure(JITProcId Id) {
+  std::error_code handleKnownFunction(JITFuncId Id) {
     typedef OrcRemoteTargetServer ThisT;
 
-    DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
+    DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n");
 
     switch (Id) {
     case CallIntVoidId:
@@ -111,27 +111,17 @@ public:
     llvm_unreachable("Unhandled JIT RPC procedure Id.");
   }
 
-  std::error_code requestCompile(TargetAddress &CompiledFnAddr,
-                                 TargetAddress TrampolineAddr) {
-    if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
-      return EC;
-
-    while (1) {
-      JITProcId Id = InvalidId;
-      if (auto EC = getNextProcId(Id))
-        return EC;
+  ErrorOr<TargetAddress> requestCompile(TargetAddress TrampolineAddr) {
+    auto Listen =
+      [&](RPCChannel &C, uint32_t Id) {
+        return handleKnownFunction(static_cast<JITFuncId>(Id));
+      };
 
-      switch (Id) {
-      case RequestCompileResponseId:
-        return handle<RequestCompileResponse>(Channel,
-                                              readArgs(CompiledFnAddr));
-      default:
-        if (auto EC = handleKnownProcedure(Id))
-          return EC;
-      }
-    }
+    return callSTHandling<RequestCompile>(Channel, Listen, TrampolineAddr);
+  }
 
-    llvm_unreachable("Fell through request-compile command loop.");
+  void handleTerminateSession() {
+    handle<TerminateSession>(Channel, [](){ return std::error_code(); });
   }
 
 private:
@@ -175,18 +165,16 @@ private:
   static std::error_code doNothing() { return std::error_code(); }
 
   static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) {
-    TargetAddress CompiledFnAddr = 0;
-
     auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
-    auto EC = T->requestCompile(
-        CompiledFnAddr, static_cast<TargetAddress>(
-                            reinterpret_cast<uintptr_t>(TrampolineAddr)));
-    assert(!EC && "Compile request failed");
-    (void)EC;
-    return CompiledFnAddr;
+    auto AddrOrErr = T->requestCompile(
+		       static_cast<TargetAddress>(
+		         reinterpret_cast<uintptr_t>(TrampolineAddr)));
+    // FIXME: Allow customizable failure substitution functions.
+    assert(AddrOrErr && "Compile request failed");
+    return *AddrOrErr;
   }
 
-  std::error_code handleCallIntVoid(TargetAddress Addr) {
+  ErrorOr<int32_t> handleCallIntVoid(TargetAddress Addr) {
     typedef int (*IntVoidFnTy)();
     IntVoidFnTy Fn =
         reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
@@ -195,11 +183,11 @@ private:
     int Result = Fn();
     DEBUG(dbgs() << "  Result = " << Result << "\n");
 
-    return call<CallIntVoidResponse>(Channel, Result);
+    return Result;
   }
 
-  std::error_code handleCallMain(TargetAddress Addr,
-                                 std::vector<std::string> Args) {
+  ErrorOr<int32_t> handleCallMain(TargetAddress Addr,
+				  std::vector<std::string> Args) {
     typedef int (*MainFnTy)(int, const char *[]);
 
     MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
@@ -214,7 +202,7 @@ private:
     int Result = Fn(ArgC, ArgV.get());
     DEBUG(dbgs() << "  Result = " << Result << "\n");
 
-    return call<CallMainResponse>(Channel, Result);
+    return Result;
   }
 
   std::error_code handleCallVoidVoid(TargetAddress Addr) {
@@ -226,7 +214,7 @@ private:
     Fn();
     DEBUG(dbgs() << "  Complete.\n");
 
-    return call<CallVoidVoidResponse>(Channel);
+    return std::error_code();
   }
 
   std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
@@ -273,8 +261,9 @@ private:
     return std::error_code();
   }
 
-  std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
-                                          uint32_t NumStubsRequired) {
+  ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>>
+  handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
+			  uint32_t NumStubsRequired) {
     DEBUG(dbgs() << "  ISMgr " << Id << " request " << NumStubsRequired
                  << " stubs.\n");
 
@@ -296,8 +285,7 @@ private:
     auto &BlockList = StubOwnerItr->second;
     BlockList.push_back(std::move(IS));
 
-    return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
-                                           NumStubsEmitted);
+    return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted);
   }
 
   std::error_code handleEmitResolverBlock() {
@@ -316,7 +304,8 @@ private:
                                                 sys::Memory::MF_EXEC);
   }
 
-  std::error_code handleEmitTrampolineBlock() {
+  ErrorOr<std::tuple<TargetAddress, uint32_t>>
+  handleEmitTrampolineBlock() {
     std::error_code EC;
     auto TrampolineBlock =
         sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
@@ -325,7 +314,7 @@ private:
     if (EC)
       return EC;
 
-    unsigned NumTrampolines =
+    uint32_t NumTrampolines =
         (sys::Process::getPageSize() - TargetT::PointerSize) /
         TargetT::TrampolineSize;
 
@@ -339,20 +328,21 @@ private:
 
     TrampolineBlocks.push_back(std::move(TrampolineBlock));
 
-    return call<EmitTrampolineBlockResponse>(
-        Channel,
-        static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
-        NumTrampolines);
+    auto TrampolineBaseAddr =
+      static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem));
+
+    return std::make_tuple(TrampolineBaseAddr, NumTrampolines);
   }
 
-  std::error_code handleGetSymbolAddress(const std::string &Name) {
+  ErrorOr<TargetAddress> handleGetSymbolAddress(const std::string &Name) {
     TargetAddress Addr = SymbolLookup(Name);
     DEBUG(dbgs() << "  Symbol '" << Name << "' =  " << format("0x%016x", Addr)
                  << "\n");
-    return call<GetSymbolAddressResponse>(Channel, Addr);
+    return Addr;
   }
 
-  std::error_code handleGetRemoteInfo() {
+  ErrorOr<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>>
+  handleGetRemoteInfo() {
     std::string ProcessTriple = sys::getProcessTriple();
     uint32_t PointerSize = TargetT::PointerSize;
     uint32_t PageSize = sys::Process::getPageSize();
@@ -364,24 +354,23 @@ private:
                  << "    page size          = " << PageSize << "\n"
                  << "    trampoline size    = " << TrampolineSize << "\n"
                  << "    indirect stub size = " << IndirectStubSize << "\n");
-    return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize,
-                                       PageSize, TrampolineSize,
-                                       IndirectStubSize);
+    return std::make_tuple(ProcessTriple, PointerSize, PageSize ,TrampolineSize,
+			   IndirectStubSize);
   }
 
-  std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
+  ErrorOr<std::vector<char>>
+  handleReadMem(TargetAddress RSrc, uint64_t Size) {
     char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
 
     DEBUG(dbgs() << "  Reading " << Size << " bytes from "
                  << format("0x%016x", RSrc) << "\n");
 
-    if (auto EC = call<ReadMemResponse>(Channel))
-      return EC;
-
-    if (auto EC = Channel.appendBytes(Src, Size))
-      return EC;
+    std::vector<char> Buffer;
+    Buffer.resize(Size);
+    for (char *P = Src; Size != 0; --Size)
+      Buffer.push_back(*P++);
 
-    return Channel.send();
+    return Buffer;
   }
 
   std::error_code handleRegisterEHFrames(TargetAddress TAddr, uint32_t Size) {
@@ -392,8 +381,9 @@ private:
     return std::error_code();
   }
 
-  std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
-                                   uint32_t Align) {
+  ErrorOr<TargetAddress>
+  handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
+		   uint32_t Align) {
     auto I = Allocators.find(Id);
     if (I == Allocators.end())
       return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
@@ -408,7 +398,7 @@ private:
     TargetAddress AllocAddr =
         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
 
-    return call<ReserveMemResponse>(Channel, AllocAddr);
+    return AllocAddr;
   }
 
   std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
@@ -425,11 +415,10 @@ private:
     return Allocator.setProtections(LocalAddr, Flags);
   }
 
-  std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
-    char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
-    DEBUG(dbgs() << "  Writing " << Size << " bytes to "
-                 << format("0x%016x", RDst) << "\n");
-    return Channel.readBytes(Dst, Size);
+  std::error_code handleWriteMem(DirectBufferWriter DBW) {
+    DEBUG(dbgs() << "  Writing " << DBW.getSize() << " bytes to "
+	         << format("0x%016x", DBW.getDst()) << "\n");
+    return std::error_code();
   }
 
   std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) {

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h Sun Apr 17 20:06:49 2016
@@ -5,8 +5,10 @@
 
 #include "OrcError.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Endian.h"
 
+#include <mutex>
 #include <system_error>
 
 namespace llvm {
@@ -26,31 +28,68 @@ public:
 
   /// Flush the stream if possible.
   virtual std::error_code send() = 0;
+
+  /// Get the lock for stream reading.
+  std::mutex& getReadLock() { return readLock; }
+
+  /// Get the lock for stream writing.
+  std::mutex& getWriteLock() { return writeLock; }
+
+private:
+  std::mutex readLock, writeLock;
 };
 
+/// Notify the channel that we're starting a message send.
+/// Locks the channel for writing.
+inline std::error_code startSendMessage(RPCChannel &C) {
+  C.getWriteLock().lock();
+  return std::error_code();
+}
+
+/// Notify the channel that we're ending a message send.
+/// Unlocks the channel for writing.
+inline std::error_code endSendMessage(RPCChannel &C) {
+  C.getWriteLock().unlock();
+  return std::error_code();
+}
+
+/// Notify the channel that we're starting a message receive.
+/// Locks the channel for reading.
+inline std::error_code startReceiveMessage(RPCChannel &C) {
+  C.getReadLock().lock();
+  return std::error_code();
+}
+
+/// Notify the channel that we're ending a message receive.
+/// Unlocks the channel for reading.
+inline std::error_code endReceiveMessage(RPCChannel &C) {
+  C.getReadLock().unlock();
+  return std::error_code();
+}
+
 /// RPC channel serialization for a variadic list of arguments.
 template <typename T, typename... Ts>
-std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &... Args) {
+std::error_code serializeSeq(RPCChannel &C, const T &Arg, const Ts &... Args) {
   if (auto EC = serialize(C, Arg))
     return EC;
-  return serialize_seq(C, Args...);
+  return serializeSeq(C, Args...);
 }
 
 /// RPC channel serialization for an (empty) variadic list of arguments.
-inline std::error_code serialize_seq(RPCChannel &C) {
+inline std::error_code serializeSeq(RPCChannel &C) {
   return std::error_code();
 }
 
 /// RPC channel deserialization for a variadic list of arguments.
 template <typename T, typename... Ts>
-std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) {
+std::error_code deserializeSeq(RPCChannel &C, T &Arg, Ts &... Args) {
   if (auto EC = deserialize(C, Arg))
     return EC;
-  return deserialize_seq(C, Args...);
+  return deserializeSeq(C, Args...);
 }
 
 /// RPC channel serialization for an (empty) variadic list of arguments.
-inline std::error_code deserialize_seq(RPCChannel &C) {
+inline std::error_code deserializeSeq(RPCChannel &C) {
   return std::error_code();
 }
 
@@ -138,6 +177,34 @@ inline std::error_code deserialize(RPCCh
   return C.readBytes(&S[0], Count);
 }
 
+// Serialization helper for std::tuple.
+template <typename TupleT, size_t... Is>
+inline std::error_code serializeTupleHelper(RPCChannel &C,
+					    const TupleT &V,
+					    llvm::index_sequence<Is...> _) {
+  return serializeSeq(C, std::get<Is>(V)...);
+}
+
+/// RPC channel serialization for std::tuple.
+template <typename... ArgTs>
+inline std::error_code serialize(RPCChannel &C, const std::tuple<ArgTs...> &V) {
+  return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
+}
+
+// Serialization helper for std::tuple.
+template <typename TupleT, size_t... Is>
+inline std::error_code deserializeTupleHelper(RPCChannel &C,
+					      TupleT &V,
+					      llvm::index_sequence<Is...> _) {
+  return deserializeSeq(C, std::get<Is>(V)...);
+}
+
+/// RPC channel deserialization for std::tuple.
+template <typename... ArgTs>
+inline std::error_code deserialize(RPCChannel &C, std::tuple<ArgTs...> &V) {
+  return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
+}
+
 /// RPC channel serialization for ArrayRef<T>.
 template <typename T>
 std::error_code serialize(RPCChannel &C, const ArrayRef<T> &A) {

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h Sun Apr 17 20:06:49 2016
@@ -14,46 +14,197 @@
 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
 
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ExecutionEngine/Orc/OrcError.h"
+#include "llvm/Support/ErrorOr.h"
+#include <future>
+#include <map>
 
 namespace llvm {
 namespace orc {
 namespace remote {
 
+/// Describes reserved RPC Function Ids.
+///
+/// The default implementation will serve for integer and enum function id
+/// types. If you want to use a custom type as your FunctionId you can
+/// specialize this class and provide unique values for InvalidId,
+/// ResponseId and FirstValidId.
+
+template <typename T>
+class RPCFunctionIdTraits {
+public:
+  constexpr static const T InvalidId = static_cast<T>(0);
+  constexpr static const T ResponseId = static_cast<T>(1);
+  constexpr static const T FirstValidId = static_cast<T>(2);
+};
+
 // Base class containing utilities that require partial specialization.
 // These cannot be included in RPC, as template class members cannot be
 // partially specialized.
 class RPCBase {
 protected:
-  template <typename ProcedureIdT, ProcedureIdT ProcId, typename FnT>
-  class ProcedureHelper {
-  public:
-    static const ProcedureIdT Id = ProcId;
-  };
 
-  template <typename ChannelT, typename Proc> class CallHelper;
+  // RPC Function description type.
+  //
+  // This class provides the information and operations needed to support the
+  // RPC primitive operations (call, expect, etc) for a given function. It
+  // is specialized for void and non-void functions to deal with the differences
+  // betwen the two. Both specializations have the same interface:
+  //
+  // Id - The function's unique identifier.
+  // OptionalReturn - The return type for asyncronous calls.
+  // ErrorReturn - The return type for synchronous calls.
+  // optionalToErrorReturn - Conversion from a valid OptionalReturn to an
+  //                         ErrorReturn.
+  // readResult - Deserialize a result from a channel.
+  // abandon - Abandon a promised (asynchronous) result.
+  // respond - Retun a result on the channel.
+  template <typename FunctionIdT, FunctionIdT FuncId, typename FnT>
+  class FunctionHelper {};
 
-  template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId,
+  // RPC Function description specialization for non-void functions.
+  template <typename FunctionIdT, FunctionIdT FuncId, typename RetT,
             typename... ArgTs>
-  class CallHelper<ChannelT,
-                   ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> {
+  class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> {
   public:
-    static std::error_code call(ChannelT &C, const ArgTs &... Args) {
-      if (auto EC = serialize(C, ProcId))
+
+    static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId &&
+                  FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId,
+                  "Cannot define custom function with InvalidId or ResponseId. "
+                  "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId.");
+
+    static const FunctionIdT Id = FuncId;
+
+    typedef Optional<RetT> OptionalReturn;
+
+    typedef ErrorOr<RetT> ErrorReturn;
+
+    static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) {
+      assert(V && "Return value not available");
+      return std::move(*V);
+    }
+
+    template <typename ChannelT>
+    static std::error_code readResult(ChannelT &C,
+                                      std::promise<OptionalReturn> &P) {
+      RetT Val;
+      auto EC = deserialize(C, Val);
+      // FIXME: Join error EC2 from endReceiveMessage with the deserialize
+      //        error once we switch to using Error.
+      auto EC2 = endReceiveMessage(C);
+      (void)EC2;
+
+      if (EC) {
+        P.set_value(OptionalReturn());
         return EC;
-      // If you see a compile-error on this line you're probably calling a
-      // function with the wrong signature.
-      return serialize_seq(C, Args...);
+      }
+      P.set_value(std::move(Val));
+      return std::error_code();
+    }
+
+    static void abandon(std::promise<OptionalReturn> &P) {
+      P.set_value(OptionalReturn());
+    }
+
+    template <typename ChannelT, typename SequenceNumberT>
+    static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo,
+                                   const ErrorReturn &Result) {
+      FunctionIdT ResponseId =
+        RPCFunctionIdTraits<FunctionIdT>::ResponseId;
+
+      // If the handler returned an error then bail out with that.
+      if (!Result)
+        return Result.getError();
+
+      // Otherwise open a new message on the channel and send the result.
+      if (auto EC = startSendMessage(C))
+        return EC;
+      if (auto EC = serializeSeq(C, ResponseId, SeqNo, *Result))
+        return EC;
+      return endSendMessage(C);
     }
   };
 
-  template <typename ChannelT, typename Proc> class HandlerHelper;
+  // RPC Function description specialization for void functions.
+  template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs>
+  class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> {
+  public:
 
-  template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId,
-            typename... ArgTs>
-  class HandlerHelper<ChannelT,
-                      ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>> {
+    static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId &&
+                  FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId,
+                  "Cannot define custom function with InvalidId or ResponseId. "
+                  "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId.");
+
+    static const FunctionIdT Id = FuncId;
+
+    typedef bool OptionalReturn;
+    typedef std::error_code ErrorReturn;
+
+    static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) {
+      assert(V && "Return value not available");
+      return std::error_code();
+    }
+
+    template <typename ChannelT>
+    static std::error_code readResult(ChannelT &C,
+                                      std::promise<OptionalReturn> &P) {
+      // Void functions don't have anything to deserialize, so we're good.
+      P.set_value(true);
+      return endReceiveMessage(C);
+    }
+
+    static void abandon(std::promise<OptionalReturn> &P) {
+      P.set_value(false);
+    }
+
+    template <typename ChannelT, typename SequenceNumberT>
+    static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo,
+				   const ErrorReturn &Result) {
+      const FunctionIdT ResponseId =
+	RPCFunctionIdTraits<FunctionIdT>::ResponseId;
+
+      // If the handler returned an error then bail out with that.
+      if (Result)
+        return Result;
+
+      // Otherwise open a new message on the channel and send the result.
+      if (auto EC = startSendMessage(C))
+        return EC;
+      if (auto EC = serializeSeq(C, ResponseId, SeqNo))
+        return EC;
+      return endSendMessage(C);
+    }
+  };
+
+  // Helper for the call primitive.
+  template <typename ChannelT, typename SequenceNumberT, typename Func>
+  class CallHelper;
+
+  template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT,
+	    FunctionIdT FuncId, typename RetT, typename... ArgTs>
+  class CallHelper<ChannelT, SequenceNumberT,
+                   FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> {
+  public:
+    static std::error_code call(ChannelT &C, SequenceNumberT SeqNo,
+				const ArgTs &... Args) {
+      if (auto EC = startSendMessage(C))
+        return EC;
+      if (auto EC = serializeSeq(C, FuncId, SeqNo, Args...))
+        return EC;
+      return endSendMessage(C);
+    }
+  };
+
+  // Helper for handle primitive.
+  template <typename ChannelT, typename SequenceNumberT, typename Func>
+  class HandlerHelper;
+
+  template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT,
+	    FunctionIdT FuncId, typename RetT, typename... ArgTs>
+  class HandlerHelper<ChannelT, SequenceNumberT,
+                      FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> {
   public:
     template <typename HandlerT>
     static std::error_code handle(ChannelT &C, HandlerT Handler) {
@@ -61,34 +212,46 @@ protected:
     }
 
   private:
+
+    typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func;
+
     template <typename HandlerT, size_t... Is>
     static std::error_code readAndHandle(ChannelT &C, HandlerT Handler,
                                          llvm::index_sequence<Is...> _) {
       std::tuple<ArgTs...> RPCArgs;
+      SequenceNumberT SeqNo;
       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
       // for RPCArgs. Void cast RPCArgs to work around this for now.
       // FIXME: Remove this workaround once we can assume a working GCC version.
       (void)RPCArgs;
-      if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...))
+      if (auto EC = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...))
         return EC;
-      return Handler(std::get<Is>(RPCArgs)...);
+
+      // We've deserialized the arguments, so unlock the channel for reading
+      // before we call the handler. This allows recursive RPC calls.
+      if (auto EC = endReceiveMessage(C))
+        return EC;
+
+      return Func::template respond<ChannelT, SequenceNumberT>(
+                     C, SeqNo, Handler(std::get<Is>(RPCArgs)...));
     }
+
   };
 
-  template <typename ClassT, typename... ArgTs> class MemberFnWrapper {
+  // Helper for wrapping member functions up as functors.
+  template <typename ClassT, typename RetT, typename... ArgTs>
+  class MemberFnWrapper {
   public:
-    typedef std::error_code (ClassT::*MethodT)(ArgTs...);
+    typedef RetT(ClassT::*MethodT)(ArgTs...);
     MemberFnWrapper(ClassT &Instance, MethodT Method)
         : Instance(Instance), Method(Method) {}
-    std::error_code operator()(ArgTs &... Args) {
-      return (Instance.*Method)(Args...);
-    }
-
+    RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); }
   private:
     ClassT &Instance;
     MethodT Method;
   };
 
+  // Helper that provides a Functor for deserializing arguments.
   template <typename... ArgTs> class ReadArgs {
   public:
     std::error_code operator()() { return std::error_code(); }
@@ -112,7 +275,7 @@ protected:
 
 /// Contains primitive utilities for defining, calling and handling calls to
 /// remote procedures. ChannelT is a bidirectional stream conforming to the
-/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure
+/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure
 /// identifier type that must be serializable on ChannelT.
 ///
 /// These utilities support the construction of very primitive RPC utilities.
@@ -129,120 +292,184 @@ protected:
 ///
 /// Overview (see comments individual types/methods for details):
 ///
-/// Procedure<Id, Args...> :
+/// Function<Id, Args...> :
 ///
 ///   associates a unique serializable id with an argument list.
 ///
 ///
-/// call<Proc>(Channel, Args...) :
+/// call<Func>(Channel, Args...) :
 ///
-///   Calls the remote procedure 'Proc' by serializing Proc's id followed by its
+///   Calls the remote procedure 'Func' by serializing Func's id followed by its
 /// arguments and sending the resulting bytes to 'Channel'.
 ///
 ///
-/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> :
+/// handle<Func>(Channel, <functor matching std::error_code(Args...)> :
 ///
-///   Handles a call to 'Proc' by deserializing its arguments and calling the
-/// given functor. This assumes that the id for 'Proc' has already been
+///   Handles a call to 'Func' by deserializing its arguments and calling the
+/// given functor. This assumes that the id for 'Func' has already been
 /// deserialized.
 ///
-/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> :
+/// expect<Func>(Channel, <functor matching std::error_code(Args...)> :
 ///
 ///   The same as 'handle', except that the procedure id should not have been
-/// read yet. Expect will deserialize the id and assert that it matches Proc's
+/// read yet. Expect will deserialize the id and assert that it matches Func's
 /// id. If it does not, and unexpected RPC call error is returned.
-
-template <typename ChannelT, typename ProcedureIdT = uint32_t>
+template <typename ChannelT, typename FunctionIdT = uint32_t,
+          typename SequenceNumberT = uint16_t>
 class RPC : public RPCBase {
 public:
+
+  RPC() = default;
+  RPC(const RPC&) = delete;
+  RPC& operator=(const RPC&) = delete;
+  RPC(RPC &&Other) : SequenceNumberMgr(std::move(Other.SequenceNumberMgr)), OutstandingResults(std::move(Other.OutstandingResults)) {}
+  RPC& operator=(RPC&&) = default;
+
   /// Utility class for defining/referring to RPC procedures.
   ///
   /// Typedefs of this utility are used when calling/handling remote procedures.
   ///
-  /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any
-  /// other Procedure typedef in the RPC API being defined.
+  /// FuncId should be a unique value of FunctionIdT (i.e. not used with any
+  /// other Function typedef in the RPC API being defined.
   ///
   /// the template argument Ts... gives the argument list for the remote
   /// procedure.
   ///
   /// E.g.
   ///
-  ///   typedef Procedure<0, bool> Proc1;
-  ///   typedef Procedure<1, std::string, std::vector<int>> Proc2;
+  ///   typedef Function<0, bool> Func1;
+  ///   typedef Function<1, std::string, std::vector<int>> Func2;
   ///
-  ///   if (auto EC = call<Proc1>(Channel, true))
+  ///   if (auto EC = call<Func1>(Channel, true))
   ///     /* handle EC */;
   ///
-  ///   if (auto EC = expect<Proc2>(Channel,
+  ///   if (auto EC = expect<Func2>(Channel,
   ///         [](std::string &S, std::vector<int> &V) {
   ///           // Stuff.
   ///           return std::error_code();
   ///         })
   ///     /* handle EC */;
   ///
-  template <ProcedureIdT ProcId, typename FnT>
-  using Procedure = ProcedureHelper<ProcedureIdT, ProcId, FnT>;
+  template <FunctionIdT FuncId, typename FnT>
+  using Function = FunctionHelper<FunctionIdT, FuncId, FnT>;
+
+  /// Return type for asynchronous call primitives.
+  template <typename Func>
+  using AsyncCallResult =
+    std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>;
 
   /// Serialize Args... to channel C, but do not call C.send().
   ///
-  /// For buffered channels, this can be used to queue up several calls before
-  /// flushing the channel.
-  template <typename Proc, typename... ArgTs>
-  static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) {
-    return CallHelper<ChannelT, Proc>::call(C, Args...);
+  /// For void functions returns a std::future<Error>. For functions that
+  /// return an R, returns a std::future<Optional<R>>.
+  template <typename Func, typename... ArgTs>
+  ErrorOr<AsyncCallResult<Func>>
+  appendCallAsync(ChannelT &C, const ArgTs &... Args) {
+    auto SeqNo = SequenceNumberMgr.getSequenceNumber();
+    std::promise<typename Func::OptionalReturn> Promise;
+    auto Result = Promise.get_future();
+    OutstandingResults[SeqNo] = std::move(Promise);
+
+    if (auto EC =
+          CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo,
+							    Args...)) {
+      abandonOutstandingResults();
+      return EC;
+    } else
+      return AsyncCallResult<Func>(std::move(Result), SeqNo);
   }
 
   /// Serialize Args... to channel C and call C.send().
-  template <typename Proc, typename... ArgTs>
-  static std::error_code call(ChannelT &C, const ArgTs &... Args) {
-    if (auto EC = appendCall<Proc>(C, Args...))
+  template <typename Func, typename... ArgTs>
+  ErrorOr<AsyncCallResult<Func>>
+  callAsync(ChannelT &C, const ArgTs &... Args) {
+    auto SeqNo = SequenceNumberMgr.getSequenceNumber();
+    std::promise<typename Func::OptionalReturn> Promise;
+    auto Result = Promise.get_future();
+    OutstandingResults[SeqNo] =
+      createOutstandingResult<Func>(std::move(Promise));
+    if (auto EC =
+        CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, Args...)) {
+      abandonOutstandingResults();
       return EC;
-    return C.send();
+    }
+    if (auto EC = C.send()) {
+      abandonOutstandingResults();
+      return EC;
+    }
+    return AsyncCallResult<Func>(std::move(Result), SeqNo);
   }
 
-  /// Deserialize and return an enum whose underlying type is ProcedureIdT.
-  static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) {
+  /// This can be used in single-threaded mode.
+  template <typename Func, typename HandleFtor, typename... ArgTs>
+  typename Func::ErrorReturn
+  callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) {
+    if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) {
+      auto &ResultAndSeqNo = *ResultAndSeqNoOrErr;
+      if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther))
+	return EC;
+      return Func::optionalToErrorReturn(ResultAndSeqNo.first.get());
+    } else
+      return ResultAndSeqNoOrErr.getError();
+  }
+
+  // This can be used in single-threaded mode.
+  template <typename Func, typename... ArgTs>
+  typename Func::ErrorReturn
+  callST(ChannelT &C, const ArgTs &... Args) {
+    return callSTHandling<Func>(C, handleNone, Args...);
+  }
+
+  /// Start receiving a new function call.
+  ///
+  /// Calls startReceiveMessage on the channel, then deserializes a FunctionId
+  /// into Id.
+  std::error_code startReceivingFunction(ChannelT &C, FunctionIdT &Id) {
+    if (auto EC = startReceiveMessage(C))
+      return EC;
+
     return deserialize(C, Id);
   }
 
-  /// Deserialize args for Proc from C and call Handler. The signature of
+  /// Deserialize args for Func from C and call Handler. The signature of
   /// handler must conform to 'std::error_code(Args...)' where Args... matches
-  /// the arguments used in the Proc typedef.
-  template <typename Proc, typename HandlerT>
+  /// the arguments used in the Func typedef.
+  template <typename Func, typename HandlerT>
   static std::error_code handle(ChannelT &C, HandlerT Handler) {
-    return HandlerHelper<ChannelT, Proc>::handle(C, Handler);
+    return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler);
   }
 
   /// Helper version of 'handle' for calling member functions.
-  template <typename Proc, typename ClassT, typename... ArgTs>
+  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
   static std::error_code
   handle(ChannelT &C, ClassT &Instance,
-         std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
-    return handle<Proc>(
-        C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
+         RetT (ClassT::*HandlerMethod)(ArgTs...)) {
+    return handle<Func>(
+             C,
+	     MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod));
   }
 
-  /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc.
+  /// Deserialize a FunctionIdT from C and verify it matches the id for Func.
   /// If the id does match, deserialize the arguments and call the handler
   /// (similarly to handle).
   /// If the id does not match, return an unexpect RPC call error and do not
   /// deserialize any further bytes.
-  template <typename Proc, typename HandlerT>
-  static std::error_code expect(ChannelT &C, HandlerT Handler) {
-    ProcedureIdT ProcId;
-    if (auto EC = getNextProcId(C, ProcId))
-      return EC;
-    if (ProcId != Proc::Id)
+  template <typename Func, typename HandlerT>
+  std::error_code expect(ChannelT &C, HandlerT Handler) {
+    FunctionIdT FuncId;
+    if (auto EC = startReceivingFunction(C, FuncId))
+      return std::move(EC);
+    if (FuncId != Func::Id)
       return orcError(OrcErrorCode::UnexpectedRPCCall);
-    return handle<Proc>(C, Handler);
+    return handle<Func>(C, Handler);
   }
 
   /// Helper version of expect for calling member functions.
-  template <typename Proc, typename ClassT, typename... ArgTs>
+  template <typename Func, typename ClassT, typename... ArgTs>
   static std::error_code
   expect(ChannelT &C, ClassT &Instance,
          std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
-    return expect<Proc>(
+    return expect<Func>(
         C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
   }
 
@@ -251,18 +478,163 @@ public:
   /// channel.
   /// E.g.
   ///
-  ///   typedef Procedure<0, bool, int> Proc1;
+  ///   typedef Function<0, bool, int> Func1;
   ///
   ///   ...
   ///   bool B;
   ///   int I;
-  ///   if (auto EC = expect<Proc1>(Channel, readArgs(B, I)))
+  ///   if (auto EC = expect<Func1>(Channel, readArgs(B, I)))
   ///     /* Handle Args */ ;
   ///
   template <typename... ArgTs>
   static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
     return ReadArgs<ArgTs...>(Args...);
   }
+
+  /// Read a response from Channel.
+  /// This should be called from the receive loop to retrieve results.
+  std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) {
+    if (auto EC = deserialize(C, SeqNo)) {
+      abandonOutstandingResults();
+      return EC;
+    }
+
+    auto I = OutstandingResults.find(SeqNo);
+    if (I == OutstandingResults.end()) {
+      abandonOutstandingResults();
+      return orcError(OrcErrorCode::UnexpectedRPCResponse);
+    }
+
+    if (auto EC = I->second->readResult(C)) {
+      abandonOutstandingResults();
+      // FIXME: Release sequence numbers?
+      return EC;
+    }
+
+    OutstandingResults.erase(I);
+    SequenceNumberMgr.releaseSequenceNumber(SeqNo);
+
+    return std::error_code();
+  }
+
+  // Loop waiting for a result with the given sequence number.
+  // This can be used as a receive loop if the user doesn't have a default.
+  template <typename HandleOtherFtor>
+  std::error_code waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo,
+				HandleOtherFtor &HandleOther = handleNone) {
+    bool GotTgtResult = false;
+
+    while (!GotTgtResult) {
+      FunctionIdT Id =
+	RPCFunctionIdTraits<FunctionIdT>::InvalidId;
+      if (auto EC = startReceivingFunction(C, Id))
+	return EC;
+      if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) {
+        SequenceNumberT SeqNo;
+	if (auto EC = handleResponse(C, SeqNo))
+	  return EC;
+	GotTgtResult = (SeqNo == TgtSeqNo);
+      } else if (auto EC = HandleOther(C, Id))
+	return EC;
+    }
+
+    return std::error_code();
+  };
+
+  // Default handler for 'other' (non-response) functions when waiting for a
+  // result from the channel.
+  static std::error_code handleNone(ChannelT&, FunctionIdT) {
+    return orcError(OrcErrorCode::UnexpectedRPCCall);
+  };
+
+private:
+
+  // Manage sequence numbers.
+  class SequenceNumberManager {
+  public:
+
+    SequenceNumberManager() = default;
+
+    SequenceNumberManager(SequenceNumberManager &&Other)
+      : NextSequenceNumber(std::move(Other.NextSequenceNumber)),
+        FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {}
+
+    SequenceNumberManager& operator=(SequenceNumberManager &&Other) {
+      NextSequenceNumber = std::move(Other.NextSequenceNumber);
+      FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers);
+    }
+
+    void reset() {
+      std::lock_guard<std::mutex> Lock(SeqNoLock);
+      NextSequenceNumber = 0;
+      FreeSequenceNumbers.clear();
+    }
+
+    SequenceNumberT getSequenceNumber() {
+      std::lock_guard<std::mutex> Lock(SeqNoLock);
+      if (FreeSequenceNumbers.empty())
+        return NextSequenceNumber++;
+      auto SequenceNumber = FreeSequenceNumbers.back();
+      FreeSequenceNumbers.pop_back();
+      return SequenceNumber;
+    }
+
+    void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
+      std::lock_guard<std::mutex> Lock(SeqNoLock);
+      FreeSequenceNumbers.push_back(SequenceNumber);
+    }
+
+  private:
+    std::mutex SeqNoLock;
+    SequenceNumberT NextSequenceNumber = 0;
+    std::vector<SequenceNumberT> FreeSequenceNumbers;
+  };
+
+  // Base class for results that haven't been returned from the other end of the
+  // RPC connection yet.
+  class OutstandingResult {
+  public:
+    virtual ~OutstandingResult() {}
+    virtual std::error_code readResult(ChannelT &C) = 0;
+    virtual void abandon() = 0;
+  };
+
+  // Outstanding results for a specific function.
+  template <typename Func>
+  class OutstandingResultImpl : public OutstandingResult {
+  private:
+  public:
+    OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P)
+      : P(std::move(P)) {}
+
+    std::error_code readResult(ChannelT &C) override {
+      return Func::readResult(C, P);
+    }
+
+    void abandon() override { Func::abandon(P); }
+
+  private:
+    std::promise<typename Func::OptionalReturn> P;
+  };
+
+  // Create an outstanding result for the given function.
+  template <typename Func>
+  std::unique_ptr<OutstandingResult>
+  createOutstandingResult(std::promise<typename Func::OptionalReturn> &&P) {
+    return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P));
+  }
+
+  // Abandon all outstanding results.
+  void abandonOutstandingResults() {
+    for (auto &KV : OutstandingResults)
+      KV.second->abandon();
+    OutstandingResults.clear();
+    SequenceNumberMgr.reset();
+  }
+
+  SequenceNumberManager SequenceNumberMgr;
+  std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>>
+    OutstandingResults;
 };
 
 } // end namespace remote

Modified: llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp (original)
+++ llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp Sun Apr 17 20:06:49 2016
@@ -38,6 +38,8 @@ public:
       return "Remote indirect stubs owner Id already in use";
     case OrcErrorCode::UnexpectedRPCCall:
       return "Unexpected RPC call";
+    case OrcErrorCode::UnexpectedRPCResponse:
+      return "Unexpected RPC response";
     }
     llvm_unreachable("Unhandled error code");
   }

Modified: llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp (original)
+++ llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp Sun Apr 17 20:06:49 2016
@@ -13,50 +13,40 @@ namespace llvm {
 namespace orc {
 namespace remote {
 
-#define PROCNAME(X) \
+#define FUNCNAME(X) \
   case X ## Id: \
   return #X
 
-const char *OrcRemoteTargetRPCAPI::getJITProcIdName(JITProcId Id) {
+const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) {
   switch (Id) {
   case InvalidId:
-    return "*** Invalid JITProcId ***";
-  PROCNAME(CallIntVoid);
-  PROCNAME(CallIntVoidResponse);
-  PROCNAME(CallMain);
-  PROCNAME(CallMainResponse);
-  PROCNAME(CallVoidVoid);
-  PROCNAME(CallVoidVoidResponse);
-  PROCNAME(CreateRemoteAllocator);
-  PROCNAME(CreateIndirectStubsOwner);
-  PROCNAME(DeregisterEHFrames);
-  PROCNAME(DestroyRemoteAllocator);
-  PROCNAME(DestroyIndirectStubsOwner);
-  PROCNAME(EmitIndirectStubs);
-  PROCNAME(EmitIndirectStubsResponse);
-  PROCNAME(EmitResolverBlock);
-  PROCNAME(EmitTrampolineBlock);
-  PROCNAME(EmitTrampolineBlockResponse);
-  PROCNAME(GetSymbolAddress);
-  PROCNAME(GetSymbolAddressResponse);
-  PROCNAME(GetRemoteInfo);
-  PROCNAME(GetRemoteInfoResponse);
-  PROCNAME(ReadMem);
-  PROCNAME(ReadMemResponse);
-  PROCNAME(RegisterEHFrames);
-  PROCNAME(ReserveMem);
-  PROCNAME(ReserveMemResponse);
-  PROCNAME(RequestCompile);
-  PROCNAME(RequestCompileResponse);
-  PROCNAME(SetProtections);
-  PROCNAME(TerminateSession);
-  PROCNAME(WriteMem);
-  PROCNAME(WritePtr);
+    return "*** Invalid JITFuncId ***";
+  FUNCNAME(CallIntVoid);
+  FUNCNAME(CallMain);
+  FUNCNAME(CallVoidVoid);
+  FUNCNAME(CreateRemoteAllocator);
+  FUNCNAME(CreateIndirectStubsOwner);
+  FUNCNAME(DeregisterEHFrames);
+  FUNCNAME(DestroyRemoteAllocator);
+  FUNCNAME(DestroyIndirectStubsOwner);
+  FUNCNAME(EmitIndirectStubs);
+  FUNCNAME(EmitResolverBlock);
+  FUNCNAME(EmitTrampolineBlock);
+  FUNCNAME(GetSymbolAddress);
+  FUNCNAME(GetRemoteInfo);
+  FUNCNAME(ReadMem);
+  FUNCNAME(RegisterEHFrames);
+  FUNCNAME(ReserveMem);
+  FUNCNAME(RequestCompile);
+  FUNCNAME(SetProtections);
+  FUNCNAME(TerminateSession);
+  FUNCNAME(WriteMem);
+  FUNCNAME(WritePtr);
   };
   return nullptr;
 }
 
-#undef PROCNAME
+#undef FUNCNAME
 
 } // end namespace remote
 } // end namespace orc

Modified: llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp (original)
+++ llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp Sun Apr 17 20:06:49 2016
@@ -54,8 +54,8 @@ int main(int argc, char *argv[]) {
   JITServer Server(Channel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames);
 
   while (1) {
-    JITServer::JITProcId Id = JITServer::InvalidId;
-    if (auto EC = Server.getNextProcId(Id)) {
+    JITServer::JITFuncId Id = JITServer::InvalidId;
+    if (auto EC = Server.getNextFuncId(Id)) {
       errs() << "Error: " << EC.message() << "\n";
       return 1;
     }
@@ -63,7 +63,7 @@ int main(int argc, char *argv[]) {
     case JITServer::TerminateSessionId:
       return 0;
     default:
-      if (auto EC = Server.handleKnownProcedure(Id)) {
+      if (auto EC = Server.handleKnownFunction(Id)) {
         errs() << "Error: " << EC.message() << "\n";
         return 1;
       }

Modified: llvm/trunk/tools/lli/RemoteJITUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/RemoteJITUtils.h?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/tools/lli/RemoteJITUtils.h (original)
+++ llvm/trunk/tools/lli/RemoteJITUtils.h Sun Apr 17 20:06:49 2016
@@ -16,6 +16,7 @@
 
 #include "llvm/ExecutionEngine/Orc/RPCChannel.h"
 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
+#include <mutex>
 
 #if !defined(_MSC_VER) && !defined(__MINGW32__)
 #include <unistd.h>

Modified: llvm/trunk/tools/lli/lli.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/lli.cpp?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/tools/lli/lli.cpp (original)
+++ llvm/trunk/tools/lli/lli.cpp Sun Apr 17 20:06:49 2016
@@ -582,7 +582,7 @@ int main(int argc, char **argv, char * c
   // Reset errno to zero on entry to main.
   errno = 0;
 
-  int Result;
+  int Result = -1;
 
   // Sanity check use of remote-jit: LLI currently only supports use of the
   // remote JIT on Unix platforms.
@@ -681,12 +681,13 @@ int main(int argc, char **argv, char * c
     static_cast<ForwardingMemoryManager*>(RTDyldMM)->setResolver(
       orc::createLambdaResolver(
         [&](const std::string &Name) {
-          orc::TargetAddress Addr = 0;
-          if (auto EC = R->getSymbolAddress(Addr, Name)) {
-            errs() << "Failure during symbol lookup: " << EC.message() << "\n";
-            exit(1);
-          }
-          return RuntimeDyld::SymbolInfo(Addr, JITSymbolFlags::Exported);
+          if (auto AddrOrErr = R->getSymbolAddress(Name))
+	    return RuntimeDyld::SymbolInfo(*AddrOrErr, JITSymbolFlags::Exported);
+	  else {
+	    errs() << "Failure during symbol lookup: "
+		   << AddrOrErr.getError().message() << "\n";
+	    exit(1);
+	  }
         },
         [](const std::string &Name) { return nullptr; }
       ));
@@ -698,8 +699,10 @@ int main(int argc, char **argv, char * c
     EE->finalizeObject();
     DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x"
                  << format("%llx", Entry) << "\n");
-    if (auto EC = R->callIntVoid(Result, Entry))
-      errs() << "ERROR: " << EC.message() << "\n";
+    if (auto ResultOrErr = R->callIntVoid(Entry))
+      Result = *ResultOrErr;
+    else
+      errs() << "ERROR: " << ResultOrErr.getError().message() << "\n";
 
     // Like static constructors, the remote target MCJIT support doesn't handle
     // this yet. It could. FIXME.

Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=266581&r1=266580&r2=266581&view=diff
==============================================================================
--- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
+++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Sun Apr 17 20:06:49 2016
@@ -44,26 +44,25 @@ private:
 class DummyRPC : public testing::Test,
                  public RPC<QueueChannel> {
 public:
-  typedef Procedure<1, void(bool)> Proc1;
-  typedef Procedure<2, void(int8_t, uint8_t, int16_t, uint16_t,
-                            int32_t, uint32_t, int64_t, uint64_t,
-                            bool, std::string, std::vector<int>)> AllTheTypes;
+  typedef Function<2, void(bool)> BasicVoid;
+  typedef Function<3, int32_t(bool)> BasicInt;
+  typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t,
+                           int32_t, uint32_t, int64_t, uint64_t,
+                           bool, std::string, std::vector<int>)> AllTheTypes;
 };
 
 
-TEST_F(DummyRPC, TestBasic) {
+TEST_F(DummyRPC, TestAsyncBasicVoid) {
   std::queue<char> Queue;
   QueueChannel C(Queue);
 
-  {
-    // Make a call to Proc1.
-    auto EC = call<Proc1>(C, true);
-    EXPECT_FALSE(EC) << "Simple call over queue failed";
-  }
+  // Make an async call.
+  auto ResOrErr = callAsync<BasicVoid>(C, true);
+  EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
 
   {
     // Expect a call to Proc1.
-    auto EC = expect<Proc1>(C,
+    auto EC = expect<BasicVoid>(C,
                 [&](bool &B) {
                   EXPECT_EQ(B, true)
                     << "Bool serialization broken";
@@ -71,31 +70,71 @@ TEST_F(DummyRPC, TestBasic) {
                 });
     EXPECT_FALSE(EC) << "Simple expect over queue failed";
   }
+
+  {
+    // Wait for the result.
+    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    EXPECT_FALSE(EC) << "Could not read result.";
+  }
+
+  // Verify that the function returned ok.
+  auto Val = ResOrErr->first.get();
+  EXPECT_TRUE(Val) << "Remote void function failed to execute.";
 }
 
-TEST_F(DummyRPC, TestSerialization) {
+TEST_F(DummyRPC, TestAsyncBasicInt) {
   std::queue<char> Queue;
   QueueChannel C(Queue);
 
+  // Make an async call.
+  auto ResOrErr = callAsync<BasicInt>(C, false);
+  EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
+
   {
-    // Make a call to Proc1.
-    std::vector<int> v({42, 7});
-    auto EC = call<AllTheTypes>(C,
-                                -101,
-                                250,
-                                -10000,
-                                10000,
-                                -1000000000,
-                                1000000000,
-                                -10000000000,
-                                10000000000,
-                                true,
-                                "foo",
-                                v);
-    EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed";
+    // Expect a call to Proc1.
+    auto EC = expect<BasicInt>(C,
+                [&](bool &B) {
+                  EXPECT_EQ(B, false)
+                    << "Bool serialization broken";
+                  return 42;
+                });
+    EXPECT_FALSE(EC) << "Simple expect over queue failed";
   }
 
   {
+    // Wait for the result.
+    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    EXPECT_FALSE(EC) << "Could not read result.";
+  }
+
+  // Verify that the function returned ok.
+  auto Val = ResOrErr->first.get();
+  EXPECT_TRUE(!!Val) << "Remote int function failed to execute.";
+  EXPECT_EQ(*Val, 42) << "Remote int function return wrong value.";
+}
+
+TEST_F(DummyRPC, TestSerialization) {
+  std::queue<char> Queue;
+  QueueChannel C(Queue);
+
+  // Make a call to Proc1.
+  std::vector<int> v({42, 7});
+  auto ResOrErr = callAsync<AllTheTypes>(C,
+                                         -101,
+                                         250,
+                                         -10000,
+                                         10000,
+                                         -1000000000,
+                                         1000000000,
+                                         -10000000000,
+                                         10000000000,
+                                         true,
+                                         "foo",
+                                         v);
+  EXPECT_TRUE(!!ResOrErr)
+    << "Big (serialization test) call over queue failed";
+
+  {
     // Expect a call to Proc1.
     auto EC = expect<AllTheTypes>(C,
                 [&](int8_t &s8,
@@ -136,4 +175,14 @@ TEST_F(DummyRPC, TestSerialization) {
                   });
     EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed";
   }
+
+  {
+    // Wait for the result.
+    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    EXPECT_FALSE(EC) << "Could not read result.";
+  }
+
+  // Verify that the function returned ok.
+  auto Val = ResOrErr->first.get();
+  EXPECT_TRUE(Val) << "Remote void function failed to execute.";
 }




More information about the llvm-commits mailing list