[clang-tools-extra] r342888 - [clangd] Do bounds checks while reading data, otherwise var-length records are too painful. NFC

Sam McCall via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 24 07:51:15 PDT 2018


Author: sammccall
Date: Mon Sep 24 07:51:15 2018
New Revision: 342888

URL: http://llvm.org/viewvc/llvm-project?rev=342888&view=rev
Log:
[clangd] Do bounds checks while reading data, otherwise var-length records are too painful. NFC

Modified:
    clang-tools-extra/trunk/clangd/index/Serialization.cpp

Modified: clang-tools-extra/trunk/clangd/index/Serialization.cpp
URL: http://llvm.org/viewvc/llvm-project/clang-tools-extra/trunk/clangd/index/Serialization.cpp?rev=342888&r1=342887&r2=342888&view=diff
==============================================================================
--- clang-tools-extra/trunk/clangd/index/Serialization.cpp (original)
+++ clang-tools-extra/trunk/clangd/index/Serialization.cpp Mon Sep 24 07:51:15 2018
@@ -23,24 +23,83 @@ Error makeError(const Twine &Msg) {
 
 // IO PRIMITIVES
 // We use little-endian 32 bit ints, sometimes with variable-length encoding.
+//
+// Variable-length int encoding (varint) uses the bottom 7 bits of each byte
+// to encode the number, and the top bit to indicate whether more bytes follow.
+// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop].
+// This represents 0x1a | 0x2f<<7 = 6042.
+// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact.
 
-StringRef consume(StringRef &Data, int N) {
-  StringRef Ret = Data.take_front(N);
-  Data = Data.drop_front(N);
-  return Ret;
-}
+// Reads binary data from a StringRef, and keeps track of position.
+class Reader {
+  const char *Begin, *End;
+  bool Err;
 
-uint8_t consume8(StringRef &Data) {
-  uint8_t Ret = Data.front();
-  Data = Data.drop_front();
-  return Ret;
-}
+public:
+  Reader(StringRef Data) : Begin(Data.begin()), End(Data.end()) {}
+  // The "error" bit is set by reading past EOF or reading invalid data.
+  // When in an error state, reads may return zero values: callers should check.
+  bool err() const { return Err; }
+  // Did we read all the data, or encounter an error?
+  bool eof() const { return Begin == End || Err; }
+  // All the data we didn't read yet.
+  StringRef rest() const { return StringRef(Begin, End - Begin); }
+
+  uint8_t consume8() {
+    if (LLVM_UNLIKELY(Begin == End)) {
+      Err = true;
+      return 0;
+    }
+    return *Begin++;
+  }
 
-uint32_t consume32(StringRef &Data) {
-  auto Ret = support::endian::read32le(Data.bytes_begin());
-  Data = Data.drop_front(4);
-  return Ret;
-}
+  uint32_t consume32() {
+    if (LLVM_UNLIKELY(Begin + 4 > End)) {
+      Err = true;
+      return 0;
+    }
+    auto Ret = support::endian::read32le(Begin);
+    Begin += 4;
+    return Ret;
+  }
+
+  StringRef consume(int N) {
+    if (LLVM_UNLIKELY(Begin + N > End)) {
+      Err = true;
+      return StringRef();
+    }
+    StringRef Ret(Begin, N);
+    Begin += N;
+    return Ret;
+  }
+
+  uint32_t consumeVar() {
+    constexpr static uint8_t More = 1 << 7;
+    uint8_t B = consume8();
+    if (LLVM_LIKELY(!(B & More)))
+      return B;
+    uint32_t Val = B & ~More;
+    for (int Shift = 7; B & More && Shift < 32; Shift += 7) {
+      B = consume8();
+      Val |= (B & ~More) << Shift;
+    }
+    return Val;
+  }
+
+  StringRef consumeString(ArrayRef<StringRef> Strings) {
+    auto StringIndex = consumeVar();
+    if (LLVM_UNLIKELY(StringIndex >= Strings.size())) {
+      Err = true;
+      return StringRef();
+    }
+    return Strings[StringIndex];
+  }
+
+  SymbolID consumeID() {
+    StringRef Raw = consume(SymbolID::RawSize); // short if truncated.
+    return LLVM_UNLIKELY(err()) ? SymbolID() : SymbolID::fromRaw(Raw);
+  }
+};
 
 void write32(uint32_t I, raw_ostream &OS) {
   char buf[4];
@@ -48,11 +107,6 @@ void write32(uint32_t I, raw_ostream &OS
   OS.write(buf, sizeof(buf));
 }
 
-// Variable-length int encoding (varint) uses the bottom 7 bits of each byte
-// to encode the number, and the top bit to indicate whether more bytes follow.
-// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop].
-// This represents 0x1a | 0x2f<<7 = 6042.
-// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact.
 void writeVar(uint32_t I, raw_ostream &OS) {
   constexpr static uint8_t More = 1 << 7;
   if (LLVM_LIKELY(I < 1 << 7)) {
@@ -69,19 +123,6 @@ void writeVar(uint32_t I, raw_ostream &O
   }
 }
 
-uint32_t consumeVar(StringRef &Data) {
-  constexpr static uint8_t More = 1 << 7;
-  uint8_t B = consume8(Data);
-  if (LLVM_LIKELY(!(B & More)))
-    return B;
-  uint32_t Val = B & ~More;
-  for (int Shift = 7; B & More && Shift < 32; Shift += 7) {
-    B = consume8(Data);
-    Val |= (B & ~More) << Shift;
-  }
-  return Val;
-}
-
 // STRING TABLE ENCODING
 // Index data has many string fields, and many strings are identical.
 // We store each string once, and refer to them by index.
@@ -146,30 +187,34 @@ struct StringTableIn {
 };
 
 Expected<StringTableIn> readStringTable(StringRef Data) {
-  if (Data.size() < 4)
-    return makeError("Bad string table: not enough metadata");
-  size_t UncompressedSize = consume32(Data);
+  Reader R(Data);
+  size_t UncompressedSize = R.consume32();
+  if (R.err())
+    return makeError("Truncated string table");
 
   StringRef Uncompressed;
   SmallString<1> UncompressedStorage;
   if (UncompressedSize == 0) // No compression
-    Uncompressed = Data;
+    Uncompressed = R.rest();
   else {
-    if (Error E =
-            llvm::zlib::uncompress(Data, UncompressedStorage, UncompressedSize))
+    if (Error E = llvm::zlib::uncompress(R.rest(), UncompressedStorage,
+                                         UncompressedSize))
       return std::move(E);
     Uncompressed = UncompressedStorage;
   }
 
   StringTableIn Table;
   StringSaver Saver(Table.Arena);
-  for (StringRef Rest = Uncompressed; !Rest.empty();) {
-    auto Len = Rest.find(0);
+  R = Reader(Uncompressed);
+  for (Reader R(Uncompressed); !R.eof();) {
+    auto Len = R.rest().find(0);
     if (Len == StringRef::npos)
       return makeError("Bad string table: not null terminated");
-    Table.Strings.push_back(Saver.save(consume(Rest, Len)));
-    Rest = Rest.drop_front();
+    Table.Strings.push_back(Saver.save(R.consume(Len)));
+    R.consume8();
   }
+  if (R.err())
+    return makeError("Truncated string table");
   return std::move(Table);
 }
 
@@ -179,27 +224,35 @@ Expected<StringTableIn> readStringTable(
 //  - enums encode as the underlying type
 //  - most numbers encode as varint
 
-// It's useful to the implementation to assume symbols have a bounded size.
-constexpr size_t SymbolSizeBound = 512;
-// To ensure the bounded size, restrict the number of include headers stored.
-constexpr unsigned MaxIncludes = 50;
+void writeLocation(const SymbolLocation &Loc, const StringTableOut &Strings,
+                   raw_ostream &OS) {
+  writeVar(Strings.index(Loc.FileURI), OS);
+  for (const auto &Endpoint : {Loc.Start, Loc.End}) {
+    writeVar(Endpoint.Line, OS);
+    writeVar(Endpoint.Column, OS);
+  }
+}
+
+SymbolLocation readLocation(Reader &Data, ArrayRef<StringRef> Strings) {
+  SymbolLocation Loc;
+  Loc.FileURI = Data.consumeString(Strings);
+  for (auto *Endpoint : {&Loc.Start, &Loc.End}) {
+    Endpoint->Line = Data.consumeVar();
+    Endpoint->Column = Data.consumeVar();
+  }
+  return Loc;
+}
 
 void writeSymbol(const Symbol &Sym, const StringTableOut &Strings,
                  raw_ostream &OS) {
-  auto StartOffset = OS.tell();
   OS << Sym.ID.raw(); // TODO: once we start writing xrefs and posting lists,
                       // symbol IDs should probably be in a string table.
   OS.write(static_cast<uint8_t>(Sym.SymInfo.Kind));
   OS.write(static_cast<uint8_t>(Sym.SymInfo.Lang));
   writeVar(Strings.index(Sym.Name), OS);
   writeVar(Strings.index(Sym.Scope), OS);
-  for (const auto &Loc : {Sym.Definition, Sym.CanonicalDeclaration}) {
-    writeVar(Strings.index(Loc.FileURI), OS);
-    for (const auto &Endpoint : {Loc.Start, Loc.End}) {
-      writeVar(Endpoint.Line, OS);
-      writeVar(Endpoint.Column, OS);
-    }
-  }
+  writeLocation(Sym.Definition, Strings, OS);
+  writeLocation(Sym.CanonicalDeclaration, Strings, OS);
   writeVar(Sym.References, OS);
   OS.write(static_cast<uint8_t>(Sym.Flags));
   OS.write(static_cast<uint8_t>(Sym.Origin));
@@ -212,86 +265,33 @@ void writeSymbol(const Symbol &Sym, cons
     writeVar(Strings.index(Include.IncludeHeader), OS);
     writeVar(Include.References, OS);
   };
-  // There are almost certainly few includes, so we can just write them.
-  if (LLVM_LIKELY(Sym.IncludeHeaders.size() <= MaxIncludes)) {
-    writeVar(Sym.IncludeHeaders.size(), OS);
-    for (const auto &Include : Sym.IncludeHeaders)
-      WriteInclude(Include);
-  } else {
-    // If there are too many, make sure we truncate the least important.
-    using Pointer = const Symbol::IncludeHeaderWithReferences *;
-    std::vector<Pointer> Pointers;
-    for (const auto &Include : Sym.IncludeHeaders)
-      Pointers.push_back(&Include);
-    std::sort(Pointers.begin(), Pointers.end(), [](Pointer L, Pointer R) {
-      return L->References > R->References;
-    });
-    Pointers.resize(MaxIncludes);
-
-    writeVar(MaxIncludes, OS);
-    for (Pointer P : Pointers)
-      WriteInclude(*P);
-  }
-
-  assert(OS.tell() - StartOffset < SymbolSizeBound && "Symbol length unsafe!");
-  (void)StartOffset; // Unused in NDEBUG;
+  writeVar(Sym.IncludeHeaders.size(), OS);
+  for (const auto &Include : Sym.IncludeHeaders)
+    WriteInclude(Include);
 }
 
-Expected<Symbol> readSymbol(StringRef &Data, const StringTableIn &Strings) {
-  // Usually we can skip bounds checks because the buffer is huge.
-  // Near the end of the buffer, this would be unsafe. In this rare case, copy
-  // the data into a bigger buffer so we can again skip the checks.
-  if (LLVM_UNLIKELY(Data.size() < SymbolSizeBound)) {
-    std::string Buf(Data);
-    Buf.resize(SymbolSizeBound);
-    StringRef ExtendedData = Buf;
-    auto Ret = readSymbol(ExtendedData, Strings);
-    unsigned BytesRead = Buf.size() - ExtendedData.size();
-    if (BytesRead > Data.size())
-      return makeError("read past end of data");
-    Data = Data.drop_front(BytesRead);
-    return Ret;
-  }
-
-#define READ_STRING(Field)                                                     \
-  do {                                                                         \
-    auto StringIndex = consumeVar(Data);                                       \
-    if (LLVM_UNLIKELY(StringIndex >= Strings.Strings.size()))                  \
-      return makeError("Bad string index");                                    \
-    Field = Strings.Strings[StringIndex];                                      \
-  } while (0)
-
+Symbol readSymbol(Reader &Data, ArrayRef<StringRef> Strings) {
   Symbol Sym;
-  Sym.ID = SymbolID::fromRaw(consume(Data, 20));
-  Sym.SymInfo.Kind = static_cast<index::SymbolKind>(consume8(Data));
-  Sym.SymInfo.Lang = static_cast<index::SymbolLanguage>(consume8(Data));
-  READ_STRING(Sym.Name);
-  READ_STRING(Sym.Scope);
-  for (SymbolLocation *Loc : {&Sym.Definition, &Sym.CanonicalDeclaration}) {
-    READ_STRING(Loc->FileURI);
-    for (auto &Endpoint : {&Loc->Start, &Loc->End}) {
-      Endpoint->Line = consumeVar(Data);
-      Endpoint->Column = consumeVar(Data);
-    }
-  }
-  Sym.References = consumeVar(Data);
-  Sym.Flags = static_cast<Symbol::SymbolFlag>(consume8(Data));
-  Sym.Origin = static_cast<SymbolOrigin>(consume8(Data));
-  READ_STRING(Sym.Signature);
-  READ_STRING(Sym.CompletionSnippetSuffix);
-  READ_STRING(Sym.Documentation);
-  READ_STRING(Sym.ReturnType);
-  unsigned IncludeHeaderN = consumeVar(Data);
-  if (IncludeHeaderN > MaxIncludes)
-    return makeError("too many IncludeHeaders");
-  Sym.IncludeHeaders.resize(IncludeHeaderN);
+  Sym.ID = Data.consumeID();
+  Sym.SymInfo.Kind = static_cast<index::SymbolKind>(Data.consume8());
+  Sym.SymInfo.Lang = static_cast<index::SymbolLanguage>(Data.consume8());
+  Sym.Name = Data.consumeString(Strings);
+  Sym.Scope = Data.consumeString(Strings);
+  Sym.Definition = readLocation(Data, Strings);
+  Sym.CanonicalDeclaration = readLocation(Data, Strings);
+  Sym.References = Data.consumeVar();
+  Sym.Flags = static_cast<Symbol::SymbolFlag>(Data.consumeVar());
+  Sym.Origin = static_cast<SymbolOrigin>(Data.consumeVar());
+  Sym.Signature = Data.consumeString(Strings);
+  Sym.CompletionSnippetSuffix = Data.consumeString(Strings);
+  Sym.Documentation = Data.consumeString(Strings);
+  Sym.ReturnType = Data.consumeString(Strings);
+  Sym.IncludeHeaders.resize(Data.consumeVar());
   for (auto &I : Sym.IncludeHeaders) {
-    READ_STRING(I.IncludeHeader);
-    I.References = consumeVar(Data);
+    I.IncludeHeader = Data.consumeString(Strings);
+    I.References = Data.consumeVar();
   }
-
-#undef READ_STRING
-  return std::move(Sym);
+  return Sym;
 }
 
 } // namespace
@@ -306,7 +306,7 @@ Expected<Symbol> readSymbol(StringRef &D
 // The current versioning scheme is simple - non-current versions are rejected.
 // If you make a breaking change, bump this version number to invalidate stored
 // data. Later we may want to support some backward compatibility.
-constexpr static uint32_t Version = 3;
+constexpr static uint32_t Version = 4;
 
 Expected<IndexFileIn> readIndexFile(StringRef Data) {
   auto RIFF = riff::readFile(Data);
@@ -322,8 +322,8 @@ Expected<IndexFileIn> readIndexFile(Stri
     if (!Chunks.count(RequiredChunk))
       return makeError("missing required chunk " + RequiredChunk);
 
-  StringRef Meta = Chunks.lookup("meta");
-  if (Meta.size() < 4 || consume32(Meta) != Version)
+  Reader Meta(Chunks.lookup("meta"));
+  if (Meta.consume32() != Version)
     return makeError("wrong version");
 
   auto Strings = readStringTable(Chunks.lookup("stri"));
@@ -332,13 +332,12 @@ Expected<IndexFileIn> readIndexFile(Stri
 
   IndexFileIn Result;
   if (Chunks.count("symb")) {
-    StringRef SymbolData = Chunks.lookup("symb");
+    Reader SymbolReader(Chunks.lookup("symb"));
     SymbolSlab::Builder Symbols;
-    while (!SymbolData.empty())
-      if (auto Sym = readSymbol(SymbolData, *Strings))
-        Symbols.insert(*Sym);
-      else
-        return Sym.takeError();
+    while (!SymbolReader.eof())
+      Symbols.insert(readSymbol(SymbolReader, Strings->Strings));
+    if (SymbolReader.err())
+      return makeError("malformed or truncated symbol");
     Result.Symbols = std::move(Symbols).build();
   }
   return std::move(Result);




More information about the cfe-commits mailing list