[llvm] [StaticDataLayout][PGO] Add profile format for static data layout, and the classes to operate on the profiles. (PR #138170)
Mingming Liu via llvm-commits
llvm-commits at lists.llvm.org
Mon May 5 16:58:39 PDT 2025
https://github.com/mingmingl-llvm updated https://github.com/llvm/llvm-project/pull/138170
>From 6cd7d8dafd6a5cb438c5bd595e126f3cd863814e Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 1 May 2025 09:37:59 -0700
Subject: [PATCH 1/2] Add classes to construct and use data access profiles
---
llvm/include/llvm/ADT/MapVector.h | 2 +
.../include/llvm/ProfileData/DataAccessProf.h | 156 +++++++++++
llvm/include/llvm/ProfileData/InstrProf.h | 16 +-
llvm/lib/ProfileData/CMakeLists.txt | 1 +
llvm/lib/ProfileData/DataAccessProf.cpp | 246 ++++++++++++++++++
llvm/lib/ProfileData/InstrProf.cpp | 8 +-
llvm/unittests/ProfileData/MemProfTest.cpp | 161 ++++++++++++
7 files changed, 579 insertions(+), 11 deletions(-)
create mode 100644 llvm/include/llvm/ProfileData/DataAccessProf.h
create mode 100644 llvm/lib/ProfileData/DataAccessProf.cpp
diff --git a/llvm/include/llvm/ADT/MapVector.h b/llvm/include/llvm/ADT/MapVector.h
index c11617a81c97d..fe0d106795c34 100644
--- a/llvm/include/llvm/ADT/MapVector.h
+++ b/llvm/include/llvm/ADT/MapVector.h
@@ -57,6 +57,8 @@ class MapVector {
return std::move(Vector);
}
+ ArrayRef<value_type> getArrayRef() const { return Vector; }
+
size_type size() const { return Vector.size(); }
/// Grow the MapVector so that it can contain at least \p NumEntries items
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
new file mode 100644
index 0000000000000..2cce4945fddd5
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -0,0 +1,156 @@
+//===- DataAccessProf.h - Data access profile format support ---------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains support to construct and use data access profiles.
+//
+// For the original RFC of this pass please see
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_
+#define LLVM_PROFILEDATA_DATAACCESSPROF_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+
+#include <cstdint>
+#include <variant>
+
+namespace llvm {
+
+namespace data_access_prof {
+// The location of data in the source code.
+struct DataLocation {
+ // The filename where the data is located.
+ StringRef FileName;
+ // The line number in the source code.
+ uint32_t Line;
+};
+
+// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+ // Represents a data symbol. The semantic comes in two forms: a symbol index
+ // for symbol name if `IsStringLiteral` is false, or the hash of a string
+ // content if `IsStringLiteral` is true. Required.
+ uint64_t SymbolID;
+
+ // The access count of symbol. Required.
+ uint64_t AccessCount;
+
+ // True iff this is a record for string literal (symbols with name pattern
+ // `.str.*` in the symbol table). Required.
+ bool IsStringLiteral;
+
+ // The locations of data in the source code. Optional.
+ llvm::SmallVector<DataLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on it.
+/// This class provides profile look-up, serialization and deserialization.
+class DataAccessProfData {
+public:
+ // SymbolID is either a string representing symbol name, or a uint64_t
+ // representing the content hash of a string literal.
+ using SymbolID = std::variant<StringRef, uint64_t>;
+ using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
+
+ DataAccessProfData() : saver(Allocator) {}
+
+ /// Serialize profile data to the output stream.
+ /// Storage layout:
+ /// - Serialized strings.
+ /// - The encoded hashes.
+ /// - Records.
+ Error serialize(ProfOStream &OS) const;
+
+ /// Deserialize this class from the given buffer.
+ Error deserialize(const unsigned char *&Ptr);
+
+ /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+ /// isn't a record. Internally, this function will canonicalize the symbol
+ /// name before the lookup.
+ const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+
+ /// Returns true if \p SymID is seen in profiled binaries and cold.
+ bool isKnownColdSymbol(const SymbolID SymID) const;
+
+ /// Methods to add symbolized data access profile. Returns error if duplicated
+ /// symbol names or content hashes are seen. The user of this class should
+ /// aggregate counters that corresponds to the same symbol name or with the
+ /// same string literal hash before calling 'add*' methods.
+ Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+ Error addSymbolizedDataAccessProfile(
+ SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations);
+ Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+
+ /// Returns a iterable StringRef for strings in the order they are added.
+ auto getStrings() const {
+ ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
+ StrToIndexMap.begin(), StrToIndexMap.end());
+ return llvm::make_first_range(RefSymbolNames);
+ }
+
+ /// Returns array reference for various internal data structures.
+ inline ArrayRef<
+ std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
+ getRecords() const {
+ return Records.getArrayRef();
+ }
+ inline ArrayRef<StringRef> getKnownColdSymbols() const {
+ return KnownColdSymbols.getArrayRef();
+ }
+ inline ArrayRef<uint64_t> getKnownColdHashes() const {
+ return KnownColdHashes.getArrayRef();
+ }
+
+private:
+ /// Serialize the symbol strings into the output stream.
+ Error serializeStrings(ProfOStream &OS) const;
+
+ /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
+ /// start of the next payload.
+ Error deserializeStrings(const unsigned char *&Ptr,
+ const uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols);
+
+ /// Decode the records and increment \p Ptr to the start of the next payload.
+ Error deserializeRecords(const unsigned char *&Ptr);
+
+ /// A helper function to compute a storage index for \p SymbolID.
+ uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+
+ // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
+ // its record index.
+ MapVector<SymbolID, DataAccessProfRecord> Records;
+
+ // Use MapVector to keep input order of strings for serialization and
+ // deserialization.
+ StringToIndexMap StrToIndexMap;
+ llvm::SetVector<uint64_t> KnownColdHashes;
+ llvm::SetVector<StringRef> KnownColdSymbols;
+ // Keeps owned copies of the input strings.
+ llvm::BumpPtrAllocator Allocator;
+ llvm::UniqueStringSaver saver;
+};
+
+} // namespace data_access_prof
+} // namespace llvm
+
+#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 2d011c89f27cb..8a6be22bdb1a4 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,6 +357,12 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
/// the duplicated profile variables for Comdat functions.
bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
+/// \c NameStrings is a string composed of one of more possibly encoded
+/// sub-strings. The substrings are separated by 0 or more zero bytes. This
+/// method decodes the string and calls `NameCallback` for each substring.
+Error readAndDecodeStrings(StringRef NameStrings,
+ std::function<Error(StringRef)> NameCallback);
+
/// An enum describing the attributes of an instrumented profile.
enum class InstrProfKind {
Unknown = 0x0,
@@ -493,6 +499,11 @@ class InstrProfSymtab {
public:
using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
+ // Returns the canonial name of the given PGOName. In a canonical name, all
+ // suffixes that begins with "." except ".__uniq." are stripped.
+ // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
+ static StringRef getCanonicalName(StringRef PGOName);
+
private:
using AddrIntervalMap =
IntervalMap<uint64_t, uint64_t, 4, IntervalMapHalfOpenInfo<uint64_t>>;
@@ -528,11 +539,6 @@ class InstrProfSymtab {
static StringRef getExternalSymbol() { return "** External Symbol **"; }
- // Returns the canonial name of the given PGOName. In a canonical name, all
- // suffixes that begins with "." except ".__uniq." are stripped.
- // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
- static StringRef getCanonicalName(StringRef PGOName);
-
// Add the function into the symbol table, by creating the following
// map entries:
// name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)}
diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt
index eb7c2a3c1a28a..67a69d7761b2c 100644
--- a/llvm/lib/ProfileData/CMakeLists.txt
+++ b/llvm/lib/ProfileData/CMakeLists.txt
@@ -1,4 +1,5 @@
add_llvm_component_library(LLVMProfileData
+ DataAccessProf.cpp
GCOV.cpp
IndexedMemProfData.cpp
InstrProf.cpp
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
new file mode 100644
index 0000000000000..cf538d6a1b28e
--- /dev/null
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -0,0 +1,246 @@
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Compression.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <sys/types.h>
+
+namespace llvm {
+namespace data_access_prof {
+
+// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
+// creates an owned copy of `Str`, adds a map entry for it and returns the
+// iterator.
+static MapVector<StringRef, uint64_t>::iterator
+saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+ llvm::UniqueStringSaver &saver, StringRef Str) {
+ auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+ return Iter;
+}
+
+const DataAccessProfRecord *
+DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+ auto Key = SymbolID;
+ if (std::holds_alternative<StringRef>(SymbolID))
+ Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+
+ auto It = Records.find(Key);
+ if (It != Records.end())
+ return &It->second;
+
+ return nullptr;
+}
+
+bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+ if (std::holds_alternative<uint64_t>(SymID))
+ return KnownColdHashes.count(std::get<uint64_t>(SymID));
+ return KnownColdSymbols.count(std::get<StringRef>(SymID));
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
+ uint64_t AccessCount) {
+ uint64_t RecordID = -1;
+ bool IsStringLiteral = false;
+ SymbolID Key;
+ if (std::holds_alternative<uint64_t>(Symbol)) {
+ RecordID = std::get<uint64_t>(Symbol);
+ Key = RecordID;
+ IsStringLiteral = true;
+ } else {
+ StringRef SymbolName = std::get<StringRef>(Symbol);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+
+ StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
+ Key = CanonicalName;
+ RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+ IsStringLiteral = false;
+ }
+
+ auto [Iter, Inserted] = Records.try_emplace(
+ Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+ if (!Inserted)
+ return make_error<StringError>("Duplicate symbol or string literal added. "
+ "User of DataAccessProfData should "
+ "aggregate count for the same symbol. ",
+ llvm::errc::invalid_argument);
+
+ return Error::success();
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(
+ SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations) {
+ if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ return E;
+
+ auto &Record = Records.back().second;
+ for (const auto &Location : Locations)
+ Record.Locations.push_back(
+ {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+ Location.Line});
+
+ return Error::success();
+}
+
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+ if (std::holds_alternative<uint64_t>(SymbolID)) {
+ KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
+ return Error::success();
+ }
+ StringRef SymbolName = std::get<StringRef>(SymbolID);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+ StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
+ KnownColdSymbols.insert(CanonicalSymName);
+ return Error::success();
+}
+
+Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
+ uint64_t NumSampledSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint64_t NumColdKnownSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+ return E;
+
+ uint64_t Num =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ for (uint64_t I = 0; I < Num; ++I)
+ KnownColdHashes.insert(
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
+
+ return deserializeRecords(Ptr);
+}
+
+Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+ OS.write(StrToIndexMap.size());
+ OS.write(KnownColdSymbols.size());
+
+ std::vector<std::string> Strs;
+ Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
+ for (const auto &Str : StrToIndexMap)
+ Strs.push_back(Str.first.str());
+ for (const auto &Str : KnownColdSymbols)
+ Strs.push_back(Str.str());
+
+ std::string CompressedStrings;
+ if (!Strs.empty())
+ if (Error E = collectGlobalObjectNameStrings(
+ Strs, compression::zlib::isAvailable(), CompressedStrings))
+ return E;
+ const uint64_t CompressedStringLen = CompressedStrings.length();
+ // Record the length of compressed string.
+ OS.write(CompressedStringLen);
+ // Write the chars in compressed strings.
+ for (auto &c : CompressedStrings)
+ OS.writeByte(static_cast<uint8_t>(c));
+ // Pad up to a multiple of 8.
+ // InstrProfReader could read bytes according to 'CompressedStringLen'.
+ const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
+ for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
+ OS.writeByte(0);
+ return Error::success();
+}
+
+uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+ if (std::holds_alternative<uint64_t>(SymbolID))
+ return std::get<uint64_t>(SymbolID);
+
+ return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+}
+
+Error DataAccessProfData::serialize(ProfOStream &OS) const {
+ if (Error E = serializeStrings(OS))
+ return E;
+ OS.write(KnownColdHashes.size());
+ for (const auto &Hash : KnownColdHashes)
+ OS.write(Hash);
+ OS.write((uint64_t)(Records.size()));
+ for (const auto &[Key, Rec] : Records) {
+ OS.write(getEncodedIndex(Rec.SymbolID));
+ OS.writeByte(Rec.IsStringLiteral);
+ OS.write(Rec.AccessCount);
+ OS.write(Rec.Locations.size());
+ for (const auto &Loc : Rec.Locations) {
+ OS.write(getEncodedIndex(Loc.FileName));
+ OS.write32(Loc.Line);
+ }
+ }
+ return Error::success();
+}
+
+Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
+ uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols) {
+ uint64_t Len =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
+ // symbols with samples, and next N strings are known cold symbols.
+ uint64_t StringCnt = 0;
+ std::function<Error(StringRef)> addName = [&](StringRef Name) {
+ if (StringCnt < NumSampledSymbols)
+ saveStringToMap(StrToIndexMap, saver, Name);
+ else
+ KnownColdSymbols.insert(saver.save(Name));
+ ++StringCnt;
+ return Error::success();
+ };
+ if (Error E =
+ readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName))
+ return E;
+
+ Ptr += alignTo(Len, 8);
+ return Error::success();
+}
+
+Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
+ SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+
+ uint64_t NumRecords =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ for (uint64_t I = 0; I < NumRecords; ++I) {
+ uint64_t ID =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ bool IsStringLiteral =
+ support::endian::readNext<uint8_t, llvm::endianness::little>(Ptr);
+
+ uint64_t AccessCount =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ SymbolID SymbolID;
+ if (IsStringLiteral)
+ SymbolID = ID;
+ else
+ SymbolID = Strings[ID];
+ if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ return E;
+
+ auto &Record = Records.back().second;
+
+ uint64_t NumLocations =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ Record.Locations.reserve(NumLocations);
+ for (uint64_t J = 0; J < NumLocations; ++J) {
+ uint64_t FileNameIndex =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint32_t Line =
+ support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+ Record.Locations.push_back({Strings[FileNameIndex], Line});
+ }
+ }
+ return Error::success();
+}
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp
index 88621787c1dd9..254f941acde82 100644
--- a/llvm/lib/ProfileData/InstrProf.cpp
+++ b/llvm/lib/ProfileData/InstrProf.cpp
@@ -573,12 +573,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable,
return Error::success();
}
-/// \c NameStrings is a string composed of one of more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
-static Error
-readAndDecodeStrings(StringRef NameStrings,
- std::function<Error(StringRef)> NameCallback) {
+Error readAndDecodeStrings(StringRef NameStrings,
+ std::function<Error(StringRef)> NameCallback) {
const uint8_t *P = NameStrings.bytes_begin();
const uint8_t *EndP = NameStrings.bytes_end();
while (P < EndP) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 3e430aa4eae58..f6362448a5734 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,14 +10,17 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/DebugInfo/DIContext.h"
#include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
#include "llvm/IR/Value.h"
#include "llvm/Object/ObjectFile.h"
+#include "llvm/ProfileData/DataAccessProf.h"
#include "llvm/ProfileData/MemProfData.inc"
#include "llvm/ProfileData/MemProfReader.h"
#include "llvm/ProfileData/MemProfYAML.h"
#include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -36,6 +39,8 @@ using ::llvm::StringRef;
using ::llvm::object::SectionedAddress;
using ::llvm::symbolize::SymbolizableModule;
using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Pair;
using ::testing::Return;
@@ -747,6 +752,162 @@ TEST(MemProf, YAMLParser) {
ElementsAre(0x3000)))));
}
+static std::string ErrorToString(Error E) {
+ std::string ErrMsg;
+ llvm::raw_string_ostream OS(ErrMsg);
+ llvm::logAllUnhandledErrors(std::move(E), OS);
+ return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+ // Returns error if the input symbol name is empty.
+ llvm::data_access_prof::DataAccessProfData Data;
+ EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+ HasSubstr("Empty symbol name"));
+
+ // Returns error when the same symbol gets added twice.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
+ EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+ HasSubstr("Duplicate symbol or string literal added"));
+
+ // Returns error when the same string content hash gets added twice.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
+ EXPECT_THAT(ErrorToString(
+ Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+ HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+ using namespace llvm::data_access_prof;
+ llvm::data_access_prof::DataAccessProfData Data;
+
+ // In the bool conversion, Error is true if it's in a failure state and false
+ // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
+ {
+ DataLocation{"file2", 3},
+ }));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+ ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+ ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+ (uint64_t)135246, 1000,
+ {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+ {
+ // Test that symbol names and file names are stored in the input order.
+ EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles.
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+ EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+ EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+ EXPECT_THAT(
+ Data.getProfileRecord("foo.llvm.123"),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())));
+ EXPECT_THAT(
+ *Data.getProfileRecord("bar.__uniq.321"),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))));
+ EXPECT_THAT(
+ *Data.getProfileRecord((uint64_t)135246),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2))))));
+ }
+
+ // Tests serialization and de-serialization.
+ llvm::data_access_prof::DataAccessProfData deserializedData;
+ {
+ std::string serializedData;
+ llvm::raw_string_ostream OS(serializedData);
+ llvm::ProfOStream POS(OS);
+
+ EXPECT_FALSE(Data.serialize(POS));
+
+ const unsigned char *p =
+ reinterpret_cast<const unsigned char *>(serializedData.data());
+ ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ testing::IsEmpty());
+ EXPECT_FALSE(deserializedData.deserialize(p));
+
+ EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+ ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+ EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+ ElementsAre("sym2", "sym1"));
+ EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+ // Look up profiles after deserialization.
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+ EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+ auto Records =
+ llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+ EXPECT_THAT(
+ Records,
+ ElementsAre(
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+ testing::Field(&DataAccessProfRecord::AccessCount, 100),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(&DataAccessProfRecord::Locations,
+ testing::IsEmpty())),
+ AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+ testing::Field(&DataAccessProfRecord::AccessCount, 123),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(AllOf(
+ testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 3))))),
+ AllOf(
+ testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+ testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+ testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+ testing::Field(
+ &DataAccessProfRecord::Locations,
+ ElementsAre(
+ AllOf(testing::Field(&DataLocation::FileName, "file1"),
+ testing::Field(&DataLocation::Line, 1)),
+ AllOf(testing::Field(&DataLocation::FileName, "file2"),
+ testing::Field(&DataLocation::Line, 2)))))));
+ }
+}
+
// Verify that the YAML parser accepts a GUID expressed as a function name.
TEST(MemProf, YAMLParserGUID) {
StringRef YAMLData = R"YAML(
>From 47275299fd3e9ef242f108040df8ffe46f9cd7b0 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 5 May 2025 16:57:56 -0700
Subject: [PATCH 2/2] resolve review feedback
---
.../include/llvm/ProfileData/DataAccessProf.h | 37 ++++++++++------
llvm/lib/ProfileData/DataAccessProf.cpp | 43 ++++++++++---------
llvm/unittests/ProfileData/MemProfTest.cpp | 23 +++++-----
3 files changed, 57 insertions(+), 46 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 2cce4945fddd5..36648d6298ee5 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -45,6 +45,11 @@ struct DataLocation {
// The data access profiles for a symbol.
struct DataAccessProfRecord {
+ DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
+ bool IsStringLiteral)
+ : SymbolID(SymbolID), AccessCount(AccessCount),
+ IsStringLiteral(IsStringLiteral) {}
+
// Represents a data symbol. The semantic comes in two forms: a symbol index
// for symbol name if `IsStringLiteral` is false, or the hash of a string
// content if `IsStringLiteral` is true. Required.
@@ -58,7 +63,7 @@ struct DataAccessProfRecord {
bool IsStringLiteral;
// The locations of data in the source code. Optional.
- llvm::SmallVector<DataLocation> Locations;
+ llvm::SmallVector<DataLocation, 0> Locations;
};
/// Encapsulates the data access profile data and the methods to operate on it.
@@ -70,7 +75,7 @@ class DataAccessProfData {
using SymbolID = std::variant<StringRef, uint64_t>;
using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
- DataAccessProfData() : saver(Allocator) {}
+ DataAccessProfData() : Saver(Allocator) {}
/// Serialize profile data to the output stream.
/// Storage layout:
@@ -94,10 +99,13 @@ class DataAccessProfData {
/// symbol names or content hashes are seen. The user of this class should
/// aggregate counters that corresponds to the same symbol name or with the
/// same string literal hash before calling 'add*' methods.
- Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
- Error addSymbolizedDataAccessProfile(
- SymbolID SymbolID, uint64_t AccessCount,
- const llvm::SmallVector<DataLocation> &Locations);
+ Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+ /// Similar to the method above, for records with \p Locations representing
+ /// the `filename:line` where this symbol shows up. Note because of linker's
+ /// merge of identical symbols (e.g., unnamed_addr string literals), one
+ /// symbol is likely to have multiple locations.
+ Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations);
Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
/// Returns a iterable StringRef for strings in the order they are added.
@@ -122,13 +130,13 @@ class DataAccessProfData {
private:
/// Serialize the symbol strings into the output stream.
- Error serializeStrings(ProfOStream &OS) const;
+ Error serializeSymbolsAndFilenames(ProfOStream &OS) const;
/// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
/// start of the next payload.
- Error deserializeStrings(const unsigned char *&Ptr,
- const uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols);
+ Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
+ const uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols);
/// Decode the records and increment \p Ptr to the start of the next payload.
Error deserializeRecords(const unsigned char *&Ptr);
@@ -136,6 +144,12 @@ class DataAccessProfData {
/// A helper function to compute a storage index for \p SymbolID.
uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+ // Keeps owned copies of the input strings.
+ // NOTE: Keep `Saver` initialized before other class members that reference
+ // its string copies and destructed after they are destructed.
+ llvm::BumpPtrAllocator Allocator;
+ llvm::UniqueStringSaver Saver;
+
// `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
// its record index.
MapVector<SymbolID, DataAccessProfRecord> Records;
@@ -145,9 +159,6 @@ class DataAccessProfData {
StringToIndexMap StrToIndexMap;
llvm::SetVector<uint64_t> KnownColdHashes;
llvm::SetVector<StringRef> KnownColdSymbols;
- // Keeps owned copies of the input strings.
- llvm::BumpPtrAllocator Allocator;
- llvm::UniqueStringSaver saver;
};
} // namespace data_access_prof
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index cf538d6a1b28e..c52533c13919c 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -18,8 +18,8 @@ namespace data_access_prof {
// iterator.
static MapVector<StringRef, uint64_t>::iterator
saveStringToMap(MapVector<StringRef, uint64_t> &Map,
- llvm::UniqueStringSaver &saver, StringRef Str) {
- auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+ llvm::UniqueStringSaver &Saver, StringRef Str) {
+ auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
return Iter;
}
@@ -38,12 +38,12 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
if (std::holds_alternative<uint64_t>(SymID))
- return KnownColdHashes.count(std::get<uint64_t>(SymID));
- return KnownColdSymbols.count(std::get<StringRef>(SymID));
+ return KnownColdHashes.contains(std::get<uint64_t>(SymID));
+ return KnownColdSymbols.contains(std::get<StringRef>(SymID));
}
-Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
- uint64_t AccessCount) {
+Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+ uint64_t AccessCount) {
uint64_t RecordID = -1;
bool IsStringLiteral = false;
SymbolID Key;
@@ -59,12 +59,12 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
Key = CanonicalName;
- RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+ RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
IsStringLiteral = false;
}
- auto [Iter, Inserted] = Records.try_emplace(
- Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+ auto [Iter, Inserted] =
+ Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
if (!Inserted)
return make_error<StringError>("Duplicate symbol or string literal added. "
"User of DataAccessProfData should "
@@ -74,16 +74,16 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
return Error::success();
}
-Error DataAccessProfData::addSymbolizedDataAccessProfile(
+Error DataAccessProfData::setDataAccessProfile(
SymbolID SymbolID, uint64_t AccessCount,
const llvm::SmallVector<DataLocation> &Locations) {
- if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
auto &Record = Records.back().second;
for (const auto &Location : Locations)
Record.Locations.push_back(
- {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+ {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
Location.Line});
return Error::success();
@@ -108,7 +108,8 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
uint64_t NumColdKnownSymbols =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
- if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+ if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
+ NumColdKnownSymbols))
return E;
uint64_t Num =
@@ -120,7 +121,7 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
return deserializeRecords(Ptr);
}
-Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
OS.write(StrToIndexMap.size());
OS.write(KnownColdSymbols.size());
@@ -158,7 +159,7 @@ uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
}
Error DataAccessProfData::serialize(ProfOStream &OS) const {
- if (Error E = serializeStrings(OS))
+ if (Error E = serializeSymbolsAndFilenames(OS))
return E;
OS.write(KnownColdHashes.size());
for (const auto &Hash : KnownColdHashes)
@@ -177,9 +178,9 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
return Error::success();
}
-Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
- uint64_t NumSampledSymbols,
- uint64_t NumColdKnownSymbols) {
+Error DataAccessProfData::deserializeSymbolsAndFilenames(
+ const unsigned char *&Ptr, uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols) {
uint64_t Len =
support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
@@ -188,9 +189,9 @@ Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
uint64_t StringCnt = 0;
std::function<Error(StringRef)> addName = [&](StringRef Name) {
if (StringCnt < NumSampledSymbols)
- saveStringToMap(StrToIndexMap, saver, Name);
+ saveStringToMap(StrToIndexMap, Saver, Name);
else
- KnownColdSymbols.insert(saver.save(Name));
+ KnownColdSymbols.insert(Saver.save(Name));
++StringCnt;
return Error::success();
};
@@ -223,7 +224,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
SymbolID = ID;
else
SymbolID = Strings[ID];
- if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+ if (Error E = setDataAccessProfile(SymbolID, AccessCount))
return E;
auto &Record = Records.back().second;
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index f6362448a5734..b7b8d642ad930 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -764,18 +764,17 @@ static std::string ErrorToString(Error E) {
TEST(MemProf, DataAccessProfileError) {
// Returns error if the input symbol name is empty.
llvm::data_access_prof::DataAccessProfData Data;
- EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
HasSubstr("Empty symbol name"));
// Returns error when the same symbol gets added twice.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
- EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+ ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
HasSubstr("Duplicate symbol or string literal added"));
// Returns error when the same string content hash gets added twice.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
- EXPECT_THAT(ErrorToString(
- Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+ ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+ EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
HasSubstr("Duplicate symbol or string literal added"));
}
@@ -788,16 +787,16 @@ TEST(MemProf, DataAccessProfile) {
// In the bool conversion, Error is true if it's in a failure state and false
// if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+ ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
- {
- DataLocation{"file2", 3},
- }));
+ ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+ {
+ DataLocation{"file2", 3},
+ }));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
- ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+ ASSERT_FALSE(Data.setDataAccessProfile(
(uint64_t)135246, 1000,
{DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
More information about the llvm-commits
mailing list