[llvm] e741b8c - [llvm][ir] Purge MD_prof custom accessors

Christian Ulmann via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 19 05:30:11 PST 2023


Author: Christian Ulmann
Date: 2023-01-19T14:26:26+01:00
New Revision: e741b8c2e5200269f846bdd88ca98e44681fe8df

URL: https://github.com/llvm/llvm-project/commit/e741b8c2e5200269f846bdd88ca98e44681fe8df
DIFF: https://github.com/llvm/llvm-project/commit/e741b8c2e5200269f846bdd88ca98e44681fe8df.diff

LOG: [llvm][ir] Purge MD_prof custom accessors

This commit purges direct accesses to MD_prof metadata and replaces them
with the accessors provided from the utility file wherever possible.
This commit can be seen as the first step towards switching the branch weights to 64 bits.
See post here: https://discourse.llvm.org/t/extend-md-prof-branch-weights-metadata-from-32-to-64-bits/67492

Reviewed By: davidxl, paulkirth

Differential Revision: https://reviews.llvm.org/D141393

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/CFGPrinter.h
    llvm/include/llvm/IR/Instructions.h
    llvm/include/llvm/IR/ProfDataUtils.h
    llvm/lib/Analysis/BranchProbabilityInfo.cpp
    llvm/lib/IR/Instruction.cpp
    llvm/lib/IR/Instructions.cpp
    llvm/lib/IR/Metadata.cpp
    llvm/lib/IR/ProfDataUtils.cpp
    llvm/lib/Transforms/IPO/PartialInlining.cpp
    llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
    llvm/lib/Transforms/Scalar/JumpThreading.cpp
    llvm/lib/Transforms/Scalar/LoopPredication.cpp
    llvm/lib/Transforms/Utils/Local.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h
index 5def897839a4c..eeac11bc7af12 100644
--- a/llvm/include/llvm/Analysis/CFGPrinter.h
+++ b/llvm/include/llvm/Analysis/CFGPrinter.h
@@ -26,6 +26,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/DOTGraphTraits.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -276,14 +277,10 @@ struct DOTGraphTraits<DOTFuncInfo *> : public DefaultDOTGraphTraits {
     if (Attrs.size())
       return Attrs;
 
-    MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
+    MDNode *WeightsNode = getBranchWeightMDNode(*TI);
     if (!WeightsNode)
       return "";
 
-    MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
-    if (MDName->getString() != "branch_weights")
-      return "";
-
     OpNo = I.getSuccessorIndex() + 1;
     if (OpNo >= WeightsNode->getNumOperands())
       return "";

diff  --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 9f82ac91e92d0..b24176a573c7c 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -3620,8 +3620,6 @@ class SwitchInstProfUpdateWrapper {
   bool Changed = false;
 
 protected:
-  static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
-
   MDNode *buildProfBranchWeightsMD();
 
   void init();

diff  --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 1298af12c6481..8019d0a3969b1 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -34,6 +34,27 @@ bool isBranchWeightMD(const MDNode *ProfileData);
 /// otherwise.
 bool hasBranchWeightMD(const Instruction &I);
 
+/// Checks if an instructions has valid Branch Weight Metadata
+///
+/// \param I The instruction to check
+/// \returns True if I has an MD_prof node containing valid Branch Weights,
+/// i.e., one weight for each successor. False otherwise.
+bool hasValidBranchWeightMD(const Instruction &I);
+
+/// Get the branch weights metadata node
+///
+/// \param I The Instruction to get the weights from.
+/// \returns A pointer to I's branch weights metadata node, if it exists.
+/// Nullptr otherwise.
+MDNode *getBranchWeightMDNode(const Instruction &I);
+
+/// Get the valid branch weights metadata node
+///
+/// \param I The Instruction to get the weights from.
+/// \returns A pointer to I's valid branch weights metadata node, if it exists.
+/// Nullptr otherwise.
+MDNode *getValidBranchWeightMDNode(const Instruction &I);
+
 /// Extract branch weights from MD_prof metadata
 ///
 /// \param ProfileData A pointer to an MDNode.
@@ -70,5 +91,13 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
 /// metadata was found.
 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 
+/// Retrieve the total of all weights from an instruction.
+///
+/// \param I The instruction to extract the total weight from
+/// \param [out] TotalWeights input variable to fill with total weights
+/// \returns True on success with profile total weights filled in. False if no
+/// metadata was found.
+bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
+
 } // namespace llvm
 #endif

diff  --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index fc908e136c240..7931001d0a2b2 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -383,18 +383,13 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
         isa<InvokeInst>(TI) || isa<CallBrInst>(TI)))
     return false;
 
-  MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
+  MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
   if (!WeightsNode)
     return false;
 
   // Check that the number of successors is manageable.
   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
 
-  // Ensure there are weights for all of the successors. Note that the first
-  // operand to the metadata node is a name, not a weight.
-  if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
-    return false;
-
   // Build up the final weights that will be used in a temporary buffer.
   // Compute the sum of all weights to later decide whether they need to
   // be scaled to fit in 32 bits.
@@ -403,7 +398,7 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
 
-  extractBranchWeights(*TI, Weights);
+  extractBranchWeights(WeightsNode, Weights);
   for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
     WeightSum += Weights[I];
     const LoopBlock SrcLoopBB = getLoopBlock(BB);

diff  --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 242b242a8ce26..8e91f0e8929b4 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 using namespace llvm;
 
@@ -855,13 +856,8 @@ Instruction *Instruction::cloneImpl() const {
 }
 
 void Instruction::swapProfMetadata() {
-  MDNode *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3 ||
-      !isa<MDString>(ProfileData->getOperand(0)))
-    return;
-
-  MDString *MDName = cast<MDString>(ProfileData->getOperand(0));
-  if (MDName->getString() != "branch_weights")
+  MDNode *ProfileData = getBranchWeightMDNode(*this);
+  if (!ProfileData || ProfileData->getNumOperands() != 3)
     return;
 
   // The first operand is the name. Fetch them backwards and build a new one.

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 53e7477c16d4f..603ea9de213ff 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -31,6 +31,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
@@ -4572,15 +4573,6 @@ void SwitchInst::growOperands() {
   growHungoffUses(ReservedSpace);
 }
 
-MDNode *
-SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) {
-  if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof))
-    if (auto *MDName = dyn_cast<MDString>(ProfileData->getOperand(0)))
-      if (MDName->getString() == "branch_weights")
-        return ProfileData;
-  return nullptr;
-}
-
 MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
   assert(Changed && "called only if metadata has changed");
 
@@ -4599,7 +4591,7 @@ MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
 }
 
 void SwitchInstProfUpdateWrapper::init() {
-  MDNode *ProfileData = getProfBranchWeightsMD(SI);
+  MDNode *ProfileData = getBranchWeightMDNode(SI);
   if (!ProfileData)
     return;
 
@@ -4609,11 +4601,8 @@ void SwitchInstProfUpdateWrapper::init() {
   }
 
   SmallVector<uint32_t, 8> Weights;
-  for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
-    ConstantInt *C = mdconst::extract<ConstantInt>(ProfileData->getOperand(CI));
-    uint32_t CW = C->getValue().getZExtValue();
-    Weights.push_back(CW);
-  }
+  if (!extractBranchWeights(ProfileData, Weights))
+    return;
   this->Weights = std::move(Weights);
 }
 
@@ -4686,7 +4675,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight(
 SwitchInstProfUpdateWrapper::CaseWeightOpt
 SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
                                                 unsigned idx) {
-  if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
+  if (MDNode *ProfileData = getBranchWeightMDNode(SI))
     if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1)
       return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
           ->getValue()

diff  --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index bc1deb019c2b0..bb757269e55f9 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -1544,7 +1544,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
+  return ::extractProfTotalWeight(*this, TotalVal);
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {

diff  --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 66c645c54fea7..e534368b05e41 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -101,6 +101,24 @@ bool hasBranchWeightMD(const Instruction &I) {
   return isBranchWeightMD(ProfileData);
 }
 
+bool hasValidBranchWeightMD(const Instruction &I) {
+  return getValidBranchWeightMDNode(I);
+}
+
+MDNode *getBranchWeightMDNode(const Instruction &I) {
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  if (!isBranchWeightMD(ProfileData))
+    return nullptr;
+  return ProfileData;
+}
+
+MDNode *getValidBranchWeightMDNode(const Instruction &I) {
+  auto *ProfileData = getBranchWeightMDNode(I);
+  if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
+    return ProfileData;
+  return nullptr;
+}
+
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))
@@ -118,7 +136,8 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
                           uint64_t &FalseVal) {
   assert((I.getOpcode() == Instruction::Br ||
           I.getOpcode() == Instruction::Select) &&
-         "Looking for branch weights on something besides branch or select");
+         "Looking for branch weights on something besides branch, select, or "
+         "switch");
 
   SmallVector<uint32_t, 2> Weights;
   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
@@ -161,4 +180,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
   return false;
 }
 
+bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
+  return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
+}
+
 } // namespace llvm

diff  --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 1aa737b865aa6..310e4d4164a5c 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -716,8 +716,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
     BranchInst *BR = dyn_cast<BranchInst>(E->getTerminator());
     if (!BR || BR->isUnconditional())
       continue;
-    uint64_t T, F;
-    if (extractBranchWeights(*BR, T, F))
+    if (hasBranchWeightMD(*BR))
       return true;
   }
   return false;

diff  --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 096faad245ccc..a072ba278fced 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -575,32 +576,26 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT,
   return true;
 }
 
-// Returns true and sets the true probability and false probability of an
-// MD_prof metadata if it's well-formed.
-static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb,
-                        BranchProbability &FalseProb) {
-  if (!MD) return false;
-  MDString *MDName = cast<MDString>(MD->getOperand(0));
-  if (MDName->getString() != "branch_weights" ||
-      MD->getNumOperands() != 3)
+// Constructs the true and false branch probabilities if the the instruction has
+// valid branch weights. Returns true when this was successful, false otherwise.
+static bool extractBranchProbabilities(Instruction *I,
+                                       BranchProbability &TrueProb,
+                                       BranchProbability &FalseProb) {
+  uint64_t TrueWeight;
+  uint64_t FalseWeight;
+  if (!extractBranchWeights(*I, TrueWeight, FalseWeight))
     return false;
-  ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1));
-  ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2));
-  if (!TrueWeight || !FalseWeight)
-    return false;
-  uint64_t TrueWt = TrueWeight->getValue().getZExtValue();
-  uint64_t FalseWt = FalseWeight->getValue().getZExtValue();
-  uint64_t SumWt = TrueWt + FalseWt;
+  uint64_t SumWeight = TrueWeight + FalseWeight;
 
-  assert(SumWt >= TrueWt && SumWt >= FalseWt &&
+  assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight &&
          "Overflow calculating branch probabilities.");
 
   // Guard against 0-to-0 branch weights to avoid a division-by-zero crash.
-  if (SumWt == 0)
+  if (SumWeight == 0)
     return false;
 
-  TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
-  FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
+  TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight);
+  FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight);
   return true;
 }
 
@@ -639,8 +634,7 @@ static bool checkBiasedBranch(BranchInst *BI, Region *R,
   if (!BI->isConditional())
     return false;
   BranchProbability ThenProb, ElseProb;
-  if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof),
-                   ThenProb, ElseProb))
+  if (!extractBranchProbabilities(BI, ThenProb, ElseProb))
     return false;
   BasicBlock *IfThen = BI->getSuccessor(0);
   BasicBlock *IfElse = BI->getSuccessor(1);
@@ -669,8 +663,7 @@ static bool checkBiasedSelect(
     DenseSet<SelectInst *> &FalseBiasedSelectsGlobal,
     DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) {
   BranchProbability TrueProb, FalseProb;
-  if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof),
-                   TrueProb, FalseProb))
+  if (!extractBranchProbabilities(SI, TrueProb, FalseProb))
     return false;
   CHR_DEBUG(dbgs() << "SI " << *SI << " ");
   CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " ");

diff  --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index eefbaec2d90d7..f41eaed2e3e7e 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -2522,18 +2522,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB,
 bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) {
   const Instruction *TI = BB->getTerminator();
   assert(TI->getNumSuccessors() > 1 && "not a split");
-
-  MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
-  if (!WeightsNode)
-    return false;
-
-  MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
-  if (MDName->getString() != "branch_weights")
-    return false;
-
-  // Ensure there are weights for all of the successors. Note that the first
-  // operand to the metadata node is a name, not a weight.
-  return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1;
+  return hasValidBranchWeightMD(*TI);
 }
 
 /// Update the block frequency of BB and branch weight and the metadata on the

diff  --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 578b399363678..49c0fff84d817 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -191,6 +191,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
@@ -974,37 +975,24 @@ bool LoopPredication::isLoopProfitableToPredicate() {
       LatchExitBlock->getTerminatingDeoptimizeCall())
     return false;
 
-  auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) {
-    if (!ProfileData || !ProfileData->getOperand(0))
-      return false;
-    if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
-      if (!MDS->getString().equals("branch_weights"))
-        return false;
-    if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
-      return false;
-    return true;
-  };
-  MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
   // Latch terminator has no valid profile data, so nothing to check
   // profitability on.
-  if (!IsValidProfileData(LatchProfileData, LatchTerm))
+  if (!hasValidBranchWeightMD(*LatchTerm))
     return true;
 
   auto ComputeBranchProbability =
       [&](const BasicBlock *ExitingBlock,
           const BasicBlock *ExitBlock) -> BranchProbability {
     auto *Term = ExitingBlock->getTerminator();
-    MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
     unsigned NumSucc = Term->getNumSuccessors();
-    if (IsValidProfileData(ProfileData, Term)) {
-      uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
-      for (unsigned i = 0; i < NumSucc; i++) {
-        ConstantInt *CI =
-            mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
-        ProfVal = CI->getValue().getZExtValue();
+    if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
+      SmallVector<uint32_t> Weights;
+      extractBranchWeights(ProfileData, Weights);
+      uint64_t Numerator = 0, Denominator = 0;
+      for (auto [i, Weight] : llvm::enumerate(Weights)) {
         if (Term->getSuccessor(i) == ExitBlock)
-          Numerator += ProfVal;
-        Denominator += ProfVal;
+          Numerator += Weight;
+        Denominator += Weight;
       }
       return BranchProbability::getBranchProbability(Numerator, Denominator);
     } else {

diff  --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index a8a7b64e6a235..00cbee9a25e70 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -62,6 +62,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -210,20 +211,18 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
       // Check to see if this branch is going to the same place as the default
       // dest.  If so, eliminate it as an explicit compare.
       if (i->getCaseSuccessor() == DefaultDest) {
-        MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
+        MDNode *MD = getValidBranchWeightMDNode(*SI);
         unsigned NCases = SI->getNumCases();
         // Fold the case metadata into the default if there will be any branches
         // left, unless the metadata doesn't match the switch.
-        if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) {
+        if (NCases > 1 && MD) {
           // Collect branch weights into a vector.
           SmallVector<uint32_t, 8> Weights;
-          for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
-               ++MD_i) {
-            auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i));
-            Weights.push_back(CI->getValue().getZExtValue());
-          }
+          extractBranchWeights(MD, Weights);
+
           // Merge weight of this case to the default weight.
           unsigned idx = i->getCaseIndex();
+          // TODO: Add overflow check.
           Weights[0] += Weights[idx+1];
           // Remove weight for this case.
           std::swap(Weights[idx+1], Weights.back());
@@ -313,18 +312,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
       BranchInst *NewBr = Builder.CreateCondBr(Cond,
                                                FirstCase.getCaseSuccessor(),
                                                SI->getDefaultDest());
-      MDNode *MD = SI->getMetadata(LLVMContext::MD_prof);
-      if (MD && MD->getNumOperands() == 3) {
-        ConstantInt *SICase =
-            mdconst::dyn_extract<ConstantInt>(MD->getOperand(2));
-        ConstantInt *SIDef =
-            mdconst::dyn_extract<ConstantInt>(MD->getOperand(1));
-        assert(SICase && SIDef);
+      SmallVector<uint32_t> Weights;
+      if (extractBranchWeights(*SI, Weights) && Weights.size() == 2) {
+        uint32_t DefWeight = Weights[0];
+        uint32_t CaseWeight = Weights[1];
         // The TrueWeight should be the weight for the single case of SI.
         NewBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(BB->getContext()).
-                        createBranchWeights(SICase->getValue().getZExtValue(),
-                                            SIDef->getValue().getZExtValue()));
+                           MDBuilder(BB->getContext())
+                               .createBranchWeights(CaseWeight, DefWeight));
       }
 
       // Update make.implicit metadata to the newly-created conditional branch.


        


More information about the llvm-commits mailing list