[llvm] [nfc] Fix RTTI for `InstrProf` intrinsics (PR #83511)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 29 23:34:30 PST 2024
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/83511
>From c746a8f38e342695864cee5bdfebe9eed5542337 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 29 Feb 2024 15:12:07 -0800
Subject: [PATCH] [pgo][nfc] Model `Count` as a `std::optional` in
`PGOUseBBInfo` (#83364)
Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
---
llvm/include/llvm/IR/IntrinsicInst.h | 37 +++++++-
.../Instrumentation/PGOInstrumentation.cpp | 86 +++++++++----------
2 files changed, 74 insertions(+), 49 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index fbaaef8ea44315..c07b83a81a63e1 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1429,7 +1429,35 @@ class VACopyInst : public IntrinsicInst {
/// A base class for all instrprof intrinsics.
class InstrProfInstBase : public IntrinsicInst {
+protected:
+ static bool isCounterBase(const IntrinsicInst &I) {
+ switch (I.getIntrinsicID()) {
+ case Intrinsic::instrprof_cover:
+ case Intrinsic::instrprof_increment:
+ case Intrinsic::instrprof_increment_step:
+ case Intrinsic::instrprof_timestamp:
+ case Intrinsic::instrprof_value_profile:
+ return true;
+ }
+ return false;
+ }
+ static bool isMCDCBitmapBase(const IntrinsicInst &I) {
+ switch (I.getIntrinsicID()) {
+ case Intrinsic::instrprof_mcdc_parameters:
+ case Intrinsic::instrprof_mcdc_tvbitmap_update:
+ return true;
+ }
+ return false;
+ }
+
public:
+ static bool classof(const Value *V) {
+ if (const auto *Instr = dyn_cast<IntrinsicInst>(V))
+ return isCounterBase(*Instr) || isMCDCBitmapBase(*Instr) ||
+ Instr->getIntrinsicID() ==
+ Intrinsic::instrprof_mcdc_condbitmap_update;
+ return false;
+ }
// The name of the instrumented function.
GlobalVariable *getName() const {
return cast<GlobalVariable>(
@@ -1444,6 +1472,12 @@ class InstrProfInstBase : public IntrinsicInst {
/// A base class for all instrprof counter intrinsics.
class InstrProfCntrInstBase : public InstrProfInstBase {
public:
+ static bool classof(const Value *V) {
+ if (const auto *Instr = dyn_cast<IntrinsicInst>(V))
+ return InstrProfInstBase::isCounterBase(*Instr);
+ return false;
+ }
+
// The number of counters for the instrumented function.
ConstantInt *getNumCounters() const;
// The index of the counter that this instruction acts on.
@@ -1524,8 +1558,7 @@ class InstrProfValueProfileInst : public InstrProfCntrInstBase {
class InstrProfMCDCBitmapInstBase : public InstrProfInstBase {
public:
static bool classof(const IntrinsicInst *I) {
- return I->getIntrinsicID() == Intrinsic::instrprof_mcdc_parameters ||
- I->getIntrinsicID() == Intrinsic::instrprof_mcdc_tvbitmap_update;
+ return InstrProfInstBase::isMCDCBitmapBase(*I);
}
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c20fc942eaf0d5..0c042e73ba0836 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
// This class stores the auxiliary information for each BB.
struct PGOUseBBInfo : public PGOBBInfo {
- uint64_t CountValue = 0;
- bool CountValid;
+ std::optional<uint64_t> Count;
int32_t UnknownCountInEdge = 0;
int32_t UnknownCountOutEdge = 0;
DirectEdges InEdges;
DirectEdges OutEdges;
- PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
+ PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
// Set the profile count value for this BB.
- void setBBInfoCount(uint64_t Value) {
- CountValue = Value;
- CountValid = true;
- }
+ void setBBInfoCount(uint64_t Value) { Count = Value; }
// Return the information string of this object.
std::string infoString() const {
- if (!CountValid)
+ if (!Count)
return PGOBBInfo::infoString();
- return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
- .str();
+ return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str();
}
// Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
// If only one out-edge, the edge profile count should be the same as BB
// profile count.
- if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
- setEdgeCount(E.get(), SrcInfo.CountValue);
+ if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
+ setEdgeCount(E.get(), *SrcInfo.Count);
else {
const BasicBlock *DestBB = E->DestBB;
PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
// If only one in-edge, the edge profile count should be the same as BB
// profile count.
- if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
- setEdgeCount(E.get(), DestInfo.CountValue);
+ if (DestInfo.Count && DestInfo.InEdges.size() == 1)
+ setEdgeCount(E.get(), *DestInfo.Count);
}
if (E->CountValid)
continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
// For efficient traversal, it's better to start from the end as most
// of the instrumented edges are at the end.
for (auto &BB : reverse(F)) {
- PGOUseBBInfo *Count = findBBInfo(&BB);
- if (Count == nullptr)
+ PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
+ if (UseBBInfo == nullptr)
continue;
- if (!Count->CountValid) {
- if (Count->UnknownCountOutEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->OutEdges);
- Count->CountValid = true;
+ if (!UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
Changes = true;
- } else if (Count->UnknownCountInEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->InEdges);
- Count->CountValid = true;
+ } else if (UseBBInfo->UnknownCountInEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
Changes = true;
}
}
- if (Count->CountValid) {
- if (Count->UnknownCountOutEdge == 1) {
+ if (UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 1) {
uint64_t Total = 0;
- uint64_t OutSum = sumEdgeCount(Count->OutEdges);
+ uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
// If the one of the successor block can early terminate (no-return),
// we can end up with situation where out edge sum count is larger as
// the source BB's count is collected by a post-dominated block.
- if (Count->CountValue > OutSum)
- Total = Count->CountValue - OutSum;
- setEdgeCount(Count->OutEdges, Total);
+ if (*UseBBInfo->Count > OutSum)
+ Total = *UseBBInfo->Count - OutSum;
+ setEdgeCount(UseBBInfo->OutEdges, Total);
Changes = true;
}
- if (Count->UnknownCountInEdge == 1) {
+ if (UseBBInfo->UnknownCountInEdge == 1) {
uint64_t Total = 0;
- uint64_t InSum = sumEdgeCount(Count->InEdges);
- if (Count->CountValue > InSum)
- Total = Count->CountValue - InSum;
- setEdgeCount(Count->InEdges, Total);
+ uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
+ if (*UseBBInfo->Count > InSum)
+ Total = *UseBBInfo->Count - InSum;
+ setEdgeCount(UseBBInfo->InEdges, Total);
Changes = true;
}
}
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- assert(BI->CountValid && "BB count is not valid");
+ assert(BI->Count && "BB count is not valid");
}
#endif
- uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
uint64_t FuncMaxCount = FuncEntryCount;
for (auto &BB : F) {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
+ FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
}
// Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
isa<CallBrInst>(TI)))
continue;
- if (getBBInfo(&BB).CountValue == 0)
+ const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
+ if (!*BBCountInfo.Count)
continue;
// We have a non-zero Branch BB.
- const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
unsigned Size = BBCountInfo.OutEdges.size();
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
Instruction *TI = BB.getTerminator();
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
- setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
+ setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
}
}
}
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
uint64_t TotalCount = 0;
auto BI = UseFunc->findBBInfo(SI.getParent());
if (BI != nullptr)
- TotalCount = BI->CountValue;
+ TotalCount = *BI->Count;
// False Count
SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (!Func.findBBInfo(&BBI))
continue;
auto BFICount = NBFI.getBlockProfileCount(&BBI);
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = *Func.getBBInfo(&BBI).Count;
BFICountValue = *BFICount;
SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (Scale < 1.001 && Scale > 0.999)
return;
- uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
if (NewEntryCount == 0)
NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
uint64_t CountValue = 0;
uint64_t BFICountValue = 0;
- if (Func.getBBInfo(&BBI).CountValid)
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
BBNum++;
if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
OS << getSimpleNodeName(Node) << ":\\l";
PGOUseBBInfo *BI = Graph->findBBInfo(Node);
OS << "Count : ";
- if (BI && BI->CountValid)
- OS << BI->CountValue << "\\l";
+ if (BI && BI->Count)
+ OS << *BI->Count << "\\l";
else
OS << "Unknown\\l";
More information about the llvm-commits
mailing list