[llvm] Re-apply "[StaticDataLayout][PGO]Implement reader and writer change for data access profiles" (PR #141275)

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Fri May 23 11:41:11 PDT 2025


https://github.com/mingmingl-llvm created https://github.com/llvm/llvm-project/pull/141275

Fix the use-of-uninitialized-memory error (https://lab.llvm.org/buildbot/#/builders/94/builds/7373).

Tested:  The error is reproduced with https://github.com/llvm/llvm-zorg/blob/main/zorg/buildbot/builders/sanitizers/buildbot_bootstrap_msan.sh without the fix, and test pass with the fix.


**Original commit message:**

https://github.com/llvm/llvm-project/pull/138170 introduces classes to operate on data access profiles. This change supports the read and write of `DataAccessProfData` in indexed format of MemProf (v4) as well as its the text (yaml) format. 

For indexed format:
* InstrProfWriter owns (by `std::unique_ptr<DataAccessProfData>`) the data access profiles, and gives a non-owned copy when it calls `writeMemProf`.
  * MemProf v4 header has a new `uint64_t` to record the byte offset of data access profiles. This `uint64_t` field is zero if data access profile is not set (nullptr).
* MemProfReader reads the offset from v4 header and de-serializes in-memory bytes into class `DataAccessProfData`.

For textual format:
* MemProfYAML.h adds the mapping for DAP class, and make DAP optional for both read and write.

099a0fa (by @snehasish) introduces v4 which contains CalleeGuids in CallSiteInfo, and this change changes the v4 format in place with data access profiles. The current plan is to bump the version and enable v4 profiles with both features, assuming waiting for this change won't delay the callsite change too long.

>From 6cd7d8dafd6a5cb438c5bd595e126f3cd863814e Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 1 May 2025 09:37:59 -0700
Subject: [PATCH 01/14] Add classes to construct and use data access profiles

---
 llvm/include/llvm/ADT/MapVector.h             |   2 +
 .../include/llvm/ProfileData/DataAccessProf.h | 156 +++++++++++
 llvm/include/llvm/ProfileData/InstrProf.h     |  16 +-
 llvm/lib/ProfileData/CMakeLists.txt           |   1 +
 llvm/lib/ProfileData/DataAccessProf.cpp       | 246 ++++++++++++++++++
 llvm/lib/ProfileData/InstrProf.cpp            |   8 +-
 llvm/unittests/ProfileData/MemProfTest.cpp    | 161 ++++++++++++
 7 files changed, 579 insertions(+), 11 deletions(-)
 create mode 100644 llvm/include/llvm/ProfileData/DataAccessProf.h
 create mode 100644 llvm/lib/ProfileData/DataAccessProf.cpp

diff --git a/llvm/include/llvm/ADT/MapVector.h b/llvm/include/llvm/ADT/MapVector.h
index c11617a81c97d..fe0d106795c34 100644
--- a/llvm/include/llvm/ADT/MapVector.h
+++ b/llvm/include/llvm/ADT/MapVector.h
@@ -57,6 +57,8 @@ class MapVector {
     return std::move(Vector);
   }
 
+  ArrayRef<value_type> getArrayRef() const { return Vector; }
+
   size_type size() const { return Vector.size(); }
 
   /// Grow the MapVector so that it can contain at least \p NumEntries items
diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
new file mode 100644
index 0000000000000..2cce4945fddd5
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -0,0 +1,156 @@
+//===- DataAccessProf.h - Data access profile format support ---------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains support to construct and use data access profiles.
+//
+// For the original RFC of this pass please see
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_
+#define LLVM_PROFILEDATA_DATAACCESSPROF_H_
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+
+#include <cstdint>
+#include <variant>
+
+namespace llvm {
+
+namespace data_access_prof {
+// The location of data in the source code.
+struct DataLocation {
+  // The filename where the data is located.
+  StringRef FileName;
+  // The line number in the source code.
+  uint32_t Line;
+};
+
+// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+  // Represents a data symbol. The semantic comes in two forms: a symbol index
+  // for symbol name if `IsStringLiteral` is false, or the hash of a string
+  // content if `IsStringLiteral` is true. Required.
+  uint64_t SymbolID;
+
+  // The access count of symbol. Required.
+  uint64_t AccessCount;
+
+  // True iff this is a record for string literal (symbols with name pattern
+  // `.str.*` in the symbol table). Required.
+  bool IsStringLiteral;
+
+  // The locations of data in the source code. Optional.
+  llvm::SmallVector<DataLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on it.
+/// This class provides profile look-up, serialization and deserialization.
+class DataAccessProfData {
+public:
+  // SymbolID is either a string representing symbol name, or a uint64_t
+  // representing the content hash of a string literal.
+  using SymbolID = std::variant<StringRef, uint64_t>;
+  using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
+
+  DataAccessProfData() : saver(Allocator) {}
+
+  /// Serialize profile data to the output stream.
+  /// Storage layout:
+  /// - Serialized strings.
+  /// - The encoded hashes.
+  /// - Records.
+  Error serialize(ProfOStream &OS) const;
+
+  /// Deserialize this class from the given buffer.
+  Error deserialize(const unsigned char *&Ptr);
+
+  /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+  /// isn't a record. Internally, this function will canonicalize the symbol
+  /// name before the lookup.
+  const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+
+  /// Returns true if \p SymID is seen in profiled binaries and cold.
+  bool isKnownColdSymbol(const SymbolID SymID) const;
+
+  /// Methods to add symbolized data access profile. Returns error if duplicated
+  /// symbol names or content hashes are seen. The user of this class should
+  /// aggregate counters that corresponds to the same symbol name or with the
+  /// same string literal hash before calling 'add*' methods.
+  Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  Error addSymbolizedDataAccessProfile(
+      SymbolID SymbolID, uint64_t AccessCount,
+      const llvm::SmallVector<DataLocation> &Locations);
+  Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+
+  /// Returns a iterable StringRef for strings in the order they are added.
+  auto getStrings() const {
+    ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
+        StrToIndexMap.begin(), StrToIndexMap.end());
+    return llvm::make_first_range(RefSymbolNames);
+  }
+
+  /// Returns array reference for various internal data structures.
+  inline ArrayRef<
+      std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
+  getRecords() const {
+    return Records.getArrayRef();
+  }
+  inline ArrayRef<StringRef> getKnownColdSymbols() const {
+    return KnownColdSymbols.getArrayRef();
+  }
+  inline ArrayRef<uint64_t> getKnownColdHashes() const {
+    return KnownColdHashes.getArrayRef();
+  }
+
+private:
+  /// Serialize the symbol strings into the output stream.
+  Error serializeStrings(ProfOStream &OS) const;
+
+  /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
+  /// start of the next payload.
+  Error deserializeStrings(const unsigned char *&Ptr,
+                           const uint64_t NumSampledSymbols,
+                           uint64_t NumColdKnownSymbols);
+
+  /// Decode the records and increment \p Ptr to the start of the next payload.
+  Error deserializeRecords(const unsigned char *&Ptr);
+
+  /// A helper function to compute a storage index for \p SymbolID.
+  uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+
+  // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
+  // its record index.
+  MapVector<SymbolID, DataAccessProfRecord> Records;
+
+  // Use MapVector to keep input order of strings for serialization and
+  // deserialization.
+  StringToIndexMap StrToIndexMap;
+  llvm::SetVector<uint64_t> KnownColdHashes;
+  llvm::SetVector<StringRef> KnownColdSymbols;
+  // Keeps owned copies of the input strings.
+  llvm::BumpPtrAllocator Allocator;
+  llvm::UniqueStringSaver saver;
+};
+
+} // namespace data_access_prof
+} // namespace llvm
+
+#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 2d011c89f27cb..8a6be22bdb1a4 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,6 +357,12 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 /// the duplicated profile variables for Comdat functions.
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
+/// \c NameStrings is a string composed of one of more possibly encoded
+/// sub-strings. The substrings are separated by 0 or more zero bytes. This
+/// method decodes the string and calls `NameCallback` for each substring.
+Error readAndDecodeStrings(StringRef NameStrings,
+                           std::function<Error(StringRef)> NameCallback);
+
 /// An enum describing the attributes of an instrumented profile.
 enum class InstrProfKind {
   Unknown = 0x0,
@@ -493,6 +499,11 @@ class InstrProfSymtab {
 public:
   using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
 
+  // Returns the canonial name of the given PGOName. In a canonical name, all
+  // suffixes that begins with "." except ".__uniq." are stripped.
+  // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
+  static StringRef getCanonicalName(StringRef PGOName);
+
 private:
   using AddrIntervalMap =
       IntervalMap<uint64_t, uint64_t, 4, IntervalMapHalfOpenInfo<uint64_t>>;
@@ -528,11 +539,6 @@ class InstrProfSymtab {
 
   static StringRef getExternalSymbol() { return "** External Symbol **"; }
 
-  // Returns the canonial name of the given PGOName. In a canonical name, all
-  // suffixes that begins with "." except ".__uniq." are stripped.
-  // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
-  static StringRef getCanonicalName(StringRef PGOName);
-
   // Add the function into the symbol table, by creating the following
   // map entries:
   // name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)}
diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt
index eb7c2a3c1a28a..67a69d7761b2c 100644
--- a/llvm/lib/ProfileData/CMakeLists.txt
+++ b/llvm/lib/ProfileData/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_llvm_component_library(LLVMProfileData
+  DataAccessProf.cpp
   GCOV.cpp
   IndexedMemProfData.cpp
   InstrProf.cpp
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
new file mode 100644
index 0000000000000..cf538d6a1b28e
--- /dev/null
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -0,0 +1,246 @@
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Compression.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <sys/types.h>
+
+namespace llvm {
+namespace data_access_prof {
+
+// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
+// creates an owned copy of `Str`, adds a map entry for it and returns the
+// iterator.
+static MapVector<StringRef, uint64_t>::iterator
+saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+                llvm::UniqueStringSaver &saver, StringRef Str) {
+  auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+  return Iter;
+}
+
+const DataAccessProfRecord *
+DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+  auto Key = SymbolID;
+  if (std::holds_alternative<StringRef>(SymbolID))
+    Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+
+  auto It = Records.find(Key);
+  if (It != Records.end())
+    return &It->second;
+
+  return nullptr;
+}
+
+bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+  if (std::holds_alternative<uint64_t>(SymID))
+    return KnownColdHashes.count(std::get<uint64_t>(SymID));
+  return KnownColdSymbols.count(std::get<StringRef>(SymID));
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
+                                                         uint64_t AccessCount) {
+  uint64_t RecordID = -1;
+  bool IsStringLiteral = false;
+  SymbolID Key;
+  if (std::holds_alternative<uint64_t>(Symbol)) {
+    RecordID = std::get<uint64_t>(Symbol);
+    Key = RecordID;
+    IsStringLiteral = true;
+  } else {
+    StringRef SymbolName = std::get<StringRef>(Symbol);
+    if (SymbolName.empty())
+      return make_error<StringError>("Empty symbol name",
+                                     llvm::errc::invalid_argument);
+
+    StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
+    Key = CanonicalName;
+    RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+    IsStringLiteral = false;
+  }
+
+  auto [Iter, Inserted] = Records.try_emplace(
+      Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+  if (!Inserted)
+    return make_error<StringError>("Duplicate symbol or string literal added. "
+                                   "User of DataAccessProfData should "
+                                   "aggregate count for the same symbol. ",
+                                   llvm::errc::invalid_argument);
+
+  return Error::success();
+}
+
+Error DataAccessProfData::addSymbolizedDataAccessProfile(
+    SymbolID SymbolID, uint64_t AccessCount,
+    const llvm::SmallVector<DataLocation> &Locations) {
+  if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+    return E;
+
+  auto &Record = Records.back().second;
+  for (const auto &Location : Locations)
+    Record.Locations.push_back(
+        {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+         Location.Line});
+
+  return Error::success();
+}
+
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+  if (std::holds_alternative<uint64_t>(SymbolID)) {
+    KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
+    return Error::success();
+  }
+  StringRef SymbolName = std::get<StringRef>(SymbolID);
+  if (SymbolName.empty())
+    return make_error<StringError>("Empty symbol name",
+                                   llvm::errc::invalid_argument);
+  StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
+  KnownColdSymbols.insert(CanonicalSymName);
+  return Error::success();
+}
+
+Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
+  uint64_t NumSampledSymbols =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  uint64_t NumColdKnownSymbols =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+    return E;
+
+  uint64_t Num =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  for (uint64_t I = 0; I < Num; ++I)
+    KnownColdHashes.insert(
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
+
+  return deserializeRecords(Ptr);
+}
+
+Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+  OS.write(StrToIndexMap.size());
+  OS.write(KnownColdSymbols.size());
+
+  std::vector<std::string> Strs;
+  Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
+  for (const auto &Str : StrToIndexMap)
+    Strs.push_back(Str.first.str());
+  for (const auto &Str : KnownColdSymbols)
+    Strs.push_back(Str.str());
+
+  std::string CompressedStrings;
+  if (!Strs.empty())
+    if (Error E = collectGlobalObjectNameStrings(
+            Strs, compression::zlib::isAvailable(), CompressedStrings))
+      return E;
+  const uint64_t CompressedStringLen = CompressedStrings.length();
+  // Record the length of compressed string.
+  OS.write(CompressedStringLen);
+  // Write the chars in compressed strings.
+  for (auto &c : CompressedStrings)
+    OS.writeByte(static_cast<uint8_t>(c));
+  // Pad up to a multiple of 8.
+  // InstrProfReader could read bytes according to 'CompressedStringLen'.
+  const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
+  for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
+    OS.writeByte(0);
+  return Error::success();
+}
+
+uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+  if (std::holds_alternative<uint64_t>(SymbolID))
+    return std::get<uint64_t>(SymbolID);
+
+  return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+}
+
+Error DataAccessProfData::serialize(ProfOStream &OS) const {
+  if (Error E = serializeStrings(OS))
+    return E;
+  OS.write(KnownColdHashes.size());
+  for (const auto &Hash : KnownColdHashes)
+    OS.write(Hash);
+  OS.write((uint64_t)(Records.size()));
+  for (const auto &[Key, Rec] : Records) {
+    OS.write(getEncodedIndex(Rec.SymbolID));
+    OS.writeByte(Rec.IsStringLiteral);
+    OS.write(Rec.AccessCount);
+    OS.write(Rec.Locations.size());
+    for (const auto &Loc : Rec.Locations) {
+      OS.write(getEncodedIndex(Loc.FileName));
+      OS.write32(Loc.Line);
+    }
+  }
+  return Error::success();
+}
+
+Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
+                                             uint64_t NumSampledSymbols,
+                                             uint64_t NumColdKnownSymbols) {
+  uint64_t Len =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+  // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
+  // symbols with samples, and next N strings are known cold symbols.
+  uint64_t StringCnt = 0;
+  std::function<Error(StringRef)> addName = [&](StringRef Name) {
+    if (StringCnt < NumSampledSymbols)
+      saveStringToMap(StrToIndexMap, saver, Name);
+    else
+      KnownColdSymbols.insert(saver.save(Name));
+    ++StringCnt;
+    return Error::success();
+  };
+  if (Error E =
+          readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName))
+    return E;
+
+  Ptr += alignTo(Len, 8);
+  return Error::success();
+}
+
+Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
+  SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+
+  uint64_t NumRecords =
+      support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+  for (uint64_t I = 0; I < NumRecords; ++I) {
+    uint64_t ID =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    bool IsStringLiteral =
+        support::endian::readNext<uint8_t, llvm::endianness::little>(Ptr);
+
+    uint64_t AccessCount =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    SymbolID SymbolID;
+    if (IsStringLiteral)
+      SymbolID = ID;
+    else
+      SymbolID = Strings[ID];
+    if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+      return E;
+
+    auto &Record = Records.back().second;
+
+    uint64_t NumLocations =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+    Record.Locations.reserve(NumLocations);
+    for (uint64_t J = 0; J < NumLocations; ++J) {
+      uint64_t FileNameIndex =
+          support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+      uint32_t Line =
+          support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+      Record.Locations.push_back({Strings[FileNameIndex], Line});
+    }
+  }
+  return Error::success();
+}
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp
index 88621787c1dd9..254f941acde82 100644
--- a/llvm/lib/ProfileData/InstrProf.cpp
+++ b/llvm/lib/ProfileData/InstrProf.cpp
@@ -573,12 +573,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable,
   return Error::success();
 }
 
-/// \c NameStrings is a string composed of one of more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
-static Error
-readAndDecodeStrings(StringRef NameStrings,
-                     std::function<Error(StringRef)> NameCallback) {
+Error readAndDecodeStrings(StringRef NameStrings,
+                           std::function<Error(StringRef)> NameCallback) {
   const uint8_t *P = NameStrings.bytes_begin();
   const uint8_t *EndP = NameStrings.bytes_end();
   while (P < EndP) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 3e430aa4eae58..f6362448a5734 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,14 +10,17 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/DebugInfo/DIContext.h"
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Object/ObjectFile.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/ProfileData/MemProfReader.h"
 #include "llvm/ProfileData/MemProfYAML.h"
 #include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -36,6 +39,8 @@ using ::llvm::StringRef;
 using ::llvm::object::SectionedAddress;
 using ::llvm::symbolize::SymbolizableModule;
 using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 using ::testing::Pair;
 using ::testing::Return;
@@ -747,6 +752,162 @@ TEST(MemProf, YAMLParser) {
                                ElementsAre(0x3000)))));
 }
 
+static std::string ErrorToString(Error E) {
+  std::string ErrMsg;
+  llvm::raw_string_ostream OS(ErrMsg);
+  llvm::logAllUnhandledErrors(std::move(E), OS);
+  return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+  // Returns error if the input symbol name is empty.
+  llvm::data_access_prof::DataAccessProfData Data;
+  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+              HasSubstr("Empty symbol name"));
+
+  // Returns error when the same symbol gets added twice.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+              HasSubstr("Duplicate symbol or string literal added"));
+
+  // Returns error when the same string content hash gets added twice.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(
+                  Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+              HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+  using namespace llvm::data_access_prof;
+  llvm::data_access_prof::DataAccessProfData Data;
+
+  // In the bool conversion, Error is true if it's in a failure state and false
+  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
+                                                   {
+                                                       DataLocation{"file2", 3},
+                                                   }));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+      (uint64_t)135246, 1000,
+      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+  {
+    // Test that symbol names and file names are stored in the input order.
+    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles.
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+    EXPECT_THAT(
+        Data.getProfileRecord("foo.llvm.123"),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+              testing::Field(&DataAccessProfRecord::AccessCount, 100),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+              testing::Field(&DataAccessProfRecord::Locations,
+                             testing::IsEmpty())));
+    EXPECT_THAT(
+        *Data.getProfileRecord("bar.__uniq.321"),
+        AllOf(
+            testing::Field(&DataAccessProfRecord::SymbolID, 1),
+            testing::Field(&DataAccessProfRecord::AccessCount, 123),
+            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+            testing::Field(&DataAccessProfRecord::Locations,
+                           ElementsAre(AllOf(
+                               testing::Field(&DataLocation::FileName, "file2"),
+                               testing::Field(&DataLocation::Line, 3))))));
+    EXPECT_THAT(
+        *Data.getProfileRecord((uint64_t)135246),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+              testing::Field(
+                  &DataAccessProfRecord::Locations,
+                  ElementsAre(
+                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                            testing::Field(&DataLocation::Line, 1)),
+                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                            testing::Field(&DataLocation::Line, 2))))));
+  }
+
+  // Tests serialization and de-serialization.
+  llvm::data_access_prof::DataAccessProfData deserializedData;
+  {
+    std::string serializedData;
+    llvm::raw_string_ostream OS(serializedData);
+    llvm::ProfOStream POS(OS);
+
+    EXPECT_FALSE(Data.serialize(POS));
+
+    const unsigned char *p =
+        reinterpret_cast<const unsigned char *>(serializedData.data());
+    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                testing::IsEmpty());
+    EXPECT_FALSE(deserializedData.deserialize(p));
+
+    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+                ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles after deserialization.
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+    auto Records =
+        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+    EXPECT_THAT(
+        Records,
+        ElementsAre(
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(&DataAccessProfRecord::Locations,
+                                 testing::IsEmpty())),
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(
+                      &DataAccessProfRecord::Locations,
+                      ElementsAre(AllOf(
+                          testing::Field(&DataLocation::FileName, "file2"),
+                          testing::Field(&DataLocation::Line, 3))))),
+            AllOf(
+                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+                testing::Field(
+                    &DataAccessProfRecord::Locations,
+                    ElementsAre(
+                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                              testing::Field(&DataLocation::Line, 1)),
+                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                              testing::Field(&DataLocation::Line, 2)))))));
+  }
+}
+
 // Verify that the YAML parser accepts a GUID expressed as a function name.
 TEST(MemProf, YAMLParserGUID) {
   StringRef YAMLData = R"YAML(

>From 47275299fd3e9ef242f108040df8ffe46f9cd7b0 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 5 May 2025 16:57:56 -0700
Subject: [PATCH 02/14] resolve review feedback

---
 .../include/llvm/ProfileData/DataAccessProf.h | 37 ++++++++++------
 llvm/lib/ProfileData/DataAccessProf.cpp       | 43 ++++++++++---------
 llvm/unittests/ProfileData/MemProfTest.cpp    | 23 +++++-----
 3 files changed, 57 insertions(+), 46 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 2cce4945fddd5..36648d6298ee5 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -45,6 +45,11 @@ struct DataLocation {
 
 // The data access profiles for a symbol.
 struct DataAccessProfRecord {
+  DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
+                       bool IsStringLiteral)
+      : SymbolID(SymbolID), AccessCount(AccessCount),
+        IsStringLiteral(IsStringLiteral) {}
+
   // Represents a data symbol. The semantic comes in two forms: a symbol index
   // for symbol name if `IsStringLiteral` is false, or the hash of a string
   // content if `IsStringLiteral` is true. Required.
@@ -58,7 +63,7 @@ struct DataAccessProfRecord {
   bool IsStringLiteral;
 
   // The locations of data in the source code. Optional.
-  llvm::SmallVector<DataLocation> Locations;
+  llvm::SmallVector<DataLocation, 0> Locations;
 };
 
 /// Encapsulates the data access profile data and the methods to operate on it.
@@ -70,7 +75,7 @@ class DataAccessProfData {
   using SymbolID = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
-  DataAccessProfData() : saver(Allocator) {}
+  DataAccessProfData() : Saver(Allocator) {}
 
   /// Serialize profile data to the output stream.
   /// Storage layout:
@@ -94,10 +99,13 @@ class DataAccessProfData {
   /// symbol names or content hashes are seen. The user of this class should
   /// aggregate counters that corresponds to the same symbol name or with the
   /// same string literal hash before calling 'add*' methods.
-  Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
-  Error addSymbolizedDataAccessProfile(
-      SymbolID SymbolID, uint64_t AccessCount,
-      const llvm::SmallVector<DataLocation> &Locations);
+  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  /// Similar to the method above, for records with \p Locations representing
+  /// the `filename:line` where this symbol shows up. Note because of linker's
+  /// merge of identical symbols (e.g., unnamed_addr string literals), one
+  /// symbol is likely to have multiple locations.
+  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
+                             const llvm::SmallVector<DataLocation> &Locations);
   Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
 
   /// Returns a iterable StringRef for strings in the order they are added.
@@ -122,13 +130,13 @@ class DataAccessProfData {
 
 private:
   /// Serialize the symbol strings into the output stream.
-  Error serializeStrings(ProfOStream &OS) const;
+  Error serializeSymbolsAndFilenames(ProfOStream &OS) const;
 
   /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the
   /// start of the next payload.
-  Error deserializeStrings(const unsigned char *&Ptr,
-                           const uint64_t NumSampledSymbols,
-                           uint64_t NumColdKnownSymbols);
+  Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
+                                       const uint64_t NumSampledSymbols,
+                                       uint64_t NumColdKnownSymbols);
 
   /// Decode the records and increment \p Ptr to the start of the next payload.
   Error deserializeRecords(const unsigned char *&Ptr);
@@ -136,6 +144,12 @@ class DataAccessProfData {
   /// A helper function to compute a storage index for \p SymbolID.
   uint64_t getEncodedIndex(const SymbolID SymbolID) const;
 
+  // Keeps owned copies of the input strings.
+  // NOTE: Keep `Saver` initialized before other class members that reference
+  // its string copies and destructed after they are destructed.
+  llvm::BumpPtrAllocator Allocator;
+  llvm::UniqueStringSaver Saver;
+
   // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
   // its record index.
   MapVector<SymbolID, DataAccessProfRecord> Records;
@@ -145,9 +159,6 @@ class DataAccessProfData {
   StringToIndexMap StrToIndexMap;
   llvm::SetVector<uint64_t> KnownColdHashes;
   llvm::SetVector<StringRef> KnownColdSymbols;
-  // Keeps owned copies of the input strings.
-  llvm::BumpPtrAllocator Allocator;
-  llvm::UniqueStringSaver saver;
 };
 
 } // namespace data_access_prof
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index cf538d6a1b28e..c52533c13919c 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -18,8 +18,8 @@ namespace data_access_prof {
 // iterator.
 static MapVector<StringRef, uint64_t>::iterator
 saveStringToMap(MapVector<StringRef, uint64_t> &Map,
-                llvm::UniqueStringSaver &saver, StringRef Str) {
-  auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size());
+                llvm::UniqueStringSaver &Saver, StringRef Str) {
+  auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
   return Iter;
 }
 
@@ -38,12 +38,12 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
 
 bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
-    return KnownColdHashes.count(std::get<uint64_t>(SymID));
-  return KnownColdSymbols.count(std::get<StringRef>(SymID));
+    return KnownColdHashes.contains(std::get<uint64_t>(SymID));
+  return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
-                                                         uint64_t AccessCount) {
+Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+                                               uint64_t AccessCount) {
   uint64_t RecordID = -1;
   bool IsStringLiteral = false;
   SymbolID Key;
@@ -59,12 +59,12 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
 
     StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
     Key = CanonicalName;
-    RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second;
+    RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
     IsStringLiteral = false;
   }
 
-  auto [Iter, Inserted] = Records.try_emplace(
-      Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral});
+  auto [Iter, Inserted] =
+      Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
   if (!Inserted)
     return make_error<StringError>("Duplicate symbol or string literal added. "
                                    "User of DataAccessProfData should "
@@ -74,16 +74,16 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol,
   return Error::success();
 }
 
-Error DataAccessProfData::addSymbolizedDataAccessProfile(
+Error DataAccessProfData::setDataAccessProfile(
     SymbolID SymbolID, uint64_t AccessCount,
     const llvm::SmallVector<DataLocation> &Locations) {
-  if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+  if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
   auto &Record = Records.back().second;
   for (const auto &Location : Locations)
     Record.Locations.push_back(
-        {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first,
+        {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
          Location.Line});
 
   return Error::success();
@@ -108,7 +108,8 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
   uint64_t NumColdKnownSymbols =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
-  if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols))
+  if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
+                                               NumColdKnownSymbols))
     return E;
 
   uint64_t Num =
@@ -120,7 +121,7 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
   return deserializeRecords(Ptr);
 }
 
-Error DataAccessProfData::serializeStrings(ProfOStream &OS) const {
+Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   OS.write(StrToIndexMap.size());
   OS.write(KnownColdSymbols.size());
 
@@ -158,7 +159,7 @@ uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
 }
 
 Error DataAccessProfData::serialize(ProfOStream &OS) const {
-  if (Error E = serializeStrings(OS))
+  if (Error E = serializeSymbolsAndFilenames(OS))
     return E;
   OS.write(KnownColdHashes.size());
   for (const auto &Hash : KnownColdHashes)
@@ -177,9 +178,9 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
   return Error::success();
 }
 
-Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
-                                             uint64_t NumSampledSymbols,
-                                             uint64_t NumColdKnownSymbols) {
+Error DataAccessProfData::deserializeSymbolsAndFilenames(
+    const unsigned char *&Ptr, uint64_t NumSampledSymbols,
+    uint64_t NumColdKnownSymbols) {
   uint64_t Len =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
@@ -188,9 +189,9 @@ Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr,
   uint64_t StringCnt = 0;
   std::function<Error(StringRef)> addName = [&](StringRef Name) {
     if (StringCnt < NumSampledSymbols)
-      saveStringToMap(StrToIndexMap, saver, Name);
+      saveStringToMap(StrToIndexMap, Saver, Name);
     else
-      KnownColdSymbols.insert(saver.save(Name));
+      KnownColdSymbols.insert(Saver.save(Name));
     ++StringCnt;
     return Error::success();
   };
@@ -223,7 +224,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
       SymbolID = ID;
     else
       SymbolID = Strings[ID];
-    if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount))
+    if (Error E = setDataAccessProfile(SymbolID, AccessCount))
       return E;
 
     auto &Record = Records.back().second;
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index f6362448a5734..b7b8d642ad930 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -764,18 +764,17 @@ static std::string ErrorToString(Error E) {
 TEST(MemProf, DataAccessProfileError) {
   // Returns error if the input symbol name is empty.
   llvm::data_access_prof::DataAccessProfData Data;
-  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)),
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
               HasSubstr("Empty symbol name"));
 
   // Returns error when the same symbol gets added twice.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100));
-  EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)),
+  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
               HasSubstr("Duplicate symbol or string literal added"));
 
   // Returns error when the same string content hash gets added twice.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000));
-  EXPECT_THAT(ErrorToString(
-                  Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)),
+  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
               HasSubstr("Duplicate symbol or string literal added"));
 }
 
@@ -788,16 +787,16 @@ TEST(MemProf, DataAccessProfile) {
 
   // In the bool conversion, Error is true if it's in a failure state and false
   // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123,
-                                                   {
-                                                       DataLocation{"file2", 3},
-                                                   }));
+  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+                                         {
+                                             DataLocation{"file2", 3},
+                                         }));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
-  ASSERT_FALSE(Data.addSymbolizedDataAccessProfile(
+  ASSERT_FALSE(Data.setDataAccessProfile(
       (uint64_t)135246, 1000,
       {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
 

>From 80249bce82307ffef1817747d1c2fca100a9ea53 Mon Sep 17 00:00:00 2001
From: Mingming Liu <mingmingl at google.com>
Date: Tue, 6 May 2025 10:39:53 -0700
Subject: [PATCH 03/14] Apply suggestions from code review

Co-authored-by: Kazu Hirata <kazu at google.com>
---
 llvm/include/llvm/ProfileData/DataAccessProf.h | 11 ++++++-----
 llvm/include/llvm/ProfileData/InstrProf.h      |  2 +-
 llvm/lib/ProfileData/DataAccessProf.cpp        |  2 +-
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 36648d6298ee5..81289b60b96ed 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -97,7 +97,7 @@ class DataAccessProfData {
 
   /// Methods to add symbolized data access profile. Returns error if duplicated
   /// symbol names or content hashes are seen. The user of this class should
-  /// aggregate counters that corresponds to the same symbol name or with the
+  /// aggregate counters that correspond to the same symbol name or with the
   /// same string literal hash before calling 'add*' methods.
   Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
@@ -108,7 +108,8 @@ class DataAccessProfData {
                              const llvm::SmallVector<DataLocation> &Locations);
   Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
 
-  /// Returns a iterable StringRef for strings in the order they are added.
+  /// Returns an iterable StringRef for strings in the order they are added.
+  /// Each string may be a symbol name or a file name.
   auto getStrings() const {
     ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
         StrToIndexMap.begin(), StrToIndexMap.end());
@@ -116,15 +117,15 @@ class DataAccessProfData {
   }
 
   /// Returns array reference for various internal data structures.
-  inline ArrayRef<
+  ArrayRef<
       std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
   getRecords() const {
     return Records.getArrayRef();
   }
-  inline ArrayRef<StringRef> getKnownColdSymbols() const {
+  ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
-  inline ArrayRef<uint64_t> getKnownColdHashes() const {
+  ArrayRef<uint64_t> getKnownColdHashes() const {
     return KnownColdHashes.getArrayRef();
   }
 
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 8a6be22bdb1a4..710b2f6836064 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -357,7 +357,7 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 /// the duplicated profile variables for Comdat functions.
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
-/// \c NameStrings is a string composed of one of more possibly encoded
+/// \c NameStrings is a string composed of one or more possibly encoded
 /// sub-strings. The substrings are separated by 0 or more zero bytes. This
 /// method decodes the string and calls `NameCallback` for each substring.
 Error readAndDecodeStrings(StringRef NameStrings,
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index c52533c13919c..a42ee41b24358 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -141,7 +141,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   // Record the length of compressed string.
   OS.write(CompressedStringLen);
   // Write the chars in compressed strings.
-  for (auto &c : CompressedStrings)
+  for (char C : CompressedStrings)
     OS.writeByte(static_cast<uint8_t>(c));
   // Pad up to a multiple of 8.
   // InstrProfReader could read bytes according to 'CompressedStringLen'.

>From b69c9930b9584a8dede96b0a9605c03f4d504bb8 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 10:50:59 -0700
Subject: [PATCH 04/14] resolve review feedback

---
 .../include/llvm/ProfileData/DataAccessProf.h |  43 ++---
 llvm/include/llvm/ProfileData/InstrProf.h     |   5 +-
 llvm/lib/ProfileData/DataAccessProf.cpp       |  72 ++++---
 llvm/unittests/ProfileData/CMakeLists.txt     |   1 +
 .../ProfileData/DataAccessProfTest.cpp        | 181 ++++++++++++++++++
 llvm/unittests/ProfileData/MemProfTest.cpp    | 160 ----------------
 6 files changed, 245 insertions(+), 217 deletions(-)
 create mode 100644 llvm/unittests/ProfileData/DataAccessProfTest.cpp

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 81289b60b96ed..a85a3320ae57c 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -70,9 +70,11 @@ struct DataAccessProfRecord {
 /// This class provides profile look-up, serialization and deserialization.
 class DataAccessProfData {
 public:
-  // SymbolID is either a string representing symbol name, or a uint64_t
-  // representing the content hash of a string literal.
-  using SymbolID = std::variant<StringRef, uint64_t>;
+  // SymbolID is either a string representing symbol name if the symbol has
+  // stable mangled name relative to source code, or a uint64_t representing the
+  // content hash of a string literal (with unstable name patterns like
+  // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+  using SymbolHandle = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
   DataAccessProfData() : Saver(Allocator) {}
@@ -90,38 +92,32 @@ class DataAccessProfData {
   /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
   /// isn't a record. Internally, this function will canonicalize the symbol
   /// name before the lookup.
-  const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const;
+  const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
 
   /// Returns true if \p SymID is seen in profiled binaries and cold.
-  bool isKnownColdSymbol(const SymbolID SymID) const;
+  bool isKnownColdSymbol(const SymbolHandle SymID) const;
 
-  /// Methods to add symbolized data access profile. Returns error if duplicated
+  /// Methods to set symbolized data access profile. Returns error if duplicated
   /// symbol names or content hashes are seen. The user of this class should
   /// aggregate counters that correspond to the same symbol name or with the
-  /// same string literal hash before calling 'add*' methods.
-  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount);
+  /// same string literal hash before calling 'set*' methods.
+  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
   /// the `filename:line` where this symbol shows up. Note because of linker's
   /// merge of identical symbols (e.g., unnamed_addr string literals), one
   /// symbol is likely to have multiple locations.
-  Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount,
-                             const llvm::SmallVector<DataLocation> &Locations);
-  Error addKnownSymbolWithoutSamples(SymbolID SymbolID);
+  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
+                             ArrayRef<DataLocation> Locations);
+  Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
 
   /// Returns an iterable StringRef for strings in the order they are added.
   /// Each string may be a symbol name or a file name.
   auto getStrings() const {
-    ArrayRef<std::pair<StringRef, uint64_t>> RefSymbolNames(
-        StrToIndexMap.begin(), StrToIndexMap.end());
-    return llvm::make_first_range(RefSymbolNames);
+    return llvm::make_first_range(StrToIndexMap.getArrayRef());
   }
 
   /// Returns array reference for various internal data structures.
-  ArrayRef<
-      std::pair<std::variant<StringRef, uint64_t>, DataAccessProfRecord>>
-  getRecords() const {
-    return Records.getArrayRef();
-  }
+  auto getRecords() const { return Records.getArrayRef(); }
   ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
@@ -137,13 +133,13 @@ class DataAccessProfData {
   /// start of the next payload.
   Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr,
                                        const uint64_t NumSampledSymbols,
-                                       uint64_t NumColdKnownSymbols);
+                                       const uint64_t NumColdKnownSymbols);
 
   /// Decode the records and increment \p Ptr to the start of the next payload.
   Error deserializeRecords(const unsigned char *&Ptr);
 
   /// A helper function to compute a storage index for \p SymbolID.
-  uint64_t getEncodedIndex(const SymbolID SymbolID) const;
+  uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
 
   // Keeps owned copies of the input strings.
   // NOTE: Keep `Saver` initialized before other class members that reference
@@ -151,9 +147,8 @@ class DataAccessProfData {
   llvm::BumpPtrAllocator Allocator;
   llvm::UniqueStringSaver Saver;
 
-  // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to
-  // its record index.
-  MapVector<SymbolID, DataAccessProfRecord> Records;
+  // `Records` stores the records.
+  MapVector<SymbolHandle, DataAccessProfRecord> Records;
 
   // Use MapVector to keep input order of strings for serialization and
   // deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 710b2f6836064..33b93ea0a558a 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -358,8 +358,9 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName);
 bool needsComdatForCounter(const GlobalObject &GV, const Module &M);
 
 /// \c NameStrings is a string composed of one or more possibly encoded
-/// sub-strings. The substrings are separated by 0 or more zero bytes. This
-/// method decodes the string and calls `NameCallback` for each substring.
+/// sub-strings. The substrings are separated by `\01` (returned by
+/// InstrProf.h:getInstrProfNameSeparator). This method decodes the string and
+/// calls `NameCallback` for each substring.
 Error readAndDecodeStrings(StringRef NameStrings,
                            std::function<Error(StringRef)> NameCallback);
 
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index a42ee41b24358..d1c034f639347 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -23,11 +23,22 @@ saveStringToMap(MapVector<StringRef, uint64_t> &Map,
   return Iter;
 }
 
+// Returns the canonical name or error.
+static Expected<StringRef> getCanonicalName(StringRef Name) {
+  if (Name.empty())
+    return make_error<StringError>("Empty symbol name",
+                                   llvm::errc::invalid_argument);
+  return InstrProfSymtab::getCanonicalName(Name);
+}
+
 const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
   auto Key = SymbolID;
-  if (std::holds_alternative<StringRef>(SymbolID))
-    Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+  if (std::holds_alternative<StringRef>(SymbolID)) {
+    StringRef Name = std::get<StringRef>(SymbolID);
+    assert(!Name.empty() && "Empty symbol name");
+    Key = InstrProfSymtab::getCanonicalName(Name);
+  }
 
   auto It = Records.find(Key);
   if (It != Records.end())
@@ -36,30 +47,27 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
   return nullptr;
 }
 
-bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
     return KnownColdHashes.contains(std::get<uint64_t>(SymID));
   return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
                                                uint64_t AccessCount) {
   uint64_t RecordID = -1;
   bool IsStringLiteral = false;
-  SymbolID Key;
+  SymbolHandle Key;
   if (std::holds_alternative<uint64_t>(Symbol)) {
     RecordID = std::get<uint64_t>(Symbol);
     Key = RecordID;
     IsStringLiteral = true;
   } else {
-    StringRef SymbolName = std::get<StringRef>(Symbol);
-    if (SymbolName.empty())
-      return make_error<StringError>("Empty symbol name",
-                                     llvm::errc::invalid_argument);
-
-    StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
-    Key = CanonicalName;
-    RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
+    auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
+    if (!CanonicalName)
+      return CanonicalName.takeError();
+    std::tie(Key, RecordID) =
+        *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
     IsStringLiteral = false;
   }
 
@@ -75,8 +83,8 @@ Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
 }
 
 Error DataAccessProfData::setDataAccessProfile(
-    SymbolID SymbolID, uint64_t AccessCount,
-    const llvm::SmallVector<DataLocation> &Locations) {
+    SymbolHandle SymbolID, uint64_t AccessCount,
+    ArrayRef<DataLocation> Locations) {
   if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
@@ -89,17 +97,15 @@ Error DataAccessProfData::setDataAccessProfile(
   return Error::success();
 }
 
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
   if (std::holds_alternative<uint64_t>(SymbolID)) {
     KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
     return Error::success();
   }
-  StringRef SymbolName = std::get<StringRef>(SymbolID);
-  if (SymbolName.empty())
-    return make_error<StringError>("Empty symbol name",
-                                   llvm::errc::invalid_argument);
-  StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
-  KnownColdSymbols.insert(CanonicalSymName);
+  auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
+  if (!CanonicalName)
+    return CanonicalName.takeError();
+  KnownColdSymbols.insert(*CanonicalName);
   return Error::success();
 }
 
@@ -142,7 +148,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   OS.write(CompressedStringLen);
   // Write the chars in compressed strings.
   for (char C : CompressedStrings)
-    OS.writeByte(static_cast<uint8_t>(c));
+    OS.writeByte(static_cast<uint8_t>(C));
   // Pad up to a multiple of 8.
   // InstrProfReader could read bytes according to 'CompressedStringLen'.
   const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
@@ -151,11 +157,15 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
   return Error::success();
 }
 
-uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+uint64_t
+DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
   if (std::holds_alternative<uint64_t>(SymbolID))
     return std::get<uint64_t>(SymbolID);
 
-  return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+  auto Iter = StrToIndexMap.find(std::get<StringRef>(SymbolID));
+  assert(Iter != StrToIndexMap.end() &&
+         "String literals not found in StrToIndexMap");
+  return Iter->second;
 }
 
 Error DataAccessProfData::serialize(ProfOStream &OS) const {
@@ -179,13 +189,13 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const {
 }
 
 Error DataAccessProfData::deserializeSymbolsAndFilenames(
-    const unsigned char *&Ptr, uint64_t NumSampledSymbols,
-    uint64_t NumColdKnownSymbols) {
+    const unsigned char *&Ptr, const uint64_t NumSampledSymbols,
+    const uint64_t NumColdKnownSymbols) {
   uint64_t Len =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-  // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
-  // symbols with samples, and next N strings are known cold symbols.
+  // The first NumSampledSymbols strings are symbols with samples, and next
+  // NumColdKnownSymbols strings are known cold symbols.
   uint64_t StringCnt = 0;
   std::function<Error(StringRef)> addName = [&](StringRef Name) {
     if (StringCnt < NumSampledSymbols)
@@ -219,7 +229,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
     uint64_t AccessCount =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-    SymbolID SymbolID;
+    SymbolHandle SymbolID;
     if (IsStringLiteral)
       SymbolID = ID;
     else
diff --git a/llvm/unittests/ProfileData/CMakeLists.txt b/llvm/unittests/ProfileData/CMakeLists.txt
index 0a7f7da085950..29b9cb751dabe 100644
--- a/llvm/unittests/ProfileData/CMakeLists.txt
+++ b/llvm/unittests/ProfileData/CMakeLists.txt
@@ -10,6 +10,7 @@ set(LLVM_LINK_COMPONENTS
 add_llvm_unittest(ProfileDataTests
   BPFunctionNodeTest.cpp
   CoverageMappingTest.cpp
+  DataAccessProfTest.cpp
   InstrProfDataTest.cpp
   InstrProfTest.cpp
   ItaniumManglingCanonicalizerTest.cpp
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
new file mode 100644
index 0000000000000..50c4af49fe76b
--- /dev/null
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -0,0 +1,181 @@
+
+//===- unittests/Support/DataAccessProfTest.cpp
+//----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gmock/gmock-more-matchers.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+namespace data_access_prof {
+namespace {
+
+using ::llvm::StringRef;
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+
+static std::string ErrorToString(Error E) {
+  std::string ErrMsg;
+  llvm::raw_string_ostream OS(ErrMsg);
+  llvm::logAllUnhandledErrors(std::move(E), OS);
+  return ErrMsg;
+}
+
+// Test the various scenarios when DataAccessProfData should return error on
+// invalid input.
+TEST(MemProf, DataAccessProfileError) {
+  // Returns error if the input symbol name is empty.
+  DataAccessProfData Data;
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
+              HasSubstr("Empty symbol name"));
+
+  // Returns error when the same symbol gets added twice.
+  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
+              HasSubstr("Duplicate symbol or string literal added"));
+
+  // Returns error when the same string content hash gets added twice.
+  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
+  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
+              HasSubstr("Duplicate symbol or string literal added"));
+}
+
+// Test the following operations on DataAccessProfData:
+// - Profile record look up.
+// - Serialization and de-serialization.
+TEST(MemProf, DataAccessProfile) {
+  DataAccessProfData Data;
+
+  // In the bool conversion, Error is true if it's in a failure state and false
+  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
+  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
+  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
+                                         {
+                                             DataLocation{"file2", 3},
+                                         }));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
+  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
+  ASSERT_FALSE(Data.setDataAccessProfile(
+      (uint64_t)135246, 1000,
+      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+
+  {
+    // Test that symbol names and file names are stored in the input order.
+    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles.
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
+
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+
+    EXPECT_THAT(
+        *Data.getProfileRecord("foo.llvm.123"),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+              testing::Field(&DataAccessProfRecord::AccessCount, 100),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+              testing::Field(&DataAccessProfRecord::Locations,
+                             testing::IsEmpty())));
+    EXPECT_THAT(
+        *Data.getProfileRecord("bar.__uniq.321"),
+        AllOf(
+            testing::Field(&DataAccessProfRecord::SymbolID, 1),
+            testing::Field(&DataAccessProfRecord::AccessCount, 123),
+            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+            testing::Field(&DataAccessProfRecord::Locations,
+                           ElementsAre(AllOf(
+                               testing::Field(&DataLocation::FileName, "file2"),
+                               testing::Field(&DataLocation::Line, 3))))));
+    EXPECT_THAT(
+        *Data.getProfileRecord((uint64_t)135246),
+        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+              testing::Field(
+                  &DataAccessProfRecord::Locations,
+                  ElementsAre(
+                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                            testing::Field(&DataLocation::Line, 1)),
+                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                            testing::Field(&DataLocation::Line, 2))))));
+  }
+
+  // Tests serialization and de-serialization.
+  DataAccessProfData deserializedData;
+  {
+    std::string serializedData;
+    llvm::raw_string_ostream OS(serializedData);
+    llvm::ProfOStream POS(OS);
+
+    EXPECT_FALSE(Data.serialize(POS));
+
+    const unsigned char *p =
+        reinterpret_cast<const unsigned char *>(serializedData.data());
+    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                testing::IsEmpty());
+    EXPECT_FALSE(deserializedData.deserialize(p));
+
+    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
+                ElementsAre("sym2", "sym1"));
+    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
+
+    // Look up profiles after deserialization.
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
+    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
+
+    auto Records =
+        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
+
+    EXPECT_THAT(
+        Records,
+        ElementsAre(
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(&DataAccessProfRecord::Locations,
+                                 testing::IsEmpty())),
+            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
+                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
+                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
+                  testing::Field(
+                      &DataAccessProfRecord::Locations,
+                      ElementsAre(AllOf(
+                          testing::Field(&DataLocation::FileName, "file2"),
+                          testing::Field(&DataLocation::Line, 3))))),
+            AllOf(
+                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
+                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
+                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
+                testing::Field(
+                    &DataAccessProfRecord::Locations,
+                    ElementsAre(
+                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
+                              testing::Field(&DataLocation::Line, 1)),
+                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
+                              testing::Field(&DataLocation::Line, 2)))))));
+  }
+}
+} // namespace
+} // namespace data_access_prof
+} // namespace llvm
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index b7b8d642ad930..3e430aa4eae58 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -10,17 +10,14 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLForwardCompat.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/DebugInfo/DIContext.h"
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Object/ObjectFile.h"
-#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/ProfileData/MemProfReader.h"
 #include "llvm/ProfileData/MemProfYAML.h"
 #include "llvm/Support/raw_ostream.h"
-#include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -39,8 +36,6 @@ using ::llvm::StringRef;
 using ::llvm::object::SectionedAddress;
 using ::llvm::symbolize::SymbolizableModule;
 using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
-using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 using ::testing::Pair;
 using ::testing::Return;
@@ -752,161 +747,6 @@ TEST(MemProf, YAMLParser) {
                                ElementsAre(0x3000)))));
 }
 
-static std::string ErrorToString(Error E) {
-  std::string ErrMsg;
-  llvm::raw_string_ostream OS(ErrMsg);
-  llvm::logAllUnhandledErrors(std::move(E), OS);
-  return ErrMsg;
-}
-
-// Test the various scenarios when DataAccessProfData should return error on
-// invalid input.
-TEST(MemProf, DataAccessProfileError) {
-  // Returns error if the input symbol name is empty.
-  llvm::data_access_prof::DataAccessProfData Data;
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)),
-              HasSubstr("Empty symbol name"));
-
-  // Returns error when the same symbol gets added twice.
-  ASSERT_FALSE(Data.setDataAccessProfile("foo", 100));
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)),
-              HasSubstr("Duplicate symbol or string literal added"));
-
-  // Returns error when the same string content hash gets added twice.
-  ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000));
-  EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)),
-              HasSubstr("Duplicate symbol or string literal added"));
-}
-
-// Test the following operations on DataAccessProfData:
-// - Profile record look up.
-// - Serialization and de-serialization.
-TEST(MemProf, DataAccessProfile) {
-  using namespace llvm::data_access_prof;
-  llvm::data_access_prof::DataAccessProfData Data;
-
-  // In the bool conversion, Error is true if it's in a failure state and false
-  // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error.
-  ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
-  ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
-                                         {
-                                             DataLocation{"file2", 3},
-                                         }));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
-  ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
-  ASSERT_FALSE(Data.setDataAccessProfile(
-      (uint64_t)135246, 1000,
-      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
-
-  {
-    // Test that symbol names and file names are stored in the input order.
-    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
-    EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
-    EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
-
-    // Look up profiles.
-    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789));
-    EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678));
-    EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
-    EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
-
-    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
-    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
-
-    EXPECT_THAT(
-        Data.getProfileRecord("foo.llvm.123"),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-              testing::Field(&DataAccessProfRecord::AccessCount, 100),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-              testing::Field(&DataAccessProfRecord::Locations,
-                             testing::IsEmpty())));
-    EXPECT_THAT(
-        *Data.getProfileRecord("bar.__uniq.321"),
-        AllOf(
-            testing::Field(&DataAccessProfRecord::SymbolID, 1),
-            testing::Field(&DataAccessProfRecord::AccessCount, 123),
-            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-            testing::Field(&DataAccessProfRecord::Locations,
-                           ElementsAre(AllOf(
-                               testing::Field(&DataLocation::FileName, "file2"),
-                               testing::Field(&DataLocation::Line, 3))))));
-    EXPECT_THAT(
-        *Data.getProfileRecord((uint64_t)135246),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-              testing::Field(
-                  &DataAccessProfRecord::Locations,
-                  ElementsAre(
-                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                            testing::Field(&DataLocation::Line, 1)),
-                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                            testing::Field(&DataLocation::Line, 2))))));
-  }
-
-  // Tests serialization and de-serialization.
-  llvm::data_access_prof::DataAccessProfData deserializedData;
-  {
-    std::string serializedData;
-    llvm::raw_string_ostream OS(serializedData);
-    llvm::ProfOStream POS(OS);
-
-    EXPECT_FALSE(Data.serialize(POS));
-
-    const unsigned char *p =
-        reinterpret_cast<const unsigned char *>(serializedData.data());
-    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
-                testing::IsEmpty());
-    EXPECT_FALSE(deserializedData.deserialize(p));
-
-    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
-    EXPECT_THAT(deserializedData.getKnownColdSymbols(),
-                ElementsAre("sym2", "sym1"));
-    EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
-
-    // Look up profiles after deserialization.
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2"));
-    EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1"));
-
-    auto Records =
-        llvm::to_vector(llvm::make_second_range(deserializedData.getRecords()));
-
-    EXPECT_THAT(
-        Records,
-        ElementsAre(
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(&DataAccessProfRecord::Locations,
-                                 testing::IsEmpty())),
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(
-                      &DataAccessProfRecord::Locations,
-                      ElementsAre(AllOf(
-                          testing::Field(&DataLocation::FileName, "file2"),
-                          testing::Field(&DataLocation::Line, 3))))),
-            AllOf(
-                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-                testing::Field(
-                    &DataAccessProfRecord::Locations,
-                    ElementsAre(
-                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                              testing::Field(&DataLocation::Line, 1)),
-                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                              testing::Field(&DataLocation::Line, 2)))))));
-  }
-}
-
 // Verify that the YAML parser accepts a GUID expressed as a function name.
 TEST(MemProf, YAMLParserGUID) {
   StringRef YAMLData = R"YAML(

>From df080949a1ccc75fceb386b0232005f96397752d Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 6 May 2025 16:35:14 -0700
Subject: [PATCH 05/14] resolve feedback

---
 llvm/include/llvm/ProfileData/DataAccessProf.h |  6 +++++-
 llvm/lib/ProfileData/DataAccessProf.cpp        | 12 +++++++++---
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index a85a3320ae57c..91de43fdf60ca 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -52,7 +52,11 @@ struct DataAccessProfRecord {
 
   // Represents a data symbol. The semantic comes in two forms: a symbol index
   // for symbol name if `IsStringLiteral` is false, or the hash of a string
-  // content if `IsStringLiteral` is true. Required.
+  // content if `IsStringLiteral` is true. For most of the symbolizable static
+  // data, the mangled symbol names remain stable relative to the source code
+  // and therefore used to identify symbols across binary releases. String
+  // literals have unstable name patterns like `.str.N[.llvm.hash]`, so we use
+  // the content hash instead. This is a required field.
   uint64_t SymbolID;
 
   // The access count of symbol. Required.
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d1c034f639347..d7e67f5f09cbe 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -35,9 +35,15 @@ const DataAccessProfRecord *
 DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
   auto Key = SymbolID;
   if (std::holds_alternative<StringRef>(SymbolID)) {
-    StringRef Name = std::get<StringRef>(SymbolID);
-    assert(!Name.empty() && "Empty symbol name");
-    Key = InstrProfSymtab::getCanonicalName(Name);
+    auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
+    // If name canonicalization fails, suppress the error inside.
+    if (!NameOrErr) {
+      assert(
+          std::get<StringRef>(SymbolID).empty() &&
+          "Name canonicalization only fails when stringified string is empty.");
+      return nullptr;
+    }
+    Key = *NameOrErr;
   }
 
   auto It = Records.find(Key);

>From 6dd04e46542851b84bf26cd95245399204072085 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 12 May 2025 17:05:19 -0700
Subject: [PATCH 06/14] resolve comments

---
 .../include/llvm/ProfileData/DataAccessProf.h | 119 ++++++++++++------
 llvm/include/llvm/ProfileData/InstrProf.h     |   2 +-
 llvm/lib/ProfileData/DataAccessProf.cpp       |  49 ++++----
 .../ProfileData/DataAccessProfTest.cpp        | 116 ++++++++---------
 4 files changed, 167 insertions(+), 119 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index 91de43fdf60ca..e8504102238d1 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -30,23 +30,40 @@
 #include "llvm/Support/StringSaver.h"
 
 #include <cstdint>
+#include <optional>
 #include <variant>
 
 namespace llvm {
 
 namespace data_access_prof {
-// The location of data in the source code.
-struct DataLocation {
+
+/// The location of data in the source code. Used by profile lookup API.
+struct SourceLocation {
+  SourceLocation(StringRef FileNameRef, uint32_t Line)
+      : FileName(FileNameRef.str()), Line(Line) {}
+  /// The filename where the data is located.
+  std::string FileName;
+  /// The line number in the source code.
+  uint32_t Line;
+};
+
+namespace internal {
+
+// Conceptually similar to SourceLocation except that FileNames are StringRef of
+// which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData`
+// to represent data locations internally.
+struct SourceLocationRef {
   // The filename where the data is located.
   StringRef FileName;
   // The line number in the source code.
   uint32_t Line;
 };
 
-// The data access profiles for a symbol.
-struct DataAccessProfRecord {
-  DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount,
-                       bool IsStringLiteral)
+// The data access profiles for a symbol. Used by `DataAccessProfData`
+// to represent records internally.
+struct DataAccessProfRecordRef {
+  DataAccessProfRecordRef(uint64_t SymbolID, uint64_t AccessCount,
+                          bool IsStringLiteral)
       : SymbolID(SymbolID), AccessCount(AccessCount),
         IsStringLiteral(IsStringLiteral) {}
 
@@ -67,18 +84,43 @@ struct DataAccessProfRecord {
   bool IsStringLiteral;
 
   // The locations of data in the source code. Optional.
-  llvm::SmallVector<DataLocation, 0> Locations;
+  llvm::SmallVector<SourceLocationRef, 0> Locations;
 };
+} // namespace internal
+
+// SymbolID is either a string representing symbol name if the symbol has
+// stable mangled name relative to source code, or a uint64_t representing the
+// content hash of a string literal (with unstable name patterns like
+// `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
+using SymbolHandleRef = std::variant<StringRef, uint64_t>;
 
-/// Encapsulates the data access profile data and the methods to operate on it.
-/// This class provides profile look-up, serialization and deserialization.
+// The senamtic is the same as `SymbolHandleRef` above. The strings are owned.
+using SymbolHandle = std::variant<std::string, uint64_t>;
+
+/// The data access profiles for a symbol.
+struct DataAccessProfRecord {
+public:
+  DataAccessProfRecord(SymbolHandleRef SymHandleRef,
+                       ArrayRef<internal::SourceLocationRef> LocRefs) {
+    if (std::holds_alternative<StringRef>(SymHandleRef)) {
+      SymHandle = std::get<StringRef>(SymHandleRef).str();
+    } else
+      SymHandle = std::get<uint64_t>(SymHandleRef);
+
+    for (auto Loc : LocRefs)
+      Locations.push_back(SourceLocation(Loc.FileName, Loc.Line));
+  }
+  SymbolHandle SymHandle;
+
+  // The locations of data in the source code. Optional.
+  SmallVector<SourceLocation> Locations;
+};
+
+/// Encapsulates the data access profile data and the methods to operate on
+/// it. This class provides profile look-up, serialization and
+/// deserialization.
 class DataAccessProfData {
 public:
-  // SymbolID is either a string representing symbol name if the symbol has
-  // stable mangled name relative to source code, or a uint64_t representing the
-  // content hash of a string literal (with unstable name patterns like
-  // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object.
-  using SymbolHandle = std::variant<StringRef, uint64_t>;
   using StringToIndexMap = llvm::MapVector<StringRef, uint64_t>;
 
   DataAccessProfData() : Saver(Allocator) {}
@@ -93,35 +135,39 @@ class DataAccessProfData {
   /// Deserialize this class from the given buffer.
   Error deserialize(const unsigned char *&Ptr);
 
-  /// Returns a pointer of profile record for \p SymbolID, or nullptr if there
+  /// Returns a profile record for \p SymbolID, or std::nullopt if there
   /// isn't a record. Internally, this function will canonicalize the symbol
   /// name before the lookup.
-  const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const;
+  std::optional<DataAccessProfRecord>
+  getProfileRecord(const SymbolHandleRef SymID) const;
 
   /// Returns true if \p SymID is seen in profiled binaries and cold.
-  bool isKnownColdSymbol(const SymbolHandle SymID) const;
+  bool isKnownColdSymbol(const SymbolHandleRef SymID) const;
 
-  /// Methods to set symbolized data access profile. Returns error if duplicated
-  /// symbol names or content hashes are seen. The user of this class should
-  /// aggregate counters that correspond to the same symbol name or with the
-  /// same string literal hash before calling 'set*' methods.
-  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount);
+  /// Methods to set symbolized data access profile. Returns error if
+  /// duplicated symbol names or content hashes are seen. The user of this
+  /// class should aggregate counters that correspond to the same symbol name
+  /// or with the same string literal hash before calling 'set*' methods.
+  Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount);
   /// Similar to the method above, for records with \p Locations representing
   /// the `filename:line` where this symbol shows up. Note because of linker's
   /// merge of identical symbols (e.g., unnamed_addr string literals), one
   /// symbol is likely to have multiple locations.
-  Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount,
-                             ArrayRef<DataLocation> Locations);
-  Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID);
-
-  /// Returns an iterable StringRef for strings in the order they are added.
-  /// Each string may be a symbol name or a file name.
-  auto getStrings() const {
-    return llvm::make_first_range(StrToIndexMap.getArrayRef());
+  Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount,
+                             ArrayRef<SourceLocation> Locations);
+  /// Add a symbol that's seen in the profiled binary without samples.
+  Error addKnownSymbolWithoutSamples(SymbolHandleRef SymbolID);
+
+  /// The following methods return array reference for various internal data
+  /// structures.
+  ArrayRef<StringToIndexMap::value_type> getStrToIndexMapRef() const {
+    return StrToIndexMap.getArrayRef();
+  }
+  ArrayRef<
+      MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef>::value_type>
+  getRecords() const {
+    return Records.getArrayRef();
   }
-
-  /// Returns array reference for various internal data structures.
-  auto getRecords() const { return Records.getArrayRef(); }
   ArrayRef<StringRef> getKnownColdSymbols() const {
     return KnownColdSymbols.getArrayRef();
   }
@@ -139,11 +185,12 @@ class DataAccessProfData {
                                        const uint64_t NumSampledSymbols,
                                        const uint64_t NumColdKnownSymbols);
 
-  /// Decode the records and increment \p Ptr to the start of the next payload.
+  /// Decode the records and increment \p Ptr to the start of the next
+  /// payload.
   Error deserializeRecords(const unsigned char *&Ptr);
 
   /// A helper function to compute a storage index for \p SymbolID.
-  uint64_t getEncodedIndex(const SymbolHandle SymbolID) const;
+  uint64_t getEncodedIndex(const SymbolHandleRef SymbolID) const;
 
   // Keeps owned copies of the input strings.
   // NOTE: Keep `Saver` initialized before other class members that reference
@@ -152,7 +199,7 @@ class DataAccessProfData {
   llvm::UniqueStringSaver Saver;
 
   // `Records` stores the records.
-  MapVector<SymbolHandle, DataAccessProfRecord> Records;
+  MapVector<SymbolHandleRef, internal::DataAccessProfRecordRef> Records;
 
   // Use MapVector to keep input order of strings for serialization and
   // deserialization.
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index 33b93ea0a558a..544a59df43ed3 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -500,7 +500,7 @@ class InstrProfSymtab {
 public:
   using AddrHashMap = std::vector<std::pair<uint64_t, uint64_t>>;
 
-  // Returns the canonial name of the given PGOName. In a canonical name, all
+  // Returns the canonical name of the given PGOName. In a canonical name, all
   // suffixes that begins with "." except ".__uniq." are stripped.
   // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`.
   static StringRef getCanonicalName(StringRef PGOName);
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index d7e67f5f09cbe..c5d0099977cfa 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -16,11 +16,11 @@ namespace data_access_prof {
 // If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
 // creates an owned copy of `Str`, adds a map entry for it and returns the
 // iterator.
-static MapVector<StringRef, uint64_t>::iterator
-saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+static std::pair<StringRef, uint64_t>
+saveStringToMap(DataAccessProfData::StringToIndexMap &Map,
                 llvm::UniqueStringSaver &Saver, StringRef Str) {
   auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
-  return Iter;
+  return *Iter;
 }
 
 // Returns the canonical name or error.
@@ -31,8 +31,8 @@ static Expected<StringRef> getCanonicalName(StringRef Name) {
   return InstrProfSymtab::getCanonicalName(Name);
 }
 
-const DataAccessProfRecord *
-DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
+std::optional<DataAccessProfRecord>
+DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
   auto Key = SymbolID;
   if (std::holds_alternative<StringRef>(SymbolID)) {
     auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
@@ -41,40 +41,39 @@ DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const {
       assert(
           std::get<StringRef>(SymbolID).empty() &&
           "Name canonicalization only fails when stringified string is empty.");
-      return nullptr;
+      return std::nullopt;
     }
     Key = *NameOrErr;
   }
 
   auto It = Records.find(Key);
-  if (It != Records.end())
-    return &It->second;
+  if (It != Records.end()) {
+    return DataAccessProfRecord(Key, It->second.Locations);
+  }
 
-  return nullptr;
+  return std::nullopt;
 }
 
-bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const {
+bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const {
   if (std::holds_alternative<uint64_t>(SymID))
     return KnownColdHashes.contains(std::get<uint64_t>(SymID));
   return KnownColdSymbols.contains(std::get<StringRef>(SymID));
 }
 
-Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
+Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol,
                                                uint64_t AccessCount) {
   uint64_t RecordID = -1;
-  bool IsStringLiteral = false;
-  SymbolHandle Key;
-  if (std::holds_alternative<uint64_t>(Symbol)) {
+  const bool IsStringLiteral = std::holds_alternative<uint64_t>(Symbol);
+  SymbolHandleRef Key;
+  if (IsStringLiteral) {
     RecordID = std::get<uint64_t>(Symbol);
     Key = RecordID;
-    IsStringLiteral = true;
   } else {
     auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
     if (!CanonicalName)
       return CanonicalName.takeError();
     std::tie(Key, RecordID) =
-        *saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
-    IsStringLiteral = false;
+        saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
   }
 
   auto [Iter, Inserted] =
@@ -89,21 +88,22 @@ Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol,
 }
 
 Error DataAccessProfData::setDataAccessProfile(
-    SymbolHandle SymbolID, uint64_t AccessCount,
-    ArrayRef<DataLocation> Locations) {
+    SymbolHandleRef SymbolID, uint64_t AccessCount,
+    ArrayRef<SourceLocation> Locations) {
   if (Error E = setDataAccessProfile(SymbolID, AccessCount))
     return E;
 
   auto &Record = Records.back().second;
   for (const auto &Location : Locations)
     Record.Locations.push_back(
-        {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
+        {saveStringToMap(StrToIndexMap, Saver, Location.FileName).first,
          Location.Line});
 
   return Error::success();
 }
 
-Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) {
+Error DataAccessProfData::addKnownSymbolWithoutSamples(
+    SymbolHandleRef SymbolID) {
   if (std::holds_alternative<uint64_t>(SymbolID)) {
     KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
     return Error::success();
@@ -164,7 +164,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
 }
 
 uint64_t
-DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const {
+DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const {
   if (std::holds_alternative<uint64_t>(SymbolID))
     return std::get<uint64_t>(SymbolID);
 
@@ -220,7 +220,8 @@ Error DataAccessProfData::deserializeSymbolsAndFilenames(
 }
 
 Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
-  SmallVector<StringRef> Strings = llvm::to_vector(getStrings());
+  SmallVector<StringRef> Strings =
+      llvm::to_vector(llvm::make_first_range(getStrToIndexMapRef()));
 
   uint64_t NumRecords =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
@@ -235,7 +236,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
     uint64_t AccessCount =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
-    SymbolHandle SymbolID;
+    SymbolHandleRef SymbolID;
     if (IsStringLiteral)
       SymbolID = ID;
     else
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
index 50c4af49fe76b..127230d4805e7 100644
--- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -1,4 +1,3 @@
-
 //===- unittests/Support/DataAccessProfTest.cpp
 //----------------------------------===//
 //
@@ -10,6 +9,7 @@
 
 #include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
 #include "gmock/gmock-more-matchers.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
@@ -19,7 +19,9 @@ namespace data_access_prof {
 namespace {
 
 using ::llvm::StringRef;
+using llvm::ValueIs;
 using ::testing::ElementsAre;
+using ::testing::Field;
 using ::testing::HasSubstr;
 using ::testing::IsEmpty;
 
@@ -53,6 +55,8 @@ TEST(MemProf, DataAccessProfileError) {
 // - Profile record look up.
 // - Serialization and de-serialization.
 TEST(MemProf, DataAccessProfile) {
+  using internal::DataAccessProfRecordRef;
+  using internal::SourceLocationRef;
   DataAccessProfData Data;
 
   // In the bool conversion, Error is true if it's in a failure state and false
@@ -62,18 +66,19 @@ TEST(MemProf, DataAccessProfile) {
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2"));
   ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123,
                                          {
-                                             DataLocation{"file2", 3},
+                                             SourceLocation{"file2", 3},
                                          }));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1"));
   ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678));
   ASSERT_FALSE(Data.setDataAccessProfile(
       (uint64_t)135246, 1000,
-      {DataLocation{"file1", 1}, DataLocation{"file2", 2}}));
+      {SourceLocation{"file1", 1}, SourceLocation{"file2", 2}}));
 
   {
     // Test that symbol names and file names are stored in the input order.
-    EXPECT_THAT(llvm::to_vector(Data.getStrings()),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(
+        llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())),
+        ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
     EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
     EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
 
@@ -83,38 +88,34 @@ TEST(MemProf, DataAccessProfile) {
     EXPECT_TRUE(Data.isKnownColdSymbol("sym2"));
     EXPECT_TRUE(Data.isKnownColdSymbol("sym1"));
 
-    EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr);
-    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr);
+    EXPECT_EQ(Data.getProfileRecord("non-existence"), std::nullopt);
+    EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), std::nullopt);
 
     EXPECT_THAT(
-        *Data.getProfileRecord("foo.llvm.123"),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-              testing::Field(&DataAccessProfRecord::AccessCount, 100),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-              testing::Field(&DataAccessProfRecord::Locations,
-                             testing::IsEmpty())));
+        Data.getProfileRecord("foo.llvm.123"),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<std::string>(testing::Eq("foo"))),
+            Field(&DataAccessProfRecord::Locations, testing::IsEmpty()))));
     EXPECT_THAT(
-        *Data.getProfileRecord("bar.__uniq.321"),
-        AllOf(
-            testing::Field(&DataAccessProfRecord::SymbolID, 1),
-            testing::Field(&DataAccessProfRecord::AccessCount, 123),
-            testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-            testing::Field(&DataAccessProfRecord::Locations,
-                           ElementsAre(AllOf(
-                               testing::Field(&DataLocation::FileName, "file2"),
-                               testing::Field(&DataLocation::Line, 3))))));
+        Data.getProfileRecord("bar.__uniq.321"),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<std::string>(
+                      testing::Eq("bar.__uniq.321"))),
+            Field(&DataAccessProfRecord::Locations,
+                  ElementsAre(AllOf(Field(&SourceLocation::FileName, "file2"),
+                                    Field(&SourceLocation::Line, 3)))))));
     EXPECT_THAT(
-        *Data.getProfileRecord((uint64_t)135246),
-        AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-              testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-              testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-              testing::Field(
-                  &DataAccessProfRecord::Locations,
-                  ElementsAre(
-                      AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                            testing::Field(&DataLocation::Line, 1)),
-                      AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                            testing::Field(&DataLocation::Line, 2))))));
+        Data.getProfileRecord((uint64_t)135246),
+        ValueIs(AllOf(
+            Field(&DataAccessProfRecord::SymHandle,
+                  testing::VariantWith<uint64_t>(testing::Eq(135246))),
+            Field(&DataAccessProfRecord::Locations,
+                  ElementsAre(AllOf(Field(&SourceLocation::FileName, "file1"),
+                                    Field(&SourceLocation::Line, 1)),
+                              AllOf(Field(&SourceLocation::FileName, "file2"),
+                                    Field(&SourceLocation::Line, 2)))))));
   }
 
   // Tests serialization and de-serialization.
@@ -128,11 +129,13 @@ TEST(MemProf, DataAccessProfile) {
 
     const unsigned char *p =
         reinterpret_cast<const unsigned char *>(serializedData.data());
-    ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()),
+    ASSERT_THAT(llvm::to_vector(llvm::make_first_range(
+                    deserializedData.getStrToIndexMapRef())),
                 testing::IsEmpty());
     EXPECT_FALSE(deserializedData.deserialize(p));
 
-    EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()),
+    EXPECT_THAT(llvm::to_vector(llvm::make_first_range(
+                    deserializedData.getStrToIndexMapRef())),
                 ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
     EXPECT_THAT(deserializedData.getKnownColdSymbols(),
                 ElementsAre("sym2", "sym1"));
@@ -150,30 +153,27 @@ TEST(MemProf, DataAccessProfile) {
     EXPECT_THAT(
         Records,
         ElementsAre(
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 100),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(&DataAccessProfRecord::Locations,
-                                 testing::IsEmpty())),
-            AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1),
-                  testing::Field(&DataAccessProfRecord::AccessCount, 123),
-                  testing::Field(&DataAccessProfRecord::IsStringLiteral, false),
-                  testing::Field(
-                      &DataAccessProfRecord::Locations,
-                      ElementsAre(AllOf(
-                          testing::Field(&DataLocation::FileName, "file2"),
-                          testing::Field(&DataLocation::Line, 3))))),
             AllOf(
-                testing::Field(&DataAccessProfRecord::SymbolID, 135246),
-                testing::Field(&DataAccessProfRecord::AccessCount, 1000),
-                testing::Field(&DataAccessProfRecord::IsStringLiteral, true),
-                testing::Field(
-                    &DataAccessProfRecord::Locations,
-                    ElementsAre(
-                        AllOf(testing::Field(&DataLocation::FileName, "file1"),
-                              testing::Field(&DataLocation::Line, 1)),
-                        AllOf(testing::Field(&DataLocation::FileName, "file2"),
-                              testing::Field(&DataLocation::Line, 2)))))));
+                Field(&DataAccessProfRecordRef::SymbolID, 0),
+                Field(&DataAccessProfRecordRef::AccessCount, 100),
+                Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+                Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())),
+            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1),
+                  Field(&DataAccessProfRecordRef::AccessCount, 123),
+                  Field(&DataAccessProfRecordRef::IsStringLiteral, false),
+                  Field(&DataAccessProfRecordRef::Locations,
+                        ElementsAre(
+                            AllOf(Field(&SourceLocationRef::FileName, "file2"),
+                                  Field(&SourceLocationRef::Line, 3))))),
+            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 135246),
+                  Field(&DataAccessProfRecordRef::AccessCount, 1000),
+                  Field(&DataAccessProfRecordRef::IsStringLiteral, true),
+                  Field(&DataAccessProfRecordRef::Locations,
+                        ElementsAre(
+                            AllOf(Field(&SourceLocationRef::FileName, "file1"),
+                                  Field(&SourceLocationRef::Line, 1)),
+                            AllOf(Field(&SourceLocationRef::FileName, "file2"),
+                                  Field(&SourceLocationRef::Line, 2)))))));
   }
 }
 } // namespace

>From 4045c943966609cc9a92693752af0e29a19e1ef9 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 15 May 2025 09:56:52 -0700
Subject: [PATCH 07/14] Support reading and writing data access profiles in
 memprof v4.

---
 .../include/llvm/ProfileData/DataAccessProf.h | 12 +++-
 .../llvm/ProfileData/IndexedMemProfData.h     | 12 +++-
 .../llvm/ProfileData/InstrProfReader.h        |  6 +-
 .../llvm/ProfileData/InstrProfWriter.h        |  6 ++
 llvm/include/llvm/ProfileData/MemProfReader.h | 15 +++++
 llvm/include/llvm/ProfileData/MemProfYAML.h   | 58 ++++++++++++++++++
 llvm/lib/ProfileData/DataAccessProf.cpp       |  6 +-
 llvm/lib/ProfileData/IndexedMemProfData.cpp   | 61 +++++++++++++++----
 llvm/lib/ProfileData/InstrProfReader.cpp      | 14 +++++
 llvm/lib/ProfileData/InstrProfWriter.cpp      | 20 ++++--
 llvm/lib/ProfileData/MemProfReader.cpp        | 30 +++++++++
 .../tools/llvm-profdata/memprof-yaml.test     | 11 ++++
 llvm/tools/llvm-profdata/llvm-profdata.cpp    |  4 ++
 .../ProfileData/DataAccessProfTest.cpp        | 11 ++--
 14 files changed, 235 insertions(+), 31 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index e8504102238d1..f5f6abf0a2817 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -41,6 +41,8 @@ namespace data_access_prof {
 struct SourceLocation {
   SourceLocation(StringRef FileNameRef, uint32_t Line)
       : FileName(FileNameRef.str()), Line(Line) {}
+
+  SourceLocation() {}
   /// The filename where the data is located.
   std::string FileName;
   /// The line number in the source code.
@@ -53,6 +55,8 @@ namespace internal {
 // which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData`
 // to represent data locations internally.
 struct SourceLocationRef {
+  SourceLocationRef(StringRef FileNameRef, uint32_t Line)
+      : FileName(FileNameRef), Line(Line) {}
   // The filename where the data is located.
   StringRef FileName;
   // The line number in the source code.
@@ -100,8 +104,9 @@ using SymbolHandle = std::variant<std::string, uint64_t>;
 /// The data access profiles for a symbol.
 struct DataAccessProfRecord {
 public:
-  DataAccessProfRecord(SymbolHandleRef SymHandleRef,
-                       ArrayRef<internal::SourceLocationRef> LocRefs) {
+  DataAccessProfRecord(SymbolHandleRef SymHandleRef, uint64_t AccessCount,
+                       ArrayRef<internal::SourceLocationRef> LocRefs)
+      : AccessCount(AccessCount) {
     if (std::holds_alternative<StringRef>(SymHandleRef)) {
       SymHandle = std::get<StringRef>(SymHandleRef).str();
     } else
@@ -110,8 +115,9 @@ struct DataAccessProfRecord {
     for (auto Loc : LocRefs)
       Locations.push_back(SourceLocation(Loc.FileName, Loc.Line));
   }
+  DataAccessProfRecord() {}
   SymbolHandle SymHandle;
-
+  uint64_t AccessCount;
   // The locations of data in the source code. Optional.
   SmallVector<SourceLocation> Locations;
 };
diff --git a/llvm/include/llvm/ProfileData/IndexedMemProfData.h b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
index 3c6c329d1c49d..66fa38472059b 100644
--- a/llvm/include/llvm/ProfileData/IndexedMemProfData.h
+++ b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
@@ -10,14 +10,20 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/MemProf.h"
 
+#include <functional>
+#include <optional>
+
 namespace llvm {
 
 // Write the MemProf data to OS.
-Error writeMemProf(ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
-                   memprof::IndexedVersion MemProfVersionRequested,
-                   bool MemProfFullSchema);
+Error writeMemProf(
+    ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+    memprof::IndexedVersion MemProfVersionRequested, bool MemProfFullSchema,
+    std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+        DataAccessProfileData);
 
 } // namespace llvm
diff --git a/llvm/include/llvm/ProfileData/InstrProfReader.h b/llvm/include/llvm/ProfileData/InstrProfReader.h
index c250a9ede39bc..a3436e1dfe711 100644
--- a/llvm/include/llvm/ProfileData/InstrProfReader.h
+++ b/llvm/include/llvm/ProfileData/InstrProfReader.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/Object/BuildID.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/InstrProfCorrelator.h"
 #include "llvm/ProfileData/MemProf.h"
@@ -703,10 +704,13 @@ class IndexedMemProfReader {
   const unsigned char *CallStackBase = nullptr;
   // The number of elements in the radix tree array.
   unsigned RadixTreeSize = 0;
+  /// The data access profiles, deserialized from binary data.
+  std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
 
   Error deserializeV2(const unsigned char *Start, const unsigned char *Ptr);
   Error deserializeRadixTreeBased(const unsigned char *Start,
-                                  const unsigned char *Ptr);
+                                  const unsigned char *Ptr,
+                                  memprof::IndexedVersion Version);
 
 public:
   IndexedMemProfReader() = default;
diff --git a/llvm/include/llvm/ProfileData/InstrProfWriter.h b/llvm/include/llvm/ProfileData/InstrProfWriter.h
index 67d85daa81623..cf1cec25c3cac 100644
--- a/llvm/include/llvm/ProfileData/InstrProfWriter.h
+++ b/llvm/include/llvm/ProfileData/InstrProfWriter.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/Object/BuildID.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/MemProf.h"
 #include "llvm/Support/Error.h"
@@ -81,6 +82,8 @@ class InstrProfWriter {
   // Whether to generated random memprof hotness for testing.
   bool MemprofGenerateRandomHotness;
 
+  std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
+
 public:
   // For memprof testing, random hotness can be assigned to the contexts if
   // MemprofGenerateRandomHotness is enabled. The random seed can be either
@@ -122,6 +125,9 @@ class InstrProfWriter {
   // Add a binary id to the binary ids list.
   void addBinaryIds(ArrayRef<llvm::object::BuildID> BIs);
 
+  void addDataAccessProfData(
+      std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfile);
+
   /// Merge existing function counts from the given writer.
   void mergeRecordsFromWriter(InstrProfWriter &&IPW,
                               function_ref<void(Error)> Warn);
diff --git a/llvm/include/llvm/ProfileData/MemProfReader.h b/llvm/include/llvm/ProfileData/MemProfReader.h
index 29d9e57cae3e3..02defa189ea7e 100644
--- a/llvm/include/llvm/ProfileData/MemProfReader.h
+++ b/llvm/include/llvm/ProfileData/MemProfReader.h
@@ -228,6 +228,21 @@ class YAMLMemProfReader final : public MemProfReader {
   create(std::unique_ptr<MemoryBuffer> Buffer);
 
   void parse(StringRef YAMLData);
+
+  std::unique_ptr<data_access_prof::DataAccessProfData>
+  takeDataAccessProfData() {
+    return std::move(DataAccessProfileData);
+  }
+
+private:
+  // Called by `parse` to set data access profiles after parsing them from Yaml
+  // files.
+  void setDataAccessProfileData(
+      std::unique_ptr<data_access_prof::DataAccessProfData> Data) {
+    DataAccessProfileData = std::move(Data);
+  }
+
+  std::unique_ptr<data_access_prof::DataAccessProfData> DataAccessProfileData;
 };
 } // namespace memprof
 } // namespace llvm
diff --git a/llvm/include/llvm/ProfileData/MemProfYAML.h b/llvm/include/llvm/ProfileData/MemProfYAML.h
index 08dee253f615a..634edc4aa0122 100644
--- a/llvm/include/llvm/ProfileData/MemProfYAML.h
+++ b/llvm/include/llvm/ProfileData/MemProfYAML.h
@@ -2,6 +2,7 @@
 #define LLVM_PROFILEDATA_MEMPROFYAML_H_
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/MemProf.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/YAMLTraits.h"
@@ -20,9 +21,24 @@ struct GUIDMemProfRecordPair {
   MemProfRecord Record;
 };
 
+// Helper struct to yamlify data_access_prof::DataAccessProfData. The struct
+// members use owned strings. This is for simplicity and assumes that most real
+// world use cases do look-ups and regression test scale is small.
+struct YamlDataAccessProfData {
+  std::vector<data_access_prof::DataAccessProfRecord> Records;
+  std::vector<uint64_t> KnownColdHashes;
+  std::vector<std::string> KnownColdSymbols;
+
+  bool isEmpty() const {
+    return Records.empty() && KnownColdHashes.empty() &&
+           KnownColdSymbols.empty();
+  }
+};
+
 // The top-level data structure, only used with YAML for now.
 struct AllMemProfData {
   std::vector<GUIDMemProfRecordPair> HeapProfileRecords;
+  YamlDataAccessProfData YamlifiedDataAccessProfiles;
 };
 } // namespace memprof
 
@@ -206,9 +222,49 @@ template <> struct MappingTraits<memprof::GUIDMemProfRecordPair> {
   }
 };
 
+template <> struct MappingTraits<data_access_prof::SourceLocation> {
+  static void mapping(IO &Io, data_access_prof::SourceLocation &Loc) {
+    Io.mapOptional("FileName", Loc.FileName);
+    Io.mapOptional("Line", Loc.Line);
+  }
+};
+
+template <> struct MappingTraits<data_access_prof::DataAccessProfRecord> {
+  static void mapping(IO &Io, data_access_prof::DataAccessProfRecord &Rec) {
+    if (Io.outputting()) {
+      if (std::holds_alternative<std::string>(Rec.SymHandle)) {
+        Io.mapOptional("Symbol", std::get<std::string>(Rec.SymHandle));
+      } else {
+        Io.mapOptional("Hash", std::get<uint64_t>(Rec.SymHandle));
+      }
+    } else {
+      std::string SymName;
+      uint64_t Hash = 0;
+      Io.mapOptional("Symbol", SymName);
+      Io.mapOptional("Hash", Hash);
+      if (!SymName.empty()) {
+        Rec.SymHandle = SymName;
+      } else {
+        Rec.SymHandle = Hash;
+      }
+    }
+
+    Io.mapOptional("Locations", Rec.Locations);
+  }
+};
+
+template <> struct MappingTraits<memprof::YamlDataAccessProfData> {
+  static void mapping(IO &Io, memprof::YamlDataAccessProfData &Data) {
+    Io.mapOptional("SampledRecords", Data.Records);
+    Io.mapOptional("KnownColdSymbols", Data.KnownColdSymbols);
+    Io.mapOptional("KnownColdHashes", Data.KnownColdHashes);
+  }
+};
+
 template <> struct MappingTraits<memprof::AllMemProfData> {
   static void mapping(IO &Io, memprof::AllMemProfData &Data) {
     Io.mapRequired("HeapProfileRecords", Data.HeapProfileRecords);
+    Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
   }
 };
 
@@ -234,5 +290,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::AllocationInfo)
 LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::CallSiteInfo)
 LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::GUIDMemProfRecordPair)
 LLVM_YAML_IS_SEQUENCE_VECTOR(memprof::GUIDHex64) // Used for CalleeGuids
+LLVM_YAML_IS_SEQUENCE_VECTOR(data_access_prof::DataAccessProfRecord)
+LLVM_YAML_IS_SEQUENCE_VECTOR(data_access_prof::SourceLocation)
 
 #endif // LLVM_PROFILEDATA_MEMPROFYAML_H_
diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp
index c5d0099977cfa..61a73fab7269f 100644
--- a/llvm/lib/ProfileData/DataAccessProf.cpp
+++ b/llvm/lib/ProfileData/DataAccessProf.cpp
@@ -48,7 +48,8 @@ DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
 
   auto It = Records.find(Key);
   if (It != Records.end()) {
-    return DataAccessProfRecord(Key, It->second.Locations);
+    return DataAccessProfRecord(Key, It->second.AccessCount,
+                                It->second.Locations);
   }
 
   return std::nullopt;
@@ -111,7 +112,8 @@ Error DataAccessProfData::addKnownSymbolWithoutSamples(
   auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
   if (!CanonicalName)
     return CanonicalName.takeError();
-  KnownColdSymbols.insert(*CanonicalName);
+  KnownColdSymbols.insert(
+      saveStringToMap(StrToIndexMap, Saver, *CanonicalName).first);
   return Error::success();
 }
 
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 3d20f7a7a5778..cc1b03101c880 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/ProfileData/MemProf.h"
@@ -216,7 +217,9 @@ static Error writeMemProfV2(ProfOStream &OS,
 
 static Error writeMemProfRadixTreeBased(
     ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
-    memprof::IndexedVersion Version, bool MemProfFullSchema) {
+    memprof::IndexedVersion Version, bool MemProfFullSchema,
+    std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+        DataAccessProfileData) {
   assert((Version == memprof::Version3 || Version == memprof::Version4) &&
          "Unsupported version for radix tree format");
 
@@ -225,6 +228,8 @@ static Error writeMemProfRadixTreeBased(
   OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
   OS.write(0ULL); // Reserve space for the memprof record payload offset.
   OS.write(0ULL); // Reserve space for the memprof record table offset.
+  if (Version == memprof::Version4)
+    OS.write(0ULL); // Reserve space for the data access profile offset.
 
   auto Schema = memprof::getHotColdSchema();
   if (MemProfFullSchema)
@@ -251,17 +256,26 @@ static Error writeMemProfRadixTreeBased(
   uint64_t RecordTableOffset = writeMemProfRecords(
       OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
 
+  uint64_t DataAccessProfOffset = 0;
+  if (DataAccessProfileData.has_value()) {
+    DataAccessProfOffset = OS.tell();
+    if (Error E = (*DataAccessProfileData).get().serialize(OS))
+      return E;
+  }
+
   // Verify that the computation for the number of elements in the call stack
   // array works.
   assert(CallStackPayloadOffset +
              NumElements * sizeof(memprof::LinearFrameId) ==
          RecordPayloadOffset);
 
-  uint64_t Header[] = {
+  SmallVector<uint64_t, 4> Header = {
       CallStackPayloadOffset,
       RecordPayloadOffset,
       RecordTableOffset,
   };
+  if (Version == memprof::Version4)
+    Header.push_back(DataAccessProfOffset);
   OS.patch({{HeaderUpdatePos, Header}});
 
   return Error::success();
@@ -272,28 +286,33 @@ static Error writeMemProfV3(ProfOStream &OS,
                             memprof::IndexedMemProfData &MemProfData,
                             bool MemProfFullSchema) {
   return writeMemProfRadixTreeBased(OS, MemProfData, memprof::Version3,
-                                    MemProfFullSchema);
+                                    MemProfFullSchema, std::nullopt);
 }
 
 // Write out MemProf Version4
-static Error writeMemProfV4(ProfOStream &OS,
-                            memprof::IndexedMemProfData &MemProfData,
-                            bool MemProfFullSchema) {
+static Error writeMemProfV4(
+    ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+    bool MemProfFullSchema,
+    std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+        DataAccessProfileData) {
   return writeMemProfRadixTreeBased(OS, MemProfData, memprof::Version4,
-                                    MemProfFullSchema);
+                                    MemProfFullSchema, DataAccessProfileData);
 }
 
 // Write out the MemProf data in a requested version.
-Error writeMemProf(ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
-                   memprof::IndexedVersion MemProfVersionRequested,
-                   bool MemProfFullSchema) {
+Error writeMemProf(
+    ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
+    memprof::IndexedVersion MemProfVersionRequested, bool MemProfFullSchema,
+    std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+        DataAccessProfileData) {
   switch (MemProfVersionRequested) {
   case memprof::Version2:
     return writeMemProfV2(OS, MemProfData, MemProfFullSchema);
   case memprof::Version3:
     return writeMemProfV3(OS, MemProfData, MemProfFullSchema);
   case memprof::Version4:
-    return writeMemProfV4(OS, MemProfData, MemProfFullSchema);
+    return writeMemProfV4(OS, MemProfData, MemProfFullSchema,
+                          DataAccessProfileData);
   }
 
   return make_error<InstrProfError>(
@@ -357,7 +376,10 @@ Error IndexedMemProfReader::deserializeV2(const unsigned char *Start,
 }
 
 Error IndexedMemProfReader::deserializeRadixTreeBased(
-    const unsigned char *Start, const unsigned char *Ptr) {
+    const unsigned char *Start, const unsigned char *Ptr,
+    memprof::IndexedVersion Version) {
+  assert((Version == memprof::Version3 || Version == memprof::Version4) &&
+         "Unsupported version for radix tree format");
   // The offset in the stream right before invoking
   // CallStackTableGenerator.Emit.
   const uint64_t CallStackPayloadOffset =
@@ -369,6 +391,11 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
   const uint64_t RecordTableOffset =
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
+  uint64_t DataAccessProfOffset = 0;
+  if (Version == memprof::Version4)
+    DataAccessProfOffset =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
   // Read the schema.
   auto SchemaOr = memprof::readMemProfSchema(Ptr);
   if (!SchemaOr)
@@ -390,6 +417,14 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
       /*Payload=*/Start + RecordPayloadOffset,
       /*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
 
+  if (DataAccessProfOffset > RecordTableOffset) {
+    DataAccessProfileData =
+        std::make_unique<data_access_prof::DataAccessProfData>();
+    const unsigned char *DAPPtr = Start + DataAccessProfOffset;
+    if (Error E = DataAccessProfileData->deserialize(DAPPtr))
+      return E;
+  }
+
   return Error::success();
 }
 
@@ -423,7 +458,7 @@ Error IndexedMemProfReader::deserialize(const unsigned char *Start,
   case memprof::Version3:
   case memprof::Version4:
     // V3 and V4 share the same high-level structure (radix tree, linear IDs).
-    if (Error E = deserializeRadixTreeBased(Start, Ptr))
+    if (Error E = deserializeRadixTreeBased(Start, Ptr, Version))
       return E;
     break;
   }
diff --git a/llvm/lib/ProfileData/InstrProfReader.cpp b/llvm/lib/ProfileData/InstrProfReader.cpp
index e6c83430cd8e9..78aba992fcd65 100644
--- a/llvm/lib/ProfileData/InstrProfReader.cpp
+++ b/llvm/lib/ProfileData/InstrProfReader.cpp
@@ -1551,6 +1551,20 @@ memprof::AllMemProfData IndexedMemProfReader::getAllMemProfData() const {
     Pair.Record = std::move(*Record);
     AllMemProfData.HeapProfileRecords.push_back(std::move(Pair));
   }
+  // Populate the data access profiles for yaml output.
+  if (DataAccessProfileData != nullptr) {
+    for (const auto &[SymHandleRef, RecordRef] :
+         DataAccessProfileData->getRecords())
+      AllMemProfData.YamlifiedDataAccessProfiles.Records.push_back(
+          data_access_prof::DataAccessProfRecord(
+              SymHandleRef, RecordRef.AccessCount, RecordRef.Locations));
+    for (StringRef ColdSymbol : DataAccessProfileData->getKnownColdSymbols())
+      AllMemProfData.YamlifiedDataAccessProfiles.KnownColdSymbols.push_back(
+          ColdSymbol.str());
+    for (uint64_t Hash : DataAccessProfileData->getKnownColdHashes())
+      AllMemProfData.YamlifiedDataAccessProfiles.KnownColdHashes.push_back(
+          Hash);
+  }
   return AllMemProfData;
 }
 
diff --git a/llvm/lib/ProfileData/InstrProfWriter.cpp b/llvm/lib/ProfileData/InstrProfWriter.cpp
index 2759346935b14..b0012ccc4f4bf 100644
--- a/llvm/lib/ProfileData/InstrProfWriter.cpp
+++ b/llvm/lib/ProfileData/InstrProfWriter.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/ProfileSummary.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/IndexedMemProfData.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/MemProf.h"
@@ -29,6 +30,7 @@
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <ctime>
+#include <functional>
 #include <memory>
 #include <string>
 #include <tuple>
@@ -152,9 +154,7 @@ void InstrProfWriter::setValueProfDataEndianness(llvm::endianness Endianness) {
   InfoObj->ValueProfDataEndianness = Endianness;
 }
 
-void InstrProfWriter::setOutputSparse(bool Sparse) {
-  this->Sparse = Sparse;
-}
+void InstrProfWriter::setOutputSparse(bool Sparse) { this->Sparse = Sparse; }
 
 void InstrProfWriter::addRecord(NamedInstrProfRecord &&I, uint64_t Weight,
                                 function_ref<void(Error)> Warn) {
@@ -329,6 +329,12 @@ void InstrProfWriter::addBinaryIds(ArrayRef<llvm::object::BuildID> BIs) {
   llvm::append_range(BinaryIds, BIs);
 }
 
+void InstrProfWriter::addDataAccessProfData(
+    std::unique_ptr<data_access_prof::DataAccessProfData>
+        DataAccessProfDataIn) {
+  DataAccessProfileData = std::move(DataAccessProfDataIn);
+}
+
 void InstrProfWriter::addTemporalProfileTrace(TemporalProfTraceTy Trace) {
   assert(Trace.FunctionNameRefs.size() <= MaxTemporalProfTraceLength);
   assert(!Trace.FunctionNameRefs.empty());
@@ -614,8 +620,14 @@ Error InstrProfWriter::writeImpl(ProfOStream &OS) {
   uint64_t MemProfSectionStart = 0;
   if (static_cast<bool>(ProfileKind & InstrProfKind::MemProf)) {
     MemProfSectionStart = OS.tell();
+    std::optional<std::reference_wrapper<data_access_prof::DataAccessProfData>>
+        DAP = std::nullopt;
+    if (DataAccessProfileData.get() != nullptr)
+      DAP = std::ref(*DataAccessProfileData.get());
+
     if (auto E = writeMemProf(OS, MemProfData, MemProfVersionRequested,
-                              MemProfFullSchema))
+                              MemProfFullSchema, DAP))
+
       return E;
   }
 
diff --git a/llvm/lib/ProfileData/MemProfReader.cpp b/llvm/lib/ProfileData/MemProfReader.cpp
index e0f280b9eb2f6..2969029682e74 100644
--- a/llvm/lib/ProfileData/MemProfReader.cpp
+++ b/llvm/lib/ProfileData/MemProfReader.cpp
@@ -37,6 +37,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/Path.h"
 
@@ -823,6 +824,35 @@ void YAMLMemProfReader::parse(StringRef YAMLData) {
 
     MemProfData.Records.try_emplace(GUID, std::move(IndexedRecord));
   }
+
+  if (Doc.YamlifiedDataAccessProfiles.isEmpty())
+    return;
+
+  auto ToSymHandleRef = [](const data_access_prof::SymbolHandle &Handle)
+      -> data_access_prof::SymbolHandleRef {
+    if (std::holds_alternative<std::string>(Handle))
+      return StringRef(std::get<std::string>(Handle));
+    return std::get<uint64_t>(Handle);
+  };
+
+  auto DataAccessProfileData =
+      std::make_unique<data_access_prof::DataAccessProfData>();
+  for (const auto &Record : Doc.YamlifiedDataAccessProfiles.Records)
+    if (Error E = DataAccessProfileData->setDataAccessProfile(
+            ToSymHandleRef(Record.SymHandle), Record.AccessCount,
+            Record.Locations))
+      reportFatalInternalError(std::move(E));
+
+  for (const uint64_t Hash : Doc.YamlifiedDataAccessProfiles.KnownColdHashes)
+    if (Error E = DataAccessProfileData->addKnownSymbolWithoutSamples(Hash))
+      reportFatalInternalError(std::move(E));
+
+  for (const std::string &Sym :
+       Doc.YamlifiedDataAccessProfiles.KnownColdSymbols)
+    if (Error E = DataAccessProfileData->addKnownSymbolWithoutSamples(Sym))
+      reportFatalInternalError(std::move(E));
+
+  setDataAccessProfileData(std::move(DataAccessProfileData));
 }
 } // namespace memprof
 } // namespace llvm
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 9766cc50f37d7..5e0c7fb3ea1d8 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -35,4 +35,15 @@ HeapProfileRecords:
         - { Function: 0x7777777777777777, LineOffset: 77, Column: 70, IsInlineFrame: true }
         - { Function: 0x8888888888888888, LineOffset: 88, Column: 80, IsInlineFrame: false }
         CalleeGuids: [ 0x300 ]
+DataAccessProfiles:
+  SampledRecords:
+    - Symbol:          abcde
+    - Hash:            101010
+      Locations:
+        - FileName:        file
+          Line:            233
+  KnownColdSymbols:
+    - foo
+    - bar
+  KnownColdHashes: [ 999, 1001 ]
 ...
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index 885e06df6c390..8660eed6be2bf 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -16,6 +16,7 @@
 #include "llvm/Debuginfod/HTTPClient.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/Object/Binary.h"
+#include "llvm/ProfileData/DataAccessProf.h"
 #include "llvm/ProfileData/InstrProfCorrelator.h"
 #include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/ProfileData/InstrProfWriter.h"
@@ -756,6 +757,8 @@ loadInput(const WeightedFile &Input, SymbolRemapper *Remapper,
 
     auto MemProfData = Reader->takeMemProfData();
 
+    auto DataAccessProfData = Reader->takeDataAccessProfData();
+
     // Check for the empty input in case the YAML file is invalid.
     if (MemProfData.Records.empty()) {
       WC->Errors.emplace_back(
@@ -764,6 +767,7 @@ loadInput(const WeightedFile &Input, SymbolRemapper *Remapper,
     }
 
     WC->Writer.addMemProfData(std::move(MemProfData), MemProfError);
+    WC->Writer.addDataAccessProfData(std::move(DataAccessProfData));
     return;
   }
 
diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
index 127230d4805e7..084a8e96cdafe 100644
--- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp
+++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp
@@ -78,7 +78,7 @@ TEST(MemProf, DataAccessProfile) {
     // Test that symbol names and file names are stored in the input order.
     EXPECT_THAT(
         llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())),
-        ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+        ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1"));
     EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1"));
     EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678));
 
@@ -134,9 +134,10 @@ TEST(MemProf, DataAccessProfile) {
                 testing::IsEmpty());
     EXPECT_FALSE(deserializedData.deserialize(p));
 
-    EXPECT_THAT(llvm::to_vector(llvm::make_first_range(
-                    deserializedData.getStrToIndexMapRef())),
-                ElementsAre("foo", "bar.__uniq.321", "file2", "file1"));
+    EXPECT_THAT(
+        llvm::to_vector(
+            llvm::make_first_range(deserializedData.getStrToIndexMapRef())),
+        ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1"));
     EXPECT_THAT(deserializedData.getKnownColdSymbols(),
                 ElementsAre("sym2", "sym1"));
     EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678));
@@ -158,7 +159,7 @@ TEST(MemProf, DataAccessProfile) {
                 Field(&DataAccessProfRecordRef::AccessCount, 100),
                 Field(&DataAccessProfRecordRef::IsStringLiteral, false),
                 Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())),
-            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1),
+            AllOf(Field(&DataAccessProfRecordRef::SymbolID, 2),
                   Field(&DataAccessProfRecordRef::AccessCount, 123),
                   Field(&DataAccessProfRecordRef::IsStringLiteral, false),
                   Field(&DataAccessProfRecordRef::Locations,

>From c5d8fcf23c5db83cfaa73ce6af6329b50a180fc1 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 16 May 2025 12:26:11 -0700
Subject: [PATCH 08/14] resolve comments

---
 llvm/include/llvm/ProfileData/MemProfYAML.h   |  5 +-
 llvm/lib/ProfileData/IndexedMemProfData.cpp   |  6 ++-
 .../tools/llvm-profdata/memprof-yaml.test     | 47 ++++++++++++++++++-
 3 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/MemProfYAML.h b/llvm/include/llvm/ProfileData/MemProfYAML.h
index 634edc4aa0122..cdfa69b400ee6 100644
--- a/llvm/include/llvm/ProfileData/MemProfYAML.h
+++ b/llvm/include/llvm/ProfileData/MemProfYAML.h
@@ -264,7 +264,10 @@ template <> struct MappingTraits<memprof::YamlDataAccessProfData> {
 template <> struct MappingTraits<memprof::AllMemProfData> {
   static void mapping(IO &Io, memprof::AllMemProfData &Data) {
     Io.mapRequired("HeapProfileRecords", Data.HeapProfileRecords);
-    Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
+    // Map data access profiles if reading input, or if writing output &&
+    // the struct is populated.
+    if (!Io.outputting() || !Data.YamlifiedDataAccessProfiles.isEmpty())
+      Io.mapOptional("DataAccessProfiles", Data.YamlifiedDataAccessProfiles);
   }
 };
 
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index cc1b03101c880..dc351771f4153 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,7 +228,7 @@ static Error writeMemProfRadixTreeBased(
   OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
   OS.write(0ULL); // Reserve space for the memprof record payload offset.
   OS.write(0ULL); // Reserve space for the memprof record table offset.
-  if (Version == memprof::Version4)
+  if (Version >= memprof::Version4)
     OS.write(0ULL); // Reserve space for the data access profile offset.
 
   auto Schema = memprof::getHotColdSchema();
@@ -258,6 +258,8 @@ static Error writeMemProfRadixTreeBased(
 
   uint64_t DataAccessProfOffset = 0;
   if (DataAccessProfileData.has_value()) {
+    assert(Version >= memprof::Version4 &&
+           "Data access profiles are added starting from v4");
     DataAccessProfOffset = OS.tell();
     if (Error E = (*DataAccessProfileData).get().serialize(OS))
       return E;
@@ -274,7 +276,7 @@ static Error writeMemProfRadixTreeBased(
       RecordPayloadOffset,
       RecordTableOffset,
   };
-  if (Version == memprof::Version4)
+  if (Version >= memprof::Version4)
     Header.push_back(DataAccessProfOffset);
   OS.patch({{HeaderUpdatePos, Header}});
 
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 5e0c7fb3ea1d8..f5b6cf86d2922 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -1,10 +1,17 @@
 ; RUN: split-file %s %t
 ; COM: The text format only supports the latest version.
+
+; Verify that the YAML output is identical to the YAML input.
+; memprof-in.yaml has both heap profile records and data access profiles.
 ; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in.yaml -o %t/memprof-out.indexed
 ; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out.yaml
 ; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
 
-; Verify that the YAML output is identical to the YAML input.
+; memprof-in-no-dap has empty data access profiles.
+; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in-no-dap.yaml -o %t/memprof-out.indexed
+; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out-no-dap.yaml
+; RUN: diff -b %t/memprof-in-no-dap.yaml %t/memprof-out-no-dap.yaml
+
 ;--- memprof-in.yaml
 ---
 HeapProfileRecords:
@@ -38,12 +45,48 @@ HeapProfileRecords:
 DataAccessProfiles:
   SampledRecords:
     - Symbol:          abcde
+      Locations:
+      - FileName:      file2.h
+        Line:          123
+      - FileName:      file3.cpp
+        Line:          456
     - Hash:            101010
       Locations:
-        - FileName:        file
+        - FileName:        file.cpp
           Line:            233
   KnownColdSymbols:
     - foo
     - bar
   KnownColdHashes: [ 999, 1001 ]
 ...
+;--- memprof-in-no-dap.yaml
+---
+HeapProfileRecords:
+  - GUID:            0xdeadbeef12345678
+    AllocSites:
+      - Callstack:
+          - { Function: 0x1111111111111111, LineOffset: 11, Column: 10, IsInlineFrame: true }
+          - { Function: 0x2222222222222222, LineOffset: 22, Column: 20, IsInlineFrame: false }
+        MemInfoBlock:
+          AllocCount:      111
+          TotalSize:       222
+          TotalLifetime:   333
+          TotalLifetimeAccessDensity: 444
+      - Callstack:
+          - { Function: 0x3333333333333333, LineOffset: 33, Column: 30, IsInlineFrame: false }
+          - { Function: 0x4444444444444444, LineOffset: 44, Column: 40, IsInlineFrame: true }
+        MemInfoBlock:
+          AllocCount:      555
+          TotalSize:       666
+          TotalLifetime:   777
+          TotalLifetimeAccessDensity: 888
+    CallSites:
+      - Frames:
+        - { Function: 0x5555555555555555, LineOffset: 55, Column: 50, IsInlineFrame: true }
+        - { Function: 0x6666666666666666, LineOffset: 66, Column: 60, IsInlineFrame: false }
+        CalleeGuids: [ 0x100, 0x200 ]
+      - Frames:
+        - { Function: 0x7777777777777777, LineOffset: 77, Column: 70, IsInlineFrame: true }
+        - { Function: 0x8888888888888888, LineOffset: 88, Column: 80, IsInlineFrame: false }
+        CalleeGuids: [ 0x300 ]
+...

>From 0fb3e6a243415e3514e5904de48d5ca880026765 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 16 May 2025 16:24:09 -0700
Subject: [PATCH 09/14] record data access profile payload length

---
 llvm/lib/ProfileData/IndexedMemProfData.cpp | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index dc351771f4153..3ef568025c60c 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,8 +228,10 @@ static Error writeMemProfRadixTreeBased(
   OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
   OS.write(0ULL); // Reserve space for the memprof record payload offset.
   OS.write(0ULL); // Reserve space for the memprof record table offset.
-  if (Version >= memprof::Version4)
+  if (Version >= memprof::Version4) {
     OS.write(0ULL); // Reserve space for the data access profile offset.
+    OS.write(0ULL); // Reserve space for the size of data access profiles.
+  }
 
   auto Schema = memprof::getHotColdSchema();
   if (MemProfFullSchema)
@@ -257,12 +259,14 @@ static Error writeMemProfRadixTreeBased(
       OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
 
   uint64_t DataAccessProfOffset = 0;
+  uint64_t DataAccessProfLength = 0;
   if (DataAccessProfileData.has_value()) {
     assert(Version >= memprof::Version4 &&
            "Data access profiles are added starting from v4");
     DataAccessProfOffset = OS.tell();
     if (Error E = (*DataAccessProfileData).get().serialize(OS))
       return E;
+    DataAccessProfLength = OS.tell() - DataAccessProfOffset;
   }
 
   // Verify that the computation for the number of elements in the call stack
@@ -271,13 +275,15 @@ static Error writeMemProfRadixTreeBased(
              NumElements * sizeof(memprof::LinearFrameId) ==
          RecordPayloadOffset);
 
-  SmallVector<uint64_t, 4> Header = {
+  SmallVector<uint64_t, 8> Header = {
       CallStackPayloadOffset,
       RecordPayloadOffset,
       RecordTableOffset,
   };
-  if (Version >= memprof::Version4)
+  if (Version >= memprof::Version4) {
     Header.push_back(DataAccessProfOffset);
+    Header.push_back(DataAccessProfLength);
+  }
   OS.patch({{HeaderUpdatePos, Header}});
 
   return Error::success();
@@ -394,9 +400,13 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
   uint64_t DataAccessProfOffset = 0;
-  if (Version == memprof::Version4)
+  uint64_t DataAccessProfLength = 0;
+  if (Version == memprof::Version4) {
     DataAccessProfOffset =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+    DataAccessProfLength =
+        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+  }
 
   // Read the schema.
   auto SchemaOr = memprof::readMemProfSchema(Ptr);
@@ -419,7 +429,7 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
       /*Payload=*/Start + RecordPayloadOffset,
       /*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
 
-  if (DataAccessProfOffset > RecordTableOffset) {
+  if (DataAccessProfLength > 0) {
     DataAccessProfileData =
         std::make_unique<data_access_prof::DataAccessProfData>();
     const unsigned char *DAPPtr = Start + DataAccessProfOffset;

>From d3443c752d77d4ddefea472a4787c501c97bb092 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 19 May 2025 16:44:05 -0700
Subject: [PATCH 10/14] undo length record and check

---
 llvm/lib/ProfileData/IndexedMemProfData.cpp   | 23 +++++++------------
 .../tools/llvm-profdata/memprof-yaml.test     |  2 +-
 2 files changed, 9 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 3ef568025c60c..d5f134629529c 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -228,10 +228,8 @@ static Error writeMemProfRadixTreeBased(
   OS.write(0ULL); // Reserve space for the memprof call stack payload offset.
   OS.write(0ULL); // Reserve space for the memprof record payload offset.
   OS.write(0ULL); // Reserve space for the memprof record table offset.
-  if (Version >= memprof::Version4) {
+  if (Version >= memprof::Version4)
     OS.write(0ULL); // Reserve space for the data access profile offset.
-    OS.write(0ULL); // Reserve space for the size of data access profiles.
-  }
 
   auto Schema = memprof::getHotColdSchema();
   if (MemProfFullSchema)
@@ -259,14 +257,12 @@ static Error writeMemProfRadixTreeBased(
       OS, MemProfData.Records, &Schema, Version, &MemProfCallStackIndexes);
 
   uint64_t DataAccessProfOffset = 0;
-  uint64_t DataAccessProfLength = 0;
   if (DataAccessProfileData.has_value()) {
     assert(Version >= memprof::Version4 &&
            "Data access profiles are added starting from v4");
     DataAccessProfOffset = OS.tell();
     if (Error E = (*DataAccessProfileData).get().serialize(OS))
       return E;
-    DataAccessProfLength = OS.tell() - DataAccessProfOffset;
   }
 
   // Verify that the computation for the number of elements in the call stack
@@ -275,15 +271,14 @@ static Error writeMemProfRadixTreeBased(
              NumElements * sizeof(memprof::LinearFrameId) ==
          RecordPayloadOffset);
 
-  SmallVector<uint64_t, 8> Header = {
+  SmallVector<uint64_t, 4> Header = {
       CallStackPayloadOffset,
       RecordPayloadOffset,
       RecordTableOffset,
   };
-  if (Version >= memprof::Version4) {
+  if (Version >= memprof::Version4)
     Header.push_back(DataAccessProfOffset);
-    Header.push_back(DataAccessProfLength);
-  }
+
   OS.patch({{HeaderUpdatePos, Header}});
 
   return Error::success();
@@ -400,13 +395,9 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
 
   uint64_t DataAccessProfOffset = 0;
-  uint64_t DataAccessProfLength = 0;
-  if (Version == memprof::Version4) {
+  if (Version == memprof::Version4)
     DataAccessProfOffset =
         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
-    DataAccessProfLength =
-        support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
-  }
 
   // Read the schema.
   auto SchemaOr = memprof::readMemProfSchema(Ptr);
@@ -429,7 +420,9 @@ Error IndexedMemProfReader::deserializeRadixTreeBased(
       /*Payload=*/Start + RecordPayloadOffset,
       /*Base=*/Start, memprof::RecordLookupTrait(Version, Schema)));
 
-  if (DataAccessProfLength > 0) {
+  assert((!DataAccessProfOffset || DataAccessProfOffset > RecordTableOffset) &&
+         "Data access profile is either empty or after the record table");
+  if (DataAccessProfOffset > RecordTableOffset) {
     DataAccessProfileData =
         std::make_unique<data_access_prof::DataAccessProfData>();
     const unsigned char *DAPPtr = Start + DataAccessProfOffset;
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index f5b6cf86d2922..ff0d82d92e1f6 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -7,7 +7,7 @@
 ; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out.yaml
 ; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
 
-; memprof-in-no-dap has empty data access profiles.
+; memprof-in-no-dap.yaml has empty data access profiles.
 ; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in-no-dap.yaml -o %t/memprof-out.indexed
 ; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out-no-dap.yaml
 ; RUN: diff -b %t/memprof-in-no-dap.yaml %t/memprof-out-no-dap.yaml

>From 5d0e236a0dad61b4df185e275919978af059393c Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Wed, 21 May 2025 22:22:43 -0700
Subject: [PATCH 11/14] resolve comments

---
 llvm/include/llvm/ProfileData/DataAccessProf.h     | 2 +-
 llvm/include/llvm/ProfileData/IndexedMemProfData.h | 3 +--
 llvm/lib/ProfileData/IndexedMemProfData.cpp        | 5 +++--
 llvm/lib/ProfileData/InstrProfWriter.cpp           | 1 -
 4 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index ee6805b677380..cd4b200486a3f 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -112,7 +112,7 @@ struct DataAccessProfRecord {
       SymHandle = std::get<uint64_t>(SymHandleRef);
 
     for (auto Loc : LocRefs)
-      Locations.emplace_back(SourceLocation(Loc.FileName, Loc.Line));
+      Locations.emplace_back(Loc.FileName, Loc.Line);
   }
   // Empty constructor is used in yaml conversion.
   DataAccessProfRecord() {}
diff --git a/llvm/include/llvm/ProfileData/IndexedMemProfData.h b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
index 18be88ab742c5..2b40094a9bc21 100644
--- a/llvm/include/llvm/ProfileData/IndexedMemProfData.h
+++ b/llvm/include/llvm/ProfileData/IndexedMemProfData.h
@@ -89,8 +89,7 @@ struct IndexedMemProfData {
 Error writeMemProf(
     ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
     memprof::IndexedVersion MemProfVersionRequested, bool MemProfFullSchema,
-    std::unique_ptr<memprof::DataAccessProfData>
-        DataAccessProfileData);
+    std::unique_ptr<memprof::DataAccessProfData> DataAccessProfileData);
 
 } // namespace llvm
 #endif
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 99d39c3234e02..7398e4c468bbe 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -219,7 +219,8 @@ static Error writeMemProfV2(ProfOStream &OS,
 static Error writeMemProfRadixTreeBased(
     ProfOStream &OS, memprof::IndexedMemProfData &MemProfData,
     memprof::IndexedVersion Version, bool MemProfFullSchema,
-    std::unique_ptr<memprof::DataAccessProfData> DataAccessProfileData) {
+    std::unique_ptr<memprof::DataAccessProfData> DataAccessProfileData =
+        nullptr) {
   assert((Version == memprof::Version3 || Version == memprof::Version4) &&
          "Unsupported version for radix tree format");
 
@@ -289,7 +290,7 @@ static Error writeMemProfV3(ProfOStream &OS,
                             memprof::IndexedMemProfData &MemProfData,
                             bool MemProfFullSchema) {
   return writeMemProfRadixTreeBased(OS, MemProfData, memprof::Version3,
-                                    MemProfFullSchema, nullptr);
+                                    MemProfFullSchema);
 }
 
 // Write out MemProf Version4
diff --git a/llvm/lib/ProfileData/InstrProfWriter.cpp b/llvm/lib/ProfileData/InstrProfWriter.cpp
index f13b858ebbc12..039e1bc955cd4 100644
--- a/llvm/lib/ProfileData/InstrProfWriter.cpp
+++ b/llvm/lib/ProfileData/InstrProfWriter.cpp
@@ -29,7 +29,6 @@
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <ctime>
-#include <functional>
 #include <memory>
 #include <string>
 #include <tuple>

>From b14f1e2c592ae9defc6283c7fb42ed028aff9749 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 22 May 2025 09:38:37 -0700
Subject: [PATCH 12/14] Test backward compatibility

---
 .../tools/llvm-profdata/memprof-yaml.test     | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 8683da8249153..302ce3a1aea87 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -7,6 +7,12 @@
 ; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out.yaml
 ; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
 
+; Merge text profile as v3 binary profile. Test that the merged v3 profile
+; are identical to memprof-in-v3.ymal, and doesn't have callee guids or dap.
+; RUN: llvm-profdata merge --memprof-version=3 %t/memprof-in.yaml -o %t/memprof-out-v3.indexed
+; RUN: llvm-profdata show --memory %t/memprof-out-v3.indexed > %t/memprof-out-v3.yaml
+; RUN: diff -b %t/memprof-out-v3.yaml %t/memprof-in-v3.yaml
+
 ; memprof-in-no-dap.yaml has empty data access profiles.
 ; RUN: llvm-profdata merge --memprof-version=4 %t/memprof-in-no-dap.yaml -o %t/memprof-out.indexed
 ; RUN: llvm-profdata show --memory %t/memprof-out.indexed > %t/memprof-out-no-dap.yaml
@@ -59,6 +65,35 @@ DataAccessProfiles:
     - bar
   KnownColdStrHashes: [ 999, 1001 ]
 ...
+;--- memprof-in-v3.yaml
+---
+HeapProfileRecords:
+  - GUID:            0xdeadbeef12345678
+    AllocSites:
+      - Callstack:
+          - { Function: 0x1111111111111111, LineOffset: 11, Column: 10, IsInlineFrame: true }
+          - { Function: 0x2222222222222222, LineOffset: 22, Column: 20, IsInlineFrame: false }
+        MemInfoBlock:
+          AllocCount:      111
+          TotalSize:       222
+          TotalLifetime:   333
+          TotalLifetimeAccessDensity: 444
+      - Callstack:
+          - { Function: 0x3333333333333333, LineOffset: 33, Column: 30, IsInlineFrame: false }
+          - { Function: 0x4444444444444444, LineOffset: 44, Column: 40, IsInlineFrame: true }
+        MemInfoBlock:
+          AllocCount:      555
+          TotalSize:       666
+          TotalLifetime:   777
+          TotalLifetimeAccessDensity: 888
+    CallSites:
+      - Frames:
+        - { Function: 0x5555555555555555, LineOffset: 55, Column: 50, IsInlineFrame: true }
+        - { Function: 0x6666666666666666, LineOffset: 66, Column: 60, IsInlineFrame: false }
+      - Frames:
+        - { Function: 0x7777777777777777, LineOffset: 77, Column: 70, IsInlineFrame: true }
+        - { Function: 0x8888888888888888, LineOffset: 88, Column: 80, IsInlineFrame: false }
+...
 ;--- memprof-in-no-dap.yaml
 ---
 HeapProfileRecords:

>From 09eab646fb13763c747de9a42e0d9260520381f1 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 22 May 2025 11:12:03 -0700
Subject: [PATCH 13/14] fix typo

---
 llvm/test/tools/llvm-profdata/memprof-yaml.test | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 302ce3a1aea87..5beda8c036a12 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -8,7 +8,7 @@
 ; RUN: diff -b %t/memprof-in.yaml %t/memprof-out.yaml
 
 ; Merge text profile as v3 binary profile. Test that the merged v3 profile
-; are identical to memprof-in-v3.ymal, and doesn't have callee guids or dap.
+; are identical to memprof-in-v3.yaml, and doesn't have callee guids or dap.
 ; RUN: llvm-profdata merge --memprof-version=3 %t/memprof-in.yaml -o %t/memprof-out-v3.indexed
 ; RUN: llvm-profdata show --memory %t/memprof-out-v3.indexed > %t/memprof-out-v3.yaml
 ; RUN: diff -b %t/memprof-out-v3.yaml %t/memprof-in-v3.yaml

>From 6b63f73a72030b3107c39b9a7f0ad04c921677fb Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 23 May 2025 11:36:50 -0700
Subject: [PATCH 14/14] fix use-of-uninitialized-memory error

---
 llvm/include/llvm/ProfileData/DataAccessProf.h  | 2 +-
 llvm/include/llvm/ProfileData/MemProfYAML.h     | 2 +-
 llvm/test/tools/llvm-profdata/memprof-yaml.test | 2 ++
 3 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h
index cd4b200486a3f..c0f0c6d9c9fc1 100644
--- a/llvm/include/llvm/ProfileData/DataAccessProf.h
+++ b/llvm/include/llvm/ProfileData/DataAccessProf.h
@@ -115,7 +115,7 @@ struct DataAccessProfRecord {
       Locations.emplace_back(Loc.FileName, Loc.Line);
   }
   // Empty constructor is used in yaml conversion.
-  DataAccessProfRecord() {}
+  DataAccessProfRecord() : AccessCount(0) {}
   SymbolHandle SymHandle;
   uint64_t AccessCount;
   // The locations of data in the source code. Optional.
diff --git a/llvm/include/llvm/ProfileData/MemProfYAML.h b/llvm/include/llvm/ProfileData/MemProfYAML.h
index f8dc659f66662..ad5d7c0e22751 100644
--- a/llvm/include/llvm/ProfileData/MemProfYAML.h
+++ b/llvm/include/llvm/ProfileData/MemProfYAML.h
@@ -248,7 +248,7 @@ template <> struct MappingTraits<memprof::DataAccessProfRecord> {
         Rec.SymHandle = Hash;
       }
     }
-
+    Io.mapRequired("AccessCount", Rec.AccessCount);
     Io.mapOptional("Locations", Rec.Locations);
   }
 };
diff --git a/llvm/test/tools/llvm-profdata/memprof-yaml.test b/llvm/test/tools/llvm-profdata/memprof-yaml.test
index 5beda8c036a12..0caa1fe5d9fd3 100644
--- a/llvm/test/tools/llvm-profdata/memprof-yaml.test
+++ b/llvm/test/tools/llvm-profdata/memprof-yaml.test
@@ -51,12 +51,14 @@ HeapProfileRecords:
 DataAccessProfiles:
   SampledRecords:
     - Symbol:          abcde
+      AccessCount:     100
       Locations:
       - FileName:      file2.h
         Line:          123
       - FileName:      file3.cpp
         Line:          456
     - Hash:            101010
+      AccessCount:     200
       Locations:
         - FileName:        file.cpp
           Line:            233



More information about the llvm-commits mailing list