[llvm] [StaticDataLayout][PGO]Implement reader and writer change for data access profiles (PR #139997)
Mingming Liu via llvm-commits
llvm-commits at lists.llvm.org
Wed May 21 16:24:02 PDT 2025
https://github.com/mingmingl-llvm updated https://github.com/llvm/llvm-project/pull/139997
>From 6cd7d8dafd6a5cb438c5bd595e126f3cd863814e Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 1 May 2025 09:37:59 -0700
Subject: [PATCH 01/10] Add classes to construct and use data access profiles
---
llvm/include/llvm/ADT/MapVector.h | 2 +
.../include/llvm/ProfileData/DataAccessProf.h | 156 +++++++++++
llvm/include/llvm/ProfileData/InstrProf.h | 16 +-
llvm/lib/ProfileData/CMakeLists.txt | 1 +
llvm/lib/ProfileData/DataAccessProf.cpp | 246 ++++++++++++++++++
llvm/lib/ProfileData/InstrProf.cpp | 8 +-
llvm/unittests/ProfileData/MemProfTest.cpp | 161 ++++++++++++
7 files changed, 579 insertions(+), 11 deletions(-)
create mode 100644 llvm/include/llvm/ProfileData/DataAccessProf.h
create mode 100644 llvm/lib/ProfileData/DataAccessProf.cpp
diff --git a/llvm/include/llvm/ADT/MapVector.h b/llvm/include/llvm/ADT/MapVector.h
index c11617a81c97d..fe0d106795c34 100644
--- a/llvm/include/llvm/ADT/MapVector.h
+++ b/llvm/include/llvm/ADT/MapVector.h
@@ -57,6 +57,8 @@ class MapVector {
return std::move(Vector);
}
+ ArrayRef<value_type> getArrayRef() const { return Vector; }
+
size_type size() const { return Vector.size(); }
/// Grow the MapVector so that it can contain at least \p NumEntries items
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
new file mode 100644
index 0000000000000..2cce4945fddd5
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -0,0 +1,156 @@
+//===- DataAccessProf.h - Data access profile format support ---------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains support to construct and use data access profiles.
+//
+// For the original RFC of this pass please see
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_
+#define LLVM_PROFILEDATA_DATAACCESSPROF_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+
+#include <cstdint>
+#include <variant>
+
+namespace llvm {
+
+namespace data_access_prof {
+// The location of data in the source code.
+struct DataLocation {
+ // The filename where the data is located.
+ StringRef FileName;
+ // The line number in the source code.
+ uint32_t Line;
+};
+
+// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+ // Represents a data symbol. The semantic comes in two forms: a symbol index
+ // for symbol name if `IsStringLiteral` is false, or the hash of a string
+ // content if `IsStringLiteral` is true. Required.
+ uint64_t SymbolID;
+
+ // The access count of symbol. Required.
+ uint64_t AccessCount;
+
+ // True iff this is a record for string literal (symbols with name pattern
+ // `.str.*` in the symbol table). Required.
+ bool IsStringLiteral;
+
+ // The locations of data in the source code. Optional.
+ llvm::SmallVector<DataLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on it.
+/// This class provides profile look-up, serialization and deserialization.
+class DataAccessProfData {
+public:
+ // SymbolID is either a string representing symbol name, or a uint64_t
+ // representing the content hash of a string literal.
+ using SymbolID = std::variant<StringRef, uint64_t>;
+ using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
+
+ DataAccessProfData() : saver(Allocator) {}
+
+ /// Serialize profile data to the output stream.
+ /// Storage layout:
+ /// - Serialized strings.
+ /// - The encoded hashes.
+ /// - Records.
+ Error serialize(ProfOStream &OS) const;
+
+ /// Deserialize this class from the given buffer.
+ Error deserialize(const unsigned char *&Ptr);
+
+ /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+ /// isn't a record. Internally, this function will canonicalize the symbol
+ /// name before the lookup.
+ const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+
+ /// Returns true if \p SymID is seen in profiled binaries and cold.
+ bool isKnownColdSymbol(const SymbolID SymID) const;
+
+ /// Methods to add symbolized data access profile. Returns error if duplicated
+ /// symbol names or content hashes are seen. The user of this class should
+ /// aggregate counters that corresponds to the same symbol name or with the
+ /// same string literal hash before calling 'add*' methods.
+ Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+ Error addSymbolizedDataAccessProfile(
+ SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations);
+ Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+
+ /// Returns a iterable StringRef for strings in the order they are added.
+ auto getStrings() const {
+ ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
+ StrToIndexMap.begin(), StrToIndexMap.end());
+ return llvm::make_first_range(RefSymbolNames);
+ }
+
+ /// Returns array reference for various internal data structures.
+ inline ArrayRef<
+ std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
+ getRecords() const {
+ return Records.getArrayRef();
+ }
+ inline ArrayRef<StringRef> getKnownColdSymbols() const {
+ return KnownColdSymbols.getArrayRef();
+ }
+ inline ArrayRef<uint64_t> getKnownColdHashes() const {
+ return KnownColdHashes.getArrayRef();
+ }
+
+private:
+ /// Serialize the symbol strings into the output stream.
+ Error serializeStrings(ProfOStream &OS) const;
+
+ /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
+ /// start of the next payload.
+ Error deserializeStrings(const unsigned char *&Ptr,
+ const uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols);
+
+ /// Decode the records and increment \p Ptr to the start of the next payload.
+ Error deserializeRecords(const unsigned char *&Ptr);
+
+ /// A helper function to compute a storage index for \p SymbolID.
+ uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+
+ // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
+ // its record index.
+ MapVector<SymbolID, DataAccessProfRecord> Records;
+
+ // Use MapVector to keep input order of strings for serialization and
+ // deserialization.
+ StringToIndexMap StrToIndexMap;
+ llvm::SetVector<uint64_t> KnownColdHashes;
+ llvm::SetVector<StringRef> KnownColdSymbols;
+ // Keeps owned copies of the input strings.
+ llvm::BumpPtrAllocator Allocator;
+ llvm::UniqueStringSaver saver;
+};
+
+} // namespace data_access_prof
+} // namespace llvm
+
+#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 2d011c89f27cb..8a6be22bdb1a4 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,6 +357,12 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
/// the duplicated profile variables for Comdat functions.
bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
+/// \c NameStrings is a string composed of one of more possibly encoded
+/// sub-strings. The substrings are separated by 0 or more zero bytes. This
+/// method decodes the string and calls `NameCallback` for each substring.
+Error readAndDecodeStrings(StringRef NameStrings,
+ std::function<Error(StringRef)> NameCallback);
+
/// An enum describing the attributes of an instrumented profile.
enum class InstrProfKind {
Unknown = 0x0,
@@ -493,6 +499,11 @@ class InstrProfSymtab {
public:
using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
+ // Returns the canonial name of the given PGOName. In a canonical name, all
+ // suffixes that begins with "." except ".__uniq." are stripped.
+ // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
+ static StringRef getCanonicalName(StringRef PGOName);
+
private:
using AddrIntervalMap =
IntervalMap<uint64_t, uint64_t, 4, IntervalMapHalfOpenInfo<uint64_t>>;
@@ -528,11 +539,6 @@ class InstrProfSymtab {
static StringRef getExternalSymbol() { return "** External Symbol **"; }
- // Returns the canonial name of the given PGOName. In a canonical name, all
- // suffixes that begins with "." except ".__uniq." are stripped.
- // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
- static StringRef getCanonicalName(StringRef PGOName);
-
// Add the function into the symbol table, by creating the following
// map entries:
// name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)}
diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt
index eb7c2a3c1a28a..67a69d7761b2c 100644
--- a/llvm/lib/ProfileData/CMakeLists.txt
+++ b/llvm/lib/ProfileData/CMakeLists.txt
@@ -1,4 +1,5 @@
add_llvm_component_library(LLVMProfileData
+ DataAccessProf.cpp
GCOV.cpp
IndexedMemProfData.cpp
InstrProf.cpp
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
new file mode 100644
index 0000000000000..cf538d6a1b28e
--- /dev/null
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -0,0 +1,246 @@
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Compression.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <sys/types.h>
+
+namespace llvm {
+namespace data_access_prof {
+
+// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
+// creates an owned copy of `Str`, adds a map entry for it and returns the
+// iterator.
+static MapVector<StringRef, uint64_t>::iterator
+saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+ llvm::UniqueStringSaver &saver, StringRef Str) {
+ auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+ return Iter;
+}
+
+const DataAccessProfRecord *
+DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+ auto Key = SymbolID;
+ if (std::holds_alternative<StringRef>(SymbolID))
+ Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+
+ auto It = Records.find(Key);
+ if (It != Records.end())
+ return &It->second;
+
+ return nullptr;
+}
+
+bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+ if (std::holds_alternative<uint64_t>(SymID))
+ return KnownColdHashes.count(std::get<uint64_t>(SymID));
+ return KnownColdSymbols.count(std::get<StringRef>(SymID));
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
+ uint64_t AccessCount) {
+ uint64_t RecordID = -1;
+ bool IsStringLiteral = false;
+ SymbolID Key;
+ if (std::holds_alternative<uint64_t>(Symbol)) {
+ RecordID = std::get<uint64_t>(Symbol);
+ Key = RecordID;
+ IsStringLiteral = true;
+ } else {
+ StringRef SymbolName = std::get<StringRef>(Symbol);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+
+ StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
+ Key = CanonicalName;
+ RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+ IsStringLiteral = false;
+ }
+
+ auto [Iter, Inserted] = Records.try_emplace(
+ Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+ if (!Inserted)
+ return make_error<StringError>("Duplicate symbol or string literal added. "
+ "User of DataAccessProfData should "
+ "aggregate count for the same symbol. ",
+ llvm::errc::invalid_argument);
+
+ return Error::success();
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(
+ SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations) {
+ if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ return E;
+
+ auto &Record = Records.back().second;
+ for (const auto &Location : Locations)
+ Record.Locations.push_back(
+ {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+ Location.Line});
+
+ return Error::success();
+}
+
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+ if (std::holds_alternative<uint64_t>(SymbolID)) {
+ KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
+ return Error::success();
+ }
+ StringRef SymbolName = std::get<StringRef>(SymbolID);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+ StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
+ KnownColdSymbols.insert(CanonicalSymName);
+ return Error::success();
+}
+
+Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
+ uint64_t NumSampledSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint64_t NumColdKnownSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+ return E;
+
+ uint64_t Num =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ for (uint64_t I = 0; I < Num; ++I)
+ KnownColdHashes.insert(
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
+
+ return deserializeRecords(Ptr);
+}
+
+Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+ OS.write(StrToIndexMap.size());
+ OS.write(KnownColdSymbols.size());
+
+ std::vector<std::string> Strs;
+ Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
+ for (const auto &Str : StrToIndexMap)
+ Strs.push_back(Str.first.str());
+ for (const auto &Str : KnownColdSymbols)
+ Strs.push_back(Str.str());
+
+ std::string CompressedStrings;
+ if (!Strs.empty())
+ if (Error E = collectGlobalObjectNameStrings(
+ Strs, compression::zlib::isAvailable(), CompressedStrings))
+ return E;
+ const uint64_t CompressedStringLen = CompressedStrings.length();
+ // Record the length of compressed string.
+ OS.write(CompressedStringLen);
+ // Write the chars in compressed strings.
+ for (auto &c : CompressedStrings)
+ OS.writeByte(static_cast<uint8_t>(c));
+ // Pad up to a multiple of 8.
+ // InstrProfReader could read bytes according to 'CompressedStringLen'.
+ const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
+ for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
+ OS.writeByte(0);
+ return Error::success();
+}
+
+uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+ if (std::holds_alternative<uint64_t>(SymbolID))
+ return std::get<uint64_t>(SymbolID);
+
+ return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+}
+
+Error DataAccessProfData::serialize(ProfOStream &OS) const {
+ if (Error E = serializeStrings(OS))
+ return E;
+ OS.write(KnownColdHashes.size());
+ for (const auto &Hash : KnownColdHashes)
+ OS.write(Hash);
+ OS.write((uint64_t)(Records.size()));
+ for (const auto &[Key, Rec] : Records) {
+ OS.write(getEncodedIndex(Rec.SymbolID));
+ OS.writeByte(Rec.IsStringLiteral);
+ OS.write(Rec.AccessCount);
+ OS.write(Rec.Locations.size());
+ for (const auto &Loc : Rec.Locations) {
+ OS.write(getEncodedIndex(Loc.FileName));
+ OS.write32(Loc.Line);
+ }
+ }
+ return Error::success();
+}
+
+Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
+ uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols) {
+ uint64_t Len =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
+ // symbols with samples, and next N strings are known cold symbols.
+ uint64_t StringCnt = 0;
+ std::function<Error(StringRef)> addName = [&](StringRef Name) {
+ if (StringCnt < NumSampledSymbols)
+ saveStringToMap(StrToIndexMap, saver, Name);
+ else
+ KnownColdSymbols.insert(saver.save(Name));
+ ++StringCnt;
+ return Error::success();
+ };
+ if (Error E =
+ readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName))
+ return E;
+
+ Ptr += alignTo(Len, 8);
+ return Error::success();
+}
+
+Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
+ SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+
+ uint64_t NumRecords =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ for (uint64_t I = 0; I < NumRecords; ++I) {
+ uint64_t ID =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ bool IsStringLiteral =
+ support::endian::readNext<uint8_t, llvm::endianness::little>(Ptr);
+
+ uint64_t AccessCount =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ SymbolID SymbolID;
+ if (IsStringLiteral)
+ SymbolID = ID;
+ else
+ SymbolID = Strings[ID];
+ if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ return E;
+
+ auto &Record = Records.back().second;
+
+ uint64_t NumLocations =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ Record.Locations.reserve(NumLocations);
+ for (uint64_t J = 0; J < NumLocations; ++J) {
+ uint64_t FileNameIndex =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint32_t Line =
+ support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+ Record.Locations.push_back({Strings[FileNameIndex], Line});
+ }
+ }
+ return Error::success();
+}
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp
index 88621787c1dd9..254f941acde82 100644
--- a/llvm/lib/ProfileData/InstrProf.cpp
+++ b/llvm/lib/ProfileData/InstrProf.cpp
@@ -573,12 +573,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable,
return Error::success();
}
-/// \c NameStrings is a string composed of one of more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
-static Error
-readAndDecodeStrings(StringRef NameStrings,
- std::function<Error(StringRef)> NameCallback) {
+Error readAndDecodeStrings(StringRef NameStrings,
+ std::function<Error(StringRef)> NameCallback) {
const uint8_t *P = NameStrings.bytes_begin();
const uint8_t *EndP = NameStrings.bytes_end();
while (P < EndP) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 3e430aa4eae58..f6362448a5734 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,14 +10,17 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/DebugInfo/DIContext.h"
#include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
#include "llvm/IR/Value.h"
#include "llvm/Object/ObjectFile.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/MemProfData.inc"
#include "llvm/ProfileData/MemProfReader.h"
#include "llvm/ProfileData/MemProfYAML.h"
#include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -36,6 +39,8 @@ using ::llvm::StringRef;
using ::llvm::object::SectionedAddress;
using ::llvm::symbolize::SymbolizableModule;
using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Pair;
using ::testing::Return;
@@ -747,6 +752,162 @@ TEST(MemProf, YAMLParser) {
ElementsAre(0x3000)))));
}
+static std::string ErrorToString(Error E) {
+ std::string ErrMsg;
+ llvm::raw_string_ostream OS(ErrMsg);
+ llvm::logAllUnhandledErrors(std::move(E), OS);
+ return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+ // Returns error if the input symbol name is empty.
+ llvm::data_access_prof::DataAccessProfData Data;
+ EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+ HasSubstr("Empty symbol name"));
+
+ // Returns error when the same symbol gets added twice.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
+ EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+ HasSubstr("Duplicate symbol or string literal added"));
+
+ // Returns error when the same string content hash gets added twice.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
+ EXPECT_THAT(ErrorToString(
+ Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+ HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+ using namespace llvm::data_access_prof;
+ llvm::data_access_prof::DataAccessProfData Data;
+
+ // In the bool conversion, Error is true if it's in a failure state and false
+ // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
+ {
+ DataLocation{"file2", 3},
+ }));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+ (uint64_t)135246, 1000,
+ {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+ {
+ // Test that symbol names and file names are stored in the input order.
+ EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles.
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+ EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+ EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+ EXPECT_THAT(
+ Data.getProfileRecord("foo.llvm.123"),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())));
+ EXPECT_THAT(
+ *Data.getProfileRecord("bar.__uniq.321"),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))));
+ EXPECT_THAT(
+ *Data.getProfileRecord((uint64_t)135246),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2))))));
+ }
+
+ // Tests serialization and de-serialization.
+ llvm::data_access_prof::DataAccessProfData deserializedData;
+ {
+ std::string serializedData;
+ llvm::raw_string_ostream OS(serializedData);
+ llvm::ProfOStream POS(OS);
+
+ EXPECT_FALSE(Data.serialize(POS));
+
+ const unsigned char *p =
+ reinterpret_cast<const unsigned char *>(serializedData.data());
+ ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ testing::IsEmpty());
+ EXPECT_FALSE(deserializedData.deserialize(p));
+
+ EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+ ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles after deserialization.
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+ auto Records =
+ llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+ EXPECT_THAT(
+ Records,
+ ElementsAre(
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2)))))));
+ }
+}
+
// Verify that the YAML parser accepts a GUID expressed as a function name.
TEST(MemProf, YAMLParserGUID) {
StringRef YAMLData = R"YAML(
>From 47275299fd3e9ef242f108040df8ffe46f9cd7b0 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 5 May 2025 16:57:56 -0700
Subject: [PATCH 02/10] resolve review feedback
---
.../include/llvm/ProfileData/DataAccessProf.h | 37 ++++++++++------
llvm/lib/ProfileData/DataAccessProf.cpp | 43 ++++++++++---------
llvm/unittests/ProfileData/MemProfTest.cpp | 23 +++++-----
3 files changed, 57 insertions(+), 46 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 2cce4945fddd5..36648d6298ee5 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -45,6 +45,11 @@ struct DataLocation {
// The data access profiles for a symbol.
struct DataAccessProfRecord {
+ DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
+ bool IsStringLiteral)
+ : SymbolID(SymbolID), AccessCount(AccessCount),
+ IsStringLiteral(IsStringLiteral) {}
+
// Represents a data symbol. The semantic comes in two forms: a symbol index
// for symbol name if `IsStringLiteral` is false, or the hash of a string
// content if `IsStringLiteral` is true. Required.
@@ -58,7 +63,7 @@ struct DataAccessProfRecord {
bool IsStringLiteral;
// The locations of data in the source code. Optional.
- llvm::SmallVector<DataLocation> Locations;
+ llvm::SmallVector<DataLocation, 0> Locations;
};
/// Encapsulates the data access profile data and the methods to operate on it.
@@ -70,7 +75,7 @@ class DataAccessProfData {
using SymbolID = std::variant<StringRef, uint64_t>;
using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
- DataAccessProfData() : saver(Allocator) {}
+ DataAccessProfData() : Saver(Allocator) {}
/// Serialize profile data to the output stream.
/// Storage layout:
@@ -94,10 +99,13 @@ class DataAccessProfData {
/// symbol names or content hashes are seen. The user of this class should
/// aggregate counters that corresponds to the same symbol name or with the
/// same string literal hash before calling 'add*' methods.
- Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
- Error addSymbolizedDataAccessProfile(
- SymbolID SymbolID, uint64_t AccessCount,
- const llvm::SmallVector<DataLocation> &Locations);
+ Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+ /// Similar to the method above, for records with \p Locations representing
+ /// the `filename:line` where this symbol shows up. Note because of linker's
+ /// merge of identical symbols (e.g., unnamed_addr string literals), one
+ /// symbol is likely to have multiple locations.
+ Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations);
Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
/// Returns a iterable StringRef for strings in the order they are added.
@@ -122,13 +130,13 @@ class DataAccessProfData {
private:
/// Serialize the symbol strings into the output stream.
- Error serializeStrings(ProfOStream &OS) const;
+ Error serializeSymbolsAndFilenames(ProfOStream &OS) const;
/// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
/// start of the next payload.
- Error deserializeStrings(const unsigned char *&Ptr,
- const uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols);
+ Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
+ const uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols);
/// Decode the records and increment \p Ptr to the start of the next payload.
Error deserializeRecords(const unsigned char *&Ptr);
@@ -136,6 +144,12 @@ class DataAccessProfData {
/// A helper function to compute a storage index for \p SymbolID.
uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+ // Keeps owned copies of the input strings.
+ // NOTE: Keep `Saver` initialized before other class members that reference
+ // its string copies and destructed after they are destructed.
+ llvm::BumpPtrAllocator Allocator;
+ llvm::UniqueStringSaver Saver;
+
// `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
// its record index.
MapVector<SymbolID, DataAccessProfRecord> Records;
@@ -145,9 +159,6 @@ class DataAccessProfData {
StringToIndexMap StrToIndexMap;
llvm::SetVector<uint64_t> KnownColdHashes;
llvm::SetVector<StringRef> KnownColdSymbols;
- // Keeps owned copies of the input strings.
- llvm::BumpPtrAllocator Allocator;
- llvm::UniqueStringSaver saver;
};
} // namespace data_access_prof
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index cf538d6a1b28e..c52533c13919c 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -18,8 +18,8 @@ namespace data_access_prof {
// iterator.
static MapVector<StringRef, uint64_t>::iterator
saveStringToMap(MapVector<StringRef, uint64_t> &Map,
- llvm::UniqueStringSaver &saver, StringRef Str) {
- auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+ llvm::UniqueStringSaver &Saver, StringRef Str) {
+ auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
return Iter;
}
@@ -38,12 +38,12 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
if (std::holds_alternative<uint64_t>(SymID))
- return KnownColdHashes.count(std::get<uint64_t>(SymID));
- return KnownColdSymbols.count(std::get<StringRef>(SymID));
+ return KnownColdHashes.contains(std::get<uint64_t>(SymID));
+ return KnownColdSymbols.contains(std::get<StringRef>(SymID));
}
-Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
- uint64_t AccessCount) {
+Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+ uint64_t AccessCount) {
uint64_t RecordID = -1;
bool IsStringLiteral = false;
SymbolID Key;
@@ -59,12 +59,12 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
Key = CanonicalName;
- RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+ RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
IsStringLiteral = false;
}
- auto [Iter, Inserted] = Records.try_emplace(
- Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+ auto [Iter, Inserted] =
+ Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
if (!Inserted)
return make_error<StringError>("Duplicate symbol or string literal added. "
"User of DataAccessProfData should "
@@ -74,16 +74,16 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
return Error::success();
}
-Error DataAccessProfData::addSymbolizedDataAccessProfile(
+Error DataAccessProfData::setDataAccessProfile(
SymbolID SymbolID, uint64_t AccessCount,
const llvm::SmallVector<DataLocation> &Locations) {
- if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
auto &Record = Records.back().second;
for (const auto &Location : Locations)
Record.Locations.push_back(
- {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+ {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
Location.Line});
return Error::success();
@@ -108,7 +108,8 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
uint64_t NumColdKnownSymbols =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+ if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
+ NumColdKnownSymbols))
return E;
uint64_t Num =
@@ -120,7 +121,7 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
return deserializeRecords(Ptr);
}
-Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
OS.write(StrToIndexMap.size());
OS.write(KnownColdSymbols.size());
@@ -158,7 +159,7 @@ uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
}
Error DataAccessProfData::serialize(ProfOStream &OS) const {
- if (Error E = serializeStrings(OS))
+ if (Error E = serializeSymbolsAndFilenames(OS))
return E;
OS.write(KnownColdHashes.size());
for (const auto &Hash : KnownColdHashes)
@@ -177,9 +178,9 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
return Error::success();
}
-Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
- uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols) {
+Error DataAccessProfData::deserializeSymbolsAndFilenames(
+ const unsigned char *&Ptr, uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols) {
uint64_t Len =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
@@ -188,9 +189,9 @@ Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
uint64_t StringCnt = 0;
std::function<Error(StringRef)> addName = [&](StringRef Name) {
if (StringCnt < NumSampledSymbols)
- saveStringToMap(StrToIndexMap, saver, Name);
+ saveStringToMap(StrToIndexMap, Saver, Name);
else
- KnownColdSymbols.insert(saver.save(Name));
+ KnownColdSymbols.insert(Saver.save(Name));
++StringCnt;
return Error::success();
};
@@ -223,7 +224,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
SymbolID = ID;
else
SymbolID = Strings[ID];
- if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
auto &Record = Records.back().second;
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index f6362448a5734..b7b8d642ad930 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -764,18 +764,17 @@ static std::string ErrorToString(Error E) {
TEST(MemProf, DataAccessProfileError) {
// Returns error if the input symbol name is empty.
llvm::data_access_prof::DataAccessProfData Data;
- EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
HasSubstr("Empty symbol name"));
// Returns error when the same symbol gets added twice.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
- EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+ ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
HasSubstr("Duplicate symbol or string literal added"));
// Returns error when the same string content hash gets added twice.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
- EXPECT_THAT(ErrorToString(
- Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+ ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
HasSubstr("Duplicate symbol or string literal added"));
}
@@ -788,16 +787,16 @@ TEST(MemProf, DataAccessProfile) {
// In the bool conversion, Error is true if it's in a failure state and false
// if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+ ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
- {
- DataLocation{"file2", 3},
- }));
+ ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+ {
+ DataLocation{"file2", 3},
+ }));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+ ASSERT_FALSE(Data.setDataAccessProfile(
(uint64_t)135246, 1000,
{DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
>From 80249bce82307ffef1817747d1c2fca100a9ea53 Mon Sep 17 00:00:00 2001
From: Mingming Liu <mingmingl at google.com>
Date: Tue, 6 May 2025 10:39:53 -0700
Subject: [PATCH 03/10] Apply suggestions from code review
Co-authored-by: Kazu Hirata <kazu at google.com>
---
llvm/include/llvm/ProfileData/DataAccessProf.h | 11 ++++++-----
llvm/include/llvm/ProfileData/InstrProf.h | 2 +-
llvm/lib/ProfileData/DataAccessProf.cpp | 2 +-
3 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 36648d6298ee5..81289b60b96ed 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -97,7 +97,7 @@ class DataAccessProfData {
/// Methods to add symbolized data access profile. Returns error if duplicated
/// symbol names or content hashes are seen. The user of this class should
- /// aggregate counters that corresponds to the same symbol name or with the
+ /// aggregate counters that correspond to the same symbol name or with the
/// same string literal hash before calling 'add*' methods.
Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
/// Similar to the method above, for records with \p Locations representing
@@ -108,7 +108,8 @@ class DataAccessProfData {
const llvm::SmallVector<DataLocation> &Locations);
Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
- /// Returns a iterable StringRef for strings in the order they are added.
+ /// Returns an iterable StringRef for strings in the order they are added.
+ /// Each string may be a symbol name or a file name.
auto getStrings() const {
ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
StrToIndexMap.begin(), StrToIndexMap.end());
@@ -116,15 +117,15 @@ class DataAccessProfData {
}
/// Returns array reference for various internal data structures.
- inline ArrayRef<
+ ArrayRef<
std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
getRecords() const {
return Records.getArrayRef();
}
- inline ArrayRef<StringRef> getKnownColdSymbols() const {
+ ArrayRef<StringRef> getKnownColdSymbols() const {
return KnownColdSymbols.getArrayRef();
}
- inline ArrayRef<uint64_t> getKnownColdHashes() const {
+ ArrayRef<uint64_t> getKnownColdHashes() const {
return KnownColdHashes.getArrayRef();
}
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 8a6be22bdb1a4..710b2f6836064 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,7 +357,7 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
/// the duplicated profile variables for Comdat functions.
bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
-/// \c NameStrings is a string composed of one of more possibly encoded
+/// \c NameStrings is a string composed of one or more possibly encoded
/// sub-strings. The substrings are separated by 0 or more zero bytes. This
/// method decodes the string and calls `NameCallback` for each substring.
Error readAndDecodeStrings(StringRef NameStrings,
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index c52533c13919c..a42ee41b24358 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -141,7 +141,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
// Record the length of compressed string.
OS.write(CompressedStringLen);
// Write the chars in compressed strings.
- for (auto &c : CompressedStrings)
+ for (char C : CompressedStrings)
OS.writeByte(static_cast<uint8_t>(c));
// Pad up to a multiple of 8.
// InstrProfReader could read bytes according to 'CompressedStringLen'.
>From b69c9930b9584a8dede96b0a9605c03f4d504bb8 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 10:50:59 -0700
Subject: [PATCH 04/10] resolve review feedback
---
.../include/llvm/ProfileData/DataAccessProf.h | 43 ++---
llvm/include/llvm/ProfileData/InstrProf.h | 5 +-
llvm/lib/ProfileData/DataAccessProf.cpp | 72 ++++---
llvm/unittests/ProfileData/CMakeLists.txt | 1 +
.../ProfileData/DataAccessProfTest.cpp | 181 ++++++++++++++++++
llvm/unittests/ProfileData/MemProfTest.cpp | 160 ----------------
6 files changed, 245 insertions(+), 217 deletions(-)
create mode 100644 llvm/unittests/ProfileData/DataAccessProfTest.cpp
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 81289b60b96ed..a85a3320ae57c 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -70,9 +70,11 @@ struct DataAccessProfRecord {
/// This class provides profile look-up, serialization and deserialization.
class DataAccessProfData {
public:
- // SymbolID is either a string representing symbol name, or a uint64_t
- // representing the content hash of a string literal.
- using SymbolID = std::variant<StringRef, uint64_t>;
+ // SymbolID is either a string representing symbol name if the symbol has
+ // stable mangled name relative to source code, or a uint64_t representing the
+ // content hash of a string literal (with unstable name patterns like
+ // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+ using SymbolHandle = std::variant<StringRef, uint64_t>;
using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
DataAccessProfData() : Saver(Allocator) {}
@@ -90,38 +92,32 @@ class DataAccessProfData {
/// Returns a pointer of profile record for \p SymbolID, or nullptr if there
/// isn't a record. Internally, this function will canonicalize the symbol
/// name before the lookup.
- const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+ const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
/// Returns true if \p SymID is seen in profiled binaries and cold.
- bool isKnownColdSymbol(const SymbolID SymID) const;
+ bool isKnownColdSymbol(const SymbolHandle SymID) const;
- /// Methods to add symbolized data access profile. Returns error if duplicated
+ /// Methods to set symbolized data access profile. Returns error if duplicated
/// symbol names or content hashes are seen. The user of this class should
/// aggregate counters that correspond to the same symbol name or with the
- /// same string literal hash before calling 'add*' methods.
- Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+ /// same string literal hash before calling 'set*' methods.
+ Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
/// Similar to the method above, for records with \p Locations representing
/// the `filename:line` where this symbol shows up. Note because of linker's
/// merge of identical symbols (e.g., unnamed_addr string literals), one
/// symbol is likely to have multiple locations.
- Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
- const llvm::SmallVector<DataLocation> &Locations);
- Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+ Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
+ ArrayRef<DataLocation> Locations);
+ Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
/// Returns an iterable StringRef for strings in the order they are added.
/// Each string may be a symbol name or a file name.
auto getStrings() const {
- ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
- StrToIndexMap.begin(), StrToIndexMap.end());
- return llvm::make_first_range(RefSymbolNames);
+ return llvm::make_first_range(StrToIndexMap.getArrayRef());
}
/// Returns array reference for various internal data structures.
- ArrayRef<
- std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
- getRecords() const {
- return Records.getArrayRef();
- }
+ auto getRecords() const { return Records.getArrayRef(); }
ArrayRef<StringRef> getKnownColdSymbols() const {
return KnownColdSymbols.getArrayRef();
}
@@ -137,13 +133,13 @@ class DataAccessProfData {
/// start of the next payload.
Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
const uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols);
+ const uint64_t NumColdKnownSymbols);
/// Decode the records and increment \p Ptr to the start of the next payload.
Error deserializeRecords(const unsigned char *&Ptr);
/// A helper function to compute a storage index for \p SymbolID.
- uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+ uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
// Keeps owned copies of the input strings.
// NOTE: Keep `Saver` initialized before other class members that reference
@@ -151,9 +147,8 @@ class DataAccessProfData {
llvm::BumpPtrAllocator Allocator;
llvm::UniqueStringSaver Saver;
- // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
- // its record index.
- MapVector<SymbolID, DataAccessProfRecord> Records;
+ // `Records` stores the records.
+ MapVector<SymbolHandle, DataAccessProfRecord> Records;
// Use MapVector to keep input order of strings for serialization and
// deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 710b2f6836064..33b93ea0a558a 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -358,8 +358,9 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
/// \c NameStrings is a string composed of one or more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
+/// sub-strings. The substrings are separated by `\01` (returned by
+/// InstrProf.h:getInstrProfNameSeparator). This method decodes the string and
+/// calls `NameCallback` for each substring.
Error readAndDecodeStrings(StringRef NameStrings,
std::function<Error(StringRef)> NameCallback);
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index a42ee41b24358..d1c034f639347 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -23,11 +23,22 @@ saveStringToMap(MapVector<StringRef, uint64_t> &Map,
return Iter;
}
+// Returns the canonical name or error.
+static Expected<StringRef> getCanonicalName(StringRef Name) {
+ if (Name.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+ return InstrProfSymtab::getCanonicalName(Name);
+}
+
const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
auto Key = SymbolID;
- if (std::holds_alternative<StringRef>(SymbolID))
- Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+ if (std::holds_alternative<StringRef>(SymbolID)) {
+ StringRef Name = std::get<StringRef>(SymbolID);
+ assert(!Name.empty() && "Empty symbol name");
+ Key = InstrProfSymtab::getCanonicalName(Name);
+ }
auto It = Records.find(Key);
if (It != Records.end())
@@ -36,30 +47,27 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
return nullptr;
}
-bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
if (std::holds_alternative<uint64_t>(SymID))
return KnownColdHashes.contains(std::get<uint64_t>(SymID));
return KnownColdSymbols.contains(std::get<StringRef>(SymID));
}
-Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
uint64_t AccessCount) {
uint64_t RecordID = -1;
bool IsStringLiteral = false;
- SymbolID Key;
+ SymbolHandle Key;
if (std::holds_alternative<uint64_t>(Symbol)) {
RecordID = std::get<uint64_t>(Symbol);
Key = RecordID;
IsStringLiteral = true;
} else {
- StringRef SymbolName = std::get<StringRef>(Symbol);
- if (SymbolName.empty())
- return make_error<StringError>("Empty symbol name",
- llvm::errc::invalid_argument);
-
- StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
- Key = CanonicalName;
- RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
+ auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
+ if (!CanonicalName)
+ return CanonicalName.takeError();
+ std::tie(Key, RecordID) =
+ *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
IsStringLiteral = false;
}
@@ -75,8 +83,8 @@ Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
}
Error DataAccessProfData::setDataAccessProfile(
- SymbolID SymbolID, uint64_t AccessCount,
- const llvm::SmallVector<DataLocation> &Locations) {
+ SymbolHandle SymbolID, uint64_t AccessCount,
+ ArrayRef<DataLocation> Locations) {
if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
@@ -89,17 +97,15 @@ Error DataAccessProfData::setDataAccessProfile(
return Error::success();
}
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
if (std::holds_alternative<uint64_t>(SymbolID)) {
KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
return Error::success();
}
- StringRef SymbolName = std::get<StringRef>(SymbolID);
- if (SymbolName.empty())
- return make_error<StringError>("Empty symbol name",
- llvm::errc::invalid_argument);
- StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
- KnownColdSymbols.insert(CanonicalSymName);
+ auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
+ if (!CanonicalName)
+ return CanonicalName.takeError();
+ KnownColdSymbols.insert(*CanonicalName);
return Error::success();
}
@@ -142,7 +148,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
OS.write(CompressedStringLen);
// Write the chars in compressed strings.
for (char C : CompressedStrings)
- OS.writeByte(static_cast<uint8_t>(c));
+ OS.writeByte(static_cast<uint8_t>(C));
// Pad up to a multiple of 8.
// InstrProfReader could read bytes according to 'CompressedStringLen'.
const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
@@ -151,11 +157,15 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
return Error::success();
}
-uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+uint64_t
+DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
if (std::holds_alternative<uint64_t>(SymbolID))
return std::get<uint64_t>(SymbolID);
- return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+ auto Iter = StrToIndexMap.find(std::get<StringRef>(SymbolID));
+ assert(Iter != StrToIndexMap.end() &&
+ "String literals not found in StrToIndexMap");
+ return Iter->second;
}
Error DataAccessProfData::serialize(ProfOStream &OS) const {
@@ -179,13 +189,13 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
}
Error DataAccessProfData::deserializeSymbolsAndFilenames(
- const unsigned char *&Ptr, uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols) {
+ const unsigned char *&Ptr, const uint64_t NumSampledSymbols,
+ const uint64_t NumColdKnownSymbols) {
uint64_t Len =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
- // symbols with samples, and next N strings are known cold symbols.
+ // The first NumSampledSymbols strings are symbols with samples, and next
+ // NumColdKnownSymbols strings are known cold symbols.
uint64_t StringCnt = 0;
std::function<Error(StringRef)> addName = [&](StringRef Name) {
if (StringCnt < NumSampledSymbols)
@@ -219,7 +229,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
uint64_t AccessCount =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- SymbolID SymbolID;
+ SymbolHandle SymbolID;
if (IsStringLiteral)
SymbolID = ID;
else
diff --git a/llvm/unittests/ProfileData/CMakeLists.txt b/llvm/unittests/ProfileData/CMakeLists.txt
index 0a7f7da085950..29b9cb751dabe 100644
--- a/llvm/unittests/ProfileData/CMakeLists.txt
+++ b/llvm/unittests/ProfileData/CMakeLists.txt
@@ -10,6 +10,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(ProfileDataTests
BPFunctionNodeTest.cpp
CoverageMappingTest.cpp
+ DataAccessProfTest.cpp
InstrProfDataTest.cpp
InstrProfTest.cpp
ItaniumManglingCanonicalizerTest.cpp
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
new file mode 100644
index 0000000000000..50c4af49fe76b
--- /dev/null
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -0,0 +1,181 @@
+
+//===- unittests/Support/DataAccessProfTest.cpp
+//----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+namespace data_access_prof {
+namespace {
+
+using ::llvm::StringRef;
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+
+static std::string ErrorToString(Error E) {
+ std::string ErrMsg;
+ llvm::raw_string_ostream OS(ErrMsg);
+ llvm::logAllUnhandledErrors(std::move(E), OS);
+ return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+ // Returns error if the input symbol name is empty.
+ DataAccessProfData Data;
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
+ HasSubstr("Empty symbol name"));
+
+ // Returns error when the same symbol gets added twice.
+ ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
+ HasSubstr("Duplicate symbol or string literal added"));
+
+ // Returns error when the same string content hash gets added twice.
+ ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
+ HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+ DataAccessProfData Data;
+
+ // In the bool conversion, Error is true if it's in a failure state and false
+ // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+ ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+ ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+ {
+ DataLocation{"file2", 3},
+ }));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+ ASSERT_FALSE(Data.setDataAccessProfile(
+ (uint64_t)135246, 1000,
+ {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+ {
+ // Test that symbol names and file names are stored in the input order.
+ EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles.
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+ EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+ EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+ EXPECT_THAT(
+ *Data.getProfileRecord("foo.llvm.123"),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())));
+ EXPECT_THAT(
+ *Data.getProfileRecord("bar.__uniq.321"),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))));
+ EXPECT_THAT(
+ *Data.getProfileRecord((uint64_t)135246),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2))))));
+ }
+
+ // Tests serialization and de-serialization.
+ DataAccessProfData deserializedData;
+ {
+ std::string serializedData;
+ llvm::raw_string_ostream OS(serializedData);
+ llvm::ProfOStream POS(OS);
+
+ EXPECT_FALSE(Data.serialize(POS));
+
+ const unsigned char *p =
+ reinterpret_cast<const unsigned char *>(serializedData.data());
+ ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ testing::IsEmpty());
+ EXPECT_FALSE(deserializedData.deserialize(p));
+
+ EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+ ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles after deserialization.
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+ auto Records =
+ llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+ EXPECT_THAT(
+ Records,
+ ElementsAre(
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2)))))));
+ }
+}
+} // namespace
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index b7b8d642ad930..3e430aa4eae58 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,17 +10,14 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLForwardCompat.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/DebugInfo/DIContext.h"
#include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
#include "llvm/IR/Value.h"
#include "llvm/Object/ObjectFile.h"
-#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/MemProfData.inc"
#include "llvm/ProfileData/MemProfReader.h"
#include "llvm/ProfileData/MemProfYAML.h"
#include "llvm/Support/raw_ostream.h"
-#include "gmock/gmock-more-matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -39,8 +36,6 @@ using ::llvm::StringRef;
using ::llvm::object::SectionedAddress;
using ::llvm::symbolize::SymbolizableModule;
using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
-using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Pair;
using ::testing::Return;
@@ -752,161 +747,6 @@ TEST(MemProf, YAMLParser) {
ElementsAre(0x3000)))));
}
-static std::string ErrorToString(Error E) {
- std::string ErrMsg;
- llvm::raw_string_ostream OS(ErrMsg);
- llvm::logAllUnhandledErrors(std::move(E), OS);
- return ErrMsg;
-}
-
-// Test the various scenarios when DataAccessProfData should return error on
-// invalid input.
-TEST(MemProf, DataAccessProfileError) {
- // Returns error if the input symbol name is empty.
- llvm::data_access_prof::DataAccessProfData Data;
- EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
- HasSubstr("Empty symbol name"));
-
- // Returns error when the same symbol gets added twice.
- ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
- EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
- HasSubstr("Duplicate symbol or string literal added"));
-
- // Returns error when the same string content hash gets added twice.
- ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
- EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
- HasSubstr("Duplicate symbol or string literal added"));
-}
-
-// Test the following operations on DataAccessProfData:
-// - Profile record look up.
-// - Serialization and de-serialization.
-TEST(MemProf, DataAccessProfile) {
- using namespace llvm::data_access_prof;
- llvm::data_access_prof::DataAccessProfData Data;
-
- // In the bool conversion, Error is true if it's in a failure state and false
- // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
- ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
- ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
- ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
- ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
- {
- DataLocation{"file2", 3},
- }));
- ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
- ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
- ASSERT_FALSE(Data.setDataAccessProfile(
- (uint64_t)135246, 1000,
- {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
-
- {
- // Test that symbol names and file names are stored in the input order.
- EXPECT_THAT(llvm::to_vector(Data.getStrings()),
- ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
- EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
- EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
-
- // Look up profiles.
- EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
- EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
- EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
- EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
-
- EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
- EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
-
- EXPECT_THAT(
- Data.getProfileRecord("foo.llvm.123"),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
- testing::Field(&DataAccessProfRecord::AccessCount, 100),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- testing::IsEmpty())));
- EXPECT_THAT(
- *Data.getProfileRecord("bar.__uniq.321"),
- AllOf(
- testing::Field(&DataAccessProfRecord::SymbolID, 1),
- testing::Field(&DataAccessProfRecord::AccessCount, 123),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- ElementsAre(AllOf(
- testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 3))))));
- EXPECT_THAT(
- *Data.getProfileRecord((uint64_t)135246),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
- testing::Field(&DataAccessProfRecord::AccessCount, 1000),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(
- AllOf(testing::Field(&DataLocation::FileName, "file1"),
- testing::Field(&DataLocation::Line, 1)),
- AllOf(testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 2))))));
- }
-
- // Tests serialization and de-serialization.
- llvm::data_access_prof::DataAccessProfData deserializedData;
- {
- std::string serializedData;
- llvm::raw_string_ostream OS(serializedData);
- llvm::ProfOStream POS(OS);
-
- EXPECT_FALSE(Data.serialize(POS));
-
- const unsigned char *p =
- reinterpret_cast<const unsigned char *>(serializedData.data());
- ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
- testing::IsEmpty());
- EXPECT_FALSE(deserializedData.deserialize(p));
-
- EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
- ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
- EXPECT_THAT(deserializedData.getKnownColdSymbols(),
- ElementsAre("sym2", "sym1"));
- EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
-
- // Look up profiles after deserialization.
- EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
- EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
- EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
- EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
-
- auto Records =
- llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
-
- EXPECT_THAT(
- Records,
- ElementsAre(
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
- testing::Field(&DataAccessProfRecord::AccessCount, 100),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- testing::IsEmpty())),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
- testing::Field(&DataAccessProfRecord::AccessCount, 123),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(AllOf(
- testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 3))))),
- AllOf(
- testing::Field(&DataAccessProfRecord::SymbolID, 135246),
- testing::Field(&DataAccessProfRecord::AccessCount, 1000),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(
- AllOf(testing::Field(&DataLocation::FileName, "file1"),
- testing::Field(&DataLocation::Line, 1)),
- AllOf(testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 2)))))));
- }
-}
-
// Verify that the YAML parser accepts a GUID expressed as a function name.
TEST(MemProf, YAMLParserGUID) {
StringRef YAMLData = R"YAML(
>From df080949a1ccc75fceb386b0232005f96397752d Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 16:35:14 -0700
Subject: [PATCH 05/10] resolve feedback
---
llvm/include/llvm/ProfileData/DataAccessProf.h | 6 +++++-
llvm/lib/ProfileData/DataAccessProf.cpp | 12 +++++++++---
2 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index a85a3320ae57c..91de43fdf60ca 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -52,7 +52,11 @@ struct DataAccessProfRecord {
// Represents a data symbol. The semantic comes in two forms: a symbol index
// for symbol name if `IsStringLiteral` is false, or the hash of a string
- // content if `IsStringLiteral` is true. Required.
+ // content if `IsStringLiteral` is true. For most of the symbolizable static
+ // data, the mangled symbol names remain stable relative to the source code
+ // and therefore used to identify symbols across binary releases. String
+ // literals have unstable name patterns like `.str.N[.llvm.hash]`, so we use
+ // the content hash instead. This is a required field.
uint64_t SymbolID;
// The access count of symbol. Required.
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d1c034f639347..d7e67f5f09cbe 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -35,9 +35,15 @@ const DataAccessProfRecord *
DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
auto Key = SymbolID;
if (std::holds_alternative<StringRef>(SymbolID)) {
- StringRef Name = std::get<StringRef>(SymbolID);
- assert(!Name.empty() && "Empty symbol name");
- Key = InstrProfSymtab::getCanonicalName(Name);
+ auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
+ // If name canonicalization fails, suppress the error inside.
+ if (!NameOrErr) {
+ assert(
+ std::get<StringRef>(SymbolID).empty() &&
+ "Name canonicalization only fails when stringified string is empty.");
+ return nullptr;
+ }
+ Key = *NameOrErr;
}
auto It = Records.find(Key);
>From 6dd04e46542851b84bf26cd95245399204072085 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 12 May 2025 17:05:19 -0700
Subject: [PATCH 06/10] resolve comments
---
.../include/llvm/ProfileData/DataAccessProf.h | 119 ++++++++++++------
llvm/include/llvm/ProfileData/InstrProf.h | 2 +-
llvm/lib/ProfileData/DataAccessProf.cpp | 49 ++++----
.../ProfileData/DataAccessProfTest.cpp | 116 ++++++++---------
4 files changed, 167 insertions(+), 119 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 91de43fdf60ca..e8504102238d1 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -30,23 +30,40 @@
#include "llvm/Support/StringSaver.h"
#include <cstdint>
+#include <optional>
#include <variant>
namespace llvm {
namespace data_access_prof {
-// The location of data in the source code.
-struct DataLocation {
+
+/// The location of data in the source code. Used by profile lookup API.
+struct SourceLocation {
+ SourceLocation(StringRef FileNameRef, uint32_t Line)
+ : FileName(FileNameRef.str()), Line(Line) {}
+ /// The filename where the data is located.
+ std::string FileName;
+ /// The line number in the source code.
+ uint32_t Line;
+};
+
+namespace internal {
+
+// Conceptually similar to SourceLocation except that FileNames are StringRef of
+// which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData`
+// to represent data locations internally.
+struct SourceLocationRef {
// The filename where the data is located.
StringRef FileName;
// The line number in the source code.
uint32_t Line;
};
-// The data access profiles for a symbol.
-struct DataAccessProfRecord {
- DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
- bool IsStringLiteral)
+// The data access profiles for a symbol. Used by `DataAccessProfData`
+// to represent records internally.
+struct DataAccessProfRecordRef {
+ DataAccessProfRecordRef(uint64_t SymbolID, uint64_t AccessCount,
+ bool IsStringLiteral)
: SymbolID(SymbolID), AccessCount(AccessCount),
IsStringLiteral(IsStringLiteral) {}
@@ -67,18 +84,43 @@ struct DataAccessProfRecord {
bool IsStringLiteral;
// The locations of data in the source code. Optional.
- llvm::SmallVector<DataLocation, 0> Locations;
+ llvm::SmallVector<SourceLocationRef, 0> Locations;
};
+} // namespace internal
+
+// SymbolID is either a string representing symbol name if the symbol has
+// stable mangled name relative to source code, or a uint64_t representing the
+// content hash of a string literal (with unstable name patterns like
+// `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+using SymbolHandleRef = std::variant<StringRef, uint64_t>;
-/// Encapsulates the data access profile data and the methods to operate on it.
-/// This class provides profile look-up, serialization and deserialization.
+// The senamtic is the same as `SymbolHandleRef` above. The strings are owned.
+using SymbolHandle = std::variant<std::string, uint64_t>;
+
+/// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+public:
+ DataAccessProfRecord(SymbolHandleRef SymHandleRef,
+ ArrayRef<internal::SourceLocationRef> LocRefs) {
+ if (std::holds_alternative<StringRef>(SymHandleRef)) {
+ SymHandle = std::get<StringRef>(SymHandleRef).str();
+ } else
+ SymHandle = std::get<uint64_t>(SymHandleRef);
+
+ for (auto Loc : LocRefs)
+ Locations.push_back(SourceLocation(Loc.FileName, Loc.Line));
+ }
+ SymbolHandle SymHandle;
+
+ // The locations of data in the source code. Optional.
+ SmallVector<SourceLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on
+/// it. This class provides profile look-up, serialization and
+/// deserialization.
class DataAccessProfData {
public:
- // SymbolID is either a string representing symbol name if the symbol has
- // stable mangled name relative to source code, or a uint64_t representing the
- // content hash of a string literal (with unstable name patterns like
- // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
- using SymbolHandle = std::variant<StringRef, uint64_t>;
using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
DataAccessProfData() : Saver(Allocator) {}
@@ -93,35 +135,39 @@ class DataAccessProfData {
/// Deserialize this class from the given buffer.
Error deserialize(const unsigned char *&Ptr);
- /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+ /// Returns a profile record for \p SymbolID, or std::nullopt if there
/// isn't a record. Internally, this function will canonicalize the symbol
/// name before the lookup.
- const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
+ std::optional<DataAccessProfRecord>
+ getProfileRecord(const SymbolHandleRef SymID) const;
/// Returns true if \p SymID is seen in profiled binaries and cold.
- bool isKnownColdSymbol(const SymbolHandle SymID) const;
+ bool isKnownColdSymbol(const SymbolHandleRef SymID) const;
- /// Methods to set symbolized data access profile. Returns error if duplicated
- /// symbol names or content hashes are seen. The user of this class should
- /// aggregate counters that correspond to the same symbol name or with the
- /// same string literal hash before calling 'set*' methods.
- Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
+ /// Methods to set symbolized data access profile. Returns error if
+ /// duplicated symbol names or content hashes are seen. The user of this
+ /// class should aggregate counters that correspond to the same symbol name
+ /// or with the same string literal hash before calling 'set*' methods.
+ Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount);
/// Similar to the method above, for records with \p Locations representing
/// the `filename:line` where this symbol shows up. Note because of linker's
/// merge of identical symbols (e.g., unnamed_addr string literals), one
/// symbol is likely to have multiple locations.
- Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
- ArrayRef<DataLocation> Locations);
- Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
-
- /// Returns an iterable StringRef for strings in the order they are added.
- /// Each string may be a symbol name or a file name.
- auto getStrings() const {
- return llvm::make_first_range(StrToIndexMap.getArrayRef());
+ Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount,
+ ArrayRef<SourceLocation> Locations);
+ /// Add a symbol that's seen in the profiled binary without samples.
+ Error addKnownSymbolWithoutSamples(SymbolHandleRef SymbolID);
+
+ /// The following methods return array reference for various internal data
+ /// structures.
+ ArrayRef<StringToIndexMap::value_type> getStrToIndexMapRef() const {
+ return StrToIndexMap.getArrayRef();
+ }
+ ArrayRef<
+ MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef>::value_type>
+ getRecords() const {
+ return Records.getArrayRef();
}
-
- /// Returns array reference for various internal data structures.
- auto getRecords() const { return Records.getArrayRef(); }
ArrayRef<StringRef> getKnownColdSymbols() const {
return KnownColdSymbols.getArrayRef();
}
@@ -139,11 +185,12 @@ class DataAccessProfData {
const uint64_t NumSampledSymbols,
const uint64_t NumColdKnownSymbols);
- /// Decode the records and increment \p Ptr to the start of the next payload.
+ /// Decode the records and increment \p Ptr to the start of the next
+ /// payload.
Error deserializeRecords(const unsigned char *&Ptr);
/// A helper function to compute a storage index for \p SymbolID.
- uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
+ uint64_t getEncodedIndex(const SymbolHandleRef SymbolID) const;
// Keeps owned copies of the input strings.
// NOTE: Keep `Saver` initialized before other class members that reference
@@ -152,7 +199,7 @@ class DataAccessProfData {
llvm::UniqueStringSaver Saver;
// `Records` stores the records.
- MapVector<SymbolHandle, DataAccessProfRecord> Records;
+ MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef> Records;
// Use MapVector to keep input order of strings for serialization and
// deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 33b93ea0a558a..544a59df43ed3 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -500,7 +500,7 @@ class InstrProfSymtab {
public:
using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
- // Returns the canonial name of the given PGOName. In a canonical name, all
+ // Returns the canonical name of the given PGOName. In a canonical name, all
// suffixes that begins with "." except ".__uniq." are stripped.
// FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
static StringRef getCanonicalName(StringRef PGOName);
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d7e67f5f09cbe..c5d0099977cfa 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -16,11 +16,11 @@ namespace data_access_prof {
// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
// creates an owned copy of `Str`, adds a map entry for it and returns the
// iterator.
-static MapVector<StringRef, uint64_t>::iterator
-saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+static std::pair<StringRef, uint64_t>
+saveStringToMap(DataAccessProfData::StringToIndexMap &Map,
llvm::UniqueStringSaver &Saver, StringRef Str) {
auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
- return Iter;
+ return *Iter;
}
// Returns the canonical name or error.
@@ -31,8 +31,8 @@ static Expected<StringRef> getCanonicalName(StringRef Name) {
return InstrProfSymtab::getCanonicalName(Name);
}
-const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
+std::optional<DataAccessProfRecord>
+DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
auto Key = SymbolID;
if (std::holds_alternative<StringRef>(SymbolID)) {
auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
@@ -41,40 +41,39 @@ DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
assert(
std::get<StringRef>(SymbolID).empty() &&
"Name canonicalization only fails when stringified string is empty.");
- return nullptr;
+ return std::nullopt;
}
Key = *NameOrErr;
}
auto It = Records.find(Key);
- if (It != Records.end())
- return &It->second;
+ if (It != Records.end()) {
+ return DataAccessProfRecord(Key, It->second.Locations);
+ }
- return nullptr;
+ return std::nullopt;
}
-bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const {
if (std::holds_alternative<uint64_t>(SymID))
return KnownColdHashes.contains(std::get<uint64_t>(SymID));
return KnownColdSymbols.contains(std::get<StringRef>(SymID));
}
-Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol,
uint64_t AccessCount) {
uint64_t RecordID = -1;
- bool IsStringLiteral = false;
- SymbolHandle Key;
- if (std::holds_alternative<uint64_t>(Symbol)) {
+ const bool IsStringLiteral = std::holds_alternative<uint64_t>(Symbol);
+ SymbolHandleRef Key;
+ if (IsStringLiteral) {
RecordID = std::get<uint64_t>(Symbol);
Key = RecordID;
- IsStringLiteral = true;
} else {
auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
if (!CanonicalName)
return CanonicalName.takeError();
std::tie(Key, RecordID) =
- *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
- IsStringLiteral = false;
+ saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
}
auto [Iter, Inserted] =
@@ -89,21 +88,22 @@ Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
}
Error DataAccessProfData::setDataAccessProfile(
- SymbolHandle SymbolID, uint64_t AccessCount,
- ArrayRef<DataLocation> Locations) {
+ SymbolHandleRef SymbolID, uint64_t AccessCount,
+ ArrayRef<SourceLocation> Locations) {
if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
auto &Record = Records.back().second;
for (const auto &Location : Locations)
Record.Locations.push_back(
- {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
+ {saveStringToMap(StrToIndexMap, Saver, Location.FileName).first,
Location.Line});
return Error::success();
}
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(
+ SymbolHandleRef SymbolID) {
if (std::holds_alternative<uint64_t>(SymbolID)) {
KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
return Error::success();
@@ -164,7 +164,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
}
uint64_t
-DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
+DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const {
if (std::holds_alternative<uint64_t>(SymbolID))
return std::get<uint64_t>(SymbolID);
@@ -220,7 +220,8 @@ Error DataAccessProfData::deserializeSymbolsAndFilenames(
}
Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
- SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+ SmallVector<StringRef> Strings =
+ llvm::to_vector(llvm::make_first_range(getStrToIndexMapRef()));
uint64_t NumRecords =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
@@ -235,7 +236,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
uint64_t AccessCount =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- SymbolHandle SymbolID;
+ SymbolHandleRef SymbolID;
if (IsStringLiteral)
SymbolID = ID;
else
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
index 50c4af49fe76b..127230d4805e7 100644
--- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -1,4 +1,3 @@
-
//===- unittests/Support/DataAccessProfTest.cpp
//----------------------------------===//
//
@@ -10,6 +9,7 @@
#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
#include "gmock/gmock-more-matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -19,7 +19,9 @@ namespace data_access_prof {
namespace {
using ::llvm::StringRef;
+using llvm::ValueIs;
using ::testing::ElementsAre;
+using ::testing::Field;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
@@ -53,6 +55,8 @@ TEST(MemProf, DataAccessProfileError) {
// - Profile record look up.
// - Serialization and de-serialization.
TEST(MemProf, DataAccessProfile) {
+ using internal::DataAccessProfRecordRef;
+ using internal::SourceLocationRef;
DataAccessProfData Data;
// In the bool conversion, Error is true if it's in a failure state and false
@@ -62,18 +66,19 @@ TEST(MemProf, DataAccessProfile) {
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
{
- DataLocation{"file2", 3},
+ SourceLocation{"file2", 3},
}));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
ASSERT_FALSE(Data.setDataAccessProfile(
(uint64_t)135246, 1000,
- {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+ {SourceLocation{"file1", 1}, SourceLocation{"file2", 2}}));
{
// Test that symbol names and file names are stored in the input order.
- EXPECT_THAT(llvm::to_vector(Data.getStrings()),
- ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(
+ llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
@@ -83,38 +88,34 @@ TEST(MemProf, DataAccessProfile) {
EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
- EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
- EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+ EXPECT_EQ(Data.getProfileRecord("non-existence"), std::nullopt);
+ EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), std::nullopt);
EXPECT_THAT(
- *Data.getProfileRecord("foo.llvm.123"),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
- testing::Field(&DataAccessProfRecord::AccessCount, 100),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- testing::IsEmpty())));
+ Data.getProfileRecord("foo.llvm.123"),
+ ValueIs(AllOf(
+ Field(&DataAccessProfRecord::SymHandle,
+ testing::VariantWith<std::string>(testing::Eq("foo"))),
+ Field(&DataAccessProfRecord::Locations, testing::IsEmpty()))));
EXPECT_THAT(
- *Data.getProfileRecord("bar.__uniq.321"),
- AllOf(
- testing::Field(&DataAccessProfRecord::SymbolID, 1),
- testing::Field(&DataAccessProfRecord::AccessCount, 123),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- ElementsAre(AllOf(
- testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 3))))));
+ Data.getProfileRecord("bar.__uniq.321"),
+ ValueIs(AllOf(
+ Field(&DataAccessProfRecord::SymHandle,
+ testing::VariantWith<std::string>(
+ testing::Eq("bar.__uniq.321"))),
+ Field(&DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(Field(&SourceLocation::FileName, "file2"),
+ Field(&SourceLocation::Line, 3)))))));
EXPECT_THAT(
- *Data.getProfileRecord((uint64_t)135246),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
- testing::Field(&DataAccessProfRecord::AccessCount, 1000),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(
- AllOf(testing::Field(&DataLocation::FileName, "file1"),
- testing::Field(&DataLocation::Line, 1)),
- AllOf(testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 2))))));
+ Data.getProfileRecord((uint64_t)135246),
+ ValueIs(AllOf(
+ Field(&DataAccessProfRecord::SymHandle,
+ testing::VariantWith<uint64_t>(testing::Eq(135246))),
+ Field(&DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(Field(&SourceLocation::FileName, "file1"),
+ Field(&SourceLocation::Line, 1)),
+ AllOf(Field(&SourceLocation::FileName, "file2"),
+ Field(&SourceLocation::Line, 2)))))));
}
// Tests serialization and de-serialization.
@@ -128,11 +129,13 @@ TEST(MemProf, DataAccessProfile) {
const unsigned char *p =
reinterpret_cast<const unsigned char *>(serializedData.data());
- ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ ASSERT_THAT(llvm::to_vector(llvm::make_first_range(
+ deserializedData.getStrToIndexMapRef())),
testing::IsEmpty());
EXPECT_FALSE(deserializedData.deserialize(p));
- EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ EXPECT_THAT(llvm::to_vector(llvm::make_first_range(
+ deserializedData.getStrToIndexMapRef())),
ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
EXPECT_THAT(deserializedData.getKnownColdSymbols(),
ElementsAre("sym2", "sym1"));
@@ -150,30 +153,27 @@ TEST(MemProf, DataAccessProfile) {
EXPECT_THAT(
Records,
ElementsAre(
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
- testing::Field(&DataAccessProfRecord::AccessCount, 100),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(&DataAccessProfRecord::Locations,
- testing::IsEmpty())),
- AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
- testing::Field(&DataAccessProfRecord::AccessCount, 123),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(AllOf(
- testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 3))))),
AllOf(
- testing::Field(&DataAccessProfRecord::SymbolID, 135246),
- testing::Field(&DataAccessProfRecord::AccessCount, 1000),
- testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
- testing::Field(
- &DataAccessProfRecord::Locations,
- ElementsAre(
- AllOf(testing::Field(&DataLocation::FileName, "file1"),
- testing::Field(&DataLocation::Line, 1)),
- AllOf(testing::Field(&DataLocation::FileName, "file2"),
- testing::Field(&DataLocation::Line, 2)))))));
+ Field(&DataAccessProfRecordRef::SymbolID, 0),
+ Field(&DataAccessProfRecordRef::AccessCount, 100),
+ Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+ Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())),
+ AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1),
+ Field(&DataAccessProfRecordRef::AccessCount, 123),
+ Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+ Field(&DataAccessProfRecordRef::Locations,
+ ElementsAre(
+ AllOf(Field(&SourceLocationRef::FileName, "file2"),
+ Field(&SourceLocationRef::Line, 3))))),
+ AllOf(Field(&DataAccessProfRecordRef::SymbolID, 135246),
+ Field(&DataAccessProfRecordRef::AccessCount, 1000),
+ Field(&DataAccessProfRecordRef::IsStringLiteral, true),
+ Field(&DataAccessProfRecordRef::Locations,
+ ElementsAre(
+ AllOf(Field(&SourceLocationRef::FileName, "file1"),
+ Field(&SourceLocationRef::Line, 1)),
+ AllOf(Field(&SourceLocationRef::FileName, "file2"),
+ Field(&SourceLocationRef::Line, 2)))))));
}
}
} // namespace
>From 4045c943966609cc9a92693752af0e29a19e1ef9 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 15 May 2025 09:56:52 -0700
Subject: [PATCH 07/10] Support reading and writing data access profiles in
memprof v4.
---
.../include/llvm/ProfileData/DataAccessProf.h | 12 +++-
.../llvm/ProfileData/IndexedMemProfData.h | 12 +++-
.../llvm/ProfileData/InstrProfReader.h | 6 +-
.../llvm/ProfileData/InstrProfWriter.h | 6 ++
llvm/include/llvm/ProfileData/MemProfReader.h | 15 +++++
llvm/include/llvm/ProfileData/MemProfYAML.h | 58 ++++++++++++++++++
llvm/lib/ProfileData/DataAccessProf.cpp | 6 +-
llvm/lib/ProfileData/IndexedMemProfData.cpp | 61 +++++++++++++++----
llvm/lib/ProfileData/InstrProfReader.cpp | 14 +++++
llvm/lib/ProfileData/InstrProfWriter.cpp | 20 ++++--
llvm/lib/ProfileData/MemProfReader.cpp | 30 +++++++++
.../tools/llvm-profdata/memprof-yaml.test | 11 ++++
llvm/tools/llvm-profdata/llvm-profdata.cpp | 4 ++
.../ProfileData/DataAccessProfTest.cpp | 11 ++--
14 files changed, 235 insertions(+), 31 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index e8504102238d1..f5f6abf0a2817 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -41,6 +41,8 @@ namespace data_access_prof {
struct SourceLocation {
SourceLocation(StringRef FileNameRef, uint32_t Line)
: FileName(FileNameRef.str()), Line(Line) {}
+
+ SourceLocation() {}
/// The filename where the data is located.
std::string FileName;
/// The line number in the source code.
@@ -53,6 +55,8 @@ namespace internal {
// which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData`
// to represent data locations internally.
struct SourceLocationRef {
+ SourceLocationRef(StringRef FileNameRef, uint32_t Line)
+ : FileName(FileNameRef), Line(Line) {}
// The filename where the data is located.
StringRef FileName;
// The line number in the source code.
@@ -100,8 +104,9 @@ using SymbolHandle = std::variant<std::string, uint64_t>;
/// The data access profiles for a symbol.
struct DataAccessProfRecord {
public:
- DataAccessProfRecord(SymbolHandleRef SymHandleRef,
- ArrayRef<internal::SourceLocationRef> LocRefs) {
+ DataAccessProfRecord(SymbolHandleRef SymHandleRef, uint64_t AccessCount,
+ ArrayRef<internal::SourceLocationRef> LocRefs)
+ : AccessCount(AccessCount) {
if (std::holds_alternative<StringRef>(SymHandleRef)) {
SymHandle = std::get<StringRef>(SymHandleRef).str();
} else
@@ -110,8 +115,9 @@ struct DataAccessProfRecord {
for (auto Loc : LocRefs)
Locations.push_back(SourceLocation(Loc.FileName, Loc.Line));
}
+ DataAccessProfRecord() {}
SymbolHandle SymHandle;
-
+ uint64_t AccessCount;
// The locations of data in the source code. Optional.
SmallVector<SourceLocation> Locations;
};
diff --git a/llvm/include/llvm/ProfileData/IndexedMemProfData.h b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
index 3c6c329d1c49d..66fa38472059b 100644
--- a/llvm/include/llvm/ProfileData/IndexedMemProfData.h
+++ b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
@@ -10,14 +10,20 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/MemProf.h"
+#include <functional>
+#include <optional>
+
namespace llvm {
// Write the MemProf data to OS.
-Error writeMemProf(ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
- memprof::IndexedVersion MemProfVersionRequested,
- bool MemProfFullSchema);
+Error writeMemProf(
+ ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+ memprof::IndexedVersion MemProfVersionRequested, bool MemProfFullSchema,
+ std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+ DataAccessProfileData);
} // namespace llvm
diff --git a/llvm/include/llvm/ProfileData/InstrProfReader.h b/llvm/include/llvm/ProfileData/InstrProfReader.h
index c250a9ede39bc..a3436e1dfe711 100644
--- a/llvm/include/llvm/ProfileData/InstrProfReader.h
+++ b/llvm/include/llvm/ProfileData/InstrProfReader.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/ProfileSummary.h"
#include "llvm/Object/BuildID.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/InstrProfCorrelator.h"
#include "llvm/ProfileData/MemProf.h"
@@ -703,10 +704,13 @@ class IndexedMemProfReader {
const unsigned char *CallStackBase = nullptr;
// The number of elements in the radix tree array.
unsigned RadixTreeSize = 0;
+ /// The data access profiles, deserialized from binary data.
+ std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
Error deserializeV2(const unsigned char *Start, const unsigned char *Ptr);
Error deserializeRadixTreeBased(const unsigned char *Start,
- const unsigned char *Ptr);
+ const unsigned char *Ptr,
+ memprof::IndexedVersion Version);
public:
IndexedMemProfReader() = default;
diff --git a/llvm/include/llvm/ProfileData/InstrProfWriter.h b/llvm/include/llvm/ProfileData/InstrProfWriter.h
index 67d85daa81623..cf1cec25c3cac 100644
--- a/llvm/include/llvm/ProfileData/InstrProfWriter.h
+++ b/llvm/include/llvm/ProfileData/InstrProfWriter.h
@@ -19,6 +19,7 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/Object/BuildID.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/MemProf.h"
#include "llvm/Support/Error.h"
@@ -81,6 +82,8 @@ class InstrProfWriter {
// Whether to generated random memprof hotness for testing.
bool MemprofGenerateRandomHotness;
+ std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
+
public:
// For memprof testing, random hotness can be assigned to the contexts if
// MemprofGenerateRandomHotness is enabled. The random seed can be either
@@ -122,6 +125,9 @@ class InstrProfWriter {
// Add a binary id to the binary ids list.
void addBinaryIds(ArrayRef<llvm::object::BuildID> BIs);
+ void addDataAccessProfData(
+ std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfile);
+
/// Merge existing function counts from the given writer.
void mergeRecordsFromWriter(InstrProfWriter &&IPW,
function_ref<void(Error)> Warn);
diff --git a/llvm/include/llvm/ProfileData/MemProfReader.h b/llvm/include/llvm/ProfileData/MemProfReader.h
index 29d9e57cae3e3..02defa189ea7e 100644
--- a/llvm/include/llvm/ProfileData/MemProfReader.h
+++ b/llvm/include/llvm/ProfileData/MemProfReader.h
@@ -228,6 +228,21 @@ class YAMLMemProfReader final : public MemProfReader {
create(std::unique_ptr<MemoryBuffer> Buffer);
void parse(StringRef YAMLData);
+
+ std::unique_ptr<data_access_prof::DataAccessProfData>
+ takeDataAccessProfData() {
+ return std::move(DataAccessProfileData);
+ }
+
+private:
+ // Called by `parse` to set data access profiles after parsing them from Yaml
+ // files.
+ void setDataAccessProfileData(
+ std::unique_ptr<data_access_prof::DataAccessProfData> Data) {
+ DataAccessProfileData = std::move(Data);
+ }
+
+ std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
};
} // namespace memprof
} // namespace llvm
diff --git a/llvm/include/llvm/ProfileData/MemProfYAML.h b/llvm/include/llvm/ProfileData/MemProfYAML.h
index 08dee253f615a..634edc4aa0122 100644
--- a/llvm/include/llvm/ProfileData/MemProfYAML.h
+++ b/llvm/include/llvm/ProfileData/MemProfYAML.h
@@ -2,6 +2,7 @@
#define LLVM_PROFILEDATA_MEMPROFYAML_H_
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/MemProf.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/YAMLTraits.h"
@@ -20,9 +21,24 @@ struct GUIDMemProfRecordPair {
MemProfRecord Record;
};
+// Helper struct to yamlify data_access_prof::DataAccessProfData. The struct
+// members use owned strings. This is for simplicity and assumes that most real
+// world use cases do look-ups and regression test scale is small.
+struct YamlDataAccessProfData {
+ std::vector<data_access_prof::DataAccessProfRecord> Records;
+ std::vector<uint64_t> KnownColdHashes;
+ std::vector<std::string> KnownColdSymbols;
+
+ bool isEmpty() const {
+ return Records.empty() && KnownColdHashes.empty() &&
+ KnownColdSymbols.empty();
+ }
+};
+
// The top-level data structure, only used with YAML for now.
struct AllMemProfData {
std::vector<GUIDMemProfRecordPair> HeapProfileRecords;
+ YamlDataAccessProfData YamlifiedDataAccessProfiles;
};
} // namespace memprof
@@ -206,9 +222,49 @@ template <> struct MappingTraits<memprof::GUIDMemProfRecordPair> {
}
};
+template <> struct MappingTraits<data_access_prof::SourceLocation> {
+ static void mapping(IO &Io, data_access_prof::SourceLocation &Loc) {
+ Io.mapOptional("FileName", Loc.FileName);
+ Io.mapOptional("Line", Loc.Line);
+ }
+};
+
+template <> struct MappingTraits<data_access_prof::DataAccessProfRecord> {
+ static void mapping(IO &Io, data_access_prof::DataAccessProfRecord &Rec) {
+ if (Io.outputting()) {
+ if (std::holds_alternative<std::string>(Rec.SymHandle)) {
+ Io.mapOptional("Symbol", std::get<std::string>(Rec.SymHandle));
+ } else {
+ Io.mapOptional("Hash", std::get<uint64_t>(Rec.SymHandle));
+ }
+ } else {
+ std::string SymName;
+ uint64_t Hash = 0;
+ Io.mapOptional("Symbol", SymName);
+ Io.mapOptional("Hash", Hash);
+ if (!SymName.empty()) {
+ Rec.SymHandle = SymName;
+ } else {
+ Rec.SymHandle = Hash;
+ }
+ }
+
+ Io.mapOptional("Locations", Rec.Locations);
+ }
+};
+
+template <> struct MappingTraits<memprof::YamlDataAccessProfData> {
+ static void mapping(IO &Io, memprof::YamlDataAccessProfData &Data) {
+ Io.mapOptional("SampledRecords", Data.Records);
+ Io.mapOptional("KnownColdSymbols", Data.KnownColdSymbols);
+ Io.mapOptional("KnownColdHashes", Data.KnownColdHashes);
+ }
+};
+
template <> struct MappingTraits<memprof::AllMemProfData> {
static void mapping(IO &Io, memprof::AllMemProfData &Data) {
Io.mapRequired("HeapProfileRecords", Data.HeapProfileRecords);
+ Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
}
};
@@ -234,5 +290,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::AllocationInfo)
LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::CallSiteInfo)
LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::GUIDMemProfRecordPair)
LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::GUIDHex64) // Used for CalleeGuids
+LLVM_YAML_IS_SEQUENCE_VECTOR(data_access_prof::DataAccessProfRecord)
+LLVM_YAML_IS_SEQUENCE_VECTOR(data_access_prof::SourceLocation)
#endif // LLVM_PROFILEDATA_MEMPROFYAML_H_
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index c5d0099977cfa..61a73fab7269f 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -48,7 +48,8 @@ DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
auto It = Records.find(Key);
if (It != Records.end()) {
- return DataAccessProfRecord(Key, It->second.Locations);
+ return DataAccessProfRecord(Key, It->second.AccessCount,
+ It->second.Locations);
}
return std::nullopt;
@@ -111,7 +112,8 @@ Error DataAccessProfData::addKnownSymbolWithoutSamples(
auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
if (!CanonicalName)
return CanonicalName.takeError();
- KnownColdSymbols.insert(*CanonicalName);
+ KnownColdSymbols.insert(
+ saveStringToMap(StrToIndexMap, Saver, *CanonicalName).first);
return Error::success();
}
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 3d20f7a7a5778..cc1b03101c880 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/ProfileData/MemProf.h"
@@ -216,7 +217,9 @@ static Error writeMemProfV2(ProfOStream &OS,
static Error writeMemProfRadixTreeBased(
ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
- memprof::IndexedVersion Version, bool MemProfFullSchema) {
+ memprof::IndexedVersion Version, bool MemProfFullSchema,
+ std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+ DataAccessProfileData) {
assert((Version == memprof::Version3 || Version == memprof::Version4) &&
"Unsupported version for radix tree format");
@@ -225,6 +228,8 @@ static Error writeMemProfRadixTreeBased(
OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
OS.write(0ULL); // Reserve space for the memprof record payload offset.
OS.write(0ULL); // Reserve space for the memprof record table offset.
+ if (Version == memprof::Version4)
+ OS.write(0ULL); // Reserve space for the data access profile offset.
auto Schema = memprof::getHotColdSchema();
if (MemProfFullSchema)
@@ -251,17 +256,26 @@ static Error writeMemProfRadixTreeBased(
uint64_t RecordTableOffset = writeMemProfRecords(
OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
+ uint64_t DataAccessProfOffset = 0;
+ if (DataAccessProfileData.has_value()) {
+ DataAccessProfOffset = OS.tell();
+ if (Error E = (*DataAccessProfileData).get().serialize(OS))
+ return E;
+ }
+
// Verify that the computation for the number of elements in the call stack
// array works.
assert(CallStackPayloadOffset +
NumElements * sizeof(memprof::LinearFrameId) ==
RecordPayloadOffset);
- uint64_t Header[] = {
+ SmallVector<uint64_t, 4> Header = {
CallStackPayloadOffset,
RecordPayloadOffset,
RecordTableOffset,
};
+ if (Version == memprof::Version4)
+ Header.push_back(DataAccessProfOffset);
OS.patch({{HeaderUpdatePos, Header}});
return Error::success();
@@ -272,28 +286,33 @@ static Error writeMemProfV3(ProfOStream &OS,
memprof::IndexedMemProfData &MemProfData,
bool MemProfFullSchema) {
return writeMemProfRadixTreeBased(OS, MemProfData, memprof::Version3,
- MemProfFullSchema);
+ MemProfFullSchema, std::nullopt);
}
// Write out MemProf Version4
-static Error writeMemProfV4(ProfOStream &OS,
- memprof::IndexedMemProfData &MemProfData,
- bool MemProfFullSchema) {
+static Error writeMemProfV4(
+ ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+ bool MemProfFullSchema,
+ std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+ DataAccessProfileData) {
return writeMemProfRadixTreeBased(OS, MemProfData, memprof::Version4,
- MemProfFullSchema);
+ MemProfFullSchema, DataAccessProfileData);
}
// Write out the MemProf data in a requested version.
-Error writeMemProf(ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
- memprof::IndexedVersion MemProfVersionRequested,
- bool MemProfFullSchema) {
+Error writeMemProf(
+ ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+ memprof::IndexedVersion MemProfVersionRequested, bool MemProfFullSchema,
+ std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+ DataAccessProfileData) {
switch (MemProfVersionRequested) {
case memprof::Version2:
return writeMemProfV2(OS, MemProfData, MemProfFullSchema);
case memprof::Version3:
return writeMemProfV3(OS, MemProfData, MemProfFullSchema);
case memprof::Version4:
- return writeMemProfV4(OS, MemProfData, MemProfFullSchema);
+ return writeMemProfV4(OS, MemProfData, MemProfFullSchema,
+ DataAccessProfileData);
}
return make_error<InstrProfError>(
@@ -357,7 +376,10 @@ Error IndexedMemProfReader::deserializeV2(const unsigned char *Start,
}
Error IndexedMemProfReader::deserializeRadixTreeBased(
- const unsigned char *Start, const unsigned char *Ptr) {
+ const unsigned char *Start, const unsigned char *Ptr,
+ memprof::IndexedVersion Version) {
+ assert((Version == memprof::Version3 || Version == memprof::Version4) &&
+ "Unsupported version for radix tree format");
// The offset in the stream right before invoking
// CallStackTableGenerator.Emit.
const uint64_t CallStackPayloadOffset =
@@ -369,6 +391,11 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
const uint64_t RecordTableOffset =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint64_t DataAccessProfOffset = 0;
+ if (Version == memprof::Version4)
+ DataAccessProfOffset =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
// Read the schema.
auto SchemaOr = memprof::readMemProfSchema(Ptr);
if (!SchemaOr)
@@ -390,6 +417,14 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
/*Payload=*/Start + RecordPayloadOffset,
/*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
+ if (DataAccessProfOffset > RecordTableOffset) {
+ DataAccessProfileData =
+ std::make_unique<data_access_prof::DataAccessProfData>();
+ const unsigned char *DAPPtr = Start + DataAccessProfOffset;
+ if (Error E = DataAccessProfileData->deserialize(DAPPtr))
+ return E;
+ }
+
return Error::success();
}
@@ -423,7 +458,7 @@ Error IndexedMemProfReader::deserialize(const unsigned char *Start,
case memprof::Version3:
case memprof::Version4:
// V3 and V4 share the same high-level structure (radix tree, linear IDs).
- if (Error E = deserializeRadixTreeBased(Start, Ptr))
+ if (Error E = deserializeRadixTreeBased(Start, Ptr, Version))
return E;
break;
}
diff --git a/llvm/lib/ProfileData/InstrProfReader.cpp b/llvm/lib/ProfileData/InstrProfReader.cpp
index e6c83430cd8e9..78aba992fcd65 100644
--- a/llvm/lib/ProfileData/InstrProfReader.cpp
+++ b/llvm/lib/ProfileData/InstrProfReader.cpp
@@ -1551,6 +1551,20 @@ memprof::AllMemProfData IndexedMemProfReader::getAllMemProfData() const {
Pair.Record = std::move(*Record);
AllMemProfData.HeapProfileRecords.push_back(std::move(Pair));
}
+ // Populate the data access profiles for yaml output.
+ if (DataAccessProfileData != nullptr) {
+ for (const auto &[SymHandleRef, RecordRef] :
+ DataAccessProfileData->getRecords())
+ AllMemProfData.YamlifiedDataAccessProfiles.Records.push_back(
+ data_access_prof::DataAccessProfRecord(
+ SymHandleRef, RecordRef.AccessCount, RecordRef.Locations));
+ for (StringRef ColdSymbol : DataAccessProfileData->getKnownColdSymbols())
+ AllMemProfData.YamlifiedDataAccessProfiles.KnownColdSymbols.push_back(
+ ColdSymbol.str());
+ for (uint64_t Hash : DataAccessProfileData->getKnownColdHashes())
+ AllMemProfData.YamlifiedDataAccessProfiles.KnownColdHashes.push_back(
+ Hash);
+ }
return AllMemProfData;
}
diff --git a/llvm/lib/ProfileData/InstrProfWriter.cpp b/llvm/lib/ProfileData/InstrProfWriter.cpp
index 2759346935b14..b0012ccc4f4bf 100644
--- a/llvm/lib/ProfileData/InstrProfWriter.cpp
+++ b/llvm/lib/ProfileData/InstrProfWriter.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/ProfileSummary.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/IndexedMemProfData.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/MemProf.h"
@@ -29,6 +30,7 @@
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <ctime>
+#include <functional>
#include <memory>
#include <string>
#include <tuple>
@@ -152,9 +154,7 @@ void InstrProfWriter::setValueProfDataEndianness(llvm::endianness Endianness) {
InfoObj->ValueProfDataEndianness = Endianness;
}
-void InstrProfWriter::setOutputSparse(bool Sparse) {
- this->Sparse = Sparse;
-}
+void InstrProfWriter::setOutputSparse(bool Sparse) { this->Sparse = Sparse; }
void InstrProfWriter::addRecord(NamedInstrProfRecord &&I, uint64_t Weight,
function_ref<void(Error)> Warn) {
@@ -329,6 +329,12 @@ void InstrProfWriter::addBinaryIds(ArrayRef<llvm::object::BuildID> BIs) {
llvm::append_range(BinaryIds, BIs);
}
+void InstrProfWriter::addDataAccessProfData(
+ std::unique_ptr<data_access_prof::DataAccessProfData>
+ DataAccessProfDataIn) {
+ DataAccessProfileData = std::move(DataAccessProfDataIn);
+}
+
void InstrProfWriter::addTemporalProfileTrace(TemporalProfTraceTy Trace) {
assert(Trace.FunctionNameRefs.size() <= MaxTemporalProfTraceLength);
assert(!Trace.FunctionNameRefs.empty());
@@ -614,8 +620,14 @@ Error InstrProfWriter::writeImpl(ProfOStream &OS) {
uint64_t MemProfSectionStart = 0;
if (static_cast<bool>(ProfileKind & InstrProfKind::MemProf)) {
MemProfSectionStart = OS.tell();
+ std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+ DAP = std::nullopt;
+ if (DataAccessProfileData.get() != nullptr)
+ DAP = std::ref(*DataAccessProfileData.get());
+
if (auto E = writeMemProf(OS, MemProfData, MemProfVersionRequested,
- MemProfFullSchema))
+ MemProfFullSchema, DAP))
+
return E;
}
diff --git a/llvm/lib/ProfileData/MemProfReader.cpp b/llvm/lib/ProfileData/MemProfReader.cpp
index e0f280b9eb2f6..2969029682e74 100644
--- a/llvm/lib/ProfileData/MemProfReader.cpp
+++ b/llvm/lib/ProfileData/MemProfReader.cpp
@@ -37,6 +37,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
@@ -823,6 +824,35 @@ void YAMLMemProfReader::parse(StringRef YAMLData) {
MemProfData.Records.try_emplace(GUID, std::move(IndexedRecord));
}
+
+ if (Doc.YamlifiedDataAccessProfiles.isEmpty())
+ return;
+
+ auto ToSymHandleRef = [](const data_access_prof::SymbolHandle &Handle)
+ -> data_access_prof::SymbolHandleRef {
+ if (std::holds_alternative<std::string>(Handle))
+ return StringRef(std::get<std::string>(Handle));
+ return std::get<uint64_t>(Handle);
+ };
+
+ auto DataAccessProfileData =
+ std::make_unique<data_access_prof::DataAccessProfData>();
+ for (const auto &Record : Doc.YamlifiedDataAccessProfiles.Records)
+ if (Error E = DataAccessProfileData->setDataAccessProfile(
+ ToSymHandleRef(Record.SymHandle), Record.AccessCount,
+ Record.Locations))
+ reportFatalInternalError(std::move(E));
+
+ for (const uint64_t Hash : Doc.YamlifiedDataAccessProfiles.KnownColdHashes)
+ if (Error E = DataAccessProfileData->addKnownSymbolWithoutSamples(Hash))
+ reportFatalInternalError(std::move(E));
+
+ for (const std::string &Sym :
+ Doc.YamlifiedDataAccessProfiles.KnownColdSymbols)
+ if (Error E = DataAccessProfileData->addKnownSymbolWithoutSamples(Sym))
+ reportFatalInternalError(std::move(E));
+
+ setDataAccessProfileData(std::move(DataAccessProfileData));
}
} // namespace memprof
} // namespace llvm
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 9766cc50f37d7..5e0c7fb3ea1d8 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -35,4 +35,15 @@ HeapProfileRecords:
- { Function: 0x7777777777777777, LineOffset: 77, Column: 70, IsInlineFrame: true }
- { Function: 0x8888888888888888, LineOffset: 88, Column: 80, IsInlineFrame: false }
CalleeGuids: [ 0x300 ]
+DataAccessProfiles:
+ SampledRecords:
+ - Symbol: abcde
+ - Hash: 101010
+ Locations:
+ - FileName: file
+ Line: 233
+ KnownColdSymbols:
+ - foo
+ - bar
+ KnownColdHashes: [ 999, 1001 ]
...
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index 885e06df6c390..8660eed6be2bf 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -16,6 +16,7 @@
#include "llvm/Debuginfod/HTTPClient.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Object/Binary.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/InstrProfCorrelator.h"
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/ProfileData/InstrProfWriter.h"
@@ -756,6 +757,8 @@ loadInput(const WeightedFile &Input, SymbolRemapper *Remapper,
auto MemProfData = Reader->takeMemProfData();
+ auto DataAccessProfData = Reader->takeDataAccessProfData();
+
// Check for the empty input in case the YAML file is invalid.
if (MemProfData.Records.empty()) {
WC->Errors.emplace_back(
@@ -764,6 +767,7 @@ loadInput(const WeightedFile &Input, SymbolRemapper *Remapper,
}
WC->Writer.addMemProfData(std::move(MemProfData), MemProfError);
+ WC->Writer.addDataAccessProfData(std::move(DataAccessProfData));
return;
}
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
index 127230d4805e7..084a8e96cdafe 100644
--- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -78,7 +78,7 @@ TEST(MemProf, DataAccessProfile) {
// Test that symbol names and file names are stored in the input order.
EXPECT_THAT(
llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())),
- ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1"));
EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
@@ -134,9 +134,10 @@ TEST(MemProf, DataAccessProfile) {
testing::IsEmpty());
EXPECT_FALSE(deserializedData.deserialize(p));
- EXPECT_THAT(llvm::to_vector(llvm::make_first_range(
- deserializedData.getStrToIndexMapRef())),
- ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(
+ llvm::to_vector(
+ llvm::make_first_range(deserializedData.getStrToIndexMapRef())),
+ ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1"));
EXPECT_THAT(deserializedData.getKnownColdSymbols(),
ElementsAre("sym2", "sym1"));
EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
@@ -158,7 +159,7 @@ TEST(MemProf, DataAccessProfile) {
Field(&DataAccessProfRecordRef::AccessCount, 100),
Field(&DataAccessProfRecordRef::IsStringLiteral, false),
Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())),
- AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1),
+ AllOf(Field(&DataAccessProfRecordRef::SymbolID, 2),
Field(&DataAccessProfRecordRef::AccessCount, 123),
Field(&DataAccessProfRecordRef::IsStringLiteral, false),
Field(&DataAccessProfRecordRef::Locations,
>From c5d8fcf23c5db83cfaa73ce6af6329b50a180fc1 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 16 May 2025 12:26:11 -0700
Subject: [PATCH 08/10] resolve comments
---
llvm/include/llvm/ProfileData/MemProfYAML.h | 5 +-
llvm/lib/ProfileData/IndexedMemProfData.cpp | 6 ++-
.../tools/llvm-profdata/memprof-yaml.test | 47 ++++++++++++++++++-
3 files changed, 53 insertions(+), 5 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/MemProfYAML.h b/llvm/include/llvm/ProfileData/MemProfYAML.h
index 634edc4aa0122..cdfa69b400ee6 100644
--- a/llvm/include/llvm/ProfileData/MemProfYAML.h
+++ b/llvm/include/llvm/ProfileData/MemProfYAML.h
@@ -264,7 +264,10 @@ template <> struct MappingTraits<memprof::YamlDataAccessProfData> {
template <> struct MappingTraits<memprof::AllMemProfData> {
static void mapping(IO &Io, memprof::AllMemProfData &Data) {
Io.mapRequired("HeapProfileRecords", Data.HeapProfileRecords);
- Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
+ // Map data access profiles if reading input, or if writing output &&
+ // the struct is populated.
+ if (!Io.outputting() || !Data.YamlifiedDataAccessProfiles.isEmpty())
+ Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
}
};
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index cc1b03101c880..dc351771f4153 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,7 +228,7 @@ static Error writeMemProfRadixTreeBased(
OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
OS.write(0ULL); // Reserve space for the memprof record payload offset.
OS.write(0ULL); // Reserve space for the memprof record table offset.
- if (Version == memprof::Version4)
+ if (Version >= memprof::Version4)
OS.write(0ULL); // Reserve space for the data access profile offset.
auto Schema = memprof::getHotColdSchema();
@@ -258,6 +258,8 @@ static Error writeMemProfRadixTreeBased(
uint64_t DataAccessProfOffset = 0;
if (DataAccessProfileData.has_value()) {
+ assert(Version >= memprof::Version4 &&
+ "Data access profiles are added starting from v4");
DataAccessProfOffset = OS.tell();
if (Error E = (*DataAccessProfileData).get().serialize(OS))
return E;
@@ -274,7 +276,7 @@ static Error writeMemProfRadixTreeBased(
RecordPayloadOffset,
RecordTableOffset,
};
- if (Version == memprof::Version4)
+ if (Version >= memprof::Version4)
Header.push_back(DataAccessProfOffset);
OS.patch({{HeaderUpdatePos, Header}});
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 5e0c7fb3ea1d8..f5b6cf86d2922 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -1,10 +1,17 @@
; RUN: split-file %s %t
; COM: The text format only supports the latest version.
+
+; Verify that the YAML output is identical to the YAML input.
+; memprof-in.yaml has both heap profile records and data access profiles.
; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in.yaml -o %t/memprof-out.indexed
; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out.yaml
; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
-; Verify that the YAML output is identical to the YAML input.
+; memprof-in-no-dap has empty data access profiles.
+; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in-no-dap.yaml -o %t/memprof-out.indexed
+; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out-no-dap.yaml
+; RUN: diff -b %t/memprof-in-no-dap.yaml %t/memprof-out-no-dap.yaml
+
;--- memprof-in.yaml
---
HeapProfileRecords:
@@ -38,12 +45,48 @@ HeapProfileRecords:
DataAccessProfiles:
SampledRecords:
- Symbol: abcde
+ Locations:
+ - FileName: file2.h
+ Line: 123
+ - FileName: file3.cpp
+ Line: 456
- Hash: 101010
Locations:
- - FileName: file
+ - FileName: file.cpp
Line: 233
KnownColdSymbols:
- foo
- bar
KnownColdHashes: [ 999, 1001 ]
...
+;--- memprof-in-no-dap.yaml
+---
+HeapProfileRecords:
+ - GUID: 0xdeadbeef12345678
+ AllocSites:
+ - Callstack:
+ - { Function: 0x1111111111111111, LineOffset: 11, Column: 10, IsInlineFrame: true }
+ - { Function: 0x2222222222222222, LineOffset: 22, Column: 20, IsInlineFrame: false }
+ MemInfoBlock:
+ AllocCount: 111
+ TotalSize: 222
+ TotalLifetime: 333
+ TotalLifetimeAccessDensity: 444
+ - Callstack:
+ - { Function: 0x3333333333333333, LineOffset: 33, Column: 30, IsInlineFrame: false }
+ - { Function: 0x4444444444444444, LineOffset: 44, Column: 40, IsInlineFrame: true }
+ MemInfoBlock:
+ AllocCount: 555
+ TotalSize: 666
+ TotalLifetime: 777
+ TotalLifetimeAccessDensity: 888
+ CallSites:
+ - Frames:
+ - { Function: 0x5555555555555555, LineOffset: 55, Column: 50, IsInlineFrame: true }
+ - { Function: 0x6666666666666666, LineOffset: 66, Column: 60, IsInlineFrame: false }
+ CalleeGuids: [ 0x100, 0x200 ]
+ - Frames:
+ - { Function: 0x7777777777777777, LineOffset: 77, Column: 70, IsInlineFrame: true }
+ - { Function: 0x8888888888888888, LineOffset: 88, Column: 80, IsInlineFrame: false }
+ CalleeGuids: [ 0x300 ]
+...
>From 0fb3e6a243415e3514e5904de48d5ca880026765 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 16 May 2025 16:24:09 -0700
Subject: [PATCH 09/10] record data access profile payload length
---
llvm/lib/ProfileData/IndexedMemProfData.cpp | 20 +++++++++++++++-----
1 file changed, 15 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index dc351771f4153..3ef568025c60c 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,8 +228,10 @@ static Error writeMemProfRadixTreeBased(
OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
OS.write(0ULL); // Reserve space for the memprof record payload offset.
OS.write(0ULL); // Reserve space for the memprof record table offset.
- if (Version >= memprof::Version4)
+ if (Version >= memprof::Version4) {
OS.write(0ULL); // Reserve space for the data access profile offset.
+ OS.write(0ULL); // Reserve space for the size of data access profiles.
+ }
auto Schema = memprof::getHotColdSchema();
if (MemProfFullSchema)
@@ -257,12 +259,14 @@ static Error writeMemProfRadixTreeBased(
OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
uint64_t DataAccessProfOffset = 0;
+ uint64_t DataAccessProfLength = 0;
if (DataAccessProfileData.has_value()) {
assert(Version >= memprof::Version4 &&
"Data access profiles are added starting from v4");
DataAccessProfOffset = OS.tell();
if (Error E = (*DataAccessProfileData).get().serialize(OS))
return E;
+ DataAccessProfLength = OS.tell() - DataAccessProfOffset;
}
// Verify that the computation for the number of elements in the call stack
@@ -271,13 +275,15 @@ static Error writeMemProfRadixTreeBased(
NumElements * sizeof(memprof::LinearFrameId) ==
RecordPayloadOffset);
- SmallVector<uint64_t, 4> Header = {
+ SmallVector<uint64_t, 8> Header = {
CallStackPayloadOffset,
RecordPayloadOffset,
RecordTableOffset,
};
- if (Version >= memprof::Version4)
+ if (Version >= memprof::Version4) {
Header.push_back(DataAccessProfOffset);
+ Header.push_back(DataAccessProfLength);
+ }
OS.patch({{HeaderUpdatePos, Header}});
return Error::success();
@@ -394,9 +400,13 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
uint64_t DataAccessProfOffset = 0;
- if (Version == memprof::Version4)
+ uint64_t DataAccessProfLength = 0;
+ if (Version == memprof::Version4) {
DataAccessProfOffset =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ DataAccessProfLength =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ }
// Read the schema.
auto SchemaOr = memprof::readMemProfSchema(Ptr);
@@ -419,7 +429,7 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
/*Payload=*/Start + RecordPayloadOffset,
/*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
- if (DataAccessProfOffset > RecordTableOffset) {
+ if (DataAccessProfLength > 0) {
DataAccessProfileData =
std::make_unique<data_access_prof::DataAccessProfData>();
const unsigned char *DAPPtr = Start + DataAccessProfOffset;
>From d3443c752d77d4ddefea472a4787c501c97bb092 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 19 May 2025 16:44:05 -0700
Subject: [PATCH 10/10] undo length record and check
---
llvm/lib/ProfileData/IndexedMemProfData.cpp | 23 +++++++------------
.../tools/llvm-profdata/memprof-yaml.test | 2 +-
2 files changed, 9 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 3ef568025c60c..d5f134629529c 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,10 +228,8 @@ static Error writeMemProfRadixTreeBased(
OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
OS.write(0ULL); // Reserve space for the memprof record payload offset.
OS.write(0ULL); // Reserve space for the memprof record table offset.
- if (Version >= memprof::Version4) {
+ if (Version >= memprof::Version4)
OS.write(0ULL); // Reserve space for the data access profile offset.
- OS.write(0ULL); // Reserve space for the size of data access profiles.
- }
auto Schema = memprof::getHotColdSchema();
if (MemProfFullSchema)
@@ -259,14 +257,12 @@ static Error writeMemProfRadixTreeBased(
OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
uint64_t DataAccessProfOffset = 0;
- uint64_t DataAccessProfLength = 0;
if (DataAccessProfileData.has_value()) {
assert(Version >= memprof::Version4 &&
"Data access profiles are added starting from v4");
DataAccessProfOffset = OS.tell();
if (Error E = (*DataAccessProfileData).get().serialize(OS))
return E;
- DataAccessProfLength = OS.tell() - DataAccessProfOffset;
}
// Verify that the computation for the number of elements in the call stack
@@ -275,15 +271,14 @@ static Error writeMemProfRadixTreeBased(
NumElements * sizeof(memprof::LinearFrameId) ==
RecordPayloadOffset);
- SmallVector<uint64_t, 8> Header = {
+ SmallVector<uint64_t, 4> Header = {
CallStackPayloadOffset,
RecordPayloadOffset,
RecordTableOffset,
};
- if (Version >= memprof::Version4) {
+ if (Version >= memprof::Version4)
Header.push_back(DataAccessProfOffset);
- Header.push_back(DataAccessProfLength);
- }
+
OS.patch({{HeaderUpdatePos, Header}});
return Error::success();
@@ -400,13 +395,9 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
uint64_t DataAccessProfOffset = 0;
- uint64_t DataAccessProfLength = 0;
- if (Version == memprof::Version4) {
+ if (Version == memprof::Version4)
DataAccessProfOffset =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- DataAccessProfLength =
- support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- }
// Read the schema.
auto SchemaOr = memprof::readMemProfSchema(Ptr);
@@ -429,7 +420,9 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
/*Payload=*/Start + RecordPayloadOffset,
/*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
- if (DataAccessProfLength > 0) {
+ assert((!DataAccessProfOffset || DataAccessProfOffset > RecordTableOffset) &&
+ "Data access profile is either empty or after the record table");
+ if (DataAccessProfOffset > RecordTableOffset) {
DataAccessProfileData =
std::make_unique<data_access_prof::DataAccessProfData>();
const unsigned char *DAPPtr = Start + DataAccessProfOffset;
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index f5b6cf86d2922..ff0d82d92e1f6 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -7,7 +7,7 @@
; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out.yaml
; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
-; memprof-in-no-dap has empty data access profiles.
+; memprof-in-no-dap.yaml has empty data access profiles.
; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in-no-dap.yaml -o %t/memprof-out.indexed
; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out-no-dap.yaml
; RUN: diff -b %t/memprof-in-no-dap.yaml %t/memprof-out-no-dap.yaml
More information about the llvm-commits
mailing list