[llvm] LoopVectorize: Set branch_weight for conditional branches (PR #72450)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 15 15:03:41 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-pgo
Author: Matthias Braun (MatzeB)
<details>
<summary>Changes</summary>
Consistently add `branch_weights` metadata in any condition branch
created from `LoopVectorize.cpp`:
- Will only add metadata if the original loop-latch branch had metadata
assigned.
- Most checks should rarely trigger so I am using a 127:1 ratio.
- For the middle block we assume an equal distribution of modulo
results.
---
Patch is 44.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72450.diff
14 Files Affected:
- (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4)
- (modified) llvm/lib/IR/ProfDataUtils.cpp (+7)
- (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7)
- (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2)
- (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4)
- (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5)
- (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11)
- (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4)
- (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3)
- (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13)
- (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9)
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+60-19)
- (added) llvm/test/Transforms/LoopVectorize/branch-weights.ll (+81)
- (modified) llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll (+30-30)
``````````diff
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
/// metadata was found.
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
} // namespace llvm
#endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+ MDBuilder MDB(I.getContext());
+ MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+ I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
} // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/PseudoProbe.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
- I.setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(
- {static_cast<uint32_t>(BlockWeights[BB])}));
+ setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
// clear it for cold code.
for (auto &I : *BB) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
- if (cast<CallBase>(I).isIndirectCall())
+ if (cast<CallBase>(I).isIndirectCall()) {
I.setMetadata(LLVMContext::MD_prof, nullptr);
- else
- I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+ } else {
+ setBranchWeights(I, {uint32_t(0)});
+ }
}
}
}
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
if (MaxWeight > 0 &&
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
ORE->emit([&]() {
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
<< "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
};
- MDBuilder MDB(F.getContext());
- MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*MergedBR, Weights);
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
<< "\n");
}
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Value.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
if (AttachProfToDirectCall) {
- MDBuilder MDB(NewInst.getContext());
- NewInst.setMetadata(
- LLVMContext::MD_prof,
- MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+ setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
}
using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
// If A is uncovered, set weight=1.
// This setup will allow BFI to give nonzero profile counts to only covered
// blocks.
- SmallVector<unsigned, 4> Weights;
+ SmallVector<uint32_t, 4> Weights;
for (auto *Succ : successors(&BB))
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
if (Weights.size() >= 2)
- BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Weights));
+ llvm::setBranchWeights(*BB.getTerminator(), Weights);
}
unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Module *M, Instruction *TI,
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
- MDBuilder MDB(M->getContext());
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
- TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
if (EmitBranchProbability) {
std::string BrCondStr = getBranchCondString(TI);
if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
if (BP >= BranchProbability(50, 100))
continue;
- SmallVector<uint32_t, 2> Weights;
+ uint32_t Weights[2];
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
- Weights.push_back(BP.getNumerator());
- Weights.push_back(BP.getCompl().getNumerator());
+ Weights[0] = BP.getNumerator();
+ Weights[1] = BP.getCompl().getNumerator();
} else {
- Weights.push_back(BP.getCompl().getNumerator());
- Weights.push_back(BP.getNumerator());
+ Weights[0] = BP.getCompl().getNumerator();
+ Weights[1] = BP.getNumerator();
}
- PredBr->setMetadata(LLVMContext::MD_prof,
- MDBuilder(PredBr->getParent()->getContext())
- .createBranchWeights(Weights));
+ setBranchWeights(*PredBr, Weights);
}
}
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
Weights.push_back(Prob.getNumerator());
auto TI = BB->getTerminator();
- TI->setMetadata(
- LLVMContext::MD_prof,
- MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+ setBranchWeights(*TI, Weights);
}
}
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/MisExpect.h"
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
SI.setCondition(ArgValue);
-
- SI.setMetadata(LLVMContext::MD_prof,
- MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+ setBranchWeights(SI, Weights);
return true;
}
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
// Remove weight for this case.
std::swap(Weights[Idx + 1], Weights.back());
Weights.pop_back();
- SI->setMetadata(LLVMContext::MD_prof,
- MDBuilder(BB->getContext()).
- createBranchWeights(Weights));
+ setBranchWeights(*SI, Weights);
}
// Remove this entry.
BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
/// To avoid dealing with division rounding we can just multiple both part
/// of weights to E and use weight as (F - I * E, E).
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
+ setBranchWeights(*Term, Info.Weights);
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
if (SubWeight != 0)
// Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
}
}
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
- MDBuilder MDB(Term->getContext());
- Term->setMetadata(LLVMContext::MD_prof,
- MDB.createBranchWeights(Info.Weights));
-}
-
/// Clones the body of the loop L, putting it between \p InsertTop and \p
/// InsertBot.
/// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
}
- for (const auto &[Term, Info] : Weights)
- fixupBranchWeights(Term, Info);
+ for (const auto &[Term, Info] : Weights) {
+ setBranchWeights(*Term, Info.Weights);
+ }
// Update Metadata for count of peeled off iterations.
unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
LoopBackWeight = 0;
}
- MDBuilder MDB(LoopBI.getContext());
- MDNode *LoopWeightMD =
- MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
- SuccsSwapped ? ExitWeight1 : LoopBackWeight);
- LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+ const uint32_t LoopBIWeights[] = {
+ SuccsSwapped ? LoopBackWeight : ExitWeight1,
+ SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+ };
+ setBranchWeights(LoopBI, LoopBIWeights);
if (HasConditionalPreHeader) {
- MDNode *PreHeaderWeightMD =
- MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
- SuccsSwapped ? ExitWeight0 : EnterWeight);
- PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+ const uint32_t PreHeaderBIWeights[] = {
+ SuccsSwapped ? EnterWeight : ExitWeight0,
+ SuccsSwapped ? ExitWeight0 : EnterWeight,
+ };
+ setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
}
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9d0315d114f65c..0a10f0d6471769a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -112,10 +112,12 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
@@ -396,6 +398,19 @@ static cl::opt<bool> UseWiderVFIfCallVariantsPresent(
cl::Hidden,
cl::desc("Try wider VFs if they enable the use of vector variants"));
+// Likelyhood of bypassing the vectorized loop because assumptions about SCEV
+// variables not overflowing do not hold. See `emitSCEVChecks`.
+static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because pointers overlap. See
+// `emitMemRuntimeChecks`.
+static constexpr uint32_t MemCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because there are zero trips left
+// after prolog. See `emitIterationCountCheck`.
+static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because of zero trips necessary.
+// See `emitMinimumVectorEpilogueIterCountCheck`.
+static constexpr uint32_t EpilogueMinItersBypassWeights[] = {1, 127};
+
/// A helper function that returns true if the given type is irregular. The
/// type is irregular if its allocated size doesn't equal the store size of an
/// element of the corresponding vector type.
@@ -1962,12 +1977,14 @@ class GeneratedRTChecks {
SCEVExpander MemCheckExp;
bool CostTooHigh = false;
+ const bool AddBranchWeights;
public:
GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
- TargetTransformInfo *TTI, const DataLayout &DL)
+ TargetTransformInfo *TTI, const DataLayout &DL,
+ bool AddBranchWeights)
: DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
- MemCheckExp(SE, DL, "scev.check") {}
+ MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
/// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
/// accurately estimate the cost of the runtime checks. The blocks are
@@ -2160,8 +2177,10 @@ class GeneratedRTChecks {
DT->addNewBlock(SCEVCheckBlock, Pred);
DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock);
- ReplaceInstWithInst(SCEVCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, Cond));
+ BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond);
+ if (AddBranchWeights)
+ setBranchWeights(BI, SCEVCheckBypassWeights);
+ ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI);
return SCEVCheckBlock;
}
@@ -2185,9 +2204,12 @@ class GeneratedRTChecks {
if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
PL->addBasicBlockToLoop(MemCheckBlock, *LI);
- ReplaceInstWithInst(
- MemCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
+ if (AddBranchWeights) {
+ setBranchWeights(BI, MemCheckBypassWeights);
+ }
+ ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI);
MemCheckBlock->getTerminator()->setDebugLoc(
Pred->getTerminator()->getDebugLoc());
@@ -2900,9 +2922,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
// dominator of the exit blocks.
DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock);
- ReplaceInstWithInst(
- TCCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+ setBranchWeights(BI, MinItersBypassWeights);
+ ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
LoopBypassBlocks.push_back(TCCheckBlock);
}
@@ -3133,7 +3157,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
IRBuilder<> B(LoopMiddleBlock->getTerminator());
B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc());
Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n");
- cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN);
+ BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator());
+ BI.setCondition(CmpN);
+ if (hasBranchWeightMD(*ScalarLatchTerm)) {
+ // Assume that `Count % VectorTripCount` is equally distributed.
+ unsigned TripCount = UF * VF.getKnownMinValue();
+ assert(TripCount > 0 && "trip count should not be zero");
+ MDBuilder MDB(ScalarLatchTerm->getContext());
+ MDNode *BranchWeights = MDB.createBranchWeights(1, TripCount - 1);
+ BI.setMetadata(LLVMContext::MD_prof, BranchWeights);
+ }
}
#ifdef EXPENSIVE_CHECKS
@@ -7896,9 +7929,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
EPI.TripCount = Count;
}
- ReplaceInstWithInst(
- TCCheckBlock->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+ setBranchWeights(BI, MinItersBypassWeights);
+ ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
return TCCheckBlock;
}
@@ -8042,9 +8077,11 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
EPI.EpilogueVF, EPI.EpilogueUF),
"min.epilog.iters.check");
- ReplaceInstWithInst(
- Insert->getTerminator(),
- BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+ BranchInst &BI =
+ *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+ if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+ setBranchWeights(BI, EpilogueMinItersBypassWeights);
+ ReplaceInstWithInst(Insert->getTerminator(), &BI);
LoopBypassBlocks.push_back(Insert);
return Insert;
@@ -9731,8 +9768,10 @@ static bool processLoopInVPlanNativePath(
VPlan &BestPlan = LVP.getBestPlanFor(VF.Width);
{
+ bool AddBranchWeights =
+ hasBranchWeightMD(*L->ge...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72450
More information about the llvm-commits
mailing list