[llvm] [StaticDataLayout][PGO] Add profile format for static data layout, and the classes to operate on the profiles. (PR #138170)
Teresa Johnson via llvm-commits
llvm-commits at lists.llvm.org
Tue May 6 09:23:24 PDT 2025
================
@@ -0,0 +1,247 @@
+#include "llvm/ProfileData/DataAccessProf.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ProfileData/InstrProf.h"
+#include "llvm/Support/Compression.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <sys/types.h>
+
+namespace llvm {
+namespace data_access_prof {
+
+// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
+// creates an owned copy of `Str`, adds a map entry for it and returns the
+// iterator.
+static MapVector<StringRef, uint64_t>::iterator
+saveStringToMap(MapVector<StringRef, uint64_t> &Map,
+ llvm::UniqueStringSaver &Saver, StringRef Str) {
+ auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
+ return Iter;
+}
+
+const DataAccessProfRecord *
+DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const {
+ auto Key = SymbolID;
+ if (std::holds_alternative<StringRef>(SymbolID))
+ Key = InstrProfSymtab::getCanonicalName(std::get<StringRef>(SymbolID));
+
+ auto It = Records.find(Key);
+ if (It != Records.end())
+ return &It->second;
+
+ return nullptr;
+}
+
+bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const {
+ if (std::holds_alternative<uint64_t>(SymID))
+ return KnownColdHashes.contains(std::get<uint64_t>(SymID));
+ return KnownColdSymbols.contains(std::get<StringRef>(SymID));
+}
+
+Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol,
+ uint64_t AccessCount) {
+ uint64_t RecordID = -1;
+ bool IsStringLiteral = false;
+ SymbolID Key;
+ if (std::holds_alternative<uint64_t>(Symbol)) {
+ RecordID = std::get<uint64_t>(Symbol);
+ Key = RecordID;
+ IsStringLiteral = true;
+ } else {
+ StringRef SymbolName = std::get<StringRef>(Symbol);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+
+ StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName);
+ Key = CanonicalName;
+ RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second;
+ IsStringLiteral = false;
+ }
+
+ auto [Iter, Inserted] =
+ Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
+ if (!Inserted)
+ return make_error<StringError>("Duplicate symbol or string literal added. "
+ "User of DataAccessProfData should "
+ "aggregate count for the same symbol. ",
+ llvm::errc::invalid_argument);
+
+ return Error::success();
+}
+
+Error DataAccessProfData::setDataAccessProfile(
+ SymbolID SymbolID, uint64_t AccessCount,
+ const llvm::SmallVector<DataLocation> &Locations) {
+ if (Error E = setDataAccessProfile(SymbolID, AccessCount))
+ return E;
+
+ auto &Record = Records.back().second;
+ for (const auto &Location : Locations)
+ Record.Locations.push_back(
+ {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first,
+ Location.Line});
+
+ return Error::success();
+}
+
+Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) {
+ if (std::holds_alternative<uint64_t>(SymbolID)) {
+ KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
+ return Error::success();
+ }
+ StringRef SymbolName = std::get<StringRef>(SymbolID);
+ if (SymbolName.empty())
+ return make_error<StringError>("Empty symbol name",
+ llvm::errc::invalid_argument);
+ StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName);
+ KnownColdSymbols.insert(CanonicalSymName);
+ return Error::success();
+}
+
+Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
+ uint64_t NumSampledSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ uint64_t NumColdKnownSymbols =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
+ NumColdKnownSymbols))
+ return E;
+
+ uint64_t Num =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+ for (uint64_t I = 0; I < Num; ++I)
+ KnownColdHashes.insert(
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
+
+ return deserializeRecords(Ptr);
+}
+
+Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
+ OS.write(StrToIndexMap.size());
+ OS.write(KnownColdSymbols.size());
+
+ std::vector<std::string> Strs;
+ Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
+ for (const auto &Str : StrToIndexMap)
+ Strs.push_back(Str.first.str());
+ for (const auto &Str : KnownColdSymbols)
+ Strs.push_back(Str.str());
+
+ std::string CompressedStrings;
+ if (!Strs.empty())
+ if (Error E = collectGlobalObjectNameStrings(
+ Strs, compression::zlib::isAvailable(), CompressedStrings))
+ return E;
+ const uint64_t CompressedStringLen = CompressedStrings.length();
+ // Record the length of compressed string.
+ OS.write(CompressedStringLen);
+ // Write the chars in compressed strings.
+ for (auto &c : CompressedStrings)
+ OS.writeByte(static_cast<uint8_t>(c));
+ // Pad up to a multiple of 8.
+ // InstrProfReader could read bytes according to 'CompressedStringLen'.
+ const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
+ for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
+ OS.writeByte(0);
+ return Error::success();
+}
+
+uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const {
+ if (std::holds_alternative<uint64_t>(SymbolID))
+ return std::get<uint64_t>(SymbolID);
+
+ return StrToIndexMap.find(std::get<StringRef>(SymbolID))->second;
+}
+
+Error DataAccessProfData::serialize(ProfOStream &OS) const {
+ if (Error E = serializeSymbolsAndFilenames(OS))
+ return E;
+ OS.write(KnownColdHashes.size());
+ for (const auto &Hash : KnownColdHashes)
+ OS.write(Hash);
+ OS.write((uint64_t)(Records.size()));
+ for (const auto &[Key, Rec] : Records) {
+ OS.write(getEncodedIndex(Rec.SymbolID));
+ OS.writeByte(Rec.IsStringLiteral);
+ OS.write(Rec.AccessCount);
+ OS.write(Rec.Locations.size());
+ for (const auto &Loc : Rec.Locations) {
+ OS.write(getEncodedIndex(Loc.FileName));
+ OS.write32(Loc.Line);
+ }
+ }
+ return Error::success();
+}
+
+Error DataAccessProfData::deserializeSymbolsAndFilenames(
+ const unsigned char *&Ptr, uint64_t NumSampledSymbols,
+ uint64_t NumColdKnownSymbols) {
+ uint64_t Len =
+ support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
+
+ // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are
----------------
teresajohnson wrote:
Simpler to directly say "The first NumSampledSymbols strings are symbols with samples, and the next NumColdKnownSymbols are known cold symbols." ?
https://github.com/llvm/llvm-project/pull/138170
More information about the llvm-commits
mailing list