[llvm] r300167 - [ORC] Add RPC and serialization support for Errors and Expecteds.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 12 20:51:35 PDT 2017


Author: lhames
Date: Wed Apr 12 22:51:35 2017
New Revision: 300167

URL: http://llvm.org/viewvc/llvm-project?rev=300167&view=rev
Log:
[ORC] Add RPC and serialization support for Errors and Expecteds.

This patch allows Error and Expected types to be passed to and returned from
RPC functions.

Serializers and deserializers for custom error types (types deriving from the
ErrorInfo class template) can be registered with the SerializationTraits for
a given channel type (see registerStringError in RPCSerialization.h for an
example), allowing a given custom type to be sent/received. Unregistered types
will be serialized/deserialized as StringErrors using the custom type's log
message as the error string.


Modified:
    llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
    llvm/trunk/include/llvm/Support/Error.h
    llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
    llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/OrcError.h Wed Apr 12 22:51:35 2017
@@ -32,6 +32,7 @@ enum class OrcErrorCode : int {
   RPCResponseAbandoned,
   UnexpectedRPCCall,
   UnexpectedRPCResponse,
+  UnknownErrorCodeFromRemote
 };
 
 std::error_code orcError(OrcErrorCode ErrCode);

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h Wed Apr 12 22:51:35 2017
@@ -12,6 +12,7 @@
 
 #include "OrcError.h"
 #include "llvm/Support/thread.h"
+#include <map>
 #include <mutex>
 #include <sstream>
 
@@ -114,6 +115,35 @@ public:
   static const char* getName() { return "std::string"; }
 };
 
+template <>
+class RPCTypeName<Error> {
+public:
+  static const char* getName() { return "Error"; }
+};
+
+template <typename T>
+class RPCTypeName<Expected<T>> {
+public:
+  static const char* getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      raw_string_ostream(Name) << "Expected<"
+                               << RPCTypeNameSequence<T>()
+                               << ">";
+    return Name.data();
+  }
+
+private:
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename T>
+std::mutex RPCTypeName<Expected<T>>::NameMutex;
+
+template <typename T>
+std::string RPCTypeName<Expected<T>>::Name;
+
 template <typename T1, typename T2>
 class RPCTypeName<std::pair<T1, T2>> {
 public:
@@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, Ar
 public:
 
   template <typename CArgT>
-  static Error serialize(ChannelT &C, const CArgT &CArg) {
-    return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg);
+  static Error serialize(ChannelT &C, CArgT &&CArg) {
+    return SerializationTraits<ChannelT, ArgT,
+                               typename std::decay<CArgT>::type>::
+             serialize(C, std::forward<CArgT>(CArg));
   }
 
   template <typename CArgT>
@@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, Ar
 public:
 
   template <typename CArgT, typename... CArgTs>
-  static Error serialize(ChannelT &C, const CArgT &CArg,
-                         const CArgTs&... CArgs) {
+  static Error serialize(ChannelT &C, CArgT &&CArg,
+                         CArgTs &&... CArgs) {
     if (auto Err =
-        SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg))
+        SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>::
+          serialize(C, std::forward<CArgT>(CArg)))
       return Err;
     if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))
       return Err;
-    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
+    return SequenceSerialization<ChannelT, ArgTs...>::
+             serialize(C, std::forward<CArgTs>(CArgs)...);
   }
 
   template <typename CArgT, typename... CArgTs>
   static Error deserialize(ChannelT &C, CArgT &CArg,
-                           CArgTs&... CArgs) {
+                           CArgTs &... CArgs) {
     if (auto Err =
         SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))
       return Err;
@@ -281,8 +315,9 @@ public:
 };
 
 template <typename ChannelT, typename... ArgTs>
-Error serializeSeq(ChannelT &C, const ArgTs &... Args) {
-  return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...);
+Error serializeSeq(ChannelT &C, ArgTs &&... Args) {
+  return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>::
+           serialize(C, std::forward<ArgTs>(Args)...);
 }
 
 template <typename ChannelT, typename... ArgTs>
@@ -290,6 +325,186 @@ Error deserializeSeq(ChannelT &C, ArgTs
   return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);
 }
 
+template <typename ChannelT>
+class SerializationTraits<ChannelT, Error> {
+public:
+
+  using WrappedErrorSerializer =
+    std::function<Error(ChannelT &C, const ErrorInfoBase&)>;
+
+  using WrappedErrorDeserializer =
+    std::function<Error(ChannelT &C, Error &Err)>;
+
+  template <typename ErrorInfoT, typename SerializeFtor,
+            typename DeserializeFtor>
+  static void registerErrorType(std::string Name, SerializeFtor Serialize,
+                                DeserializeFtor Deserialize) {
+    assert(!Name.empty() &&
+           "The empty string is reserved for the Success value");
+
+    std::lock_guard<std::mutex> Lock(SerializersMutex);
+
+    // We're abusing the stability of std::map here: We take a reference to the
+    // key of the deserializers map to save us from duplicating the string in
+    // the serializer. This should be changed to use a stringpool if we switch
+    // to a map type that may move keys in memory.
+    auto I =
+      Deserializers.insert(Deserializers.begin(),
+                           std::make_pair(std::move(Name),
+                                          std::move(Deserialize)));
+
+    const std::string &KeyName = I->first;
+    // FIXME: Move capture Serialize once we have C++14.
+    Serializers[ErrorInfoT::classID()] =
+      [&KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error {
+        assert(EIB.dynamicClassID() == ErrorInfoT::classID() &&
+               "Serializer called for wrong error type");
+        if (auto Err = serializeSeq(C, KeyName))
+          return Err;
+        return Serialize(C, static_cast<const ErrorInfoT&>(EIB));
+      };
+  }
+
+  static Error serialize(ChannelT &C, Error &&Err) {
+    std::lock_guard<std::mutex> Lock(SerializersMutex);
+    if (!Err)
+      return serializeSeq(C, std::string());
+
+    return handleErrors(std::move(Err),
+                        [&C](const ErrorInfoBase &EIB) {
+                          auto SI = Serializers.find(EIB.dynamicClassID());
+                          if (SI == Serializers.end())
+                            return serializeAsStringError(C, EIB);
+                          return (SI->second)(C, EIB);
+                        });
+  }
+
+  static Error deserialize(ChannelT &C, Error &Err) {
+    std::lock_guard<std::mutex> Lock(SerializersMutex);
+
+    std::string Key;
+    if (auto Err = deserializeSeq(C, Key))
+      return Err;
+
+    if (Key.empty()) {
+      ErrorAsOutParameter EAO(&Err);
+      Err = Error::success();
+      return Error::success();
+    }
+
+    auto DI = Deserializers.find(Key);
+    assert(DI != Deserializers.end() && "No deserializer for error type");
+    return (DI->second)(C, Err);
+  }
+
+private:
+
+  static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) {
+    assert(EIB.dynamicClassID() != StringError::classID() &&
+           "StringError serialization not registered");
+    std::string ErrMsg;
+    {
+      raw_string_ostream ErrMsgStream(ErrMsg);
+      EIB.log(ErrMsgStream);
+    }
+    return serialize(C, make_error<StringError>(std::move(ErrMsg),
+                                                inconvertibleErrorCode()));
+  }
+
+  static std::mutex SerializersMutex;
+  static std::map<const void*, WrappedErrorSerializer> Serializers;
+  static std::map<std::string, WrappedErrorDeserializer> Deserializers;
+};
+
+template <typename ChannelT>
+std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
+
+template <typename ChannelT>
+std::map<const void*,
+         typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer>
+SerializationTraits<ChannelT, Error>::Serializers;
+
+template <typename ChannelT>
+std::map<std::string,
+         typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer>
+SerializationTraits<ChannelT, Error>::Deserializers;
+
+template <typename ChannelT>
+void registerStringError() {
+  static bool AlreadyRegistered = false;
+  if (!AlreadyRegistered) {
+    SerializationTraits<ChannelT, Error>::
+      template registerErrorType<StringError>(
+        "StringError",
+        [](ChannelT &C, const StringError &SE) {
+          return serializeSeq(C, SE.getMessage());
+        },
+        [](ChannelT &C, Error &Err) {
+          ErrorAsOutParameter EAO(&Err);
+          std::string Msg;
+          if (auto E2 = deserializeSeq(C, Msg))
+            return E2;
+          Err =
+            make_error<StringError>(std::move(Msg),
+                                    orcError(
+                                      OrcErrorCode::UnknownErrorCodeFromRemote));
+          return Error::success();
+        });
+    AlreadyRegistered = true;
+  }
+};
+
+/// SerializationTraits for Expected<T1> from an Expected<T2>.
+template <typename ChannelT, typename T1, typename T2>
+class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> {
+public:
+
+  static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) {
+    if (ValOrErr) {
+      if (auto Err = serializeSeq(C, true))
+        return Err;
+      return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr);
+    }
+    if (auto Err = serializeSeq(C, false))
+      return Err;
+    return serializeSeq(C, ValOrErr.takeError());
+  }
+
+  static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) {
+    ExpectedAsOutParameter<T2> EAO(&ValOrErr);
+    bool HasValue;
+    if (auto Err = deserializeSeq(C, HasValue))
+      return Err;
+    if (HasValue)
+      return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr);
+    Error Err = Error::success();
+    if (auto E2 = deserializeSeq(C, Err))
+      return E2;
+    ValOrErr = std::move(Err);
+    return Error::success();
+  }
+};
+
+/// SerializationTraits for Expected<T1> from a T2.
+template <typename ChannelT, typename T1, typename T2>
+class SerializationTraits<ChannelT, Expected<T1>, T2> {
+public:
+
+  static Error serialize(ChannelT &C, T2 &&Val) {
+    return serializeSeq(C, Expected<T2>(std::forward<T2>(Val)));
+  }
+};
+
+/// SerializationTraits for Expected<T1> from an Error.
+template <typename ChannelT, typename T>
+class SerializationTraits<ChannelT, Expected<T>, Error> {
+public:
+
+  static Error serialize(ChannelT &C, Error &&Err) {
+    return serializeSeq(C, Expected<T>(std::move(Err)));
+  }
+};
+
 /// SerializationTraits default specialization for std::pair.
 template <typename ChannelT, typename T1, typename T2>
 class SerializationTraits<ChannelT, std::pair<T1, T2>> {

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCUtils.h Wed Apr 12 22:51:35 2017
@@ -129,7 +129,7 @@ public:
 
   CouldNotNegotiate(std::string Signature);
   std::error_code convertToErrorCode() const override;
-  void log(raw_ostream &OS) const override;  
+  void log(raw_ostream &OS) const override;
   const std::string &getSignature() const { return Signature; }
 private:
   std::string Signature;
@@ -362,30 +362,122 @@ template <> class ResultTraits<Error> :
 template <typename RetT>
 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
 
+// Determines whether an RPC function's defined error return type supports
+// error return value.
+template <typename T>
+class SupportsErrorReturn {
+public:
+  static const bool value = false;
+};
+
+template <>
+class SupportsErrorReturn<Error> {
+public:
+  static const bool value = true;
+};
+
+template <typename T>
+class SupportsErrorReturn<Expected<T>> {
+public:
+  static const bool value = true;
+};
+
+// RespondHelper packages return values based on whether or not the declared
+// RPC function return type supports error returns.
+template <bool FuncSupportsErrorReturn>
+class RespondHelper;
+
+// RespondHelper specialization for functions that support error returns.
+template <>
+class RespondHelper<true> {
+public:
+
+  // Send Expected<T>.
+  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+            typename FunctionIdT, typename SequenceNumberT>
+  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+                          SequenceNumberT SeqNo,
+                          Expected<HandlerRetT> ResultOrErr) {
+    if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
+      return ResultOrErr.takeError();
+
+    // Open the response message.
+    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+      return Err;
+
+    // Serialize the result.
+    if (auto Err =
+        SerializationTraits<ChannelT, WireRetT,
+                            Expected<HandlerRetT>>::serialize(
+                                                     C, std::move(ResultOrErr)))
+      return Err;
+
+    // Close the response message.
+    return C.endSendMessage();
+  }
+
+  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+                          SequenceNumberT SeqNo, Error Err) {
+    if (Err && Err.isA<RPCFatalError>())
+      return Err;
+    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+      return Err2;
+    if (auto Err2 = serializeSeq(C, std::move(Err)))
+      return Err2;
+    return C.endSendMessage();
+  }
+
+};
+
+// RespondHelper specialization for functions that do not support error returns.
+template <>
+class RespondHelper<false> {
+public:
+
+  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+            typename FunctionIdT, typename SequenceNumberT>
+  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+                          SequenceNumberT SeqNo,
+                          Expected<HandlerRetT> ResultOrErr) {
+    if (auto Err = ResultOrErr.takeError())
+      return Err;
+
+    // Open the response message.
+    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+      return Err;
+
+    // Serialize the result.
+    if (auto Err =
+        SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
+                                                               C, *ResultOrErr))
+      return Err;
+
+    // Close the response message.
+    return C.endSendMessage();
+  }
+
+  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+                          SequenceNumberT SeqNo, Error Err) {
+    if (Err)
+      return Err;
+    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+      return Err2;
+    return C.endSendMessage();
+  }
+
+};
+
+
 // Send a response of the given wire return type (WireRetT) over the
 // channel, with the given sequence number.
 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
           typename FunctionIdT, typename SequenceNumberT>
-static Error respond(ChannelT &C, const FunctionIdT &ResponseId,
-                     SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
-  // If this was an error bail out.
-  // FIXME: Send an "error" message to the client if this is not a channel
-  //        failure?
-  if (auto Err = ResultOrErr.takeError())
-    return Err;
-
-  // Open the response message.
-  if (auto Err = C.startSendMessage(ResponseId, SeqNo))
-    return Err;
-
-  // Serialize the result.
-  if (auto Err =
-          SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
-              C, *ResultOrErr))
-    return Err;
-
-  // Close the response message.
-  return C.endSendMessage();
+Error respond(ChannelT &C, const FunctionIdT &ResponseId,
+              SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
+  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+    template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
 }
 
 // Send an empty response message on the given channel to indicate that
@@ -394,11 +486,8 @@ template <typename WireRetT, typename Ch
           typename SequenceNumberT>
 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
               Error Err) {
-  if (Err)
-    return Err;
-  if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
-    return Err2;
-  return C.endSendMessage();
+  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+    sendResult(C, ResponseId, SeqNo, std::move(Err));
 }
 
 // Converts a given type to the equivalent error return type.
@@ -658,6 +747,72 @@ public:
   }
 
   // Abandon this response by calling the handler with an 'abandoned response'
+  // error.
+  void abandon() override {
+    if (auto Err = Handler(this->createAbandonedResponseError())) {
+      // Handlers should not fail when passed an abandoned response error.
+      report_fatal_error(std::move(Err));
+    }
+  }
+
+private:
+  HandlerT Handler;
+};
+
+template <typename ChannelT, typename FuncRetT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
+    : public ResponseHandler<ChannelT> {
+public:
+  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+  // Handle the result by deserializing it from the channel then passing it
+  // to the user defined handler.
+  Error handleResponse(ChannelT &C) override {
+    using HandlerArgType = typename ResponseHandlerArg<
+        typename HandlerTraits<HandlerT>::Type>::ArgType;
+    HandlerArgType Result((typename HandlerArgType::value_type()));
+
+    if (auto Err =
+            SerializationTraits<ChannelT, Expected<FuncRetT>,
+                                HandlerArgType>::deserialize(C, Result))
+      return Err;
+    if (auto Err = C.endReceiveMessage())
+      return Err;
+    return Handler(std::move(Result));
+  }
+
+  // Abandon this response by calling the handler with an 'abandoned response'
+  // error.
+  void abandon() override {
+    if (auto Err = Handler(this->createAbandonedResponseError())) {
+      // Handlers should not fail when passed an abandoned response error.
+      report_fatal_error(std::move(Err));
+    }
+  }
+
+private:
+  HandlerT Handler;
+};
+
+template <typename ChannelT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Error, HandlerT>
+    : public ResponseHandler<ChannelT> {
+public:
+  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+  // Handle the result by deserializing it from the channel then passing it
+  // to the user defined handler.
+  Error handleResponse(ChannelT &C) override {
+    Error Result = Error::success();
+    if (auto Err =
+            SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
+      return Err;
+    if (auto Err = C.endReceiveMessage())
+      return Err;
+    return Handler(std::move(Result));
+  }
+
+  // Abandon this response by calling the handler with an 'abandoned response'
   // error.
   void abandon() override {
     if (auto Err = Handler(this->createAbandonedResponseError())) {

Modified: llvm/trunk/include/llvm/Support/Error.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Support/Error.h?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Support/Error.h (original)
+++ llvm/trunk/include/llvm/Support/Error.h Wed Apr 12 22:51:35 2017
@@ -236,6 +236,14 @@ public:
     return getPtr() && getPtr()->isA(ErrT::classID());
   }
 
+  /// Returns the dynamic class id of this error, or null if this is a success
+  /// value.
+  const void* dynamicClassID() const {
+    if (!getPtr())
+      return nullptr;
+    return getPtr()->dynamicClassID();
+  }
+
 private:
   void assertIsChecked() {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -635,6 +643,7 @@ private:
 /// takeError(). It also adds an bool errorIsA<ErrT>() method for testing the
 /// error class type.
 template <class T> class LLVM_NODISCARD Expected {
+  template <class T1> friend class ExpectedAsOutParameter;
   template <class OtherT> friend class Expected;
   static const bool isRef = std::is_reference<T>::value;
   typedef ReferenceStorage<typename std::remove_reference<T>::type> wrap;
@@ -743,7 +752,7 @@ public:
 
   /// \brief Check that this Expected<T> is an error of type ErrT.
   template <typename ErrT> bool errorIsA() const {
-    return HasError && getErrorStorage()->template isA<ErrT>();
+    return HasError && (*getErrorStorage())->template isA<ErrT>();
   }
 
   /// \brief Take ownership of the stored error.
@@ -838,6 +847,18 @@ private:
     return reinterpret_cast<error_type *>(ErrorStorage.buffer);
   }
 
+  const error_type *getErrorStorage() const {
+    assert(HasError && "Cannot get error when a value exists!");
+    return reinterpret_cast<const error_type *>(ErrorStorage.buffer);
+  }
+
+  // Used by ExpectedAsOutParameter to reset the checked flag.
+  void setUnchecked() {
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+    Unchecked = true;
+#endif
+  }
+
   void assertIsChecked() {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
     if (Unchecked) {
@@ -864,6 +885,28 @@ private:
 #endif
 };
 
+/// Helper for Expected<T>s used as out-parameters.
+///
+/// See ErrorAsOutParameter.
+template <typename T>
+class ExpectedAsOutParameter {
+public:
+
+  ExpectedAsOutParameter(Expected<T> *ValOrErr)
+    : ValOrErr(ValOrErr) {
+    if (ValOrErr)
+      (void)!!*ValOrErr;
+  }
+
+  ~ExpectedAsOutParameter() {
+    if (ValOrErr)
+      ValOrErr->setUnchecked();
+  }
+
+private:
+  Expected<T> *ValOrErr;
+};
+
 /// This class wraps a std::error_code in a Error.
 ///
 /// This is useful if you're writing an interface that returns a Error

Modified: llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp (original)
+++ llvm/trunk/lib/ExecutionEngine/Orc/OrcError.cpp Wed Apr 12 22:51:35 2017
@@ -49,6 +49,9 @@ public:
       return "Unexpected RPC call";
     case OrcErrorCode::UnexpectedRPCResponse:
       return "Unexpected RPC response";
+    case OrcErrorCode::UnknownErrorCodeFromRemote:
+      return "Unknown error returned from remote RPC function "
+             "(Use StringError to get error message)";
     }
     llvm_unreachable("Unhandled error code");
   }

Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=300167&r1=300166&r2=300167&view=diff
==============================================================================
--- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
+++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Wed Apr 12 22:51:35 2017
@@ -47,6 +47,54 @@ namespace rpc {
 
 class RPCBar {};
 
+class DummyError : public ErrorInfo<DummyError> {
+public:
+
+  static char ID;
+
+  DummyError(uint32_t Val) : Val(Val) {}
+
+  std::error_code convertToErrorCode() const override {
+    // Use a nonsense error code - we want to verify that errors
+    // transmitted over the network are replaced with
+    // OrcErrorCode::UnknownErrorCodeFromRemote.
+    return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+  }
+
+  void log(raw_ostream &OS) const override {
+    OS << "Dummy error " << Val;
+  }
+
+  uint32_t getValue() const { return Val; }
+
+public:
+  uint32_t Val;
+};
+
+char DummyError::ID = 0;
+
+template <typename ChannelT>
+void registerDummyErrorSerialization() {
+  static bool AlreadyRegistered = false;
+  if (!AlreadyRegistered) {
+    SerializationTraits<ChannelT, Error>::
+      template registerErrorType<DummyError>(
+        "DummyError",
+        [](ChannelT &C, const DummyError &DE) {
+          return serializeSeq(C, DE.getValue());
+        },
+        [](ChannelT &C, Error &Err) -> Error {
+          ErrorAsOutParameter EAO(&Err);
+          uint32_t Val;
+          if (auto Err = deserializeSeq(C, Val))
+            return Err;
+          Err = make_error<DummyError>(Val);
+          return Error::success();
+        });
+    AlreadyRegistered = true;
+  }
+}
+
 namespace llvm {
 namespace orc {
 namespace rpc {
@@ -98,6 +146,16 @@ namespace DummyRPCAPI {
     static const char* getName() { return "CustomType"; }
   };
 
+  class ErrorFunc : public Function<ErrorFunc, Error()> {
+  public:
+    static const char* getName() { return "ErrorFunc"; }
+  };
+
+  class ExpectedFunc : public Function<ExpectedFunc, Expected<uint32_t>()> {
+  public:
+    static const char* getName() { return "ExpectedFunc"; }
+  };
+
 }
 
 class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> {
@@ -493,6 +551,140 @@ TEST(DummyRPC, TestWithAltCustomType) {
   ServerThread.join();
 }
 
+TEST(DummyRPC, ReturnErrorSuccess) {
+  registerDummyErrorSerialization<QueueChannel>();
+
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::ErrorFunc>(
+        []() {
+          return Error::success();
+        });
+
+      // Handle the negotiate plus one call.
+      for (unsigned I = 0; I != 2; ++I)
+        cantFail(Server.handleOne());
+    });
+
+  cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+             [&](Error Err) {
+               EXPECT_FALSE(!!Err) << "Expected success value";
+               return Error::success();
+             }));
+
+  cantFail(Client.handleOne());
+
+  ServerThread.join();
+}
+
+TEST(DummyRPC, ReturnErrorFailure) {
+  registerDummyErrorSerialization<QueueChannel>();
+
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::ErrorFunc>(
+        []() {
+          return make_error<DummyError>(42);
+        });
+
+      // Handle the negotiate plus one call.
+      for (unsigned I = 0; I != 2; ++I)
+        cantFail(Server.handleOne());
+    });
+
+  cantFail(Client.callAsync<DummyRPCAPI::ErrorFunc>(
+             [&](Error Err) {
+               EXPECT_TRUE(Err.isA<DummyError>())
+                 << "Incorrect error type";
+               return handleErrors(
+                        std::move(Err),
+                        [](const DummyError &DE) {
+                          EXPECT_EQ(DE.getValue(), 42ULL)
+                            << "Incorrect DummyError serialization";
+                        });
+             }));
+
+  cantFail(Client.handleOne());
+
+  ServerThread.join();
+}
+
+TEST(DummyRPC, RPCExpectedSuccess) {
+  registerDummyErrorSerialization<QueueChannel>();
+
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+        []() -> uint32_t {
+          return 42;
+        });
+
+      // Handle the negotiate plus one call.
+      for (unsigned I = 0; I != 2; ++I)
+        cantFail(Server.handleOne());
+    });
+
+  cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+               [&](Expected<uint32_t> ValOrErr) {
+                 EXPECT_TRUE(!!ValOrErr)
+                   << "Expected success value";
+                 EXPECT_EQ(*ValOrErr, 42ULL)
+                   << "Incorrect Expected<uint32_t> deserialization";
+                 return Error::success();
+               }));
+
+  cantFail(Client.handleOne());
+
+  ServerThread.join();
+};
+
+TEST(DummyRPC, RPCExpectedFailure) {
+  registerDummyErrorSerialization<QueueChannel>();
+
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::ExpectedFunc>(
+        []() -> Expected<uint32_t> {
+          return make_error<DummyError>(7);
+        });
+
+      // Handle the negotiate plus one call.
+      for (unsigned I = 0; I != 2; ++I)
+        cantFail(Server.handleOne());
+    });
+
+  cantFail(Client.callAsync<DummyRPCAPI::ExpectedFunc>(
+               [&](Expected<uint32_t> ValOrErr) {
+                 EXPECT_FALSE(!!ValOrErr)
+                   << "Expected failure value";
+                 auto Err = ValOrErr.takeError();
+                 EXPECT_TRUE(Err.isA<DummyError>())
+                   << "Incorrect error type";
+                 return handleErrors(
+                          std::move(Err),
+                          [](const DummyError &DE) {
+                            EXPECT_EQ(DE.getValue(), 7ULL)
+                              << "Incorrect DummyError serialization";
+                          });
+               }));
+
+  cantFail(Client.handleOne());
+
+  ServerThread.join();
+};
+
 TEST(DummyRPC, TestParallelCallGroup) {
   auto Channels = createPairedQueueChannels();
   DummyRPCEndpoint Client(*Channels.first);




More information about the llvm-commits mailing list