[llvm] a32c2c3 - [NFC] Use Optional<ProfileCount> to model invalid counts

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 14 19:17:44 PST 2021


Author: Mircea Trofin
Date: 2021-11-14T19:03:30-08:00
New Revision: a32c2c380863d02eb0fd5e8757a62d96114b9519

URL: https://github.com/llvm/llvm-project/commit/a32c2c380863d02eb0fd5e8757a62d96114b9519
DIFF: https://github.com/llvm/llvm-project/commit/a32c2c380863d02eb0fd5e8757a62d96114b9519.diff

LOG: [NFC] Use Optional<ProfileCount> to model invalid counts

ProfileCount could model invalid values, but a user had no indication
that the getCount method could return bogus data. Optional<ProfileCount>
addresses that, because the user must dereference the optional. In
addition, the patch removes concept duplication.

Differential Revision: https://reviews.llvm.org/D113839

Added: 
    

Modified: 
    llvm/include/llvm/IR/Function.h
    llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
    llvm/lib/Analysis/InlineCost.cpp
    llvm/lib/Analysis/ProfileSummaryInfo.cpp
    llvm/lib/CodeGen/MachineSizeOpts.cpp
    llvm/lib/IR/Function.cpp
    llvm/lib/Transforms/IPO/PartialInlining.cpp
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/lib/Transforms/Utils/InlineFunction.cpp
    llvm/unittests/IR/MetadataTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index e5c675a64af00..669418eacbb0a 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -247,33 +247,22 @@ class LLVM_EXTERNAL_VISIBILITY Function : public GlobalObject,
     setValueSubclassData((getSubclassDataFromValue() & 0xc00f) | (ID << 4));
   }
 
-  enum ProfileCountType { PCT_Invalid, PCT_Real, PCT_Synthetic };
+  enum ProfileCountType { PCT_Real, PCT_Synthetic };
 
   /// Class to represent profile counts.
   ///
   /// This class represents both real and synthetic profile counts.
   class ProfileCount {
   private:
-    uint64_t Count;
-    ProfileCountType PCT;
-    static ProfileCount Invalid;
+    uint64_t Count = 0;
+    ProfileCountType PCT = PCT_Real;
 
   public:
-    ProfileCount() : Count(-1), PCT(PCT_Invalid) {}
     ProfileCount(uint64_t Count, ProfileCountType PCT)
         : Count(Count), PCT(PCT) {}
-    bool hasValue() const { return PCT != PCT_Invalid; }
     uint64_t getCount() const { return Count; }
     ProfileCountType getType() const { return PCT; }
     bool isSynthetic() const { return PCT == PCT_Synthetic; }
-    explicit operator bool() { return hasValue(); }
-    bool operator!() const { return !hasValue(); }
-    // Update the count retaining the same profile count type.
-    ProfileCount &setCount(uint64_t C) {
-      Count = C;
-      return *this;
-    }
-    static ProfileCount getInvalid() { return ProfileCount(-1, PCT_Invalid); }
   };
 
   /// Set the entry count for this function.
@@ -293,7 +282,7 @@ class LLVM_EXTERNAL_VISIBILITY Function : public GlobalObject,
   ///
   /// Entry count is the number of times the function was executed.
   /// When AllowSynthetic is false, only pgo_data will be returned.
-  ProfileCount getEntryCount(bool AllowSynthetic = false) const;
+  Optional<ProfileCount> getEntryCount(bool AllowSynthetic = false) const;
 
   /// Return true if the function is annotated with profile data.
   ///

diff  --git a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
index e4e45b3076beb..2a5e1f65d7316 100644
--- a/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
+++ b/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
@@ -602,7 +602,7 @@ BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
   if (!EntryCount)
     return None;
   // Use 128 bit APInt to do the arithmetic to avoid overflow.
-  APInt BlockCount(128, EntryCount.getCount());
+  APInt BlockCount(128, EntryCount->getCount());
   APInt BlockFreq(128, Freq);
   APInt EntryFreq(128, getEntryFreq());
   BlockCount *= BlockFreq;

diff  --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index a77dd32f1816e..6d97898c08e3a 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -771,7 +771,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
     // Make sure we have a nonzero entry count.
     auto EntryCount = F.getEntryCount();
-    if (!EntryCount || !EntryCount.getCount())
+    if (!EntryCount || !EntryCount->getCount())
       return false;
 
     BlockFrequencyInfo *CalleeBFI = &(GetBFI(F));
@@ -837,8 +837,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
     // Compute the cycle savings per call.
     auto EntryProfileCount = F.getEntryCount();
-    assert(EntryProfileCount.hasValue() && EntryProfileCount.getCount());
-    auto EntryCount = EntryProfileCount.getCount();
+    assert(EntryProfileCount.hasValue() && EntryProfileCount->getCount());
+    auto EntryCount = EntryProfileCount->getCount();
     CycleSavings += EntryCount / 2;
     CycleSavings = CycleSavings.udiv(EntryCount);
 

diff  --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
index 69ca5502e5753..268ed9d047419 100644
--- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp
+++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
@@ -103,7 +103,7 @@ bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) const {
   // FIXME: The heuristic used below for determining hotness is based on
   // preliminary SPEC tuning for inliner. This will eventually be a
   // convenience method that calls isHotCount.
-  return FunctionCount && isHotCount(FunctionCount.getCount());
+  return FunctionCount && isHotCount(FunctionCount->getCount());
 }
 
 /// Returns true if the function contains hot code. This can include a hot
@@ -116,7 +116,7 @@ bool ProfileSummaryInfo::isFunctionHotInCallGraph(
   if (!F || !hasProfileSummary())
     return false;
   if (auto FunctionCount = F->getEntryCount())
-    if (isHotCount(FunctionCount.getCount()))
+    if (isHotCount(FunctionCount->getCount()))
       return true;
 
   if (hasSampleProfile()) {
@@ -145,7 +145,7 @@ bool ProfileSummaryInfo::isFunctionColdInCallGraph(
   if (!F || !hasProfileSummary())
     return false;
   if (auto FunctionCount = F->getEntryCount())
-    if (!isColdCount(FunctionCount.getCount()))
+    if (!isColdCount(FunctionCount->getCount()))
       return false;
 
   if (hasSampleProfile()) {
@@ -176,10 +176,10 @@ bool ProfileSummaryInfo::isFunctionHotOrColdInCallGraphNthPercentile(
     return false;
   if (auto FunctionCount = F->getEntryCount()) {
     if (isHot &&
-        isHotCountNthPercentile(PercentileCutoff, FunctionCount.getCount()))
+        isHotCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
       return true;
     if (!isHot &&
-        !isColdCountNthPercentile(PercentileCutoff, FunctionCount.getCount()))
+        !isColdCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
       return false;
   }
   if (hasSampleProfile()) {
@@ -230,7 +230,7 @@ bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) const {
   // FIXME: The heuristic used below for determining coldness is based on
   // preliminary SPEC tuning for inliner. This will eventually be a
   // convenience method that calls isHotCount.
-  return FunctionCount && isColdCount(FunctionCount.getCount());
+  return FunctionCount && isColdCount(FunctionCount->getCount());
 }
 
 /// Compute the hot and cold thresholds.

diff  --git a/llvm/lib/CodeGen/MachineSizeOpts.cpp b/llvm/lib/CodeGen/MachineSizeOpts.cpp
index 584d43b420044..28712d1a816bd 100644
--- a/llvm/lib/CodeGen/MachineSizeOpts.cpp
+++ b/llvm/lib/CodeGen/MachineSizeOpts.cpp
@@ -82,7 +82,7 @@ bool isFunctionColdInCallGraph(
     ProfileSummaryInfo *PSI,
     const MachineBlockFrequencyInfo &MBFI) {
   if (auto FunctionCount = MF->getFunction().getEntryCount())
-    if (!PSI->isColdCount(FunctionCount.getCount()))
+    if (!PSI->isColdCount(FunctionCount->getCount()))
       return false;
   for (const auto &MBB : *MF)
     if (!isColdBlock(&MBB, PSI, &MBFI))
@@ -99,7 +99,7 @@ bool isFunctionHotInCallGraphNthPercentile(
     const MachineBlockFrequencyInfo &MBFI) {
   if (auto FunctionCount = MF->getFunction().getEntryCount())
     if (PSI->isHotCountNthPercentile(PercentileCutoff,
-                                     FunctionCount.getCount()))
+                                     FunctionCount->getCount()))
       return true;
   for (const auto &MBB : *MF)
     if (isHotBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))
@@ -112,7 +112,7 @@ bool isFunctionColdInCallGraphNthPercentile(
     const MachineBlockFrequencyInfo &MBFI) {
   if (auto FunctionCount = MF->getFunction().getEntryCount())
     if (!PSI->isColdCountNthPercentile(PercentileCutoff,
-                                       FunctionCount.getCount()))
+                                       FunctionCount->getCount()))
       return false;
   for (const auto &MBB : *MF)
     if (!isColdBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))

diff  --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 7eddffab13b9d..82b20a8af91bf 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1892,10 +1892,9 @@ void Function::setValueSubclassDataBit(unsigned Bit, bool On) {
 
 void Function::setEntryCount(ProfileCount Count,
                              const DenseSet<GlobalValue::GUID> *S) {
-  assert(Count.hasValue());
 #if !defined(NDEBUG)
   auto PrevCount = getEntryCount();
-  assert(!PrevCount.hasValue() || PrevCount.getType() == Count.getType());
+  assert(!PrevCount.hasValue() || PrevCount->getType() == Count.getType());
 #endif
 
   auto ImportGUIDs = getImportGUIDs();
@@ -1913,7 +1912,7 @@ void Function::setEntryCount(uint64_t Count, Function::ProfileCountType Type,
   setEntryCount(ProfileCount(Count, Type), Imports);
 }
 
-ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
+Optional<ProfileCount> Function::getEntryCount(bool AllowSynthetic) const {
   MDNode *MD = getMetadata(LLVMContext::MD_prof);
   if (MD && MD->getOperand(0))
     if (MDString *MDS = dyn_cast<MDString>(MD->getOperand(0))) {
@@ -1923,7 +1922,7 @@ ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
         // A value of -1 is used for SamplePGO when there were no samples.
         // Treat this the same as unknown.
         if (Count == (uint64_t)-1)
-          return ProfileCount::getInvalid();
+          return None;
         return ProfileCount(Count, PCT_Real);
       } else if (AllowSynthetic &&
                  MDS->getString().equals("synthetic_function_entry_count")) {
@@ -1932,7 +1931,7 @@ ProfileCount Function::getEntryCount(bool AllowSynthetic) const {
         return ProfileCount(Count, PCT_Synthetic);
       }
     }
-  return ProfileCount::getInvalid();
+  return None;
 }
 
 DenseSet<GlobalValue::GUID> Function::getImportGUIDs() const {

diff  --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 6921c46be7a8e..7402e399a88a0 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -1411,7 +1411,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
     computeCallsiteToProfCountMap(Cloner.ClonedFunc, CallSiteToProfCountMap);
 
   uint64_t CalleeEntryCountV =
-      (CalleeEntryCount ? CalleeEntryCount.getCount() : 0);
+      (CalleeEntryCount ? CalleeEntryCount->getCount() : 0);
 
   bool AnyInline = false;
   for (User *User : Users) {
@@ -1459,8 +1459,8 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {
   if (AnyInline) {
     Cloner.IsFunctionInlined = true;
     if (CalleeEntryCount)
-      Cloner.OrigFunc->setEntryCount(
-          CalleeEntryCount.setCount(CalleeEntryCountV));
+      Cloner.OrigFunc->setEntryCount(Function::ProfileCount(
+          CalleeEntryCountV, CalleeEntryCount->getType()));
     OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc);
     OrigFuncORE.emit([&]() {
       return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc)

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c437003060cbb..af5946325bbb8 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1686,7 +1686,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
   BlockFrequencyInfo NBFI(F, NBPI, LI);
 #ifndef NDEBUG
   auto BFIEntryCount = F.getEntryCount();
-  assert(BFIEntryCount.hasValue() && (BFIEntryCount.getCount() > 0) &&
+  assert(BFIEntryCount.hasValue() && (BFIEntryCount->getCount() > 0) &&
          "Invalid BFI Entrycount");
 #endif
   auto SumCount = APFloat::getZero(APFloat::IEEEdouble());

diff  --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 6a3ac1ee77baa..f4776589910f2 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -1600,8 +1600,7 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap,
                               const ProfileCount &CalleeEntryCount,
                               const CallBase &TheCall, ProfileSummaryInfo *PSI,
                               BlockFrequencyInfo *CallerBFI) {
-  if (!CalleeEntryCount.hasValue() || CalleeEntryCount.isSynthetic() ||
-      CalleeEntryCount.getCount() < 1)
+  if (CalleeEntryCount.isSynthetic() || CalleeEntryCount.getCount() < 1)
     return;
   auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None;
   int64_t CallCount =
@@ -1616,7 +1615,7 @@ void llvm::updateProfileCallee(
   if (!CalleeCount.hasValue())
     return;
 
-  const uint64_t PriorEntryCount = CalleeCount.getCount();
+  const uint64_t PriorEntryCount = CalleeCount->getCount();
 
   // Since CallSiteCount is an estimate, it could exceed the original callee
   // count and has to be set to 0 so guard against underflow.
@@ -1969,8 +1968,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
         updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI,
                         CalledFunc->front());
 
-      updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), CB,
-                        IFI.PSI, IFI.CallerBFI);
+      if (auto Profile = CalledFunc->getEntryCount())
+        updateCallProfile(CalledFunc, VMap, *Profile, CB, IFI.PSI,
+                          IFI.CallerBFI);
     }
 
     // Inject byval arguments initialization.

diff  --git a/llvm/unittests/IR/MetadataTest.cpp b/llvm/unittests/IR/MetadataTest.cpp
index 959d76db7d2d3..f870ec81a0172 100644
--- a/llvm/unittests/IR/MetadataTest.cpp
+++ b/llvm/unittests/IR/MetadataTest.cpp
@@ -3424,23 +3424,24 @@ TEST_F(FunctionAttachmentTest, Verifier) {
   EXPECT_FALSE(verifyFunction(*F));
 }
 
-TEST_F(FunctionAttachmentTest, EntryCount) {
+TEST_F(FunctionAttachmentTest, RealEntryCount) {
   Function *F = getFunction("foo");
   EXPECT_FALSE(F->getEntryCount().hasValue());
   F->setEntryCount(12304, Function::PCT_Real);
   auto Count = F->getEntryCount();
   EXPECT_TRUE(Count.hasValue());
-  EXPECT_EQ(12304u, Count.getCount());
-  EXPECT_EQ(Function::PCT_Real, Count.getType());
+  EXPECT_EQ(12304u, Count->getCount());
+  EXPECT_EQ(Function::PCT_Real, Count->getType());
+}
 
-  // Repeat the same for synthetic counts.
-  F = getFunction("bar");
+TEST_F(FunctionAttachmentTest, SyntheticEntryCount) {
+  Function *F = getFunction("bar");
   EXPECT_FALSE(F->getEntryCount().hasValue());
   F->setEntryCount(123, Function::PCT_Synthetic);
-  Count = F->getEntryCount(true /*allow synthetic*/);
+  auto Count = F->getEntryCount(true /*allow synthetic*/);
   EXPECT_TRUE(Count.hasValue());
-  EXPECT_EQ(123u, Count.getCount());
-  EXPECT_EQ(Function::PCT_Synthetic, Count.getType());
+  EXPECT_EQ(123u, Count->getCount());
+  EXPECT_EQ(Function::PCT_Synthetic, Count->getType());
 }
 
 TEST_F(FunctionAttachmentTest, SubprogramAttachment) {


        


More information about the llvm-commits mailing list