[clang-tools-extra] 3c09103 - [clangd] Sanity-check array sizes read from disk before allocating them.
Sam McCall via cfe-commits
cfe-commits at lists.llvm.org
Wed Nov 11 14:17:04 PST 2020
Author: Sam McCall
Date: 2020-11-11T23:16:53+01:00
New Revision: 3c09103291686630564c1ff3f78c0f8dc69d069f
URL: https://github.com/llvm/llvm-project/commit/3c09103291686630564c1ff3f78c0f8dc69d069f
DIFF: https://github.com/llvm/llvm-project/commit/3c09103291686630564c1ff3f78c0f8dc69d069f.diff
LOG: [clangd] Sanity-check array sizes read from disk before allocating them.
Previously a corrupted index shard could cause us to resize arrays to an
arbitrary int32. This tends to be a huge number, and can render the
system unresponsive.
Instead, cap this at the amount of data that might reasonably be read
(e.g. the #bytes in the file). If the specified length is more than that,
assume the data is corrupt.
Differential Revision: https://reviews.llvm.org/D91258
Added:
Modified:
clang-tools-extra/clangd/index/Serialization.cpp
clang-tools-extra/clangd/unittests/SerializationTests.cpp
Removed:
################################################################################
diff --git a/clang-tools-extra/clangd/index/Serialization.cpp b/clang-tools-extra/clangd/index/Serialization.cpp
index dcea8e902fe4..a817758d7a54 100644
--- a/clang-tools-extra/clangd/index/Serialization.cpp
+++ b/clang-tools-extra/clangd/index/Serialization.cpp
@@ -16,6 +16,7 @@
#include "support/Trace.h"
#include "clang/Tooling/CompilationDatabase.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Compression.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/Error.h"
@@ -104,6 +105,20 @@ class Reader {
llvm::StringRef Raw = consume(SymbolID::RawSize); // short if truncated.
return LLVM_UNLIKELY(err()) ? SymbolID() : SymbolID::fromRaw(Raw);
}
+
+ // Read a varint (as consumeVar) and resize the container accordingly.
+ // If the size is invalid, return false and mark an error.
+ // (The caller should abort in this case).
+ template <typename T> LLVM_NODISCARD bool consumeSize(T &Container) {
+ auto Size = consumeVar();
+ // Conservatively assume each element is at least one byte.
+ if (Size > (End - Begin)) {
+ Err = true;
+ return false;
+ }
+ Container.resize(Size);
+ return true;
+ }
};
void write32(uint32_t I, llvm::raw_ostream &OS) {
@@ -257,7 +272,8 @@ IncludeGraphNode readIncludeGraphNode(Reader &Data,
IGN.URI = Data.consumeString(Strings);
llvm::StringRef Digest = Data.consume(IGN.Digest.size());
std::copy(Digest.bytes_begin(), Digest.bytes_end(), IGN.Digest.begin());
- IGN.DirectIncludes.resize(Data.consumeVar());
+ if (!Data.consumeSize(IGN.DirectIncludes))
+ return IGN;
for (llvm::StringRef &Include : IGN.DirectIncludes)
Include = Data.consumeString(Strings);
return IGN;
@@ -323,7 +339,8 @@ Symbol readSymbol(Reader &Data, llvm::ArrayRef<llvm::StringRef> Strings) {
Sym.Documentation = Data.consumeString(Strings);
Sym.ReturnType = Data.consumeString(Strings);
Sym.Type = Data.consumeString(Strings);
- Sym.IncludeHeaders.resize(Data.consumeVar());
+ if (!Data.consumeSize(Sym.IncludeHeaders))
+ return Sym;
for (auto &I : Sym.IncludeHeaders) {
I.IncludeHeader = Data.consumeString(Strings);
I.References = Data.consumeVar();
@@ -353,7 +370,8 @@ std::pair<SymbolID, std::vector<Ref>>
readRefs(Reader &Data, llvm::ArrayRef<llvm::StringRef> Strings) {
std::pair<SymbolID, std::vector<Ref>> Result;
Result.first = Data.consumeID();
- Result.second.resize(Data.consumeVar());
+ if (!Data.consumeSize(Result.second))
+ return Result;
for (auto &Ref : Result.second) {
Ref.Kind = static_cast<RefKind>(Data.consume8());
Ref.Location = readLocation(Data, Strings);
@@ -400,7 +418,8 @@ InternedCompileCommand
readCompileCommand(Reader CmdReader, llvm::ArrayRef<llvm::StringRef> Strings) {
InternedCompileCommand Cmd;
Cmd.Directory = CmdReader.consumeString(Strings);
- Cmd.CommandLine.resize(CmdReader.consumeVar());
+ if (!CmdReader.consumeSize(Cmd.CommandLine))
+ return Cmd;
for (llvm::StringRef &C : Cmd.CommandLine)
C = CmdReader.consumeString(Strings);
return Cmd;
diff --git a/clang-tools-extra/clangd/unittests/SerializationTests.cpp b/clang-tools-extra/clangd/unittests/SerializationTests.cpp
index 1d2c1db1ee98..94db6c9127a8 100644
--- a/clang-tools-extra/clangd/unittests/SerializationTests.cpp
+++ b/clang-tools-extra/clangd/unittests/SerializationTests.cpp
@@ -7,15 +7,21 @@
//===----------------------------------------------------------------------===//
#include "Headers.h"
+#include "RIFF.h"
#include "index/Index.h"
#include "index/Serialization.h"
+#include "support/Logger.h"
#include "clang/Tooling/CompilationDatabase.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ScopedPrinter.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#ifdef LLVM_ON_UNIX
+#include <sys/resource.h>
+#endif
-using ::testing::_;
-using ::testing::AllOf;
using ::testing::ElementsAre;
using ::testing::Pair;
using ::testing::UnorderedElementsAre;
@@ -297,6 +303,86 @@ TEST(SerializationTest, CmdlTest) {
EXPECT_NE(SerializedCmd.Output, Cmd.Output);
}
}
+
+#if LLVM_ON_UNIX // rlimit is part of POSIX
+class ScopedMemoryLimit {
+ struct rlimit OriginalLimit;
+ bool Succeeded = false;
+
+public:
+ ScopedMemoryLimit(rlim_t Bytes) {
+ if (!getrlimit(RLIMIT_AS, &OriginalLimit)) {
+ struct rlimit NewLimit = OriginalLimit;
+ NewLimit.rlim_cur = Bytes;
+ Succeeded = !setrlimit(RLIMIT_AS, &NewLimit);
+ }
+ if (!Succeeded)
+ log("Failed to set rlimit");
+ }
+
+ ~ScopedMemoryLimit() {
+ if (Succeeded)
+ setrlimit(RLIMIT_AS, &OriginalLimit);
+ }
+};
+#else
+class ScopedMemoryLimit {
+public:
+ ScopedMemoryLimit(unsigned Bytes) { log("rlimit unsupported"); }
+};
+#endif
+
+// Test that our deserialization detects invalid array sizes without allocating.
+// If this detection fails, the test should allocate a huge array and crash.
+TEST(SerializationTest, NoCrashOnBadArraySize) {
+ // This test is tricky because we need to construct a subtly invalid file.
+ // First, create a valid serialized file.
+ auto In = readIndexFile(YAML);
+ ASSERT_FALSE(!In) << In.takeError();
+ IndexFileOut Out(*In);
+ Out.Format = IndexFileFormat::RIFF;
+ std::string Serialized = llvm::to_string(Out);
+
+ // Low-level parse it again and find the `srcs` chunk we're going to corrupt.
+ auto Parsed = riff::readFile(Serialized);
+ ASSERT_FALSE(!Parsed) << Parsed.takeError();
+ auto Srcs = llvm::find_if(Parsed->Chunks, [](riff::Chunk C) {
+ return C.ID == riff::fourCC("srcs");
+ });
+ ASSERT_NE(Srcs, Parsed->Chunks.end());
+
+ // Srcs consists of a sequence of IncludeGraphNodes. In our case, just one.
+ // The node has:
+ // - 1 byte: flags (1)
+ // - varint(stringID): URI
+ // - 8 byte: file digest
+ // - varint: DirectIncludes.length
+ // - repeated varint(stringID): DirectIncludes
+ // We want to set DirectIncludes.length to a huge number.
+ // The offset isn't trivial to find, so we use the file digest.
+ std::string FileDigest = llvm::fromHex("EED8F5EAF25C453C");
+ unsigned Pos = Srcs->Data.find_first_of(FileDigest);
+ ASSERT_NE(Pos, StringRef::npos) << "Couldn't locate file digest";
+ Pos += FileDigest.size();
+
+ // Varints are little-endian base-128 numbers, where the top-bit of each byte
+ // indicates whether there are more. 8fffffff7f -> 0xffffffff.
+ std::string CorruptSrcs =
+ (Srcs->Data.take_front(Pos) + llvm::fromHex("8fffffff7f") +
+ "some_random_garbage")
+ .str();
+ Srcs->Data = CorruptSrcs;
+
+ // Try to crash rather than hang on large allocation.
+ ScopedMemoryLimit MemLimit(1000 * 1024 * 1024); // 1GB
+
+ std::string CorruptFile = llvm::to_string(*Parsed);
+ auto CorruptParsed = readIndexFile(CorruptFile);
+ ASSERT_TRUE(!CorruptParsed);
+ EXPECT_EQ(llvm::toString(CorruptParsed.takeError()),
+ "malformed or truncated include uri");
+}
+
} // namespace
} // namespace clangd
} // namespace clang
More information about the cfe-commits
mailing list