[llvm] [StaticDataLayout][PGO] Add profile format for static data layout, and the classes to operate on the profiles. (PR #138170)

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 22:17:59 PDT 2025


https://github.com/mingmingl-llvm updated https://github.com/llvm/llvm-project/pull/138170

>From 6cd7d8dafd6a5cb438c5bd595e126f3cd863814e Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 1 May 2025 09:37:59 -0700
Subject: [PATCH 1/6] Add classes to construct and use data access profiles

---
 llvm/include/llvm/ADT/MapVector.h             |   2 +
 .../include/llvm/ProfileData/DataAccessProf.h | 156 +++++++++++
 llvm/include/llvm/ProfileData/InstrProf.h     |  16 +-
 llvm/lib/ProfileData/CMakeLists.txt           |   1 +
 llvm/lib/ProfileData/DataAccessProf.cpp       | 246 ++++++++++++++++++
 llvm/lib/ProfileData/InstrProf.cpp            |   8 +-
 llvm/unittests/ProfileData/MemProfTest.cpp    | 161 ++++++++++++
 7 files changed, 579 insertions(+), 11 deletions(-)
 create mode 100644 llvm/include/llvm/ProfileData/DataAccessProf.h
 create mode 100644 llvm/lib/ProfileData/DataAccessProf.cpp

diff --git a/llvm/include/llvm/ADT/MapVector.h b/llvm/include/llvm/ADT/MapVector.h
index c11617a81c97d..fe0d106795c34 100644
--- a/llvm/include/llvm/ADT/MapVector.h
+++ b/llvm/include/llvm/ADT/MapVector.h
@@ -57,6 +57,8 @@ class MapVector {
     return std::move(Vector);
   }
 
+  ArrayRef<value_type> getArrayRef() const { return Vector; }
+
   size_type size() const { return Vector.size(); }
 
   /// Grow the MapVector so that it can contain at least \p NumEntries items
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
new file mode 100644
index 0000000000000..2cce4945fddd5
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -0,0 +1,156 @@
+//===- DataAccessProf.h - Data access profile format support ---------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains support to construct and use data access profiles.
+//
+// For the original RFC of this pass please see
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_
+#define LLVM_PROFILEDATA_DATAACCESSPROF_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+
+#include <cstdint>
+#include <variant>
+
+namespace llvm {
+
+namespace data_access_prof {
+// The location of data in the source code.
+struct DataLocation {
+  // The filename where the data is located.
+  StringRef FileName;
+  // The line number in the source code.
+  uint32_t Line;
+};
+
+// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+  // Represents a data symbol. The semantic comes in two forms: a symbol index
+  // for symbol name if `IsStringLiteral` is false, or the hash of a string
+  // content if `IsStringLiteral` is true. Required.
+  uint64_t SymbolID;
+
+  // The access count of symbol. Required.
+  uint64_t AccessCount;
+
+  // True iff this is a record for string literal (symbols with name pattern
+  // `.str.*` in the symbol table). Required.
+  bool IsStringLiteral;
+
+  // The locations of data in the source code. Optional.
+  llvm::SmallVector<DataLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on it.
+/// This class provides profile look-up, serialization and deserialization.
+class DataAccessProfData {
+public:
+  // SymbolID is either a string representing symbol name, or a uint64_t
+  // representing the content hash of a string literal.
+  using SymbolID = std::variant<StringRef, uint64_t>;
+  using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
+
+  DataAccessProfData() : saver(Allocator) {}
+
+  /// Serialize profile data to the output stream.
+  /// Storage layout:
+  /// - Serialized strings.
+  /// - The encoded hashes.
+  /// - Records.
+  Error serialize(ProfOStream &OS) const;
+
+  /// Deserialize this class from the given buffer.
+  Error deserialize(const unsigned char *&Ptr);
+
+  /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+  /// isn't a record. Internally, this function will canonicalize the symbol
+  /// name before the lookup.
+  const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+
+  /// Returns true if \p SymID is seen in profiled binaries and cold.
+  bool isKnownColdSymbol(const SymbolID SymID) const;
+
+  /// Methods to add symbolized data access profile. Returns error if duplicated
+  /// symbol names or content hashes are seen. The user of this class should
+  /// aggregate counters that corresponds to the same symbol name or with the
+  /// same string literal hash before calling 'add*' methods.
+  Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  Error addSymbolizedDataAccessProfile(
+      SymbolID SymbolID, uint64_t AccessCount,
+      const llvm::SmallVector<DataLocation> &Locations);
+  Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+
+  /// Returns a iterable StringRef for strings in the order they are added.
+  auto getStrings() const {
+    ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
+        StrToIndexMap.begin(), StrToIndexMap.end());
+    return llvm::make_first_range(RefSymbolNames);
+  }
+
+  /// Returns array reference for various internal data structures.
+  inline ArrayRef<
+      std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
+  getRecords() const {
+    return Records.getArrayRef();
+  }
+  inline ArrayRef<StringRef> getKnownColdSymbols() const {
+    return KnownColdSymbols.getArrayRef();
+  }
+  inline ArrayRef<uint64_t> getKnownColdHashes() const {
+    return KnownColdHashes.getArrayRef();
+  }
+
+private:
+  /// Serialize the symbol strings into the output stream.
+  Error serializeStrings(ProfOStream &OS) const;
+
+  /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
+  /// start of the next payload.
+  Error deserializeStrings(const unsigned char *&Ptr,
+                           const uint64_t NumSampledSymbols,
+                           uint64_t NumColdKnownSymbols);
+
+  /// Decode the records and increment \p Ptr to the start of the next payload.
+  Error deserializeRecords(const unsigned char *&Ptr);
+
+  /// A helper function to compute a storage index for \p SymbolID.
+  uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+
+  // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
+  // its record index.
+  MapVector<SymbolID, DataAccessProfRecord> Records;
+
+  // Use MapVector to keep input order of strings for serialization and
+  // deserialization.
+  StringToIndexMap StrToIndexMap;
+  llvm::SetVector<uint64_t> KnownColdHashes;
+  llvm::SetVector<StringRef> KnownColdSymbols;
+  // Keeps owned copies of the input strings.
+  llvm::BumpPtrAllocator Allocator;
+  llvm::UniqueStringSaver saver;
+};
+
+} // namespace data_access_prof
+} // namespace llvm
+
+#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 2d011c89f27cb..8a6be22bdb1a4 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,6 +357,12 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 /// the duplicated profile variables for Comdat functions.
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
+/// \c NameStrings is a string composed of one of more possibly encoded
+/// sub-strings. The substrings are separated by 0 or more zero bytes. This
+/// method decodes the string and calls `NameCallback` for each substring.
+Error readAndDecodeStrings(StringRef NameStrings,
+                           std::function<Error(StringRef)> NameCallback);
+
 /// An enum describing the attributes of an instrumented profile.
 enum class InstrProfKind {
   Unknown = 0x0,
@@ -493,6 +499,11 @@ class InstrProfSymtab {
 public:
   using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
 
+  // Returns the canonial name of the given PGOName. In a canonical name, all
+  // suffixes that begins with "." except ".__uniq." are stripped.
+  // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
+  static StringRef getCanonicalName(StringRef PGOName);
+
 private:
   using AddrIntervalMap =
       IntervalMap<uint64_t, uint64_t, 4, IntervalMapHalfOpenInfo<uint64_t>>;
@@ -528,11 +539,6 @@ class InstrProfSymtab {
 
   static StringRef getExternalSymbol() { return "** External Symbol **"; }
 
-  // Returns the canonial name of the given PGOName. In a canonical name, all
-  // suffixes that begins with "." except ".__uniq." are stripped.
-  // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
-  static StringRef getCanonicalName(StringRef PGOName);
-
   // Add the function into the symbol table, by creating the following
   // map entries:
   // name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)}
diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt
index eb7c2a3c1a28a..67a69d7761b2c 100644
--- a/llvm/lib/ProfileData/CMakeLists.txt
+++ b/llvm/lib/ProfileData/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_llvm_component_library(LLVMProfileData
+  DataAccessProf.cpp
   GCOV.cpp
   IndexedMemProfData.cpp
   InstrProf.cpp
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
new file mode 100644
index 0000000000000..cf538d6a1b28e
--- /dev/null
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -0,0 +1,246 @@
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Compression.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <sys/types.h>
+
+namespace llvm {
+namespace data_access_prof {
+
+// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
+// creates an owned copy of `Str`, adds a map entry for it and returns the
+// iterator.
+static MapVector<StringRef, uint64_t>::iterator
+saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+                llvm::UniqueStringSaver &saver, StringRef Str) {
+  auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+  return Iter;
+}
+
+const DataAccessProfRecord *
+DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+  auto Key = SymbolID;
+  if (std::holds_alternative<StringRef>(SymbolID))
+    Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+
+  auto It = Records.find(Key);
+  if (It != Records.end())
+    return &It->second;
+
+  return nullptr;
+}
+
+bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+  if (std::holds_alternative<uint64_t>(SymID))
+    return KnownColdHashes.count(std::get<uint64_t>(SymID));
+  return KnownColdSymbols.count(std::get<StringRef>(SymID));
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
+                                                         uint64_t AccessCount) {
+  uint64_t RecordID = -1;
+  bool IsStringLiteral = false;
+  SymbolID Key;
+  if (std::holds_alternative<uint64_t>(Symbol)) {
+    RecordID = std::get<uint64_t>(Symbol);
+    Key = RecordID;
+    IsStringLiteral = true;
+  } else {
+    StringRef SymbolName = std::get<StringRef>(Symbol);
+    if (SymbolName.empty())
+      return make_error<StringError>("Empty symbol name",
+                                     llvm::errc::invalid_argument);
+
+    StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
+    Key = CanonicalName;
+    RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+    IsStringLiteral = false;
+  }
+
+  auto [Iter, Inserted] = Records.try_emplace(
+      Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+  if (!Inserted)
+    return make_error<StringError>("Duplicate symbol or string literal added. "
+                                   "User of DataAccessProfData should "
+                                   "aggregate count for the same symbol. ",
+                                   llvm::errc::invalid_argument);
+
+  return Error::success();
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(
+    SymbolID SymbolID, uint64_t AccessCount,
+    const llvm::SmallVector<DataLocation> &Locations) {
+  if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+    return E;
+
+  auto &Record = Records.back().second;
+  for (const auto &Location : Locations)
+    Record.Locations.push_back(
+        {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+         Location.Line});
+
+  return Error::success();
+}
+
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+  if (std::holds_alternative<uint64_t>(SymbolID)) {
+    KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
+    return Error::success();
+  }
+  StringRef SymbolName = std::get<StringRef>(SymbolID);
+  if (SymbolName.empty())
+    return make_error<StringError>("Empty symbol name",
+                                   llvm::errc::invalid_argument);
+  StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
+  KnownColdSymbols.insert(CanonicalSymName);
+  return Error::success();
+}
+
+Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
+  uint64_t NumSampledSymbols =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  uint64_t NumColdKnownSymbols =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+    return E;
+
+  uint64_t Num =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  for (uint64_t I = 0; I < Num; ++I)
+    KnownColdHashes.insert(
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
+
+  return deserializeRecords(Ptr);
+}
+
+Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+  OS.write(StrToIndexMap.size());
+  OS.write(KnownColdSymbols.size());
+
+  std::vector<std::string> Strs;
+  Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
+  for (const auto &Str : StrToIndexMap)
+    Strs.push_back(Str.first.str());
+  for (const auto &Str : KnownColdSymbols)
+    Strs.push_back(Str.str());
+
+  std::string CompressedStrings;
+  if (!Strs.empty())
+    if (Error E = collectGlobalObjectNameStrings(
+            Strs, compression::zlib::isAvailable(), CompressedStrings))
+      return E;
+  const uint64_t CompressedStringLen = CompressedStrings.length();
+  // Record the length of compressed string.
+  OS.write(CompressedStringLen);
+  // Write the chars in compressed strings.
+  for (auto &c : CompressedStrings)
+    OS.writeByte(static_cast<uint8_t>(c));
+  // Pad up to a multiple of 8.
+  // InstrProfReader could read bytes according to 'CompressedStringLen'.
+  const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
+  for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
+    OS.writeByte(0);
+  return Error::success();
+}
+
+uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+  if (std::holds_alternative<uint64_t>(SymbolID))
+    return std::get<uint64_t>(SymbolID);
+
+  return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+}
+
+Error DataAccessProfData::serialize(ProfOStream &OS) const {
+  if (Error E = serializeStrings(OS))
+    return E;
+  OS.write(KnownColdHashes.size());
+  for (const auto &Hash : KnownColdHashes)
+    OS.write(Hash);
+  OS.write((uint64_t)(Records.size()));
+  for (const auto &[Key, Rec] : Records) {
+    OS.write(getEncodedIndex(Rec.SymbolID));
+    OS.writeByte(Rec.IsStringLiteral);
+    OS.write(Rec.AccessCount);
+    OS.write(Rec.Locations.size());
+    for (const auto &Loc : Rec.Locations) {
+      OS.write(getEncodedIndex(Loc.FileName));
+      OS.write32(Loc.Line);
+    }
+  }
+  return Error::success();
+}
+
+Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
+                                             uint64_t NumSampledSymbols,
+                                             uint64_t NumColdKnownSymbols) {
+  uint64_t Len =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+  // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
+  // symbols with samples, and next N strings are known cold symbols.
+  uint64_t StringCnt = 0;
+  std::function<Error(StringRef)> addName = [&](StringRef Name) {
+    if (StringCnt < NumSampledSymbols)
+      saveStringToMap(StrToIndexMap, saver, Name);
+    else
+      KnownColdSymbols.insert(saver.save(Name));
+    ++StringCnt;
+    return Error::success();
+  };
+  if (Error E =
+          readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName))
+    return E;
+
+  Ptr += alignTo(Len, 8);
+  return Error::success();
+}
+
+Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
+  SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+
+  uint64_t NumRecords =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+  for (uint64_t I = 0; I < NumRecords; ++I) {
+    uint64_t ID =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    bool IsStringLiteral =
+        support::endian::readNext<uint8_t, llvm::endianness::little>(Ptr);
+
+    uint64_t AccessCount =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    SymbolID SymbolID;
+    if (IsStringLiteral)
+      SymbolID = ID;
+    else
+      SymbolID = Strings[ID];
+    if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+      return E;
+
+    auto &Record = Records.back().second;
+
+    uint64_t NumLocations =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    Record.Locations.reserve(NumLocations);
+    for (uint64_t J = 0; J < NumLocations; ++J) {
+      uint64_t FileNameIndex =
+          support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+      uint32_t Line =
+          support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+      Record.Locations.push_back({Strings[FileNameIndex], Line});
+    }
+  }
+  return Error::success();
+}
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp
index 88621787c1dd9..254f941acde82 100644
--- a/llvm/lib/ProfileData/InstrProf.cpp
+++ b/llvm/lib/ProfileData/InstrProf.cpp
@@ -573,12 +573,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable,
   return Error::success();
 }
 
-/// \c NameStrings is a string composed of one of more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
-static Error
-readAndDecodeStrings(StringRef NameStrings,
-                     std::function<Error(StringRef)> NameCallback) {
+Error readAndDecodeStrings(StringRef NameStrings,
+                           std::function<Error(StringRef)> NameCallback) {
   const uint8_t *P = NameStrings.bytes_begin();
   const uint8_t *EndP = NameStrings.bytes_end();
   while (P < EndP) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 3e430aa4eae58..f6362448a5734 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,14 +10,17 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/DebugInfo/DIContext.h"
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Object/ObjectFile.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/ProfileData/MemProfReader.h"
 #include "llvm/ProfileData/MemProfYAML.h"
 #include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -36,6 +39,8 @@ using ::llvm::StringRef;
 using ::llvm::object::SectionedAddress;
 using ::llvm::symbolize::SymbolizableModule;
 using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 using ::testing::Pair;
 using ::testing::Return;
@@ -747,6 +752,162 @@ TEST(MemProf, YAMLParser) {
                                ElementsAre(0x3000)))));
 }
 
+static std::string ErrorToString(Error E) {
+  std::string ErrMsg;
+  llvm::raw_string_ostream OS(ErrMsg);
+  llvm::logAllUnhandledErrors(std::move(E), OS);
+  return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+  // Returns error if the input symbol name is empty.
+  llvm::data_access_prof::DataAccessProfData Data;
+  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+              HasSubstr("Empty symbol name"));
+
+  // Returns error when the same symbol gets added twice.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+              HasSubstr("Duplicate symbol or string literal added"));
+
+  // Returns error when the same string content hash gets added twice.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(
+                  Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+              HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+  using namespace llvm::data_access_prof;
+  llvm::data_access_prof::DataAccessProfData Data;
+
+  // In the bool conversion, Error is true if it's in a failure state and false
+  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
+                                                   {
+                                                       DataLocation{"file2", 3},
+                                                   }));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+      (uint64_t)135246, 1000,
+      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+  {
+    // Test that symbol names and file names are stored in the input order.
+    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles.
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+    EXPECT_THAT(
+        Data.getProfileRecord("foo.llvm.123"),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+              testing::Field(&DataAccessProfRecord::AccessCount, 100),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+              testing::Field(&DataAccessProfRecord::Locations,
+                             testing::IsEmpty())));
+    EXPECT_THAT(
+        *Data.getProfileRecord("bar.__uniq.321"),
+        AllOf(
+            testing::Field(&DataAccessProfRecord::SymbolID, 1),
+            testing::Field(&DataAccessProfRecord::AccessCount, 123),
+            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+            testing::Field(&DataAccessProfRecord::Locations,
+                           ElementsAre(AllOf(
+                               testing::Field(&DataLocation::FileName, "file2"),
+                               testing::Field(&DataLocation::Line, 3))))));
+    EXPECT_THAT(
+        *Data.getProfileRecord((uint64_t)135246),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+              testing::Field(
+                  &DataAccessProfRecord::Locations,
+                  ElementsAre(
+                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                            testing::Field(&DataLocation::Line, 1)),
+                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                            testing::Field(&DataLocation::Line, 2))))));
+  }
+
+  // Tests serialization and de-serialization.
+  llvm::data_access_prof::DataAccessProfData deserializedData;
+  {
+    std::string serializedData;
+    llvm::raw_string_ostream OS(serializedData);
+    llvm::ProfOStream POS(OS);
+
+    EXPECT_FALSE(Data.serialize(POS));
+
+    const unsigned char *p =
+        reinterpret_cast<const unsigned char *>(serializedData.data());
+    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                testing::IsEmpty());
+    EXPECT_FALSE(deserializedData.deserialize(p));
+
+    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+                ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles after deserialization.
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+    auto Records =
+        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+    EXPECT_THAT(
+        Records,
+        ElementsAre(
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(&DataAccessProfRecord::Locations,
+                                 testing::IsEmpty())),
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(
+                      &DataAccessProfRecord::Locations,
+                      ElementsAre(AllOf(
+                          testing::Field(&DataLocation::FileName, "file2"),
+                          testing::Field(&DataLocation::Line, 3))))),
+            AllOf(
+                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+                testing::Field(
+                    &DataAccessProfRecord::Locations,
+                    ElementsAre(
+                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                              testing::Field(&DataLocation::Line, 1)),
+                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                              testing::Field(&DataLocation::Line, 2)))))));
+  }
+}
+
 // Verify that the YAML parser accepts a GUID expressed as a function name.
 TEST(MemProf, YAMLParserGUID) {
   StringRef YAMLData = R"YAML(

>From 47275299fd3e9ef242f108040df8ffe46f9cd7b0 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 5 May 2025 16:57:56 -0700
Subject: [PATCH 2/6] resolve review feedback

---
 .../include/llvm/ProfileData/DataAccessProf.h | 37 ++++++++++------
 llvm/lib/ProfileData/DataAccessProf.cpp       | 43 ++++++++++---------
 llvm/unittests/ProfileData/MemProfTest.cpp    | 23 +++++-----
 3 files changed, 57 insertions(+), 46 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 2cce4945fddd5..36648d6298ee5 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -45,6 +45,11 @@ struct DataLocation {
 
 // The data access profiles for a symbol.
 struct DataAccessProfRecord {
+  DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
+                       bool IsStringLiteral)
+      : SymbolID(SymbolID), AccessCount(AccessCount),
+        IsStringLiteral(IsStringLiteral) {}
+
   // Represents a data symbol. The semantic comes in two forms: a symbol index
   // for symbol name if `IsStringLiteral` is false, or the hash of a string
   // content if `IsStringLiteral` is true. Required.
@@ -58,7 +63,7 @@ struct DataAccessProfRecord {
   bool IsStringLiteral;
 
   // The locations of data in the source code. Optional.
-  llvm::SmallVector<DataLocation> Locations;
+  llvm::SmallVector<DataLocation, 0> Locations;
 };
 
 /// Encapsulates the data access profile data and the methods to operate on it.
@@ -70,7 +75,7 @@ class DataAccessProfData {
   using SymbolID = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
-  DataAccessProfData() : saver(Allocator) {}
+  DataAccessProfData() : Saver(Allocator) {}
 
   /// Serialize profile data to the output stream.
   /// Storage layout:
@@ -94,10 +99,13 @@ class DataAccessProfData {
   /// symbol names or content hashes are seen. The user of this class should
   /// aggregate counters that corresponds to the same symbol name or with the
   /// same string literal hash before calling 'add*' methods.
-  Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
-  Error addSymbolizedDataAccessProfile(
-      SymbolID SymbolID, uint64_t AccessCount,
-      const llvm::SmallVector<DataLocation> &Locations);
+  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  /// Similar to the method above, for records with \p Locations representing
+  /// the `filename:line` where this symbol shows up. Note because of linker's
+  /// merge of identical symbols (e.g., unnamed_addr string literals), one
+  /// symbol is likely to have multiple locations.
+  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
+                             const llvm::SmallVector<DataLocation> &Locations);
   Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
 
   /// Returns a iterable StringRef for strings in the order they are added.
@@ -122,13 +130,13 @@ class DataAccessProfData {
 
 private:
   /// Serialize the symbol strings into the output stream.
-  Error serializeStrings(ProfOStream &OS) const;
+  Error serializeSymbolsAndFilenames(ProfOStream &OS) const;
 
   /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
   /// start of the next payload.
-  Error deserializeStrings(const unsigned char *&Ptr,
-                           const uint64_t NumSampledSymbols,
-                           uint64_t NumColdKnownSymbols);
+  Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
+                                       const uint64_t NumSampledSymbols,
+                                       uint64_t NumColdKnownSymbols);
 
   /// Decode the records and increment \p Ptr to the start of the next payload.
   Error deserializeRecords(const unsigned char *&Ptr);
@@ -136,6 +144,12 @@ class DataAccessProfData {
   /// A helper function to compute a storage index for \p SymbolID.
   uint64_t getEncodedIndex(const SymbolID SymbolID) const;
 
+  // Keeps owned copies of the input strings.
+  // NOTE: Keep `Saver` initialized before other class members that reference
+  // its string copies and destructed after they are destructed.
+  llvm::BumpPtrAllocator Allocator;
+  llvm::UniqueStringSaver Saver;
+
   // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
   // its record index.
   MapVector<SymbolID, DataAccessProfRecord> Records;
@@ -145,9 +159,6 @@ class DataAccessProfData {
   StringToIndexMap StrToIndexMap;
   llvm::SetVector<uint64_t> KnownColdHashes;
   llvm::SetVector<StringRef> KnownColdSymbols;
-  // Keeps owned copies of the input strings.
-  llvm::BumpPtrAllocator Allocator;
-  llvm::UniqueStringSaver saver;
 };
 
 } // namespace data_access_prof
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index cf538d6a1b28e..c52533c13919c 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -18,8 +18,8 @@ namespace data_access_prof {
 // iterator.
 static MapVector<StringRef, uint64_t>::iterator
 saveStringToMap(MapVector<StringRef, uint64_t> &Map,
-                llvm::UniqueStringSaver &saver, StringRef Str) {
-  auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+                llvm::UniqueStringSaver &Saver, StringRef Str) {
+  auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
   return Iter;
 }
 
@@ -38,12 +38,12 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
 
 bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
-    return KnownColdHashes.count(std::get<uint64_t>(SymID));
-  return KnownColdSymbols.count(std::get<StringRef>(SymID));
+    return KnownColdHashes.contains(std::get<uint64_t>(SymID));
+  return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
-                                                         uint64_t AccessCount) {
+Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+                                               uint64_t AccessCount) {
   uint64_t RecordID = -1;
   bool IsStringLiteral = false;
   SymbolID Key;
@@ -59,12 +59,12 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
 
     StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
     Key = CanonicalName;
-    RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+    RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
     IsStringLiteral = false;
   }
 
-  auto [Iter, Inserted] = Records.try_emplace(
-      Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+  auto [Iter, Inserted] =
+      Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
   if (!Inserted)
     return make_error<StringError>("Duplicate symbol or string literal added. "
                                    "User of DataAccessProfData should "
@@ -74,16 +74,16 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
   return Error::success();
 }
 
-Error DataAccessProfData::addSymbolizedDataAccessProfile(
+Error DataAccessProfData::setDataAccessProfile(
     SymbolID SymbolID, uint64_t AccessCount,
     const llvm::SmallVector<DataLocation> &Locations) {
-  if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+  if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
   auto &Record = Records.back().second;
   for (const auto &Location : Locations)
     Record.Locations.push_back(
-        {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+        {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
          Location.Line});
 
   return Error::success();
@@ -108,7 +108,8 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
   uint64_t NumColdKnownSymbols =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
-  if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+  if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
+                                               NumColdKnownSymbols))
     return E;
 
   uint64_t Num =
@@ -120,7 +121,7 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
   return deserializeRecords(Ptr);
 }
 
-Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   OS.write(StrToIndexMap.size());
   OS.write(KnownColdSymbols.size());
 
@@ -158,7 +159,7 @@ uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
 }
 
 Error DataAccessProfData::serialize(ProfOStream &OS) const {
-  if (Error E = serializeStrings(OS))
+  if (Error E = serializeSymbolsAndFilenames(OS))
     return E;
   OS.write(KnownColdHashes.size());
   for (const auto &Hash : KnownColdHashes)
@@ -177,9 +178,9 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
   return Error::success();
 }
 
-Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
-                                             uint64_t NumSampledSymbols,
-                                             uint64_t NumColdKnownSymbols) {
+Error DataAccessProfData::deserializeSymbolsAndFilenames(
+    const unsigned char *&Ptr, uint64_t NumSampledSymbols,
+    uint64_t NumColdKnownSymbols) {
   uint64_t Len =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
@@ -188,9 +189,9 @@ Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
   uint64_t StringCnt = 0;
   std::function<Error(StringRef)> addName = [&](StringRef Name) {
     if (StringCnt < NumSampledSymbols)
-      saveStringToMap(StrToIndexMap, saver, Name);
+      saveStringToMap(StrToIndexMap, Saver, Name);
     else
-      KnownColdSymbols.insert(saver.save(Name));
+      KnownColdSymbols.insert(Saver.save(Name));
     ++StringCnt;
     return Error::success();
   };
@@ -223,7 +224,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
       SymbolID = ID;
     else
       SymbolID = Strings[ID];
-    if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+    if (Error E = setDataAccessProfile(SymbolID, AccessCount))
       return E;
 
     auto &Record = Records.back().second;
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index f6362448a5734..b7b8d642ad930 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -764,18 +764,17 @@ static std::string ErrorToString(Error E) {
 TEST(MemProf, DataAccessProfileError) {
   // Returns error if the input symbol name is empty.
   llvm::data_access_prof::DataAccessProfData Data;
-  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
               HasSubstr("Empty symbol name"));
 
   // Returns error when the same symbol gets added twice.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
-  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
               HasSubstr("Duplicate symbol or string literal added"));
 
   // Returns error when the same string content hash gets added twice.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
-  EXPECT_THAT(ErrorToString(
-                  Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
               HasSubstr("Duplicate symbol or string literal added"));
 }
 
@@ -788,16 +787,16 @@ TEST(MemProf, DataAccessProfile) {
 
   // In the bool conversion, Error is true if it's in a failure state and false
   // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
-                                                   {
-                                                       DataLocation{"file2", 3},
-                                                   }));
+  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+                                         {
+                                             DataLocation{"file2", 3},
+                                         }));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+  ASSERT_FALSE(Data.setDataAccessProfile(
       (uint64_t)135246, 1000,
       {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
 

>From 80249bce82307ffef1817747d1c2fca100a9ea53 Mon Sep 17 00:00:00 2001
From: Mingming Liu <mingmingl at google.com>
Date: Tue, 6 May 2025 10:39:53 -0700
Subject: [PATCH 3/6] Apply suggestions from code review

Co-authored-by: Kazu Hirata <kazu at google.com>
---
 llvm/include/llvm/ProfileData/DataAccessProf.h | 11 ++++++-----
 llvm/include/llvm/ProfileData/InstrProf.h      |  2 +-
 llvm/lib/ProfileData/DataAccessProf.cpp        |  2 +-
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 36648d6298ee5..81289b60b96ed 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -97,7 +97,7 @@ class DataAccessProfData {
 
   /// Methods to add symbolized data access profile. Returns error if duplicated
   /// symbol names or content hashes are seen. The user of this class should
-  /// aggregate counters that corresponds to the same symbol name or with the
+  /// aggregate counters that correspond to the same symbol name or with the
   /// same string literal hash before calling 'add*' methods.
   Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
@@ -108,7 +108,8 @@ class DataAccessProfData {
                              const llvm::SmallVector<DataLocation> &Locations);
   Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
 
-  /// Returns a iterable StringRef for strings in the order they are added.
+  /// Returns an iterable StringRef for strings in the order they are added.
+  /// Each string may be a symbol name or a file name.
   auto getStrings() const {
     ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
         StrToIndexMap.begin(), StrToIndexMap.end());
@@ -116,15 +117,15 @@ class DataAccessProfData {
   }
 
   /// Returns array reference for various internal data structures.
-  inline ArrayRef<
+  ArrayRef<
       std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
   getRecords() const {
     return Records.getArrayRef();
   }
-  inline ArrayRef<StringRef> getKnownColdSymbols() const {
+  ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
-  inline ArrayRef<uint64_t> getKnownColdHashes() const {
+  ArrayRef<uint64_t> getKnownColdHashes() const {
     return KnownColdHashes.getArrayRef();
   }
 
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 8a6be22bdb1a4..710b2f6836064 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,7 +357,7 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 /// the duplicated profile variables for Comdat functions.
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
-/// \c NameStrings is a string composed of one of more possibly encoded
+/// \c NameStrings is a string composed of one or more possibly encoded
 /// sub-strings. The substrings are separated by 0 or more zero bytes. This
 /// method decodes the string and calls `NameCallback` for each substring.
 Error readAndDecodeStrings(StringRef NameStrings,
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index c52533c13919c..a42ee41b24358 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -141,7 +141,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   // Record the length of compressed string.
   OS.write(CompressedStringLen);
   // Write the chars in compressed strings.
-  for (auto &c : CompressedStrings)
+  for (char C : CompressedStrings)
     OS.writeByte(static_cast<uint8_t>(c));
   // Pad up to a multiple of 8.
   // InstrProfReader could read bytes according to 'CompressedStringLen'.

>From b69c9930b9584a8dede96b0a9605c03f4d504bb8 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 10:50:59 -0700
Subject: [PATCH 4/6] resolve review feedback

---
 .../include/llvm/ProfileData/DataAccessProf.h |  43 ++---
 llvm/include/llvm/ProfileData/InstrProf.h     |   5 +-
 llvm/lib/ProfileData/DataAccessProf.cpp       |  72 ++++---
 llvm/unittests/ProfileData/CMakeLists.txt     |   1 +
 .../ProfileData/DataAccessProfTest.cpp        | 181 ++++++++++++++++++
 llvm/unittests/ProfileData/MemProfTest.cpp    | 160 ----------------
 6 files changed, 245 insertions(+), 217 deletions(-)
 create mode 100644 llvm/unittests/ProfileData/DataAccessProfTest.cpp

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 81289b60b96ed..a85a3320ae57c 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -70,9 +70,11 @@ struct DataAccessProfRecord {
 /// This class provides profile look-up, serialization and deserialization.
 class DataAccessProfData {
 public:
-  // SymbolID is either a string representing symbol name, or a uint64_t
-  // representing the content hash of a string literal.
-  using SymbolID = std::variant<StringRef, uint64_t>;
+  // SymbolID is either a string representing symbol name if the symbol has
+  // stable mangled name relative to source code, or a uint64_t representing the
+  // content hash of a string literal (with unstable name patterns like
+  // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+  using SymbolHandle = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
   DataAccessProfData() : Saver(Allocator) {}
@@ -90,38 +92,32 @@ class DataAccessProfData {
   /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
   /// isn't a record. Internally, this function will canonicalize the symbol
   /// name before the lookup.
-  const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+  const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
 
   /// Returns true if \p SymID is seen in profiled binaries and cold.
-  bool isKnownColdSymbol(const SymbolID SymID) const;
+  bool isKnownColdSymbol(const SymbolHandle SymID) const;
 
-  /// Methods to add symbolized data access profile. Returns error if duplicated
+  /// Methods to set symbolized data access profile. Returns error if duplicated
   /// symbol names or content hashes are seen. The user of this class should
   /// aggregate counters that correspond to the same symbol name or with the
-  /// same string literal hash before calling 'add*' methods.
-  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  /// same string literal hash before calling 'set*' methods.
+  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
   /// the `filename:line` where this symbol shows up. Note because of linker's
   /// merge of identical symbols (e.g., unnamed_addr string literals), one
   /// symbol is likely to have multiple locations.
-  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
-                             const llvm::SmallVector<DataLocation> &Locations);
-  Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
+                             ArrayRef<DataLocation> Locations);
+  Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
 
   /// Returns an iterable StringRef for strings in the order they are added.
   /// Each string may be a symbol name or a file name.
   auto getStrings() const {
-    ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
-        StrToIndexMap.begin(), StrToIndexMap.end());
-    return llvm::make_first_range(RefSymbolNames);
+    return llvm::make_first_range(StrToIndexMap.getArrayRef());
   }
 
   /// Returns array reference for various internal data structures.
-  ArrayRef<
-      std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
-  getRecords() const {
-    return Records.getArrayRef();
-  }
+  auto getRecords() const { return Records.getArrayRef(); }
   ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
@@ -137,13 +133,13 @@ class DataAccessProfData {
   /// start of the next payload.
   Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
                                        const uint64_t NumSampledSymbols,
-                                       uint64_t NumColdKnownSymbols);
+                                       const uint64_t NumColdKnownSymbols);
 
   /// Decode the records and increment \p Ptr to the start of the next payload.
   Error deserializeRecords(const unsigned char *&Ptr);
 
   /// A helper function to compute a storage index for \p SymbolID.
-  uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+  uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
 
   // Keeps owned copies of the input strings.
   // NOTE: Keep `Saver` initialized before other class members that reference
@@ -151,9 +147,8 @@ class DataAccessProfData {
   llvm::BumpPtrAllocator Allocator;
   llvm::UniqueStringSaver Saver;
 
-  // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
-  // its record index.
-  MapVector<SymbolID, DataAccessProfRecord> Records;
+  // `Records` stores the records.
+  MapVector<SymbolHandle, DataAccessProfRecord> Records;
 
   // Use MapVector to keep input order of strings for serialization and
   // deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 710b2f6836064..33b93ea0a558a 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -358,8 +358,9 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
 /// \c NameStrings is a string composed of one or more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
+/// sub-strings. The substrings are separated by `\01` (returned by
+/// InstrProf.h:getInstrProfNameSeparator). This method decodes the string and
+/// calls `NameCallback` for each substring.
 Error readAndDecodeStrings(StringRef NameStrings,
                            std::function<Error(StringRef)> NameCallback);
 
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index a42ee41b24358..d1c034f639347 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -23,11 +23,22 @@ saveStringToMap(MapVector<StringRef, uint64_t> &Map,
   return Iter;
 }
 
+// Returns the canonical name or error.
+static Expected<StringRef> getCanonicalName(StringRef Name) {
+  if (Name.empty())
+    return make_error<StringError>("Empty symbol name",
+                                   llvm::errc::invalid_argument);
+  return InstrProfSymtab::getCanonicalName(Name);
+}
+
 const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
   auto Key = SymbolID;
-  if (std::holds_alternative<StringRef>(SymbolID))
-    Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+  if (std::holds_alternative<StringRef>(SymbolID)) {
+    StringRef Name = std::get<StringRef>(SymbolID);
+    assert(!Name.empty() && "Empty symbol name");
+    Key = InstrProfSymtab::getCanonicalName(Name);
+  }
 
   auto It = Records.find(Key);
   if (It != Records.end())
@@ -36,30 +47,27 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
   return nullptr;
 }
 
-bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
     return KnownColdHashes.contains(std::get<uint64_t>(SymID));
   return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
                                                uint64_t AccessCount) {
   uint64_t RecordID = -1;
   bool IsStringLiteral = false;
-  SymbolID Key;
+  SymbolHandle Key;
   if (std::holds_alternative<uint64_t>(Symbol)) {
     RecordID = std::get<uint64_t>(Symbol);
     Key = RecordID;
     IsStringLiteral = true;
   } else {
-    StringRef SymbolName = std::get<StringRef>(Symbol);
-    if (SymbolName.empty())
-      return make_error<StringError>("Empty symbol name",
-                                     llvm::errc::invalid_argument);
-
-    StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
-    Key = CanonicalName;
-    RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
+    auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
+    if (!CanonicalName)
+      return CanonicalName.takeError();
+    std::tie(Key, RecordID) =
+        *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
     IsStringLiteral = false;
   }
 
@@ -75,8 +83,8 @@ Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
 }
 
 Error DataAccessProfData::setDataAccessProfile(
-    SymbolID SymbolID, uint64_t AccessCount,
-    const llvm::SmallVector<DataLocation> &Locations) {
+    SymbolHandle SymbolID, uint64_t AccessCount,
+    ArrayRef<DataLocation> Locations) {
   if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
@@ -89,17 +97,15 @@ Error DataAccessProfData::setDataAccessProfile(
   return Error::success();
 }
 
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
   if (std::holds_alternative<uint64_t>(SymbolID)) {
     KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
     return Error::success();
   }
-  StringRef SymbolName = std::get<StringRef>(SymbolID);
-  if (SymbolName.empty())
-    return make_error<StringError>("Empty symbol name",
-                                   llvm::errc::invalid_argument);
-  StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
-  KnownColdSymbols.insert(CanonicalSymName);
+  auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
+  if (!CanonicalName)
+    return CanonicalName.takeError();
+  KnownColdSymbols.insert(*CanonicalName);
   return Error::success();
 }
 
@@ -142,7 +148,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   OS.write(CompressedStringLen);
   // Write the chars in compressed strings.
   for (char C : CompressedStrings)
-    OS.writeByte(static_cast<uint8_t>(c));
+    OS.writeByte(static_cast<uint8_t>(C));
   // Pad up to a multiple of 8.
   // InstrProfReader could read bytes according to 'CompressedStringLen'.
   const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
@@ -151,11 +157,15 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   return Error::success();
 }
 
-uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+uint64_t
+DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
   if (std::holds_alternative<uint64_t>(SymbolID))
     return std::get<uint64_t>(SymbolID);
 
-  return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+  auto Iter = StrToIndexMap.find(std::get<StringRef>(SymbolID));
+  assert(Iter != StrToIndexMap.end() &&
+         "String literals not found in StrToIndexMap");
+  return Iter->second;
 }
 
 Error DataAccessProfData::serialize(ProfOStream &OS) const {
@@ -179,13 +189,13 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
 }
 
 Error DataAccessProfData::deserializeSymbolsAndFilenames(
-    const unsigned char *&Ptr, uint64_t NumSampledSymbols,
-    uint64_t NumColdKnownSymbols) {
+    const unsigned char *&Ptr, const uint64_t NumSampledSymbols,
+    const uint64_t NumColdKnownSymbols) {
   uint64_t Len =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-  // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
-  // symbols with samples, and next N strings are known cold symbols.
+  // The first NumSampledSymbols strings are symbols with samples, and next
+  // NumColdKnownSymbols strings are known cold symbols.
   uint64_t StringCnt = 0;
   std::function<Error(StringRef)> addName = [&](StringRef Name) {
     if (StringCnt < NumSampledSymbols)
@@ -219,7 +229,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
     uint64_t AccessCount =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-    SymbolID SymbolID;
+    SymbolHandle SymbolID;
     if (IsStringLiteral)
       SymbolID = ID;
     else
diff --git a/llvm/unittests/ProfileData/CMakeLists.txt b/llvm/unittests/ProfileData/CMakeLists.txt
index 0a7f7da085950..29b9cb751dabe 100644
--- a/llvm/unittests/ProfileData/CMakeLists.txt
+++ b/llvm/unittests/ProfileData/CMakeLists.txt
@@ -10,6 +10,7 @@ set(LLVM_LINK_COMPONENTS
 add_llvm_unittest(ProfileDataTests
   BPFunctionNodeTest.cpp
   CoverageMappingTest.cpp
+  DataAccessProfTest.cpp
   InstrProfDataTest.cpp
   InstrProfTest.cpp
   ItaniumManglingCanonicalizerTest.cpp
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
new file mode 100644
index 0000000000000..50c4af49fe76b
--- /dev/null
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -0,0 +1,181 @@
+
+//===- unittests/Support/DataAccessProfTest.cpp
+//----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+namespace data_access_prof {
+namespace {
+
+using ::llvm::StringRef;
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+
+static std::string ErrorToString(Error E) {
+  std::string ErrMsg;
+  llvm::raw_string_ostream OS(ErrMsg);
+  llvm::logAllUnhandledErrors(std::move(E), OS);
+  return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+  // Returns error if the input symbol name is empty.
+  DataAccessProfData Data;
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
+              HasSubstr("Empty symbol name"));
+
+  // Returns error when the same symbol gets added twice.
+  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
+              HasSubstr("Duplicate symbol or string literal added"));
+
+  // Returns error when the same string content hash gets added twice.
+  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
+              HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+  DataAccessProfData Data;
+
+  // In the bool conversion, Error is true if it's in a failure state and false
+  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+                                         {
+                                             DataLocation{"file2", 3},
+                                         }));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+  ASSERT_FALSE(Data.setDataAccessProfile(
+      (uint64_t)135246, 1000,
+      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+  {
+    // Test that symbol names and file names are stored in the input order.
+    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles.
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+    EXPECT_THAT(
+        *Data.getProfileRecord("foo.llvm.123"),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+              testing::Field(&DataAccessProfRecord::AccessCount, 100),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+              testing::Field(&DataAccessProfRecord::Locations,
+                             testing::IsEmpty())));
+    EXPECT_THAT(
+        *Data.getProfileRecord("bar.__uniq.321"),
+        AllOf(
+            testing::Field(&DataAccessProfRecord::SymbolID, 1),
+            testing::Field(&DataAccessProfRecord::AccessCount, 123),
+            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+            testing::Field(&DataAccessProfRecord::Locations,
+                           ElementsAre(AllOf(
+                               testing::Field(&DataLocation::FileName, "file2"),
+                               testing::Field(&DataLocation::Line, 3))))));
+    EXPECT_THAT(
+        *Data.getProfileRecord((uint64_t)135246),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+              testing::Field(
+                  &DataAccessProfRecord::Locations,
+                  ElementsAre(
+                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                            testing::Field(&DataLocation::Line, 1)),
+                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                            testing::Field(&DataLocation::Line, 2))))));
+  }
+
+  // Tests serialization and de-serialization.
+  DataAccessProfData deserializedData;
+  {
+    std::string serializedData;
+    llvm::raw_string_ostream OS(serializedData);
+    llvm::ProfOStream POS(OS);
+
+    EXPECT_FALSE(Data.serialize(POS));
+
+    const unsigned char *p =
+        reinterpret_cast<const unsigned char *>(serializedData.data());
+    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                testing::IsEmpty());
+    EXPECT_FALSE(deserializedData.deserialize(p));
+
+    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+                ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles after deserialization.
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+    auto Records =
+        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+    EXPECT_THAT(
+        Records,
+        ElementsAre(
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(&DataAccessProfRecord::Locations,
+                                 testing::IsEmpty())),
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(
+                      &DataAccessProfRecord::Locations,
+                      ElementsAre(AllOf(
+                          testing::Field(&DataLocation::FileName, "file2"),
+                          testing::Field(&DataLocation::Line, 3))))),
+            AllOf(
+                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+                testing::Field(
+                    &DataAccessProfRecord::Locations,
+                    ElementsAre(
+                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                              testing::Field(&DataLocation::Line, 1)),
+                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                              testing::Field(&DataLocation::Line, 2)))))));
+  }
+}
+} // namespace
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index b7b8d642ad930..3e430aa4eae58 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,17 +10,14 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLForwardCompat.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/DebugInfo/DIContext.h"
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Object/ObjectFile.h"
-#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/ProfileData/MemProfReader.h"
 #include "llvm/ProfileData/MemProfYAML.h"
 #include "llvm/Support/raw_ostream.h"
-#include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -39,8 +36,6 @@ using ::llvm::StringRef;
 using ::llvm::object::SectionedAddress;
 using ::llvm::symbolize::SymbolizableModule;
 using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
-using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 using ::testing::Pair;
 using ::testing::Return;
@@ -752,161 +747,6 @@ TEST(MemProf, YAMLParser) {
                                ElementsAre(0x3000)))));
 }
 
-static std::string ErrorToString(Error E) {
-  std::string ErrMsg;
-  llvm::raw_string_ostream OS(ErrMsg);
-  llvm::logAllUnhandledErrors(std::move(E), OS);
-  return ErrMsg;
-}
-
-// Test the various scenarios when DataAccessProfData should return error on
-// invalid input.
-TEST(MemProf, DataAccessProfileError) {
-  // Returns error if the input symbol name is empty.
-  llvm::data_access_prof::DataAccessProfData Data;
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
-              HasSubstr("Empty symbol name"));
-
-  // Returns error when the same symbol gets added twice.
-  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
-              HasSubstr("Duplicate symbol or string literal added"));
-
-  // Returns error when the same string content hash gets added twice.
-  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
-              HasSubstr("Duplicate symbol or string literal added"));
-}
-
-// Test the following operations on DataAccessProfData:
-// - Profile record look up.
-// - Serialization and de-serialization.
-TEST(MemProf, DataAccessProfile) {
-  using namespace llvm::data_access_prof;
-  llvm::data_access_prof::DataAccessProfData Data;
-
-  // In the bool conversion, Error is true if it's in a failure state and false
-  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
-  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
-  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
-                                         {
-                                             DataLocation{"file2", 3},
-                                         }));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
-  ASSERT_FALSE(Data.setDataAccessProfile(
-      (uint64_t)135246, 1000,
-      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
-
-  {
-    // Test that symbol names and file names are stored in the input order.
-    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
-    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
-    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
-
-    // Look up profiles.
-    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
-    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
-    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
-    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
-
-    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
-    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
-
-    EXPECT_THAT(
-        Data.getProfileRecord("foo.llvm.123"),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-              testing::Field(&DataAccessProfRecord::AccessCount, 100),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-              testing::Field(&DataAccessProfRecord::Locations,
-                             testing::IsEmpty())));
-    EXPECT_THAT(
-        *Data.getProfileRecord("bar.__uniq.321"),
-        AllOf(
-            testing::Field(&DataAccessProfRecord::SymbolID, 1),
-            testing::Field(&DataAccessProfRecord::AccessCount, 123),
-            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-            testing::Field(&DataAccessProfRecord::Locations,
-                           ElementsAre(AllOf(
-                               testing::Field(&DataLocation::FileName, "file2"),
-                               testing::Field(&DataLocation::Line, 3))))));
-    EXPECT_THAT(
-        *Data.getProfileRecord((uint64_t)135246),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-              testing::Field(
-                  &DataAccessProfRecord::Locations,
-                  ElementsAre(
-                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                            testing::Field(&DataLocation::Line, 1)),
-                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                            testing::Field(&DataLocation::Line, 2))))));
-  }
-
-  // Tests serialization and de-serialization.
-  llvm::data_access_prof::DataAccessProfData deserializedData;
-  {
-    std::string serializedData;
-    llvm::raw_string_ostream OS(serializedData);
-    llvm::ProfOStream POS(OS);
-
-    EXPECT_FALSE(Data.serialize(POS));
-
-    const unsigned char *p =
-        reinterpret_cast<const unsigned char *>(serializedData.data());
-    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
-                testing::IsEmpty());
-    EXPECT_FALSE(deserializedData.deserialize(p));
-
-    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
-    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
-                ElementsAre("sym2", "sym1"));
-    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
-
-    // Look up profiles after deserialization.
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
-
-    auto Records =
-        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
-
-    EXPECT_THAT(
-        Records,
-        ElementsAre(
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(&DataAccessProfRecord::Locations,
-                                 testing::IsEmpty())),
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(
-                      &DataAccessProfRecord::Locations,
-                      ElementsAre(AllOf(
-                          testing::Field(&DataLocation::FileName, "file2"),
-                          testing::Field(&DataLocation::Line, 3))))),
-            AllOf(
-                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-                testing::Field(
-                    &DataAccessProfRecord::Locations,
-                    ElementsAre(
-                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                              testing::Field(&DataLocation::Line, 1)),
-                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                              testing::Field(&DataLocation::Line, 2)))))));
-  }
-}
-
 // Verify that the YAML parser accepts a GUID expressed as a function name.
 TEST(MemProf, YAMLParserGUID) {
   StringRef YAMLData = R"YAML(

>From df080949a1ccc75fceb386b0232005f96397752d Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 16:35:14 -0700
Subject: [PATCH 5/6] resolve feedback

---
 llvm/include/llvm/ProfileData/DataAccessProf.h |  6 +++++-
 llvm/lib/ProfileData/DataAccessProf.cpp        | 12 +++++++++---
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index a85a3320ae57c..91de43fdf60ca 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -52,7 +52,11 @@ struct DataAccessProfRecord {
 
   // Represents a data symbol. The semantic comes in two forms: a symbol index
   // for symbol name if `IsStringLiteral` is false, or the hash of a string
-  // content if `IsStringLiteral` is true. Required.
+  // content if `IsStringLiteral` is true. For most of the symbolizable static
+  // data, the mangled symbol names remain stable relative to the source code
+  // and therefore used to identify symbols across binary releases. String
+  // literals have unstable name patterns like `.str.N[.llvm.hash]`, so we use
+  // the content hash instead. This is a required field.
   uint64_t SymbolID;
 
   // The access count of symbol. Required.
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d1c034f639347..d7e67f5f09cbe 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -35,9 +35,15 @@ const DataAccessProfRecord *
 DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
   auto Key = SymbolID;
   if (std::holds_alternative<StringRef>(SymbolID)) {
-    StringRef Name = std::get<StringRef>(SymbolID);
-    assert(!Name.empty() && "Empty symbol name");
-    Key = InstrProfSymtab::getCanonicalName(Name);
+    auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
+    // If name canonicalization fails, suppress the error inside.
+    if (!NameOrErr) {
+      assert(
+          std::get<StringRef>(SymbolID).empty() &&
+          "Name canonicalization only fails when stringified string is empty.");
+      return nullptr;
+    }
+    Key = *NameOrErr;
   }
 
   auto It = Records.find(Key);

>From 6dd04e46542851b84bf26cd95245399204072085 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 12 May 2025 17:05:19 -0700
Subject: [PATCH 6/6] resolve comments

---
 .../include/llvm/ProfileData/DataAccessProf.h | 119 ++++++++++++------
 llvm/include/llvm/ProfileData/InstrProf.h     |   2 +-
 llvm/lib/ProfileData/DataAccessProf.cpp       |  49 ++++----
 .../ProfileData/DataAccessProfTest.cpp        | 116 ++++++++---------
 4 files changed, 167 insertions(+), 119 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 91de43fdf60ca..e8504102238d1 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -30,23 +30,40 @@
 #include "llvm/Support/StringSaver.h"
 
 #include <cstdint>
+#include <optional>
 #include <variant>
 
 namespace llvm {
 
 namespace data_access_prof {
-// The location of data in the source code.
-struct DataLocation {
+
+/// The location of data in the source code. Used by profile lookup API.
+struct SourceLocation {
+  SourceLocation(StringRef FileNameRef, uint32_t Line)
+      : FileName(FileNameRef.str()), Line(Line) {}
+  /// The filename where the data is located.
+  std::string FileName;
+  /// The line number in the source code.
+  uint32_t Line;
+};
+
+namespace internal {
+
+// Conceptually similar to SourceLocation except that FileNames are StringRef of
+// which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData`
+// to represent data locations internally.
+struct SourceLocationRef {
   // The filename where the data is located.
   StringRef FileName;
   // The line number in the source code.
   uint32_t Line;
 };
 
-// The data access profiles for a symbol.
-struct DataAccessProfRecord {
-  DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
-                       bool IsStringLiteral)
+// The data access profiles for a symbol. Used by `DataAccessProfData`
+// to represent records internally.
+struct DataAccessProfRecordRef {
+  DataAccessProfRecordRef(uint64_t SymbolID, uint64_t AccessCount,
+                          bool IsStringLiteral)
       : SymbolID(SymbolID), AccessCount(AccessCount),
         IsStringLiteral(IsStringLiteral) {}
 
@@ -67,18 +84,43 @@ struct DataAccessProfRecord {
   bool IsStringLiteral;
 
   // The locations of data in the source code. Optional.
-  llvm::SmallVector<DataLocation, 0> Locations;
+  llvm::SmallVector<SourceLocationRef, 0> Locations;
 };
+} // namespace internal
+
+// SymbolID is either a string representing symbol name if the symbol has
+// stable mangled name relative to source code, or a uint64_t representing the
+// content hash of a string literal (with unstable name patterns like
+// `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+using SymbolHandleRef = std::variant<StringRef, uint64_t>;
 
-/// Encapsulates the data access profile data and the methods to operate on it.
-/// This class provides profile look-up, serialization and deserialization.
+// The senamtic is the same as `SymbolHandleRef` above. The strings are owned.
+using SymbolHandle = std::variant<std::string, uint64_t>;
+
+/// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+public:
+  DataAccessProfRecord(SymbolHandleRef SymHandleRef,
+                       ArrayRef<internal::SourceLocationRef> LocRefs) {
+    if (std::holds_alternative<StringRef>(SymHandleRef)) {
+      SymHandle = std::get<StringRef>(SymHandleRef).str();
+    } else
+      SymHandle = std::get<uint64_t>(SymHandleRef);
+
+    for (auto Loc : LocRefs)
+      Locations.push_back(SourceLocation(Loc.FileName, Loc.Line));
+  }
+  SymbolHandle SymHandle;
+
+  // The locations of data in the source code. Optional.
+  SmallVector<SourceLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on
+/// it. This class provides profile look-up, serialization and
+/// deserialization.
 class DataAccessProfData {
 public:
-  // SymbolID is either a string representing symbol name if the symbol has
-  // stable mangled name relative to source code, or a uint64_t representing the
-  // content hash of a string literal (with unstable name patterns like
-  // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
-  using SymbolHandle = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
   DataAccessProfData() : Saver(Allocator) {}
@@ -93,35 +135,39 @@ class DataAccessProfData {
   /// Deserialize this class from the given buffer.
   Error deserialize(const unsigned char *&Ptr);
 
-  /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+  /// Returns a profile record for \p SymbolID, or std::nullopt if there
   /// isn't a record. Internally, this function will canonicalize the symbol
   /// name before the lookup.
-  const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
+  std::optional<DataAccessProfRecord>
+  getProfileRecord(const SymbolHandleRef SymID) const;
 
   /// Returns true if \p SymID is seen in profiled binaries and cold.
-  bool isKnownColdSymbol(const SymbolHandle SymID) const;
+  bool isKnownColdSymbol(const SymbolHandleRef SymID) const;
 
-  /// Methods to set symbolized data access profile. Returns error if duplicated
-  /// symbol names or content hashes are seen. The user of this class should
-  /// aggregate counters that correspond to the same symbol name or with the
-  /// same string literal hash before calling 'set*' methods.
-  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
+  /// Methods to set symbolized data access profile. Returns error if
+  /// duplicated symbol names or content hashes are seen. The user of this
+  /// class should aggregate counters that correspond to the same symbol name
+  /// or with the same string literal hash before calling 'set*' methods.
+  Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
   /// the `filename:line` where this symbol shows up. Note because of linker's
   /// merge of identical symbols (e.g., unnamed_addr string literals), one
   /// symbol is likely to have multiple locations.
-  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
-                             ArrayRef<DataLocation> Locations);
-  Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
-
-  /// Returns an iterable StringRef for strings in the order they are added.
-  /// Each string may be a symbol name or a file name.
-  auto getStrings() const {
-    return llvm::make_first_range(StrToIndexMap.getArrayRef());
+  Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount,
+                             ArrayRef<SourceLocation> Locations);
+  /// Add a symbol that's seen in the profiled binary without samples.
+  Error addKnownSymbolWithoutSamples(SymbolHandleRef SymbolID);
+
+  /// The following methods return array reference for various internal data
+  /// structures.
+  ArrayRef<StringToIndexMap::value_type> getStrToIndexMapRef() const {
+    return StrToIndexMap.getArrayRef();
+  }
+  ArrayRef<
+      MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef>::value_type>
+  getRecords() const {
+    return Records.getArrayRef();
   }
-
-  /// Returns array reference for various internal data structures.
-  auto getRecords() const { return Records.getArrayRef(); }
   ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
@@ -139,11 +185,12 @@ class DataAccessProfData {
                                        const uint64_t NumSampledSymbols,
                                        const uint64_t NumColdKnownSymbols);
 
-  /// Decode the records and increment \p Ptr to the start of the next payload.
+  /// Decode the records and increment \p Ptr to the start of the next
+  /// payload.
   Error deserializeRecords(const unsigned char *&Ptr);
 
   /// A helper function to compute a storage index for \p SymbolID.
-  uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
+  uint64_t getEncodedIndex(const SymbolHandleRef SymbolID) const;
 
   // Keeps owned copies of the input strings.
   // NOTE: Keep `Saver` initialized before other class members that reference
@@ -152,7 +199,7 @@ class DataAccessProfData {
   llvm::UniqueStringSaver Saver;
 
   // `Records` stores the records.
-  MapVector<SymbolHandle, DataAccessProfRecord> Records;
+  MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef> Records;
 
   // Use MapVector to keep input order of strings for serialization and
   // deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 33b93ea0a558a..544a59df43ed3 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -500,7 +500,7 @@ class InstrProfSymtab {
 public:
   using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
 
-  // Returns the canonial name of the given PGOName. In a canonical name, all
+  // Returns the canonical name of the given PGOName. In a canonical name, all
   // suffixes that begins with "." except ".__uniq." are stripped.
   // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
   static StringRef getCanonicalName(StringRef PGOName);
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d7e67f5f09cbe..c5d0099977cfa 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -16,11 +16,11 @@ namespace data_access_prof {
 // If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
 // creates an owned copy of `Str`, adds a map entry for it and returns the
 // iterator.
-static MapVector<StringRef, uint64_t>::iterator
-saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+static std::pair<StringRef, uint64_t>
+saveStringToMap(DataAccessProfData::StringToIndexMap &Map,
                 llvm::UniqueStringSaver &Saver, StringRef Str) {
   auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
-  return Iter;
+  return *Iter;
 }
 
 // Returns the canonical name or error.
@@ -31,8 +31,8 @@ static Expected<StringRef> getCanonicalName(StringRef Name) {
   return InstrProfSymtab::getCanonicalName(Name);
 }
 
-const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
+std::optional<DataAccessProfRecord>
+DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
   auto Key = SymbolID;
   if (std::holds_alternative<StringRef>(SymbolID)) {
     auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
@@ -41,40 +41,39 @@ DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
       assert(
           std::get<StringRef>(SymbolID).empty() &&
           "Name canonicalization only fails when stringified string is empty.");
-      return nullptr;
+      return std::nullopt;
     }
     Key = *NameOrErr;
   }
 
   auto It = Records.find(Key);
-  if (It != Records.end())
-    return &It->second;
+  if (It != Records.end()) {
+    return DataAccessProfRecord(Key, It->second.Locations);
+  }
 
-  return nullptr;
+  return std::nullopt;
 }
 
-bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
     return KnownColdHashes.contains(std::get<uint64_t>(SymID));
   return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol,
                                                uint64_t AccessCount) {
   uint64_t RecordID = -1;
-  bool IsStringLiteral = false;
-  SymbolHandle Key;
-  if (std::holds_alternative<uint64_t>(Symbol)) {
+  const bool IsStringLiteral = std::holds_alternative<uint64_t>(Symbol);
+  SymbolHandleRef Key;
+  if (IsStringLiteral) {
     RecordID = std::get<uint64_t>(Symbol);
     Key = RecordID;
-    IsStringLiteral = true;
   } else {
     auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
     if (!CanonicalName)
       return CanonicalName.takeError();
     std::tie(Key, RecordID) =
-        *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
-    IsStringLiteral = false;
+        saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
   }
 
   auto [Iter, Inserted] =
@@ -89,21 +88,22 @@ Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
 }
 
 Error DataAccessProfData::setDataAccessProfile(
-    SymbolHandle SymbolID, uint64_t AccessCount,
-    ArrayRef<DataLocation> Locations) {
+    SymbolHandleRef SymbolID, uint64_t AccessCount,
+    ArrayRef<SourceLocation> Locations) {
   if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
   auto &Record = Records.back().second;
   for (const auto &Location : Locations)
     Record.Locations.push_back(
-        {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
+        {saveStringToMap(StrToIndexMap, Saver, Location.FileName).first,
          Location.Line});
 
   return Error::success();
 }
 
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(
+    SymbolHandleRef SymbolID) {
   if (std::holds_alternative<uint64_t>(SymbolID)) {
     KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
     return Error::success();
@@ -164,7 +164,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
 }
 
 uint64_t
-DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
+DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const {
   if (std::holds_alternative<uint64_t>(SymbolID))
     return std::get<uint64_t>(SymbolID);
 
@@ -220,7 +220,8 @@ Error DataAccessProfData::deserializeSymbolsAndFilenames(
 }
 
 Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
-  SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+  SmallVector<StringRef> Strings =
+      llvm::to_vector(llvm::make_first_range(getStrToIndexMapRef()));
 
   uint64_t NumRecords =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
@@ -235,7 +236,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
     uint64_t AccessCount =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-    SymbolHandle SymbolID;
+    SymbolHandleRef SymbolID;
     if (IsStringLiteral)
       SymbolID = ID;
     else
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
index 50c4af49fe76b..127230d4805e7 100644
--- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -1,4 +1,3 @@
-
 //===- unittests/Support/DataAccessProfTest.cpp
 //----------------------------------===//
 //
@@ -10,6 +9,7 @@
 
 #include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
 #include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
@@ -19,7 +19,9 @@ namespace data_access_prof {
 namespace {
 
 using ::llvm::StringRef;
+using llvm::ValueIs;
 using ::testing::ElementsAre;
+using ::testing::Field;
 using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 
@@ -53,6 +55,8 @@ TEST(MemProf, DataAccessProfileError) {
 // - Profile record look up.
 // - Serialization and de-serialization.
 TEST(MemProf, DataAccessProfile) {
+  using internal::DataAccessProfRecordRef;
+  using internal::SourceLocationRef;
   DataAccessProfData Data;
 
   // In the bool conversion, Error is true if it's in a failure state and false
@@ -62,18 +66,19 @@ TEST(MemProf, DataAccessProfile) {
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
   ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
                                          {
-                                             DataLocation{"file2", 3},
+                                             SourceLocation{"file2", 3},
                                          }));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
   ASSERT_FALSE(Data.setDataAccessProfile(
       (uint64_t)135246, 1000,
-      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+      {SourceLocation{"file1", 1}, SourceLocation{"file2", 2}}));
 
   {
     // Test that symbol names and file names are stored in the input order.
-    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(
+        llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())),
+        ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
     EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
     EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
 
@@ -83,38 +88,34 @@ TEST(MemProf, DataAccessProfile) {
     EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
     EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
 
-    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
-    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), std::nullopt);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), std::nullopt);
 
     EXPECT_THAT(
-        *Data.getProfileRecord("foo.llvm.123"),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-              testing::Field(&DataAccessProfRecord::AccessCount, 100),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-              testing::Field(&DataAccessProfRecord::Locations,
-                             testing::IsEmpty())));
+        Data.getProfileRecord("foo.llvm.123"),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<std::string>(testing::Eq("foo"))),
+            Field(&DataAccessProfRecord::Locations, testing::IsEmpty()))));
     EXPECT_THAT(
-        *Data.getProfileRecord("bar.__uniq.321"),
-        AllOf(
-            testing::Field(&DataAccessProfRecord::SymbolID, 1),
-            testing::Field(&DataAccessProfRecord::AccessCount, 123),
-            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-            testing::Field(&DataAccessProfRecord::Locations,
-                           ElementsAre(AllOf(
-                               testing::Field(&DataLocation::FileName, "file2"),
-                               testing::Field(&DataLocation::Line, 3))))));
+        Data.getProfileRecord("bar.__uniq.321"),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<std::string>(
+                      testing::Eq("bar.__uniq.321"))),
+            Field(&DataAccessProfRecord::Locations,
+                  ElementsAre(AllOf(Field(&SourceLocation::FileName, "file2"),
+                                    Field(&SourceLocation::Line, 3)))))));
     EXPECT_THAT(
-        *Data.getProfileRecord((uint64_t)135246),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-              testing::Field(
-                  &DataAccessProfRecord::Locations,
-                  ElementsAre(
-                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                            testing::Field(&DataLocation::Line, 1)),
-                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                            testing::Field(&DataLocation::Line, 2))))));
+        Data.getProfileRecord((uint64_t)135246),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<uint64_t>(testing::Eq(135246))),
+            Field(&DataAccessProfRecord::Locations,
+                  ElementsAre(AllOf(Field(&SourceLocation::FileName, "file1"),
+                                    Field(&SourceLocation::Line, 1)),
+                              AllOf(Field(&SourceLocation::FileName, "file2"),
+                                    Field(&SourceLocation::Line, 2)))))));
   }
 
   // Tests serialization and de-serialization.
@@ -128,11 +129,13 @@ TEST(MemProf, DataAccessProfile) {
 
     const unsigned char *p =
         reinterpret_cast<const unsigned char *>(serializedData.data());
-    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+    ASSERT_THAT(llvm::to_vector(llvm::make_first_range(
+                    deserializedData.getStrToIndexMapRef())),
                 testing::IsEmpty());
     EXPECT_FALSE(deserializedData.deserialize(p));
 
-    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+    EXPECT_THAT(llvm::to_vector(llvm::make_first_range(
+                    deserializedData.getStrToIndexMapRef())),
                 ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
     EXPECT_THAT(deserializedData.getKnownColdSymbols(),
                 ElementsAre("sym2", "sym1"));
@@ -150,30 +153,27 @@ TEST(MemProf, DataAccessProfile) {
     EXPECT_THAT(
         Records,
         ElementsAre(
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(&DataAccessProfRecord::Locations,
-                                 testing::IsEmpty())),
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(
-                      &DataAccessProfRecord::Locations,
-                      ElementsAre(AllOf(
-                          testing::Field(&DataLocation::FileName, "file2"),
-                          testing::Field(&DataLocation::Line, 3))))),
             AllOf(
-                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-                testing::Field(
-                    &DataAccessProfRecord::Locations,
-                    ElementsAre(
-                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                              testing::Field(&DataLocation::Line, 1)),
-                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                              testing::Field(&DataLocation::Line, 2)))))));
+                Field(&DataAccessProfRecordRef::SymbolID, 0),
+                Field(&DataAccessProfRecordRef::AccessCount, 100),
+                Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+                Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())),
+            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1),
+                  Field(&DataAccessProfRecordRef::AccessCount, 123),
+                  Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+                  Field(&DataAccessProfRecordRef::Locations,
+                        ElementsAre(
+                            AllOf(Field(&SourceLocationRef::FileName, "file2"),
+                                  Field(&SourceLocationRef::Line, 3))))),
+            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 135246),
+                  Field(&DataAccessProfRecordRef::AccessCount, 1000),
+                  Field(&DataAccessProfRecordRef::IsStringLiteral, true),
+                  Field(&DataAccessProfRecordRef::Locations,
+                        ElementsAre(
+                            AllOf(Field(&SourceLocationRef::FileName, "file1"),
+                                  Field(&SourceLocationRef::Line, 1)),
+                            AllOf(Field(&SourceLocationRef::FileName, "file2"),
+                                  Field(&SourceLocationRef::Line, 2)))))));
   }
 }
 } // namespace



More information about the llvm-commits mailing list