[llvm] r346884 - [Support] Teach YAMLIO about polymorphic types

Scott Linder via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 14 11:40:00 PST 2018


Author: scott.linder
Date: Wed Nov 14 11:39:59 2018
New Revision: 346884

URL: http://llvm.org/viewvc/llvm-project?rev=346884&view=rev
Log:
[Support] Teach YAMLIO about polymorphic types

Add support for "polymorphic" types to YAMLIO.

PolymorphicTraits can dynamically switch between other traits (Scalar, Map, or
Sequence). When inputting, the PolymorphicTraits type is told which type to
become, and when outputting the PolymorphicTraits type is asked which type it
currently is.

Also add support for TaggedScalarTraits to allow dynamically differentiating
between multiple scalar types using YAML tags.

Serialize empty maps as "{}" and empty sequences as "[]", so that types
are preserved when round-tripping PolymorphicTraits. This change has
equivalent semantics, but may break e.g. tests which compare output
verbatim.

Differential Revision: https://reviews.llvm.org/D48144

Modified:
    llvm/trunk/include/llvm/Support/YAMLTraits.h
    llvm/trunk/lib/Support/YAMLTraits.cpp
    llvm/trunk/unittests/Support/YAMLIOTest.cpp

Modified: llvm/trunk/include/llvm/Support/YAMLTraits.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Support/YAMLTraits.h?rev=346884&r1=346883&r2=346884&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Support/YAMLTraits.h (original)
+++ llvm/trunk/include/llvm/Support/YAMLTraits.h Wed Nov 14 11:39:59 2018
@@ -39,6 +39,12 @@
 namespace llvm {
 namespace yaml {
 
+enum class NodeKind : uint8_t {
+  Scalar,
+  Map,
+  Sequence,
+};
+
 struct EmptyContext {};
 
 /// This class should be specialized by any type that needs to be converted
@@ -145,14 +151,14 @@ struct ScalarTraits {
   // Must provide:
   //
   // Function to write the value as a string:
-  //static void output(const T &value, void *ctxt, llvm::raw_ostream &out);
+  // static void output(const T &value, void *ctxt, llvm::raw_ostream &out);
   //
   // Function to convert a string to a value.  Returns the empty
   // StringRef on success or an error string if string is malformed:
-  //static StringRef input(StringRef scalar, void *ctxt, T &value);
+  // static StringRef input(StringRef scalar, void *ctxt, T &value);
   //
   // Function to determine if the value should be quoted.
-  //static QuotingType mustQuote(StringRef);
+  // static QuotingType mustQuote(StringRef);
 };
 
 /// This class should be specialized by type that requires custom conversion
@@ -163,7 +169,7 @@ struct ScalarTraits {
 ///      static void output(const MyType &Value, void*, llvm::raw_ostream &Out)
 ///      {
 ///        // stream out custom formatting
-///        Out << Val;
+///        Out << Value;
 ///      }
 ///      static StringRef input(StringRef Scalar, void*, MyType &Value) {
 ///        // parse scalar and set `value`
@@ -181,6 +187,47 @@ struct BlockScalarTraits {
   // Function to convert a string to a value.  Returns the empty
   // StringRef on success or an error string if string is malformed:
   // static StringRef input(StringRef Scalar, void *ctxt, T &Value);
+  //
+  // Optional:
+  // static StringRef inputTag(T &Val, std::string Tag)
+  // static void outputTag(const T &Val, raw_ostream &Out)
+};
+
+/// This class should be specialized by type that requires custom conversion
+/// to/from a YAML scalar with optional tags. For example:
+///
+///    template <>
+///    struct TaggedScalarTraits<MyType> {
+///      static void output(const MyType &Value, void*, llvm::raw_ostream
+///      &ScalarOut, llvm::raw_ostream &TagOut)
+///      {
+///        // stream out custom formatting including optional Tag
+///        Out << Value;
+///      }
+///      static StringRef input(StringRef Scalar, StringRef Tag, void*, MyType
+///      &Value) {
+///        // parse scalar and set `value`
+///        // return empty string on success, or error string
+///        return StringRef();
+///      }
+///      static QuotingType mustQuote(const MyType &Value, StringRef) {
+///        return QuotingType::Single;
+///      }
+///    };
+template <typename T> struct TaggedScalarTraits {
+  // Must provide:
+  //
+  // Function to write the value and tag as strings:
+  // static void output(const T &Value, void *ctx, llvm::raw_ostream &ScalarOut,
+  // llvm::raw_ostream &TagOut);
+  //
+  // Function to convert a string to a value.  Returns the empty
+  // StringRef on success or an error string if string is malformed:
+  // static StringRef input(StringRef Scalar, StringRef Tag, void *ctxt, T
+  // &Value);
+  //
+  // Function to determine if the value should be quoted.
+  // static QuotingType mustQuote(const T &Value, StringRef Scalar);
 };
 
 /// This class should be specialized by any type that needs to be converted
@@ -234,6 +281,31 @@ struct CustomMappingTraits {
   // static void output(IO &io, T &elem);
 };
 
+/// This class should be specialized by any type that can be represented as
+/// a scalar, map, or sequence, decided dynamically. For example:
+///
+///    typedef std::unique_ptr<MyBase> MyPoly;
+///
+///    template<>
+///    struct PolymorphicTraits<MyPoly> {
+///      static NodeKind getKind(const MyPoly &poly) {
+///        return poly->getKind();
+///      }
+///      static MyScalar& getAsScalar(MyPoly &poly) {
+///        if (!poly || !isa<MyScalar>(poly))
+///          poly.reset(new MyScalar());
+///        return *cast<MyScalar>(poly.get());
+///      }
+///      // ...
+///    };
+template <typename T> struct PolymorphicTraits {
+  // Must provide:
+  // static NodeKind getKind(const T &poly);
+  // static scalar_type &getAsScalar(T &poly);
+  // static map_type &getAsMap(T &poly);
+  // static sequence_type &getAsSequence(T &poly);
+};
+
 // Only used for better diagnostics of missing traits
 template <typename T>
 struct MissingTrait;
@@ -307,6 +379,24 @@ struct has_BlockScalarTraits
       (sizeof(test<BlockScalarTraits<T>>(nullptr, nullptr)) == 1);
 };
 
+// Test if TaggedScalarTraits<T> is defined on type T.
+template <class T> struct has_TaggedScalarTraits {
+  using Signature_input = StringRef (*)(StringRef, StringRef, void *, T &);
+  using Signature_output = void (*)(const T &, void *, raw_ostream &,
+                                    raw_ostream &);
+  using Signature_mustQuote = QuotingType (*)(const T &, StringRef);
+
+  template <typename U>
+  static char test(SameType<Signature_input, &U::input> *,
+                   SameType<Signature_output, &U::output> *,
+                   SameType<Signature_mustQuote, &U::mustQuote> *);
+
+  template <typename U> static double test(...);
+
+  static bool const value =
+      (sizeof(test<TaggedScalarTraits<T>>(nullptr, nullptr, nullptr)) == 1);
+};
+
 // Test if MappingContextTraits<T> is defined on type T.
 template <class T, class Context> struct has_MappingTraits {
   using Signature_mapping = void (*)(class IO &, T &, Context &);
@@ -438,6 +528,17 @@ struct has_DocumentListTraits
   static bool const value = (sizeof(test<DocumentListTraits<T>>(nullptr))==1);
 };
 
+template <class T> struct has_PolymorphicTraits {
+  using Signature_getKind = NodeKind (*)(const T &);
+
+  template <typename U>
+  static char test(SameType<Signature_getKind, &U::getKind> *);
+
+  template <typename U> static double test(...);
+
+  static bool const value = (sizeof(test<PolymorphicTraits<T>>(nullptr)) == 1);
+};
+
 inline bool isNumeric(StringRef S) {
   const static auto skipDigits = [](StringRef Input) {
     return Input.drop_front(
@@ -626,10 +727,12 @@ struct missingTraits
                                         !has_ScalarBitSetTraits<T>::value &&
                                         !has_ScalarTraits<T>::value &&
                                         !has_BlockScalarTraits<T>::value &&
+                                        !has_TaggedScalarTraits<T>::value &&
                                         !has_MappingTraits<T, Context>::value &&
                                         !has_SequenceTraits<T>::value &&
                                         !has_CustomMappingTraits<T>::value &&
-                                        !has_DocumentListTraits<T>::value> {};
+                                        !has_DocumentListTraits<T>::value &&
+                                        !has_PolymorphicTraits<T>::value> {};
 
 template <typename T, typename Context>
 struct validatedMappingTraits
@@ -683,6 +786,9 @@ public:
 
   virtual void scalarString(StringRef &, QuotingType) = 0;
   virtual void blockScalarString(StringRef &) = 0;
+  virtual void scalarTag(std::string &) = 0;
+
+  virtual NodeKind getNodeKind() = 0;
 
   virtual void setError(const Twine &) = 0;
 
@@ -917,6 +1023,31 @@ yamlize(IO &YamlIO, T &Val, bool, EmptyC
   }
 }
 
+template <typename T>
+typename std::enable_if<has_TaggedScalarTraits<T>::value, void>::type
+yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) {
+  if (io.outputting()) {
+    std::string ScalarStorage, TagStorage;
+    raw_string_ostream ScalarBuffer(ScalarStorage), TagBuffer(TagStorage);
+    TaggedScalarTraits<T>::output(Val, io.getContext(), ScalarBuffer,
+                                  TagBuffer);
+    io.scalarTag(TagBuffer.str());
+    StringRef ScalarStr = ScalarBuffer.str();
+    io.scalarString(ScalarStr,
+                    TaggedScalarTraits<T>::mustQuote(Val, ScalarStr));
+  } else {
+    std::string Tag;
+    io.scalarTag(Tag);
+    StringRef Str;
+    io.scalarString(Str, QuotingType::None);
+    StringRef Result =
+        TaggedScalarTraits<T>::input(Str, Tag, io.getContext(), Val);
+    if (!Result.empty()) {
+      io.setError(Twine(Result));
+    }
+  }
+}
+
 template <typename T, typename Context>
 typename std::enable_if<validatedMappingTraits<T, Context>::value, void>::type
 yamlize(IO &io, T &Val, bool, Context &Ctx) {
@@ -973,6 +1104,20 @@ yamlize(IO &io, T &Val, bool, EmptyConte
 }
 
 template <typename T>
+typename std::enable_if<has_PolymorphicTraits<T>::value, void>::type
+yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) {
+  switch (io.outputting() ? PolymorphicTraits<T>::getKind(Val)
+                          : io.getNodeKind()) {
+  case NodeKind::Scalar:
+    return yamlize(io, PolymorphicTraits<T>::getAsScalar(Val), true, Ctx);
+  case NodeKind::Map:
+    return yamlize(io, PolymorphicTraits<T>::getAsMap(Val), true, Ctx);
+  case NodeKind::Sequence:
+    return yamlize(io, PolymorphicTraits<T>::getAsSequence(Val), true, Ctx);
+  }
+}
+
+template <typename T>
 typename std::enable_if<missingTraits<T, EmptyContext>::value, void>::type
 yamlize(IO &io, T &Val, bool, EmptyContext &Ctx) {
   char missing_yaml_trait_for_type[sizeof(MissingTrait<T>)];
@@ -1250,6 +1395,8 @@ private:
   void endBitSetScalar() override;
   void scalarString(StringRef &, QuotingType) override;
   void blockScalarString(StringRef &) override;
+  void scalarTag(std::string &) override;
+  NodeKind getNodeKind() override;
   void setError(const Twine &message) override;
   bool canElideEmptySequence() override;
 
@@ -1395,6 +1542,8 @@ public:
   void endBitSetScalar() override;
   void scalarString(StringRef &, QuotingType) override;
   void blockScalarString(StringRef &) override;
+  void scalarTag(std::string &) override;
+  NodeKind getNodeKind() override;
   void setError(const Twine &message) override;
   bool canElideEmptySequence() override;
 
@@ -1414,14 +1563,21 @@ private:
   void flowKey(StringRef Key);
 
   enum InState {
-    inSeq,
-    inFlowSeq,
+    inSeqFirstElement,
+    inSeqOtherElement,
+    inFlowSeqFirstElement,
+    inFlowSeqOtherElement,
     inMapFirstKey,
     inMapOtherKey,
     inFlowMapFirstKey,
     inFlowMapOtherKey
   };
 
+  static bool inSeqAnyElement(InState State);
+  static bool inFlowSeqAnyElement(InState State);
+  static bool inMapAnyKey(InState State);
+  static bool inFlowMapAnyKey(InState State);
+
   raw_ostream &Out;
   int WrapColumn;
   SmallVector<InState, 8> StateStack;
@@ -1557,6 +1713,16 @@ operator>>(Input &In, T &Val) {
   return In;
 }
 
+// Define non-member operator>> so that Input can stream in a polymorphic type.
+template <typename T>
+inline typename std::enable_if<has_PolymorphicTraits<T>::value, Input &>::type
+operator>>(Input &In, T &Val) {
+  EmptyContext Ctx;
+  if (In.setCurrentDocument())
+    yamlize(In, Val, true, Ctx);
+  return In;
+}
+
 // Provide better error message about types missing a trait specialization
 template <typename T>
 inline typename std::enable_if<missingTraits<T, EmptyContext>::value,
@@ -1641,6 +1807,24 @@ operator<<(Output &Out, T &Val) {
     yamlize(Out, Val, true, Ctx);
     Out.postflightDocument();
   }
+  Out.endDocuments();
+  return Out;
+}
+
+// Define non-member operator<< so that Output can stream out a polymorphic
+// type.
+template <typename T>
+inline typename std::enable_if<has_PolymorphicTraits<T>::value, Output &>::type
+operator<<(Output &Out, T &Val) {
+  EmptyContext Ctx;
+  Out.beginDocuments();
+  if (Out.preflightDocument(0)) {
+    // FIXME: The parser does not support explicit documents terminated with a
+    // plain scalar; the end-marker is included as part of the scalar token.
+    assert(PolymorphicTraits<T>::getKind(Val) != NodeKind::Scalar && "plain scalar documents are not supported");
+    yamlize(Out, Val, true, Ctx);
+    Out.postflightDocument();
+  }
   Out.endDocuments();
   return Out;
 }

Modified: llvm/trunk/lib/Support/YAMLTraits.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Support/YAMLTraits.cpp?rev=346884&r1=346883&r2=346884&view=diff
==============================================================================
--- llvm/trunk/lib/Support/YAMLTraits.cpp (original)
+++ llvm/trunk/lib/Support/YAMLTraits.cpp Wed Nov 14 11:39:59 2018
@@ -341,11 +341,25 @@ void Input::scalarString(StringRef &S, Q
 
 void Input::blockScalarString(StringRef &S) { scalarString(S, QuotingType::None); }
 
+void Input::scalarTag(std::string &Tag) {
+  Tag = CurrentNode->_node->getVerbatimTag();
+}
+
 void Input::setError(HNode *hnode, const Twine &message) {
   assert(hnode && "HNode must not be NULL");
   setError(hnode->_node, message);
 }
 
+NodeKind Input::getNodeKind() {
+  if (isa<ScalarHNode>(CurrentNode))
+    return NodeKind::Scalar;
+  else if (isa<MapHNode>(CurrentNode))
+    return NodeKind::Map;
+  else if (isa<SequenceHNode>(CurrentNode))
+    return NodeKind::Sequence;
+  llvm_unreachable("Unsupported node kind");
+}
+
 void Input::setError(Node *node, const Twine &message) {
   Strm->printError(node, message);
   EC = make_error_code(errc::invalid_argument);
@@ -436,9 +450,11 @@ bool Output::mapTag(StringRef Tag, bool
     // If this tag is being written inside a sequence we should write the start
     // of the sequence before writing the tag, otherwise the tag won't be
     // attached to the element in the sequence, but rather the sequence itself.
-    bool SequenceElement =
-        StateStack.size() > 1 && (StateStack[StateStack.size() - 2] == inSeq ||
-          StateStack[StateStack.size() - 2] == inFlowSeq);
+    bool SequenceElement = false;
+    if (StateStack.size() > 1) {
+      auto &E = StateStack[StateStack.size() - 2];
+      SequenceElement = inSeqAnyElement(E) || inFlowSeqAnyElement(E);
+    }
     if (SequenceElement && StateStack.back() == inMapFirstKey) {
       newLineCheck();
     } else {
@@ -461,6 +477,9 @@ bool Output::mapTag(StringRef Tag, bool
 }
 
 void Output::endMapping() {
+  // If we did not map anything, we should explicitly emit an empty map
+  if (StateStack.back() == inMapFirstKey)
+    output("{}");
   StateStack.pop_back();
 }
 
@@ -524,12 +543,15 @@ void Output::endDocuments() {
 }
 
 unsigned Output::beginSequence() {
-  StateStack.push_back(inSeq);
+  StateStack.push_back(inSeqFirstElement);
   NeedsNewLine = true;
   return 0;
 }
 
 void Output::endSequence() {
+  // If we did not emit anything, we should explicitly emit an empty sequence
+  if (StateStack.back() == inSeqFirstElement)
+    output("[]");
   StateStack.pop_back();
 }
 
@@ -538,10 +560,17 @@ bool Output::preflightElement(unsigned,
 }
 
 void Output::postflightElement(void *) {
+  if (StateStack.back() == inSeqFirstElement) {
+    StateStack.pop_back();
+    StateStack.push_back(inSeqOtherElement);
+  } else if (StateStack.back() == inFlowSeqFirstElement) {
+    StateStack.pop_back();
+    StateStack.push_back(inFlowSeqOtherElement);
+  }
 }
 
 unsigned Output::beginFlowSequence() {
-  StateStack.push_back(inFlowSeq);
+  StateStack.push_back(inFlowSeqFirstElement);
   newLineCheck();
   ColumnAtFlowStart = Column;
   output("[ ");
@@ -680,6 +709,14 @@ void Output::blockScalarString(StringRef
   }
 }
 
+void Output::scalarTag(std::string &Tag) {
+  if (Tag.empty())
+    return;
+  newLineCheck();
+  output(Tag);
+  output(" ");
+}
+
 void Output::setError(const Twine &message) {
 }
 
@@ -693,7 +730,7 @@ bool Output::canElideEmptySequence() {
     return true;
   if (StateStack.back() != inMapFirstKey)
     return true;
-  return (StateStack[StateStack.size()-2] != inSeq);
+  return !inSeqAnyElement(StateStack[StateStack.size() - 2]);
 }
 
 void Output::output(StringRef s) {
@@ -703,9 +740,8 @@ void Output::output(StringRef s) {
 
 void Output::outputUpToEndOfLine(StringRef s) {
   output(s);
-  if (StateStack.empty() || (StateStack.back() != inFlowSeq &&
-                             StateStack.back() != inFlowMapFirstKey &&
-                             StateStack.back() != inFlowMapOtherKey))
+  if (StateStack.empty() || (!inFlowSeqAnyElement(StateStack.back()) &&
+                             !inFlowMapAnyKey(StateStack.back())))
     NeedsNewLine = true;
 }
 
@@ -725,16 +761,20 @@ void Output::newLineCheck() {
 
   outputNewLine();
 
-  assert(StateStack.size() > 0);
+  if (StateStack.size() == 0)
+    return;
+
   unsigned Indent = StateStack.size() - 1;
   bool OutputDash = false;
 
-  if (StateStack.back() == inSeq) {
+  if (StateStack.back() == inSeqFirstElement ||
+      StateStack.back() == inSeqOtherElement) {
     OutputDash = true;
-  } else if ((StateStack.size() > 1) && ((StateStack.back() == inMapFirstKey) ||
-             (StateStack.back() == inFlowSeq) ||
-             (StateStack.back() == inFlowMapFirstKey)) &&
-             (StateStack[StateStack.size() - 2] == inSeq)) {
+  } else if ((StateStack.size() > 1) &&
+             ((StateStack.back() == inMapFirstKey) ||
+              inFlowSeqAnyElement(StateStack.back()) ||
+              (StateStack.back() == inFlowMapFirstKey)) &&
+             inSeqAnyElement(StateStack[StateStack.size() - 2])) {
     --Indent;
     OutputDash = true;
   }
@@ -772,6 +812,24 @@ void Output::flowKey(StringRef Key) {
   output(": ");
 }
 
+NodeKind Output::getNodeKind() { report_fatal_error("invalid call"); }
+
+bool Output::inSeqAnyElement(InState State) {
+  return State == inSeqFirstElement || State == inSeqOtherElement;
+}
+
+bool Output::inFlowSeqAnyElement(InState State) {
+  return State == inFlowSeqFirstElement || State == inFlowSeqOtherElement;
+}
+
+bool Output::inMapAnyKey(InState State) {
+  return State == inMapFirstKey || State == inMapOtherKey;
+}
+
+bool Output::inFlowMapAnyKey(InState State) {
+  return State == inFlowMapFirstKey || State == inFlowMapOtherKey;
+}
+
 //===----------------------------------------------------------------------===//
 //  traits for built-in types
 //===----------------------------------------------------------------------===//

Modified: llvm/trunk/unittests/Support/YAMLIOTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Support/YAMLIOTest.cpp?rev=346884&r1=346883&r2=346884&view=diff
==============================================================================
--- llvm/trunk/unittests/Support/YAMLIOTest.cpp (original)
+++ llvm/trunk/unittests/Support/YAMLIOTest.cpp Wed Nov 14 11:39:59 2018
@@ -7,6 +7,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Casting.h"
@@ -2642,3 +2643,235 @@ TEST(YAMLIO, Numeric) {
   EXPECT_FALSE(isNumeric("-inf"));
   EXPECT_FALSE(isNumeric("1,230.15"));
 }
+
+//===----------------------------------------------------------------------===//
+//  Test PolymorphicTraits and TaggedScalarTraits
+//===----------------------------------------------------------------------===//
+
+struct Poly {
+  enum NodeKind {
+    NK_Scalar,
+    NK_Seq,
+    NK_Map,
+  } Kind;
+
+  Poly(NodeKind Kind) : Kind(Kind) {}
+
+  virtual ~Poly() = default;
+
+  NodeKind getKind() const { return Kind; }
+};
+
+struct Scalar : Poly {
+  enum ScalarKind {
+    SK_Unknown,
+    SK_Double,
+    SK_Bool,
+  } SKind;
+
+  union {
+    double DoubleValue;
+    bool BoolValue;
+  };
+
+  Scalar() : Poly(NK_Scalar), SKind(SK_Unknown) {}
+  Scalar(double DoubleValue)
+      : Poly(NK_Scalar), SKind(SK_Double), DoubleValue(DoubleValue) {}
+  Scalar(bool BoolValue)
+      : Poly(NK_Scalar), SKind(SK_Bool), BoolValue(BoolValue) {}
+
+  static bool classof(const Poly *N) { return N->getKind() == NK_Scalar; }
+};
+
+struct Seq : Poly, std::vector<std::unique_ptr<Poly>> {
+  Seq() : Poly(NK_Seq) {}
+
+  static bool classof(const Poly *N) { return N->getKind() == NK_Seq; }
+};
+
+struct Map : Poly, llvm::StringMap<std::unique_ptr<Poly>> {
+  Map() : Poly(NK_Map) {}
+
+  static bool classof(const Poly *N) { return N->getKind() == NK_Map; }
+};
+
+namespace llvm {
+namespace yaml {
+
+template <> struct PolymorphicTraits<std::unique_ptr<Poly>> {
+  static NodeKind getKind(const std::unique_ptr<Poly> &N) {
+    if (isa<Scalar>(*N))
+      return NodeKind::Scalar;
+    if (isa<Seq>(*N))
+      return NodeKind::Sequence;
+    if (isa<Map>(*N))
+      return NodeKind::Map;
+    llvm_unreachable("unsupported node type");
+  }
+
+  static Scalar &getAsScalar(std::unique_ptr<Poly> &N) {
+    if (!N || !isa<Scalar>(*N))
+      N = llvm::make_unique<Scalar>();
+    return *cast<Scalar>(N.get());
+  }
+
+  static Seq &getAsSequence(std::unique_ptr<Poly> &N) {
+    if (!N || !isa<Seq>(*N))
+      N = llvm::make_unique<Seq>();
+    return *cast<Seq>(N.get());
+  }
+
+  static Map &getAsMap(std::unique_ptr<Poly> &N) {
+    if (!N || !isa<Map>(*N))
+      N = llvm::make_unique<Map>();
+    return *cast<Map>(N.get());
+  }
+};
+
+template <> struct TaggedScalarTraits<Scalar> {
+  static void output(const Scalar &S, void *Ctxt, raw_ostream &ScalarOS,
+                     raw_ostream &TagOS) {
+    switch (S.SKind) {
+    case Scalar::SK_Unknown:
+      report_fatal_error("output unknown scalar");
+      break;
+    case Scalar::SK_Double:
+      TagOS << "!double";
+      ScalarTraits<double>::output(S.DoubleValue, Ctxt, ScalarOS);
+      break;
+    case Scalar::SK_Bool:
+      TagOS << "!bool";
+      ScalarTraits<bool>::output(S.BoolValue, Ctxt, ScalarOS);
+      break;
+    }
+  }
+
+  static StringRef input(StringRef ScalarStr, StringRef Tag, void *Ctxt,
+                         Scalar &S) {
+    S.SKind = StringSwitch<Scalar::ScalarKind>(Tag)
+                  .Case("!double", Scalar::SK_Double)
+                  .Case("!bool", Scalar::SK_Bool)
+                  .Default(Scalar::SK_Unknown);
+    switch (S.SKind) {
+    case Scalar::SK_Unknown:
+      return StringRef("unknown scalar tag");
+    case Scalar::SK_Double:
+      return ScalarTraits<double>::input(ScalarStr, Ctxt, S.DoubleValue);
+    case Scalar::SK_Bool:
+      return ScalarTraits<bool>::input(ScalarStr, Ctxt, S.BoolValue);
+    }
+    llvm_unreachable("unknown scalar kind");
+  }
+
+  static QuotingType mustQuote(const Scalar &S, StringRef Str) {
+    switch (S.SKind) {
+    case Scalar::SK_Unknown:
+      report_fatal_error("quote unknown scalar");
+    case Scalar::SK_Double:
+      return ScalarTraits<double>::mustQuote(Str);
+    case Scalar::SK_Bool:
+      return ScalarTraits<bool>::mustQuote(Str);
+    }
+    llvm_unreachable("unknown scalar kind");
+  }
+};
+
+template <> struct CustomMappingTraits<Map> {
+  static void inputOne(IO &IO, StringRef Key, Map &M) {
+    IO.mapRequired(Key.str().c_str(), M[Key]);
+  }
+
+  static void output(IO &IO, Map &M) {
+    for (auto &N : M)
+      IO.mapRequired(N.getKey().str().c_str(), N.getValue());
+  }
+};
+
+template <> struct SequenceTraits<Seq> {
+  static size_t size(IO &IO, Seq &A) { return A.size(); }
+
+  static std::unique_ptr<Poly> &element(IO &IO, Seq &A, size_t Index) {
+    if (Index >= A.size())
+      A.resize(Index + 1);
+    return A[Index];
+  }
+};
+
+} // namespace yaml
+} // namespace llvm
+
+TEST(YAMLIO, TestReadWritePolymorphicScalar) {
+  std::string intermediate;
+  std::unique_ptr<Poly> node = llvm::make_unique<Scalar>(true);
+
+  llvm::raw_string_ostream ostr(intermediate);
+  Output yout(ostr);
+#ifdef GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+  EXPECT_DEATH(yout << node, "plain scalar documents are not supported");
+#endif
+#endif
+}
+
+TEST(YAMLIO, TestReadWritePolymorphicSeq) {
+  std::string intermediate;
+  {
+    auto seq = llvm::make_unique<Seq>();
+    seq->push_back(llvm::make_unique<Scalar>(true));
+    seq->push_back(llvm::make_unique<Scalar>(1.0));
+    auto node = llvm::unique_dyn_cast<Poly>(seq);
+
+    llvm::raw_string_ostream ostr(intermediate);
+    Output yout(ostr);
+    yout << node;
+  }
+  {
+    Input yin(intermediate);
+    std::unique_ptr<Poly> node;
+    yin >> node;
+
+    EXPECT_FALSE(yin.error());
+    auto seq = llvm::dyn_cast<Seq>(node.get());
+    ASSERT_TRUE(seq);
+    ASSERT_EQ(seq->size(), 2u);
+    auto first = llvm::dyn_cast<Scalar>((*seq)[0].get());
+    ASSERT_TRUE(first);
+    EXPECT_EQ(first->SKind, Scalar::SK_Bool);
+    EXPECT_TRUE(first->BoolValue);
+    auto second = llvm::dyn_cast<Scalar>((*seq)[1].get());
+    ASSERT_TRUE(second);
+    EXPECT_EQ(second->SKind, Scalar::SK_Double);
+    EXPECT_EQ(second->DoubleValue, 1.0);
+  }
+}
+
+TEST(YAMLIO, TestReadWritePolymorphicMap) {
+  std::string intermediate;
+  {
+    auto map = llvm::make_unique<Map>();
+    (*map)["foo"] = llvm::make_unique<Scalar>(false);
+    (*map)["bar"] = llvm::make_unique<Scalar>(2.0);
+    std::unique_ptr<Poly> node = llvm::unique_dyn_cast<Poly>(map);
+
+    llvm::raw_string_ostream ostr(intermediate);
+    Output yout(ostr);
+    yout << node;
+  }
+  {
+    Input yin(intermediate);
+    std::unique_ptr<Poly> node;
+    yin >> node;
+
+    EXPECT_FALSE(yin.error());
+    auto map = llvm::dyn_cast<Map>(node.get());
+    ASSERT_TRUE(map);
+    auto foo = llvm::dyn_cast<Scalar>((*map)["foo"].get());
+    ASSERT_TRUE(foo);
+    EXPECT_EQ(foo->SKind, Scalar::SK_Bool);
+    EXPECT_FALSE(foo->BoolValue);
+    auto bar = llvm::dyn_cast<Scalar>((*map)["bar"].get());
+    ASSERT_TRUE(bar);
+    EXPECT_EQ(bar->SKind, Scalar::SK_Double);
+    EXPECT_EQ(bar->DoubleValue, 2.0);
+  }
+}




More information about the llvm-commits mailing list