[llvm] [memprof] Add access checks to PortableMemInfoBlock::get* (PR #90121)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 01:17:56 PDT 2024


https://github.com/kazutakahirata updated https://github.com/llvm/llvm-project/pull/90121

>From 472f06baac53e105cb9af860507e5935a698e510 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Wed, 24 Apr 2024 12:54:04 -0700
Subject: [PATCH 1/2] [memprof] Add access checks to PortableMemInfoBlock::get*

  commit 4c8ec8f8bc3fb4dda4fd36c3b2ad745bd3451970
  Author: Kazu Hirata <kazu at google.com>
  Date:   Wed Apr 24 16:25:35 2024 -0700

introduced the idea of serializing/deserializing a subset of the
fields in PortableMemInfoBlock.  While it reduces the size of the
indexed MemProf profile file, we now could inadvertently access
unavailable fields and go without noticing.

To protect ourselves from the risk, this patch adds access checks to
PortableMemInfoBlock::get* methods by embedding a bit set representing
available fields into PortableMemInfoBlock.
---
 llvm/include/llvm/ProfileData/MemProf.h      | 34 ++++++++++++++++----
 llvm/unittests/ProfileData/InstrProfTest.cpp |  8 ++---
 llvm/unittests/ProfileData/MemProfTest.cpp   |  2 +-
 3 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/MemProf.h b/llvm/include/llvm/ProfileData/MemProf.h
index d378c3696f8d0b..e59c6b6b02f141 100644
--- a/llvm/include/llvm/ProfileData/MemProf.h
+++ b/llvm/include/llvm/ProfileData/MemProf.h
@@ -2,6 +2,7 @@
 #define LLVM_PROFILEDATA_MEMPROF_H_
 
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/GlobalValue.h"
@@ -10,6 +11,7 @@
 #include "llvm/Support/EndianStream.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include <bitset>
 #include <cstdint>
 #include <optional>
 
@@ -55,7 +57,10 @@ MemProfSchema getHotColdSchema();
 // deserialize methods.
 struct PortableMemInfoBlock {
   PortableMemInfoBlock() = default;
-  explicit PortableMemInfoBlock(const MemInfoBlock &Block) {
+  explicit PortableMemInfoBlock(const MemInfoBlock &Block,
+                                const MemProfSchema &IncomingSchema) {
+    for (const Meta Id : IncomingSchema)
+      Schema.set(llvm::to_underlying(Id));
 #define MIBEntryDef(NameTag, Name, Type) Name = Block.Name;
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -67,10 +72,12 @@ struct PortableMemInfoBlock {
 
   // Read the contents of \p Ptr based on the \p Schema to populate the
   // MemInfoBlock member.
-  void deserialize(const MemProfSchema &Schema, const unsigned char *Ptr) {
+  void deserialize(const MemProfSchema &IncomingSchema,
+                   const unsigned char *Ptr) {
     using namespace support;
 
-    for (const Meta Id : Schema) {
+    Schema.reset();
+    for (const Meta Id : IncomingSchema) {
       switch (Id) {
 #define MIBEntryDef(NameTag, Name, Type)                                       \
   case Meta::Name: {                                                           \
@@ -82,6 +89,8 @@ struct PortableMemInfoBlock {
         llvm_unreachable("Unknown meta type id, is the profile collected from "
                          "a newer version of the runtime?");
       }
+
+      Schema.set(llvm::to_underlying(Id));
     }
   }
 
@@ -116,15 +125,22 @@ struct PortableMemInfoBlock {
 
   // Define getters for each type which can be called by analyses.
 #define MIBEntryDef(NameTag, Name, Type)                                       \
-  Type get##Name() const { return Name; }
+  Type get##Name() const {                                                     \
+    assert(Schema[llvm::to_underlying(Meta::Name)]);                           \
+    return Name;                                                               \
+  }
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
 
   void clear() { *this = PortableMemInfoBlock(); }
 
   bool operator==(const PortableMemInfoBlock &Other) const {
+    if (Other.Schema != Schema)
+      return false;
+
 #define MIBEntryDef(NameTag, Name, Type)                                       \
-  if (Other.get##Name() != get##Name())                                        \
+  if (Schema[llvm::to_underlying(Meta::Name)] &&                               \
+      Other.get##Name() != get##Name())                                        \
     return false;
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -155,6 +171,9 @@ struct PortableMemInfoBlock {
   }
 
 private:
+  // The set of available fields, indexed by Meta::Name.
+  std::bitset<llvm::to_underlying(Meta::Size)> Schema;
+
 #define MIBEntryDef(NameTag, Name, Type) Type Name = Type();
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -296,8 +315,9 @@ struct IndexedAllocationInfo {
 
   IndexedAllocationInfo() = default;
   IndexedAllocationInfo(ArrayRef<FrameId> CS, CallStackId CSId,
-                        const MemInfoBlock &MB)
-      : CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB) {}
+                        const MemInfoBlock &MB,
+                        const MemProfSchema &Schema = getFullSchema())
+      : CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB, Schema) {}
 
   // Returns the size in bytes when this allocation info struct is serialized.
   size_t serializedSize(const MemProfSchema &Schema,
diff --git a/llvm/unittests/ProfileData/InstrProfTest.cpp b/llvm/unittests/ProfileData/InstrProfTest.cpp
index edc427dcbc4540..1b0ee6b8cdab98 100644
--- a/llvm/unittests/ProfileData/InstrProfTest.cpp
+++ b/llvm/unittests/ProfileData/InstrProfTest.cpp
@@ -407,13 +407,13 @@ IndexedMemProfRecord makeRecord(
 IndexedMemProfRecord
 makeRecordV2(std::initializer_list<::llvm::memprof::CallStackId> AllocFrames,
              std::initializer_list<::llvm::memprof::CallStackId> CallSiteFrames,
-             const MemInfoBlock &Block) {
+             const MemInfoBlock &Block, const memprof::MemProfSchema &Schema) {
   llvm::memprof::IndexedMemProfRecord MR;
   for (const auto &CSId : AllocFrames)
     // We don't populate IndexedAllocationInfo::CallStack because we use it only
     // in Version0 and Version1.
     MR.AllocSites.emplace_back(::llvm::SmallVector<memprof::FrameId>(), CSId,
-                               Block);
+                               Block, Schema);
   for (const auto &CSId : CallSiteFrames)
     MR.CallSiteIds.push_back(CSId);
   return MR;
@@ -544,7 +544,7 @@ TEST_F(InstrProfTest, test_memprof_v2_full_schema) {
 
   const IndexedMemProfRecord IndexedMR = makeRecordV2(
       /*AllocFrames=*/{0x111, 0x222},
-      /*CallSiteFrames=*/{0x333}, MIB);
+      /*CallSiteFrames=*/{0x333}, MIB, memprof::getFullSchema());
   const FrameIdMapTy IdToFrameMap = getFrameMapping();
   const auto CSIdToCallStackMap = getCallStackMapping();
   for (const auto &I : IdToFrameMap) {
@@ -584,7 +584,7 @@ TEST_F(InstrProfTest, test_memprof_v2_partial_schema) {
 
   const IndexedMemProfRecord IndexedMR = makeRecordV2(
       /*AllocFrames=*/{0x111, 0x222},
-      /*CallSiteFrames=*/{0x333}, MIB);
+      /*CallSiteFrames=*/{0x333}, MIB, memprof::getHotColdSchema());
   const FrameIdMapTy IdToFrameMap = getFrameMapping();
   const auto CSIdToCallStackMap = getCallStackMapping();
   for (const auto &I : IdToFrameMap) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 503901094ba9a5..2e881adcefcec5 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -241,7 +241,7 @@ TEST(MemProf, PortableWrapper) {
                     /*dealloc_cpu=*/4);
 
   const auto Schema = llvm::memprof::getFullSchema();
-  PortableMemInfoBlock WriteBlock(Info);
+  PortableMemInfoBlock WriteBlock(Info, Schema);
 
   std::string Buffer;
   llvm::raw_string_ostream OS(Buffer);

>From 193ce8a7694b8c098116201c22829d472115f3af Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Thu, 25 Apr 2024 15:19:03 -0700
Subject: [PATCH 2/2] Add unit tests.

---
 llvm/include/llvm/ProfileData/MemProf.h    |  5 ++
 llvm/unittests/ProfileData/MemProfTest.cpp | 60 ++++++++++++++++++++++
 2 files changed, 65 insertions(+)

diff --git a/llvm/include/llvm/ProfileData/MemProf.h b/llvm/include/llvm/ProfileData/MemProf.h
index e59c6b6b02f141..9f296b1ceee71c 100644
--- a/llvm/include/llvm/ProfileData/MemProf.h
+++ b/llvm/include/llvm/ProfileData/MemProf.h
@@ -123,6 +123,11 @@ struct PortableMemInfoBlock {
 #undef MIBEntryDef
   }
 
+  // Return the schema, only for unit tests.
+  std::bitset<llvm::to_underlying(Meta::Size)> getSchema() const {
+    return Schema;
+  }
+
   // Define getters for each type which can be called by analyses.
 #define MIBEntryDef(NameTag, Name, Type)                                       \
   Type get##Name() const {                                                     \
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 2e881adcefcec5..2b3360468e30b8 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -1,6 +1,7 @@
 #include "llvm/ProfileData/MemProf.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/DebugInfo/DIContext.h"
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
@@ -326,6 +327,65 @@ TEST(MemProf, RecordSerializationRoundTripVerion2) {
   EXPECT_EQ(Record, GotRecord);
 }
 
+TEST(MemProf, RecordSerializationRoundTripVersion2HotColdSchema) {
+  const auto Schema = llvm::memprof::getHotColdSchema();
+
+  MemInfoBlock Info;
+  Info.AllocCount = 11;
+  Info.TotalSize = 22;
+  Info.TotalLifetime = 33;
+  Info.TotalLifetimeAccessDensity = 44;
+
+  llvm::SmallVector<llvm::memprof::CallStackId> CallStackIds = {0x123, 0x456};
+
+  llvm::SmallVector<llvm::memprof::CallStackId> CallSiteIds = {0x333, 0x444};
+
+  IndexedMemProfRecord Record;
+  for (const auto &CSId : CallStackIds) {
+    // Use the same info block for both allocation sites.
+    Record.AllocSites.emplace_back(llvm::SmallVector<FrameId>(), CSId, Info,
+                                   Schema);
+  }
+  Record.CallSiteIds.assign(CallSiteIds);
+
+  std::bitset<llvm::to_underlying(Meta::Size)> SchemaBitSet;
+  for (auto Id : Schema)
+    SchemaBitSet.set(llvm::to_underlying(Id));
+
+  // Verify that SchemaBitSet has the fields we expect and nothing else, which
+  // we check with count().
+  EXPECT_EQ(SchemaBitSet.count(), 4U);
+  EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::AllocCount)]);
+  EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::TotalSize)]);
+  EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::TotalLifetime)]);
+  EXPECT_TRUE(
+      SchemaBitSet[llvm::to_underlying(Meta::TotalLifetimeAccessDensity)]);
+
+  // Verify that Schema has propagated all the way to the Info field in each
+  // IndexedAllocationInfo.
+  ASSERT_THAT(Record.AllocSites, ::SizeIs(2));
+  EXPECT_EQ(Record.AllocSites[0].Info.getSchema(), SchemaBitSet);
+  EXPECT_EQ(Record.AllocSites[1].Info.getSchema(), SchemaBitSet);
+
+  std::string Buffer;
+  llvm::raw_string_ostream OS(Buffer);
+  Record.serialize(Schema, OS, llvm::memprof::Version2);
+  OS.flush();
+
+  const IndexedMemProfRecord GotRecord = IndexedMemProfRecord::deserialize(
+      Schema, reinterpret_cast<const unsigned char *>(Buffer.data()),
+      llvm::memprof::Version2);
+
+  // Verify that Schema comes back correctly after deserialization. Technically,
+  // the comparison between Record and GotRecord below includes the comparison
+  // of their Schemas, but we'll verify the Schemas on our own.
+  ASSERT_THAT(GotRecord.AllocSites, ::SizeIs(2));
+  EXPECT_EQ(GotRecord.AllocSites[0].Info.getSchema(), SchemaBitSet);
+  EXPECT_EQ(GotRecord.AllocSites[1].Info.getSchema(), SchemaBitSet);
+
+  EXPECT_EQ(Record, GotRecord);
+}
+
 TEST(MemProf, SymbolizationFilter) {
   std::unique_ptr<MockSymbolizer> Symbolizer(new MockSymbolizer());
 



More information about the llvm-commits mailing list