[llvm] [pgo][nfc] Model `Count` as a `std::optional` in `PGOUseBBInfo` (PR #83364)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 28 17:10:21 PST 2024
https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/83364
Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
>From 4563c9d36fedcec8848f5ca3dd3871edcc3af8d8 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 28 Feb 2024 17:06:31 -0800
Subject: [PATCH] [pgo][nfc] Model `Count` as a `std::optional` in
`PGOUseBBInfo`
Simpler code, compared to tracking state of 2 variables and the
ambiguity of "0" CountValue (is it 0 or is it invalid?)
---
.../Instrumentation/PGOInstrumentation.cpp | 86 +++++++++----------
1 file changed, 39 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c20fc942eaf0d5..0c042e73ba0836 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
// This class stores the auxiliary information for each BB.
struct PGOUseBBInfo : public PGOBBInfo {
- uint64_t CountValue = 0;
- bool CountValid;
+ std::optional<uint64_t> Count;
int32_t UnknownCountInEdge = 0;
int32_t UnknownCountOutEdge = 0;
DirectEdges InEdges;
DirectEdges OutEdges;
- PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
+ PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
// Set the profile count value for this BB.
- void setBBInfoCount(uint64_t Value) {
- CountValue = Value;
- CountValid = true;
- }
+ void setBBInfoCount(uint64_t Value) { Count = Value; }
// Return the information string of this object.
std::string infoString() const {
- if (!CountValid)
+ if (!Count)
return PGOBBInfo::infoString();
- return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
- .str();
+ return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str();
}
// Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
// If only one out-edge, the edge profile count should be the same as BB
// profile count.
- if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
- setEdgeCount(E.get(), SrcInfo.CountValue);
+ if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
+ setEdgeCount(E.get(), *SrcInfo.Count);
else {
const BasicBlock *DestBB = E->DestBB;
PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
// If only one in-edge, the edge profile count should be the same as BB
// profile count.
- if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
- setEdgeCount(E.get(), DestInfo.CountValue);
+ if (DestInfo.Count && DestInfo.InEdges.size() == 1)
+ setEdgeCount(E.get(), *DestInfo.Count);
}
if (E->CountValid)
continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
// For efficient traversal, it's better to start from the end as most
// of the instrumented edges are at the end.
for (auto &BB : reverse(F)) {
- PGOUseBBInfo *Count = findBBInfo(&BB);
- if (Count == nullptr)
+ PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
+ if (UseBBInfo == nullptr)
continue;
- if (!Count->CountValid) {
- if (Count->UnknownCountOutEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->OutEdges);
- Count->CountValid = true;
+ if (!UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
Changes = true;
- } else if (Count->UnknownCountInEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->InEdges);
- Count->CountValid = true;
+ } else if (UseBBInfo->UnknownCountInEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
Changes = true;
}
}
- if (Count->CountValid) {
- if (Count->UnknownCountOutEdge == 1) {
+ if (UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 1) {
uint64_t Total = 0;
- uint64_t OutSum = sumEdgeCount(Count->OutEdges);
+ uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
// If the one of the successor block can early terminate (no-return),
// we can end up with situation where out edge sum count is larger as
// the source BB's count is collected by a post-dominated block.
- if (Count->CountValue > OutSum)
- Total = Count->CountValue - OutSum;
- setEdgeCount(Count->OutEdges, Total);
+ if (*UseBBInfo->Count > OutSum)
+ Total = *UseBBInfo->Count - OutSum;
+ setEdgeCount(UseBBInfo->OutEdges, Total);
Changes = true;
}
- if (Count->UnknownCountInEdge == 1) {
+ if (UseBBInfo->UnknownCountInEdge == 1) {
uint64_t Total = 0;
- uint64_t InSum = sumEdgeCount(Count->InEdges);
- if (Count->CountValue > InSum)
- Total = Count->CountValue - InSum;
- setEdgeCount(Count->InEdges, Total);
+ uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
+ if (*UseBBInfo->Count > InSum)
+ Total = *UseBBInfo->Count - InSum;
+ setEdgeCount(UseBBInfo->InEdges, Total);
Changes = true;
}
}
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- assert(BI->CountValid && "BB count is not valid");
+ assert(BI->Count && "BB count is not valid");
}
#endif
- uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
uint64_t FuncMaxCount = FuncEntryCount;
for (auto &BB : F) {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
+ FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
}
// Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
isa<CallBrInst>(TI)))
continue;
- if (getBBInfo(&BB).CountValue == 0)
+ const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
+ if (!*BBCountInfo.Count)
continue;
// We have a non-zero Branch BB.
- const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
unsigned Size = BBCountInfo.OutEdges.size();
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
Instruction *TI = BB.getTerminator();
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
- setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
+ setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
}
}
}
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
uint64_t TotalCount = 0;
auto BI = UseFunc->findBBInfo(SI.getParent());
if (BI != nullptr)
- TotalCount = BI->CountValue;
+ TotalCount = *BI->Count;
// False Count
SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (!Func.findBBInfo(&BBI))
continue;
auto BFICount = NBFI.getBlockProfileCount(&BBI);
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = *Func.getBBInfo(&BBI).Count;
BFICountValue = *BFICount;
SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (Scale < 1.001 && Scale > 0.999)
return;
- uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
if (NewEntryCount == 0)
NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
uint64_t CountValue = 0;
uint64_t BFICountValue = 0;
- if (Func.getBBInfo(&BBI).CountValid)
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
BBNum++;
if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
OS << getSimpleNodeName(Node) << ":\\l";
PGOUseBBInfo *BI = Graph->findBBInfo(Node);
OS << "Count : ";
- if (BI && BI->CountValid)
- OS << BI->CountValue << "\\l";
+ if (BI && BI->Count)
+ OS << *BI->Count << "\\l";
else
OS << "Unknown\\l";
More information about the llvm-commits
mailing list