[llvm] 4bb3d0e - Revert D153927 "Resubmit with fix: [NFC] Refactor MBB hotness/coldness into templated PSI functions."

Fangrui Song via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 28 00:30:55 PDT 2023


Author: Fangrui Song
Date: 2023-06-28T00:30:52-07:00
New Revision: 4bb3d0e5318ef6083596853daf59f0bdb4700d55

URL: https://github.com/llvm/llvm-project/commit/4bb3d0e5318ef6083596853daf59f0bdb4700d55
DIFF: https://github.com/llvm/llvm-project/commit/4bb3d0e5318ef6083596853daf59f0bdb4700d55.diff

LOG: Revert D153927 "Resubmit with fix: [NFC] Refactor MBB hotness/coldness into templated PSI functions."

This reverts commit 4d8cf2ae6804e0d3f2b668dbec0f5c1983358328.

There is a library layering violation. LLVMAnalysis cannot depend on LLVMCodeGen.

```
llvm/include/llvm/Analysis/ProfileSummaryInfo.h:19:10: fatal error: 'llvm/CodeGen/MachineFunction.h' file not found
   19 | #include "llvm/CodeGen/MachineFunction.h"
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ProfileSummaryInfo.h
    llvm/include/llvm/Transforms/Utils/SizeOpts.h
    llvm/lib/Analysis/ProfileSummaryInfo.cpp
    llvm/lib/CodeGen/MachineSizeOpts.cpp
    llvm/lib/Transforms/Utils/SizeOpts.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
index 39654c80b40b2..292c713f07ca0 100644
--- a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
+++ b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
@@ -15,10 +15,6 @@
 #define LLVM_ANALYSIS_PROFILESUMMARYINFO_H
 
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/Analysis/BlockFrequencyInfo.h"
-#include "llvm/CodeGen/MachineFunction.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/Pass.h"
@@ -27,7 +23,9 @@
 
 namespace llvm {
 class BasicBlock;
+class BlockFrequencyInfo;
 class CallBase;
+class Function;
 
 /// Analysis providing profile information.
 ///
@@ -109,77 +107,28 @@ class ProfileSummaryInfo {
   bool hasHugeWorkingSetSize() const;
   /// Returns true if the working set size of the code is considered large.
   bool hasLargeWorkingSetSize() const;
-  /// Returns true if \p F has hot function entry. If it returns false, it
-  /// either means it is not hot or it is unknown whether it is hot or not (for
-  /// example, no profile data is available).
-  template <typename FuncT> bool isFunctionEntryHot(const FuncT *F) const {
-    if (!F || !hasProfileSummary())
-      return false;
-    std::optional<Function::ProfileCount> FunctionCount = getEntryCount(F);
-    // FIXME: The heuristic used below for determining hotness is based on
-    // preliminary SPEC tuning for inliner. This will eventually be a
-    // convenience method that calls isHotCount.
-    return FunctionCount && isHotCount(FunctionCount->getCount());
-  }
-
+  /// Returns true if \p F has hot function entry.
+  bool isFunctionEntryHot(const Function *F) const;
   /// Returns true if \p F contains hot code.
-  template <typename FuncT, typename BFIT>
-  bool isFunctionHotInCallGraph(const FuncT *F, BFIT &BFI) const {
-    if (!F || !hasProfileSummary())
-      return false;
-    if (auto FunctionCount = getEntryCount(F))
-      if (isHotCount(FunctionCount->getCount()))
-        return true;
-
-    if (auto TotalCallCount = getTotalCallCount(F)) {
-      if (isHotCount(*TotalCallCount))
-        return true;
-    }
-
-    for (const auto &BB : *F)
-      if (isHotBlock(&BB, &BFI))
-        return true;
-    return false;
-  }
+  bool isFunctionHotInCallGraph(const Function *F,
+                                BlockFrequencyInfo &BFI) const;
   /// Returns true if \p F has cold function entry.
   bool isFunctionEntryCold(const Function *F) const;
   /// Returns true if \p F contains only cold code.
-  template <typename FuncT, typename BFIT>
-  bool isFunctionColdInCallGraph(const FuncT *F, BFIT &BFI) const {
-    if (!F || !hasProfileSummary())
-      return false;
-    if (auto FunctionCount = getEntryCount(F))
-      if (!isColdCount(FunctionCount->getCount()))
-        return false;
-
-    if (auto TotalCallCount = getTotalCallCount(F)) {
-      if (!isColdCount(*TotalCallCount))
-        return false;
-    }
-
-    for (const auto &BB : *F)
-      if (!isColdBlock(&BB, &BFI))
-        return false;
-    return true;
-  }
+  bool isFunctionColdInCallGraph(const Function *F,
+                                 BlockFrequencyInfo &BFI) const;
   /// Returns true if the hotness of \p F is unknown.
   bool isFunctionHotnessUnknown(const Function &F) const;
   /// Returns true if \p F contains hot code with regard to a given hot
   /// percentile cutoff value.
-  template <typename FuncT, typename BFIT>
   bool isFunctionHotInCallGraphNthPercentile(int PercentileCutoff,
-                                             const FuncT *F, BFIT &BFI) const {
-    return isFunctionHotOrColdInCallGraphNthPercentile<true, FuncT, BFIT>(
-        PercentileCutoff, F, BFI);
-  }
+                                             const Function *F,
+                                             BlockFrequencyInfo &BFI) const;
   /// Returns true if \p F contains cold code with regard to a given cold
   /// percentile cutoff value.
-  template <typename FuncT, typename BFIT>
   bool isFunctionColdInCallGraphNthPercentile(int PercentileCutoff,
-                                              const FuncT *F, BFIT &BFI) const {
-    return isFunctionHotOrColdInCallGraphNthPercentile<false, FuncT, BFIT>(
-        PercentileCutoff, F, BFI);
-  }
+                                              const Function *F,
+                                              BlockFrequencyInfo &BFI) const;
   /// Returns true if count \p C is considered hot.
   bool isHotCount(uint64_t C) const;
   /// Returns true if count \p C is considered cold.
@@ -194,57 +143,22 @@ class ProfileSummaryInfo {
   /// PercentileCutoff is encoded as a 6 digit decimal fixed point number, where
   /// the first two digits are the whole part. E.g. 995000 for 99.5 percentile.
   bool isColdCountNthPercentile(int PercentileCutoff, uint64_t C) const;
-
   /// Returns true if BasicBlock \p BB is considered hot.
-  template <typename BBType, typename BFIT>
-  bool isHotBlock(const BBType *BB, BFIT *BFI) const {
-    auto Count = BFI->getBlockProfileCount(BB);
-    return Count && isHotCount(*Count);
-  }
-
+  bool isHotBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) const;
   /// Returns true if BasicBlock \p BB is considered cold.
-  template <typename BBType, typename BFIT>
-  bool isColdBlock(const BBType *BB, BFIT *BFI) const {
-    auto Count = BFI->getBlockProfileCount(BB);
-    return Count && isColdCount(*Count);
-  }
-
-  template <typename BFIT>
-  bool isColdBlock(BlockFrequency BlockFreq, const BFIT *BFI) const {
-    auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
-    return Count && isColdCount(*Count);
-  }
-
-  template <typename BBType, typename BFIT>
-  bool isHotBlockNthPercentile(int PercentileCutoff, const BBType *BB,
-                               BFIT *BFI) const {
-    return isHotOrColdBlockNthPercentile<true, BBType, BFIT>(PercentileCutoff,
-                                                             BB, BFI);
-  }
-
-  template <typename BFIT>
-  bool isHotBlockNthPercentile(int PercentileCutoff, BlockFrequency BlockFreq,
-                               BFIT *BFI) const {
-    return isHotOrColdBlockNthPercentile<true, BFIT>(PercentileCutoff,
-                                                     BlockFreq, BFI);
-  }
-
+  bool isColdBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) const;
+  /// Returns true if BasicBlock \p BB is considered hot with regard to a given
+  /// hot percentile cutoff value.
+  /// PercentileCutoff is encoded as a 6 digit decimal fixed point number, where
+  /// the first two digits are the whole part. E.g. 995000 for 99.5 percentile.
+  bool isHotBlockNthPercentile(int PercentileCutoff, const BasicBlock *BB,
+                               BlockFrequencyInfo *BFI) const;
   /// Returns true if BasicBlock \p BB is considered cold with regard to a given
   /// cold percentile cutoff value.
   /// PercentileCutoff is encoded as a 6 digit decimal fixed point number, where
   /// the first two digits are the whole part. E.g. 995000 for 99.5 percentile.
-  template <typename BBType, typename BFIT>
-  bool isColdBlockNthPercentile(int PercentileCutoff, const BBType *BB,
-                                BFIT *BFI) const {
-    return isHotOrColdBlockNthPercentile<false, BBType, BFIT>(PercentileCutoff,
-                                                              BB, BFI);
-  }
-  template <typename BFIT>
-  bool isColdBlockNthPercentile(int PercentileCutoff, BlockFrequency BlockFreq,
-                                BFIT *BFI) const {
-    return isHotOrColdBlockNthPercentile<false, BFIT>(PercentileCutoff,
-                                                      BlockFreq, BFI);
-  }
+  bool isColdBlockNthPercentile(int PercentileCutoff, const BasicBlock *BB,
+                                BlockFrequencyInfo *BFI) const;
   /// Returns true if the call site \p CB is considered hot.
   bool isHotCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) const;
   /// Returns true if call site \p CB is considered cold.
@@ -264,97 +178,18 @@ class ProfileSummaryInfo {
     return ColdCountThreshold.value_or(0);
   }
 
-private:
-  template <typename FuncT>
-  std::optional<uint64_t> getTotalCallCount(const FuncT *F) const {
-    return std::nullopt;
-  }
-
-  template <bool isHot, typename FuncT, typename BFIT>
-  bool isFunctionHotOrColdInCallGraphNthPercentile(int PercentileCutoff,
-                                                   const FuncT *F,
-                                                   BFIT &FI) const {
-    if (!F || !hasProfileSummary())
-      return false;
-    if (auto FunctionCount = getEntryCount(F)) {
-      if (isHot &&
-          isHotCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
-        return true;
-      if (!isHot && !isColdCountNthPercentile(PercentileCutoff,
-                                              FunctionCount->getCount()))
-        return false;
-    }
-    if (auto TotalCallCount = getTotalCallCount(F)) {
-      if (isHot && isHotCountNthPercentile(PercentileCutoff, *TotalCallCount))
-        return true;
-      if (!isHot &&
-          !isColdCountNthPercentile(PercentileCutoff, *TotalCallCount))
-        return false;
-    }
-    for (const auto &BB : *F) {
-      if (isHot && isHotBlockNthPercentile(PercentileCutoff, &BB, &FI))
-        return true;
-      if (!isHot && !isColdBlockNthPercentile(PercentileCutoff, &BB, &FI))
-        return false;
-    }
-    return !isHot;
-  }
-
-  template <bool isHot>
-  bool isHotOrColdCountNthPercentile(int PercentileCutoff, uint64_t C) const;
-
-  template <bool isHot, typename BBType, typename BFIT>
-  bool isHotOrColdBlockNthPercentile(int PercentileCutoff, const BBType *BB,
-                                     BFIT *BFI) const {
-    auto Count = BFI->getBlockProfileCount(BB);
-    if (isHot)
-      return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
-    else
-      return Count && isColdCountNthPercentile(PercentileCutoff, *Count);
-  }
-
-  template <bool isHot, typename BFIT>
-  bool isHotOrColdBlockNthPercentile(int PercentileCutoff,
-                                     BlockFrequency BlockFreq,
-                                     BFIT *BFI) const {
-    auto Count = BFI->getProfileCountFromFreq(BlockFreq.getFrequency());
-    if (isHot)
-      return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
-    else
-      return Count && isColdCountNthPercentile(PercentileCutoff, *Count);
-  }
-
-  template <typename FuncT>
-  std::optional<Function::ProfileCount> getEntryCount(const FuncT *F) const {
-    return F->getEntryCount();
-  }
+ private:
+   template <bool isHot>
+   bool isFunctionHotOrColdInCallGraphNthPercentile(
+       int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const;
+   template <bool isHot>
+   bool isHotOrColdCountNthPercentile(int PercentileCutoff, uint64_t C) const;
+   template <bool isHot>
+   bool isHotOrColdBlockNthPercentile(int PercentileCutoff,
+                                      const BasicBlock *BB,
+                                      BlockFrequencyInfo *BFI) const;
 };
 
-template <>
-inline std::optional<uint64_t>
-ProfileSummaryInfo::getTotalCallCount<Function>(const Function *F) const {
-  if (!hasSampleProfile())
-    return std::nullopt;
-  uint64_t TotalCallCount = 0;
-  for (const auto &BB : *F) {
-    for (const auto &I : BB) {
-      if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-        if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) {
-          TotalCallCount += *CallCount;
-        }
-      }
-    }
-  }
-  return TotalCallCount;
-}
-
-template <>
-inline std::optional<Function::ProfileCount>
-ProfileSummaryInfo::getEntryCount<MachineFunction>(
-    const MachineFunction *F) const {
-  return F->getFunction().getEntryCount();
-}
-
 /// An analysis pass based on legacy pass manager to deliver ProfileSummaryInfo.
 class ProfileSummaryInfoWrapperPass : public ImmutablePass {
   std::unique_ptr<ProfileSummaryInfo> PSI;

diff  --git a/llvm/include/llvm/Transforms/Utils/SizeOpts.h b/llvm/include/llvm/Transforms/Utils/SizeOpts.h
index a9e72768f81e3..aa9e9bd6c69b7 100644
--- a/llvm/include/llvm/Transforms/Utils/SizeOpts.h
+++ b/llvm/include/llvm/Transforms/Utils/SizeOpts.h
@@ -47,7 +47,7 @@ static inline bool isPGSOColdCodeOnly(ProfileSummaryInfo *PSI) {
          (PGSOLargeWorkingSetSizeOnly && !PSI->hasLargeWorkingSetSize());
 }
 
-template <typename FuncT, typename BFIT>
+template<typename AdapterT, typename FuncT, typename BFIT>
 bool shouldFuncOptimizeForSizeImpl(const FuncT *F, ProfileSummaryInfo *PSI,
                                    BFIT *BFI, PGSOQueryType QueryType) {
   assert(F);
@@ -58,20 +58,19 @@ bool shouldFuncOptimizeForSizeImpl(const FuncT *F, ProfileSummaryInfo *PSI,
   if (!EnablePGSO)
     return false;
   if (isPGSOColdCodeOnly(PSI))
-    return PSI->isFunctionColdInCallGraph(F, *BFI);
+    return AdapterT::isFunctionColdInCallGraph(F, PSI, *BFI);
   if (PSI->hasSampleProfile())
     // The "isCold" check seems to work better for Sample PGO as it could have
     // many profile-unannotated functions.
-    return PSI->isFunctionColdInCallGraphNthPercentile(PgsoCutoffSampleProf, F,
-                                                       *BFI);
-  return !PSI->isFunctionHotInCallGraphNthPercentile(PgsoCutoffInstrProf, F,
-                                                     *BFI);
+    return AdapterT::isFunctionColdInCallGraphNthPercentile(
+        PgsoCutoffSampleProf, F, PSI, *BFI);
+  return !AdapterT::isFunctionHotInCallGraphNthPercentile(PgsoCutoffInstrProf,
+                                                          F, PSI, *BFI);
 }
 
-template <typename BlockTOrBlockFreq, typename BFIT>
-bool shouldOptimizeForSizeImpl(BlockTOrBlockFreq BBOrBlockFreq,
-                               ProfileSummaryInfo *PSI, BFIT *BFI,
-                               PGSOQueryType QueryType) {
+template<typename AdapterT, typename BlockTOrBlockFreq, typename BFIT>
+bool shouldOptimizeForSizeImpl(BlockTOrBlockFreq BBOrBlockFreq, ProfileSummaryInfo *PSI,
+                               BFIT *BFI, PGSOQueryType QueryType) {
   if (!PSI || !BFI || !PSI->hasProfileSummary())
     return false;
   if (ForcePGSO)
@@ -79,13 +78,14 @@ bool shouldOptimizeForSizeImpl(BlockTOrBlockFreq BBOrBlockFreq,
   if (!EnablePGSO)
     return false;
   if (isPGSOColdCodeOnly(PSI))
-    return PSI->isColdBlock(BBOrBlockFreq, BFI);
+    return AdapterT::isColdBlock(BBOrBlockFreq, PSI, BFI);
   if (PSI->hasSampleProfile())
     // The "isCold" check seems to work better for Sample PGO as it could have
     // many profile-unannotated functions.
-    return PSI->isColdBlockNthPercentile(PgsoCutoffSampleProf, BBOrBlockFreq,
-                                         BFI);
-  return !PSI->isHotBlockNthPercentile(PgsoCutoffInstrProf, BBOrBlockFreq, BFI);
+    return AdapterT::isColdBlockNthPercentile(PgsoCutoffSampleProf,
+                                              BBOrBlockFreq, PSI, BFI);
+  return !AdapterT::isHotBlockNthPercentile(PgsoCutoffInstrProf, BBOrBlockFreq,
+                                            PSI, BFI);
 }
 
 /// Returns true if function \p F is suggested to be size-optimized based on the

diff  --git a/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
index 203f1e42733f3..6b9f15bf2f647 100644
--- a/llvm/lib/Analysis/ProfileSummaryInfo.cpp
+++ b/llvm/lib/Analysis/ProfileSummaryInfo.cpp
@@ -95,11 +95,129 @@ std::optional<uint64_t> ProfileSummaryInfo::getProfileCount(
   return std::nullopt;
 }
 
+/// Returns true if the function's entry is hot. If it returns false, it
+/// either means it is not hot or it is unknown whether it is hot or not (for
+/// example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) const {
+  if (!F || !hasProfileSummary())
+    return false;
+  auto FunctionCount = F->getEntryCount();
+  // FIXME: The heuristic used below for determining hotness is based on
+  // preliminary SPEC tuning for inliner. This will eventually be a
+  // convenience method that calls isHotCount.
+  return FunctionCount && isHotCount(FunctionCount->getCount());
+}
+
+/// Returns true if the function contains hot code. This can include a hot
+/// function entry count, hot basic block, or (in the case of Sample PGO)
+/// hot total call edge count.
+/// If it returns false, it either means it is not hot or it is unknown
+/// (for example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionHotInCallGraph(
+    const Function *F, BlockFrequencyInfo &BFI) const {
+  if (!F || !hasProfileSummary())
+    return false;
+  if (auto FunctionCount = F->getEntryCount())
+    if (isHotCount(FunctionCount->getCount()))
+      return true;
+
+  if (hasSampleProfile()) {
+    uint64_t TotalCallCount = 0;
+    for (const auto &BB : *F)
+      for (const auto &I : BB)
+        if (isa<CallInst>(I) || isa<InvokeInst>(I))
+          if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr))
+            TotalCallCount += *CallCount;
+    if (isHotCount(TotalCallCount))
+      return true;
+  }
+  for (const auto &BB : *F)
+    if (isHotBlock(&BB, &BFI))
+      return true;
+  return false;
+}
+
+/// Returns true if the function only contains cold code. This means that
+/// the function entry and blocks are all cold, and (in the case of Sample PGO)
+/// the total call edge count is cold.
+/// If it returns false, it either means it is not cold or it is unknown
+/// (for example, no profile data is available).
+bool ProfileSummaryInfo::isFunctionColdInCallGraph(
+    const Function *F, BlockFrequencyInfo &BFI) const {
+  if (!F || !hasProfileSummary())
+    return false;
+  if (auto FunctionCount = F->getEntryCount())
+    if (!isColdCount(FunctionCount->getCount()))
+      return false;
+
+  if (hasSampleProfile()) {
+    uint64_t TotalCallCount = 0;
+    for (const auto &BB : *F)
+      for (const auto &I : BB)
+        if (isa<CallInst>(I) || isa<InvokeInst>(I))
+          if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr))
+            TotalCallCount += *CallCount;
+    if (!isColdCount(TotalCallCount))
+      return false;
+  }
+  for (const auto &BB : *F)
+    if (!isColdBlock(&BB, &BFI))
+      return false;
+  return true;
+}
+
 bool ProfileSummaryInfo::isFunctionHotnessUnknown(const Function &F) const {
   assert(hasPartialSampleProfile() && "Expect partial sample profile");
   return !F.getEntryCount();
 }
 
+template <bool isHot>
+bool ProfileSummaryInfo::isFunctionHotOrColdInCallGraphNthPercentile(
+    int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const {
+  if (!F || !hasProfileSummary())
+    return false;
+  if (auto FunctionCount = F->getEntryCount()) {
+    if (isHot &&
+        isHotCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
+      return true;
+    if (!isHot &&
+        !isColdCountNthPercentile(PercentileCutoff, FunctionCount->getCount()))
+      return false;
+  }
+  if (hasSampleProfile()) {
+    uint64_t TotalCallCount = 0;
+    for (const auto &BB : *F)
+      for (const auto &I : BB)
+        if (isa<CallInst>(I) || isa<InvokeInst>(I))
+          if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr))
+            TotalCallCount += *CallCount;
+    if (isHot && isHotCountNthPercentile(PercentileCutoff, TotalCallCount))
+      return true;
+    if (!isHot && !isColdCountNthPercentile(PercentileCutoff, TotalCallCount))
+      return false;
+  }
+  for (const auto &BB : *F) {
+    if (isHot && isHotBlockNthPercentile(PercentileCutoff, &BB, &BFI))
+      return true;
+    if (!isHot && !isColdBlockNthPercentile(PercentileCutoff, &BB, &BFI))
+      return false;
+  }
+  return !isHot;
+}
+
+// Like isFunctionHotInCallGraph but for a given cutoff.
+bool ProfileSummaryInfo::isFunctionHotInCallGraphNthPercentile(
+    int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const {
+  return isFunctionHotOrColdInCallGraphNthPercentile<true>(
+      PercentileCutoff, F, BFI);
+}
+
+bool ProfileSummaryInfo::isFunctionColdInCallGraphNthPercentile(
+    int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const {
+  return isFunctionHotOrColdInCallGraphNthPercentile<false>(
+      PercentileCutoff, F, BFI);
+}
+
 /// Returns true if the function's entry is a cold. If it returns false, it
 /// either means it is not cold or it is unknown whether it is cold or not (for
 /// example, no profile data is available).
@@ -207,6 +325,38 @@ uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() const {
   return ColdCountThreshold.value_or(0);
 }
 
+bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB,
+                                    BlockFrequencyInfo *BFI) const {
+  auto Count = BFI->getBlockProfileCount(BB);
+  return Count && isHotCount(*Count);
+}
+
+bool ProfileSummaryInfo::isColdBlock(const BasicBlock *BB,
+                                     BlockFrequencyInfo *BFI) const {
+  auto Count = BFI->getBlockProfileCount(BB);
+  return Count && isColdCount(*Count);
+}
+
+template <bool isHot>
+bool ProfileSummaryInfo::isHotOrColdBlockNthPercentile(
+    int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const {
+  auto Count = BFI->getBlockProfileCount(BB);
+  if (isHot)
+    return Count && isHotCountNthPercentile(PercentileCutoff, *Count);
+  else
+    return Count && isColdCountNthPercentile(PercentileCutoff, *Count);
+}
+
+bool ProfileSummaryInfo::isHotBlockNthPercentile(
+    int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const {
+  return isHotOrColdBlockNthPercentile<true>(PercentileCutoff, BB, BFI);
+}
+
+bool ProfileSummaryInfo::isColdBlockNthPercentile(
+    int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const {
+  return isHotOrColdBlockNthPercentile<false>(PercentileCutoff, BB, BFI);
+}
+
 bool ProfileSummaryInfo::isHotCallSite(const CallBase &CB,
                                        BlockFrequencyInfo *BFI) const {
   auto C = getProfileCount(CB, BFI);

diff  --git a/llvm/lib/CodeGen/MachineSizeOpts.cpp b/llvm/lib/CodeGen/MachineSizeOpts.cpp
index 53bed7397d099..28712d1a816bd 100644
--- a/llvm/lib/CodeGen/MachineSizeOpts.cpp
+++ b/llvm/lib/CodeGen/MachineSizeOpts.cpp
@@ -24,11 +24,168 @@ extern cl::opt<bool> ForcePGSO;
 extern cl::opt<int> PgsoCutoffInstrProf;
 extern cl::opt<int> PgsoCutoffSampleProf;
 
+namespace {
+namespace machine_size_opts_detail {
+
+/// Like ProfileSummaryInfo::isColdBlock but for MachineBasicBlock.
+bool isColdBlock(const MachineBasicBlock *MBB,
+                 ProfileSummaryInfo *PSI,
+                 const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getBlockProfileCount(MBB);
+  return Count && PSI->isColdCount(*Count);
+}
+
+bool isColdBlock(BlockFrequency BlockFreq,
+                 ProfileSummaryInfo *PSI,
+                 const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getProfileCountFromFreq(BlockFreq.getFrequency());
+  return Count && PSI->isColdCount(*Count);
+}
+
+/// Like ProfileSummaryInfo::isHotBlockNthPercentile but for MachineBasicBlock.
+static bool isHotBlockNthPercentile(int PercentileCutoff,
+                                    const MachineBasicBlock *MBB,
+                                    ProfileSummaryInfo *PSI,
+                                    const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getBlockProfileCount(MBB);
+  return Count && PSI->isHotCountNthPercentile(PercentileCutoff, *Count);
+}
+
+static bool isHotBlockNthPercentile(int PercentileCutoff,
+                                    BlockFrequency BlockFreq,
+                                    ProfileSummaryInfo *PSI,
+                                    const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getProfileCountFromFreq(BlockFreq.getFrequency());
+  return Count && PSI->isHotCountNthPercentile(PercentileCutoff, *Count);
+}
+
+static bool isColdBlockNthPercentile(int PercentileCutoff,
+                                     const MachineBasicBlock *MBB,
+                                     ProfileSummaryInfo *PSI,
+                                     const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getBlockProfileCount(MBB);
+  return Count && PSI->isColdCountNthPercentile(PercentileCutoff, *Count);
+}
+
+static bool isColdBlockNthPercentile(int PercentileCutoff,
+                                     BlockFrequency BlockFreq,
+                                     ProfileSummaryInfo *PSI,
+                                     const MachineBlockFrequencyInfo *MBFI) {
+  auto Count = MBFI->getProfileCountFromFreq(BlockFreq.getFrequency());
+  return Count && PSI->isColdCountNthPercentile(PercentileCutoff, *Count);
+}
+
+/// Like ProfileSummaryInfo::isFunctionColdInCallGraph but for
+/// MachineFunction.
+bool isFunctionColdInCallGraph(
+    const MachineFunction *MF,
+    ProfileSummaryInfo *PSI,
+    const MachineBlockFrequencyInfo &MBFI) {
+  if (auto FunctionCount = MF->getFunction().getEntryCount())
+    if (!PSI->isColdCount(FunctionCount->getCount()))
+      return false;
+  for (const auto &MBB : *MF)
+    if (!isColdBlock(&MBB, PSI, &MBFI))
+      return false;
+  return true;
+}
+
+/// Like ProfileSummaryInfo::isFunctionHotInCallGraphNthPercentile but for
+/// MachineFunction.
+bool isFunctionHotInCallGraphNthPercentile(
+    int PercentileCutoff,
+    const MachineFunction *MF,
+    ProfileSummaryInfo *PSI,
+    const MachineBlockFrequencyInfo &MBFI) {
+  if (auto FunctionCount = MF->getFunction().getEntryCount())
+    if (PSI->isHotCountNthPercentile(PercentileCutoff,
+                                     FunctionCount->getCount()))
+      return true;
+  for (const auto &MBB : *MF)
+    if (isHotBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))
+      return true;
+  return false;
+}
+
+bool isFunctionColdInCallGraphNthPercentile(
+    int PercentileCutoff, const MachineFunction *MF, ProfileSummaryInfo *PSI,
+    const MachineBlockFrequencyInfo &MBFI) {
+  if (auto FunctionCount = MF->getFunction().getEntryCount())
+    if (!PSI->isColdCountNthPercentile(PercentileCutoff,
+                                       FunctionCount->getCount()))
+      return false;
+  for (const auto &MBB : *MF)
+    if (!isColdBlockNthPercentile(PercentileCutoff, &MBB, PSI, &MBFI))
+      return false;
+  return true;
+}
+} // namespace machine_size_opts_detail
+
+struct MachineBasicBlockBFIAdapter {
+  static bool isFunctionColdInCallGraph(const MachineFunction *MF,
+                                        ProfileSummaryInfo *PSI,
+                                        const MachineBlockFrequencyInfo &MBFI) {
+    return machine_size_opts_detail::isFunctionColdInCallGraph(MF, PSI, MBFI);
+  }
+  static bool isFunctionHotInCallGraphNthPercentile(
+      int CutOff,
+      const MachineFunction *MF,
+      ProfileSummaryInfo *PSI,
+      const MachineBlockFrequencyInfo &MBFI) {
+    return machine_size_opts_detail::isFunctionHotInCallGraphNthPercentile(
+        CutOff, MF, PSI, MBFI);
+  }
+  static bool isFunctionColdInCallGraphNthPercentile(
+      int CutOff, const MachineFunction *MF, ProfileSummaryInfo *PSI,
+      const MachineBlockFrequencyInfo &MBFI) {
+    return machine_size_opts_detail::isFunctionColdInCallGraphNthPercentile(
+        CutOff, MF, PSI, MBFI);
+  }
+  static bool isColdBlock(const MachineBasicBlock *MBB,
+                          ProfileSummaryInfo *PSI,
+                          const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isColdBlock(MBB, PSI, MBFI);
+  }
+  static bool isColdBlock(BlockFrequency BlockFreq,
+                          ProfileSummaryInfo *PSI,
+                          const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isColdBlock(BlockFreq, PSI, MBFI);
+  }
+  static bool isHotBlockNthPercentile(int CutOff,
+                                      const MachineBasicBlock *MBB,
+                                      ProfileSummaryInfo *PSI,
+                                      const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isHotBlockNthPercentile(
+        CutOff, MBB, PSI, MBFI);
+  }
+  static bool isHotBlockNthPercentile(int CutOff,
+                                      BlockFrequency BlockFreq,
+                                      ProfileSummaryInfo *PSI,
+                                      const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isHotBlockNthPercentile(
+        CutOff, BlockFreq, PSI, MBFI);
+  }
+  static bool isColdBlockNthPercentile(int CutOff, const MachineBasicBlock *MBB,
+                                       ProfileSummaryInfo *PSI,
+                                       const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isColdBlockNthPercentile(CutOff, MBB, PSI,
+                                                              MBFI);
+  }
+  static bool isColdBlockNthPercentile(int CutOff, BlockFrequency BlockFreq,
+                                       ProfileSummaryInfo *PSI,
+                                       const MachineBlockFrequencyInfo *MBFI) {
+    return machine_size_opts_detail::isColdBlockNthPercentile(CutOff, BlockFreq,
+                                                              PSI, MBFI);
+  }
+};
+} // end anonymous namespace
+
 bool llvm::shouldOptimizeForSize(const MachineFunction *MF,
                                  ProfileSummaryInfo *PSI,
                                  const MachineBlockFrequencyInfo *MBFI,
                                  PGSOQueryType QueryType) {
-  return shouldFuncOptimizeForSizeImpl(MF, PSI, MBFI, QueryType);
+  return shouldFuncOptimizeForSizeImpl<MachineBasicBlockBFIAdapter>(
+      MF, PSI, MBFI, QueryType);
 }
 
 bool llvm::shouldOptimizeForSize(const MachineBasicBlock *MBB,
@@ -36,7 +193,8 @@ bool llvm::shouldOptimizeForSize(const MachineBasicBlock *MBB,
                                  const MachineBlockFrequencyInfo *MBFI,
                                  PGSOQueryType QueryType) {
   assert(MBB);
-  return shouldOptimizeForSizeImpl(MBB, PSI, MBFI, QueryType);
+  return shouldOptimizeForSizeImpl<MachineBasicBlockBFIAdapter>(
+      MBB, PSI, MBFI, QueryType);
 }
 
 bool llvm::shouldOptimizeForSize(const MachineBasicBlock *MBB,
@@ -47,6 +205,6 @@ bool llvm::shouldOptimizeForSize(const MachineBasicBlock *MBB,
   if (!PSI || !MBFIW)
     return false;
   BlockFrequency BlockFreq = MBFIW->getBlockFreq(MBB);
-  return shouldOptimizeForSizeImpl(BlockFreq, PSI, &MBFIW->getMBFI(),
-                                   QueryType);
+  return shouldOptimizeForSizeImpl<MachineBasicBlockBFIAdapter>(
+      BlockFreq, PSI, &MBFIW->getMBFI(), QueryType);
 }

diff  --git a/llvm/lib/Transforms/Utils/SizeOpts.cpp b/llvm/lib/Transforms/Utils/SizeOpts.cpp
index 1ca2e0e6ebb90..1242380f73c16 100644
--- a/llvm/lib/Transforms/Utils/SizeOpts.cpp
+++ b/llvm/lib/Transforms/Utils/SizeOpts.cpp
@@ -98,12 +98,14 @@ struct BasicBlockBFIAdapter {
 bool llvm::shouldOptimizeForSize(const Function *F, ProfileSummaryInfo *PSI,
                                  BlockFrequencyInfo *BFI,
                                  PGSOQueryType QueryType) {
-  return shouldFuncOptimizeForSizeImpl(F, PSI, BFI, QueryType);
+  return shouldFuncOptimizeForSizeImpl<BasicBlockBFIAdapter>(F, PSI, BFI,
+                                                             QueryType);
 }
 
 bool llvm::shouldOptimizeForSize(const BasicBlock *BB, ProfileSummaryInfo *PSI,
                                  BlockFrequencyInfo *BFI,
                                  PGSOQueryType QueryType) {
   assert(BB);
-  return shouldOptimizeForSizeImpl(BB, PSI, BFI, QueryType);
+  return shouldOptimizeForSizeImpl<BasicBlockBFIAdapter>(BB, PSI, BFI,
+                                                         QueryType);
 }


        


More information about the llvm-commits mailing list