[llvm] Add setBranchWeigths convenience function (PR #72446)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 15 14:46:17 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-pgo

Author: Matthias Braun (MatzeB)

<details>
<summary>Changes</summary>

Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.


---
Full diff: https://github.com/llvm/llvm-project/pull/72446.diff


11 Files Affected:

- (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4) 
- (modified) llvm/lib/IR/ProfDataUtils.cpp (+7) 
- (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7) 
- (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2) 
- (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4) 
- (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5) 
- (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11) 
- (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4) 
- (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3) 
- (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13) 
- (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9) 


``````````diff
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 /// metadata was found.
 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+  MDBuilder MDB(I.getContext());
+  MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+  I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/PseudoProbe.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // clear it for cold code.
       for (auto &I : *BB) {
         if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-          if (cast<CallBase>(I).isIndirectCall())
+          if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
-          else
-            I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+          } else {
+            setBranchWeights(I, {uint32_t(0)});
+          }
         }
       }
     }
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
     if (MaxWeight > 0 &&
         (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
       LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
-      TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+      setBranchWeights(*TI, Weights);
       ORE->emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
                << "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
       static_cast<uint32_t>(CHRBranchBias.scale(1000)),
       static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
   };
-  MDBuilder MDB(F.getContext());
-  MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*MergedBR, Weights);
   CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
             << "\n");
 }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst.getContext());
-    NewInst.setMetadata(
-        LLVMContext::MD_prof,
-        MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
   }
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
     // If A is uncovered, set weight=1.
     // This setup will allow BFI to give nonzero profile counts to only covered
     // blocks.
-    SmallVector<unsigned, 4> Weights;
+    SmallVector<uint32_t, 4> Weights;
     for (auto *Succ : successors(&BB))
       Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
     if (Weights.size() >= 2)
-      BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
-                                      MDB.createBranchWeights(Weights));
+      llvm::setBranchWeights(*BB.getTerminator(), Weights);
   }
 
   unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Module *M, Instruction *TI,
                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
-  MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
   SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
 
   misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
 
-  TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*TI, Weights);
   if (EmitBranchProbability) {
     std::string BrCondStr = getBranchCondString(TI);
     if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    uint32_t Weights[2];
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
-      Weights.push_back(BP.getNumerator());
-      Weights.push_back(BP.getCompl().getNumerator());
+      Weights[0] = BP.getNumerator();
+      Weights[1] = BP.getCompl().getNumerator();
     } else {
-      Weights.push_back(BP.getCompl().getNumerator());
-      Weights.push_back(BP.getNumerator());
+      Weights[0] = BP.getCompl().getNumerator();
+      Weights[1] = BP.getNumerator();
     }
-    PredBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(PredBr->getParent()->getContext())
-                            .createBranchWeights(Weights));
+    setBranchWeights(*PredBr, Weights);
   }
 }
 
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    TI->setMetadata(
-        LLVMContext::MD_prof,
-        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+    setBranchWeights(*TI, Weights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
   misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
 
   SI.setCondition(ArgValue);
-
-  SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+  setBranchWeights(SI, Weights);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          setBranchWeights(*SI, Weights);
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
 /// To avoid dealing with division rounding we can just multiple both part
 /// of weights to E and use weight as (F - I * E, E).
 static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
+  setBranchWeights(*Term, Info.Weights);
   for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
     if (SubWeight != 0)
       // Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
   }
 }
 
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
-}
-
 /// Clones the body of the loop L, putting it between \p InsertTop and \p
 /// InsertBot.
 /// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
     PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
   }
 
-  for (const auto &[Term, Info] : Weights)
-    fixupBranchWeights(Term, Info);
+  for (const auto &[Term, Info] : Weights) {
+    setBranchWeights(*Term, Info.Weights);
+  }
 
   // Update Metadata for count of peeled off iterations.
   unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     LoopBackWeight = 0;
   }
 
-  MDBuilder MDB(LoopBI.getContext());
-  MDNode *LoopWeightMD =
-      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
-                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
-  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  const uint32_t LoopBIWeights[] = {
+      SuccsSwapped ? LoopBackWeight : ExitWeight1,
+      SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+  };
+  setBranchWeights(LoopBI, LoopBIWeights);
   if (HasConditionalPreHeader) {
-    MDNode *PreHeaderWeightMD =
-        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
-                                SuccsSwapped ? ExitWeight0 : EnterWeight);
-    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+    const uint32_t PreHeaderBIWeights[] = {
+        SuccsSwapped ? EnterWeight : ExitWeight0,
+        SuccsSwapped ? ExitWeight0 : EnterWeight,
+    };
+    setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
   }
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/72446


More information about the llvm-commits mailing list