[llvm] r266581 - [ORC] Generalize the ORC RPC utils to support RPC function return values and
David Blaikie via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 18 09:13:39 PDT 2016
On Sun, Apr 17, 2016 at 6:06 PM, Lang Hames via llvm-commits <
llvm-commits at lists.llvm.org> wrote:
> Author: lhames
> Date: Sun Apr 17 20:06:49 2016
> New Revision: 266581
>
> URL: http://llvm.org/viewvc/llvm-project?rev=266581&view=rev
> Log:
> [ORC] Generalize the ORC RPC utils to support RPC function return values
> and
> asynchronous call/handle. Also updates the ORC remote JIT API to use the
> new
> scheme.
>
> The previous version of the RPC tools only supported void functions, and
> required the user to manually call a paired function to return results.
> This
> patch replaces the Procedure typedef (which only supported void functions)
> with
> the Function typedef which supports return values, e.g.:
>
> Function<FooId, int32_t(std::string)> Foo;
>
> The RPC primitives and channel operations are also expanded. RPC channels
> must
> support four new operations: startSendMessage, endSendMessage,
> startRecieveMessage and endRecieveMessage, to handle channel locking. In
> addition, serialization support for tuples to RPCChannels is added to
> enable
> multiple return values.
>
> The RPC primitives are expanded from callAppend, call, expect and handle,
> to:
>
> appendCallAsync - Make an asynchronous call to the given function.
>
> callAsync - The same as appendCallAsync, but calls send on the channel when
> done.
>
> callSTHandling - Blocking call for single-threaded code. Wraps a call to
> callAsync then waits on the result, using a user-supplied
> handler to handle any callbacks from the remote.
>
> callST - The same as callSTHandling, except that it doesn't handle
> callbacks - it expects the result to be the first return.
>
> expect and handle - as before.
>
> handleResponse - Handle a response from the remote.
>
> waitForResult - Wait for the response with the given sequence number to
> arrive.
>
>
> Modified:
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
> llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h
> llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
> llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
> llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp
> llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp
> llvm/trunk/tools/lli/RemoteJITUtils.h
> llvm/trunk/tools/lli/lli.cpp
> llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
>
> Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h Sun Apr 17
> 20:06:49 2016
> @@ -26,7 +26,8 @@ enum class OrcErrorCode : int {
> RemoteMProtectAddrUnrecognized,
> RemoteIndirectStubsOwnerDoesNotExist,
> RemoteIndirectStubsOwnerIdAlreadyInUse,
> - UnexpectedRPCCall
> + UnexpectedRPCCall,
> + UnexpectedRPCResponse,
> };
>
> std::error_code orcError(OrcErrorCode ErrCode);
>
> Modified:
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
> (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
> Sun Apr 17 20:06:49 2016
> @@ -36,6 +36,7 @@ namespace remote {
> template <typename ChannelT>
> class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI {
> public:
> +
> /// Remote memory manager.
> class RCMemoryManager : public RuntimeDyld::MemoryManager {
> public:
> @@ -105,11 +106,13 @@ public:
> DEBUG(dbgs() << "Allocator " << Id << " reserved:\n");
>
> if (CodeSize != 0) {
> - std::error_code EC =
> Client.reserveMem(Unmapped.back().RemoteCodeAddr,
> - Id, CodeSize, CodeAlign);
> - // FIXME; Add error to poll.
> - assert(!EC && "Failed reserving remote memory.");
> - (void)EC;
> + if (auto AddrOrErr = Client.reserveMem(Id, CodeSize, CodeAlign))
> + Unmapped.back().RemoteCodeAddr = *AddrOrErr;
> + else {
> + // FIXME; Add error to poll.
> + assert(!AddrOrErr.getError() && "Failed reserving remote
> memory.");
> + }
> +
> DEBUG(dbgs() << " code: "
> << format("0x%016x", Unmapped.back().RemoteCodeAddr)
> << " (" << CodeSize << " bytes, alignment " <<
> CodeAlign
> @@ -117,11 +120,13 @@ public:
> }
>
> if (RODataSize != 0) {
> - std::error_code EC =
> Client.reserveMem(Unmapped.back().RemoteRODataAddr,
> - Id, RODataSize,
> RODataAlign);
> - // FIXME; Add error to poll.
> - assert(!EC && "Failed reserving remote memory.");
> - (void)EC;
> + if (auto AddrOrErr = Client.reserveMem(Id, RODataSize,
> RODataAlign))
> + Unmapped.back().RemoteRODataAddr = *AddrOrErr;
> + else {
> + // FIXME; Add error to poll.
> + assert(!AddrOrErr.getError() && "Failed reserving remote
> memory.");
> + }
> +
> DEBUG(dbgs() << " ro-data: "
> << format("0x%016x",
> Unmapped.back().RemoteRODataAddr)
> << " (" << RODataSize << " bytes, alignment "
> @@ -129,11 +134,13 @@ public:
> }
>
> if (RWDataSize != 0) {
> - std::error_code EC =
> Client.reserveMem(Unmapped.back().RemoteRWDataAddr,
> - Id, RWDataSize,
> RWDataAlign);
> - // FIXME; Add error to poll.
> - assert(!EC && "Failed reserving remote memory.");
> - (void)EC;
> + if (auto AddrOrErr = Client.reserveMem(Id, RWDataSize,
> RWDataAlign))
> + Unmapped.back().RemoteRWDataAddr = *AddrOrErr;
> + else {
> + // FIXME; Add error to poll.
> + assert(!AddrOrErr.getError() && "Failed reserving remote
> memory.");
> + }
> +
> DEBUG(dbgs() << " rw-data: "
> << format("0x%016x",
> Unmapped.back().RemoteRWDataAddr)
> << " (" << RWDataSize << " bytes, alignment "
> @@ -431,8 +438,10 @@ public:
> TargetAddress PtrBase;
> unsigned NumStubsEmitted;
>
> - Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id,
> - NewStubsRequired);
> + if (auto StubInfoOrErr = Remote.emitIndirectStubs(Id,
> NewStubsRequired))
> + std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr;
> + else
> + return StubInfoOrErr.getError();
>
> unsigned NewBlockId = RemoteIndirectStubsInfos.size();
> RemoteIndirectStubsInfos.push_back({StubBase, PtrBase,
> NumStubsEmitted});
> @@ -484,8 +493,12 @@ public:
> void grow() override {
> TargetAddress BlockAddr = 0;
> uint32_t NumTrampolines = 0;
> - auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines);
> - assert(!EC && "Failed to create trampolines");
> + if (auto TrampolineInfoOrErr = Remote.emitTrampolineBlock())
> + std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr;
> + else {
> + // FIXME: Return error.
> + llvm_unreachable("Failed to create trampolines");
> + }
>
> uint32_t TrampolineSize = Remote.getTrampolineSize();
> for (unsigned I = 0; I < NumTrampolines; ++I)
> @@ -503,53 +516,33 @@ public:
> OrcRemoteTargetClient H(Channel, EC);
> if (EC)
> return EC;
> - return H;
> + return ErrorOr<OrcRemoteTargetClient>(std::move(H));
> }
>
> /// Call the int(void) function at the given address in the target and
> return
> /// its result.
> - std::error_code callIntVoid(int &Result, TargetAddress Addr) {
> + ErrorOr<int> callIntVoid(TargetAddress Addr) {
> DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) <<
> "\n");
>
> - if (auto EC = call<CallIntVoid>(Channel, Addr))
> - return EC;
> -
> - unsigned NextProcId;
> - if (auto EC = listenForCompileRequests(NextProcId))
> - return EC;
> -
> - if (NextProcId != CallIntVoidResponseId)
> - return orcError(OrcErrorCode::UnexpectedRPCCall);
> -
> - return handle<CallIntVoidResponse>(Channel, [&](int R) {
> - Result = R;
> - DEBUG(dbgs() << "Result: " << R << "\n");
> - return std::error_code();
> - });
> + auto Listen =
> + [&](RPCChannel &C, uint32_t Id) {
> + return listenForCompileRequests(C, Id);
> + };
> + return callSTHandling<CallIntVoid>(Channel, Listen, Addr);
> }
>
> /// Call the int(int, char*[]) function at the given address in the
> target and
> /// return its result.
> - std::error_code callMain(int &Result, TargetAddress Addr,
> - const std::vector<std::string> &Args) {
> + ErrorOr<int> callMain(TargetAddress Addr,
> + const std::vector<std::string> &Args) {
> DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x",
> Addr)
> << "\n");
>
> - if (auto EC = call<CallMain>(Channel, Addr, Args))
> - return EC;
> -
> - unsigned NextProcId;
> - if (auto EC = listenForCompileRequests(NextProcId))
> - return EC;
> -
> - if (NextProcId != CallMainResponseId)
> - return orcError(OrcErrorCode::UnexpectedRPCCall);
> -
> - return handle<CallMainResponse>(Channel, [&](int R) {
> - Result = R;
> - DEBUG(dbgs() << "Result: " << R << "\n");
> - return std::error_code();
> - });
> + auto Listen =
> + [&](RPCChannel &C, uint32_t Id) {
> + return listenForCompileRequests(C, Id);
> + };
> + return callSTHandling<CallMain>(Channel, Listen, Addr, Args);
> }
>
> /// Call the void() function at the given address in the target and
> wait for
> @@ -558,17 +551,11 @@ public:
> DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr)
> << "\n");
>
> - if (auto EC = call<CallVoidVoid>(Channel, Addr))
> - return EC;
> -
> - unsigned NextProcId;
> - if (auto EC = listenForCompileRequests(NextProcId))
> - return EC;
> -
> - if (NextProcId != CallVoidVoidResponseId)
> - return orcError(OrcErrorCode::UnexpectedRPCCall);
> -
> - return handle<CallVoidVoidResponse>(Channel, doNothing);
> + auto Listen =
> + [&](RPCChannel &C, JITFuncId Id) {
> + return listenForCompileRequests(C, Id);
> + };
> + return callSTHandling<CallVoidVoid>(Channel, Listen, Addr);
> }
>
> /// Create an RCMemoryManager which will allocate its memory on the
> remote
> @@ -578,7 +565,7 @@ public:
> assert(!MM && "MemoryManager should be null before creation.");
>
> auto Id = AllocatorIds.getNext();
> - if (auto EC = call<CreateRemoteAllocator>(Channel, Id))
> + if (auto EC = callST<CreateRemoteAllocator>(Channel, Id))
> return EC;
> MM = llvm::make_unique<RCMemoryManager>(*this, Id);
> return std::error_code();
> @@ -590,7 +577,7 @@ public:
> createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) {
> assert(!I && "Indirect stubs manager should be null before
> creation.");
> auto Id = IndirectStubOwnerIds.getNext();
> - if (auto EC = call<CreateIndirectStubsOwner>(Channel, Id))
> + if (auto EC = callST<CreateIndirectStubsOwner>(Channel, Id))
> return EC;
> I = llvm::make_unique<RCIndirectStubsManager>(*this, Id);
> return std::error_code();
> @@ -599,45 +586,39 @@ public:
> /// Search for symbols in the remote process. Note: This should be used
> by
> /// symbol resolvers *after* they've searched the local symbol table in
> the
> /// JIT stack.
> - std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) {
> + ErrorOr<TargetAddress> getSymbolAddress(StringRef Name) {
> // Check for an 'out-of-band' error, e.g. from an MM destructor.
> if (ExistingError)
> return ExistingError;
>
> - // Request remote symbol address.
> - if (auto EC = call<GetSymbolAddress>(Channel, Name))
> - return EC;
> -
> - return expect<GetSymbolAddressResponse>(Channel, [&](TargetAddress
> &A) {
> - Addr = A;
> - DEBUG(dbgs() << "Remote address lookup " << Name << " = "
> - << format("0x%016x", Addr) << "\n");
> - return std::error_code();
> - });
> + return callST<GetSymbolAddress>(Channel, Name);
> }
>
> /// Get the triple for the remote target.
> const std::string &getTargetTriple() const { return RemoteTargetTriple;
> }
>
> - std::error_code terminateSession() { return
> call<TerminateSession>(Channel); }
> + std::error_code terminateSession() {
> + return callST<TerminateSession>(Channel);
> + }
>
> private:
> OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC)
> : Channel(Channel) {
> - if ((EC = call<GetRemoteInfo>(Channel)))
> - return;
> -
> - EC = expect<GetRemoteInfoResponse>(
> - Channel, readArgs(RemoteTargetTriple, RemotePointerSize,
> RemotePageSize,
> - RemoteTrampolineSize, RemoteIndirectStubSize));
> + if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) {
> + std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize,
> + RemoteTrampolineSize, RemoteIndirectStubSize) =
> + *RIOrErr;
> + EC = std::error_code();
> + } else
> + EC = RIOrErr.getError();
> }
>
> std::error_code deregisterEHFrames(TargetAddress Addr, uint32_t Size) {
> - return call<RegisterEHFrames>(Channel, Addr, Size);
> + return callST<RegisterEHFrames>(Channel, Addr, Size);
> }
>
> void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
> - if (auto EC = call<DestroyRemoteAllocator>(Channel, Id)) {
> + if (auto EC = callST<DestroyRemoteAllocator>(Channel, Id)) {
> // FIXME: This will be triggered by a removeModuleSet call:
> Propagate
> // error return up through that.
> llvm_unreachable("Failed to destroy remote allocator.");
> @@ -647,19 +628,13 @@ private:
>
> std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId
> Id) {
> IndirectStubOwnerIds.release(Id);
> - return call<DestroyIndirectStubsOwner>(Channel, Id);
> + return callST<DestroyIndirectStubsOwner>(Channel, Id);
> }
>
> - std::error_code emitIndirectStubs(TargetAddress &StubBase,
> - TargetAddress &PtrBase,
> - uint32_t &NumStubsEmitted,
> - ResourceIdMgr::ResourceId Id,
> - uint32_t NumStubsRequired) {
> - if (auto EC = call<EmitIndirectStubs>(Channel, Id, NumStubsRequired))
> - return EC;
> -
> - return expect<EmitIndirectStubsResponse>(
> - Channel, readArgs(StubBase, PtrBase, NumStubsEmitted));
> + ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>>
> + emitIndirectStubs(ResourceIdMgr::ResourceId Id,
> + uint32_t NumStubsRequired) {
> + return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired);
> }
>
> std::error_code emitResolverBlock() {
> @@ -667,24 +642,16 @@ private:
> if (ExistingError)
> return ExistingError;
>
> - return call<EmitResolverBlock>(Channel);
> + return callST<EmitResolverBlock>(Channel);
> }
>
> - std::error_code emitTrampolineBlock(TargetAddress &BlockAddr,
> - uint32_t &NumTrampolines) {
> + ErrorOr<std::tuple<TargetAddress, uint32_t>>
> + emitTrampolineBlock() {
> // Check for an 'out-of-band' error, e.g. from an MM destructor.
> if (ExistingError)
> return ExistingError;
>
> - if (auto EC = call<EmitTrampolineBlock>(Channel))
> - return EC;
> -
> - return expect<EmitTrampolineBlockResponse>(
> - Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) {
> - BlockAddr = BAddr;
> - NumTrampolines = NTrampolines;
> - return std::error_code();
> - });
> + return callST<EmitTrampolineBlock>(Channel);
> }
>
> uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; }
> @@ -693,67 +660,46 @@ private:
>
> uint32_t getTrampolineSize() const { return RemoteTrampolineSize; }
>
> - std::error_code listenForCompileRequests(uint32_t &NextId) {
> + std::error_code listenForCompileRequests(RPCChannel &C, uint32_t &Id) {
> // Check for an 'out-of-band' error, e.g. from an MM destructor.
> if (ExistingError)
> return ExistingError;
>
> - if (auto EC = getNextProcId(Channel, NextId))
> - return EC;
> -
> - while (NextId == RequestCompileId) {
> - TargetAddress TrampolineAddr = 0;
> - if (auto EC = handle<RequestCompile>(Channel,
> readArgs(TrampolineAddr)))
> - return EC;
> -
> - TargetAddress ImplAddr = CompileCallback(TrampolineAddr);
> - if (auto EC = call<RequestCompileResponse>(Channel, ImplAddr))
> - return EC;
> -
> - if (auto EC = getNextProcId(Channel, NextId))
> + if (Id == RequestCompileId) {
> + if (auto EC = handle<RequestCompile>(C, CompileCallback))
> return EC;
> + return std::error_code();
> }
> -
> - return std::error_code();
> + // else
> + return orcError(OrcErrorCode::UnexpectedRPCCall);
> }
>
> - std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) {
> + ErrorOr<std::vector<char>> readMem(char *Dst, TargetAddress Src,
> uint64_t Size) {
> // Check for an 'out-of-band' error, e.g. from an MM destructor.
> if (ExistingError)
> return ExistingError;
>
> - if (auto EC = call<ReadMem>(Channel, Src, Size))
> - return EC;
> -
> - if (auto EC = expect<ReadMemResponse>(
> - Channel, [&]() { return Channel.readBytes(Dst, Size); }))
> - return EC;
> -
> - return std::error_code();
> + return callST<ReadMem>(Channel, Src, Size);
> }
>
> std::error_code registerEHFrames(TargetAddress &RAddr, uint32_t Size) {
> - return call<RegisterEHFrames>(Channel, RAddr, Size);
> + return callST<RegisterEHFrames>(Channel, RAddr, Size);
> }
>
> - std::error_code reserveMem(TargetAddress &RemoteAddr,
> - ResourceIdMgr::ResourceId Id, uint64_t Size,
> - uint32_t Align) {
> + ErrorOr<TargetAddress> reserveMem(ResourceIdMgr::ResourceId Id,
> uint64_t Size,
> + uint32_t Align) {
>
> // Check for an 'out-of-band' error, e.g. from an MM destructor.
> if (ExistingError)
> return ExistingError;
>
> - if (std::error_code EC = call<ReserveMem>(Channel, Id, Size, Align))
> - return EC;
> -
> - return expect<ReserveMemResponse>(Channel, readArgs(RemoteAddr));
> + return callST<ReserveMem>(Channel, Id, Size, Align);
> }
>
> std::error_code setProtections(ResourceIdMgr::ResourceId Id,
> TargetAddress RemoteSegAddr,
> unsigned ProtFlags) {
> - return call<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags);
> + return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags);
> }
>
> std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t
> Size) {
> @@ -761,15 +707,7 @@ private:
> if (ExistingError)
> return ExistingError;
>
> - // Make the send call.
> - if (auto EC = call<WriteMem>(Channel, Addr, Size))
> - return EC;
> -
> - // Follow this up with the section contents.
> - if (auto EC = Channel.appendBytes(Src, Size))
> - return EC;
> -
> - return Channel.send();
> + return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size));
> }
>
> std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) {
> @@ -777,7 +715,7 @@ private:
> if (ExistingError)
> return ExistingError;
>
> - return call<WritePtr>(Channel, Addr, PtrVal);
> + return callST<WritePtr>(Channel, Addr, PtrVal);
> }
>
> static std::error_code doNothing() { return std::error_code(); }
>
> Modified:
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
> (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
> Sun Apr 17 20:06:49 2016
> @@ -24,8 +24,48 @@ namespace llvm {
> namespace orc {
> namespace remote {
>
> +class DirectBufferWriter {
> +public:
> + DirectBufferWriter() = default;
> + DirectBufferWriter(const char *Src, TargetAddress Dst, uint64_t Size)
> + : Src(Src), Dst(Dst), Size(Size) {}
> +
> + const char *getSrc() const { return Src; }
> + TargetAddress getDst() const { return Dst; }
> + uint64_t getSize() const { return Size; }
> +private:
> + const char *Src;
> + TargetAddress Dst;
> + uint64_t Size;
> +};
> +
> +inline std::error_code serialize(RPCChannel &C,
> + const DirectBufferWriter &DBW) {
> + if (auto EC = serialize(C, DBW.getDst()))
> + return EC;
> + if (auto EC = serialize(C, DBW.getSize()))
> + return EC;
> + return C.appendBytes(DBW.getSrc(), DBW.getSize());
> +}
> +
> +inline std::error_code deserialize(RPCChannel &C,
> + DirectBufferWriter &DBW) {
> + TargetAddress Dst;
> + if (auto EC = deserialize(C, Dst))
> + return EC;
> + uint64_t Size;
> + if (auto EC = deserialize(C, Size))
> + return EC;
> + char *Addr = reinterpret_cast<char*>(static_cast<uintptr_t>(Dst));
> +
> + DBW = DirectBufferWriter(0, Dst, Size);
> +
> + return C.readBytes(Addr, Size);
> +}
> +
> class OrcRemoteTargetRPCAPI : public RPC<RPCChannel> {
> protected:
> +
> class ResourceIdMgr {
> public:
> typedef uint64_t ResourceId;
> @@ -45,146 +85,111 @@ protected:
> };
>
> public:
> - enum JITProcId : uint32_t {
> - InvalidId = 0,
> - CallIntVoidId,
> - CallIntVoidResponseId,
> + enum JITFuncId : uint32_t {
> + InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId,
> + CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId,
> CallMainId,
> - CallMainResponseId,
> CallVoidVoidId,
> - CallVoidVoidResponseId,
> CreateRemoteAllocatorId,
> CreateIndirectStubsOwnerId,
> DeregisterEHFramesId,
> DestroyRemoteAllocatorId,
> DestroyIndirectStubsOwnerId,
> EmitIndirectStubsId,
> - EmitIndirectStubsResponseId,
> EmitResolverBlockId,
> EmitTrampolineBlockId,
> - EmitTrampolineBlockResponseId,
> GetSymbolAddressId,
> - GetSymbolAddressResponseId,
> GetRemoteInfoId,
> - GetRemoteInfoResponseId,
> ReadMemId,
> - ReadMemResponseId,
> RegisterEHFramesId,
> ReserveMemId,
> - ReserveMemResponseId,
> RequestCompileId,
> - RequestCompileResponseId,
> SetProtectionsId,
> TerminateSessionId,
> WriteMemId,
> WritePtrId
> };
>
> - static const char *getJITProcIdName(JITProcId Id);
> -
> - typedef Procedure<CallIntVoidId, void(TargetAddress Addr)> CallIntVoid;
> + static const char *getJITFuncIdName(JITFuncId Id);
>
> - typedef Procedure<CallIntVoidResponseId, void(int Result)>
> - CallIntVoidResponse;
> + typedef Function<CallIntVoidId, int32_t(TargetAddress Addr)>
> CallIntVoid;
>
> - typedef Procedure<CallMainId, void(TargetAddress Addr,
> - std::vector<std::string> Args)>
> + typedef Function<CallMainId, int32_t(TargetAddress Addr,
> + std::vector<std::string> Args)>
> CallMain;
>
> - typedef Procedure<CallMainResponseId, void(int Result)>
> CallMainResponse;
> -
> - typedef Procedure<CallVoidVoidId, void(TargetAddress FnAddr)>
> CallVoidVoid;
> -
> - typedef Procedure<CallVoidVoidResponseId, void()> CallVoidVoidResponse;
> + typedef Function<CallVoidVoidId, void(TargetAddress FnAddr)>
> CallVoidVoid;
>
> - typedef Procedure<CreateRemoteAllocatorId,
> - void(ResourceIdMgr::ResourceId AllocatorID)>
> + typedef Function<CreateRemoteAllocatorId,
> + void(ResourceIdMgr::ResourceId AllocatorID)>
> CreateRemoteAllocator;
>
> - typedef Procedure<CreateIndirectStubsOwnerId,
> - void(ResourceIdMgr::ResourceId StubOwnerID)>
> + typedef Function<CreateIndirectStubsOwnerId,
> + void(ResourceIdMgr::ResourceId StubOwnerID)>
> CreateIndirectStubsOwner;
>
> - typedef Procedure<DeregisterEHFramesId,
> - void(TargetAddress Addr, uint32_t Size)>
> + typedef Function<DeregisterEHFramesId,
> + void(TargetAddress Addr, uint32_t Size)>
> DeregisterEHFrames;
>
> - typedef Procedure<DestroyRemoteAllocatorId,
> - void(ResourceIdMgr::ResourceId AllocatorID)>
> + typedef Function<DestroyRemoteAllocatorId,
> + void(ResourceIdMgr::ResourceId AllocatorID)>
> DestroyRemoteAllocator;
>
> - typedef Procedure<DestroyIndirectStubsOwnerId,
> - void(ResourceIdMgr::ResourceId StubsOwnerID)>
> + typedef Function<DestroyIndirectStubsOwnerId,
> + void(ResourceIdMgr::ResourceId StubsOwnerID)>
> DestroyIndirectStubsOwner;
>
> - typedef Procedure<EmitIndirectStubsId,
> - void(ResourceIdMgr::ResourceId StubsOwnerID,
> - uint32_t NumStubsRequired)>
> + /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted).
> + typedef Function<EmitIndirectStubsId,
> + std::tuple<TargetAddress, TargetAddress, uint32_t>(
> + ResourceIdMgr::ResourceId StubsOwnerID,
> + uint32_t NumStubsRequired)>
> EmitIndirectStubs;
>
> - typedef Procedure<EmitIndirectStubsResponseId,
> - void(TargetAddress StubsBaseAddr,
> - TargetAddress PtrsBaseAddr,
> - uint32_t NumStubsEmitted)>
> - EmitIndirectStubsResponse;
> + typedef Function<EmitResolverBlockId, void()> EmitResolverBlock;
>
> - typedef Procedure<EmitResolverBlockId, void()> EmitResolverBlock;
> + /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines).
> + typedef Function<EmitTrampolineBlockId,
> + std::tuple<TargetAddress, uint32_t>()>
> EmitTrampolineBlock;
>
> - typedef Procedure<EmitTrampolineBlockId, void()> EmitTrampolineBlock;
> -
> - typedef Procedure<EmitTrampolineBlockResponseId,
> - void(TargetAddress BlockAddr, uint32_t
> NumTrampolines)>
> - EmitTrampolineBlockResponse;
> -
> - typedef Procedure<GetSymbolAddressId, void(std::string SymbolName)>
> + typedef Function<GetSymbolAddressId, TargetAddress(std::string
> SymbolName)>
> GetSymbolAddress;
>
> - typedef Procedure<GetSymbolAddressResponseId, void(uint64_t SymbolAddr)>
> - GetSymbolAddressResponse;
> -
> - typedef Procedure<GetRemoteInfoId, void()> GetRemoteInfo;
> -
> - typedef Procedure<GetRemoteInfoResponseId,
> - void(std::string Triple, uint32_t PointerSize,
> - uint32_t PageSize, uint32_t TrampolineSize,
> - uint32_t IndirectStubSize)>
> - GetRemoteInfoResponse;
> + /// GetRemoteInfo result is (Triple, PointerSize, PageSize,
> TrampolineSize,
> + /// IndirectStubsSize).
> + typedef Function<GetRemoteInfoId,
> + std::tuple<std::string, uint32_t, uint32_t, uint32_t,
> + uint32_t>()> GetRemoteInfo;
>
> - typedef Procedure<ReadMemId, void(TargetAddress Src, uint64_t Size)>
> + typedef Function<ReadMemId,
> + std::vector<char>(TargetAddress Src, uint64_t Size)>
> ReadMem;
>
> - typedef Procedure<ReadMemResponseId, void()> ReadMemResponse;
> -
> - typedef Procedure<RegisterEHFramesId,
> - void(TargetAddress Addr, uint32_t Size)>
> + typedef Function<RegisterEHFramesId,
> + void(TargetAddress Addr, uint32_t Size)>
> RegisterEHFrames;
>
> - typedef Procedure<ReserveMemId,
> - void(ResourceIdMgr::ResourceId AllocID, uint64_t Size,
> - uint32_t Align)>
> + typedef Function<ReserveMemId,
> + TargetAddress(ResourceIdMgr::ResourceId AllocID,
> + uint64_t Size, uint32_t Align)>
> ReserveMem;
>
> - typedef Procedure<ReserveMemResponseId, void(TargetAddress Addr)>
> - ReserveMemResponse;
> -
> - typedef Procedure<RequestCompileId, void(TargetAddress TrampolineAddr)>
> + typedef Function<RequestCompileId,
> + TargetAddress(TargetAddress TrampolineAddr)>
> RequestCompile;
>
> - typedef Procedure<RequestCompileResponseId, void(TargetAddress
> ImplAddr)>
> - RequestCompileResponse;
> -
> - typedef Procedure<SetProtectionsId,
> - void(ResourceIdMgr::ResourceId AllocID, TargetAddress
> Dst,
> - uint32_t ProtFlags)>
> + typedef Function<SetProtectionsId,
> + void(ResourceIdMgr::ResourceId AllocID, TargetAddress
> Dst,
> + uint32_t ProtFlags)>
> SetProtections;
>
> - typedef Procedure<TerminateSessionId, void()> TerminateSession;
> + typedef Function<TerminateSessionId, void()> TerminateSession;
>
> - typedef Procedure<WriteMemId,
> - void(TargetAddress Dst, uint64_t Size /* Data to
> follow */)>
> + typedef Function<WriteMemId, void(DirectBufferWriter DB)>
> WriteMem;
>
> - typedef Procedure<WritePtrId, void(TargetAddress Dst, TargetAddress
> Val)>
> + typedef Function<WritePtrId, void(TargetAddress Dst, TargetAddress Val)>
> WritePtr;
> };
>
>
> Modified:
> llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
> (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
> Sun Apr 17 20:06:49 2016
> @@ -45,14 +45,14 @@ public:
> EHFramesRegister(std::move(EHFramesRegister)),
> EHFramesDeregister(std::move(EHFramesDeregister)) {}
>
> - std::error_code getNextProcId(JITProcId &Id) {
> + std::error_code getNextFuncId(JITFuncId &Id) {
> return deserialize(Channel, Id);
> }
>
> - std::error_code handleKnownProcedure(JITProcId Id) {
> + std::error_code handleKnownFunction(JITFuncId Id) {
> typedef OrcRemoteTargetServer ThisT;
>
> - DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) <<
> "\n");
> + DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) <<
> "\n");
>
> switch (Id) {
> case CallIntVoidId:
> @@ -111,27 +111,17 @@ public:
> llvm_unreachable("Unhandled JIT RPC procedure Id.");
> }
>
> - std::error_code requestCompile(TargetAddress &CompiledFnAddr,
> - TargetAddress TrampolineAddr) {
> - if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
> - return EC;
> -
> - while (1) {
> - JITProcId Id = InvalidId;
> - if (auto EC = getNextProcId(Id))
> - return EC;
> + ErrorOr<TargetAddress> requestCompile(TargetAddress TrampolineAddr) {
> + auto Listen =
> + [&](RPCChannel &C, uint32_t Id) {
> + return handleKnownFunction(static_cast<JITFuncId>(Id));
> + };
>
> - switch (Id) {
> - case RequestCompileResponseId:
> - return handle<RequestCompileResponse>(Channel,
> - readArgs(CompiledFnAddr));
> - default:
> - if (auto EC = handleKnownProcedure(Id))
> - return EC;
> - }
> - }
> + return callSTHandling<RequestCompile>(Channel, Listen,
> TrampolineAddr);
> + }
>
> - llvm_unreachable("Fell through request-compile command loop.");
> + void handleTerminateSession() {
> + handle<TerminateSession>(Channel, [](){ return std::error_code(); });
> }
>
> private:
> @@ -175,18 +165,16 @@ private:
> static std::error_code doNothing() { return std::error_code(); }
>
> static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr)
> {
> - TargetAddress CompiledFnAddr = 0;
> -
> auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
> - auto EC = T->requestCompile(
> - CompiledFnAddr, static_cast<TargetAddress>(
> - reinterpret_cast<uintptr_t>(TrampolineAddr)));
> - assert(!EC && "Compile request failed");
> - (void)EC;
> - return CompiledFnAddr;
> + auto AddrOrErr = T->requestCompile(
> + static_cast<TargetAddress>(
> + reinterpret_cast<uintptr_t>(TrampolineAddr)));
> + // FIXME: Allow customizable failure substitution functions.
> + assert(AddrOrErr && "Compile request failed");
> + return *AddrOrErr;
> }
>
> - std::error_code handleCallIntVoid(TargetAddress Addr) {
> + ErrorOr<int32_t> handleCallIntVoid(TargetAddress Addr) {
> typedef int (*IntVoidFnTy)();
> IntVoidFnTy Fn =
> reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
> @@ -195,11 +183,11 @@ private:
> int Result = Fn();
> DEBUG(dbgs() << " Result = " << Result << "\n");
>
> - return call<CallIntVoidResponse>(Channel, Result);
> + return Result;
> }
>
> - std::error_code handleCallMain(TargetAddress Addr,
> - std::vector<std::string> Args) {
> + ErrorOr<int32_t> handleCallMain(TargetAddress Addr,
> + std::vector<std::string> Args) {
> typedef int (*MainFnTy)(int, const char *[]);
>
> MainFnTy Fn =
> reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
> @@ -214,7 +202,7 @@ private:
> int Result = Fn(ArgC, ArgV.get());
> DEBUG(dbgs() << " Result = " << Result << "\n");
>
> - return call<CallMainResponse>(Channel, Result);
> + return Result;
> }
>
> std::error_code handleCallVoidVoid(TargetAddress Addr) {
> @@ -226,7 +214,7 @@ private:
> Fn();
> DEBUG(dbgs() << " Complete.\n");
>
> - return call<CallVoidVoidResponse>(Channel);
> + return std::error_code();
> }
>
> std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId
> Id) {
> @@ -273,8 +261,9 @@ private:
> return std::error_code();
> }
>
> - std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
> - uint32_t NumStubsRequired) {
> + ErrorOr<std::tuple<TargetAddress, TargetAddress, uint32_t>>
> + handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
> + uint32_t NumStubsRequired) {
> DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired
> << " stubs.\n");
>
> @@ -296,8 +285,7 @@ private:
> auto &BlockList = StubOwnerItr->second;
> BlockList.push_back(std::move(IS));
>
> - return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
> - NumStubsEmitted);
> + return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted);
> }
>
> std::error_code handleEmitResolverBlock() {
> @@ -316,7 +304,8 @@ private:
> sys::Memory::MF_EXEC);
> }
>
> - std::error_code handleEmitTrampolineBlock() {
> + ErrorOr<std::tuple<TargetAddress, uint32_t>>
> + handleEmitTrampolineBlock() {
> std::error_code EC;
> auto TrampolineBlock =
> sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
> @@ -325,7 +314,7 @@ private:
> if (EC)
> return EC;
>
> - unsigned NumTrampolines =
> + uint32_t NumTrampolines =
> (sys::Process::getPageSize() - TargetT::PointerSize) /
> TargetT::TrampolineSize;
>
> @@ -339,20 +328,21 @@ private:
>
> TrampolineBlocks.push_back(std::move(TrampolineBlock));
>
> - return call<EmitTrampolineBlockResponse>(
> - Channel,
> -
> static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
> - NumTrampolines);
> + auto TrampolineBaseAddr =
> +
> static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem));
> +
> + return std::make_tuple(TrampolineBaseAddr, NumTrampolines);
> }
>
> - std::error_code handleGetSymbolAddress(const std::string &Name) {
> + ErrorOr<TargetAddress> handleGetSymbolAddress(const std::string &Name) {
> TargetAddress Addr = SymbolLookup(Name);
> DEBUG(dbgs() << " Symbol '" << Name << "' = " << format("0x%016x",
> Addr)
> << "\n");
> - return call<GetSymbolAddressResponse>(Channel, Addr);
> + return Addr;
> }
>
> - std::error_code handleGetRemoteInfo() {
> + ErrorOr<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>>
> + handleGetRemoteInfo() {
> std::string ProcessTriple = sys::getProcessTriple();
> uint32_t PointerSize = TargetT::PointerSize;
> uint32_t PageSize = sys::Process::getPageSize();
> @@ -364,24 +354,23 @@ private:
> << " page size = " << PageSize << "\n"
> << " trampoline size = " << TrampolineSize << "\n"
> << " indirect stub size = " << IndirectStubSize <<
> "\n");
> - return call<GetRemoteInfoResponse>(Channel, ProcessTriple,
> PointerSize,
> - PageSize, TrampolineSize,
> - IndirectStubSize);
> + return std::make_tuple(ProcessTriple, PointerSize, PageSize
> ,TrampolineSize,
> + IndirectStubSize);
> }
>
> - std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
> + ErrorOr<std::vector<char>>
> + handleReadMem(TargetAddress RSrc, uint64_t Size) {
> char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
>
> DEBUG(dbgs() << " Reading " << Size << " bytes from "
> << format("0x%016x", RSrc) << "\n");
>
> - if (auto EC = call<ReadMemResponse>(Channel))
> - return EC;
> -
> - if (auto EC = Channel.appendBytes(Src, Size))
> - return EC;
> + std::vector<char> Buffer;
> + Buffer.resize(Size);
> + for (char *P = Src; Size != 0; --Size)
> + Buffer.push_back(*P++);
>
> - return Channel.send();
> + return Buffer;
> }
>
> std::error_code handleRegisterEHFrames(TargetAddress TAddr, uint32_t
> Size) {
> @@ -392,8 +381,9 @@ private:
> return std::error_code();
> }
>
> - std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t
> Size,
> - uint32_t Align) {
> + ErrorOr<TargetAddress>
> + handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
> + uint32_t Align) {
> auto I = Allocators.find(Id);
> if (I == Allocators.end())
> return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
> @@ -408,7 +398,7 @@ private:
> TargetAddress AllocAddr =
>
> static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
>
> - return call<ReserveMemResponse>(Channel, AllocAddr);
> + return AllocAddr;
> }
>
> std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
> @@ -425,11 +415,10 @@ private:
> return Allocator.setProtections(LocalAddr, Flags);
> }
>
> - std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
> - char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
> - DEBUG(dbgs() << " Writing " << Size << " bytes to "
> - << format("0x%016x", RDst) << "\n");
> - return Channel.readBytes(Dst, Size);
> + std::error_code handleWriteMem(DirectBufferWriter DBW) {
> + DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to "
> + << format("0x%016x", DBW.getDst()) << "\n");
> + return std::error_code();
> }
>
> std::error_code handleWritePtr(TargetAddress Addr, TargetAddress
> PtrVal) {
>
> Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCChannel.h Sun Apr 17
> 20:06:49 2016
> @@ -5,8 +5,10 @@
>
> #include "OrcError.h"
> #include "llvm/ADT/ArrayRef.h"
> +#include "llvm/ADT/STLExtras.h"
> #include "llvm/Support/Endian.h"
>
> +#include <mutex>
> #include <system_error>
>
> namespace llvm {
> @@ -26,31 +28,68 @@ public:
>
> /// Flush the stream if possible.
> virtual std::error_code send() = 0;
> +
> + /// Get the lock for stream reading.
> + std::mutex& getReadLock() { return readLock; }
> +
> + /// Get the lock for stream writing.
> + std::mutex& getWriteLock() { return writeLock; }
> +
> +private:
> + std::mutex readLock, writeLock;
> };
>
> +/// Notify the channel that we're starting a message send.
> +/// Locks the channel for writing.
> +inline std::error_code startSendMessage(RPCChannel &C) {
> + C.getWriteLock().lock();
> + return std::error_code();
> +}
> +
> +/// Notify the channel that we're ending a message send.
> +/// Unlocks the channel for writing.
> +inline std::error_code endSendMessage(RPCChannel &C) {
> + C.getWriteLock().unlock();
> + return std::error_code();
> +}
> +
> +/// Notify the channel that we're starting a message receive.
> +/// Locks the channel for reading.
> +inline std::error_code startReceiveMessage(RPCChannel &C) {
> + C.getReadLock().lock();
> + return std::error_code();
> +}
> +
> +/// Notify the channel that we're ending a message receive.
> +/// Unlocks the channel for reading.
> +inline std::error_code endReceiveMessage(RPCChannel &C) {
> + C.getReadLock().unlock();
> + return std::error_code();
> +}
> +
> /// RPC channel serialization for a variadic list of arguments.
> template <typename T, typename... Ts>
> -std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &...
> Args) {
> +std::error_code serializeSeq(RPCChannel &C, const T &Arg, const Ts &...
> Args) {
> if (auto EC = serialize(C, Arg))
> return EC;
> - return serialize_seq(C, Args...);
> + return serializeSeq(C, Args...);
> }
>
> /// RPC channel serialization for an (empty) variadic list of arguments.
> -inline std::error_code serialize_seq(RPCChannel &C) {
> +inline std::error_code serializeSeq(RPCChannel &C) {
> return std::error_code();
> }
>
> /// RPC channel deserialization for a variadic list of arguments.
> template <typename T, typename... Ts>
> -std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) {
> +std::error_code deserializeSeq(RPCChannel &C, T &Arg, Ts &... Args) {
> if (auto EC = deserialize(C, Arg))
> return EC;
> - return deserialize_seq(C, Args...);
> + return deserializeSeq(C, Args...);
> }
>
> /// RPC channel serialization for an (empty) variadic list of arguments.
> -inline std::error_code deserialize_seq(RPCChannel &C) {
> +inline std::error_code deserializeSeq(RPCChannel &C) {
> return std::error_code();
> }
>
> @@ -138,6 +177,34 @@ inline std::error_code deserialize(RPCCh
> return C.readBytes(&S[0], Count);
> }
>
> +// Serialization helper for std::tuple.
> +template <typename TupleT, size_t... Is>
> +inline std::error_code serializeTupleHelper(RPCChannel &C,
> + const TupleT &V,
> + llvm::index_sequence<Is...> _)
> {
> + return serializeSeq(C, std::get<Is>(V)...);
> +}
> +
> +/// RPC channel serialization for std::tuple.
> +template <typename... ArgTs>
> +inline std::error_code serialize(RPCChannel &C, const
> std::tuple<ArgTs...> &V) {
> + return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
> +}
> +
> +// Serialization helper for std::tuple.
> +template <typename TupleT, size_t... Is>
> +inline std::error_code deserializeTupleHelper(RPCChannel &C,
> + TupleT &V,
> + llvm::index_sequence<Is...>
> _) {
> + return deserializeSeq(C, std::get<Is>(V)...);
> +}
> +
> +/// RPC channel deserialization for std::tuple.
> +template <typename... ArgTs>
> +inline std::error_code deserialize(RPCChannel &C, std::tuple<ArgTs...>
> &V) {
> + return deserializeTupleHelper(C, V,
> llvm::index_sequence_for<ArgTs...>());
> +}
> +
> /// RPC channel serialization for ArrayRef<T>.
> template <typename T>
> std::error_code serialize(RPCChannel &C, const ArrayRef<T> &A) {
>
> Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h (original)
> +++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h Sun Apr 17
> 20:06:49 2016
> @@ -14,46 +14,197 @@
> #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
> #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
>
> +#include "llvm/ADT/Optional.h"
> #include "llvm/ADT/STLExtras.h"
> #include "llvm/ExecutionEngine/Orc/OrcError.h"
> +#include "llvm/Support/ErrorOr.h"
> +#include <future>
> +#include <map>
>
> namespace llvm {
> namespace orc {
> namespace remote {
>
> +/// Describes reserved RPC Function Ids.
> +///
> +/// The default implementation will serve for integer and enum function id
> +/// types. If you want to use a custom type as your FunctionId you can
> +/// specialize this class and provide unique values for InvalidId,
> +/// ResponseId and FirstValidId.
> +
> +template <typename T>
> +class RPCFunctionIdTraits {
> +public:
> + constexpr static const T InvalidId = static_cast<T>(0);
> + constexpr static const T ResponseId = static_cast<T>(1);
> + constexpr static const T FirstValidId = static_cast<T>(2);
> +};
> +
> // Base class containing utilities that require partial specialization.
> // These cannot be included in RPC, as template class members cannot be
> // partially specialized.
> class RPCBase {
> protected:
> - template <typename ProcedureIdT, ProcedureIdT ProcId, typename FnT>
> - class ProcedureHelper {
> - public:
> - static const ProcedureIdT Id = ProcId;
> - };
>
> - template <typename ChannelT, typename Proc> class CallHelper;
> + // RPC Function description type.
> + //
> + // This class provides the information and operations needed to support
> the
> + // RPC primitive operations (call, expect, etc) for a given function. It
> + // is specialized for void and non-void functions to deal with the
> differences
> + // betwen the two. Both specializations have the same interface:
> + //
> + // Id - The function's unique identifier.
> + // OptionalReturn - The return type for asyncronous calls.
> + // ErrorReturn - The return type for synchronous calls.
> + // optionalToErrorReturn - Conversion from a valid OptionalReturn to an
> + // ErrorReturn.
> + // readResult - Deserialize a result from a channel.
> + // abandon - Abandon a promised (asynchronous) result.
> + // respond - Retun a result on the channel.
> + template <typename FunctionIdT, FunctionIdT FuncId, typename FnT>
> + class FunctionHelper {};
>
> - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId,
> + // RPC Function description specialization for non-void functions.
> + template <typename FunctionIdT, FunctionIdT FuncId, typename RetT,
> typename... ArgTs>
> - class CallHelper<ChannelT,
> - ProcedureHelper<ProcedureIdT, ProcId, void(ArgTs...)>>
> {
> + class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> {
> public:
> - static std::error_code call(ChannelT &C, const ArgTs &... Args) {
> - if (auto EC = serialize(C, ProcId))
> +
> + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId &&
> + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId,
> + "Cannot define custom function with InvalidId or
> ResponseId. "
> + "Please use
> RPCFunctionTraits<FunctionIdT>::FirstValidId.");
> +
> + static const FunctionIdT Id = FuncId;
> +
> + typedef Optional<RetT> OptionalReturn;
> +
> + typedef ErrorOr<RetT> ErrorReturn;
> +
> + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) {
> + assert(V && "Return value not available");
> + return std::move(*V);
> + }
> +
> + template <typename ChannelT>
> + static std::error_code readResult(ChannelT &C,
> + std::promise<OptionalReturn> &P) {
> + RetT Val;
> + auto EC = deserialize(C, Val);
> + // FIXME: Join error EC2 from endReceiveMessage with the deserialize
> + // error once we switch to using Error.
> + auto EC2 = endReceiveMessage(C);
> + (void)EC2;
> +
> + if (EC) {
> + P.set_value(OptionalReturn());
> return EC;
> - // If you see a compile-error on this line you're probably calling a
> - // function with the wrong signature.
> - return serialize_seq(C, Args...);
> + }
> + P.set_value(std::move(Val));
> + return std::error_code();
> + }
> +
> + static void abandon(std::promise<OptionalReturn> &P) {
> + P.set_value(OptionalReturn());
> + }
> +
> + template <typename ChannelT, typename SequenceNumberT>
> + static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo,
> + const ErrorReturn &Result) {
> + FunctionIdT ResponseId =
> + RPCFunctionIdTraits<FunctionIdT>::ResponseId;
> +
> + // If the handler returned an error then bail out with that.
> + if (!Result)
> + return Result.getError();
> +
> + // Otherwise open a new message on the channel and send the result.
> + if (auto EC = startSendMessage(C))
> + return EC;
> + if (auto EC = serializeSeq(C, ResponseId, SeqNo, *Result))
> + return EC;
> + return endSendMessage(C);
> }
> };
>
> - template <typename ChannelT, typename Proc> class HandlerHelper;
> + // RPC Function description specialization for void functions.
> + template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs>
> + class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> {
> + public:
>
> - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId,
> - typename... ArgTs>
> - class HandlerHelper<ChannelT,
> - ProcedureHelper<ProcedureIdT, ProcId,
> void(ArgTs...)>> {
> + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId &&
> + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId,
> + "Cannot define custom function with InvalidId or
> ResponseId. "
> + "Please use
> RPCFunctionTraits<FunctionIdT>::FirstValidId.");
> +
> + static const FunctionIdT Id = FuncId;
> +
> + typedef bool OptionalReturn;
> + typedef std::error_code ErrorReturn;
> +
> + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) {
> + assert(V && "Return value not available");
> + return std::error_code();
> + }
> +
> + template <typename ChannelT>
> + static std::error_code readResult(ChannelT &C,
> + std::promise<OptionalReturn> &P) {
> + // Void functions don't have anything to deserialize, so we're good.
> + P.set_value(true);
> + return endReceiveMessage(C);
> + }
> +
> + static void abandon(std::promise<OptionalReturn> &P) {
> + P.set_value(false);
> + }
> +
> + template <typename ChannelT, typename SequenceNumberT>
> + static std::error_code respond(ChannelT &C, SequenceNumberT SeqNo,
> + const ErrorReturn &Result) {
> + const FunctionIdT ResponseId =
> + RPCFunctionIdTraits<FunctionIdT>::ResponseId;
> +
> + // If the handler returned an error then bail out with that.
> + if (Result)
> + return Result;
> +
> + // Otherwise open a new message on the channel and send the result.
> + if (auto EC = startSendMessage(C))
> + return EC;
> + if (auto EC = serializeSeq(C, ResponseId, SeqNo))
> + return EC;
> + return endSendMessage(C);
> + }
> + };
> +
> + // Helper for the call primitive.
> + template <typename ChannelT, typename SequenceNumberT, typename Func>
> + class CallHelper;
> +
> + template <typename ChannelT, typename SequenceNumberT, typename
> FunctionIdT,
> + FunctionIdT FuncId, typename RetT, typename... ArgTs>
> + class CallHelper<ChannelT, SequenceNumberT,
> + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> {
> + public:
> + static std::error_code call(ChannelT &C, SequenceNumberT SeqNo,
> + const ArgTs &... Args) {
> + if (auto EC = startSendMessage(C))
> + return EC;
> + if (auto EC = serializeSeq(C, FuncId, SeqNo, Args...))
> + return EC;
> + return endSendMessage(C);
> + }
> + };
> +
> + // Helper for handle primitive.
> + template <typename ChannelT, typename SequenceNumberT, typename Func>
> + class HandlerHelper;
> +
> + template <typename ChannelT, typename SequenceNumberT, typename
> FunctionIdT,
> + FunctionIdT FuncId, typename RetT, typename... ArgTs>
> + class HandlerHelper<ChannelT, SequenceNumberT,
> + FunctionHelper<FunctionIdT, FuncId,
> RetT(ArgTs...)>> {
> public:
> template <typename HandlerT>
> static std::error_code handle(ChannelT &C, HandlerT Handler) {
> @@ -61,34 +212,46 @@ protected:
> }
>
> private:
> +
> + typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func;
> +
> template <typename HandlerT, size_t... Is>
> static std::error_code readAndHandle(ChannelT &C, HandlerT Handler,
> llvm::index_sequence<Is...> _) {
> std::tuple<ArgTs...> RPCArgs;
> + SequenceNumberT SeqNo;
> // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable
> warning
> // for RPCArgs. Void cast RPCArgs to work around this for now.
> // FIXME: Remove this workaround once we can assume a working GCC
> version.
> (void)RPCArgs;
> - if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...))
> + if (auto EC = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...))
> return EC;
> - return Handler(std::get<Is>(RPCArgs)...);
> +
> + // We've deserialized the arguments, so unlock the channel for
> reading
> + // before we call the handler. This allows recursive RPC calls.
> + if (auto EC = endReceiveMessage(C))
> + return EC;
> +
> + return Func::template respond<ChannelT, SequenceNumberT>(
> + C, SeqNo, Handler(std::get<Is>(RPCArgs)...));
> }
> +
> };
>
> - template <typename ClassT, typename... ArgTs> class MemberFnWrapper {
> + // Helper for wrapping member functions up as functors.
> + template <typename ClassT, typename RetT, typename... ArgTs>
> + class MemberFnWrapper {
> public:
> - typedef std::error_code (ClassT::*MethodT)(ArgTs...);
> + typedef RetT(ClassT::*MethodT)(ArgTs...);
> MemberFnWrapper(ClassT &Instance, MethodT Method)
> : Instance(Instance), Method(Method) {}
> - std::error_code operator()(ArgTs &... Args) {
> - return (Instance.*Method)(Args...);
> - }
> -
> + RetT operator()(ArgTs &... Args) { return
> (Instance.*Method)(Args...); }
> private:
> ClassT &Instance;
> MethodT Method;
> };
>
> + // Helper that provides a Functor for deserializing arguments.
> template <typename... ArgTs> class ReadArgs {
> public:
> std::error_code operator()() { return std::error_code(); }
> @@ -112,7 +275,7 @@ protected:
>
> /// Contains primitive utilities for defining, calling and handling calls
> to
> /// remote procedures. ChannelT is a bidirectional stream conforming to
> the
> -/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a
> procedure
> +/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a
> procedure
> /// identifier type that must be serializable on ChannelT.
> ///
> /// These utilities support the construction of very primitive RPC
> utilities.
> @@ -129,120 +292,184 @@ protected:
> ///
> /// Overview (see comments individual types/methods for details):
> ///
> -/// Procedure<Id, Args...> :
> +/// Function<Id, Args...> :
> ///
> /// associates a unique serializable id with an argument list.
> ///
> ///
> -/// call<Proc>(Channel, Args...) :
> +/// call<Func>(Channel, Args...) :
> ///
> -/// Calls the remote procedure 'Proc' by serializing Proc's id followed
> by its
> +/// Calls the remote procedure 'Func' by serializing Func's id followed
> by its
> /// arguments and sending the resulting bytes to 'Channel'.
> ///
> ///
> -/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> :
> +/// handle<Func>(Channel, <functor matching std::error_code(Args...)> :
> ///
> -/// Handles a call to 'Proc' by deserializing its arguments and calling
> the
> -/// given functor. This assumes that the id for 'Proc' has already been
> +/// Handles a call to 'Func' by deserializing its arguments and calling
> the
> +/// given functor. This assumes that the id for 'Func' has already been
> /// deserialized.
> ///
> -/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> :
> +/// expect<Func>(Channel, <functor matching std::error_code(Args...)> :
> ///
> /// The same as 'handle', except that the procedure id should not have
> been
> -/// read yet. Expect will deserialize the id and assert that it matches
> Proc's
> +/// read yet. Expect will deserialize the id and assert that it matches
> Func's
> /// id. If it does not, and unexpected RPC call error is returned.
> -
> -template <typename ChannelT, typename ProcedureIdT = uint32_t>
> +template <typename ChannelT, typename FunctionIdT = uint32_t,
> + typename SequenceNumberT = uint16_t>
> class RPC : public RPCBase {
> public:
> +
> + RPC() = default;
> + RPC(const RPC&) = delete;
> + RPC& operator=(const RPC&) = delete;
> + RPC(RPC &&Other) :
> SequenceNumberMgr(std::move(Other.SequenceNumberMgr)),
> OutstandingResults(std::move(Other.OutstandingResults)) {}
> + RPC& operator=(RPC&&) = default;
> +
> /// Utility class for defining/referring to RPC procedures.
> ///
> /// Typedefs of this utility are used when calling/handling remote
> procedures.
> ///
> - /// ProcId should be a unique value of ProcedureIdT (i.e. not used with
> any
> - /// other Procedure typedef in the RPC API being defined.
> + /// FuncId should be a unique value of FunctionIdT (i.e. not used with
> any
> + /// other Function typedef in the RPC API being defined.
> ///
> /// the template argument Ts... gives the argument list for the remote
> /// procedure.
> ///
> /// E.g.
> ///
> - /// typedef Procedure<0, bool> Proc1;
> - /// typedef Procedure<1, std::string, std::vector<int>> Proc2;
> + /// typedef Function<0, bool> Func1;
> + /// typedef Function<1, std::string, std::vector<int>> Func2;
> ///
> - /// if (auto EC = call<Proc1>(Channel, true))
> + /// if (auto EC = call<Func1>(Channel, true))
> /// /* handle EC */;
> ///
> - /// if (auto EC = expect<Proc2>(Channel,
> + /// if (auto EC = expect<Func2>(Channel,
> /// [](std::string &S, std::vector<int> &V) {
> /// // Stuff.
> /// return std::error_code();
> /// })
> /// /* handle EC */;
> ///
> - template <ProcedureIdT ProcId, typename FnT>
> - using Procedure = ProcedureHelper<ProcedureIdT, ProcId, FnT>;
> + template <FunctionIdT FuncId, typename FnT>
> + using Function = FunctionHelper<FunctionIdT, FuncId, FnT>;
> +
> + /// Return type for asynchronous call primitives.
> + template <typename Func>
> + using AsyncCallResult =
> + std::pair<std::future<typename Func::OptionalReturn>,
> SequenceNumberT>;
>
> /// Serialize Args... to channel C, but do not call C.send().
> ///
> - /// For buffered channels, this can be used to queue up several calls
> before
> - /// flushing the channel.
> - template <typename Proc, typename... ArgTs>
> - static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) {
> - return CallHelper<ChannelT, Proc>::call(C, Args...);
> + /// For void functions returns a std::future<Error>. For functions that
> + /// return an R, returns a std::future<Optional<R>>.
> + template <typename Func, typename... ArgTs>
> + ErrorOr<AsyncCallResult<Func>>
> + appendCallAsync(ChannelT &C, const ArgTs &... Args) {
> + auto SeqNo = SequenceNumberMgr.getSequenceNumber();
> + std::promise<typename Func::OptionalReturn> Promise;
> + auto Result = Promise.get_future();
> + OutstandingResults[SeqNo] = std::move(Promise);
> +
> + if (auto EC =
> + CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo,
> + Args...)) {
> + abandonOutstandingResults();
> + return EC;
> + } else
> + return AsyncCallResult<Func>(std::move(Result), SeqNo);
> }
>
> /// Serialize Args... to channel C and call C.send().
> - template <typename Proc, typename... ArgTs>
> - static std::error_code call(ChannelT &C, const ArgTs &... Args) {
> - if (auto EC = appendCall<Proc>(C, Args...))
> + template <typename Func, typename... ArgTs>
> + ErrorOr<AsyncCallResult<Func>>
> + callAsync(ChannelT &C, const ArgTs &... Args) {
> + auto SeqNo = SequenceNumberMgr.getSequenceNumber();
> + std::promise<typename Func::OptionalReturn> Promise;
> + auto Result = Promise.get_future();
> + OutstandingResults[SeqNo] =
> + createOutstandingResult<Func>(std::move(Promise));
> + if (auto EC =
> + CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo,
> Args...)) {
> + abandonOutstandingResults();
> return EC;
> - return C.send();
> + }
> + if (auto EC = C.send()) {
> + abandonOutstandingResults();
> + return EC;
> + }
> + return AsyncCallResult<Func>(std::move(Result), SeqNo);
> }
>
> - /// Deserialize and return an enum whose underlying type is
> ProcedureIdT.
> - static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) {
> + /// This can be used in single-threaded mode.
> + template <typename Func, typename HandleFtor, typename... ArgTs>
> + typename Func::ErrorReturn
> + callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &...
> Args) {
> + if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) {
> + auto &ResultAndSeqNo = *ResultAndSeqNoOrErr;
> + if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther))
> + return EC;
> + return Func::optionalToErrorReturn(ResultAndSeqNo.first.get());
> + } else
> + return ResultAndSeqNoOrErr.getError();
> + }
> +
> + // This can be used in single-threaded mode.
> + template <typename Func, typename... ArgTs>
> + typename Func::ErrorReturn
> + callST(ChannelT &C, const ArgTs &... Args) {
> + return callSTHandling<Func>(C, handleNone, Args...);
> + }
> +
> + /// Start receiving a new function call.
> + ///
> + /// Calls startReceiveMessage on the channel, then deserializes a
> FunctionId
> + /// into Id.
> + std::error_code startReceivingFunction(ChannelT &C, FunctionIdT &Id) {
> + if (auto EC = startReceiveMessage(C))
> + return EC;
> +
> return deserialize(C, Id);
> }
>
> - /// Deserialize args for Proc from C and call Handler. The signature of
> + /// Deserialize args for Func from C and call Handler. The signature of
> /// handler must conform to 'std::error_code(Args...)' where Args...
> matches
> - /// the arguments used in the Proc typedef.
> - template <typename Proc, typename HandlerT>
> + /// the arguments used in the Func typedef.
> + template <typename Func, typename HandlerT>
> static std::error_code handle(ChannelT &C, HandlerT Handler) {
> - return HandlerHelper<ChannelT, Proc>::handle(C, Handler);
> + return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C,
> Handler);
> }
>
> /// Helper version of 'handle' for calling member functions.
> - template <typename Proc, typename ClassT, typename... ArgTs>
> + template <typename Func, typename ClassT, typename RetT, typename...
> ArgTs>
> static std::error_code
> handle(ChannelT &C, ClassT &Instance,
> - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
> - return handle<Proc>(
> - C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
> + RetT (ClassT::*HandlerMethod)(ArgTs...)) {
> + return handle<Func>(
> + C,
> + MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance,
> HandlerMethod));
> }
>
> - /// Deserialize a ProcedureIdT from C and verify it matches the id for
> Proc.
> + /// Deserialize a FunctionIdT from C and verify it matches the id for
> Func.
> /// If the id does match, deserialize the arguments and call the handler
> /// (similarly to handle).
> /// If the id does not match, return an unexpect RPC call error and do
> not
> /// deserialize any further bytes.
> - template <typename Proc, typename HandlerT>
> - static std::error_code expect(ChannelT &C, HandlerT Handler) {
> - ProcedureIdT ProcId;
> - if (auto EC = getNextProcId(C, ProcId))
> - return EC;
> - if (ProcId != Proc::Id)
> + template <typename Func, typename HandlerT>
> + std::error_code expect(ChannelT &C, HandlerT Handler) {
> + FunctionIdT FuncId;
> + if (auto EC = startReceivingFunction(C, FuncId))
> + return std::move(EC);
> + if (FuncId != Func::Id)
> return orcError(OrcErrorCode::UnexpectedRPCCall);
> - return handle<Proc>(C, Handler);
> + return handle<Func>(C, Handler);
> }
>
> /// Helper version of expect for calling member functions.
> - template <typename Proc, typename ClassT, typename... ArgTs>
> + template <typename Func, typename ClassT, typename... ArgTs>
> static std::error_code
> expect(ChannelT &C, ClassT &Instance,
> std::error_code (ClassT::*HandlerMethod)(ArgTs...)) {
> - return expect<Proc>(
> + return expect<Func>(
> C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod));
> }
>
> @@ -251,18 +478,163 @@ public:
> /// channel.
> /// E.g.
> ///
> - /// typedef Procedure<0, bool, int> Proc1;
> + /// typedef Function<0, bool, int> Func1;
> ///
> /// ...
> /// bool B;
> /// int I;
> - /// if (auto EC = expect<Proc1>(Channel, readArgs(B, I)))
> + /// if (auto EC = expect<Func1>(Channel, readArgs(B, I)))
> /// /* Handle Args */ ;
> ///
> template <typename... ArgTs>
> static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
> return ReadArgs<ArgTs...>(Args...);
> }
> +
> + /// Read a response from Channel.
> + /// This should be called from the receive loop to retrieve results.
> + std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) {
> + if (auto EC = deserialize(C, SeqNo)) {
> + abandonOutstandingResults();
> + return EC;
> + }
> +
> + auto I = OutstandingResults.find(SeqNo);
> + if (I == OutstandingResults.end()) {
> + abandonOutstandingResults();
> + return orcError(OrcErrorCode::UnexpectedRPCResponse);
> + }
> +
> + if (auto EC = I->second->readResult(C)) {
> + abandonOutstandingResults();
> + // FIXME: Release sequence numbers?
> + return EC;
> + }
> +
> + OutstandingResults.erase(I);
> + SequenceNumberMgr.releaseSequenceNumber(SeqNo);
> +
> + return std::error_code();
> + }
> +
> + // Loop waiting for a result with the given sequence number.
> + // This can be used as a receive loop if the user doesn't have a
> default.
> + template <typename HandleOtherFtor>
> + std::error_code waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo,
> + HandleOtherFtor &HandleOther = handleNone)
> {
> + bool GotTgtResult = false;
> +
> + while (!GotTgtResult) {
> + FunctionIdT Id =
> + RPCFunctionIdTraits<FunctionIdT>::InvalidId;
> + if (auto EC = startReceivingFunction(C, Id))
> + return EC;
> + if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) {
> + SequenceNumberT SeqNo;
> + if (auto EC = handleResponse(C, SeqNo))
>
Indentation looks a bit off here ^ ('if' is not indented the same as
'SequenceNumberT') & following into the next two lines and maybe the return
in the else too?
> + return EC;
> + GotTgtResult = (SeqNo == TgtSeqNo);
+ } else if (auto EC = HandleOther(C, Id))
> + return EC;
> + }
> +
> + return std::error_code();
> + };
> +
> + // Default handler for 'other' (non-response) functions when waiting
> for a
> + // result from the channel.
> + static std::error_code handleNone(ChannelT&, FunctionIdT) {
> + return orcError(OrcErrorCode::UnexpectedRPCCall);
> + };
> +
> +private:
> +
> + // Manage sequence numbers.
> + class SequenceNumberManager {
> + public:
> +
> + SequenceNumberManager() = default;
> +
> + SequenceNumberManager(SequenceNumberManager &&Other)
> + : NextSequenceNumber(std::move(Other.NextSequenceNumber)),
> + FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {}
> +
> + SequenceNumberManager& operator=(SequenceNumberManager &&Other) {
> + NextSequenceNumber = std::move(Other.NextSequenceNumber);
> + FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers);
> + }
> +
> + void reset() {
> + std::lock_guard<std::mutex> Lock(SeqNoLock);
> + NextSequenceNumber = 0;
> + FreeSequenceNumbers.clear();
> + }
> +
> + SequenceNumberT getSequenceNumber() {
> + std::lock_guard<std::mutex> Lock(SeqNoLock);
> + if (FreeSequenceNumbers.empty())
> + return NextSequenceNumber++;
> + auto SequenceNumber = FreeSequenceNumbers.back();
> + FreeSequenceNumbers.pop_back();
> + return SequenceNumber;
> + }
> +
> + void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
> + std::lock_guard<std::mutex> Lock(SeqNoLock);
> + FreeSequenceNumbers.push_back(SequenceNumber);
> + }
> +
> + private:
> + std::mutex SeqNoLock;
> + SequenceNumberT NextSequenceNumber = 0;
> + std::vector<SequenceNumberT> FreeSequenceNumbers;
> + };
> +
> + // Base class for results that haven't been returned from the other end
> of the
> + // RPC connection yet.
> + class OutstandingResult {
> + public:
> + virtual ~OutstandingResult() {}
> + virtual std::error_code readResult(ChannelT &C) = 0;
> + virtual void abandon() = 0;
> + };
> +
> + // Outstanding results for a specific function.
> + template <typename Func>
> + class OutstandingResultImpl : public OutstandingResult {
> + private:
> + public:
> + OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P)
> + : P(std::move(P)) {}
> +
> + std::error_code readResult(ChannelT &C) override {
> + return Func::readResult(C, P);
> + }
> +
> + void abandon() override { Func::abandon(P); }
> +
> + private:
> + std::promise<typename Func::OptionalReturn> P;
> + };
> +
> + // Create an outstanding result for the given function.
> + template <typename Func>
> + std::unique_ptr<OutstandingResult>
> + createOutstandingResult(std::promise<typename Func::OptionalReturn>
> &&P) {
> + return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P));
> + }
> +
> + // Abandon all outstanding results.
> + void abandonOutstandingResults() {
> + for (auto &KV : OutstandingResults)
> + KV.second->abandon();
> + OutstandingResults.clear();
> + SequenceNumberMgr.reset();
> + }
> +
> + SequenceNumberManager SequenceNumberMgr;
> + std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>>
> + OutstandingResults;
> };
>
> } // end namespace remote
>
> Modified: llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp (original)
> +++ llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp Sun Apr 17 20:06:49
> 2016
> @@ -38,6 +38,8 @@ public:
> return "Remote indirect stubs owner Id already in use";
> case OrcErrorCode::UnexpectedRPCCall:
> return "Unexpected RPC call";
> + case OrcErrorCode::UnexpectedRPCResponse:
> + return "Unexpected RPC response";
> }
> llvm_unreachable("Unhandled error code");
> }
>
> Modified: llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp (original)
> +++ llvm/trunk/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp Sun Apr
> 17 20:06:49 2016
> @@ -13,50 +13,40 @@ namespace llvm {
> namespace orc {
> namespace remote {
>
> -#define PROCNAME(X) \
> +#define FUNCNAME(X) \
> case X ## Id: \
> return #X
>
> -const char *OrcRemoteTargetRPCAPI::getJITProcIdName(JITProcId Id) {
> +const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) {
> switch (Id) {
> case InvalidId:
> - return "*** Invalid JITProcId ***";
> - PROCNAME(CallIntVoid);
> - PROCNAME(CallIntVoidResponse);
> - PROCNAME(CallMain);
> - PROCNAME(CallMainResponse);
> - PROCNAME(CallVoidVoid);
> - PROCNAME(CallVoidVoidResponse);
> - PROCNAME(CreateRemoteAllocator);
> - PROCNAME(CreateIndirectStubsOwner);
> - PROCNAME(DeregisterEHFrames);
> - PROCNAME(DestroyRemoteAllocator);
> - PROCNAME(DestroyIndirectStubsOwner);
> - PROCNAME(EmitIndirectStubs);
> - PROCNAME(EmitIndirectStubsResponse);
> - PROCNAME(EmitResolverBlock);
> - PROCNAME(EmitTrampolineBlock);
> - PROCNAME(EmitTrampolineBlockResponse);
> - PROCNAME(GetSymbolAddress);
> - PROCNAME(GetSymbolAddressResponse);
> - PROCNAME(GetRemoteInfo);
> - PROCNAME(GetRemoteInfoResponse);
> - PROCNAME(ReadMem);
> - PROCNAME(ReadMemResponse);
> - PROCNAME(RegisterEHFrames);
> - PROCNAME(ReserveMem);
> - PROCNAME(ReserveMemResponse);
> - PROCNAME(RequestCompile);
> - PROCNAME(RequestCompileResponse);
> - PROCNAME(SetProtections);
> - PROCNAME(TerminateSession);
> - PROCNAME(WriteMem);
> - PROCNAME(WritePtr);
> + return "*** Invalid JITFuncId ***";
> + FUNCNAME(CallIntVoid);
> + FUNCNAME(CallMain);
> + FUNCNAME(CallVoidVoid);
> + FUNCNAME(CreateRemoteAllocator);
> + FUNCNAME(CreateIndirectStubsOwner);
> + FUNCNAME(DeregisterEHFrames);
> + FUNCNAME(DestroyRemoteAllocator);
> + FUNCNAME(DestroyIndirectStubsOwner);
> + FUNCNAME(EmitIndirectStubs);
> + FUNCNAME(EmitResolverBlock);
> + FUNCNAME(EmitTrampolineBlock);
> + FUNCNAME(GetSymbolAddress);
> + FUNCNAME(GetRemoteInfo);
> + FUNCNAME(ReadMem);
> + FUNCNAME(RegisterEHFrames);
> + FUNCNAME(ReserveMem);
> + FUNCNAME(RequestCompile);
> + FUNCNAME(SetProtections);
> + FUNCNAME(TerminateSession);
> + FUNCNAME(WriteMem);
> + FUNCNAME(WritePtr);
> };
> return nullptr;
> }
>
> -#undef PROCNAME
> +#undef FUNCNAME
>
> } // end namespace remote
> } // end namespace orc
>
> Modified: llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp (original)
> +++ llvm/trunk/tools/lli/ChildTarget/ChildTarget.cpp Sun Apr 17 20:06:49
> 2016
> @@ -54,8 +54,8 @@ int main(int argc, char *argv[]) {
> JITServer Server(Channel, SymbolLookup, RegisterEHFrames,
> DeregisterEHFrames);
>
> while (1) {
> - JITServer::JITProcId Id = JITServer::InvalidId;
> - if (auto EC = Server.getNextProcId(Id)) {
> + JITServer::JITFuncId Id = JITServer::InvalidId;
> + if (auto EC = Server.getNextFuncId(Id)) {
> errs() << "Error: " << EC.message() << "\n";
> return 1;
> }
> @@ -63,7 +63,7 @@ int main(int argc, char *argv[]) {
> case JITServer::TerminateSessionId:
> return 0;
> default:
> - if (auto EC = Server.handleKnownProcedure(Id)) {
> + if (auto EC = Server.handleKnownFunction(Id)) {
> errs() << "Error: " << EC.message() << "\n";
> return 1;
> }
>
> Modified: llvm/trunk/tools/lli/RemoteJITUtils.h
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/RemoteJITUtils.h?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/tools/lli/RemoteJITUtils.h (original)
> +++ llvm/trunk/tools/lli/RemoteJITUtils.h Sun Apr 17 20:06:49 2016
> @@ -16,6 +16,7 @@
>
> #include "llvm/ExecutionEngine/Orc/RPCChannel.h"
> #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
> +#include <mutex>
>
> #if !defined(_MSC_VER) && !defined(__MINGW32__)
> #include <unistd.h>
>
> Modified: llvm/trunk/tools/lli/lli.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/tools/lli/lli.cpp?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/tools/lli/lli.cpp (original)
> +++ llvm/trunk/tools/lli/lli.cpp Sun Apr 17 20:06:49 2016
> @@ -582,7 +582,7 @@ int main(int argc, char **argv, char * c
> // Reset errno to zero on entry to main.
> errno = 0;
>
> - int Result;
> + int Result = -1;
>
> // Sanity check use of remote-jit: LLI currently only supports use of
> the
> // remote JIT on Unix platforms.
> @@ -681,12 +681,13 @@ int main(int argc, char **argv, char * c
> static_cast<ForwardingMemoryManager*>(RTDyldMM)->setResolver(
> orc::createLambdaResolver(
> [&](const std::string &Name) {
> - orc::TargetAddress Addr = 0;
> - if (auto EC = R->getSymbolAddress(Addr, Name)) {
> - errs() << "Failure during symbol lookup: " << EC.message() <<
> "\n";
> - exit(1);
> - }
> - return RuntimeDyld::SymbolInfo(Addr, JITSymbolFlags::Exported);
> + if (auto AddrOrErr = R->getSymbolAddress(Name))
> + return RuntimeDyld::SymbolInfo(*AddrOrErr,
> JITSymbolFlags::Exported);
> + else {
> + errs() << "Failure during symbol lookup: "
> + << AddrOrErr.getError().message() << "\n";
> + exit(1);
> + }
> },
> [](const std::string &Name) { return nullptr; }
> ));
> @@ -698,8 +699,10 @@ int main(int argc, char **argv, char * c
> EE->finalizeObject();
> DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x"
> << format("%llx", Entry) << "\n");
> - if (auto EC = R->callIntVoid(Result, Entry))
> - errs() << "ERROR: " << EC.message() << "\n";
> + if (auto ResultOrErr = R->callIntVoid(Entry))
> + Result = *ResultOrErr;
> + else
> + errs() << "ERROR: " << ResultOrErr.getError().message() << "\n";
>
> // Like static constructors, the remote target MCJIT support doesn't
> handle
> // this yet. It could. FIXME.
>
> Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
> URL:
> http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=266581&r1=266580&r2=266581&view=diff
>
> ==============================================================================
> --- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
> +++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Sun Apr 17
> 20:06:49 2016
> @@ -44,26 +44,25 @@ private:
> class DummyRPC : public testing::Test,
> public RPC<QueueChannel> {
> public:
> - typedef Procedure<1, void(bool)> Proc1;
> - typedef Procedure<2, void(int8_t, uint8_t, int16_t, uint16_t,
> - int32_t, uint32_t, int64_t, uint64_t,
> - bool, std::string, std::vector<int>)>
> AllTheTypes;
> + typedef Function<2, void(bool)> BasicVoid;
> + typedef Function<3, int32_t(bool)> BasicInt;
> + typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t,
> + int32_t, uint32_t, int64_t, uint64_t,
> + bool, std::string, std::vector<int>)>
> AllTheTypes;
> };
>
>
> -TEST_F(DummyRPC, TestBasic) {
> +TEST_F(DummyRPC, TestAsyncBasicVoid) {
> std::queue<char> Queue;
> QueueChannel C(Queue);
>
> - {
> - // Make a call to Proc1.
> - auto EC = call<Proc1>(C, true);
> - EXPECT_FALSE(EC) << "Simple call over queue failed";
> - }
> + // Make an async call.
> + auto ResOrErr = callAsync<BasicVoid>(C, true);
> + EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
>
> {
> // Expect a call to Proc1.
> - auto EC = expect<Proc1>(C,
> + auto EC = expect<BasicVoid>(C,
> [&](bool &B) {
> EXPECT_EQ(B, true)
> << "Bool serialization broken";
> @@ -71,31 +70,71 @@ TEST_F(DummyRPC, TestBasic) {
> });
> EXPECT_FALSE(EC) << "Simple expect over queue failed";
> }
> +
> + {
> + // Wait for the result.
> + auto EC = waitForResult(C, ResOrErr->second, handleNone);
> + EXPECT_FALSE(EC) << "Could not read result.";
> + }
> +
> + // Verify that the function returned ok.
> + auto Val = ResOrErr->first.get();
> + EXPECT_TRUE(Val) << "Remote void function failed to execute.";
> }
>
> -TEST_F(DummyRPC, TestSerialization) {
> +TEST_F(DummyRPC, TestAsyncBasicInt) {
> std::queue<char> Queue;
> QueueChannel C(Queue);
>
> + // Make an async call.
> + auto ResOrErr = callAsync<BasicInt>(C, false);
> + EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
> +
> {
> - // Make a call to Proc1.
> - std::vector<int> v({42, 7});
> - auto EC = call<AllTheTypes>(C,
> - -101,
> - 250,
> - -10000,
> - 10000,
> - -1000000000,
> - 1000000000,
> - -10000000000,
> - 10000000000,
> - true,
> - "foo",
> - v);
> - EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed";
> + // Expect a call to Proc1.
> + auto EC = expect<BasicInt>(C,
> + [&](bool &B) {
> + EXPECT_EQ(B, false)
> + << "Bool serialization broken";
> + return 42;
> + });
> + EXPECT_FALSE(EC) << "Simple expect over queue failed";
> }
>
> {
> + // Wait for the result.
> + auto EC = waitForResult(C, ResOrErr->second, handleNone);
> + EXPECT_FALSE(EC) << "Could not read result.";
> + }
> +
> + // Verify that the function returned ok.
> + auto Val = ResOrErr->first.get();
> + EXPECT_TRUE(!!Val) << "Remote int function failed to execute.";
> + EXPECT_EQ(*Val, 42) << "Remote int function return wrong value.";
> +}
> +
> +TEST_F(DummyRPC, TestSerialization) {
> + std::queue<char> Queue;
> + QueueChannel C(Queue);
> +
> + // Make a call to Proc1.
> + std::vector<int> v({42, 7});
> + auto ResOrErr = callAsync<AllTheTypes>(C,
> + -101,
> + 250,
> + -10000,
> + 10000,
> + -1000000000,
> + 1000000000,
> + -10000000000,
> + 10000000000,
> + true,
> + "foo",
> + v);
> + EXPECT_TRUE(!!ResOrErr)
> + << "Big (serialization test) call over queue failed";
> +
> + {
> // Expect a call to Proc1.
> auto EC = expect<AllTheTypes>(C,
> [&](int8_t &s8,
> @@ -136,4 +175,14 @@ TEST_F(DummyRPC, TestSerialization) {
> });
> EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed";
> }
> +
> + {
> + // Wait for the result.
> + auto EC = waitForResult(C, ResOrErr->second, handleNone);
> + EXPECT_FALSE(EC) << "Could not read result.";
> + }
> +
> + // Verify that the function returned ok.
> + auto Val = ResOrErr->first.get();
> + EXPECT_TRUE(Val) << "Remote void function failed to execute.";
> }
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
>
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20160418/e861de41/attachment-0001.html>
More information about the llvm-commits
mailing list