[llvm] d434e40 - [llvm][NFC] Refactor code to use ProfDataUtils

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 2 17:09:53 PDT 2022


Author: Paul Kirth
Date: 2022-08-03T00:09:45Z
New Revision: d434e40f398e3144c69d57d2a142d35e2f760a8e

URL: https://github.com/llvm/llvm-project/commit/d434e40f398e3144c69d57d2a142d35e2f760a8e
DIFF: https://github.com/llvm/llvm-project/commit/d434e40f398e3144c69d57d2a142d35e2f760a8e.diff

LOG: [llvm][NFC] Refactor code to use ProfDataUtils

In this patch we replace common code patterns with the use of utility
functions for dealing with profiling metadata. There should be no change
in functionality, as the existing checks should be preserved in all
cases.

Reviewed By: bogner, davidxl

Differential Revision: https://reviews.llvm.org/D128860

Added: 
    

Modified: 
    llvm/include/llvm/IR/Instruction.h
    llvm/include/llvm/IR/ProfDataUtils.h
    llvm/lib/Analysis/BranchProbabilityInfo.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/CodeGen/SelectOptimize.cpp
    llvm/lib/IR/Metadata.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Transforms/IPO/PartialInlining.cpp
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/lib/Transforms/Scalar/JumpThreading.cpp
    llvm/lib/Transforms/Utils/LoopPeel.cpp
    llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
    llvm/lib/Transforms/Utils/LoopUtils.cpp
    llvm/lib/Transforms/Utils/MisExpect.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index 15b0bdf557fb1..1044a634408ca 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -335,11 +335,6 @@ class Instruction : public User,
   /// Sets the AA metadata on this instruction from the AAMDNodes structure.
   void setAAMetadata(const AAMDNodes &N);
 
-  /// Retrieve the raw weight values of a conditional branch or select.
-  /// Returns true on success with profile weights filled in.
-  /// Returns false if no metadata or invalid metadata was found.
-  bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const;
-
   /// Retrieve total raw weight values of a branch.
   /// Returns true on success with profile total weights filled in.
   /// Returns false if no metadata was found.

diff  --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b6c53a3dc1a54..0051c41536b55 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -16,15 +16,15 @@ bool isBranchWeightMD(const MDNode *ProfileData);
 /// Checks if an instructions has Branch Weight Metadata
 ///
 /// \param I The instruction to check
-/// \return True if I has an MD_prof node containing Branch Weights. False
+/// \returns True if I has an MD_prof node containing Branch Weights. False
 /// otherwise.
 bool hasBranchWeightMD(const Instruction &I);
 
 /// Extract branch weights from MD_prof metadata
 ///
 /// \param ProfileData A pointer to an MDNode.
-/// \param Weights An output vector to fill with branch weights
-/// \return True if weights were extracted, False otherwise. When false Weights
+/// \param [out] Weights An output vector to fill with branch weights
+/// \returns True if weights were extracted, False otherwise. When false Weights
 /// will be cleared.
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights);
@@ -32,24 +32,28 @@ bool extractBranchWeights(const MDNode *ProfileData,
 /// Extract branch weights attatched to an Instruction
 ///
 /// \param I The Instruction to extract weights from.
-/// \param Weights An output vector to fill with branch weights
-/// \return True if weights were extracted, False otherwise. When false Weights
+/// \param [out] Weights An output vector to fill with branch weights
+/// \returns True if weights were extracted, False otherwise. When false Weights
 /// will be cleared.
 bool extractBranchWeights(const Instruction &I,
                           SmallVectorImpl<uint32_t> &Weights);
 
-/// Retrieve the raw weight values of a conditional branch or select.
-/// Returns true on success with profile weights filled in.
-/// Returns false if no metadata or invalid metadata was found.
+/// Extract branch weights from a conditional branch or select Instruction.
+///
+/// \param I The instruction to extract branch weights from.
+/// \param [out] TrueVal will contain the branch weight for the True branch
+/// \param [out] FalseVal will contain the branch weight for the False branch
+/// \returns True on success with profile weights filled in. False if no
+/// metadata or invalid metadata was found.
 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
                           uint64_t &FalseVal);
 
 /// Retrieve the total of all weights from MD_prof data.
 ///
 /// \param ProfileData The profile data to extract the total weight from
-/// \param TotalWeights input variable to fill with total weights
-/// \return true on success with profile total weights filled in.
-/// \return false if no metadata was found.
+/// \param [out] TotalWeights input variable to fill with total weights
+/// \returns True on success with profile total weights filled in. False if no
+/// metadata was found.
 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 
 } // namespace llvm

diff  --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index f45728768fcdb..8918fb967594e 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -31,6 +31,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -401,24 +402,18 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
   SmallVector<uint32_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
-  Weights.reserve(TI->getNumSuccessors());
-  for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
-    if (!Weight)
-      return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+
+  extractBranchWeights(*TI, Weights);
+  for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+    WeightSum += Weights[I];
     const LoopBlock SrcLoopBB = getLoopBlock(BB);
-    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
+    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
     if (EstimatedWeight &&
         *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
-      UnreachableIdxs.push_back(I - 1);
+      UnreachableIdxs.push_back(I);
     else
-      ReachableIdxs.push_back(I - 1);
+      ReachableIdxs.push_back(I);
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d5d08f04d4cd8..b100fbe2b33c4 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -65,6 +65,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -6620,7 +6621,7 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI,
   // If metadata tells us that the select condition is obviously predictable,
   // then we want to replace the select with a branch.
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -8366,7 +8367,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // Another choice is to assume TrueProb for BB1 equals to TrueProb for
       // TmpBB, but the math is more complicated.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = TrueWeight;
         uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
@@ -8399,7 +8400,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // assumes that
       //   FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = 2 * TrueWeight + FalseWeight;
         uint64_t NewFalseWeight = FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);

diff  --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 011f55efce1d5..a7a698744591c 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ScaledNumber.h"
@@ -655,7 +656,7 @@ bool SelectOptimize::hasExpensiveColdOperand(
     const SmallVector<SelectInst *, 2> &ASI) {
   bool ColdOperand = false;
   uint64_t TrueWeight, FalseWeight, TotalWeight;
-  if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) {
     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
     TotalWeight = TrueWeight + FalseWeight;
     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
@@ -750,7 +751,7 @@ void SelectOptimize::getExclBackwardsSlice(Instruction *I,
 
 bool SelectOptimize::isSelectHighlyPredictable(const SelectInst *SI) {
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -959,7 +960,7 @@ SelectOptimize::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
                                      const SelectInst *SI) {
   Scaled64 PredPathCost;
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t SumWeight = TrueWeight + FalseWeight;
     if (SumWeight != 0) {
       PredPathCost = TrueCost * Scaled64::get(TrueWeight) +

diff  --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index 2a1a514922fdc..a2027439240fe 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/TrackingMDRef.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -1493,31 +1494,6 @@ void Instruction::getAllMetadataImpl(
   Value::getAllMetadata(Result);
 }
 
-bool Instruction::extractProfMetadata(uint64_t &TrueVal,
-                                      uint64_t &FalseVal) const {
-  assert(
-      (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
-      "Looking for branch weights on something besides branch or select");
-
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return false;
-
-  auto *CITrue = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1));
-  auto *CIFalse = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2));
-  if (!CITrue || !CIFalse)
-    return false;
-
-  TrueVal = CITrue->getValue().getZExtValue();
-  FalseVal = CIFalse->getValue().getZExtValue();
-
-  return true;
-}
-
 bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select ||
@@ -1526,32 +1502,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  TotalVal = 0;
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName)
-    return false;
-
-  if (ProfDataName->getString().equals("branch_weights")) {
-    TotalVal = 0;
-    for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) {
-      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i));
-      if (!V)
-        return false;
-      TotalVal += V->getValue().getZExtValue();
-    }
-    return true;
-  } else if (ProfDataName->getString().equals("VP") &&
-             ProfileData->getNumOperands() > 3) {
-    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
-                   ->getValue()
-                   .getZExtValue();
-    return true;
-  }
-  return false;
+  return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index cf728933c08d2..c8945879d0bc3 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -15,6 +15,7 @@
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSchedule.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/KnownBits.h"
@@ -757,7 +758,7 @@ bool PPCTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
     if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
       uint64_t TrueWeight = 0, FalseWeight = 0;
       if (!BI->isConditional() ||
-          !BI->extractProfMetadata(TrueWeight, FalseWeight))
+          !extractBranchWeights(*BI, TrueWeight, FalseWeight))
         continue;
 
       // If the exit path is more frequent than the loop path,

diff  --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 54c72bdbb203f..ec2e7fb215b2f 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/User.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -717,7 +718,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
     if (!BR || BR->isUnconditional())
       continue;
     uint64_t T, F;
-    if (BR->extractProfMetadata(T, F))
+    if (extractBranchWeights(*BR, T, F))
       return true;
   }
   return false;

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c4512d0222cde..90cc61c6b291b 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -91,6 +91,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -2067,7 +2068,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
       // Display scaled counts for SELECT instruction:
       OS << "SELECT : { T = ";
       uint64_t TC, FC;
-      bool HasProf = I.extractProfMetadata(TC, FC);
+      bool HasProf = extractBranchWeights(I, TC, FC);
       if (!HasProf)
         OS << "Unknown, F = Unknown }\\l";
       else

diff  --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index b31eab50c5ecc..0113e7bf570df 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -54,6 +54,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/Value.h"
@@ -216,7 +217,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     return;
 
   uint64_t TrueWeight, FalseWeight;
-  if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight))
     return;
 
   if (TrueWeight + FalseWeight == 0)
@@ -279,7 +280,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     // With PGO, this can be used to refine even existing profile data with
     // context information. This needs to be done after more performance
     // testing.
-    if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight))
+    if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight))
       continue;
 
     // We can not infer anything useful when BP >= 50%, because BP is the

diff  --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index f093fea19c4d4..9a7f9df7a330f 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -532,7 +533,7 @@ static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
                               uint64_t &ExitWeight,
                               uint64_t &FallThroughWeight) {
   uint64_t TrueWeight, FalseWeight;
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
   ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;

diff  --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 023a0afd329b5..1c44ccb589bcc 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -30,6 +30,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -471,7 +472,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
   uint64_t TrueWeight, FalseWeight;
   BranchInst *LatchBR =
       cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
                             ? FalseWeight

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 349063dd5e892..03272e5f8a549 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -38,6 +38,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -789,7 +790,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L,
   // know the number of times the backedge was taken, vs. the number of times
   // we exited the loop.
   uint64_t LoopWeight, ExitWeight;
-  if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight))
+  if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
     return None;
 
   if (L->contains(ExitingBranch->getSuccessor(1)))

diff  --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
index d85e9ddb2beec..8f94902b331fd 100644
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -118,34 +119,6 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
 namespace llvm {
 namespace misexpect {
 
-// Helper function to extract branch weights into a vector
-Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
-                                                  LLVMContext &Ctx) {
-  assert(I && "MisExpect::extractWeights given invalid pointer");
-
-  auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return None;
-
-  unsigned NOps = ProfileData->getNumOperands();
-  if (NOps < 3)
-    return None;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return None;
-
-  SmallVector<uint32_t, 4> Weights(NOps - 1);
-  for (unsigned Idx = 1; Idx < NOps; Idx++) {
-    ConstantInt *Value =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    uint32_t V = Value->getZExtValue();
-    Weights[Idx - 1] = V;
-  }
-
-  return Weights;
-}
-
 // TODO: when clang allows c++17, use std::clamp instead
 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
   if (value > hi)
@@ -218,19 +191,17 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
-  auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
-  if (!ExpectedWeightsOpt)
+  SmallVector<uint32_t> ExpectedWeights;
+  if (!extractBranchWeights(I, ExpectedWeights))
     return;
-  auto ExpectedWeights = ExpectedWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
 void checkFrontendInstrumentation(Instruction &I,
                                   const ArrayRef<uint32_t> ExpectedWeights) {
-  auto RealWeightsOpt = extractWeights(&I, I.getContext());
-  if (!RealWeightsOpt)
+  SmallVector<uint32_t> RealWeights;
+  if (!extractBranchWeights(I, RealWeights))
     return;
-  auto RealWeights = RealWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 1806081678a86..bba831521ff59 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -57,6 +57,7 @@
 #include "llvm/IR/NoFolder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1050,15 +1051,6 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
   return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
 }
 
-static inline bool HasBranchWeights(const Instruction *I) {
-  MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
-  if (ProfMD && ProfMD->getOperand(0))
-    if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
-      return MDS->getString().equals("branch_weights");
-
-  return false;
-}
-
 /// Get Weights of a given terminator, the default weight is at the front
 /// of the vector. If TI is a conditional eq, we need to swap the branch-weight
 /// metadata.
@@ -1177,8 +1169,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
 
   // Update the branch weight metadata along the way
   SmallVector<uint64_t, 8> Weights;
-  bool PredHasWeights = HasBranchWeights(PTI);
-  bool SuccHasWeights = HasBranchWeights(TI);
+  bool PredHasWeights = hasBranchWeightMD(*PTI);
+  bool SuccHasWeights = hasBranchWeightMD(*TI);
 
   if (PredHasWeights) {
     GetBranchWeights(PTI, Weights);
@@ -2752,7 +2744,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
   // the `then` block, then avoid speculating it.
   if (!BI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) {
+    if (extractBranchWeights(*BI, TWeight, FWeight) &&
+        (TWeight + FWeight) != 0) {
       uint64_t EndWeight = Invert ? TWeight : FWeight;
       BranchProbability BIEndProb =
           BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
@@ -3174,7 +3167,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // from the block that we know is predictably not entered.
   if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (DomBI->extractProfMetadata(TWeight, FWeight) &&
+    if (extractBranchWeights(*DomBI, TWeight, FWeight) &&
         (TWeight + FWeight) != 0) {
       BranchProbability BITrueProb =
           BranchProbability::getBranchProbability(TWeight, TWeight + FWeight);
@@ -3354,9 +3347,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI,
                                    uint64_t &SuccTrueWeight,
                                    uint64_t &SuccFalseWeight) {
   bool PredHasWeights =
-      PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight);
+      extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight);
   bool SuccHasWeights =
-      BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight);
+      extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight);
   if (PredHasWeights || SuccHasWeights) {
     if (!PredHasWeights)
       PredTrueWeight = PredFalseWeight = 1;
@@ -3384,7 +3377,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI,
   uint64_t PTWeight, PFWeight;
   BranchProbability PBITrueProb, Likely;
   if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) &&
-      PBI->extractProfMetadata(PTWeight, PFWeight) &&
+      extractBranchWeights(*PBI, PTWeight, PFWeight) &&
       (PTWeight + PFWeight) != 0) {
     PBITrueProb =
         BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight);
@@ -4349,7 +4342,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI,
   // Get weight for TrueBB and FalseBB.
   uint32_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
-  bool HasWeights = HasBranchWeights(SI);
+  bool HasWeights = hasBranchWeightMD(*SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
@@ -5209,7 +5202,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI,
   BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
 
   // Update weight for the newly-created conditional branch.
-  if (HasBranchWeights(SI)) {
+  if (hasBranchWeightMD(*SI)) {
     SmallVector<uint64_t, 8> Weights;
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {


        


More information about the llvm-commits mailing list