[llvm] 6e9bab7 - Revert "[llvm][NFC] Refactor code to use ProfDataUtils"

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 27 14:39:04 PDT 2022


Author: Paul Kirth
Date: 2022-07-27T21:38:11Z
New Revision: 6e9bab71b626183211625f150cc25fa22cb0973c

URL: https://github.com/llvm/llvm-project/commit/6e9bab71b626183211625f150cc25fa22cb0973c
DIFF: https://github.com/llvm/llvm-project/commit/6e9bab71b626183211625f150cc25fa22cb0973c.diff

LOG: Revert "[llvm][NFC] Refactor code to use ProfDataUtils"

This reverts commit 300c9a78819b4608b96bb26f9320bea6b8a0c4d0.

We will reland once these issues are ironed out.

Added: 
    

Modified: 
    llvm/include/llvm/IR/Instruction.h
    llvm/lib/Analysis/BranchProbabilityInfo.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/CodeGen/SelectOptimize.cpp
    llvm/lib/IR/Metadata.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Transforms/IPO/PartialInlining.cpp
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/lib/Transforms/Scalar/JumpThreading.cpp
    llvm/lib/Transforms/Utils/LoopPeel.cpp
    llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
    llvm/lib/Transforms/Utils/LoopUtils.cpp
    llvm/lib/Transforms/Utils/MisExpect.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index 1044a634408ca..15b0bdf557fb1 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -335,6 +335,11 @@ class Instruction : public User,
   /// Sets the AA metadata on this instruction from the AAMDNodes structure.
   void setAAMetadata(const AAMDNodes &N);
 
+  /// Retrieve the raw weight values of a conditional branch or select.
+  /// Returns true on success with profile weights filled in.
+  /// Returns false if no metadata or invalid metadata was found.
+  bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const;
+
   /// Retrieve total raw weight values of a branch.
   /// Returns true on success with profile total weights filled in.
   /// Returns false if no metadata was found.

diff  --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
index 8918fb967594e..f45728768fcdb 100644
--- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp
+++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp
@@ -31,7 +31,6 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -402,18 +401,24 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
   SmallVector<uint32_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
-
-  extractBranchWeights(*TI, Weights);
-  for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
-    WeightSum += Weights[I];
+  Weights.reserve(TI->getNumSuccessors());
+  for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
+    ConstantInt *Weight =
+        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
+    if (!Weight)
+      return false;
+    assert(Weight->getValue().getActiveBits() <= 32 &&
+           "Too many bits for uint32_t");
+    Weights.push_back(Weight->getZExtValue());
+    WeightSum += Weights.back();
     const LoopBlock SrcLoopBB = getLoopBlock(BB);
-    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
+    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
     if (EstimatedWeight &&
         *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
-      UnreachableIdxs.push_back(I);
+      UnreachableIdxs.push_back(I - 1);
     else
-      ReachableIdxs.push_back(I);
+      ReachableIdxs.push_back(I - 1);
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index a6b3ced721599..b8f6fc9bbcde9 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -65,7 +65,6 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -6621,7 +6620,7 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI,
   // If metadata tells us that the select condition is obviously predictable,
   // then we want to replace the select with a branch.
   uint64_t TrueWeight, FalseWeight;
-  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
+  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -8363,7 +8362,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // Another choice is to assume TrueProb for BB1 equals to TrueProb for
       // TmpBB, but the math is more complicated.
       uint64_t TrueWeight, FalseWeight;
-      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
+      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = TrueWeight;
         uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
@@ -8396,7 +8395,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // assumes that
       //   FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB.
       uint64_t TrueWeight, FalseWeight;
-      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
+      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = 2 * TrueWeight + FalseWeight;
         uint64_t NewFalseWeight = FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);

diff  --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index a7a698744591c..011f55efce1d5 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -29,7 +29,6 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ScaledNumber.h"
@@ -656,7 +655,7 @@ bool SelectOptimize::hasExpensiveColdOperand(
     const SmallVector<SelectInst *, 2> &ASI) {
   bool ColdOperand = false;
   uint64_t TrueWeight, FalseWeight, TotalWeight;
-  if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) {
+  if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) {
     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
     TotalWeight = TrueWeight + FalseWeight;
     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
@@ -751,7 +750,7 @@ void SelectOptimize::getExclBackwardsSlice(Instruction *I,
 
 bool SelectOptimize::isSelectHighlyPredictable(const SelectInst *SI) {
   uint64_t TrueWeight, FalseWeight;
-  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
+  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -960,7 +959,7 @@ SelectOptimize::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
                                      const SelectInst *SI) {
   Scaled64 PredPathCost;
   uint64_t TrueWeight, FalseWeight;
-  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
+  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
     uint64_t SumWeight = TrueWeight + FalseWeight;
     if (SumWeight != 0) {
       PredPathCost = TrueCost * Scaled64::get(TrueWeight) +

diff  --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index a2027439240fe..2a1a514922fdc 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -40,7 +40,6 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/TrackingMDRef.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -1494,6 +1493,31 @@ void Instruction::getAllMetadataImpl(
   Value::getAllMetadata(Result);
 }
 
+bool Instruction::extractProfMetadata(uint64_t &TrueVal,
+                                      uint64_t &FalseVal) const {
+  assert(
+      (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
+      "Looking for branch weights on something besides branch or select");
+
+  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
+  if (!ProfileData || ProfileData->getNumOperands() != 3)
+    return false;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
+    return false;
+
+  auto *CITrue = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1));
+  auto *CIFalse = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2));
+  if (!CITrue || !CIFalse)
+    return false;
+
+  TrueVal = CITrue->getValue().getZExtValue();
+  FalseVal = CIFalse->getValue().getZExtValue();
+
+  return true;
+}
+
 bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select ||
@@ -1502,7 +1526,32 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
+  TotalVal = 0;
+  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
+  if (!ProfileData)
+    return false;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+  if (!ProfDataName)
+    return false;
+
+  if (ProfDataName->getString().equals("branch_weights")) {
+    TotalVal = 0;
+    for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) {
+      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i));
+      if (!V)
+        return false;
+      TotalVal += V->getValue().getZExtValue();
+    }
+    return true;
+  } else if (ProfDataName->getString().equals("VP") &&
+             ProfileData->getNumOperands() > 3) {
+    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
+                   ->getValue()
+                   .getZExtValue();
+    return true;
+  }
+  return false;
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index 34dfdcf2b1109..cf728933c08d2 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -15,7 +15,6 @@
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSchedule.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/KnownBits.h"
@@ -758,7 +757,7 @@ bool PPCTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
     if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
       uint64_t TrueWeight = 0, FalseWeight = 0;
       if (!BI->isConditional() ||
-          !extractBranchWeights(BI, TrueWeight, FalseWeight))
+          !BI->extractProfMetadata(TrueWeight, FalseWeight))
         continue;
 
       // If the exit path is more frequent than the loop path,

diff  --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index ec2e7fb215b2f..54c72bdbb203f 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -40,7 +40,6 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/User.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -718,7 +717,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
     if (!BR || BR->isUnconditional())
       continue;
     uint64_t T, F;
-    if (extractBranchWeights(*BR, T, F))
+    if (BR->extractProfMetadata(T, F))
       return true;
   }
   return false;

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 90cc61c6b291b..c4512d0222cde 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -91,7 +91,6 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -2068,7 +2067,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
       // Display scaled counts for SELECT instruction:
       OS << "SELECT : { T = ";
       uint64_t TC, FC;
-      bool HasProf = extractBranchWeights(I, TC, FC);
+      bool HasProf = I.extractProfMetadata(TC, FC);
       if (!HasProf)
         OS << "Unknown, F = Unknown }\\l";
       else

diff  --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 0113e7bf570df..b31eab50c5ecc 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -54,7 +54,6 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/Value.h"
@@ -217,7 +216,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     return;
 
   uint64_t TrueWeight, FalseWeight;
-  if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight))
+  if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
     return;
 
   if (TrueWeight + FalseWeight == 0)
@@ -280,7 +279,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     // With PGO, this can be used to refine even existing profile data with
     // context information. This needs to be done after more performance
     // testing.
-    if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight))
+    if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight))
       continue;
 
     // We can not infer anything useful when BP >= 50%, because BP is the

diff  --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 9a7f9df7a330f..f093fea19c4d4 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -29,7 +29,6 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -533,7 +532,7 @@ static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
                               uint64_t &ExitWeight,
                               uint64_t &FallThroughWeight) {
   uint64_t TrueWeight, FalseWeight;
-  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
+  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
     return;
   unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
   ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;

diff  --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 1c44ccb589bcc..023a0afd329b5 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -30,7 +30,6 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -472,7 +471,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
   uint64_t TrueWeight, FalseWeight;
   BranchInst *LatchBR =
       cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
-  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
+  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
     return;
   uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
                             ? FalseWeight

diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 03272e5f8a549..349063dd5e892 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -38,7 +38,6 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -790,7 +789,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L,
   // know the number of times the backedge was taken, vs. the number of times
   // we exited the loop.
   uint64_t LoopWeight, ExitWeight;
-  if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
+  if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight))
     return None;
 
   if (L->contains(ExitingBranch->getSuccessor(1)))

diff  --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
index 2960c79f5dceb..4414b04c72645 100644
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -35,7 +35,6 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -119,6 +118,34 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
 namespace llvm {
 namespace misexpect {
 
+// Helper function to extract branch weights into a vector
+Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
+                                                  LLVMContext &Ctx) {
+  assert(I && "MisExpect::extractWeights given invalid pointer");
+
+  auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
+  if (!ProfileData)
+    return None;
+
+  unsigned NOps = ProfileData->getNumOperands();
+  if (NOps < 3)
+    return None;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
+    return None;
+
+  SmallVector<uint32_t, 4> Weights(NOps - 1);
+  for (unsigned Idx = 1; Idx < NOps; Idx++) {
+    ConstantInt *Value =
+        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+    uint32_t V = Value->getZExtValue();
+    Weights[Idx - 1] = V;
+  }
+
+  return Weights;
+}
+
 // TODO: when clang allows c++17, use std::clamp instead
 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
   if (value > hi)
@@ -191,17 +218,19 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
-  SmallVector<uint32_t> ExpectedWeights;
-  if (!extractBranchWeights(I, ExpectedWeights))
+  auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
+  if (!ExpectedWeightsOpt)
     return;
+  auto ExpectedWeights = ExpectedWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
 void checkFrontendInstrumentation(Instruction &I,
                                   const ArrayRef<uint32_t> ExpectedWeights) {
-  SmallVector<uint32_t> RealWeights;
-  if (!extractBranchWeights(I, RealWeights))
+  auto RealWeightsOpt = extractWeights(&I, I.getContext());
+  if (!RealWeightsOpt)
     return;
+  auto RealWeights = RealWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index bba831521ff59..1806081678a86 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -57,7 +57,6 @@
 #include "llvm/IR/NoFolder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1051,6 +1050,15 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
   return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
 }
 
+static inline bool HasBranchWeights(const Instruction *I) {
+  MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
+  if (ProfMD && ProfMD->getOperand(0))
+    if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
+      return MDS->getString().equals("branch_weights");
+
+  return false;
+}
+
 /// Get Weights of a given terminator, the default weight is at the front
 /// of the vector. If TI is a conditional eq, we need to swap the branch-weight
 /// metadata.
@@ -1169,8 +1177,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
 
   // Update the branch weight metadata along the way
   SmallVector<uint64_t, 8> Weights;
-  bool PredHasWeights = hasBranchWeightMD(*PTI);
-  bool SuccHasWeights = hasBranchWeightMD(*TI);
+  bool PredHasWeights = HasBranchWeights(PTI);
+  bool SuccHasWeights = HasBranchWeights(TI);
 
   if (PredHasWeights) {
     GetBranchWeights(PTI, Weights);
@@ -2744,8 +2752,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
   // the `then` block, then avoid speculating it.
   if (!BI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (extractBranchWeights(*BI, TWeight, FWeight) &&
-        (TWeight + FWeight) != 0) {
+    if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) {
       uint64_t EndWeight = Invert ? TWeight : FWeight;
       BranchProbability BIEndProb =
           BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
@@ -3167,7 +3174,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // from the block that we know is predictably not entered.
   if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (extractBranchWeights(*DomBI, TWeight, FWeight) &&
+    if (DomBI->extractProfMetadata(TWeight, FWeight) &&
         (TWeight + FWeight) != 0) {
       BranchProbability BITrueProb =
           BranchProbability::getBranchProbability(TWeight, TWeight + FWeight);
@@ -3347,9 +3354,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI,
                                    uint64_t &SuccTrueWeight,
                                    uint64_t &SuccFalseWeight) {
   bool PredHasWeights =
-      extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight);
+      PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight);
   bool SuccHasWeights =
-      extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight);
+      BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight);
   if (PredHasWeights || SuccHasWeights) {
     if (!PredHasWeights)
       PredTrueWeight = PredFalseWeight = 1;
@@ -3377,7 +3384,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI,
   uint64_t PTWeight, PFWeight;
   BranchProbability PBITrueProb, Likely;
   if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) &&
-      extractBranchWeights(*PBI, PTWeight, PFWeight) &&
+      PBI->extractProfMetadata(PTWeight, PFWeight) &&
       (PTWeight + PFWeight) != 0) {
     PBITrueProb =
         BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight);
@@ -4342,7 +4349,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI,
   // Get weight for TrueBB and FalseBB.
   uint32_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
-  bool HasWeights = hasBranchWeightMD(*SI);
+  bool HasWeights = HasBranchWeights(SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
@@ -5202,7 +5209,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI,
   BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
 
   // Update weight for the newly-created conditional branch.
-  if (hasBranchWeightMD(*SI)) {
+  if (HasBranchWeights(SI)) {
     SmallVector<uint64_t, 8> Weights;
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {


        


More information about the llvm-commits mailing list