[llvm] Avoid BlockFrequency overflow problems (PR #66280)

Matthias Braun via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 13 12:48:09 PDT 2023


https://github.com/MatzeB created https://github.com/llvm/llvm-project/pull/66280:

Multiplying raw block frequency with an integer carries a high risk of overflow.

- Introduce a new `BlockFrequency::mul` function returning a `bool` indicating overflow.
- Mark function with `__attribute__((warn_unused_result))` to avoid users accidentally ignoring the indicator.
- Fix two instances where overflow were leading to wrong results for me.

>From df68ed11a80aa9b40b78af9f3aede618dcbc50c5 Mon Sep 17 00:00:00 2001
From: Matthias Braun <matze at braunis.de>
Date: Mon, 11 Sep 2023 19:35:23 -0700
Subject: [PATCH] Avoid BlockFrequency overflow problems

Multiplying raw block frequency with an integer carries a high risk
of overflow.

- Introduce a new `BlockFrequency::mul` function returning a `bool`
  indicating overflow.
- Mark function with `__attribute__((warn_unused_result))` to avoid
  users accidentally ignoring the indicator.
- Fix two instances where overflow were leading to wrong results for me.
---
 llvm/include/llvm/Support/BlockFrequency.h |  8 ++++++++
 llvm/include/llvm/Support/Compiler.h       |  8 ++++++++
 llvm/lib/Analysis/InlineCost.cpp           | 11 ++++++-----
 llvm/lib/CodeGen/CodeGenPrepare.cpp        | 10 +++++-----
 llvm/lib/Support/BlockFrequency.cpp        |  8 ++++++++
 5 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Support/BlockFrequency.h b/llvm/include/llvm/Support/BlockFrequency.h
index 6c624d7dad7d801..1711fb592485b4c 100644
--- a/llvm/include/llvm/Support/BlockFrequency.h
+++ b/llvm/include/llvm/Support/BlockFrequency.h
@@ -16,6 +16,8 @@
 #include <cassert>
 #include <cstdint>
 
+#include "llvm/Support/Compiler.h"
+
 namespace llvm {
 
 class BranchProbability;
@@ -76,6 +78,12 @@ class BlockFrequency {
     return NewFreq;
   }
 
+  /// Multiplies frequency with `Factor` and stores the result into `Result`.
+  /// Returns `true` if an overflow occured. Overflows are common and should be
+  /// checked by all callers.
+  bool mul(uint64_t Factor,
+           BlockFrequency *Result) const LLVM_WARN_UNUSED_RESULT;
+
   /// Shift block frequency to the right by count digits saturating to 1.
   BlockFrequency &operator>>=(const unsigned count) {
     // Frequency can never be 0 by design.
diff --git a/llvm/include/llvm/Support/Compiler.h b/llvm/include/llvm/Support/Compiler.h
index 12afe90f8facd47..9527e377317ac33 100644
--- a/llvm/include/llvm/Support/Compiler.h
+++ b/llvm/include/llvm/Support/Compiler.h
@@ -269,6 +269,14 @@
 #define LLVM_ATTRIBUTE_RETURNS_NOALIAS
 #endif
 
+/// Mark a function whose return value should not be ignored. Doing so without
+/// a `[[maybe_unused]]` produces a warning if supported by the compiler.
+#if __has_attribute(warn_unused_result)
+#define LLVM_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
+#else
+#define LLVM_WARN_UNUSED_RESULT
+#endif
+
 /// LLVM_FALLTHROUGH - Mark fallthrough cases in switch statements.
 #if defined(__cplusplus) && __cplusplus > 201402L && LLVM_HAS_CPP_ATTRIBUTE(fallthrough)
 #define LLVM_FALLTHROUGH [[fallthrough]]
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index a9de1dde7c7f717..d921047d6466f52 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -118,7 +118,7 @@ static cl::opt<int> ColdCallSiteRelFreq(
              "entry frequency, for a callsite to be cold in the absence of "
              "profile information."));
 
-static cl::opt<int> HotCallSiteRelFreq(
+static cl::opt<uint64_t> HotCallSiteRelFreq(
     "hot-callsite-rel-freq", cl::Hidden, cl::init(60),
     cl::desc("Minimum block frequency, expressed as a multiple of caller's "
              "entry frequency, for a callsite to be hot in the absence of "
@@ -1820,10 +1820,11 @@ InlineCostCallAnalyzer::getHotCallSiteThreshold(CallBase &Call,
   // potentially cache the computation of scaled entry frequency, but the added
   // complexity is not worth it unless this scaling shows up high in the
   // profiles.
-  auto CallSiteBB = Call.getParent();
-  auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency();
-  auto CallerEntryFreq = CallerBFI->getEntryFreq();
-  if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq)
+  const BasicBlock *CallSiteBB = Call.getParent();
+  BlockFrequency CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB);
+  BlockFrequency CallerEntryFreq = CallerBFI->getEntryFreq();
+  BlockFrequency Limit;
+  if (!CallerEntryFreq.mul(HotCallSiteRelFreq, &Limit) && CallSiteFreq >= Limit)
     return Params.LocallyHotCallSiteThreshold;
 
   // Otherwise treat it normally.
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index f07fc4fc52bffba..e24361c1f93970d 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -198,7 +198,7 @@ static cl::opt<bool> BBSectionsGuidedSectionPrefix(
              "impacted, i.e., their prefixes will be decided by FDO/sampleFDO "
              "profiles."));
 
-static cl::opt<unsigned> FreqRatioToSkipMerge(
+static cl::opt<uint64_t> FreqRatioToSkipMerge(
     "cgp-freq-ratio-to-skip-merge", cl::Hidden, cl::init(2),
     cl::desc("Skip merging empty blocks if (frequency of empty block) / "
              "(frequency of destination block) is greater than this ratio"));
@@ -978,16 +978,16 @@ bool CodeGenPrepare::isMergingEmptyBlockProfitable(BasicBlock *BB,
   if (SameIncomingValueBBs.count(Pred))
     return true;
 
-  BlockFrequency PredFreq = BFI->getBlockFreq(Pred);
-  BlockFrequency BBFreq = BFI->getBlockFreq(BB);
+  BlockFrequency PredFreq = BFI->getBlockFreq(Pred).getFrequency();
+  BlockFrequency BBFreq = BFI->getBlockFreq(BB).getFrequency();
 
   for (auto *SameValueBB : SameIncomingValueBBs)
     if (SameValueBB->getUniquePredecessor() == Pred &&
         DestBB == findDestBlockOfMergeableEmptyBlock(SameValueBB))
       BBFreq += BFI->getBlockFreq(SameValueBB);
 
-  return PredFreq.getFrequency() <=
-         BBFreq.getFrequency() * FreqRatioToSkipMerge;
+  BlockFrequency Limit;
+  return !BBFreq.mul(FreqRatioToSkipMerge, &Limit) && PredFreq <= Limit;
 }
 
 /// Return true if we can merge BB into DestBB if there is a single
diff --git a/llvm/lib/Support/BlockFrequency.cpp b/llvm/lib/Support/BlockFrequency.cpp
index a4a1e477d9403f7..08fe3ef6061ecae 100644
--- a/llvm/lib/Support/BlockFrequency.cpp
+++ b/llvm/lib/Support/BlockFrequency.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/Support/BlockFrequency.h"
 #include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/MathExtras.h"
 
 using namespace llvm;
 
@@ -36,3 +37,10 @@ BlockFrequency BlockFrequency::operator/(BranchProbability Prob) const {
   Freq /= Prob;
   return Freq;
 }
+
+bool BlockFrequency::mul(uint64_t Factor, BlockFrequency *Result) const {
+  bool Overflow;
+  uint64_t ResultFrequency = SaturatingMultiply(Frequency, Factor, &Overflow);
+  *Result = BlockFrequency(ResultFrequency);
+  return Overflow;
+}



More information about the llvm-commits mailing list