[llvm] a32c2c3 - [NFC] Use Optional<ProfileCount> to model invalid counts
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 14 19:17:44 PST 2021
Author: Mircea Trofin
Date: 2021-11-14T19:03:30-08:00
New Revision: a32c2c380863d02eb0fd5e8757a62d96114b9519
URL: https://github.com/llvm/llvm-project/commit/a32c2c380863d02eb0fd5e8757a62d96114b9519
DIFF: https://github.com/llvm/llvm-project/commit/a32c2c380863d02eb0fd5e8757a62d96114b9519.diff
LOG: [NFC] Use Optional<ProfileCount> to model invalid counts
ProfileCount could model invalid values, but a user had no indication
that the getCount method could return bogus data. Optional<ProfileCount>
addresses that, because the user must dereference the optional. In
addition, the patch removes concept duplication.
Differential Revision: https://reviews.llvm.org/D113839
Added:
Modified:
llvm/include/llvm/IR/Function.h
llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
llvm/lib/Analysis/InlineCost.cpp
llvm/lib/Analysis/ProfileSummaryInfo.cpp
llvm/lib/CodeGen/MachineSizeOpts.cpp
llvm/lib/IR/Function.cpp
llvm/lib/Transforms/IPO/PartialInlining.cpp
llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
llvm/lib/Transforms/Utils/InlineFunction.cpp
llvm/unittests/IR/MetadataTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index e5c675a64af00..669418eacbb0a 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -247,33 +247,22 @@ class LLVM_EXTERNAL_VISIBILITY Function : public GlobalObject,
setValueSubclassData((getSubclassDataFromValue() & 0xc00f) | (ID << 4));
}
- enum ProfileCountType { PCT_Invalid, PCT_Real, PCT_Synthetic };
+ enum ProfileCountType { PCT_Real, PCT_Synthetic };
/// Class to represent profile counts.
///
/// This class represents both real and synthetic profile counts.
class ProfileCount {
private:
- uint64_t Count;
- ProfileCountType PCT;
- static ProfileCount Invalid;
+ uint64_t Count = 0;
+ ProfileCountType PCT = PCT_Real;
public:
- ProfileCount() : Count(-1), PCT(PCT_Invalid) {}
ProfileCount(uint64_t Count, ProfileCountType PCT)
: Count(Count), PCT(PCT) {}
- bool hasValue() const { return PCT != PCT_Invalid; }
uint64_t getCount() const { return Count; }
ProfileCountType getType() const { return PCT; }
bool isSynthetic() const { return PCT == PCT_Synthetic; }
- explicit operator bool() { return hasValue(); }
- bool operator!() const { return !hasValue(); }
- // Update the count retaining the same profile count type.
- ProfileCount &setCount(uint64_t C) {
- Count = C;
- return *this;
- }
- static ProfileCount getInvalid() { return ProfileCount(-1, PCT_Invalid); }
};
/// Set the entry count for this function.
@@ -293,7 +282,7 @@ class LLVM_EXTERNAL_VISIBILITY Function : public GlobalObject,
///
/// Entry count is the number of times the function was executed.
/// When AllowSynthetic is false, only pgo_data will be returned.
- ProfileCount getEntryCount(bool AllowSynthetic = false) const;
+ Optional<ProfileCount> getEntryCount(bool AllowSynthetic = false) const;
/// Return true if the function is annotated with profile data.
///
diff --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
index e4e45b3076beb..2a5e1f65d7316 100644
--- a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
+++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
@@ -602,7 +602,7 @@ BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
if (!EntryCount)
return None;
// Use 128 bit APInt to do the arithmetic to avoid overflow.
- APInt BlockCount(128, EntryCount.getCount());
+ APInt BlockCount(128, EntryCount->getCount());
APInt BlockFreq(128, Freq);
APInt EntryFreq(128, getEntryFreq());
BlockCount *= BlockFreq;
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index a77dd32f1816e..6d97898c08e3a 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -771,7 +771,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// Make sure we have a nonzero entry count.
auto EntryCount = F.getEntryCount();
- if (!EntryCount || !EntryCount.getCount())
+ if (!EntryCount || !EntryCount->getCount())
return false;
BlockFrequencyInfo *CalleeBFI = &(GetBFI(F));
@@ -837,8 +837,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// Compute the cycle savings per call.
auto EntryProfileCount = F.getEntryCount();
- assert(EntryProfileCount.hasValue() && EntryProfileCount.getCount());
- auto EntryCount = EntryProfileCount.getCount();
+ assert(EntryProfileCount.hasValue() && EntryProfileCount->getCount());
+ auto EntryCount = EntryProfileCount->getCount();
CycleSavings += EntryCount / 2;
CycleSavings = CycleSavings.udiv(EntryCount);
diff --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
index 69ca5502e5753..268ed9d047419 100644
--- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp
+++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
@@ -103,7 +103,7 @@ bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) const {
// FIXME: The heuristic used below for determining hotness is based on
// preliminary SPEC tuning for inliner. This will eventually be a
// convenience method that calls isHotCount.
- return FunctionCount && isHotCount(FunctionCount.getCount());
+ return FunctionCount && isHotCount(FunctionCount->getCount());
}
/// Returns true if the function contains hot code. This can include a hot
@@ -116,7 +116,7 @@ bool ProfileSummaryInfo::isFunctionHotInCallGraph(
if (!F || !hasProfileSummary())
return false;
if (auto FunctionCount = F->getEntryCount())
- if (isHotCount(FunctionCount.getCount()))
+ if (isHotCount(FunctionCount->getCount()))
return true;
if (hasSampleProfile()) {
@@ -145,7 +145,7 @@ bool ProfileSummaryInfo::isFunctionColdInCallGraph(
if (!F || !hasProfileSummary())
return false;
if (auto FunctionCount = F->getEntryCount())
- if (!isColdCount(FunctionCount.getCount()))
+ if (!isColdCount(FunctionCount->getCount()))
return false;
if (hasSampleProfile()) {
@@ -176,10 +176,10 @@ bool ProfileSummaryInfo::isFunctionHotOrColdInCallGraphNthPercentile(
return false;
if (auto FunctionCount = F->getEntryCount()) {
if (isHot &&
- isHotCountNthPercentile(PercentileCutoff, FunctionCount.getCount()))
+ isHotCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
return true;
if (!isHot &&
- !isColdCountNthPercentile(PercentileCutoff, FunctionCount.getCount()))
+ !isColdCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
return false;
}
if (hasSampleProfile()) {
@@ -230,7 +230,7 @@ bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) const {
// FIXME: The heuristic used below for determining coldness is based on
// preliminary SPEC tuning for inliner. This will eventually be a
// convenience method that calls isHotCount.
- return FunctionCount && isColdCount(FunctionCount.getCount());
+ return FunctionCount && isColdCount(FunctionCount->getCount());
}
/// Compute the hot and cold thresholds.
diff --git a/llvm/lib/CodeGen/MachineSizeOpts.cpp b/llvm/lib/CodeGen/MachineSizeOpts.cpp
index 584d43b420044..28712d1a816bd 100644
--- a/llvm/lib/CodeGen/MachineSizeOpts.cpp
+++ b/llvm/lib/CodeGen/MachineSizeOpts.cpp
@@ -82,7 +82,7 @@ bool isFunctionColdInCallGraph(
ProfileSummaryInfo *PSI,
const MachineBlockFrequencyInfo &MBFI) {
if (auto FunctionCount = MF->getFunction().getEntryCount())
- if (!PSI->isColdCount(FunctionCount.getCount()))
+ if (!PSI->isColdCount(FunctionCount->getCount()))
return false;
for (const auto &MBB : *MF)
if (!isColdBlock(&MBB, PSI, &MBFI))
@@ -99,7 +99,7 @@ bool isFunctionHotInCallGraphNthPercentile(
const MachineBlockFrequencyInfo &MBFI) {
if (auto FunctionCount = MF->getFunction().getEntryCount())
if (PSI->isHotCountNthPercentile(PercentileCutoff,
- FunctionCount.getCount()))
+ FunctionCount->getCount()))
return true;
for (const auto &MBB : *MF)
if (isHotBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))
@@ -112,7 +112,7 @@ bool isFunctionColdInCallGraphNthPercentile(
const MachineBlockFrequencyInfo &MBFI) {
if (auto FunctionCount = MF->getFunction().getEntryCount())
if (!PSI->isColdCountNthPercentile(PercentileCutoff,
- FunctionCount.getCount()))
+ FunctionCount->getCount()))
return false;
for (const auto &MBB : *MF)
if (!isColdBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 7eddffab13b9d..82b20a8af91bf 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1892,10 +1892,9 @@ void Function::setValueSubclassDataBit(unsigned Bit, bool On) {
void Function::setEntryCount(ProfileCount Count,
const DenseSet<GlobalValue::GUID> *S) {
- assert(Count.hasValue());
#if !defined(NDEBUG)
auto PrevCount = getEntryCount();
- assert(!PrevCount.hasValue() || PrevCount.getType() == Count.getType());
+ assert(!PrevCount.hasValue() || PrevCount->getType() == Count.getType());
#endif
auto ImportGUIDs = getImportGUIDs();
@@ -1913,7 +1912,7 @@ void Function::setEntryCount(uint64_t Count, Function::ProfileCountType Type,
setEntryCount(ProfileCount(Count, Type), Imports);
}
-ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
+Optional<ProfileCount> Function::getEntryCount(bool AllowSynthetic) const {
MDNode *MD = getMetadata(LLVMContext::MD_prof);
if (MD && MD->getOperand(0))
if (MDString *MDS = dyn_cast<MDString>(MD->getOperand(0))) {
@@ -1923,7 +1922,7 @@ ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
// A value of -1 is used for SamplePGO when there were no samples.
// Treat this the same as unknown.
if (Count == (uint64_t)-1)
- return ProfileCount::getInvalid();
+ return None;
return ProfileCount(Count, PCT_Real);
} else if (AllowSynthetic &&
MDS->getString().equals("synthetic_function_entry_count")) {
@@ -1932,7 +1931,7 @@ ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
return ProfileCount(Count, PCT_Synthetic);
}
}
- return ProfileCount::getInvalid();
+ return None;
}
DenseSet<GlobalValue::GUID> Function::getImportGUIDs() const {
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 6921c46be7a8e..7402e399a88a0 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -1411,7 +1411,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
computeCallsiteToProfCountMap(Cloner.ClonedFunc, CallSiteToProfCountMap);
uint64_t CalleeEntryCountV =
- (CalleeEntryCount ? CalleeEntryCount.getCount() : 0);
+ (CalleeEntryCount ? CalleeEntryCount->getCount() : 0);
bool AnyInline = false;
for (User *User : Users) {
@@ -1459,8 +1459,8 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
if (AnyInline) {
Cloner.IsFunctionInlined = true;
if (CalleeEntryCount)
- Cloner.OrigFunc->setEntryCount(
- CalleeEntryCount.setCount(CalleeEntryCountV));
+ Cloner.OrigFunc->setEntryCount(Function::ProfileCount(
+ CalleeEntryCountV, CalleeEntryCount->getType()));
OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc);
OrigFuncORE.emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc)
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c437003060cbb..af5946325bbb8 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1686,7 +1686,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
BlockFrequencyInfo NBFI(F, NBPI, LI);
#ifndef NDEBUG
auto BFIEntryCount = F.getEntryCount();
- assert(BFIEntryCount.hasValue() && (BFIEntryCount.getCount() > 0) &&
+ assert(BFIEntryCount.hasValue() && (BFIEntryCount->getCount() > 0) &&
"Invalid BFI Entrycount");
#endif
auto SumCount = APFloat::getZero(APFloat::IEEEdouble());
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 6a3ac1ee77baa..f4776589910f2 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -1600,8 +1600,7 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap,
const ProfileCount &CalleeEntryCount,
const CallBase &TheCall, ProfileSummaryInfo *PSI,
BlockFrequencyInfo *CallerBFI) {
- if (!CalleeEntryCount.hasValue() || CalleeEntryCount.isSynthetic() ||
- CalleeEntryCount.getCount() < 1)
+ if (CalleeEntryCount.isSynthetic() || CalleeEntryCount.getCount() < 1)
return;
auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None;
int64_t CallCount =
@@ -1616,7 +1615,7 @@ void llvm::updateProfileCallee(
if (!CalleeCount.hasValue())
return;
- const uint64_t PriorEntryCount = CalleeCount.getCount();
+ const uint64_t PriorEntryCount = CalleeCount->getCount();
// Since CallSiteCount is an estimate, it could exceed the original callee
// count and has to be set to 0 so guard against underflow.
@@ -1969,8 +1968,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI,
CalledFunc->front());
- updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), CB,
- IFI.PSI, IFI.CallerBFI);
+ if (auto Profile = CalledFunc->getEntryCount())
+ updateCallProfile(CalledFunc, VMap, *Profile, CB, IFI.PSI,
+ IFI.CallerBFI);
}
// Inject byval arguments initialization.
diff --git a/llvm/unittests/IR/MetadataTest.cpp b/llvm/unittests/IR/MetadataTest.cpp
index 959d76db7d2d3..f870ec81a0172 100644
--- a/llvm/unittests/IR/MetadataTest.cpp
+++ b/llvm/unittests/IR/MetadataTest.cpp
@@ -3424,23 +3424,24 @@ TEST_F(FunctionAttachmentTest, Verifier) {
EXPECT_FALSE(verifyFunction(*F));
}
-TEST_F(FunctionAttachmentTest, EntryCount) {
+TEST_F(FunctionAttachmentTest, RealEntryCount) {
Function *F = getFunction("foo");
EXPECT_FALSE(F->getEntryCount().hasValue());
F->setEntryCount(12304, Function::PCT_Real);
auto Count = F->getEntryCount();
EXPECT_TRUE(Count.hasValue());
- EXPECT_EQ(12304u, Count.getCount());
- EXPECT_EQ(Function::PCT_Real, Count.getType());
+ EXPECT_EQ(12304u, Count->getCount());
+ EXPECT_EQ(Function::PCT_Real, Count->getType());
+}
- // Repeat the same for synthetic counts.
- F = getFunction("bar");
+TEST_F(FunctionAttachmentTest, SyntheticEntryCount) {
+ Function *F = getFunction("bar");
EXPECT_FALSE(F->getEntryCount().hasValue());
F->setEntryCount(123, Function::PCT_Synthetic);
- Count = F->getEntryCount(true /*allow synthetic*/);
+ auto Count = F->getEntryCount(true /*allow synthetic*/);
EXPECT_TRUE(Count.hasValue());
- EXPECT_EQ(123u, Count.getCount());
- EXPECT_EQ(Function::PCT_Synthetic, Count.getType());
+ EXPECT_EQ(123u, Count->getCount());
+ EXPECT_EQ(Function::PCT_Synthetic, Count->getType());
}
TEST_F(FunctionAttachmentTest, SubprogramAttachment) {
More information about the llvm-commits
mailing list