[llvm] 300c9a7 - [llvm][NFC] Refactor code to use ProfDataUtils

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 27 14:14:06 PDT 2022


Author: Paul Kirth
Date: 2022-07-27T21:13:54Z
New Revision: 300c9a78819b4608b96bb26f9320bea6b8a0c4d0

URL: https://github.com/llvm/llvm-project/commit/300c9a78819b4608b96bb26f9320bea6b8a0c4d0
DIFF: https://github.com/llvm/llvm-project/commit/300c9a78819b4608b96bb26f9320bea6b8a0c4d0.diff

LOG: [llvm][NFC] Refactor code to use ProfDataUtils

In this patch we replace common code patterns with the use of utility
functions for dealing with profiling metadata. There should be no change
in functionality, as the existing checks should be preserved in all
cases.

Reviewed By: bogner, davidxl

Differential Revision: https://reviews.llvm.org/D128860

Added: 
    

Modified: 
    llvm/include/llvm/IR/Instruction.h
    llvm/lib/Analysis/BranchProbabilityInfo.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/CodeGen/SelectOptimize.cpp
    llvm/lib/IR/Metadata.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Transforms/IPO/PartialInlining.cpp
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/lib/Transforms/Scalar/JumpThreading.cpp
    llvm/lib/Transforms/Utils/LoopPeel.cpp
    llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
    llvm/lib/Transforms/Utils/LoopUtils.cpp
    llvm/lib/Transforms/Utils/MisExpect.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index 15b0bdf557fb1..1044a634408ca 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -335,11 +335,6 @@ class Instruction : public User,
   /// Sets the AA metadata on this instruction from the AAMDNodes structure.
   void setAAMetadata(const AAMDNodes &N);
 
-  /// Retrieve the raw weight values of a conditional branch or select.
-  /// Returns true on success with profile weights filled in.
-  /// Returns false if no metadata or invalid metadata was found.
-  bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const;
-
   /// Retrieve total raw weight values of a branch.
   /// Returns true on success with profile total weights filled in.
   /// Returns false if no metadata was found.

diff  --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index f45728768fcdb..8918fb967594e 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -31,6 +31,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -401,24 +402,18 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
   SmallVector<uint32_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
-  Weights.reserve(TI->getNumSuccessors());
-  for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
-    if (!Weight)
-      return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+
+  extractBranchWeights(*TI, Weights);
+  for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+    WeightSum += Weights[I];
     const LoopBlock SrcLoopBB = getLoopBlock(BB);
-    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
+    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
     if (EstimatedWeight &&
         *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
-      UnreachableIdxs.push_back(I - 1);
+      UnreachableIdxs.push_back(I);
     else
-      ReachableIdxs.push_back(I - 1);
+      ReachableIdxs.push_back(I);
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index b8f6fc9bbcde9..a6b3ced721599 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -65,6 +65,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -6620,7 +6621,7 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI,
   // If metadata tells us that the select condition is obviously predictable,
   // then we want to replace the select with a branch.
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -8362,7 +8363,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // Another choice is to assume TrueProb for BB1 equals to TrueProb for
       // TmpBB, but the math is more complicated.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = TrueWeight;
         uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
@@ -8395,7 +8396,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // assumes that
       //   FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = 2 * TrueWeight + FalseWeight;
         uint64_t NewFalseWeight = FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);

diff  --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 011f55efce1d5..a7a698744591c 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ScaledNumber.h"
@@ -655,7 +656,7 @@ bool SelectOptimize::hasExpensiveColdOperand(
     const SmallVector<SelectInst *, 2> &ASI) {
   bool ColdOperand = false;
   uint64_t TrueWeight, FalseWeight, TotalWeight;
-  if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) {
     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
     TotalWeight = TrueWeight + FalseWeight;
     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
@@ -750,7 +751,7 @@ void SelectOptimize::getExclBackwardsSlice(Instruction *I,
 
 bool SelectOptimize::isSelectHighlyPredictable(const SelectInst *SI) {
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -959,7 +960,7 @@ SelectOptimize::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
                                      const SelectInst *SI) {
   Scaled64 PredPathCost;
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t SumWeight = TrueWeight + FalseWeight;
     if (SumWeight != 0) {
       PredPathCost = TrueCost * Scaled64::get(TrueWeight) +

diff  --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index 2a1a514922fdc..a2027439240fe 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/TrackingMDRef.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -1493,31 +1494,6 @@ void Instruction::getAllMetadataImpl(
   Value::getAllMetadata(Result);
 }
 
-bool Instruction::extractProfMetadata(uint64_t &TrueVal,
-                                      uint64_t &FalseVal) const {
-  assert(
-      (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
-      "Looking for branch weights on something besides branch or select");
-
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return false;
-
-  auto *CITrue = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1));
-  auto *CIFalse = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2));
-  if (!CITrue || !CIFalse)
-    return false;
-
-  TrueVal = CITrue->getValue().getZExtValue();
-  FalseVal = CIFalse->getValue().getZExtValue();
-
-  return true;
-}
-
 bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select ||
@@ -1526,32 +1502,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  TotalVal = 0;
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName)
-    return false;
-
-  if (ProfDataName->getString().equals("branch_weights")) {
-    TotalVal = 0;
-    for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) {
-      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i));
-      if (!V)
-        return false;
-      TotalVal += V->getValue().getZExtValue();
-    }
-    return true;
-  } else if (ProfDataName->getString().equals("VP") &&
-             ProfileData->getNumOperands() > 3) {
-    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
-                   ->getValue()
-                   .getZExtValue();
-    return true;
-  }
-  return false;
+  return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index cf728933c08d2..34dfdcf2b1109 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -15,6 +15,7 @@
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSchedule.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/KnownBits.h"
@@ -757,7 +758,7 @@ bool PPCTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
     if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
       uint64_t TrueWeight = 0, FalseWeight = 0;
       if (!BI->isConditional() ||
-          !BI->extractProfMetadata(TrueWeight, FalseWeight))
+          !extractBranchWeights(BI, TrueWeight, FalseWeight))
         continue;
 
       // If the exit path is more frequent than the loop path,

diff  --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 54c72bdbb203f..ec2e7fb215b2f 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/User.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -717,7 +718,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
     if (!BR || BR->isUnconditional())
       continue;
     uint64_t T, F;
-    if (BR->extractProfMetadata(T, F))
+    if (extractBranchWeights(*BR, T, F))
       return true;
   }
   return false;

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c4512d0222cde..90cc61c6b291b 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -91,6 +91,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -2067,7 +2068,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
       // Display scaled counts for SELECT instruction:
       OS << "SELECT : { T = ";
       uint64_t TC, FC;
-      bool HasProf = I.extractProfMetadata(TC, FC);
+      bool HasProf = extractBranchWeights(I, TC, FC);
       if (!HasProf)
         OS << "Unknown, F = Unknown }\\l";
       else

diff  --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index b31eab50c5ecc..0113e7bf570df 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -54,6 +54,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/Value.h"
@@ -216,7 +217,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     return;
 
   uint64_t TrueWeight, FalseWeight;
-  if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight))
     return;
 
   if (TrueWeight + FalseWeight == 0)
@@ -279,7 +280,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     // With PGO, this can be used to refine even existing profile data with
     // context information. This needs to be done after more performance
     // testing.
-    if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight))
+    if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight))
       continue;
 
     // We can not infer anything useful when BP >= 50%, because BP is the

diff  --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index f093fea19c4d4..9a7f9df7a330f 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -532,7 +533,7 @@ static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
                               uint64_t &ExitWeight,
                               uint64_t &FallThroughWeight) {
   uint64_t TrueWeight, FalseWeight;
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
   ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;

diff  --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 023a0afd329b5..1c44ccb589bcc 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -30,6 +30,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -471,7 +472,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
   uint64_t TrueWeight, FalseWeight;
   BranchInst *LatchBR =
       cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
                             ? FalseWeight

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 349063dd5e892..03272e5f8a549 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -38,6 +38,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -789,7 +790,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L,
   // know the number of times the backedge was taken, vs. the number of times
   // we exited the loop.
   uint64_t LoopWeight, ExitWeight;
-  if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight))
+  if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
     return None;
 
   if (L->contains(ExitingBranch->getSuccessor(1)))

diff  --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
index 4414b04c72645..2960c79f5dceb 100644
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -118,34 +119,6 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
 namespace llvm {
 namespace misexpect {
 
-// Helper function to extract branch weights into a vector
-Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
-                                                  LLVMContext &Ctx) {
-  assert(I && "MisExpect::extractWeights given invalid pointer");
-
-  auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return None;
-
-  unsigned NOps = ProfileData->getNumOperands();
-  if (NOps < 3)
-    return None;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return None;
-
-  SmallVector<uint32_t, 4> Weights(NOps - 1);
-  for (unsigned Idx = 1; Idx < NOps; Idx++) {
-    ConstantInt *Value =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    uint32_t V = Value->getZExtValue();
-    Weights[Idx - 1] = V;
-  }
-
-  return Weights;
-}
-
 // TODO: when clang allows c++17, use std::clamp instead
 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
   if (value > hi)
@@ -218,19 +191,17 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
-  auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
-  if (!ExpectedWeightsOpt)
+  SmallVector<uint32_t> ExpectedWeights;
+  if (!extractBranchWeights(I, ExpectedWeights))
     return;
-  auto ExpectedWeights = ExpectedWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
 void checkFrontendInstrumentation(Instruction &I,
                                   const ArrayRef<uint32_t> ExpectedWeights) {
-  auto RealWeightsOpt = extractWeights(&I, I.getContext());
-  if (!RealWeightsOpt)
+  SmallVector<uint32_t> RealWeights;
+  if (!extractBranchWeights(I, RealWeights))
     return;
-  auto RealWeights = RealWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 1806081678a86..bba831521ff59 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -57,6 +57,7 @@
 #include "llvm/IR/NoFolder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1050,15 +1051,6 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
   return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
 }
 
-static inline bool HasBranchWeights(const Instruction *I) {
-  MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
-  if (ProfMD && ProfMD->getOperand(0))
-    if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
-      return MDS->getString().equals("branch_weights");
-
-  return false;
-}
-
 /// Get Weights of a given terminator, the default weight is at the front
 /// of the vector. If TI is a conditional eq, we need to swap the branch-weight
 /// metadata.
@@ -1177,8 +1169,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
 
   // Update the branch weight metadata along the way
   SmallVector<uint64_t, 8> Weights;
-  bool PredHasWeights = HasBranchWeights(PTI);
-  bool SuccHasWeights = HasBranchWeights(TI);
+  bool PredHasWeights = hasBranchWeightMD(*PTI);
+  bool SuccHasWeights = hasBranchWeightMD(*TI);
 
   if (PredHasWeights) {
     GetBranchWeights(PTI, Weights);
@@ -2752,7 +2744,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
   // the `then` block, then avoid speculating it.
   if (!BI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) {
+    if (extractBranchWeights(*BI, TWeight, FWeight) &&
+        (TWeight + FWeight) != 0) {
       uint64_t EndWeight = Invert ? TWeight : FWeight;
       BranchProbability BIEndProb =
           BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
@@ -3174,7 +3167,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // from the block that we know is predictably not entered.
   if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (DomBI->extractProfMetadata(TWeight, FWeight) &&
+    if (extractBranchWeights(*DomBI, TWeight, FWeight) &&
         (TWeight + FWeight) != 0) {
       BranchProbability BITrueProb =
           BranchProbability::getBranchProbability(TWeight, TWeight + FWeight);
@@ -3354,9 +3347,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI,
                                    uint64_t &SuccTrueWeight,
                                    uint64_t &SuccFalseWeight) {
   bool PredHasWeights =
-      PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight);
+      extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight);
   bool SuccHasWeights =
-      BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight);
+      extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight);
   if (PredHasWeights || SuccHasWeights) {
     if (!PredHasWeights)
       PredTrueWeight = PredFalseWeight = 1;
@@ -3384,7 +3377,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI,
   uint64_t PTWeight, PFWeight;
   BranchProbability PBITrueProb, Likely;
   if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) &&
-      PBI->extractProfMetadata(PTWeight, PFWeight) &&
+      extractBranchWeights(*PBI, PTWeight, PFWeight) &&
       (PTWeight + PFWeight) != 0) {
     PBITrueProb =
         BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight);
@@ -4349,7 +4342,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI,
   // Get weight for TrueBB and FalseBB.
   uint32_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
-  bool HasWeights = HasBranchWeights(SI);
+  bool HasWeights = hasBranchWeightMD(*SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
@@ -5209,7 +5202,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI,
   BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
 
   // Update weight for the newly-created conditional branch.
-  if (HasBranchWeights(SI)) {
+  if (hasBranchWeightMD(*SI)) {
     SmallVector<uint64_t, 8> Weights;
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {


        


More information about the llvm-commits mailing list