[llvm] 9403514 - [LoopPredication] Calculate profitability without BPI

Anna Thomas via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 19 11:25:15 PDT 2021


Author: Anna Thomas
Date: 2021-10-19T14:24:04-04:00
New Revision: 9403514e764950a0dfcd627fc90c73432314bced

URL: https://github.com/llvm/llvm-project/commit/9403514e764950a0dfcd627fc90c73432314bced
DIFF: https://github.com/llvm/llvm-project/commit/9403514e764950a0dfcd627fc90c73432314bced.diff

LOG: [LoopPredication] Calculate profitability without BPI

Using BPI within loop predication is non-trivial because BPI is only
preserved lossily in loop pass manager (one fix exposed by lossy
preservation is up for review at D111448). However, since loop
predication is only used in downstream pipelines, it is hard to keep BPI
from breaking for incomplete state with upstream changes in BPI.
Also, correctly preserving BPI for all loop passes is a non-trivial
undertaking (D110438 does this lossily), while the benefit of using it
in loop predication isn't clear.

In this patch, we rely on profile metadata to get almost similar benefit as
BPI, without actually using the complete heuristics provided by BPI.
This avoids the compile time explosion we tried to fix with D110438 and
also avoids fragile bugs because BPI can be lossy in loop passes
(D111448).

Reviewed-By: asbirlea, apilipenko
Differential Revision: https://reviews.llvm.org/D111668

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopPredication.cpp
    llvm/test/Transforms/LoopPredication/profitability.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index d8ddcd6a18cc6..aa7e79a589f27 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -256,7 +256,6 @@ class LoopPredication {
   DominatorTree *DT;
   ScalarEvolution *SE;
   LoopInfo *LI;
-  BranchProbabilityInfo *BPI;
   MemorySSAUpdater *MSSAU;
 
   Loop *L;
@@ -305,16 +304,15 @@ class LoopPredication {
   // If the loop always exits through another block in the loop, we should not
   // predicate based on the latch check. For example, the latch check can be a
   // very coarse grained check and there can be more fine grained exit checks
-  // within the loop. We identify such unprofitable loops through BPI.
+  // within the loop.
   bool isLoopProfitableToPredicate();
 
   bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
 
 public:
   LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
-                  LoopInfo *LI, BranchProbabilityInfo *BPI,
-                  MemorySSAUpdater *MSSAU)
-      : AA(AA), DT(DT), SE(SE), LI(LI), BPI(BPI), MSSAU(MSSAU){};
+                  LoopInfo *LI, MemorySSAUpdater *MSSAU)
+      : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
   bool runOnLoop(Loop *L);
 };
 
@@ -341,10 +339,8 @@ class LoopPredicationLegacyPass : public LoopPass {
     std::unique_ptr<MemorySSAUpdater> MSSAU;
     if (MSSAWP)
       MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
-    BranchProbabilityInfo &BPI =
-        getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
     auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
-    LoopPredication LP(AA, DT, SE, LI, &BPI, MSSAU ? MSSAU.get() : nullptr);
+    LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
     return LP.runOnLoop(L);
   }
 };
@@ -369,7 +365,7 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
   std::unique_ptr<MemorySSAUpdater> MSSAU;
   if (AR.MSSA)
     MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
-  LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, AR.BPI,
+  LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
                      MSSAU ? MSSAU.get() : nullptr);
   if (!LP.runOnLoop(&L))
     return PreservedAnalyses::all();
@@ -922,7 +918,7 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
 
 
 bool LoopPredication::isLoopProfitableToPredicate() {
-  if (SkipProfitabilityChecks || !BPI)
+  if (SkipProfitabilityChecks)
     return true;
 
   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges;
@@ -944,8 +940,61 @@ bool LoopPredication::isLoopProfitableToPredicate() {
          "expected to be an exiting block with 2 succs!");
   unsigned LatchBrExitIdx =
       LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
+  // We compute branch probabilities without BPI. We do not rely on BPI since
+  // Loop predication is usually run in an LPM and BPI is only preserved
+  // lossily within loop pass managers, while BPI has an inherent notion of
+  // being complete for an entire function.
+
+  // If the latch exits into a deoptimize or an unreachable block, do not
+  // predicate on that latch check.
+  auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
+  if (isa<UnreachableInst>(LatchTerm) ||
+      LatchExitBlock->getTerminatingDeoptimizeCall())
+    return false;
+
+  auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) {
+    if (!ProfileData || !ProfileData->getOperand(0))
+      return false;
+    if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
+      if (!MDS->getString().equals("branch_weights"))
+        return false;
+    if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
+      return false;
+    return true;
+  };
+  MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
+  // Latch terminator has no valid profile data, so nothing to check
+  // profitability on.
+  if (!IsValidProfileData(LatchProfileData, LatchTerm))
+    return true;
+
+  auto ComputeBranchProbability =
+      [&](const BasicBlock *ExitingBlock,
+          const BasicBlock *ExitBlock) -> BranchProbability {
+    auto *Term = ExitingBlock->getTerminator();
+    MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
+    unsigned NumSucc = Term->getNumSuccessors();
+    if (IsValidProfileData(ProfileData, Term)) {
+      uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
+      for (unsigned i = 0; i < NumSucc; i++) {
+        ConstantInt *CI =
+            mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
+        ProfVal = CI->getValue().getZExtValue();
+        if (Term->getSuccessor(i) == ExitBlock)
+          Numerator += ProfVal;
+        Denominator += ProfVal;
+      }
+      return BranchProbability::getBranchProbability(Numerator, Denominator);
+    } else {
+      assert(LatchBlock != ExitingBlock &&
+             "Latch term should always have profile data!");
+      // No profile data, so we choose the weight as 1/num_of_succ(Src)
+      return BranchProbability::getBranchProbability(1, NumSucc);
+    }
+  };
+
   BranchProbability LatchExitProbability =
-      BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
+      ComputeBranchProbability(LatchBlock, LatchExitBlock);
 
   // Protect against degenerate inputs provided by the user. Providing a value
   // less than one, can invert the definition of profitable loop predication.
@@ -958,18 +1007,18 @@ bool LoopPredication::isLoopProfitableToPredicate() {
     LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
     ScaleFactor = 1.0;
   }
-  const auto LatchProbabilityThreshold =
-      LatchExitProbability * ScaleFactor;
+  const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
 
   for (const auto &ExitEdge : ExitEdges) {
     BranchProbability ExitingBlockProbability =
-        BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
+        ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
     // Some exiting edge has higher probability than the latch exiting edge.
     // No longer profitable to predicate.
     if (ExitingBlockProbability > LatchProbabilityThreshold)
       return false;
   }
-  // Using BPI, we have concluded that the most probable way to exit from the
+
+  // We have concluded that the most probable way to exit from the
   // loop is through the latch (or there's no profile information and all
   // exits are equally likely).
   return true;

diff  --git a/llvm/test/Transforms/LoopPredication/profitability.ll b/llvm/test/Transforms/LoopPredication/profitability.ll
index 942e0fea2a777..aae893b86c9d5 100644
--- a/llvm/test/Transforms/LoopPredication/profitability.ll
+++ b/llvm/test/Transforms/LoopPredication/profitability.ll
@@ -1,9 +1,9 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require<scalar-evolution>,require<branch-prob>,loop-mssa(loop-predication)' -verify-memoryssa < %s 2>&1 | FileCheck %s
+; RUN: opt -S -loop-predication-skip-profitability-checks=false -passes='require<scalar-evolution>,loop-mssa(loop-predication)' -verify-memoryssa < %s 2>&1 | FileCheck %s
 
-; latch block exits to a speculation block. BPI already knows (without prof
-; data) that deopt is very rarely
-; taken. So we do not predicate this loop using that coarse latch check.
+; latch block exits to a speculation block. We account for this since deopt is
+; very rarely taken. So we do not predicate this loop using that coarse latch
+; check.
 ; LatchExitProbability: 0x04000000 / 0x80000000 = 3.12%
 ; ExitingBlockProbability: 0x7ffa572a / 0x80000000 = 99.98%
 define i64 @donot_predicate(i64* nocapture readonly %arg, i32 %length, i64* nocapture readonly %arg2, i64* nocapture readonly %n_addr, i64 %i) !prof !21 {


        


More information about the llvm-commits mailing list