[llvm] r291380 - [Orc][RPC] Add an APICalls utility for grouping RPC funtions for registration.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 7 17:13:47 PST 2017


Author: lhames
Date: Sat Jan  7 19:13:47 2017
New Revision: 291380

URL: http://llvm.org/viewvc/llvm-project?rev=291380&view=rev
Log:
[Orc][RPC] Add an APICalls utility for grouping RPC funtions for registration.

APICalls allows groups of functions to be composed into an API that can be
registered as a unit with an RPC endpoint. Doing registration on a-whole API
basis (rather than per-function) allows missing API functions to be detected
early.

APICalls also allows Function membership to be tested at compile-time. This
allows clients to write static assertions that functions to be called are
members of registered APIs.


Modified:
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
    llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h?rev=291380&r1=291379&r2=291380&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h Sat Jan  7 19:13:47 2017
@@ -868,6 +868,10 @@ protected:
   // a function handler. See addHandlerImpl.
   using LaunchPolicy = std::function<Error(std::function<Error()>)>;
 
+  FunctionIdT getInvalidFunctionId() const {
+    return FnIdAllocator.getInvalidId();
+  }
+
   /// Add the given handler to the handler map and make it available for
   /// autonegotiation and execution.
   template <typename Func, typename HandlerT>
@@ -915,7 +919,7 @@ protected:
   FunctionIdT handleNegotiate(const std::string &Name) {
     auto I = LocalFunctionIds.find(Name);
     if (I == LocalFunctionIds.end())
-      return FnIdAllocator.getInvalidId();
+      return getInvalidFunctionId();
     return I->second;
   }
 
@@ -938,7 +942,7 @@ protected:
 
         // If autonegotiation indicates that the remote end doesn't support this
         // function, return an unknown function error.
-        if (RemoteId == FnIdAllocator.getInvalidId())
+        if (RemoteId == getInvalidFunctionId())
           return orcError(OrcErrorCode::UnknownRPCFunction);
 
         // Autonegotiation succeeded and returned a valid id. Update the map and
@@ -1072,29 +1076,31 @@ public:
   }
 
   /// Negotiate a function id for Func with the other end of the channel.
-  template <typename Func> Error negotiateFunction() {
+  template <typename Func> Error negotiateFunction(bool Retry = false) {
     using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
 
+    // Check if we already have a function id...
+    auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+    if (I != this->RemoteFunctionIds.end()) {
+      // If it's valid there's nothing left to do.
+      if (I->second != this->getInvalidFunctionId())
+        return Error::success();
+      // If it's invalid and we can't re-attempt negotiation, throw an error.
+      if (!Retry)
+        return orcError(OrcErrorCode::UnknownRPCFunction);
+    }
+
+    // We don't have a function id for Func yet, call the remote to try to
+    // negotiate one.
     if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
       this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+      if (*RemoteIdOrErr == this->getInvalidFunctionId())
+        return orcError(OrcErrorCode::UnknownRPCFunction);
       return Error::success();
     } else
       return RemoteIdOrErr.takeError();
   }
 
-  /// Convenience method for negotiating multiple functions at once.
-  template <typename Func> Error negotiateFunctions() {
-    return negotiateFunction<Func>();
-  }
-
-  /// Convenience method for negotiating multiple functions at once.
-  template <typename Func1, typename Func2, typename... Funcs>
-  Error negotiateFunctions() {
-    if (auto Err = negotiateFunction<Func1>())
-      return Err;
-    return negotiateFunctions<Func2, Funcs...>();
-  }
-
   /// Return type for non-blocking call primitives.
   template <typename Func>
   using NonBlockingCallResult = typename detail::ResultTraits<
@@ -1208,29 +1214,31 @@ public:
   }
 
   /// Negotiate a function id for Func with the other end of the channel.
-  template <typename Func> Error negotiateFunction() {
+  template <typename Func> Error negotiateFunction(bool Retry = false) {
     using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
 
+    // Check if we already have a function id...
+    auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+    if (I != this->RemoteFunctionIds.end()) {
+      // If it's valid there's nothing left to do.
+      if (I->second != this->getInvalidFunctionId())
+        return Error::success();
+      // If it's invalid and we can't re-attempt negotiation, throw an error.
+      if (!Retry)
+        return orcError(OrcErrorCode::UnknownRPCFunction);
+    }
+
+    // We don't have a function id for Func yet, call the remote to try to
+    // negotiate one.
     if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
       this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+      if (*RemoteIdOrErr == this->getInvalidFunctionId())
+        return orcError(OrcErrorCode::UnknownRPCFunction);
       return Error::success();
     } else
       return RemoteIdOrErr.takeError();
   }
 
-  /// Convenience method for negotiating multiple functions at once.
-  template <typename Func> Error negotiateFunctions() {
-    return negotiateFunction<Func>();
-  }
-
-  /// Convenience method for negotiating multiple functions at once.
-  template <typename Func1, typename Func2, typename... Funcs>
-  Error negotiateFunctions() {
-    if (auto Err = negotiateFunction<Func1>())
-      return Err;
-    return negotiateFunctions<Func2, Funcs...>();
-  }
-
   template <typename Func, typename... ArgTs,
             typename AltRetT = typename Func::ReturnType>
   typename detail::ResultTraits<AltRetT>::ErrorReturnType
@@ -1343,6 +1351,68 @@ private:
   uint32_t NumOutstandingCalls;
 };
 
+/// @brief Convenience class for grouping RPC Functions into APIs that can be
+///        negotiated as a block.
+///
+template <typename... Funcs>
+class APICalls {
+public:
+
+  /// @brief Test whether this API contains Function F.
+  template <typename F>
+  class Contains {
+  public:
+    static const bool value = false;
+  };
+
+  /// @brief Negotiate all functions in this API.
+  template <typename RPCEndpoint>
+  static Error negotiate(RPCEndpoint &R) {
+    return Error::success();
+  }
+};
+
+template <typename Func, typename... Funcs>
+class APICalls<Func, Funcs...> {
+public:
+
+  template <typename F>
+  class Contains {
+  public:
+    static const bool value = std::is_same<F, Func>::value |
+                              APICalls<Funcs...>::template Contains<F>::value;
+  };
+
+  template <typename RPCEndpoint>
+  static Error negotiate(RPCEndpoint &R) {
+    if (auto Err = R.template negotiateFunction<Func>())
+      return Err;
+    return APICalls<Funcs...>::negotiate(R);
+  }
+
+};
+
+template <typename... InnerFuncs, typename... Funcs>
+class APICalls<APICalls<InnerFuncs...>, Funcs...> {
+public:
+
+  template <typename F>
+  class Contains {
+  public:
+    static const bool value =
+      APICalls<InnerFuncs...>::template Contains<F>::value |
+      APICalls<Funcs...>::template Contains<F>::value;
+  };
+
+  template <typename RPCEndpoint>
+  static Error negotiate(RPCEndpoint &R) {
+    if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
+      return Err;
+    return APICalls<Funcs...>::negotiate(R);
+  }
+
+};
+
 } // end namespace rpc
 } // end namespace orc
 } // end namespace llvm

Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=291380&r1=291379&r2=291380&view=diff
==============================================================================
--- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
+++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Sat Jan  7 19:13:47 2017
@@ -108,8 +108,7 @@ namespace rpc {
 } // end namespace orc
 } // end namespace llvm
 
-class DummyRPCAPI {
-public:
+namespace DummyRPCAPI {
 
   class VoidBool : public Function<VoidBool, void(bool)> {
   public:
@@ -455,4 +454,53 @@ TEST(DummyRPC, TestParallelCallGroup) {
   }
 
   ServerThread.join();
+}
+
+TEST(DummyRPC, TestAPICalls) {
+
+  using DummyCalls1 = APICalls<DummyRPCAPI::VoidBool, DummyRPCAPI::IntInt>;
+  using DummyCalls2 = APICalls<DummyRPCAPI::AllTheTypes>;
+  using DummyCalls3 = APICalls<DummyCalls1, DummyRPCAPI::CustomType>;
+  using DummyCallsAll = APICalls<DummyCalls1, DummyCalls2, DummyRPCAPI::CustomType>;
+
+  static_assert(DummyCalls1::Contains<DummyRPCAPI::VoidBool>::value,
+                "Contains<Func> template should return true here");
+  static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value,
+                "Contains<Func> template should return false here");
+
+  Queue Q1, Q2;
+  DummyRPCEndpoint Client(Q1, Q2);
+  DummyRPCEndpoint Server(Q2, Q1);
+
+  std::thread ServerThread(
+    [&]() {
+      Server.addHandler<DummyRPCAPI::VoidBool>([](bool b) { });
+      Server.addHandler<DummyRPCAPI::IntInt>([](int x) { return x; });
+      Server.addHandler<DummyRPCAPI::CustomType>([](RPCFoo F) {});
+
+      for (unsigned I = 0; I < 4; ++I) {
+        auto Err = Server.handleOne();
+        (void)!!Err;
+      }
+    });
+
+  {
+    auto Err = DummyCalls1::negotiate(Client);
+    EXPECT_FALSE(!!Err) << "DummyCalls1::negotiate failed";
+  }
+
+  {
+    auto Err = DummyCalls3::negotiate(Client);
+    EXPECT_FALSE(!!Err) << "DummyCalls3::negotiate failed";
+  }
+
+  {
+    auto Err = DummyCallsAll::negotiate(Client);
+    EXPECT_EQ(errorToErrorCode(std::move(Err)).value(),
+              static_cast<int>(OrcErrorCode::UnknownRPCFunction))
+      << "Uxpected 'UnknownRPCFunction' error for attempted negotiate of "
+         "unsupported function";
+  }
+
+  ServerThread.join();
 }




More information about the llvm-commits mailing list