[llvm] r338305 - [ORC] Add SerializationTraits for std::set and std::map.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 30 14:08:07 PDT 2018


Author: lhames
Date: Mon Jul 30 14:08:06 2018
New Revision: 338305

URL: http://llvm.org/viewvc/llvm-project?rev=338305&view=rev
Log:
[ORC] Add SerializationTraits for std::set and std::map.

Also, make SerializationTraits for pairs forward the actual pair
template type arguments to the underlying serializer. This allows, for example,
std::pair<StringRef, bool> to be passed as an argument to an RPC call expecting
a std::pair<std::string, bool>, since there is an underlying serializer from
StringRef to std::string that can be used.

Modified:
    llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
    llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

Modified: llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h?rev=338305&r1=338304&r2=338305&view=diff
==============================================================================
--- llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h (original)
+++ llvm/trunk/include/llvm/ExecutionEngine/Orc/RPCSerialization.h Mon Jul 30 14:08:06 2018
@@ -14,7 +14,10 @@
 #include "llvm/Support/thread.h"
 #include <map>
 #include <mutex>
+#include <set>
 #include <sstream>
+#include <string>
+#include <vector>
 
 namespace llvm {
 namespace orc {
@@ -205,6 +208,42 @@ std::mutex RPCTypeName<std::vector<T>>::
 template <typename T>
 std::string RPCTypeName<std::vector<T>>::Name;
 
+template <typename T> class RPCTypeName<std::set<T>> {
+public:
+  static const char *getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      raw_string_ostream(Name)
+          << "std::set<" << RPCTypeName<T>::getName() << ">";
+    return Name.data();
+  }
+
+private:
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename T> std::mutex RPCTypeName<std::set<T>>::NameMutex;
+template <typename T> std::string RPCTypeName<std::set<T>>::Name;
+
+template <typename K, typename V> class RPCTypeName<std::map<K, V>> {
+public:
+  static const char *getName() {
+    std::lock_guard<std::mutex> Lock(NameMutex);
+    if (Name.empty())
+      raw_string_ostream(Name)
+          << "std::map<" << RPCTypeNameSequence<K, V>() << ">";
+    return Name.data();
+  }
+
+private:
+  static std::mutex NameMutex;
+  static std::string Name;
+};
+
+template <typename K, typename V>
+std::mutex RPCTypeName<std::map<K, V>>::NameMutex;
+template <typename K, typename V> std::string RPCTypeName<std::map<K, V>>::Name;
 
 /// The SerializationTraits<ChannelT, T> class describes how to serialize and
 /// deserialize an instance of type T to/from an abstract channel of type
@@ -527,15 +566,20 @@ public:
 };
 
 /// SerializationTraits default specialization for std::pair.
-template <typename ChannelT, typename T1, typename T2>
-class SerializationTraits<ChannelT, std::pair<T1, T2>> {
+template <typename ChannelT, typename T1, typename T2, typename T3, typename T4>
+class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> {
 public:
-  static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) {
-    return serializeSeq(C, V.first, V.second);
+  static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) {
+    if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first))
+      return Err;
+    return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second);
   }
 
-  static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) {
-    return deserializeSeq(C, V.first, V.second);
+  static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) {
+    if (auto Err =
+            SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first))
+      return Err;
+    return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second);
   }
 };
 
@@ -589,6 +633,9 @@ public:
 
   /// Deserialize a std::vector<T> to a std::vector<T>.
   static Error deserialize(ChannelT &C, std::vector<T> &V) {
+    assert(V.empty() &&
+           "Expected default-constructed vector to deserialize into");
+
     uint64_t Count = 0;
     if (auto Err = deserializeSeq(C, Count))
       return Err;
@@ -600,6 +647,92 @@ public:
 
     return Error::success();
   }
+};
+
+template <typename ChannelT, typename T, typename T2>
+class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> {
+public:
+  /// Serialize a std::set<T> from std::set<T2>.
+  static Error serialize(ChannelT &C, const std::set<T2> &S) {
+    if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
+      return Err;
+
+    for (const auto &E : S)
+      if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E))
+        return Err;
+
+    return Error::success();
+  }
+
+  /// Deserialize a std::set<T> to a std::set<T>.
+  static Error deserialize(ChannelT &C, std::set<T2> &S) {
+    assert(S.empty() && "Expected default-constructed set to deserialize into");
+
+    uint64_t Count = 0;
+    if (auto Err = deserializeSeq(C, Count))
+      return Err;
+
+    while (Count-- != 0) {
+      T2 Val;
+      if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val))
+        return Err;
+
+      auto Added = S.insert(Val).second;
+      if (!Added)
+        return make_error<StringError>("Duplicate element in deserialized set",
+                                       orcError(OrcErrorCode::UnknownORCError));
+    }
+
+    return Error::success();
+  }
+};
+
+template <typename ChannelT, typename K, typename V, typename K2, typename V2>
+class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> {
+public:
+  /// Serialize a std::map<K, V> from std::map<K2, V2>.
+  static Error serialize(ChannelT &C, const std::map<K2, V2> &M) {
+    if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size())))
+      return Err;
+
+    for (const auto &E : M) {
+      if (auto Err =
+              SerializationTraits<ChannelT, K, K2>::serialize(C, E.first))
+        return Err;
+      if (auto Err =
+              SerializationTraits<ChannelT, V, V2>::serialize(C, E.second))
+        return Err;
+    }
+
+    return Error::success();
+  }
+
+  /// Deserialize a std::map<K, V> to a std::map<K, V>.
+  static Error deserialize(ChannelT &C, std::map<K2, V2> &M) {
+    assert(M.empty() && "Expected default-constructed map to deserialize into");
+
+    uint64_t Count = 0;
+    if (auto Err = deserializeSeq(C, Count))
+      return Err;
+
+    while (Count-- != 0) {
+      std::pair<K2, V2> Val;
+      if (auto Err =
+              SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first))
+        return Err;
+
+      if (auto Err =
+              SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second))
+        return Err;
+
+      auto Added = M.insert(Val).second;
+      if (!Added)
+        return make_error<StringError>("Duplicate element in deserialized map",
+                                       orcError(OrcErrorCode::UnknownORCError));
+    }
+
+    return Error::success();
+  }
 };
 
 } // end namespace rpc

Modified: llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp?rev=338305&r1=338304&r2=338305&view=diff
==============================================================================
--- llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp (original)
+++ llvm/trunk/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp Mon Jul 30 14:08:06 2018
@@ -133,10 +133,10 @@ namespace DummyRPCAPI {
   };
 
   class AllTheTypes
-    : public Function<AllTheTypes,
-                      void(int8_t, uint8_t, int16_t, uint16_t, int32_t,
-                           uint32_t, int64_t, uint64_t, bool, std::string,
-                           std::vector<int>)> {
+      : public Function<AllTheTypes, void(int8_t, uint8_t, int16_t, uint16_t,
+                                          int32_t, uint32_t, int64_t, uint64_t,
+                                          bool, std::string, std::vector<int>,
+                                          std::set<int>, std::map<int, bool>)> {
   public:
     static const char* getName() { return "AllTheTypes"; }
   };
@@ -451,43 +451,50 @@ TEST(DummyRPC, TestSerialization) {
   DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
-      Server.addHandler<DummyRPCAPI::AllTheTypes>(
-          [&](int8_t S8, uint8_t U8, int16_t S16, uint16_t U16,
-              int32_t S32, uint32_t U32, int64_t S64, uint64_t U64,
-              bool B, std::string S, std::vector<int> V) {
-
-            EXPECT_EQ(S8, -101) << "int8_t serialization broken";
-            EXPECT_EQ(U8, 250) << "uint8_t serialization broken";
-            EXPECT_EQ(S16, -10000) << "int16_t serialization broken";
-            EXPECT_EQ(U16, 10000) << "uint16_t serialization broken";
-            EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken";
-            EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken";
-            EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken";
-            EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken";
-            EXPECT_EQ(B, true) << "bool serialization broken";
-            EXPECT_EQ(S, "foo") << "std::string serialization broken";
-            EXPECT_EQ(V, std::vector<int>({42, 7}))
-              << "std::vector serialization broken";
-            return Error::success();
-          });
-
-      {
-        // Poke the server to handle the negotiate call.
-        auto Err = Server.handleOne();
-        EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
-      }
-
-      {
-        // Poke the server to handle the AllTheTypes call.
-        auto Err = Server.handleOne();
-        EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)";
-      }
+    Server.addHandler<DummyRPCAPI::AllTheTypes>([&](int8_t S8, uint8_t U8,
+                                                    int16_t S16, uint16_t U16,
+                                                    int32_t S32, uint32_t U32,
+                                                    int64_t S64, uint64_t U64,
+                                                    bool B, std::string S,
+                                                    std::vector<int> V,
+                                                    std::set<int> S2,
+                                                    std::map<int, bool> M) {
+      EXPECT_EQ(S8, -101) << "int8_t serialization broken";
+      EXPECT_EQ(U8, 250) << "uint8_t serialization broken";
+      EXPECT_EQ(S16, -10000) << "int16_t serialization broken";
+      EXPECT_EQ(U16, 10000) << "uint16_t serialization broken";
+      EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken";
+      EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken";
+      EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken";
+      EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken";
+      EXPECT_EQ(B, true) << "bool serialization broken";
+      EXPECT_EQ(S, "foo") << "std::string serialization broken";
+      EXPECT_EQ(V, std::vector<int>({42, 7}))
+          << "std::vector serialization broken";
+      EXPECT_EQ(S2, std::set<int>({7, 42})) << "std::set serialization broken";
+      EXPECT_EQ(M, (std::map<int, bool>({{7, false}, {42, true}})))
+          << "std::map serialization broken";
+      return Error::success();
     });
 
+    {
+      // Poke the server to handle the negotiate call.
+      auto Err = Server.handleOne();
+      EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
+    }
+
+    {
+      // Poke the server to handle the AllTheTypes call.
+      auto Err = Server.handleOne();
+      EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)";
+    }
+  });
 
   {
     // Make an async call.
-    std::vector<int> v({42, 7});
+    std::vector<int> V({42, 7});
+    std::set<int> S({7, 42});
+    std::map<int, bool> M({{7, false}, {42, true}});
     auto Err = Client.callAsync<DummyRPCAPI::AllTheTypes>(
         [](Error Err) {
           EXPECT_FALSE(!!Err) << "Async AllTheTypes response handler failed";
@@ -497,7 +504,7 @@ TEST(DummyRPC, TestSerialization) {
         static_cast<int16_t>(-10000), static_cast<uint16_t>(10000),
         static_cast<int32_t>(-1000000000), static_cast<uint32_t>(1000000000),
         static_cast<int64_t>(-10000000000), static_cast<uint64_t>(10000000000),
-        true, std::string("foo"), v);
+        true, std::string("foo"), V, S, M);
     EXPECT_FALSE(!!Err) << "Client.callAsync failed for AllTheTypes";
   }
 




More information about the llvm-commits mailing list