[llvm] r266711 - [Orc] Tidy up some of the RPC primitives, add a unit-test for the callST

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 18 21:43:10 PDT 2016


Author: lhames
Date: Mon Apr 18 23:43:09 2016
New Revision: 266711

URL: http://llvm.org/viewvc/llvm-project?rev=266711&view=rev
Log:
[Orc] Tidy up some of the RPC primitives, add a unit-test for the callST
(synchronous call) primitive.


Modified:
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
    llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h?rev=266711&r1=266710&r2=266711&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h Mon Apr 18 23:43:09 2016
@@ -358,54 +358,83 @@ public:
 
   /// Return type for asynchronous call primitives.
   template <typename Func>
-  using AsyncCallResult =
+  using AsyncCallResult = std::future<typename Func::OptionalReturn>;
+
+  /// Return type for asynchronous call-with-seq primitives.
+  template <typename Func>
+  using AsyncCallWithSeqResult =
       std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>;
 
   /// Serialize Args... to channel C, but do not call C.send().
   ///
-  /// For void functions returns a std::future<Error>. For functions that
-  /// return an R, returns a std::future<Optional<R>>.
+  /// Returns an error (on serialization failure) or a pair of:
+  /// (1) A future Optional<T> (or future<bool> for void functions), and
+  /// (2) A sequence number.
+  ///
+  /// This utility function is primarily used for single-threaded mode support,
+  /// where the sequence number can be used to wait for the corresponding
+  /// result. In multi-threaded mode the appendCallAsync method, which does not
+  /// return the sequence numeber, should be preferred.
   template <typename Func, typename... ArgTs>
-  ErrorOr<AsyncCallResult<Func>> appendCallAsync(ChannelT &C,
-                                                 const ArgTs &... Args) {
+  ErrorOr<AsyncCallWithSeqResult<Func>>
+  appendCallAsyncWithSeq(ChannelT &C, const ArgTs &... Args) {
     auto SeqNo = SequenceNumberMgr.getSequenceNumber();
     std::promise<typename Func::OptionalReturn> Promise;
     auto Result = Promise.get_future();
-    OutstandingResults[SeqNo] = std::move(Promise);
+    OutstandingResults[SeqNo] =
+      createOutstandingResult<Func>(std::move(Promise));
 
     if (auto EC = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo,
                                                                     Args...)) {
       abandonOutstandingResults();
       return EC;
     } else
-      return AsyncCallResult<Func>(std::move(Result), SeqNo);
+      return AsyncCallWithSeqResult<Func>(std::move(Result), SeqNo);
   }
 
-  /// Serialize Args... to channel C and call C.send().
+  /// The same as appendCallAsyncWithSeq, except that it calls C.send() to
+  /// flush the channel after serializing the call.
   template <typename Func, typename... ArgTs>
-  ErrorOr<AsyncCallResult<Func>> callAsync(ChannelT &C, const ArgTs &... Args) {
-    auto SeqNo = SequenceNumberMgr.getSequenceNumber();
-    std::promise<typename Func::OptionalReturn> Promise;
-    auto Result = Promise.get_future();
-    OutstandingResults[SeqNo] =
-        createOutstandingResult<Func>(std::move(Promise));
-    if (auto EC = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo,
-                                                                    Args...)) {
-      abandonOutstandingResults();
-      return EC;
-    }
+  ErrorOr<AsyncCallWithSeqResult<Func>>
+  callAsyncWithSeq(ChannelT &C, const ArgTs &... Args) {
+    auto Result = appendCallAsyncWithSeq<Func>(C, Args...);
+    if (!Result)
+      return Result;
     if (auto EC = C.send()) {
       abandonOutstandingResults();
       return EC;
     }
-    return AsyncCallResult<Func>(std::move(Result), SeqNo);
+    return Result;
+  }
+
+  /// Serialize Args... to channel C, but do not call send.
+  /// Returns an error if serialization fails, otherwise returns a
+  /// std::future<Optional<T>> (or a future<bool> for void functions).
+  template <typename Func, typename... ArgTs>
+  ErrorOr<AsyncCallResult<Func>>
+  appendCallAsync(ChannelT &C, const ArgTs &... Args) {
+    auto ResAndSeqOrErr = appendCallAsyncWithSeq<Func>(C, Args...);
+    if (ResAndSeqOrErr)
+      return std::move(ResAndSeqOrErr->first);
+    return ResAndSeqOrErr.getError();
+  }
+
+  /// The same as appendCallAsync, except that it calls C.send to flush the
+  /// channel after serializing the call.
+  template <typename Func, typename... ArgTs>
+  ErrorOr<AsyncCallResult<Func>>
+  callAsync(ChannelT &C, const ArgTs &... Args) {
+    auto ResAndSeqOrErr = callAsyncWithSeq<Func>(C, Args...);
+    if (ResAndSeqOrErr)
+      return std::move(ResAndSeqOrErr->first);
+    return ResAndSeqOrErr.getError();
   }
 
   /// This can be used in single-threaded mode.
   template <typename Func, typename HandleFtor, typename... ArgTs>
   typename Func::ErrorReturn
   callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) {
-    if (auto ResultAndSeqNoOrErr = callAsync<Func>(C, Args...)) {
+    if (auto ResultAndSeqNoOrErr = callAsyncWithSeq<Func>(C, Args...)) {
       auto &ResultAndSeqNo = *ResultAndSeqNoOrErr;
       if (auto EC = waitForResult(C, ResultAndSeqNo.second, HandleOther))
         return EC;
@@ -491,12 +520,17 @@ public:
 
   /// Read a response from Channel.
   /// This should be called from the receive loop to retrieve results.
-  std::error_code handleResponse(ChannelT &C, SequenceNumberT &SeqNo) {
+  std::error_code handleResponse(ChannelT &C,
+                                 SequenceNumberT *SeqNoRet = nullptr) {
+    SequenceNumberT SeqNo;
     if (auto EC = deserialize(C, SeqNo)) {
       abandonOutstandingResults();
       return EC;
     }
 
+    if (SeqNoRet)
+      *SeqNoRet = SeqNo;
+
     auto I = OutstandingResults.find(SeqNo);
     if (I == OutstandingResults.end()) {
       abandonOutstandingResults();
@@ -528,7 +562,7 @@ public:
         return EC;
       if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) {
         SequenceNumberT SeqNo;
-        if (auto EC = handleResponse(C, SeqNo))
+        if (auto EC = handleResponse(C, &SeqNo))
           return EC;
         GotTgtResult = (SeqNo == TgtSeqNo);
       } else if (auto EC = HandleOther(C, Id))

Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=266711&r1=266710&r2=266711&view=diff
==============================================================================
--- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
+++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Mon Apr 18 23:43:09 2016
@@ -17,52 +17,81 @@ using namespace llvm;
 using namespace llvm::orc;
 using namespace llvm::orc::remote;
 
+class Queue : public std::queue<char> {
+public:
+  std::mutex& getLock() { return Lock; }
+private:
+  std::mutex Lock;
+};
+
 class QueueChannel : public RPCChannel {
 public:
-  QueueChannel(std::queue<char> &Queue) : Queue(Queue) {}
+  QueueChannel(Queue &InQueue, Queue &OutQueue)
+    : InQueue(InQueue), OutQueue(OutQueue) {}
 
   std::error_code readBytes(char *Dst, unsigned Size) override {
-    while (Size--) {
-      *Dst++ = Queue.front();
-      Queue.pop();
+    while (Size != 0) {
+      // If there's nothing to read then yield.
+      while (InQueue.empty())
+        std::this_thread::yield();
+
+      // Lock the channel and read what we can.
+      std::lock_guard<std::mutex> Lock(InQueue.getLock());
+      while (!InQueue.empty() && Size) {
+        *Dst++ = InQueue.front();
+        --Size;
+        InQueue.pop();
+      }
     }
     return std::error_code();
   }
 
   std::error_code appendBytes(const char *Src, unsigned Size) override {
+    std::lock_guard<std::mutex> Lock(OutQueue.getLock());
     while (Size--)
-      Queue.push(*Src++);
+      OutQueue.push(*Src++);
     return std::error_code();
   }
 
   std::error_code send() override { return std::error_code(); }
 
 private:
-  std::queue<char> &Queue;
+  Queue &InQueue;
+  Queue &OutQueue;
 };
 
 class DummyRPC : public testing::Test,
                  public RPC<QueueChannel> {
 public:
-  typedef Function<2, void(bool)> BasicVoid;
-  typedef Function<3, int32_t(bool)> BasicInt;
-  typedef Function<4, void(int8_t, uint8_t, int16_t, uint16_t,
-                           int32_t, uint32_t, int64_t, uint64_t,
-                           bool, std::string, std::vector<int>)> AllTheTypes;
+
+  enum FuncId : uint32_t {
+    VoidBoolId = RPCFunctionIdTraits<FuncId>::FirstValidId,
+    IntIntId,
+    AllTheTypesId
+  };
+
+  typedef Function<VoidBoolId, void(bool)> VoidBool;
+  typedef Function<IntIntId, int32_t(int32_t)> IntInt;
+  typedef Function<AllTheTypesId, void(int8_t, uint8_t, int16_t, uint16_t,
+                                       int32_t, uint32_t, int64_t, uint64_t,
+                                       bool, std::string, std::vector<int>)>
+    AllTheTypes;
+
 };
 
 
-TEST_F(DummyRPC, TestAsyncBasicVoid) {
-  std::queue<char> Queue;
-  QueueChannel C(Queue);
+TEST_F(DummyRPC, TestAsyncVoidBool) {
+  Queue Q1, Q2;
+  QueueChannel C1(Q1, Q2);
+  QueueChannel C2(Q2, Q1);
 
   // Make an async call.
-  auto ResOrErr = callAsync<BasicVoid>(C, true);
+  auto ResOrErr = callAsyncWithSeq<VoidBool>(C1, true);
   EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
 
   {
     // Expect a call to Proc1.
-    auto EC = expect<BasicVoid>(C,
+    auto EC = expect<VoidBool>(C2,
                 [&](bool &B) {
                   EXPECT_EQ(B, true)
                     << "Bool serialization broken";
@@ -73,7 +102,7 @@ TEST_F(DummyRPC, TestAsyncBasicVoid) {
 
   {
     // Wait for the result.
-    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    auto EC = waitForResult(C1, ResOrErr->second, handleNone);
     EXPECT_FALSE(EC) << "Could not read result.";
   }
 
@@ -82,28 +111,29 @@ TEST_F(DummyRPC, TestAsyncBasicVoid) {
   EXPECT_TRUE(Val) << "Remote void function failed to execute.";
 }
 
-TEST_F(DummyRPC, TestAsyncBasicInt) {
-  std::queue<char> Queue;
-  QueueChannel C(Queue);
+TEST_F(DummyRPC, TestAsyncIntInt) {
+  Queue Q1, Q2;
+  QueueChannel C1(Q1, Q2);
+  QueueChannel C2(Q2, Q1);
 
   // Make an async call.
-  auto ResOrErr = callAsync<BasicInt>(C, false);
+  auto ResOrErr = callAsyncWithSeq<IntInt>(C1, 21);
   EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed";
 
   {
     // Expect a call to Proc1.
-    auto EC = expect<BasicInt>(C,
-                [&](bool &B) {
-                  EXPECT_EQ(B, false)
+    auto EC = expect<IntInt>(C2,
+                [&](int32_t I) {
+                  EXPECT_EQ(I, 21)
                     << "Bool serialization broken";
-                  return 42;
+                  return 2 * I;
                 });
     EXPECT_FALSE(EC) << "Simple expect over queue failed";
   }
 
   {
     // Wait for the result.
-    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    auto EC = waitForResult(C1, ResOrErr->second, handleNone);
     EXPECT_FALSE(EC) << "Could not read result.";
   }
 
@@ -114,29 +144,30 @@ TEST_F(DummyRPC, TestAsyncBasicInt) {
 }
 
 TEST_F(DummyRPC, TestSerialization) {
-  std::queue<char> Queue;
-  QueueChannel C(Queue);
+  Queue Q1, Q2;
+  QueueChannel C1(Q1, Q2);
+  QueueChannel C2(Q2, Q1);
 
   // Make a call to Proc1.
   std::vector<int> v({42, 7});
-  auto ResOrErr = callAsync<AllTheTypes>(C,
-                                         -101,
-                                         250,
-                                         -10000,
-                                         10000,
-                                         -1000000000,
-                                         1000000000,
-                                         -10000000000,
-                                         10000000000,
-                                         true,
-                                         "foo",
-                                         v);
+  auto ResOrErr = callAsyncWithSeq<AllTheTypes>(C1,
+                                                -101,
+                                                250,
+                                                -10000,
+                                                10000,
+                                                -1000000000,
+                                                1000000000,
+                                                -10000000000,
+                                                10000000000,
+                                                true,
+                                                "foo",
+                                                v);
   EXPECT_TRUE(!!ResOrErr)
     << "Big (serialization test) call over queue failed";
 
   {
     // Expect a call to Proc1.
-    auto EC = expect<AllTheTypes>(C,
+    auto EC = expect<AllTheTypes>(C2,
                 [&](int8_t &s8,
                     uint8_t &u8,
                     int16_t &s16,
@@ -178,7 +209,7 @@ TEST_F(DummyRPC, TestSerialization) {
 
   {
     // Wait for the result.
-    auto EC = waitForResult(C, ResOrErr->second, handleNone);
+    auto EC = waitForResult(C1, ResOrErr->second, handleNone);
     EXPECT_FALSE(EC) << "Could not read result.";
   }
 
@@ -186,3 +217,25 @@ TEST_F(DummyRPC, TestSerialization) {
   auto Val = ResOrErr->first.get();
   EXPECT_TRUE(Val) << "Remote void function failed to execute.";
 }
+
+// Test the synchronous call API.
+TEST_F(DummyRPC, TestSynchronousCall) {
+  Queue Q1, Q2;
+  QueueChannel C1(Q1, Q2);
+  QueueChannel C2(Q2, Q1);
+
+  auto ServerResult =
+    std::async(std::launch::async,
+      [&]() {
+        return expect<IntInt>(C2, [&](int32_t V) { return V; });
+      });
+
+  auto ValOrErr = callST<IntInt>(C1, 42);
+
+  EXPECT_FALSE(!!ServerResult.get())
+    << "Server returned an error.";
+  EXPECT_TRUE(!!ValOrErr)
+    << "callST returned an error.";
+  EXPECT_EQ(*ValOrErr, 42)
+    << "Incorrect callST<IntInt> result";
+}




More information about the llvm-commits mailing list