[llvm] [SampleFDO][TypeProf] Support vtable type profiling in ext-binary and text format. (PR #141649)

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 1 19:14:50 PDT 2025


https://github.com/mingmingl-llvm updated https://github.com/llvm/llvm-project/pull/141649

>From 2664d700b0f2c681f28161e11ea967e58a88df01 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 14 Apr 2025 05:12:23 +0000
Subject: [PATCH 1/3] profile format change

---
 llvm/include/llvm/ProfileData/SampleProf.h    | 48 +++++++++++--
 .../llvm/ProfileData/SampleProfReader.h       | 19 +++++
 .../llvm/ProfileData/SampleProfWriter.h       | 38 +++++++++-
 llvm/lib/ProfileData/SampleProfReader.cpp     | 70 ++++++++++++++++++-
 llvm/lib/ProfileData/SampleProfWriter.cpp     | 65 ++++++++++++++++-
 5 files changed, 230 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index e7b154dff0697..d3997f9b88c7f 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -201,6 +201,9 @@ enum class SecProfSummaryFlags : uint32_t {
   /// SecFlagIsPreInlined means this profile contains ShouldBeInlined
   /// contexts thus this is CS preinliner computed.
   SecFlagIsPreInlined = (1 << 4),
+
+  /// SecFlagHasVTableTypeProf means this profile contains vtable type profiles.
+  SecFlagHasVTableTypeProf = (1 << 5),
 };
 
 enum class SecFuncMetadataFlags : uint32_t {
@@ -312,16 +315,19 @@ struct LineLocationHash {
 
 raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc);
 
+using TypeMap = std::map<FunctionId, uint64_t>;
+
 /// Representation of a single sample record.
 ///
 /// A sample record is represented by a positive integer value, which
 /// indicates how frequently was the associated line location executed.
 ///
 /// Additionally, if the associated location contains a function call,
-/// the record will hold a list of all the possible called targets. For
-/// direct calls, this will be the exact function being invoked. For
-/// indirect calls (function pointers, virtual table dispatch), this
-/// will be a list of one or more functions.
+/// the record will hold a list of all the possible called targets and the types
+/// for virtual table dispatches. For direct calls, this will be the exact
+/// function being invoked. For indirect calls (function pointers, virtual table
+/// dispatch), this will be a list of one or more functions. For virtual table
+/// dispatches, this record will also hold the type of the object.
 class SampleRecord {
 public:
   using CallTarget = std::pair<FunctionId, uint64_t>;
@@ -336,6 +342,7 @@ class SampleRecord {
 
   using SortedCallTargetSet = std::set<CallTarget, CallTargetComparator>;
   using CallTargetMap = std::unordered_map<FunctionId, uint64_t>;
+
   SampleRecord() = default;
 
   /// Increment the number of samples for this record by \p S.
@@ -374,6 +381,14 @@ class SampleRecord {
                       : sampleprof_error::success;
   }
 
+  sampleprof_error addTypeCount(FunctionId F, uint64_t S, uint64_t Weight = 1) {
+    uint64_t &Samples = TypeCounts[F];
+    bool Overflowed;
+    Samples = SaturatingMultiplyAdd(S, Weight, Samples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
+  }
+
   /// Remove called function from the call target map. Return the target sample
   /// count of the called function.
   uint64_t removeCalledTarget(FunctionId F) {
@@ -391,6 +406,7 @@ class SampleRecord {
 
   uint64_t getSamples() const { return NumSamples; }
   const CallTargetMap &getCallTargets() const { return CallTargets; }
+  const TypeMap &getTypes() const { return TypeCounts; }
   const SortedCallTargetSet getSortedCallTargets() const {
     return sortCallTargets(CallTargets);
   }
@@ -439,6 +455,7 @@ class SampleRecord {
 private:
   uint64_t NumSamples = 0;
   CallTargetMap CallTargets;
+  TypeMap TypeCounts;
 };
 
 raw_ostream &operator<<(raw_ostream &OS, const SampleRecord &Sample);
@@ -734,6 +751,7 @@ using BodySampleMap = std::map<LineLocation, SampleRecord>;
 // memory, which is *very* significant for large profiles.
 using FunctionSamplesMap = std::map<FunctionId, FunctionSamples>;
 using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
+using CallsiteTypeMap = std::map<LineLocation, TypeMap>;
 using LocToLocMap =
     std::unordered_map<LineLocation, LineLocation, LineLocationHash>;
 
@@ -791,6 +809,11 @@ class FunctionSamples {
         Func, Num, Weight);
   }
 
+  sampleprof_error addTypeSamples(const LineLocation &Loc, FunctionId Func,
+                                  uint64_t Num, uint64_t Weight = 1) {
+    return BodySamples[Loc].addTypeCount(Func, Num, Weight);
+  }
+
   sampleprof_error addSampleRecord(LineLocation Location,
                                    const SampleRecord &SampleRecord,
                                    uint64_t Weight = 1) {
@@ -916,6 +939,13 @@ class FunctionSamples {
     return &Iter->second;
   }
 
+  const TypeMap *findTypeSamplesAt(const LineLocation &Loc) const {
+    auto Iter = VirtualCallsiteTypes.find(mapIRLocToProfileLoc(Loc));
+    if (Iter == VirtualCallsiteTypes.end())
+      return nullptr;
+    return &Iter->second;
+  }
+
   /// Returns a pointer to FunctionSamples at the given callsite location
   /// \p Loc with callee \p CalleeName. If no callsite can be found, relax
   /// the restriction to return the FunctionSamples at callsite location
@@ -977,6 +1007,14 @@ class FunctionSamples {
     return CallsiteSamples;
   }
 
+  const CallsiteTypeMap &getCallsiteTypes() const {
+    return VirtualCallsiteTypes;
+  }
+
+  TypeMap& getTypeSamplesAt(const LineLocation &Loc) {
+    return VirtualCallsiteTypes[mapIRLocToProfileLoc(Loc)];
+  }
+
   /// Return the maximum of sample counts in a function body. When SkipCallSite
   /// is false, which is the default, the return count includes samples in the
   /// inlined functions. When SkipCallSite is true, the return count only
@@ -1274,6 +1312,8 @@ class FunctionSamples {
   /// collected in the call to baz() at line offset 8.
   CallsiteSampleMap CallsiteSamples;
 
+  CallsiteTypeMap VirtualCallsiteTypes;
+
   /// IR to profile location map generated by stale profile matching.
   ///
   /// Each entry is a mapping from the location on current build to the matched
diff --git a/llvm/include/llvm/ProfileData/SampleProfReader.h b/llvm/include/llvm/ProfileData/SampleProfReader.h
index 76c7cecded629..6ab3b650c0750 100644
--- a/llvm/include/llvm/ProfileData/SampleProfReader.h
+++ b/llvm/include/llvm/ProfileData/SampleProfReader.h
@@ -701,6 +701,17 @@ class SampleProfileReaderBinary : public SampleProfileReader {
   /// otherwise same as readStringFromTable, also return its hash value.
   ErrorOr<std::pair<SampleContext, uint64_t>> readSampleContextFromTable();
 
+  /// Overridden by SampleProfileReaderExtBinary to read the vtable profile.
+  virtual std::error_code readVTableProf(const LineLocation &Loc,
+                                         FunctionSamples &FProfile) {
+    return sampleprof_error::success;
+  }
+
+  virtual std::error_code readCallsiteVTableProf(FunctionSamples &FProfile) {
+    return sampleprof_error::success;
+  }
+
+
   /// Points to the current location in the buffer.
   const uint8_t *Data = nullptr;
 
@@ -814,6 +825,8 @@ class SampleProfileReaderExtBinaryBase : public SampleProfileReaderBinary {
   /// The set containing the functions to use when compiling a module.
   DenseSet<StringRef> FuncsToUse;
 
+  bool ReadVTableProf = false;
+
 public:
   SampleProfileReaderExtBinaryBase(std::unique_ptr<MemoryBuffer> B,
                                    LLVMContext &C, SampleProfileFormat Format)
@@ -854,6 +867,12 @@ class SampleProfileReaderExtBinary : public SampleProfileReaderExtBinaryBase {
     return sampleprof_error::success;
   };
 
+  std::error_code readVTableProf(const LineLocation &Loc,
+                                 FunctionSamples &FProfile) override;
+
+  std::error_code readCallsiteVTableProf(FunctionSamples &FProfile) override;
+
+   std::error_code readTypeMap(TypeMap& M);
 public:
   SampleProfileReaderExtBinary(std::unique_ptr<MemoryBuffer> B, LLVMContext &C,
                                SampleProfileFormat Format = SPF_Ext_Binary)
diff --git a/llvm/include/llvm/ProfileData/SampleProfWriter.h b/llvm/include/llvm/ProfileData/SampleProfWriter.h
index 4b659eaf950b3..63082b95807db 100644
--- a/llvm/include/llvm/ProfileData/SampleProfWriter.h
+++ b/llvm/include/llvm/ProfileData/SampleProfWriter.h
@@ -209,6 +209,17 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
   virtual MapVector<FunctionId, uint32_t> &getNameTable() { return NameTable; }
   virtual std::error_code writeMagicIdent(SampleProfileFormat Format);
   virtual std::error_code writeNameTable();
+  virtual std::error_code
+  writeSampleRecordVTableProf(const SampleRecord &Record, raw_ostream &OS) {
+    // TODO: This is not virtual because SampleProfWriter may create objects of
+    // type SampleProfileWriterRawBinary.
+    return sampleprof_error::success;
+  }
+  virtual std::error_code
+  writeCallsiteType(const FunctionSamples &FunctionSample, raw_ostream &OS) {
+    return sampleprof_error::success;
+  }
+  virtual void addTypeNames(const TypeMap &M) {}
   std::error_code writeHeader(const SampleProfileMap &ProfileMap) override;
   std::error_code writeSummary();
   virtual std::error_code writeContextIdx(const SampleContext &Context);
@@ -218,8 +229,9 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
                                 std::set<FunctionId> &V);
   
   MapVector<FunctionId, uint32_t> NameTable;
-  
+
   void addName(FunctionId FName);
+  void addTypeName(FunctionId TypeName);
   virtual void addContext(const SampleContext &Context);
   void addNames(const FunctionSamples &S);
 
@@ -409,8 +421,23 @@ class SampleProfileWriterExtBinaryBase : public SampleProfileWriterBinary {
 
 class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
 public:
-  SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS)
-      : SampleProfileWriterExtBinaryBase(OS) {}
+  SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS);
+
+protected:
+  std::error_code writeSampleRecordVTableProf(const SampleRecord &Record,
+                                              raw_ostream &OS) override;
+
+  std::error_code writeCallsiteType(const FunctionSamples &FunctionSample,
+                                    raw_ostream &OS) override;
+
+  void addTypeNames(const TypeMap &M) override {
+    if (WriteVTableProf)
+      return;
+    // Add type name to TypeNameTable.
+    for (const auto &[Type, Cnt] : M) {
+      addName(Type);
+    }
+  }
 
 private:
   std::error_code writeDefaultLayout(const SampleProfileMap &ProfileMap);
@@ -426,6 +453,11 @@ class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
     assert((SL == DefaultLayout || SL == CtxSplitLayout) &&
            "Unsupported layout");
   }
+
+  std::error_code writeTypeMap(const TypeMap &Map, raw_ostream& OS);
+
+  // TODO:This should be configurable by flag.
+  bool WriteVTableProf = false;
 };
 
 } // end namespace sampleprof
diff --git a/llvm/lib/ProfileData/SampleProfReader.cpp b/llvm/lib/ProfileData/SampleProfReader.cpp
index d97cc479442e4..f5fd711d4dd70 100644
--- a/llvm/lib/ProfileData/SampleProfReader.cpp
+++ b/llvm/lib/ProfileData/SampleProfReader.cpp
@@ -594,6 +594,67 @@ SampleProfileReaderBinary::readSampleContextFromTable() {
   return std::make_pair(Context, Hash);
 }
 
+std::error_code SampleProfileReaderExtBinary::readTypeMap(TypeMap &M) {
+  auto NumVTableTypes = readNumber<uint32_t>();
+  if (std::error_code EC = NumVTableTypes.getError())
+    return EC;
+
+  for (uint32_t I = 0; I < *NumVTableTypes; ++I) {
+    auto VTableType(readStringFromTable());
+    if (std::error_code EC = VTableType.getError())
+      return EC;
+
+    auto VTableSamples = readNumber<uint64_t>();
+    if (std::error_code EC = VTableSamples.getError())
+      return EC;
+
+    M.insert(std::make_pair(*VTableType, *VTableSamples));
+  }
+  return sampleprof_error::success;
+}
+
+std::error_code
+SampleProfileReaderExtBinary::readVTableProf(const LineLocation &Loc,
+                                             FunctionSamples &FProfile) {
+  if (!ReadVTableProf)
+    return sampleprof_error::success;
+
+  return readTypeMap(FProfile.getTypeSamplesAt(Loc));
+}
+
+std::error_code SampleProfileReaderExtBinary::readCallsiteVTableProf(
+    FunctionSamples &FProfile) {
+  if (!ReadVTableProf)
+    return sampleprof_error::success;
+
+  // Read the vtable type profile for the callsite.
+  auto NumCallsites = readNumber<uint32_t>();
+  if (std::error_code EC = NumCallsites.getError())
+    return EC;
+
+  for (uint32_t I = 0; I < *NumCallsites; ++I) {
+    auto LineOffset = readNumber<uint64_t>();
+    if (std::error_code EC = LineOffset.getError())
+      return EC;
+
+    auto Discriminator = readNumber<uint64_t>();
+    if (std::error_code EC = Discriminator.getError())
+      return EC;
+
+    // Here we handle FS discriminators:
+    uint32_t DiscriminatorVal = (*Discriminator) & getDiscriminatorMask();
+
+    if (!isOffsetLegal(*LineOffset)) {
+      return std::error_code();
+    }
+
+    if (std::error_code EC = readVTableProf(
+            LineLocation(*LineOffset, DiscriminatorVal), FProfile))
+      return EC;
+  }
+  return sampleprof_error::success;
+}
+
 std::error_code
 SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
   auto NumSamples = readNumber<uint64_t>();
@@ -643,6 +704,11 @@ SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
                                       *CalledFunction, *CalledFunctionSamples);
     }
 
+    // read vtable type profiles.
+    if (std::error_code EC = readVTableProf(
+            LineLocation(*LineOffset, DiscriminatorVal), FProfile))
+      return EC;
+
     FProfile.addBodySamples(*LineOffset, DiscriminatorVal, *NumSamples);
   }
 
@@ -674,7 +740,7 @@ SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
       return EC;
   }
 
-  return sampleprof_error::success;
+  return readCallsiteVTableProf(FProfile);
 }
 
 std::error_code
@@ -736,6 +802,8 @@ std::error_code SampleProfileReaderExtBinaryBase::readOneSection(
       FunctionSamples::ProfileIsPreInlined = ProfileIsPreInlined = true;
     if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagFSDiscriminator))
       FunctionSamples::ProfileIsFS = ProfileIsFS = true;
+    if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagHasVTableTypeProf))
+      ReadVTableProf = true;
     break;
   case SecNameTable: {
     bool FixedLengthMD5 =
diff --git a/llvm/lib/ProfileData/SampleProfWriter.cpp b/llvm/lib/ProfileData/SampleProfWriter.cpp
index 6fc16d9effdd6..9af0571915557 100644
--- a/llvm/lib/ProfileData/SampleProfWriter.cpp
+++ b/llvm/lib/ProfileData/SampleProfWriter.cpp
@@ -41,6 +41,10 @@
 using namespace llvm;
 using namespace sampleprof;
 
+static cl::opt<bool> ExtBinaryWriteVTableTypeProf(
+    "extbinary-write-vtable-type-prof", cl::init(true), cl::Hidden,
+    cl::desc("Write vtable type profile in ext-binary sample profile writer"));
+
 namespace llvm {
 namespace support {
 namespace endian {
@@ -435,6 +439,9 @@ std::error_code SampleProfileWriterExtBinaryBase::writeOneSection(
     addSectionFlag(SecProfSummary, SecProfSummaryFlags::SecFlagIsPreInlined);
   if (Type == SecProfSummary && FunctionSamples::ProfileIsFS)
     addSectionFlag(SecProfSummary, SecProfSummaryFlags::SecFlagFSDiscriminator);
+  if (Type == SecProfSummary && ExtBinaryWriteVTableTypeProf)
+    addSectionFlag(SecProfSummary,
+                   SecProfSummaryFlags::SecFlagHasVTableTypeProf);
 
   uint64_t SectionStart = markSectionStart(Type, LayoutIdx);
   switch (Type) {
@@ -478,6 +485,13 @@ std::error_code SampleProfileWriterExtBinaryBase::writeOneSection(
   return sampleprof_error::success;
 }
 
+SampleProfileWriterExtBinary::SampleProfileWriterExtBinary(
+    std::unique_ptr<raw_ostream> &OS)
+    : SampleProfileWriterExtBinaryBase(OS) {
+  // Initialize the section header layout.
+  WriteVTableProf = ExtBinaryWriteVTableTypeProf;
+}
+
 std::error_code SampleProfileWriterExtBinary::writeDefaultLayout(
     const SampleProfileMap &ProfileMap) {
   // The const indices passed to writeOneSection below are specifying the
@@ -658,8 +672,13 @@ void SampleProfileWriterBinary::addNames(const FunctionSamples &S) {
     const SampleRecord &Sample = I.second;
     for (const auto &J : Sample.getCallTargets())
       addName(J.first);
+    addTypeNames(Sample.getTypes());
   }
 
+  // Add all the names in callsite types.
+  for (const auto &CallsiteTypeSamples : S.getCallsiteTypes())
+    addTypeNames(CallsiteTypeSamples.second);
+
   // Recursively add all the names for inlined callsites.
   for (const auto &J : S.getCallsiteSamples())
     for (const auto &FS : J.second) {
@@ -805,6 +824,45 @@ std::error_code SampleProfileWriterExtBinaryBase::writeHeader(
   return sampleprof_error::success;
 }
 
+std::error_code SampleProfileWriterExtBinary::writeTypeMap(const TypeMap &Map,
+                                                           raw_ostream &OS) {
+  encodeULEB128(Map.size(), OS);
+  for (const auto &[TypeName, TypeSamples] : Map) {
+    if (std::error_code EC = writeNameIdx(TypeName))
+      return EC;
+    encodeULEB128(TypeSamples, OS);
+  }
+  return sampleprof_error::success;
+}
+
+std::error_code SampleProfileWriterExtBinary::writeSampleRecordVTableProf(
+    const SampleRecord &Record, raw_ostream &OS) {
+  if (!WriteVTableProf)
+    return sampleprof_error::success;
+
+  // TODO: Unify this with SampleProfileWriterBinary::writeBody with a
+  // pre-commit refactor.
+  return writeTypeMap(Record.getTypes(), OS);
+}
+
+std::error_code SampleProfileWriterExtBinary::writeCallsiteType(
+    const FunctionSamples &FunctionSample, raw_ostream &OS) {
+  if (!WriteVTableProf)
+    return sampleprof_error::success;
+
+  const CallsiteTypeMap &CallsiteTypeMap = FunctionSample.getCallsiteTypes();
+
+  encodeULEB128(CallsiteTypeMap.size(), OS);
+  for (const auto &[Loc, TypeMap] : CallsiteTypeMap) {
+    encodeULEB128(Loc.LineOffset, OS);
+    encodeULEB128(Loc.Discriminator, OS);
+    if (std::error_code EC = writeTypeMap(TypeMap, OS))
+      return EC;
+  }
+
+  return sampleprof_error::success;
+}
+
 std::error_code SampleProfileWriterBinary::writeSummary() {
   auto &OS = *OutputStream;
   encodeULEB128(Summary->getTotalCount(), OS);
@@ -844,6 +902,8 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
         return EC;
       encodeULEB128(CalleeSamples, OS);
     }
+    if (std::error_code EC = writeSampleRecordVTableProf(Sample, OS))
+      return EC;
   }
 
   // Recursively emit all the callsite samples.
@@ -851,7 +911,7 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
   for (const auto &J : S.getCallsiteSamples())
     NumCallsites += J.second.size();
   encodeULEB128(NumCallsites, OS);
-  for (const auto &J : S.getCallsiteSamples())
+  for (const auto &J : S.getCallsiteSamples()) {
     for (const auto &FS : J.second) {
       LineLocation Loc = J.first;
       const FunctionSamples &CalleeSamples = FS.second;
@@ -860,8 +920,9 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
       if (std::error_code EC = writeBody(CalleeSamples))
         return EC;
     }
+  }
 
-  return sampleprof_error::success;
+  return writeCallsiteType(S, OS);
 }
 
 /// Write samples of a top-level function to a binary file.

>From ee2e45b090e53785869d687910e553cd0828b30e Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 26 May 2025 17:37:11 -0700
Subject: [PATCH 2/3] text format

TODO: Support vtable profiles in binary SPGO format
---
 llvm/include/llvm/ProfileData/SampleProf.h    | 25 ++++++
 .../llvm/ProfileData/SampleProfWriter.h       |  2 +-
 llvm/lib/ProfileData/SampleProf.cpp           | 24 +++++-
 llvm/lib/ProfileData/SampleProfReader.cpp     | 79 +++++++++++++++++--
 llvm/lib/ProfileData/SampleProfWriter.cpp     | 35 ++++++--
 .../Inputs/profile-symbol-list.expected       |  3 +-
 .../Inputs/sample-profile.proftext            |  2 +
 llvm/test/tools/llvm-profdata/roundtrip.test  | 10 +--
 llvm/tools/llvm-profdata/llvm-profdata.cpp    |  5 ++
 9 files changed, 163 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index d3997f9b88c7f..88551b8bb2982 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -407,6 +407,7 @@ class SampleRecord {
   uint64_t getSamples() const { return NumSamples; }
   const CallTargetMap &getCallTargets() const { return CallTargets; }
   const TypeMap &getTypes() const { return TypeCounts; }
+  TypeMap &getTypes() { return TypeCounts; }
   const SortedCallTargetSet getSortedCallTargets() const {
     return sortCallTargets(CallTargets);
   }
@@ -809,11 +810,22 @@ class FunctionSamples {
         Func, Num, Weight);
   }
 
+  sampleprof_error addTypeSamples(uint32_t LineOffset, uint32_t Discriminator,
+                                  FunctionId Func, uint64_t Num,
+                                  uint64_t Weight = 1) {
+    return BodySamples[LineLocation(LineOffset, Discriminator)].addTypeCount(
+        Func, Num, Weight);
+  }
+
   sampleprof_error addTypeSamples(const LineLocation &Loc, FunctionId Func,
                                   uint64_t Num, uint64_t Weight = 1) {
     return BodySamples[Loc].addTypeCount(Func, Num, Weight);
   }
 
+  TypeMap &getTypeSamples(const LineLocation &Loc) {
+    return BodySamples[Loc].getTypes();
+  }
+
   sampleprof_error addSampleRecord(LineLocation Location,
                                    const SampleRecord &SampleRecord,
                                    uint64_t Weight = 1) {
@@ -1061,7 +1073,13 @@ class FunctionSamples {
       const LineLocation &Loc = I.first;
       const SampleRecord &Rec = I.second;
       mergeSampleProfErrors(Result, BodySamples[Loc].merge(Rec, Weight));
+      // const auto &OtherTypeCountMap = Rec.getTypes();
+      // for (const auto &[Type, Count] : OtherTypeCountMap) {
+      //   mergeSampleProfErrors(Result, addTypeSamples(Loc, Type, Count,
+      //   Weight));
+      // }
     }
+
     for (const auto &I : Other.getCallsiteSamples()) {
       const LineLocation &Loc = I.first;
       FunctionSamplesMap &FSMap = functionSamplesAt(Loc);
@@ -1069,6 +1087,13 @@ class FunctionSamples {
         mergeSampleProfErrors(Result,
                               FSMap[Rec.first].merge(Rec.second, Weight));
     }
+    for (const auto &[Loc, TypeCountMap] : Other.getCallsiteTypes()) {
+      TypeMap &TypeCounts = getTypeSamplesAt(Loc);
+      for (const auto &[Type, Count] : TypeCountMap) {
+        TypeCounts[Type] =
+            SaturatingMultiplyAdd(Count, Weight, TypeCounts[Type]);
+      }
+    }
     return Result;
   }
 
diff --git a/llvm/include/llvm/ProfileData/SampleProfWriter.h b/llvm/include/llvm/ProfileData/SampleProfWriter.h
index 63082b95807db..fe3121de87415 100644
--- a/llvm/include/llvm/ProfileData/SampleProfWriter.h
+++ b/llvm/include/llvm/ProfileData/SampleProfWriter.h
@@ -431,7 +431,7 @@ class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
                                     raw_ostream &OS) override;
 
   void addTypeNames(const TypeMap &M) override {
-    if (WriteVTableProf)
+    if (!WriteVTableProf)
       return;
     // Add type name to TypeNameTable.
     for (const auto &[Type, Cnt] : M) {
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 4d48de9bc7d63..e4c6d0d138167 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -123,6 +123,9 @@ sampleprof_error SampleRecord::merge(const SampleRecord &Other,
   for (const auto &I : Other.getCallTargets()) {
     mergeSampleProfErrors(Result, addCalledTarget(I.first, I.second, Weight));
   }
+  for (const auto &[TypeName, Count] : Other.getTypes()) {
+    mergeSampleProfErrors(Result, addTypeCount(TypeName, Count, Weight));
+  }
   return Result;
 }
 
@@ -138,6 +141,11 @@ void SampleRecord::print(raw_ostream &OS, unsigned Indent) const {
     for (const auto &I : getSortedCallTargets())
       OS << " " << I.first << ":" << I.second;
   }
+  if (!TypeCounts.empty()) {
+    OS << ", types:";
+    for (const auto &I : TypeCounts)
+      OS << " " << I.first << ":" << I.second;
+  }
   OS << "\n";
 }
 
@@ -179,11 +187,21 @@ void FunctionSamples::print(raw_ostream &OS, unsigned Indent) const {
     SampleSorter<LineLocation, FunctionSamplesMap> SortedCallsiteSamples(
         CallsiteSamples);
     for (const auto &CS : SortedCallsiteSamples.get()) {
-      for (const auto &FS : CS->second) {
+      for (const auto &[FuncId, FuncSample] : CS->second) {
         OS.indent(Indent + 2);
-        OS << CS->first << ": inlined callee: " << FS.second.getFunction()
+        OS << CS->first << ": inlined callee: " << FuncSample.getFunction()
            << ": ";
-        FS.second.print(OS, Indent + 4);
+        FuncSample.print(OS, Indent + 4);
+      }
+      const LineLocation &Loc = CS->first;
+      auto TypeSamplesIter = VirtualCallsiteTypes.find(Loc);
+      if (TypeSamplesIter != VirtualCallsiteTypes.end()) {
+        OS.indent(Indent + 2);
+        OS << Loc << ": vtables: ";
+        for (const auto &TypeSample : TypeSamplesIter->second) {
+          OS << TypeSample.first << ":" << TypeSample.second << " ";
+        }
+        OS << "\n";
       }
     }
     OS.indent(Indent);
diff --git a/llvm/lib/ProfileData/SampleProfReader.cpp b/llvm/lib/ProfileData/SampleProfReader.cpp
index f5fd711d4dd70..6b3766d6ebd91 100644
--- a/llvm/lib/ProfileData/SampleProfReader.cpp
+++ b/llvm/lib/ProfileData/SampleProfReader.cpp
@@ -197,8 +197,31 @@ enum class LineType {
   CallSiteProfile,
   BodyProfile,
   Metadata,
+  CallTargetTypeProfile,
+  CallSiteTypeProfile,
 };
 
+static bool parseTypeCountMap(StringRef Input,
+                              DenseMap<StringRef, uint64_t> &TypeCountMap) {
+  for (size_t Index = Input.find_first_not_of(' '); Index != StringRef::npos;) {
+    size_t n1 = Input.find(':', Index);
+    if (n1 == StringRef::npos)
+      return false; // No colon found, invalid format.
+    StringRef TypeName = Input.substr(Index, n1 - Index);
+    // n2 is the start index of count.
+    size_t n2 = n1 + 1;
+    // n3 is the start index after the 'target:count' pair.
+    size_t n3 = Input.find_first_of(' ', n2);
+    uint64_t Count;
+    if (Input.substr(n2, n3 - n2).getAsInteger(10, Count))
+      return false; // Invalid count.
+    TypeCountMap[TypeName] = Count;
+    Index = (n3 == StringRef::npos) ? StringRef::npos
+                                    : Input.find_first_not_of(' ', n3);
+  }
+  return true;
+}
+
 /// Parse \p Input as line sample.
 ///
 /// \param Input input line.
@@ -215,6 +238,7 @@ static bool ParseLine(const StringRef &Input, LineType &LineTy, uint32_t &Depth,
                       uint64_t &NumSamples, uint32_t &LineOffset,
                       uint32_t &Discriminator, StringRef &CalleeName,
                       DenseMap<StringRef, uint64_t> &TargetCountMap,
+                      DenseMap<StringRef, uint64_t> &TypeCountMap,
                       uint64_t &FunctionHash, uint32_t &Attributes,
                       bool &IsFlat) {
   for (Depth = 0; Input[Depth] == ' '; Depth++)
@@ -289,6 +313,7 @@ static bool ParseLine(const StringRef &Input, LineType &LineTy, uint32_t &Depth,
         n4 = AfterColon.find_first_of(' ');
         n4 = (n4 != StringRef::npos) ? n3 + n4 + 1 : Rest.size();
         StringRef WordAfterColon = Rest.substr(n3 + 1, n4 - n3 - 1);
+        // Break the loop if parsing integer succeeded.
         if (!WordAfterColon.getAsInteger(10, count))
           break;
 
@@ -306,6 +331,16 @@ static bool ParseLine(const StringRef &Input, LineType &LineTy, uint32_t &Depth,
       // Change n3 to the next blank space after colon + integer pair.
       n3 = n4;
     }
+  } else if (Rest.ends_with("// CallTargetVtables")) {
+    LineTy = LineType::CallTargetTypeProfile;
+    return parseTypeCountMap(
+        Rest.substr(0, Rest.size() - strlen("// CallTargetVtables")),
+        TypeCountMap);
+  } else if (Rest.ends_with("// CallSiteVtables")) {
+    LineTy = LineType::CallSiteTypeProfile;
+    return parseTypeCountMap(
+        Rest.substr(0, Rest.size() - strlen("// CallSiteVtables")),
+        TypeCountMap);
   } else {
     LineTy = LineType::CallSiteProfile;
     size_t n3 = Rest.find_last_of(':');
@@ -374,14 +409,15 @@ std::error_code SampleProfileReaderText::readImpl() {
       uint64_t NumSamples;
       StringRef FName;
       DenseMap<StringRef, uint64_t> TargetCountMap;
+      DenseMap<StringRef, uint64_t> TypeCountMap;
       uint32_t Depth, LineOffset, Discriminator;
       LineType LineTy;
       uint64_t FunctionHash = 0;
       uint32_t Attributes = 0;
       bool IsFlat = false;
       if (!ParseLine(*LineIt, LineTy, Depth, NumSamples, LineOffset,
-                     Discriminator, FName, TargetCountMap, FunctionHash,
-                     Attributes, IsFlat)) {
+                     Discriminator, FName, TargetCountMap, TypeCountMap,
+                     FunctionHash, Attributes, IsFlat)) {
         reportError(LineIt.line_number(),
                     "Expected 'NUM[.NUM]: NUM[ mangled_name:NUM]*', found " +
                         *LineIt);
@@ -410,6 +446,29 @@ std::error_code SampleProfileReaderText::readImpl() {
         DepthMetadata = 0;
         break;
       }
+
+      case LineType::CallSiteTypeProfile: {
+        TypeMap &Map = InlineStack.back()->getTypeSamplesAt(
+            LineLocation(LineOffset, Discriminator));
+        for (const auto [Type, Count] : TypeCountMap)
+          Map[FunctionId(Type)] += Count;
+        break;
+      }
+
+      case LineType::CallTargetTypeProfile: {
+        while (InlineStack.size() > Depth) {
+          InlineStack.pop_back();
+        }
+        FunctionSamples &FProfile = *InlineStack.back();
+        for (const auto &name_count : TypeCountMap) {
+          mergeSampleProfErrors(
+              Result, FProfile.addTypeSamples(LineOffset, Discriminator,
+                                              FunctionId(name_count.first),
+                                              name_count.second));
+        }
+        break;
+      }
+
       case LineType::BodyProfile: {
         while (InlineStack.size() > Depth) {
           InlineStack.pop_back();
@@ -608,6 +667,7 @@ std::error_code SampleProfileReaderExtBinary::readTypeMap(TypeMap &M) {
     if (std::error_code EC = VTableSamples.getError())
       return EC;
 
+    errs() << "readTypeMap\t" << *VTableType << "\t" << *VTableSamples << "\n";
     M.insert(std::make_pair(*VTableType, *VTableSamples));
   }
   return sampleprof_error::success;
@@ -619,7 +679,7 @@ SampleProfileReaderExtBinary::readVTableProf(const LineLocation &Loc,
   if (!ReadVTableProf)
     return sampleprof_error::success;
 
-  return readTypeMap(FProfile.getTypeSamplesAt(Loc));
+  return readTypeMap(FProfile.getTypeSamples(Loc));
 }
 
 std::error_code SampleProfileReaderExtBinary::readCallsiteVTableProf(
@@ -648,8 +708,8 @@ std::error_code SampleProfileReaderExtBinary::readCallsiteVTableProf(
       return std::error_code();
     }
 
-    if (std::error_code EC = readVTableProf(
-            LineLocation(*LineOffset, DiscriminatorVal), FProfile))
+    if (std::error_code EC = readTypeMap(FProfile.getTypeSamplesAt(
+            LineLocation(*LineOffset, DiscriminatorVal))))
       return EC;
   }
   return sampleprof_error::success;
@@ -740,7 +800,10 @@ SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
       return EC;
   }
 
-  return readCallsiteVTableProf(FProfile);
+  std::error_code EC = readCallsiteVTableProf(FProfile);
+  errs() << "readFunctionSample\t";
+  FProfile.print(errs(), 2);
+  return EC;
 }
 
 std::error_code
@@ -802,8 +865,10 @@ std::error_code SampleProfileReaderExtBinaryBase::readOneSection(
       FunctionSamples::ProfileIsPreInlined = ProfileIsPreInlined = true;
     if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagFSDiscriminator))
       FunctionSamples::ProfileIsFS = ProfileIsFS = true;
-    if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagHasVTableTypeProf))
+    if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagHasVTableTypeProf)) {
+      errs() << "SampleProfileReaderExtBinaryBase::readVTableProf\n";
       ReadVTableProf = true;
+    }
     break;
   case SecNameTable: {
     bool FixedLengthMD5 =
diff --git a/llvm/lib/ProfileData/SampleProfWriter.cpp b/llvm/lib/ProfileData/SampleProfWriter.cpp
index 9af0571915557..dde9173565b50 100644
--- a/llvm/lib/ProfileData/SampleProfWriter.cpp
+++ b/llvm/lib/ProfileData/SampleProfWriter.cpp
@@ -489,7 +489,10 @@ SampleProfileWriterExtBinary::SampleProfileWriterExtBinary(
     std::unique_ptr<raw_ostream> &OS)
     : SampleProfileWriterExtBinaryBase(OS) {
   // Initialize the section header layout.
+
   WriteVTableProf = ExtBinaryWriteVTableTypeProf;
+
+  errs() << "writeVTableProf value: " << WriteVTableProf << "\n";
 }
 
 std::error_code SampleProfileWriterExtBinary::writeDefaultLayout(
@@ -604,15 +607,25 @@ std::error_code SampleProfileWriterText::writeSample(const FunctionSamples &S) {
     for (const auto &J : Sample.getSortedCallTargets())
       OS << " " << J.first << ":" << J.second;
     OS << "\n";
+
+    if (!Sample.getTypes().empty()) {
+      OS.indent(Indent + 1);
+      Loc.print(OS);
+      OS << ": ";
+      for (const auto &Type : Sample.getTypes()) {
+        OS << Type.first << ":" << Type.second << " ";
+      }
+      OS << " // CallTargetVtables\n";
+    }
     LineCount++;
   }
 
   SampleSorter<LineLocation, FunctionSamplesMap> SortedCallsiteSamples(
       S.getCallsiteSamples());
   Indent += 1;
-  for (const auto &I : SortedCallsiteSamples.get())
+  for (const auto &I : SortedCallsiteSamples.get()) {
+    LineLocation Loc = I->first;
     for (const auto &FS : I->second) {
-      LineLocation Loc = I->first;
       const FunctionSamples &CalleeSamples = FS.second;
       OS.indent(Indent);
       if (Loc.Discriminator == 0)
@@ -622,6 +635,19 @@ std::error_code SampleProfileWriterText::writeSample(const FunctionSamples &S) {
       if (std::error_code EC = writeSample(CalleeSamples))
         return EC;
     }
+
+    if (const TypeMap *Map = S.findTypeSamplesAt(Loc); Map && !Map->empty()) {
+      OS.indent(Indent);
+      Loc.print(OS);
+      OS << ": ";
+      for (const auto &Type : *Map) {
+        OS << Type.first << ":" << Type.second << " ";
+      }
+      OS << " // CallSiteVtables\n";
+      LineCount++;
+    }
+  }
+
   Indent -= 1;
 
   if (FunctionSamples::ProfileIsProbeBased) {
@@ -828,6 +854,8 @@ std::error_code SampleProfileWriterExtBinary::writeTypeMap(const TypeMap &Map,
                                                            raw_ostream &OS) {
   encodeULEB128(Map.size(), OS);
   for (const auto &[TypeName, TypeSamples] : Map) {
+    errs() << "TypeName: " << TypeName << "\t" << "TypeSamples: " << TypeSamples
+           << "\n";
     if (std::error_code EC = writeNameIdx(TypeName))
       return EC;
     encodeULEB128(TypeSamples, OS);
@@ -840,8 +868,6 @@ std::error_code SampleProfileWriterExtBinary::writeSampleRecordVTableProf(
   if (!WriteVTableProf)
     return sampleprof_error::success;
 
-  // TODO: Unify this with SampleProfileWriterBinary::writeBody with a
-  // pre-commit refactor.
   return writeTypeMap(Record.getTypes(), OS);
 }
 
@@ -851,7 +877,6 @@ std::error_code SampleProfileWriterExtBinary::writeCallsiteType(
     return sampleprof_error::success;
 
   const CallsiteTypeMap &CallsiteTypeMap = FunctionSample.getCallsiteTypes();
-
   encodeULEB128(CallsiteTypeMap.size(), OS);
   for (const auto &[Loc, TypeMap] : CallsiteTypeMap) {
     encodeULEB128(Loc.LineOffset, OS);
diff --git a/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
index bd528b44b81c4..a7912eadbc888 100644
--- a/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
+++ b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
@@ -6,7 +6,7 @@ Samples collected in the function's body {
   5.1: 2150
   6: 4160
   7: 1068
-  9: 4128, calls: _Z3bari:2942 _Z3fooi:1262
+  9: 4128, calls: _Z3bari:2942 _Z3fooi:1262, types: vtable_bar:2942 vtable_foo:1260
 }
 Samples collected in inlined callsites {
   10: inlined callee: inline1: 2000, 0, 1 sampled lines
@@ -19,6 +19,7 @@ Samples collected in inlined callsites {
       1: 4000
     }
     No inlined callsites in this function
+  10: vtables: inline1_vtable:2000 inline2_vtable:4000
 }
 Function: _Z3bari: 40602, 2874, 1 sampled lines
 Samples collected in the function's body {
diff --git a/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext b/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
index f9f87dfd661f8..0a12512161a8c 100644
--- a/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
+++ b/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
@@ -6,10 +6,12 @@ main:184019:0
  6: 2080
  7: 534
  9: 2064 _Z3bari:1471 _Z3fooi:631
+ 9: vtable_bar:1471 vtable_foo:630  // CallTargetVtables
  10: inline1:1000
   1: 1000
  10: inline2:2000
   1: 2000
+ 10: inline1_vtable:1000 inline2_vtable:2000  // CallSiteVtables
 _Z3bari:20301:1437
  1: 1437
 _Z3fooi:7711:610
diff --git a/llvm/test/tools/llvm-profdata/roundtrip.test b/llvm/test/tools/llvm-profdata/roundtrip.test
index 7af76e0a58224..3746ac191e875 100644
--- a/llvm/test/tools/llvm-profdata/roundtrip.test
+++ b/llvm/test/tools/llvm-profdata/roundtrip.test
@@ -6,13 +6,13 @@ RUN: llvm-profdata show -o %t.1.proftext -all-functions -text %t.1.profdata
 RUN: diff -b %t.1.proftext %S/Inputs/IR_profile.proftext
 RUN: llvm-profdata merge --sample --binary -output=%t.2.profdata %S/Inputs/sample-profile.proftext
 RUN: llvm-profdata merge --sample --text -output=%t.2.proftext %t.2.profdata
-RUN: diff -b %t.2.proftext %S/Inputs/sample-profile.proftext
+COM: diff -b %t.2.proftext %S/Inputs/sample-profile.proftext
 # Round trip from text --> extbinary --> text
 RUN: llvm-profdata merge --sample --extbinary -output=%t.3.profdata %S/Inputs/sample-profile.proftext
 RUN: llvm-profdata merge --sample --text -output=%t.3.proftext %t.3.profdata
 RUN: diff -b %t.3.proftext %S/Inputs/sample-profile.proftext
 # Round trip from text --> binary --> extbinary --> text
-RUN: llvm-profdata merge --sample --binary -output=%t.4.profdata %S/Inputs/sample-profile.proftext
-RUN: llvm-profdata merge --sample --extbinary -output=%t.5.profdata %t.4.profdata
-RUN: llvm-profdata merge --sample --text -output=%t.4.proftext %t.5.profdata
-RUN: diff -b %t.4.proftext %S/Inputs/sample-profile.proftext
+COM: llvm-profdata merge --sample --binary -output=%t.4.profdata %S/Inputs/sample-profile.proftext
+COM: llvm-profdata merge --sample --extbinary -output=%t.5.profdata %t.4.profdata
+COM: llvm-profdata merge --sample --text -output=%t.4.proftext %t.5.profdata
+COM: diff -b %t.4.proftext %S/Inputs/sample-profile.proftext
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index 9a5d3f91d6256..122b7e17b705d 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -1614,6 +1614,7 @@ static void mergeSampleProfile(const WeightedFileVector &Inputs,
     }
 
     SampleProfileMap &Profiles = Reader->getProfiles();
+
     if (ProfileIsProbeBased &&
         ProfileIsProbeBased != FunctionSamples::ProfileIsProbeBased)
       exitWithError(
@@ -1629,9 +1630,13 @@ static void mergeSampleProfile(const WeightedFileVector &Inputs,
           Remapper ? remapSamples(I->second, *Remapper, Result)
                    : FunctionSamples();
       FunctionSamples &Samples = Remapper ? Remapped : I->second;
+      errs() << "llvm-profdata.cpp\tfunction samples:\t";
+      Samples.print(errs(), 2);
       SampleContext FContext = Samples.getContext();
       mergeSampleProfErrors(Result,
                             ProfileMap[FContext].merge(Samples, Input.Weight));
+      errs() << "llvm-profdata.cpp\tmerged samples:\t";
+      ProfileMap[FContext].print(errs(), 2);
       if (Result != sampleprof_error::success) {
         std::error_code EC = make_error_code(Result);
         handleMergeWriterError(errorCodeToError(EC), Input.Filename,

>From f4b332dcfb610703a7637141bf070c4d5a85369b Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Sat, 31 May 2025 18:27:26 -0700
Subject: [PATCH 3/3] Update test; non-ext binary format not supported

---
 llvm/include/llvm/ProfileData/SampleProf.h    | 133 +++++++++++-------
 .../llvm/ProfileData/SampleProfReader.h       |  28 ++--
 .../llvm/ProfileData/SampleProfWriter.h       |  48 ++-----
 llvm/lib/ProfileData/SampleProf.cpp           |  43 +++++-
 llvm/lib/ProfileData/SampleProfReader.cpp     |  74 ++++------
 llvm/lib/ProfileData/SampleProfWriter.cpp     |  85 +++++------
 .../Inputs/profile-symbol-list-ext.expected   |  43 ++++++
 .../Inputs/profile-symbol-list.expected       |   3 +-
 .../Inputs/sample-profile-ext.proftext        |  18 +++
 .../Inputs/sample-profile.proftext            |   2 -
 .../profile-symbol-list-compress.test         |   6 +
 .../llvm-profdata/profile-symbol-list.test    |   6 +
 llvm/test/tools/llvm-profdata/roundtrip.test  |  15 +-
 llvm/tools/llvm-profdata/llvm-profdata.cpp    |   5 -
 14 files changed, 295 insertions(+), 214 deletions(-)
 create mode 100644 llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list-ext.expected
 create mode 100644 llvm/test/tools/llvm-profdata/Inputs/sample-profile-ext.proftext

diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 88551b8bb2982..2a026e0673b94 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -21,11 +21,11 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/ProfileData/FunctionId.h"
+#include "llvm/ProfileData/HashKeyMap.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/MathExtras.h"
-#include "llvm/ProfileData/HashKeyMap.h"
 #include <algorithm>
 #include <cstdint>
 #include <list>
@@ -59,7 +59,9 @@ enum class sampleprof_error {
   ostream_seek_unsupported,
   uncompress_failed,
   zlib_unavailable,
-  hash_mismatch
+  hash_mismatch,
+  illegal_line_offset,
+  duplicate_vtable_type,
 };
 
 inline std::error_code make_error_code(sampleprof_error E) {
@@ -88,6 +90,9 @@ struct is_error_code_enum<llvm::sampleprof_error> : std::true_type {};
 namespace llvm {
 namespace sampleprof {
 
+constexpr char kBodySampleVTableProfPrefix[] = "<vt-call> ";
+constexpr char kInlinedCallsiteVTablerofPrefix[] = "<vt-inline> ";
+
 enum SampleProfileFormat {
   SPF_None = 0,
   SPF_Text = 0x1,
@@ -286,6 +291,9 @@ struct LineLocation {
   void print(raw_ostream &OS) const;
   void dump() const;
 
+  /// Serialize the line location to \p OS using ULEB128 encoding.
+  void serialize(raw_ostream &OS) const;
+
   bool operator<(const LineLocation &O) const {
     return LineOffset < O.LineOffset ||
            (LineOffset == O.LineOffset && Discriminator < O.Discriminator);
@@ -315,7 +323,9 @@ struct LineLocationHash {
 
 raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc);
 
-using TypeMap = std::map<FunctionId, uint64_t>;
+/// Key represents the id of a vtable and value represents its count.
+/// TODO: Rename class FunctionId to SymbolId in a separate PR.
+using TypeCountMap = std::map<FunctionId, uint64_t>;
 
 /// Representation of a single sample record.
 ///
@@ -342,7 +352,6 @@ class SampleRecord {
 
   using SortedCallTargetSet = std::set<CallTarget, CallTargetComparator>;
   using CallTargetMap = std::unordered_map<FunctionId, uint64_t>;
-
   SampleRecord() = default;
 
   /// Increment the number of samples for this record by \p S.
@@ -372,22 +381,12 @@ class SampleRecord {
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
   sampleprof_error addCalledTarget(FunctionId F, uint64_t S,
-                                   uint64_t Weight = 1) {
-    uint64_t &TargetSamples = CallTargets[F];
-    bool Overflowed;
-    TargetSamples =
-        SaturatingMultiplyAdd(S, Weight, TargetSamples, &Overflowed);
-    return Overflowed ? sampleprof_error::counter_overflow
-                      : sampleprof_error::success;
-  }
+                                   uint64_t Weight = 1);
 
-  sampleprof_error addTypeCount(FunctionId F, uint64_t S, uint64_t Weight = 1) {
-    uint64_t &Samples = TypeCounts[F];
-    bool Overflowed;
-    Samples = SaturatingMultiplyAdd(S, Weight, Samples, &Overflowed);
-    return Overflowed ? sampleprof_error::counter_overflow
-                      : sampleprof_error::success;
-  }
+  /// Add vtable type \p F with samples \p S.
+  /// Optionally scale sample count \p S by \p Weight.
+  sampleprof_error addVTableAccessCount(FunctionId F, uint64_t S,
+                                        uint64_t Weight = 1);
 
   /// Remove called function from the call target map. Return the target sample
   /// count of the called function.
@@ -406,8 +405,10 @@ class SampleRecord {
 
   uint64_t getSamples() const { return NumSamples; }
   const CallTargetMap &getCallTargets() const { return CallTargets; }
-  const TypeMap &getTypes() const { return TypeCounts; }
-  TypeMap &getTypes() { return TypeCounts; }
+  const TypeCountMap &getVTableAccessCounts() const {
+    return VTableAccessCounts;
+  }
+  TypeCountMap &getVTableAccessCounts() { return VTableAccessCounts; }
   const SortedCallTargetSet getSortedCallTargets() const {
     return sortCallTargets(CallTargets);
   }
@@ -456,7 +457,8 @@ class SampleRecord {
 private:
   uint64_t NumSamples = 0;
   CallTargetMap CallTargets;
-  TypeMap TypeCounts;
+  // The vtable types and their counts in this sample record.
+  TypeCountMap VTableAccessCounts;
 };
 
 raw_ostream &operator<<(raw_ostream &OS, const SampleRecord &Sample);
@@ -752,7 +754,7 @@ using BodySampleMap = std::map<LineLocation, SampleRecord>;
 // memory, which is *very* significant for large profiles.
 using FunctionSamplesMap = std::map<FunctionId, FunctionSamples>;
 using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
-using CallsiteTypeMap = std::map<LineLocation, TypeMap>;
+using CallsiteTypeMap = std::map<LineLocation, TypeCountMap>;
 using LocToLocMap =
     std::unordered_map<LineLocation, LineLocation, LineLocationHash>;
 
@@ -810,20 +812,14 @@ class FunctionSamples {
         Func, Num, Weight);
   }
 
-  sampleprof_error addTypeSamples(uint32_t LineOffset, uint32_t Discriminator,
-                                  FunctionId Func, uint64_t Num,
-                                  uint64_t Weight = 1) {
-    return BodySamples[LineLocation(LineOffset, Discriminator)].addTypeCount(
-        Func, Num, Weight);
+  sampleprof_error addFunctionBodyTypeSamples(const LineLocation &Loc,
+                                              FunctionId Func, uint64_t Num,
+                                              uint64_t Weight = 1) {
+    return BodySamples[Loc].addVTableAccessCount(Func, Num, Weight);
   }
 
-  sampleprof_error addTypeSamples(const LineLocation &Loc, FunctionId Func,
-                                  uint64_t Num, uint64_t Weight = 1) {
-    return BodySamples[Loc].addTypeCount(Func, Num, Weight);
-  }
-
-  TypeMap &getTypeSamples(const LineLocation &Loc) {
-    return BodySamples[Loc].getTypes();
+  TypeCountMap &getFunctionBodyTypeSamples(const LineLocation &Loc) {
+    return BodySamples[Loc].getVTableAccessCounts();
   }
 
   sampleprof_error addSampleRecord(LineLocation Location,
@@ -951,7 +947,8 @@ class FunctionSamples {
     return &Iter->second;
   }
 
-  const TypeMap *findTypeSamplesAt(const LineLocation &Loc) const {
+  /// Returns the TypeCountMap for inlined callsites at the given \p Loc.
+  const TypeCountMap *findCallsiteTypeSamplesAt(const LineLocation &Loc) const {
     auto Iter = VirtualCallsiteTypes.find(mapIRLocToProfileLoc(Loc));
     if (Iter == VirtualCallsiteTypes.end())
       return nullptr;
@@ -1019,14 +1016,42 @@ class FunctionSamples {
     return CallsiteSamples;
   }
 
-  const CallsiteTypeMap &getCallsiteTypes() const {
+  /// Return all the callsite type samples collected in the body of the
+  /// function.
+  const CallsiteTypeMap &getCallsiteTypeCounts() const {
     return VirtualCallsiteTypes;
   }
 
-  TypeMap& getTypeSamplesAt(const LineLocation &Loc) {
+  /// Returns the type samples for the un-drifted location of \p Loc.
+  TypeCountMap &getTypeSamplesAt(const LineLocation &Loc) {
     return VirtualCallsiteTypes[mapIRLocToProfileLoc(Loc)];
   }
 
+  /// Scale \p Other sample counts by \p Weight and add the scaled result to the
+  /// type samples for the undrifted location of \p Loc.
+  template <typename T>
+  sampleprof_error addCallsiteVTableTypeProfAt(const LineLocation &Loc,
+                                               const T &Other,
+                                               uint64_t Weight = 1) {
+    static_assert((std::is_same_v<typename T::key_type, StringRef> ||
+                   std::is_same_v<typename T::key_type, FunctionId>) &&
+                      std::is_same_v<typename T::mapped_type, uint64_t>,
+                  "T must be a map with StringRef or FunctionId as key and "
+                  "uint64_t as value");
+    TypeCountMap &TypeCounts = getTypeSamplesAt(Loc);
+    bool Overflowed = false;
+
+    for (const auto [Type, Count] : Other) {
+      FunctionId TypeId(Type);
+      bool RowOverflow = false;
+      TypeCounts[TypeId] = SaturatingMultiplyAdd(
+          Count, Weight, TypeCounts[TypeId], &RowOverflow);
+      Overflowed |= RowOverflow;
+    }
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
+  }
+
   /// Return the maximum of sample counts in a function body. When SkipCallSite
   /// is false, which is the default, the return count includes samples in the
   /// inlined functions. When SkipCallSite is true, the return count only
@@ -1073,13 +1098,7 @@ class FunctionSamples {
       const LineLocation &Loc = I.first;
       const SampleRecord &Rec = I.second;
       mergeSampleProfErrors(Result, BodySamples[Loc].merge(Rec, Weight));
-      // const auto &OtherTypeCountMap = Rec.getTypes();
-      // for (const auto &[Type, Count] : OtherTypeCountMap) {
-      //   mergeSampleProfErrors(Result, addTypeSamples(Loc, Type, Count,
-      //   Weight));
-      // }
     }
-
     for (const auto &I : Other.getCallsiteSamples()) {
       const LineLocation &Loc = I.first;
       FunctionSamplesMap &FSMap = functionSamplesAt(Loc);
@@ -1087,13 +1106,10 @@ class FunctionSamples {
         mergeSampleProfErrors(Result,
                               FSMap[Rec.first].merge(Rec.second, Weight));
     }
-    for (const auto &[Loc, TypeCountMap] : Other.getCallsiteTypes()) {
-      TypeMap &TypeCounts = getTypeSamplesAt(Loc);
-      for (const auto &[Type, Count] : TypeCountMap) {
-        TypeCounts[Type] =
-            SaturatingMultiplyAdd(Count, Weight, TypeCounts[Type]);
-      }
-    }
+    for (const auto &[Loc, OtherTypeMap] : Other.getCallsiteTypeCounts())
+      mergeSampleProfErrors(
+          Result, addCallsiteVTableTypeProfAt(Loc, OtherTypeMap, Weight));
+
     return Result;
   }
 
@@ -1337,6 +1353,21 @@ class FunctionSamples {
   /// collected in the call to baz() at line offset 8.
   CallsiteSampleMap CallsiteSamples;
 
+  /// Map inlined virtual callsites to the vtable from which they are loaded.
+  ///
+  /// Each entry is a mapping from the location to the list of vtables and their
+  /// sampled counts. For example, given:
+  ///
+  ///     void foo() {
+  ///       ...
+  ///  5    inlined_vcall_bar();
+  ///       ...
+  ///  5    inlined_vcall_baz();
+  ///       ...
+  ///  200  inlined_vcall_qux();
+  ///     }
+  /// This map will contain two entries. One with two types for line offset 5
+  /// and one with one type for line offset 200.
   CallsiteTypeMap VirtualCallsiteTypes;
 
   /// IR to profile location map generated by stale profile matching.
diff --git a/llvm/include/llvm/ProfileData/SampleProfReader.h b/llvm/include/llvm/ProfileData/SampleProfReader.h
index 6ab3b650c0750..d7465c06af280 100644
--- a/llvm/include/llvm/ProfileData/SampleProfReader.h
+++ b/llvm/include/llvm/ProfileData/SampleProfReader.h
@@ -701,16 +701,10 @@ class SampleProfileReaderBinary : public SampleProfileReader {
   /// otherwise same as readStringFromTable, also return its hash value.
   ErrorOr<std::pair<SampleContext, uint64_t>> readSampleContextFromTable();
 
-  /// Overridden by SampleProfileReaderExtBinary to read the vtable profile.
-  virtual std::error_code readVTableProf(const LineLocation &Loc,
-                                         FunctionSamples &FProfile) {
-    return sampleprof_error::success;
-  }
-
-  virtual std::error_code readCallsiteVTableProf(FunctionSamples &FProfile) {
-    return sampleprof_error::success;
-  }
-
+  std::error_code readBodySampleVTableProf(const LineLocation &Loc,
+                                           FunctionSamples &FProfile);
+  /// Read all callsites' vtable access counts for \p FProfile.
+  std::error_code readCallsiteVTableProf(FunctionSamples &FProfile);
 
   /// Points to the current location in the buffer.
   const uint8_t *Data = nullptr;
@@ -736,6 +730,12 @@ class SampleProfileReaderBinary : public SampleProfileReader {
   /// to the start of MD5SampleContextTable.
   const uint64_t *MD5SampleContextStart = nullptr;
 
+  /// Read bytes from the input buffer pointed by `Data` and decode them into
+  /// \p M. `Data` will be advanced to the end of the read bytes when this
+  /// function returns. Returns error if any.
+  std::error_code readVTableTypeCountMap(TypeCountMap &M);
+  bool ReadVTableProf = false;
+
 private:
   std::error_code readSummaryEntry(std::vector<ProfileSummaryEntry> &Entries);
   virtual std::error_code verifySPMagic(uint64_t Magic) = 0;
@@ -825,8 +825,6 @@ class SampleProfileReaderExtBinaryBase : public SampleProfileReaderBinary {
   /// The set containing the functions to use when compiling a module.
   DenseSet<StringRef> FuncsToUse;
 
-  bool ReadVTableProf = false;
-
 public:
   SampleProfileReaderExtBinaryBase(std::unique_ptr<MemoryBuffer> B,
                                    LLVMContext &C, SampleProfileFormat Format)
@@ -867,12 +865,6 @@ class SampleProfileReaderExtBinary : public SampleProfileReaderExtBinaryBase {
     return sampleprof_error::success;
   };
 
-  std::error_code readVTableProf(const LineLocation &Loc,
-                                 FunctionSamples &FProfile) override;
-
-  std::error_code readCallsiteVTableProf(FunctionSamples &FProfile) override;
-
-   std::error_code readTypeMap(TypeMap& M);
 public:
   SampleProfileReaderExtBinary(std::unique_ptr<MemoryBuffer> B, LLVMContext &C,
                                SampleProfileFormat Format = SPF_Ext_Binary)
diff --git a/llvm/include/llvm/ProfileData/SampleProfWriter.h b/llvm/include/llvm/ProfileData/SampleProfWriter.h
index fe3121de87415..b20ad11349912 100644
--- a/llvm/include/llvm/ProfileData/SampleProfWriter.h
+++ b/llvm/include/llvm/ProfileData/SampleProfWriter.h
@@ -209,17 +209,7 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
   virtual MapVector<FunctionId, uint32_t> &getNameTable() { return NameTable; }
   virtual std::error_code writeMagicIdent(SampleProfileFormat Format);
   virtual std::error_code writeNameTable();
-  virtual std::error_code
-  writeSampleRecordVTableProf(const SampleRecord &Record, raw_ostream &OS) {
-    // TODO: This is not virtual because SampleProfWriter may create objects of
-    // type SampleProfileWriterRawBinary.
-    return sampleprof_error::success;
-  }
-  virtual std::error_code
-  writeCallsiteType(const FunctionSamples &FunctionSample, raw_ostream &OS) {
-    return sampleprof_error::success;
-  }
-  virtual void addTypeNames(const TypeMap &M) {}
+
   std::error_code writeHeader(const SampleProfileMap &ProfileMap) override;
   std::error_code writeSummary();
   virtual std::error_code writeContextIdx(const SampleContext &Context);
@@ -231,10 +221,23 @@ class SampleProfileWriterBinary : public SampleProfileWriter {
   MapVector<FunctionId, uint32_t> NameTable;
 
   void addName(FunctionId FName);
-  void addTypeName(FunctionId TypeName);
+  /// void addTypeName(FunctionId TypeName);
   virtual void addContext(const SampleContext &Context);
   void addNames(const FunctionSamples &S);
 
+  /// Add the type names to NameTable.
+  void addTypeNames(const TypeCountMap &M);
+
+  /// Write \p CallsiteTypeMap to the output stream \p OS.
+  std::error_code
+  writeCallsiteVTableProf(const CallsiteTypeMap &CallsiteTypeMap,
+                          raw_ostream &OS);
+  /// Write \p Map to the output stream \p OS.
+  std::error_code writeTypeMap(const TypeCountMap &Map, raw_ostream &OS);
+
+  // TODO:This should be configurable by flag.
+  bool WriteVTableProf = false;
+
 private:
   friend ErrorOr<std::unique_ptr<SampleProfileWriter>>
   SampleProfileWriter::create(std::unique_ptr<raw_ostream> &OS,
@@ -423,22 +426,6 @@ class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
 public:
   SampleProfileWriterExtBinary(std::unique_ptr<raw_ostream> &OS);
 
-protected:
-  std::error_code writeSampleRecordVTableProf(const SampleRecord &Record,
-                                              raw_ostream &OS) override;
-
-  std::error_code writeCallsiteType(const FunctionSamples &FunctionSample,
-                                    raw_ostream &OS) override;
-
-  void addTypeNames(const TypeMap &M) override {
-    if (!WriteVTableProf)
-      return;
-    // Add type name to TypeNameTable.
-    for (const auto &[Type, Cnt] : M) {
-      addName(Type);
-    }
-  }
-
 private:
   std::error_code writeDefaultLayout(const SampleProfileMap &ProfileMap);
   std::error_code writeCtxSplitLayout(const SampleProfileMap &ProfileMap);
@@ -453,11 +440,6 @@ class SampleProfileWriterExtBinary : public SampleProfileWriterExtBinaryBase {
     assert((SL == DefaultLayout || SL == CtxSplitLayout) &&
            "Unsupported layout");
   }
-
-  std::error_code writeTypeMap(const TypeMap &Map, raw_ostream& OS);
-
-  // TODO:This should be configurable by flag.
-  bool WriteVTableProf = false;
 };
 
 } // end namespace sampleprof
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index e4c6d0d138167..2c86235f00b8b 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/LEB128.h"
 #include "llvm/Support/raw_ostream.h"
 #include <string>
 #include <system_error>
@@ -90,6 +91,10 @@ class SampleProfErrorCategoryType : public std::error_category {
       return "Zlib is unavailable";
     case sampleprof_error::hash_mismatch:
       return "Function hash mismatch";
+    case sampleprof_error::illegal_line_offset:
+      return "Illegal line offset in sample profile data";
+    case sampleprof_error::duplicate_vtable_type:
+      return "Duplicate vtable type in one map";
     }
     llvm_unreachable("A value of sampleprof_error has no message.");
   }
@@ -108,6 +113,11 @@ void LineLocation::print(raw_ostream &OS) const {
     OS << "." << Discriminator;
 }
 
+void LineLocation::serialize(raw_ostream &OS) const {
+  encodeULEB128(LineOffset, OS);
+  encodeULEB128(Discriminator, OS);
+}
+
 raw_ostream &llvm::sampleprof::operator<<(raw_ostream &OS,
                                           const LineLocation &Loc) {
   Loc.print(OS);
@@ -123,9 +133,10 @@ sampleprof_error SampleRecord::merge(const SampleRecord &Other,
   for (const auto &I : Other.getCallTargets()) {
     mergeSampleProfErrors(Result, addCalledTarget(I.first, I.second, Weight));
   }
-  for (const auto &[TypeName, Count] : Other.getTypes()) {
-    mergeSampleProfErrors(Result, addTypeCount(TypeName, Count, Weight));
-  }
+  for (const auto &[TypeName, Count] : Other.getVTableAccessCounts())
+    mergeSampleProfErrors(Result,
+                          addVTableAccessCount(TypeName, Count, Weight));
+
   return Result;
 }
 
@@ -141,14 +152,32 @@ void SampleRecord::print(raw_ostream &OS, unsigned Indent) const {
     for (const auto &I : getSortedCallTargets())
       OS << " " << I.first << ":" << I.second;
   }
-  if (!TypeCounts.empty()) {
-    OS << ", types:";
-    for (const auto &I : TypeCounts)
-      OS << " " << I.first << ":" << I.second;
+  if (!VTableAccessCounts.empty()) {
+    OS << ", vtables:";
+    for (const auto [Type, Count] : VTableAccessCounts)
+      OS << " " << Type << ":" << Count;
   }
   OS << "\n";
 }
 
+static sampleprof_error addWeightSample(uint64_t S, uint64_t Weight,
+                                        uint64_t &Samples) {
+  bool Overflowed;
+  Samples = SaturatingMultiplyAdd(S, Weight, Samples, &Overflowed);
+  return Overflowed ? sampleprof_error::counter_overflow
+                    : sampleprof_error::success;
+}
+
+sampleprof_error SampleRecord::addCalledTarget(FunctionId F, uint64_t S,
+                                               uint64_t Weight) {
+  return addWeightSample(S, Weight, CallTargets[F]);
+}
+
+sampleprof_error SampleRecord::addVTableAccessCount(FunctionId F, uint64_t S,
+                                                    uint64_t Weight) {
+  return addWeightSample(S, Weight, VTableAccessCounts[F]);
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 LLVM_DUMP_METHOD void SampleRecord::dump() const { print(dbgs(), 0); }
 #endif
diff --git a/llvm/lib/ProfileData/SampleProfReader.cpp b/llvm/lib/ProfileData/SampleProfReader.cpp
index 6b3766d6ebd91..e8f5a07081286 100644
--- a/llvm/lib/ProfileData/SampleProfReader.cpp
+++ b/llvm/lib/ProfileData/SampleProfReader.cpp
@@ -331,16 +331,14 @@ static bool ParseLine(const StringRef &Input, LineType &LineTy, uint32_t &Depth,
       // Change n3 to the next blank space after colon + integer pair.
       n3 = n4;
     }
-  } else if (Rest.ends_with("// CallTargetVtables")) {
+  } else if (Rest.starts_with(kBodySampleVTableProfPrefix)) {
     LineTy = LineType::CallTargetTypeProfile;
-    return parseTypeCountMap(
-        Rest.substr(0, Rest.size() - strlen("// CallTargetVtables")),
-        TypeCountMap);
-  } else if (Rest.ends_with("// CallSiteVtables")) {
+    return parseTypeCountMap(Rest.substr(strlen(kBodySampleVTableProfPrefix)),
+                             TypeCountMap);
+  } else if (Rest.starts_with(kInlinedCallsiteVTablerofPrefix)) {
     LineTy = LineType::CallSiteTypeProfile;
     return parseTypeCountMap(
-        Rest.substr(0, Rest.size() - strlen("// CallSiteVtables")),
-        TypeCountMap);
+        Rest.substr(strlen(kInlinedCallsiteVTablerofPrefix)), TypeCountMap);
   } else {
     LineTy = LineType::CallSiteProfile;
     size_t n3 = Rest.find_last_of(':');
@@ -448,10 +446,9 @@ std::error_code SampleProfileReaderText::readImpl() {
       }
 
       case LineType::CallSiteTypeProfile: {
-        TypeMap &Map = InlineStack.back()->getTypeSamplesAt(
-            LineLocation(LineOffset, Discriminator));
-        for (const auto [Type, Count] : TypeCountMap)
-          Map[FunctionId(Type)] += Count;
+        mergeSampleProfErrors(
+            Result, InlineStack.back()->addCallsiteVTableTypeProfAt(
+                        LineLocation(LineOffset, Discriminator), TypeCountMap));
         break;
       }
 
@@ -462,9 +459,9 @@ std::error_code SampleProfileReaderText::readImpl() {
         FunctionSamples &FProfile = *InlineStack.back();
         for (const auto &name_count : TypeCountMap) {
           mergeSampleProfErrors(
-              Result, FProfile.addTypeSamples(LineOffset, Discriminator,
-                                              FunctionId(name_count.first),
-                                              name_count.second));
+              Result, FProfile.addFunctionBodyTypeSamples(
+                          LineLocation(LineOffset, Discriminator),
+                          FunctionId(name_count.first), name_count.second));
         }
         break;
       }
@@ -653,7 +650,8 @@ SampleProfileReaderBinary::readSampleContextFromTable() {
   return std::make_pair(Context, Hash);
 }
 
-std::error_code SampleProfileReaderExtBinary::readTypeMap(TypeMap &M) {
+std::error_code
+SampleProfileReaderBinary::readVTableTypeCountMap(TypeCountMap &M) {
   auto NumVTableTypes = readNumber<uint32_t>();
   if (std::error_code EC = NumVTableTypes.getError())
     return EC;
@@ -667,23 +665,14 @@ std::error_code SampleProfileReaderExtBinary::readTypeMap(TypeMap &M) {
     if (std::error_code EC = VTableSamples.getError())
       return EC;
 
-    errs() << "readTypeMap\t" << *VTableType << "\t" << *VTableSamples << "\n";
-    M.insert(std::make_pair(*VTableType, *VTableSamples));
+    if (!M.insert(std::make_pair(*VTableType, *VTableSamples)).second)
+      return sampleprof_error::duplicate_vtable_type;
   }
   return sampleprof_error::success;
 }
 
 std::error_code
-SampleProfileReaderExtBinary::readVTableProf(const LineLocation &Loc,
-                                             FunctionSamples &FProfile) {
-  if (!ReadVTableProf)
-    return sampleprof_error::success;
-
-  return readTypeMap(FProfile.getTypeSamples(Loc));
-}
-
-std::error_code SampleProfileReaderExtBinary::readCallsiteVTableProf(
-    FunctionSamples &FProfile) {
+SampleProfileReaderBinary::readCallsiteVTableProf(FunctionSamples &FProfile) {
   if (!ReadVTableProf)
     return sampleprof_error::success;
 
@@ -697,18 +686,17 @@ std::error_code SampleProfileReaderExtBinary::readCallsiteVTableProf(
     if (std::error_code EC = LineOffset.getError())
       return EC;
 
+    if (!isOffsetLegal(*LineOffset))
+      return sampleprof_error::illegal_line_offset;
+
     auto Discriminator = readNumber<uint64_t>();
     if (std::error_code EC = Discriminator.getError())
       return EC;
 
     // Here we handle FS discriminators:
-    uint32_t DiscriminatorVal = (*Discriminator) & getDiscriminatorMask();
+    const uint32_t DiscriminatorVal = (*Discriminator) & getDiscriminatorMask();
 
-    if (!isOffsetLegal(*LineOffset)) {
-      return std::error_code();
-    }
-
-    if (std::error_code EC = readTypeMap(FProfile.getTypeSamplesAt(
+    if (std::error_code EC = readVTableTypeCountMap(FProfile.getTypeSamplesAt(
             LineLocation(*LineOffset, DiscriminatorVal))))
       return EC;
   }
@@ -764,10 +752,13 @@ SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
                                       *CalledFunction, *CalledFunctionSamples);
     }
 
-    // read vtable type profiles.
-    if (std::error_code EC = readVTableProf(
-            LineLocation(*LineOffset, DiscriminatorVal), FProfile))
-      return EC;
+    if (ReadVTableProf) {
+      // read vtable type profiles.
+      if (std::error_code EC =
+              readVTableTypeCountMap(FProfile.getFunctionBodyTypeSamples(
+                  LineLocation(*LineOffset, DiscriminatorVal))))
+        return EC;
+    }
 
     FProfile.addBodySamples(*LineOffset, DiscriminatorVal, *NumSamples);
   }
@@ -800,10 +791,7 @@ SampleProfileReaderBinary::readProfile(FunctionSamples &FProfile) {
       return EC;
   }
 
-  std::error_code EC = readCallsiteVTableProf(FProfile);
-  errs() << "readFunctionSample\t";
-  FProfile.print(errs(), 2);
-  return EC;
+  return readCallsiteVTableProf(FProfile);
 }
 
 std::error_code
@@ -865,10 +853,8 @@ std::error_code SampleProfileReaderExtBinaryBase::readOneSection(
       FunctionSamples::ProfileIsPreInlined = ProfileIsPreInlined = true;
     if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagFSDiscriminator))
       FunctionSamples::ProfileIsFS = ProfileIsFS = true;
-    if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagHasVTableTypeProf)) {
-      errs() << "SampleProfileReaderExtBinaryBase::readVTableProf\n";
+    if (hasSecFlag(Entry, SecProfSummaryFlags::SecFlagHasVTableTypeProf))
       ReadVTableProf = true;
-    }
     break;
   case SecNameTable: {
     bool FixedLengthMD5 =
diff --git a/llvm/lib/ProfileData/SampleProfWriter.cpp b/llvm/lib/ProfileData/SampleProfWriter.cpp
index dde9173565b50..7090d93c30292 100644
--- a/llvm/lib/ProfileData/SampleProfWriter.cpp
+++ b/llvm/lib/ProfileData/SampleProfWriter.cpp
@@ -488,11 +488,7 @@ std::error_code SampleProfileWriterExtBinaryBase::writeOneSection(
 SampleProfileWriterExtBinary::SampleProfileWriterExtBinary(
     std::unique_ptr<raw_ostream> &OS)
     : SampleProfileWriterExtBinaryBase(OS) {
-  // Initialize the section header layout.
-
   WriteVTableProf = ExtBinaryWriteVTableTypeProf;
-
-  errs() << "writeVTableProf value: " << WriteVTableProf << "\n";
 }
 
 std::error_code SampleProfileWriterExtBinary::writeDefaultLayout(
@@ -608,14 +604,14 @@ std::error_code SampleProfileWriterText::writeSample(const FunctionSamples &S) {
       OS << " " << J.first << ":" << J.second;
     OS << "\n";
 
-    if (!Sample.getTypes().empty()) {
+    if (!Sample.getVTableAccessCounts().empty()) {
       OS.indent(Indent + 1);
       Loc.print(OS);
       OS << ": ";
-      for (const auto &Type : Sample.getTypes()) {
-        OS << Type.first << ":" << Type.second << " ";
-      }
-      OS << " // CallTargetVtables\n";
+      OS << kBodySampleVTableProfPrefix;
+      for (const auto [TypeName, Count] : Sample.getVTableAccessCounts())
+        OS << TypeName << ":" << Count << " ";
+      OS << "\n";
     }
     LineCount++;
   }
@@ -636,14 +632,16 @@ std::error_code SampleProfileWriterText::writeSample(const FunctionSamples &S) {
         return EC;
     }
 
-    if (const TypeMap *Map = S.findTypeSamplesAt(Loc); Map && !Map->empty()) {
+    if (const TypeCountMap *Map = S.findCallsiteTypeSamplesAt(Loc);
+        Map && !Map->empty()) {
       OS.indent(Indent);
       Loc.print(OS);
       OS << ": ";
-      for (const auto &Type : *Map) {
-        OS << Type.first << ":" << Type.second << " ";
+      OS << kInlinedCallsiteVTablerofPrefix;
+      for (const auto [TypeId, Count] : *Map) {
+        OS << TypeId << ":" << Count << " ";
       }
-      OS << " // CallSiteVtables\n";
+      OS << "\n";
       LineCount++;
     }
   }
@@ -692,18 +690,27 @@ void SampleProfileWriterBinary::addContext(const SampleContext &Context) {
   addName(Context.getFunction());
 }
 
+void SampleProfileWriterBinary::addTypeNames(const TypeCountMap &M) {
+  if (!WriteVTableProf)
+    return;
+  // Add type name to TypeNameTable.
+  for (const auto Type : llvm::make_first_range(M))
+    addName(Type);
+}
+
 void SampleProfileWriterBinary::addNames(const FunctionSamples &S) {
   // Add all the names in indirect call targets.
   for (const auto &I : S.getBodySamples()) {
     const SampleRecord &Sample = I.second;
     for (const auto &J : Sample.getCallTargets())
       addName(J.first);
-    addTypeNames(Sample.getTypes());
+    addTypeNames(Sample.getVTableAccessCounts());
   }
 
   // Add all the names in callsite types.
-  for (const auto &CallsiteTypeSamples : S.getCallsiteTypes())
-    addTypeNames(CallsiteTypeSamples.second);
+  for (const auto &VTableAccessCountMap :
+       llvm::make_second_range(S.getCallsiteTypeCounts()))
+    addTypeNames(VTableAccessCountMap);
 
   // Recursively add all the names for inlined callsites.
   for (const auto &J : S.getCallsiteSamples())
@@ -850,37 +857,25 @@ std::error_code SampleProfileWriterExtBinaryBase::writeHeader(
   return sampleprof_error::success;
 }
 
-std::error_code SampleProfileWriterExtBinary::writeTypeMap(const TypeMap &Map,
-                                                           raw_ostream &OS) {
+std::error_code SampleProfileWriterBinary::writeTypeMap(const TypeCountMap &Map,
+                                                        raw_ostream &OS) {
   encodeULEB128(Map.size(), OS);
-  for (const auto &[TypeName, TypeSamples] : Map) {
-    errs() << "TypeName: " << TypeName << "\t" << "TypeSamples: " << TypeSamples
-           << "\n";
+  for (const auto &[TypeName, SampleCount] : Map) {
     if (std::error_code EC = writeNameIdx(TypeName))
       return EC;
-    encodeULEB128(TypeSamples, OS);
+    encodeULEB128(SampleCount, OS);
   }
   return sampleprof_error::success;
 }
 
-std::error_code SampleProfileWriterExtBinary::writeSampleRecordVTableProf(
-    const SampleRecord &Record, raw_ostream &OS) {
-  if (!WriteVTableProf)
-    return sampleprof_error::success;
-
-  return writeTypeMap(Record.getTypes(), OS);
-}
-
-std::error_code SampleProfileWriterExtBinary::writeCallsiteType(
-    const FunctionSamples &FunctionSample, raw_ostream &OS) {
+std::error_code SampleProfileWriterBinary::writeCallsiteVTableProf(
+    const CallsiteTypeMap &CallsiteTypeMap, raw_ostream &OS) {
   if (!WriteVTableProf)
     return sampleprof_error::success;
 
-  const CallsiteTypeMap &CallsiteTypeMap = FunctionSample.getCallsiteTypes();
   encodeULEB128(CallsiteTypeMap.size(), OS);
   for (const auto &[Loc, TypeMap] : CallsiteTypeMap) {
-    encodeULEB128(Loc.LineOffset, OS);
-    encodeULEB128(Loc.Discriminator, OS);
+    Loc.serialize(OS);
     if (std::error_code EC = writeTypeMap(TypeMap, OS))
       return EC;
   }
@@ -916,8 +911,7 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
   for (const auto &I : S.getBodySamples()) {
     LineLocation Loc = I.first;
     const SampleRecord &Sample = I.second;
-    encodeULEB128(Loc.LineOffset, OS);
-    encodeULEB128(Loc.Discriminator, OS);
+    Loc.serialize(OS);
     encodeULEB128(Sample.getSamples(), OS);
     encodeULEB128(Sample.getCallTargets().size(), OS);
     for (const auto &J : Sample.getSortedCallTargets()) {
@@ -927,8 +921,9 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
         return EC;
       encodeULEB128(CalleeSamples, OS);
     }
-    if (std::error_code EC = writeSampleRecordVTableProf(Sample, OS))
-      return EC;
+    if (WriteVTableProf)
+      if (std::error_code EC = writeTypeMap(Sample.getVTableAccessCounts(), OS))
+        return EC;
   }
 
   // Recursively emit all the callsite samples.
@@ -936,18 +931,14 @@ std::error_code SampleProfileWriterBinary::writeBody(const FunctionSamples &S) {
   for (const auto &J : S.getCallsiteSamples())
     NumCallsites += J.second.size();
   encodeULEB128(NumCallsites, OS);
-  for (const auto &J : S.getCallsiteSamples()) {
+  for (const auto &J : S.getCallsiteSamples())
     for (const auto &FS : J.second) {
-      LineLocation Loc = J.first;
-      const FunctionSamples &CalleeSamples = FS.second;
-      encodeULEB128(Loc.LineOffset, OS);
-      encodeULEB128(Loc.Discriminator, OS);
-      if (std::error_code EC = writeBody(CalleeSamples))
+      J.first.serialize(OS);
+      if (std::error_code EC = writeBody(FS.second))
         return EC;
     }
-  }
 
-  return writeCallsiteType(S, OS);
+  return writeCallsiteVTableProf(S.getCallsiteTypeCounts(), OS);
 }
 
 /// Write samples of a top-level function to a binary file.
diff --git a/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list-ext.expected b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list-ext.expected
new file mode 100644
index 0000000000000..bbce2af2d3377
--- /dev/null
+++ b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list-ext.expected
@@ -0,0 +1,43 @@
+Function: main: 368038, 0, 7 sampled lines
+Samples collected in the function's body {
+  4: 1068
+  4.2: 1068
+  5: 2150
+  5.1: 2150
+  6: 4160
+  7: 1068
+  9: 4128, calls: _Z3bari:2942 _Z3fooi:1262, vtables: _ZTVbar:2942 _ZTVfoo:1260
+}
+Samples collected in inlined callsites {
+  10: inlined callee: inline1: 2000, 0, 1 sampled lines
+    Samples collected in the function's body {
+      1: 2000
+    }
+    No inlined callsites in this function
+  10: inlined callee: inline2: 4000, 0, 1 sampled lines
+    Samples collected in the function's body {
+      1: 4000
+    }
+    No inlined callsites in this function
+  10: vtables: _ZTVinline1:2000 _ZTVinline2:4000
+}
+Function: _Z3bari: 40602, 2874, 1 sampled lines
+Samples collected in the function's body {
+  1: 2874
+}
+No inlined callsites in this function
+Function: _Z3fooi: 15422, 1220, 1 sampled lines
+Samples collected in the function's body {
+  1: 1220
+}
+No inlined callsites in this function
+======== Dump profile symbol list ========
+_Z3goov
+_Z3sumii
+__libc_csu_fini
+__libc_csu_init
+_dl_relocate_static_pie
+_fini
+_init
+_start
+main
diff --git a/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
index a7912eadbc888..bd528b44b81c4 100644
--- a/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
+++ b/llvm/test/tools/llvm-profdata/Inputs/profile-symbol-list.expected
@@ -6,7 +6,7 @@ Samples collected in the function's body {
   5.1: 2150
   6: 4160
   7: 1068
-  9: 4128, calls: _Z3bari:2942 _Z3fooi:1262, types: vtable_bar:2942 vtable_foo:1260
+  9: 4128, calls: _Z3bari:2942 _Z3fooi:1262
 }
 Samples collected in inlined callsites {
   10: inlined callee: inline1: 2000, 0, 1 sampled lines
@@ -19,7 +19,6 @@ Samples collected in inlined callsites {
       1: 4000
     }
     No inlined callsites in this function
-  10: vtables: inline1_vtable:2000 inline2_vtable:4000
 }
 Function: _Z3bari: 40602, 2874, 1 sampled lines
 Samples collected in the function's body {
diff --git a/llvm/test/tools/llvm-profdata/Inputs/sample-profile-ext.proftext b/llvm/test/tools/llvm-profdata/Inputs/sample-profile-ext.proftext
new file mode 100644
index 0000000000000..48a48d9e2b368
--- /dev/null
+++ b/llvm/test/tools/llvm-profdata/Inputs/sample-profile-ext.proftext
@@ -0,0 +1,18 @@
+main:184019:0
+ 4: 534
+ 4.2: 534
+ 5: 1075
+ 5.1: 1075
+ 6: 2080
+ 7: 534
+ 9: 2064 _Z3bari:1471 _Z3fooi:631
+ 9: <vt-call> _ZTVbar:1471 _ZTVfoo:630
+ 10: inline1:1000
+  1: 1000
+ 10: inline2:2000
+  1: 2000
+ 10: <vt-inline> _ZTVinline1:1000 _ZTVinline2:2000
+_Z3bari:20301:1437
+ 1: 1437
+_Z3fooi:7711:610
+ 1: 610
diff --git a/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext b/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
index 0a12512161a8c..f9f87dfd661f8 100644
--- a/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
+++ b/llvm/test/tools/llvm-profdata/Inputs/sample-profile.proftext
@@ -6,12 +6,10 @@ main:184019:0
  6: 2080
  7: 534
  9: 2064 _Z3bari:1471 _Z3fooi:631
- 9: vtable_bar:1471 vtable_foo:630  // CallTargetVtables
  10: inline1:1000
   1: 1000
  10: inline2:2000
   1: 2000
- 10: inline1_vtable:1000 inline2_vtable:2000  // CallSiteVtables
 _Z3bari:20301:1437
  1: 1437
 _Z3fooi:7711:610
diff --git a/llvm/test/tools/llvm-profdata/profile-symbol-list-compress.test b/llvm/test/tools/llvm-profdata/profile-symbol-list-compress.test
index b445695c8e8e4..9a95a62461919 100644
--- a/llvm/test/tools/llvm-profdata/profile-symbol-list-compress.test
+++ b/llvm/test/tools/llvm-profdata/profile-symbol-list-compress.test
@@ -4,3 +4,9 @@ REQUIRES: zlib
 ; RUN: llvm-profdata merge -sample -extbinary -compress-all-sections %t.1.output %t.2.output -o %t.3.output
 ; RUN: llvm-profdata show -sample -show-prof-sym-list %t.3.output > %t.4.output
 ; RUN: diff -b %S/Inputs/profile-symbol-list.expected %t.4.output
+
+; RUN: llvm-profdata merge -sample -extbinary -compress-all-sections -prof-sym-list=%S/Inputs/profile-symbol-list-1.text %S/Inputs/sample-profile-ext.proftext -o %t.1.output
+; RUN: llvm-profdata merge -sample -extbinary -compress-all-sections -prof-sym-list=%S/Inputs/profile-symbol-list-2.text %S/Inputs/sample-profile-ext.proftext -o %t.2.output
+; RUN: llvm-profdata merge -sample -extbinary -compress-all-sections %t.1.output %t.2.output -o %t.3.output
+; RUN: llvm-profdata show -sample -show-prof-sym-list %t.3.output > %t.4.output
+; RUN: diff -b %S/Inputs/profile-symbol-list-ext.expected %t.4.output
diff --git a/llvm/test/tools/llvm-profdata/profile-symbol-list.test b/llvm/test/tools/llvm-profdata/profile-symbol-list.test
index 39dcd11ec1db7..8c5f9d620ee76 100644
--- a/llvm/test/tools/llvm-profdata/profile-symbol-list.test
+++ b/llvm/test/tools/llvm-profdata/profile-symbol-list.test
@@ -7,3 +7,9 @@
 ; RUN: llvm-profdata show -sample -show-sec-info-only %t.5.output  | FileCheck %s -check-prefix=NOSYMLIST
 
 ; NOSYMLIST: ProfileSymbolListSection {{.*}} Size: 0
+
+; RUN: llvm-profdata merge -sample -extbinary -prof-sym-list=%S/Inputs/profile-symbol-list-1.text %S/Inputs/sample-profile-ext.proftext -o %t.1.output
+; RUN: llvm-profdata merge -sample -extbinary -prof-sym-list=%S/Inputs/profile-symbol-list-2.text %S/Inputs/sample-profile-ext.proftext -o %t.2.output
+; RUN: llvm-profdata merge -sample -extbinary %t.1.output %t.2.output -o %t.3.output
+; RUN: llvm-profdata show -sample -show-prof-sym-list %t.3.output > %t.4.output
+; RUN: diff -b %S/Inputs/profile-symbol-list-ext.expected %t.4.output
diff --git a/llvm/test/tools/llvm-profdata/roundtrip.test b/llvm/test/tools/llvm-profdata/roundtrip.test
index 3746ac191e875..0cf518397079d 100644
--- a/llvm/test/tools/llvm-profdata/roundtrip.test
+++ b/llvm/test/tools/llvm-profdata/roundtrip.test
@@ -6,13 +6,18 @@ RUN: llvm-profdata show -o %t.1.proftext -all-functions -text %t.1.profdata
 RUN: diff -b %t.1.proftext %S/Inputs/IR_profile.proftext
 RUN: llvm-profdata merge --sample --binary -output=%t.2.profdata %S/Inputs/sample-profile.proftext
 RUN: llvm-profdata merge --sample --text -output=%t.2.proftext %t.2.profdata
-COM: diff -b %t.2.proftext %S/Inputs/sample-profile.proftext
+RUN: diff -b %t.2.proftext %S/Inputs/sample-profile.proftext
 # Round trip from text --> extbinary --> text
 RUN: llvm-profdata merge --sample --extbinary -output=%t.3.profdata %S/Inputs/sample-profile.proftext
 RUN: llvm-profdata merge --sample --text -output=%t.3.proftext %t.3.profdata
 RUN: diff -b %t.3.proftext %S/Inputs/sample-profile.proftext
 # Round trip from text --> binary --> extbinary --> text
-COM: llvm-profdata merge --sample --binary -output=%t.4.profdata %S/Inputs/sample-profile.proftext
-COM: llvm-profdata merge --sample --extbinary -output=%t.5.profdata %t.4.profdata
-COM: llvm-profdata merge --sample --text -output=%t.4.proftext %t.5.profdata
-COM: diff -b %t.4.proftext %S/Inputs/sample-profile.proftext
+RUN: llvm-profdata merge --sample --binary -output=%t.4.profdata %S/Inputs/sample-profile.proftext
+RUN: llvm-profdata merge --sample --extbinary -output=%t.5.profdata %t.4.profdata
+RUN: llvm-profdata merge --sample --text -output=%t.4.proftext %t.5.profdata
+RUN: diff -b %t.4.proftext %S/Inputs/sample-profile.proftext
+# Round trip from text --> extbinary --> text.
+# The text profile is supported by ext-binary profile but not binary profile format.
+RUN: llvm-profdata merge --sample --extbinary --output=%t.5.profdata %S/Inputs/sample-profile-ext.proftext
+RUN: llvm-profdata merge --sample --text --output=%t.5.proftext %t.5.profdata
+RUN: diff -b %t.5.proftext %S/Inputs/sample-profile-ext.proftext
diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp
index 122b7e17b705d..9a5d3f91d6256 100644
--- a/llvm/tools/llvm-profdata/llvm-profdata.cpp
+++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp
@@ -1614,7 +1614,6 @@ static void mergeSampleProfile(const WeightedFileVector &Inputs,
     }
 
     SampleProfileMap &Profiles = Reader->getProfiles();
-
     if (ProfileIsProbeBased &&
         ProfileIsProbeBased != FunctionSamples::ProfileIsProbeBased)
       exitWithError(
@@ -1630,13 +1629,9 @@ static void mergeSampleProfile(const WeightedFileVector &Inputs,
           Remapper ? remapSamples(I->second, *Remapper, Result)
                    : FunctionSamples();
       FunctionSamples &Samples = Remapper ? Remapped : I->second;
-      errs() << "llvm-profdata.cpp\tfunction samples:\t";
-      Samples.print(errs(), 2);
       SampleContext FContext = Samples.getContext();
       mergeSampleProfErrors(Result,
                             ProfileMap[FContext].merge(Samples, Input.Weight));
-      errs() << "llvm-profdata.cpp\tmerged samples:\t";
-      ProfileMap[FContext].print(errs(), 2);
       if (Result != sampleprof_error::success) {
         std::error_code EC = make_error_code(Result);
         handleMergeWriterError(errorCodeToError(EC), Input.Filename,



More information about the llvm-commits mailing list