[compiler-rt] [llvm] [ctxprof] Capture sampling info for context roots (PR #131201)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 13 14:45:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-pgo

Author: Mircea Trofin (mtrofin)

<details>
<summary>Changes</summary>

When we collect a contextual profile, we sample the threads entering its root and only collect on one at a time (see `ContextRoot::Taken`). If we want to compare profiles between contextual profiles, and/or flat profiles, we have a problem: we don't know how to compare the counter values relative to each other. To that end, we add `ContextRoot::TotalEntries`, which is incremented every time a root is entered and serves as multiplier for the counter values collected under that root.

We expose this in the profile and leave the normalization to the user of the profile, for a few reasons:

* it's only needed if reasoning about all profiles in aggregate.
* the goal, in compiler_rt, is to flush out the profile as quickly as possible, and performing multiplications adds an overhead that may not even be necessary if the consumer of the profile doesn't care about combining profiles
* the information itself may be interesting as an indication of relative sampling of various contexts.

---

Patch is 35.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131201.diff


29 Files Affected:

- (modified) compiler-rt/lib/ctx_profile/CtxInstrContextNode.h (+2-1) 
- (modified) compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp (+6-1) 
- (modified) compiler-rt/lib/ctx_profile/CtxInstrProfiling.h (+4) 
- (modified) compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp (+3-1) 
- (modified) compiler-rt/test/ctx_profile/TestCases/generate-context.cpp (+5-1) 
- (modified) llvm/include/llvm/ProfileData/CtxInstrContextNode.h (+2-1) 
- (modified) llvm/include/llvm/ProfileData/PGOCtxProfReader.h (+8-2) 
- (modified) llvm/include/llvm/ProfileData/PGOCtxProfWriter.h (+15-5) 
- (modified) llvm/lib/ProfileData/PGOCtxProfReader.cpp (+27-5) 
- (modified) llvm/lib/ProfileData/PGOCtxProfWriter.cpp (+46-14) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll (+1) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll (+3) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll (+1) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/flatten-zero-path.ll (+1) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll (+2) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/handle-select.ll (+1) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/inline.ll (+2) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/load-unapplicable.ll (+3) 
- (modified) llvm/test/Analysis/CtxProfAnalysis/load.ll (+5) 
- (modified) llvm/test/ThinLTO/X86/ctxprof.ll (+2-2) 
- (modified) llvm/test/Transforms/EliminateAvailableExternally/transform-to-local.ll (+2-2) 
- (added) llvm/test/tools/llvm-ctxprof-util/Inputs/invalid-no-entrycount.yaml (+3) 
- (modified) llvm/test/tools/llvm-ctxprof-util/Inputs/valid-ctx-only.yaml (+2) 
- (modified) llvm/test/tools/llvm-ctxprof-util/Inputs/valid-flat-first.yaml (+2) 
- (modified) llvm/test/tools/llvm-ctxprof-util/Inputs/valid.yaml (+2) 
- (modified) llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util-negative.test (+2) 
- (modified) llvm/test/tools/llvm-ctxprof-util/llvm-ctxprof-util.test (+7-5) 
- (modified) llvm/unittests/ProfileData/PGOCtxProfReaderWriterTest.cpp (+7-7) 
- (modified) llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (+4) 


``````````diff
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
index 0fc4883305145..55962df57fb58 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
@@ -120,7 +120,8 @@ class ContextNode final {
 class ProfileWriter {
 public:
   virtual void startContextSection() = 0;
-  virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0;
+  virtual void writeContextual(const ctx_profile::ContextNode &RootNode,
+                               uint64_t TotalRootEntryCount) = 0;
   virtual void endContextSection() = 0;
 
   virtual void startFlatSection() = 0;
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
index d7ec8fde4ec7d..1c2cad1ca506e 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
@@ -340,6 +340,9 @@ ContextNode *__llvm_ctx_profile_start_context(
     ContextRoot *Root, GUID Guid, uint32_t Counters,
     uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
   IsUnderContext = true;
+  __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
+                                __sanitizer::memory_order_relaxed);
+
   if (!Root->FirstMemBlock) {
     setupContext(Root, Guid, Counters, Callsites);
   }
@@ -374,6 +377,7 @@ void __llvm_ctx_profile_start_collection() {
       ++NumMemUnits;
 
     resetContextNode(*Root->FirstNode);
+    __sanitizer::atomic_store_relaxed(&Root->TotalEntries, 0);
   }
   __sanitizer::atomic_store_relaxed(&ProfilingStarted, true);
   __sanitizer::Printf("[ctxprof] Initial NumMemUnits: %zu \n", NumMemUnits);
@@ -393,7 +397,8 @@ bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
       __sanitizer::Printf("[ctxprof] Contextual Profile is %s\n", "invalid");
       return false;
     }
-    Writer.writeContextual(*Root->FirstNode);
+    Writer.writeContextual(*Root->FirstNode, __sanitizer::atomic_load_relaxed(
+                                                 &Root->TotalEntries));
   }
   Writer.endContextSection();
   Writer.startFlatSection();
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
index ab6df6d15e704..72cc60bf523e1 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
@@ -80,6 +80,10 @@ struct ContextRoot {
   ContextNode *FirstNode = nullptr;
   Arena *FirstMemBlock = nullptr;
   Arena *CurrentMem = nullptr;
+
+  // Count the number of entries - regardless if we could take the `Taken` mutex
+  ::__sanitizer::atomic_uint64_t TotalEntries = {};
+
   // This is init-ed by the static zero initializer in LLVM.
   // Taken is used to ensure only one thread traverses the contextual graph -
   // either to read it or to write it. On server side, the same entrypoint will
diff --git a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
index f837424ccca52..62c7f53acec5f 100644
--- a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
+++ b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
@@ -238,7 +238,9 @@ TEST_F(ContextTest, Dump) {
     TestProfileWriter(ContextRoot *Root, size_t Entries)
         : Root(Root), Entries(Entries) {}
 
-    void writeContextual(const ContextNode &Node) override {
+    void writeContextual(const ContextNode &Node,
+                         uint64_t TotalRootEntryCount) override {
+      EXPECT_EQ(TotalRootEntryCount, Entries);
       EXPECT_EQ(EnteredSectionCount, 1);
       EXPECT_EQ(ExitedSectionCount, 0);
       EXPECT_FALSE(Root->Taken.TryLock());
diff --git a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
index bf33b4423fd1f..319f17debe48f 100644
--- a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
+++ b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
@@ -84,7 +84,10 @@ class TestProfileWriter : public ProfileWriter {
     std::cout << "Exited Context Section" << std::endl;
   }
 
-  void writeContextual(const ContextNode &RootNode) override {
+  void writeContextual(const ContextNode &RootNode,
+                       uint64_t EntryCount) override {
+    std::cout << "Entering Root " << RootNode.guid()
+              << " with total entry count " << EntryCount << std::endl;
     printProfile(RootNode, "", "");
   }
 
@@ -115,6 +118,7 @@ class TestProfileWriter : public ProfileWriter {
 // The second context is in the loop. We expect 2 entries and each of the
 // branches would be taken once, so the second counter is 1.
 // CHECK-NEXT: Entered Context Section
+// CHECK-NEXT: Entering Root 8657661246551306189 with total entry count 1
 // CHECK-NEXT: Guid: 8657661246551306189
 // CHECK-NEXT: Entries: 1
 // CHECK-NEXT: 2 counters and 3 callsites
diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
index 0fc4883305145..55962df57fb58 100644
--- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
+++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
@@ -120,7 +120,8 @@ class ContextNode final {
 class ProfileWriter {
 public:
   virtual void startContextSection() = 0;
-  virtual void writeContextual(const ctx_profile::ContextNode &RootNode) = 0;
+  virtual void writeContextual(const ctx_profile::ContextNode &RootNode,
+                               uint64_t TotalRootEntryCount) = 0;
   virtual void endContextSection() = 0;
 
   virtual void startFlatSection() = 0;
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index 4b0c944a5258c..48f2c4efd020d 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -92,10 +92,13 @@ class PGOCtxProfContext final : public internal::IndexNode {
 
   GlobalValue::GUID GUID = 0;
   SmallVector<uint64_t, 16> Counters;
+  const std::optional<uint64_t> RootEntryCount;
   CallsiteMapTy Callsites;
 
-  PGOCtxProfContext(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters)
-      : GUID(G), Counters(std::move(Counters)) {}
+  PGOCtxProfContext(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters,
+                    std::optional<uint64_t> RootEntryCount = std::nullopt)
+      : GUID(G), Counters(std::move(Counters)), RootEntryCount(RootEntryCount) {
+  }
 
   Expected<PGOCtxProfContext &>
   getOrEmplace(uint32_t Index, GlobalValue::GUID G,
@@ -115,6 +118,9 @@ class PGOCtxProfContext final : public internal::IndexNode {
   const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
   SmallVectorImpl<uint64_t> &counters() { return Counters; }
 
+  bool isRoot() const { return RootEntryCount.has_value(); }
+  uint64_t getTotalRootEntryCount() const { return *RootEntryCount; }
+
   uint64_t getEntrycount() const {
     assert(!Counters.empty() &&
            "Functions are expected to have at their entry BB instrumented, so "
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
index c5a724d9a2142..b2bb8fea10cfe 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
@@ -19,7 +19,14 @@
 #include "llvm/ProfileData/CtxInstrContextNode.h"
 
 namespace llvm {
-enum PGOCtxProfileRecords { Invalid = 0, Version, Guid, CalleeIndex, Counters };
+enum PGOCtxProfileRecords {
+  Invalid = 0,
+  Version,
+  Guid,
+  CallsiteIndex,
+  Counters,
+  TotalRootEntryCount
+};
 
 enum PGOCtxProfileBlockIDs {
   FIRST_VALID = bitc::FIRST_APPLICATION_BLOCKID,
@@ -73,9 +80,11 @@ class PGOCtxProfileWriter final : public ctx_profile::ProfileWriter {
   const bool IncludeEmpty;
 
   void writeGuid(ctx_profile::GUID Guid);
+  void writeCallsiteIndex(uint32_t Index);
+  void writeRootEntryCount(uint64_t EntryCount);
   void writeCounters(ArrayRef<uint64_t> Counters);
-  void writeImpl(std::optional<uint32_t> CallerIndex,
-                 const ctx_profile::ContextNode &Node);
+  void writeNode(uint32_t CallerIndex, const ctx_profile::ContextNode &Node);
+  void writeSubcontexts(const ctx_profile::ContextNode &Node);
 
 public:
   PGOCtxProfileWriter(raw_ostream &Out,
@@ -84,7 +93,8 @@ class PGOCtxProfileWriter final : public ctx_profile::ProfileWriter {
   ~PGOCtxProfileWriter() { Writer.ExitBlock(); }
 
   void startContextSection() override;
-  void writeContextual(const ctx_profile::ContextNode &RootNode) override;
+  void writeContextual(const ctx_profile::ContextNode &RootNode,
+                       uint64_t TotalRootEntryCount) override;
   void endContextSection() override;
 
   void startFlatSection() override;
@@ -94,7 +104,7 @@ class PGOCtxProfileWriter final : public ctx_profile::ProfileWriter {
 
   // constants used in writing which a reader may find useful.
   static constexpr unsigned CodeLen = 2;
-  static constexpr uint32_t CurrentVersion = 2;
+  static constexpr uint32_t CurrentVersion = 3;
   static constexpr unsigned VBREncodingBits = 6;
   static constexpr StringRef ContainerMagic = "CTXP";
 };
diff --git a/llvm/lib/ProfileData/PGOCtxProfReader.cpp b/llvm/lib/ProfileData/PGOCtxProfReader.cpp
index 5cc4c94c74b76..f53f2956a7b7e 100644
--- a/llvm/lib/ProfileData/PGOCtxProfReader.cpp
+++ b/llvm/lib/ProfileData/PGOCtxProfReader.cpp
@@ -96,16 +96,19 @@ PGOCtxProfileReader::readProfile(PGOCtxProfileBlockIDs Kind) {
   std::optional<ctx_profile::GUID> Guid;
   std::optional<SmallVector<uint64_t, 16>> Counters;
   std::optional<uint32_t> CallsiteIndex;
+  std::optional<uint64_t> TotalEntryCount;
 
   SmallVector<uint64_t, 1> RecordValues;
 
   const bool ExpectIndex = Kind == PGOCtxProfileBlockIDs::ContextNodeBlockID;
+  const bool IsRoot = Kind == PGOCtxProfileBlockIDs::ContextRootBlockID;
   // We don't prescribe the order in which the records come in, and we are ok
   // if other unsupported records appear. We seek in the current subblock until
   // we get all we know.
   auto GotAllWeNeed = [&]() {
     return Guid.has_value() && Counters.has_value() &&
-           (!ExpectIndex || CallsiteIndex.has_value());
+           (!ExpectIndex || CallsiteIndex.has_value()) &&
+           (!IsRoot || TotalEntryCount.has_value());
   };
   while (!GotAllWeNeed()) {
     RecordValues.clear();
@@ -127,13 +130,21 @@ PGOCtxProfileReader::readProfile(PGOCtxProfileBlockIDs Kind) {
         return wrongValue("Empty counters. At least the entry counter (one "
                           "value) was expected");
       break;
-    case PGOCtxProfileRecords::CalleeIndex:
+    case PGOCtxProfileRecords::CallsiteIndex:
       if (!ExpectIndex)
         return wrongValue("The root context should not have a callee index");
       if (RecordValues.size() != 1)
         return wrongValue("The callee index should have exactly one value");
       CallsiteIndex = RecordValues[0];
       break;
+    case PGOCtxProfileRecords::TotalRootEntryCount:
+      if (!IsRoot)
+        return wrongValue("Non-root has a total entry count record");
+      if (RecordValues.size() != 1)
+        return wrongValue(
+            "The root total entry count record should have exactly one value");
+      TotalEntryCount = RecordValues[0];
+      break;
     default:
       // OK if we see records we do not understand, like records (profile
       // components) introduced later.
@@ -141,7 +152,7 @@ PGOCtxProfileReader::readProfile(PGOCtxProfileBlockIDs Kind) {
     }
   }
 
-  PGOCtxProfContext Ret(*Guid, std::move(*Counters));
+  PGOCtxProfContext Ret(*Guid, std::move(*Counters), TotalEntryCount);
 
   while (canEnterBlockWithID(PGOCtxProfileBlockIDs::ContextNodeBlockID)) {
     EXPECT_OR_RET(SC, readProfile(PGOCtxProfileBlockIDs::ContextNodeBlockID));
@@ -278,7 +289,8 @@ void toYaml(yaml::Output &Out,
 
 void toYaml(yaml::Output &Out, GlobalValue::GUID Guid,
             const SmallVectorImpl<uint64_t> &Counters,
-            const PGOCtxProfContext::CallsiteMapTy &Callsites) {
+            const PGOCtxProfContext::CallsiteMapTy &Callsites,
+            std::optional<uint64_t> TotalRootEntryCount = std::nullopt) {
   yaml::EmptyContext Empty;
   Out.beginMapping();
   void *SaveInfo = nullptr;
@@ -289,6 +301,11 @@ void toYaml(yaml::Output &Out, GlobalValue::GUID Guid,
     yaml::yamlize(Out, Guid, true, Empty);
     Out.postflightKey(nullptr);
   }
+  if (TotalRootEntryCount) {
+    Out.preflightKey("TotalRootEntryCount", true, false, UseDefault, SaveInfo);
+    yaml::yamlize(Out, *TotalRootEntryCount, true, Empty);
+    Out.postflightKey(nullptr);
+  }
   {
     Out.preflightKey("Counters", true, false, UseDefault, SaveInfo);
     Out.beginFlowSequence();
@@ -308,8 +325,13 @@ void toYaml(yaml::Output &Out, GlobalValue::GUID Guid,
   }
   Out.endMapping();
 }
+
 void toYaml(yaml::Output &Out, const PGOCtxProfContext &Ctx) {
-  toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites());
+  if (Ctx.isRoot())
+    toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites(),
+           Ctx.getTotalRootEntryCount());
+  else
+    toYaml(Out, Ctx.guid(), Ctx.counters(), Ctx.callsites());
 }
 
 } // namespace
diff --git a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
index e906836b16b2b..95981d231cd6c 100644
--- a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
+++ b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
@@ -55,10 +55,12 @@ PGOCtxProfileWriter::PGOCtxProfileWriter(
     DescribeBlock(PGOCtxProfileBlockIDs::ContextsSectionBlockID, "Contexts");
     DescribeBlock(PGOCtxProfileBlockIDs::ContextRootBlockID, "Root");
     DescribeRecord(PGOCtxProfileRecords::Guid, "GUID");
+    DescribeRecord(PGOCtxProfileRecords::TotalRootEntryCount,
+                   "TotalRootEntryCount");
     DescribeRecord(PGOCtxProfileRecords::Counters, "Counters");
     DescribeBlock(PGOCtxProfileBlockIDs::ContextNodeBlockID, "Context");
     DescribeRecord(PGOCtxProfileRecords::Guid, "GUID");
-    DescribeRecord(PGOCtxProfileRecords::CalleeIndex, "CalleeIndex");
+    DescribeRecord(PGOCtxProfileRecords::CallsiteIndex, "CalleeIndex");
     DescribeRecord(PGOCtxProfileRecords::Counters, "Counters");
     DescribeBlock(PGOCtxProfileBlockIDs::FlatProfilesSectionBlockID,
                   "FlatProfiles");
@@ -85,29 +87,39 @@ void PGOCtxProfileWriter::writeGuid(ctx_profile::GUID Guid) {
   Writer.EmitRecord(PGOCtxProfileRecords::Guid, SmallVector<uint64_t, 1>{Guid});
 }
 
+void PGOCtxProfileWriter::writeCallsiteIndex(uint32_t CallsiteIndex) {
+  Writer.EmitRecord(PGOCtxProfileRecords::CallsiteIndex,
+                    SmallVector<uint64_t, 1>{CallsiteIndex});
+}
+
+void PGOCtxProfileWriter::writeRootEntryCount(uint64_t TotalRootEntryCount) {
+  Writer.EmitRecord(PGOCtxProfileRecords::TotalRootEntryCount,
+                    SmallVector<uint64_t, 1>{TotalRootEntryCount});
+}
+
 // recursively write all the subcontexts. We do need to traverse depth first to
 // model the context->subcontext implicitly, and since this captures call
 // stacks, we don't really need to be worried about stack overflow and we can
 // keep the implementation simple.
-void PGOCtxProfileWriter::writeImpl(std::optional<uint32_t> CallerIndex,
+void PGOCtxProfileWriter::writeNode(uint32_t CallsiteIndex,
                                     const ContextNode &Node) {
   // A node with no counters is an error. We don't expect this to happen from
   // the runtime, rather, this is interesting for testing the reader.
   if (!IncludeEmpty && (Node.counters_size() > 0 && Node.entrycount() == 0))
     return;
-  Writer.EnterSubblock(CallerIndex ? PGOCtxProfileBlockIDs::ContextNodeBlockID
-                                   : PGOCtxProfileBlockIDs::ContextRootBlockID,
-                       CodeLen);
+  Writer.EnterSubblock(PGOCtxProfileBlockIDs::ContextNodeBlockID, CodeLen);
   writeGuid(Node.guid());
-  if (CallerIndex)
-    Writer.EmitRecord(PGOCtxProfileRecords::CalleeIndex,
-                      SmallVector<uint64_t, 1>{*CallerIndex});
+  writeCallsiteIndex(CallsiteIndex);
   writeCounters({Node.counters(), Node.counters_size()});
+  writeSubcontexts(Node);
+  Writer.ExitBlock();
+}
+
+void PGOCtxProfileWriter::writeSubcontexts(const ContextNode &Node) {
   for (uint32_t I = 0U; I < Node.callsites_size(); ++I)
     for (const auto *Subcontext = Node.subContexts()[I]; Subcontext;
          Subcontext = Subcontext->next())
-      writeImpl(I, *Subcontext);
-  Writer.ExitBlock();
+      writeNode(I, *Subcontext);
 }
 
 void PGOCtxProfileWriter::startContextSection() {
@@ -122,8 +134,16 @@ void PGOCtxProfileWriter::startFlatSection() {
 void PGOCtxProfileWriter::endContextSection() { Writer.ExitBlock(); }
 void PGOCtxProfileWriter::endFlatSection() { Writer.ExitBlock(); }
 
-void PGOCtxProfileWriter::writeContextual(const ContextNode &RootNode) {
-  writeImpl(std::nullopt, RootNode);
+void PGOCtxProfileWriter::writeContextual(const ContextNode &RootNode,
+                                          uint64_t TotalRootEntryCount) {
+  if (!IncludeEmpty && !TotalRootEntryCount)
+    return;
+  Writer.EnterSubblock(PGOCtxProfileBlockIDs::ContextRootBlockID, CodeLen);
+  writeGuid(RootNode.guid());
+  writeRootEntryCount(TotalRootEntryCount);
+  writeCounters({RootNode.counters(), RootNode.counters_size()});
+  writeSubcontexts(RootNode);
+  Writer.ExitBlock();
 }
 
 void PGOCtxProfileWriter::writeFlat(ctx_profile::GUID Guid,
@@ -144,11 +164,15 @@ struct SerializableCtxRepresentation {
   std::vector<std::vector<SerializableCtxRepresentation>> Callsites;
 };
 
+struct SerializableRootRepresentation : public SerializableCtxRepresentation {
+  uint64_t TotalRootEntryCount = 0;
+};
+
 using SerializableFlatProfileRepresentation =
     std::pair<ctx_profile::GUID, std::vector<uint64_t>>;
 
 struct SerializableProfileRepresentation {
-  std::vector<SerializableCtxRepresentation> Contexts;
+  std::vector<SerializableRootRepresentation> Contexts;
   std::vector<SerializableFlatProfileRepresentation> FlatProfiles;
 };
 
@@ -189,6 +213,7 @@ createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
 
 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableCtxRepresentation)
 LLVM_YAML_IS_SEQUENCE_VECTOR(std::vector<SerializableCtxRepresentation>)
+LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableRootRepresentation)
 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializableFlatProfileRepresentation)
 template <> struct yaml::MappingTraits<SerializableCtxRepresentation> {
   static void mapping(yaml::IO &IO, SerializableCtxRepresentation &SCR) {
@@ -198,6 +223,13 @@ template <> struct yaml::MappingTraits<SerializableCtxRepresentation> {
   }
 };
 
+template <> struct yaml::MappingTraits<SerializableRootRepresentation> {
+  static void mapping(yaml::IO &IO, SerializableRootRepresentation &R) {
+    yaml::MappingTraits<SerializableCtxRepresentation>::mapping(IO, R);
+    IO.mapRequired("TotalRootEntryCount", R.TotalRootEntryCount);
+  }
+};
+
 template <> struct yaml::MappingTraits<SerializableProfileRepresentation> {
   static void mapping(yaml::IO &IO, SerializableProfileRepresentation &SPR) {
     IO.mapOptional("Contexts", SPR.Contexts);
@@ -232,7 +264,7 @@ Error llvm::createCtxProfFromYAML(StringRef Profile, raw_ostream &Out) {
       if (!TopList)
         return createStringError(
             "Unexpected error converting internal structure to ctx profile");
-      Writer.writeContextual(*TopList);
+      Writer.writeContextual(*TopList, DC.TotalRootEntryCount);
     }
     Writer.endContextSection();
   }
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
index 20eaf59576855..c7b325bdbfff9 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
@@ -61,6 +61,7 @@
 ;--- profile.yaml
 Contexts:
   - Guid: 4909520559318251808
+    TotalRootEntryCount: 100
     Counters: [100, 40]
     Callsites: -
                 - Guid: 11872291593386833696
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
index eb697b69e2c02..b10eb6a6ec1b1 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-check-path.ll
@@ -41,6 +41,7 @@ exit:
 ;--- profile_ok.yaml...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/131201


More information about the llvm-commits mailing list