[llvm] [pgo][nfc] Model `Count` as a `std::optional` in `PGOUseBBInfo` (PR #83364)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 17:10:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-pgo

Author: Mircea Trofin (mtrofin)

<details>
<summary>Changes</summary>

Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)

---
Full diff: https://github.com/llvm/llvm-project/pull/83364.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+39-47) 


``````````diff
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c20fc942eaf0d5..0c042e73ba0836 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
 
 // This class stores the auxiliary information for each BB.
 struct PGOUseBBInfo : public PGOBBInfo {
-  uint64_t CountValue = 0;
-  bool CountValid;
+  std::optional<uint64_t> Count;
   int32_t UnknownCountInEdge = 0;
   int32_t UnknownCountOutEdge = 0;
   DirectEdges InEdges;
   DirectEdges OutEdges;
 
-  PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
+  PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
 
   // Set the profile count value for this BB.
-  void setBBInfoCount(uint64_t Value) {
-    CountValue = Value;
-    CountValid = true;
-  }
+  void setBBInfoCount(uint64_t Value) { Count = Value; }
 
   // Return the information string of this object.
   std::string infoString() const {
-    if (!CountValid)
+    if (!Count)
       return PGOBBInfo::infoString();
-    return (Twine(PGOBBInfo::infoString()) + "  Count=" + Twine(CountValue))
-        .str();
+    return (Twine(PGOBBInfo::infoString()) + "  Count=" + Twine(*Count)).str();
   }
 
   // Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
 
     // If only one out-edge, the edge profile count should be the same as BB
     // profile count.
-    if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
-      setEdgeCount(E.get(), SrcInfo.CountValue);
+    if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
+      setEdgeCount(E.get(), *SrcInfo.Count);
     else {
       const BasicBlock *DestBB = E->DestBB;
       PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
       // If only one in-edge, the edge profile count should be the same as BB
       // profile count.
-      if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
-        setEdgeCount(E.get(), DestInfo.CountValue);
+      if (DestInfo.Count && DestInfo.InEdges.size() == 1)
+        setEdgeCount(E.get(), *DestInfo.Count);
     }
     if (E->CountValid)
       continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
     // For efficient traversal, it's better to start from the end as most
     // of the instrumented edges are at the end.
     for (auto &BB : reverse(F)) {
-      PGOUseBBInfo *Count = findBBInfo(&BB);
-      if (Count == nullptr)
+      PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
+      if (UseBBInfo == nullptr)
         continue;
-      if (!Count->CountValid) {
-        if (Count->UnknownCountOutEdge == 0) {
-          Count->CountValue = sumEdgeCount(Count->OutEdges);
-          Count->CountValid = true;
+      if (!UseBBInfo->Count) {
+        if (UseBBInfo->UnknownCountOutEdge == 0) {
+          UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
           Changes = true;
-        } else if (Count->UnknownCountInEdge == 0) {
-          Count->CountValue = sumEdgeCount(Count->InEdges);
-          Count->CountValid = true;
+        } else if (UseBBInfo->UnknownCountInEdge == 0) {
+          UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
           Changes = true;
         }
       }
-      if (Count->CountValid) {
-        if (Count->UnknownCountOutEdge == 1) {
+      if (UseBBInfo->Count) {
+        if (UseBBInfo->UnknownCountOutEdge == 1) {
           uint64_t Total = 0;
-          uint64_t OutSum = sumEdgeCount(Count->OutEdges);
+          uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
           // If the one of the successor block can early terminate (no-return),
           // we can end up with situation where out edge sum count is larger as
           // the source BB's count is collected by a post-dominated block.
-          if (Count->CountValue > OutSum)
-            Total = Count->CountValue - OutSum;
-          setEdgeCount(Count->OutEdges, Total);
+          if (*UseBBInfo->Count > OutSum)
+            Total = *UseBBInfo->Count - OutSum;
+          setEdgeCount(UseBBInfo->OutEdges, Total);
           Changes = true;
         }
-        if (Count->UnknownCountInEdge == 1) {
+        if (UseBBInfo->UnknownCountInEdge == 1) {
           uint64_t Total = 0;
-          uint64_t InSum = sumEdgeCount(Count->InEdges);
-          if (Count->CountValue > InSum)
-            Total = Count->CountValue - InSum;
-          setEdgeCount(Count->InEdges, Total);
+          uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
+          if (*UseBBInfo->Count > InSum)
+            Total = *UseBBInfo->Count - InSum;
+          setEdgeCount(UseBBInfo->InEdges, Total);
           Changes = true;
         }
       }
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
     auto BI = findBBInfo(&BB);
     if (BI == nullptr)
       continue;
-    assert(BI->CountValid && "BB count is not valid");
+    assert(BI->Count && "BB count is not valid");
   }
 #endif
-  uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
+  uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
   uint64_t FuncMaxCount = FuncEntryCount;
   for (auto &BB : F) {
     auto BI = findBBInfo(&BB);
     if (BI == nullptr)
       continue;
-    FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
+    FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
   }
 
   // Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
           isa<CallBrInst>(TI)))
       continue;
 
-    if (getBBInfo(&BB).CountValue == 0)
+    const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
+    if (!*BBCountInfo.Count)
       continue;
 
     // We have a non-zero Branch BB.
-    const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
     unsigned Size = BBCountInfo.OutEdges.size();
     SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
     uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
     if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
       Instruction *TI = BB.getTerminator();
       const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
-      setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
+      setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
     }
   }
 }
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
   uint64_t TotalCount = 0;
   auto BI = UseFunc->findBBInfo(SI.getParent());
   if (BI != nullptr)
-    TotalCount = BI->CountValue;
+    TotalCount = *BI->Count;
   // False Count
   SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
   uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
     if (!Func.findBBInfo(&BBI))
       continue;
     auto BFICount = NBFI.getBlockProfileCount(&BBI);
-    CountValue = Func.getBBInfo(&BBI).CountValue;
+    CountValue = *Func.getBBInfo(&BBI).Count;
     BFICountValue = *BFICount;
     SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
     SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
   if (Scale < 1.001 && Scale > 0.999)
     return;
 
-  uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+  uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
   uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
   if (NewEntryCount == 0)
     NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
     uint64_t CountValue = 0;
     uint64_t BFICountValue = 0;
 
-    if (Func.getBBInfo(&BBI).CountValid)
-      CountValue = Func.getBBInfo(&BBI).CountValue;
+    CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
 
     BBNum++;
     if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
     OS << getSimpleNodeName(Node) << ":\\l";
     PGOUseBBInfo *BI = Graph->findBBInfo(Node);
     OS << "Count : ";
-    if (BI && BI->CountValid)
-      OS << BI->CountValue << "\\l";
+    if (BI && BI->Count)
+      OS << *BI->Count << "\\l";
     else
       OS << "Unknown\\l";
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/83364


More information about the llvm-commits mailing list