[clang-tools-extra] r342888 - [clangd] Do bounds checks while reading data, otherwise var-length records are too painful. NFC
Sam McCall via cfe-commits
cfe-commits at lists.llvm.org
Mon Sep 24 07:51:15 PDT 2018
Author: sammccall
Date: Mon Sep 24 07:51:15 2018
New Revision: 342888
URL: http://llvm.org/viewvc/llvm-project?rev=342888&view=rev
Log:
[clangd] Do bounds checks while reading data, otherwise var-length records are too painful. NFC
Modified:
clang-tools-extra/trunk/clangd/index/Serialization.cpp
Modified: clang-tools-extra/trunk/clangd/index/Serialization.cpp
URL: http://llvm.org/viewvc/llvm-project/clang-tools-extra/trunk/clangd/index/Serialization.cpp?rev=342888&r1=342887&r2=342888&view=diff
==============================================================================
--- clang-tools-extra/trunk/clangd/index/Serialization.cpp (original)
+++ clang-tools-extra/trunk/clangd/index/Serialization.cpp Mon Sep 24 07:51:15 2018
@@ -23,24 +23,83 @@ Error makeError(const Twine &Msg) {
// IO PRIMITIVES
// We use little-endian 32 bit ints, sometimes with variable-length encoding.
+//
+// Variable-length int encoding (varint) uses the bottom 7 bits of each byte
+// to encode the number, and the top bit to indicate whether more bytes follow.
+// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop].
+// This represents 0x1a | 0x2f<<7 = 6042.
+// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact.
-StringRef consume(StringRef &Data, int N) {
- StringRef Ret = Data.take_front(N);
- Data = Data.drop_front(N);
- return Ret;
-}
+// Reads binary data from a StringRef, and keeps track of position.
+class Reader {
+ const char *Begin, *End;
+ bool Err;
-uint8_t consume8(StringRef &Data) {
- uint8_t Ret = Data.front();
- Data = Data.drop_front();
- return Ret;
-}
+public:
+ Reader(StringRef Data) : Begin(Data.begin()), End(Data.end()) {}
+ // The "error" bit is set by reading past EOF or reading invalid data.
+ // When in an error state, reads may return zero values: callers should check.
+ bool err() const { return Err; }
+ // Did we read all the data, or encounter an error?
+ bool eof() const { return Begin == End || Err; }
+ // All the data we didn't read yet.
+ StringRef rest() const { return StringRef(Begin, End - Begin); }
+
+ uint8_t consume8() {
+ if (LLVM_UNLIKELY(Begin == End)) {
+ Err = true;
+ return 0;
+ }
+ return *Begin++;
+ }
-uint32_t consume32(StringRef &Data) {
- auto Ret = support::endian::read32le(Data.bytes_begin());
- Data = Data.drop_front(4);
- return Ret;
-}
+ uint32_t consume32() {
+ if (LLVM_UNLIKELY(Begin + 4 > End)) {
+ Err = true;
+ return 0;
+ }
+ auto Ret = support::endian::read32le(Begin);
+ Begin += 4;
+ return Ret;
+ }
+
+ StringRef consume(int N) {
+ if (LLVM_UNLIKELY(Begin + N > End)) {
+ Err = true;
+ return StringRef();
+ }
+ StringRef Ret(Begin, N);
+ Begin += N;
+ return Ret;
+ }
+
+ uint32_t consumeVar() {
+ constexpr static uint8_t More = 1 << 7;
+ uint8_t B = consume8();
+ if (LLVM_LIKELY(!(B & More)))
+ return B;
+ uint32_t Val = B & ~More;
+ for (int Shift = 7; B & More && Shift < 32; Shift += 7) {
+ B = consume8();
+ Val |= (B & ~More) << Shift;
+ }
+ return Val;
+ }
+
+ StringRef consumeString(ArrayRef<StringRef> Strings) {
+ auto StringIndex = consumeVar();
+ if (LLVM_UNLIKELY(StringIndex >= Strings.size())) {
+ Err = true;
+ return StringRef();
+ }
+ return Strings[StringIndex];
+ }
+
+ SymbolID consumeID() {
+ StringRef Raw = consume(SymbolID::RawSize); // short if truncated.
+ return LLVM_UNLIKELY(err()) ? SymbolID() : SymbolID::fromRaw(Raw);
+ }
+};
void write32(uint32_t I, raw_ostream &OS) {
char buf[4];
@@ -48,11 +107,6 @@ void write32(uint32_t I, raw_ostream &OS
OS.write(buf, sizeof(buf));
}
-// Variable-length int encoding (varint) uses the bottom 7 bits of each byte
-// to encode the number, and the top bit to indicate whether more bytes follow.
-// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop].
-// This represents 0x1a | 0x2f<<7 = 6042.
-// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact.
void writeVar(uint32_t I, raw_ostream &OS) {
constexpr static uint8_t More = 1 << 7;
if (LLVM_LIKELY(I < 1 << 7)) {
@@ -69,19 +123,6 @@ void writeVar(uint32_t I, raw_ostream &O
}
}
-uint32_t consumeVar(StringRef &Data) {
- constexpr static uint8_t More = 1 << 7;
- uint8_t B = consume8(Data);
- if (LLVM_LIKELY(!(B & More)))
- return B;
- uint32_t Val = B & ~More;
- for (int Shift = 7; B & More && Shift < 32; Shift += 7) {
- B = consume8(Data);
- Val |= (B & ~More) << Shift;
- }
- return Val;
-}
-
// STRING TABLE ENCODING
// Index data has many string fields, and many strings are identical.
// We store each string once, and refer to them by index.
@@ -146,30 +187,34 @@ struct StringTableIn {
};
Expected<StringTableIn> readStringTable(StringRef Data) {
- if (Data.size() < 4)
- return makeError("Bad string table: not enough metadata");
- size_t UncompressedSize = consume32(Data);
+ Reader R(Data);
+ size_t UncompressedSize = R.consume32();
+ if (R.err())
+ return makeError("Truncated string table");
StringRef Uncompressed;
SmallString<1> UncompressedStorage;
if (UncompressedSize == 0) // No compression
- Uncompressed = Data;
+ Uncompressed = R.rest();
else {
- if (Error E =
- llvm::zlib::uncompress(Data, UncompressedStorage, UncompressedSize))
+ if (Error E = llvm::zlib::uncompress(R.rest(), UncompressedStorage,
+ UncompressedSize))
return std::move(E);
Uncompressed = UncompressedStorage;
}
StringTableIn Table;
StringSaver Saver(Table.Arena);
- for (StringRef Rest = Uncompressed; !Rest.empty();) {
- auto Len = Rest.find(0);
+ R = Reader(Uncompressed);
+ for (Reader R(Uncompressed); !R.eof();) {
+ auto Len = R.rest().find(0);
if (Len == StringRef::npos)
return makeError("Bad string table: not null terminated");
- Table.Strings.push_back(Saver.save(consume(Rest, Len)));
- Rest = Rest.drop_front();
+ Table.Strings.push_back(Saver.save(R.consume(Len)));
+ R.consume8();
}
+ if (R.err())
+ return makeError("Truncated string table");
return std::move(Table);
}
@@ -179,27 +224,35 @@ Expected<StringTableIn> readStringTable(
// - enums encode as the underlying type
// - most numbers encode as varint
-// It's useful to the implementation to assume symbols have a bounded size.
-constexpr size_t SymbolSizeBound = 512;
-// To ensure the bounded size, restrict the number of include headers stored.
-constexpr unsigned MaxIncludes = 50;
+void writeLocation(const SymbolLocation &Loc, const StringTableOut &Strings,
+ raw_ostream &OS) {
+ writeVar(Strings.index(Loc.FileURI), OS);
+ for (const auto &Endpoint : {Loc.Start, Loc.End}) {
+ writeVar(Endpoint.Line, OS);
+ writeVar(Endpoint.Column, OS);
+ }
+}
+
+SymbolLocation readLocation(Reader &Data, ArrayRef<StringRef> Strings) {
+ SymbolLocation Loc;
+ Loc.FileURI = Data.consumeString(Strings);
+ for (auto *Endpoint : {&Loc.Start, &Loc.End}) {
+ Endpoint->Line = Data.consumeVar();
+ Endpoint->Column = Data.consumeVar();
+ }
+ return Loc;
+}
void writeSymbol(const Symbol &Sym, const StringTableOut &Strings,
raw_ostream &OS) {
- auto StartOffset = OS.tell();
OS << Sym.ID.raw(); // TODO: once we start writing xrefs and posting lists,
// symbol IDs should probably be in a string table.
OS.write(static_cast<uint8_t>(Sym.SymInfo.Kind));
OS.write(static_cast<uint8_t>(Sym.SymInfo.Lang));
writeVar(Strings.index(Sym.Name), OS);
writeVar(Strings.index(Sym.Scope), OS);
- for (const auto &Loc : {Sym.Definition, Sym.CanonicalDeclaration}) {
- writeVar(Strings.index(Loc.FileURI), OS);
- for (const auto &Endpoint : {Loc.Start, Loc.End}) {
- writeVar(Endpoint.Line, OS);
- writeVar(Endpoint.Column, OS);
- }
- }
+ writeLocation(Sym.Definition, Strings, OS);
+ writeLocation(Sym.CanonicalDeclaration, Strings, OS);
writeVar(Sym.References, OS);
OS.write(static_cast<uint8_t>(Sym.Flags));
OS.write(static_cast<uint8_t>(Sym.Origin));
@@ -212,86 +265,33 @@ void writeSymbol(const Symbol &Sym, cons
writeVar(Strings.index(Include.IncludeHeader), OS);
writeVar(Include.References, OS);
};
- // There are almost certainly few includes, so we can just write them.
- if (LLVM_LIKELY(Sym.IncludeHeaders.size() <= MaxIncludes)) {
- writeVar(Sym.IncludeHeaders.size(), OS);
- for (const auto &Include : Sym.IncludeHeaders)
- WriteInclude(Include);
- } else {
- // If there are too many, make sure we truncate the least important.
- using Pointer = const Symbol::IncludeHeaderWithReferences *;
- std::vector<Pointer> Pointers;
- for (const auto &Include : Sym.IncludeHeaders)
- Pointers.push_back(&Include);
- std::sort(Pointers.begin(), Pointers.end(), [](Pointer L, Pointer R) {
- return L->References > R->References;
- });
- Pointers.resize(MaxIncludes);
-
- writeVar(MaxIncludes, OS);
- for (Pointer P : Pointers)
- WriteInclude(*P);
- }
-
- assert(OS.tell() - StartOffset < SymbolSizeBound && "Symbol length unsafe!");
- (void)StartOffset; // Unused in NDEBUG;
+ writeVar(Sym.IncludeHeaders.size(), OS);
+ for (const auto &Include : Sym.IncludeHeaders)
+ WriteInclude(Include);
}
-Expected<Symbol> readSymbol(StringRef &Data, const StringTableIn &Strings) {
- // Usually we can skip bounds checks because the buffer is huge.
- // Near the end of the buffer, this would be unsafe. In this rare case, copy
- // the data into a bigger buffer so we can again skip the checks.
- if (LLVM_UNLIKELY(Data.size() < SymbolSizeBound)) {
- std::string Buf(Data);
- Buf.resize(SymbolSizeBound);
- StringRef ExtendedData = Buf;
- auto Ret = readSymbol(ExtendedData, Strings);
- unsigned BytesRead = Buf.size() - ExtendedData.size();
- if (BytesRead > Data.size())
- return makeError("read past end of data");
- Data = Data.drop_front(BytesRead);
- return Ret;
- }
-
-#define READ_STRING(Field) \
- do { \
- auto StringIndex = consumeVar(Data); \
- if (LLVM_UNLIKELY(StringIndex >= Strings.Strings.size())) \
- return makeError("Bad string index"); \
- Field = Strings.Strings[StringIndex]; \
- } while (0)
-
+Symbol readSymbol(Reader &Data, ArrayRef<StringRef> Strings) {
Symbol Sym;
- Sym.ID = SymbolID::fromRaw(consume(Data, 20));
- Sym.SymInfo.Kind = static_cast<index::SymbolKind>(consume8(Data));
- Sym.SymInfo.Lang = static_cast<index::SymbolLanguage>(consume8(Data));
- READ_STRING(Sym.Name);
- READ_STRING(Sym.Scope);
- for (SymbolLocation *Loc : {&Sym.Definition, &Sym.CanonicalDeclaration}) {
- READ_STRING(Loc->FileURI);
- for (auto &Endpoint : {&Loc->Start, &Loc->End}) {
- Endpoint->Line = consumeVar(Data);
- Endpoint->Column = consumeVar(Data);
- }
- }
- Sym.References = consumeVar(Data);
- Sym.Flags = static_cast<Symbol::SymbolFlag>(consume8(Data));
- Sym.Origin = static_cast<SymbolOrigin>(consume8(Data));
- READ_STRING(Sym.Signature);
- READ_STRING(Sym.CompletionSnippetSuffix);
- READ_STRING(Sym.Documentation);
- READ_STRING(Sym.ReturnType);
- unsigned IncludeHeaderN = consumeVar(Data);
- if (IncludeHeaderN > MaxIncludes)
- return makeError("too many IncludeHeaders");
- Sym.IncludeHeaders.resize(IncludeHeaderN);
+ Sym.ID = Data.consumeID();
+ Sym.SymInfo.Kind = static_cast<index::SymbolKind>(Data.consume8());
+ Sym.SymInfo.Lang = static_cast<index::SymbolLanguage>(Data.consume8());
+ Sym.Name = Data.consumeString(Strings);
+ Sym.Scope = Data.consumeString(Strings);
+ Sym.Definition = readLocation(Data, Strings);
+ Sym.CanonicalDeclaration = readLocation(Data, Strings);
+ Sym.References = Data.consumeVar();
+ Sym.Flags = static_cast<Symbol::SymbolFlag>(Data.consumeVar());
+ Sym.Origin = static_cast<SymbolOrigin>(Data.consumeVar());
+ Sym.Signature = Data.consumeString(Strings);
+ Sym.CompletionSnippetSuffix = Data.consumeString(Strings);
+ Sym.Documentation = Data.consumeString(Strings);
+ Sym.ReturnType = Data.consumeString(Strings);
+ Sym.IncludeHeaders.resize(Data.consumeVar());
for (auto &I : Sym.IncludeHeaders) {
- READ_STRING(I.IncludeHeader);
- I.References = consumeVar(Data);
+ I.IncludeHeader = Data.consumeString(Strings);
+ I.References = Data.consumeVar();
}
-
-#undef READ_STRING
- return std::move(Sym);
+ return Sym;
}
} // namespace
@@ -306,7 +306,7 @@ Expected<Symbol> readSymbol(StringRef &D
// The current versioning scheme is simple - non-current versions are rejected.
// If you make a breaking change, bump this version number to invalidate stored
// data. Later we may want to support some backward compatibility.
-constexpr static uint32_t Version = 3;
+constexpr static uint32_t Version = 4;
Expected<IndexFileIn> readIndexFile(StringRef Data) {
auto RIFF = riff::readFile(Data);
@@ -322,8 +322,8 @@ Expected<IndexFileIn> readIndexFile(Stri
if (!Chunks.count(RequiredChunk))
return makeError("missing required chunk " + RequiredChunk);
- StringRef Meta = Chunks.lookup("meta");
- if (Meta.size() < 4 || consume32(Meta) != Version)
+ Reader Meta(Chunks.lookup("meta"));
+ if (Meta.consume32() != Version)
return makeError("wrong version");
auto Strings = readStringTable(Chunks.lookup("stri"));
@@ -332,13 +332,12 @@ Expected<IndexFileIn> readIndexFile(Stri
IndexFileIn Result;
if (Chunks.count("symb")) {
- StringRef SymbolData = Chunks.lookup("symb");
+ Reader SymbolReader(Chunks.lookup("symb"));
SymbolSlab::Builder Symbols;
- while (!SymbolData.empty())
- if (auto Sym = readSymbol(SymbolData, *Strings))
- Symbols.insert(*Sym);
- else
- return Sym.takeError();
+ while (!SymbolReader.eof())
+ Symbols.insert(readSymbol(SymbolReader, Strings->Strings));
+ if (SymbolReader.err())
+ return makeError("malformed or truncated symbol");
Result.Symbols = std::move(Symbols).build();
}
return std::move(Result);
More information about the cfe-commits
mailing list