[llvm] [SimplifyCFG][PGO] Reuse existing `setBranchWeights` (PR #160629)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 30 07:48:50 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/160629

>From ccef0cc48c2fbe8bf5ed6874c68f90216be9aef2 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 25 Sep 2025 02:18:30 +0000
Subject: [PATCH] [SimplifyCFG][profcheck] Fix artificially-failing
 `preserve-branchweights.ll`

---
 llvm/include/llvm/IR/Instructions.h           |   7 +-
 llvm/include/llvm/IR/ProfDataUtils.h          |   5 +
 llvm/lib/IR/Instructions.cpp                  |  17 ---
 llvm/lib/IR/ProfDataUtils.cpp                 |  36 ++++-
 llvm/lib/Transforms/IPO/SampleProfile.cpp     |   8 +-
 .../Instrumentation/IndirectCallPromotion.cpp |   2 +-
 llvm/lib/Transforms/Utils/ProfileVerify.cpp   |   5 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 128 ++++++------------
 .../LoopRotate/update-branch-weights.ll       |   5 +-
 llvm/utils/profcheck-xfail.txt                |   2 -
 10 files changed, 96 insertions(+), 119 deletions(-)

diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 95a0a7fd2f97e..de7a237098594 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -32,6 +32,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/OperandTraits.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
 #include "llvm/Support/AtomicOrdering.h"
@@ -3536,8 +3537,6 @@ class SwitchInstProfUpdateWrapper {
   bool Changed = false;
 
 protected:
-  LLVM_ABI MDNode *buildProfBranchWeightsMD();
-
   LLVM_ABI void init();
 
 public:
@@ -3549,8 +3548,8 @@ class SwitchInstProfUpdateWrapper {
   SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); }
 
   ~SwitchInstProfUpdateWrapper() {
-    if (Changed)
-      SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
+    if (Changed && Weights.has_value() && Weights->size() >= 2)
+      setBranchWeights(SI, Weights.value(), /*IsExpected=*/false);
   }
 
   /// Delegate the call to the underlying SwitchInst::removeCase() and remove
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index e97160e59c795..47fe9c323a12b 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -147,6 +147,11 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
 LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
                                bool IsExpected);
 
+/// Variant of `setBranchWeights` where the `Weights` will be fit first to
+/// uint32_t by shifting right.
+LLVM_ABI void setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
+                                     bool IsExpected);
+
 /// downscale the given weights preserving the ratio. If the maximum value is
 /// not already known and not provided via \param KnownMaxCount , it will be
 /// obtained from \param Weights.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index dd83168ab3c6e..941e41f3127d5 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4141,23 +4141,6 @@ void SwitchInst::growOperands() {
   growHungoffUses(ReservedSpace);
 }
 
-MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
-  assert(Changed && "called only if metadata has changed");
-
-  if (!Weights)
-    return nullptr;
-
-  assert(SI.getNumSuccessors() == Weights->size() &&
-         "num of prof branch_weights must accord with num of successors");
-
-  bool AllZeroes = all_of(*Weights, [](uint32_t W) { return W == 0; });
-
-  if (AllZeroes || Weights->size() < 2)
-    return nullptr;
-
-  return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
-}
-
 void SwitchInstProfUpdateWrapper::init() {
   MDNode *ProfileData = getBranchWeightMDNode(SI);
   if (!ProfileData)
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 99029c1719507..438c2180e3e7a 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/IR/ProfDataUtils.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
@@ -19,6 +20,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/Support/CommandLine.h"
 
 using namespace llvm;
 
@@ -84,10 +86,31 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
   }
 }
 
+/// Push the weights right to fit in uint32_t.
+static SmallVector<uint32_t> fitWeights(ArrayRef<uint64_t> Weights) {
+  SmallVector<uint32_t> Ret;
+  Ret.reserve(Weights.size());
+  uint64_t Max = *llvm::max_element(Weights);
+  if (Max > UINT_MAX) {
+    unsigned Offset = 32 - llvm::countl_zero(Max);
+    for (const uint64_t &Value : Weights)
+      Ret.push_back(static_cast<uint32_t>(Value >> Offset));
+  } else {
+    append_range(Ret, Weights);
+  }
+  return Ret;
+}
+
 } // namespace
 
 namespace llvm {
-
+cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
+#if defined(LLVM_ENABLE_PROFCHECK)
+                                        cl::init(false)
+#else
+                                        cl::init(true)
+#endif
+);
 const char *MDProfLabels::BranchWeights = "branch_weights";
 const char *MDProfLabels::ExpectedBranchWeights = "expected";
 const char *MDProfLabels::ValueProfile = "VP";
@@ -283,11 +306,22 @@ bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
 
 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
                       bool IsExpected) {
+  if (ElideAllZeroBranchWeights &&
+      llvm::all_of(Weights, [](uint32_t V) { return V == 0; })) {
+    I.setMetadata(LLVMContext::MD_prof, nullptr);
+    return;
+  }
+
   MDBuilder MDB(I.getContext());
   MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
 }
 
+void setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
+                            bool IsExpected) {
+  setBranchWeights(I, fitWeights(Weights), IsExpected);
+}
+
 SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
                                        std::optional<uint64_t> KnownMaxCount) {
   uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 5bc7e34938127..99b8b88ebedbb 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1664,8 +1664,9 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])},
-                           /*IsExpected=*/false);
+          setBranchWeights(
+              I, ArrayRef<uint32_t>{static_cast<uint32_t>(BlockWeights[BB])},
+              /*IsExpected=*/false);
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1676,7 +1677,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
           } else {
-            setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false);
+            setBranchWeights(I, ArrayRef<uint32_t>{uint32_t(0)},
+                             /*IsExpected=*/false);
           }
         }
       }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index f451c2b471aa6..5d21bac1db4c8 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -672,7 +672,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       createBranchWeights(CB.getContext(), Count, TotalCount - Count));
 
   if (AttachProfToDirectCall)
-    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)},
+    setBranchWeights(NewInst, ArrayRef<uint32_t>{static_cast<uint32_t>(Count)},
                      /*IsExpected=*/false);
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Utils/ProfileVerify.cpp b/llvm/lib/Transforms/Utils/ProfileVerify.cpp
index c578b4b839258..72d3dcba85dea 100644
--- a/llvm/lib/Transforms/Utils/ProfileVerify.cpp
+++ b/llvm/lib/Transforms/Utils/ProfileVerify.cpp
@@ -103,8 +103,9 @@ bool ProfileInjector::inject() {
     if (AnnotateSelect) {
       for (auto &I : BB) {
         if (isa<SelectInst>(I) && !I.getMetadata(LLVMContext::MD_prof))
-          setBranchWeights(I, {SelectTrueWeight, SelectFalseWeight},
-                           /*IsExpected=*/false);
+          setBranchWeights(
+              I, ArrayRef<uint32_t>{SelectTrueWeight, SelectFalseWeight},
+              /*IsExpected=*/false);
       }
     }
     auto *Term = getTerminatorBenefitingFromMDProf(BB);
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 216bdf4eb9efb..48987cd1087c0 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -955,33 +955,6 @@ static bool valuesOverlap(std::vector<ValueEqualityComparisonCase> &C1,
   return false;
 }
 
-// Set branch weights on SwitchInst. This sets the metadata if there is at
-// least one non-zero weight.
-static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights,
-                             bool IsExpected) {
-  // Check that there is at least one non-zero weight. Otherwise, pass
-  // nullptr to setMetadata which will erase the existing metadata.
-  MDNode *N = nullptr;
-  if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; }))
-    N = MDBuilder(SI->getParent()->getContext())
-            .createBranchWeights(Weights, IsExpected);
-  SI->setMetadata(LLVMContext::MD_prof, N);
-}
-
-// Similar to the above, but for branch and select instructions that take
-// exactly 2 weights.
-static void setBranchWeights(Instruction *I, uint32_t TrueWeight,
-                             uint32_t FalseWeight, bool IsExpected) {
-  assert(isa<BranchInst>(I) || isa<SelectInst>(I));
-  // Check that there is at least one non-zero weight. Otherwise, pass
-  // nullptr to setMetadata which will erase the existing metadata.
-  MDNode *N = nullptr;
-  if (TrueWeight || FalseWeight)
-    N = MDBuilder(I->getParent()->getContext())
-            .createBranchWeights(TrueWeight, FalseWeight, IsExpected);
-  I->setMetadata(LLVMContext::MD_prof, N);
-}
-
 /// If TI is known to be a terminator instruction and its block is known to
 /// only have a single predecessor block, check to see if that predecessor is
 /// also a value comparison with the same value, and if that comparison
@@ -1181,16 +1154,6 @@ static void getBranchWeights(Instruction *TI,
   }
 }
 
-/// Keep halving the weights until all can fit in uint32_t.
-static void fitWeights(MutableArrayRef<uint64_t> Weights) {
-  uint64_t Max = *llvm::max_element(Weights);
-  if (Max > UINT_MAX) {
-    unsigned Offset = 32 - llvm::countl_zero(Max);
-    for (uint64_t &I : Weights)
-      I >>= Offset;
-  }
-}
-
 static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
     BasicBlock *BB, BasicBlock *PredBlock, ValueToValueMapTy &VMap) {
   Instruction *PTI = PredBlock->getTerminator();
@@ -1446,14 +1409,8 @@ bool SimplifyCFGOpt::performValueComparisonIntoPredecessorFolding(
   for (ValueEqualityComparisonCase &V : PredCases)
     NewSI->addCase(V.Value, V.Dest);
 
-  if (PredHasWeights || SuccHasWeights) {
-    // Halve the weights if any of them cannot fit in an uint32_t
-    fitWeights(Weights);
-
-    SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
-
-    setBranchWeights(NewSI, MDWeights, /*IsExpected=*/false);
-  }
+  if (PredHasWeights || SuccHasWeights)
+    setFittedBranchWeights(*NewSI, Weights, /*IsExpected=*/false);
 
   eraseTerminatorAndDCECond(PTI);
 
@@ -4053,39 +4010,33 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
 
   // Try to update branch weights.
   uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
-  SmallVector<uint32_t, 2> MDWeights;
+  SmallVector<uint64_t, 2> MDWeights;
   if (extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight,
                              SuccTrueWeight, SuccFalseWeight)) {
-    SmallVector<uint64_t, 8> NewWeights;
 
     if (PBI->getSuccessor(0) == BB) {
       // PBI: br i1 %x, BB, FalseDest
       // BI:  br i1 %y, UniqueSucc, FalseDest
       // TrueWeight is TrueWeight for PBI * TrueWeight for BI.
-      NewWeights.push_back(PredTrueWeight * SuccTrueWeight);
+      MDWeights.push_back(PredTrueWeight * SuccTrueWeight);
       // FalseWeight is FalseWeight for PBI * TotalWeight for BI +
       //               TrueWeight for PBI * FalseWeight for BI.
       // We assume that total weights of a BranchInst can fit into 32 bits.
       // Therefore, we will not have overflow using 64-bit arithmetic.
-      NewWeights.push_back(PredFalseWeight *
-                               (SuccFalseWeight + SuccTrueWeight) +
-                           PredTrueWeight * SuccFalseWeight);
+      MDWeights.push_back(PredFalseWeight * (SuccFalseWeight + SuccTrueWeight) +
+                          PredTrueWeight * SuccFalseWeight);
     } else {
       // PBI: br i1 %x, TrueDest, BB
       // BI:  br i1 %y, TrueDest, UniqueSucc
       // TrueWeight is TrueWeight for PBI * TotalWeight for BI +
       //              FalseWeight for PBI * TrueWeight for BI.
-      NewWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) +
-                           PredFalseWeight * SuccTrueWeight);
+      MDWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) +
+                          PredFalseWeight * SuccTrueWeight);
       // FalseWeight is FalseWeight for PBI * FalseWeight for BI.
-      NewWeights.push_back(PredFalseWeight * SuccFalseWeight);
+      MDWeights.push_back(PredFalseWeight * SuccFalseWeight);
     }
 
-    // Halve the weights if any of them cannot fit in an uint32_t
-    fitWeights(NewWeights);
-
-    append_range(MDWeights, NewWeights);
-    setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false);
+    setFittedBranchWeights(*PBI, MDWeights, /*IsExpected=*/false);
 
     // TODO: If BB is reachable from all paths through PredBlock, then we
     // could replace PBI's branch probabilities with BI's.
@@ -4125,8 +4076,8 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
     if (auto *SI = dyn_cast<SelectInst>(PBI->getCondition()))
       if (!MDWeights.empty()) {
         assert(isSelectInRoleOfConjunctionOrDisjunction(SI));
-        setBranchWeights(SI, MDWeights[0], MDWeights[1],
-                         /*IsExpected=*/false);
+        setFittedBranchWeights(*SI, {MDWeights[0], MDWeights[1]},
+                               /*IsExpected=*/false);
       }
 
   ++NumFoldBranchToCommonDest;
@@ -4478,9 +4429,9 @@ static bool mergeConditionalStoreToAddress(
     if (InvertQCond)
       std::swap(QWeights[0], QWeights[1]);
     auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
-    setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
-                     CombinedWeights[1],
-                     /*IsExpected=*/false);
+    setFittedBranchWeights(*PostBB->getTerminator(),
+                           {CombinedWeights[0], CombinedWeights[1]},
+                           /*IsExpected=*/false);
   }
 
   QB.SetInsertPoint(T);
@@ -4836,10 +4787,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
     uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther) +
                                   PredOther * SuccCommon,
                               PredOther * SuccOther};
-    // Halve the weights if any of them cannot fit in an uint32_t
-    fitWeights(NewWeights);
 
-    setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false);
+    setFittedBranchWeights(*PBI, NewWeights, /*IsExpected=*/false);
     // Cond may be a select instruction with the first operand set to "true", or
     // the second to "false" (see how createLogicalOp works for `and` and `or`)
     if (!ProfcheckDisableMetadataFixes)
@@ -4849,8 +4798,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
         assert(dyn_cast<SelectInst>(SI)->getCondition() == PBICond);
         // The corresponding probabilities are what was referred to above as
         // PredCommon and PredOther.
-        setBranchWeights(SI, PredCommon, PredOther,
-                         /*IsExpected=*/false);
+        setFittedBranchWeights(*SI, {PredCommon, PredOther},
+                               /*IsExpected=*/false);
       }
   }
 
@@ -4876,8 +4825,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
       if (HasWeights) {
         uint64_t TrueWeight = PBIOp ? PredFalseWeight : PredTrueWeight;
         uint64_t FalseWeight = PBIOp ? PredTrueWeight : PredFalseWeight;
-        setBranchWeights(NV, TrueWeight, FalseWeight,
-                         /*IsExpected=*/false);
+        setFittedBranchWeights(*NV, {TrueWeight, FalseWeight},
+                               /*IsExpected=*/false);
       }
     }
   }
@@ -4940,7 +4889,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm,
       // Create a conditional branch sharing the condition of the select.
       BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
       if (TrueWeight != FalseWeight)
-        setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false);
+        setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
+                         /*IsExpected=*/false);
     }
   } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
     // Neither of the selected blocks were successors, so this
@@ -5889,7 +5839,8 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
         TrueWeight /= 2;
         FalseWeight /= 2;
       }
-      setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false);
+      setFittedBranchWeights(*NewBI, {TrueWeight, FalseWeight},
+                             /*IsExpected=*/false);
     }
   }
 
@@ -6364,9 +6315,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
         // BranchWeights. We want the probability and negative probability of
         // Condition == SecondCase.
         assert(BranchWeights.size() == 3);
-        setBranchWeights(SI, BranchWeights[2],
-                         BranchWeights[0] + BranchWeights[1],
-                         /*IsExpected=*/false);
+        setBranchWeights(
+            *SI, {BranchWeights[2], BranchWeights[0] + BranchWeights[1]},
+            /*IsExpected=*/false);
       }
     }
     Value *ValueCompare =
@@ -6381,8 +6332,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
       size_t FirstCasePos = (Condition != nullptr);
       size_t SecondCasePos = FirstCasePos + 1;
       uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
-      setBranchWeights(SI, BranchWeights[FirstCasePos],
-                       DefaultCase + BranchWeights[SecondCasePos],
+      setBranchWeights(*SI,
+                       {BranchWeights[FirstCasePos],
+                        DefaultCase + BranchWeights[SecondCasePos]},
                        /*IsExpected=*/false);
     }
     return Ret;
@@ -6427,8 +6379,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
             // We know there's a Default case. We base the resulting branch
             // weights off its probability.
             assert(BranchWeights.size() >= 2);
-            setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
-                             BranchWeights[0], /*IsExpected=*/false);
+            setBranchWeights(
+                *SI,
+                {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]},
+                /*IsExpected=*/false);
           }
           return Ret;
         }
@@ -6451,8 +6405,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
             Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
         if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
           assert(BranchWeights.size() >= 2);
-          setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
-                           BranchWeights[0], /*IsExpected=*/false);
+          setBranchWeights(
+              *SI,
+              {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]},
+              /*IsExpected=*/false);
         }
         return Ret;
       }
@@ -6469,8 +6425,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
           Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
       if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
         assert(BranchWeights.size() >= 2);
-        setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
-                         BranchWeights[0], /*IsExpected=*/false);
+        setBranchWeights(
+            *SI, {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]},
+            /*IsExpected=*/false);
       }
       return Ret;
     }
@@ -8152,8 +8109,7 @@ static bool mergeNestedCondBranch(BranchInst *BI, DomTreeUpdater *DTU) {
   if (HasWeight) {
     uint64_t Weights[2] = {BBTWeight * BB1FWeight + BBFWeight * BB2TWeight,
                            BBTWeight * BB1TWeight + BBFWeight * BB2FWeight};
-    fitWeights(Weights);
-    setBranchWeights(BI, Weights[0], Weights[1], /*IsExpected=*/false);
+    setFittedBranchWeights(*BI, Weights, /*IsExpected=*/false);
   }
   return true;
 }
diff --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
index 9a1f36ec5ff2b..486f62232d78c 100644
--- a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -209,10 +209,10 @@ loop_exit:
 
 ; IR-LABEL: define void @func5_zero_branch_weight
 ; IR: entry:
-; IR:   br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC5_0:![0-9]+]]
+; IR:   br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph
 
 ; IR: loop_body:
-; IR:   br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC5_0]]
+; IR:   br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body
 
 define void @func5_zero_branch_weight(i32 %n) !prof !3 {
 entry:
@@ -291,5 +291,4 @@ loop_exit:
 ; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
 ; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
 ; IR: [[PROF_FUNC4_0]] = !{!"branch_weights", i32 1, i32 0}
-; IR: [[PROF_FUNC5_0]] = !{!"branch_weights", i32 0, i32 0}
 ; IR: [[PROF_FUNC6_0]] = !{!"branch_weights", i32 0, i32 1024}
diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt
index 08c89441ec855..77e6ab7c5a6ea 100644
--- a/llvm/utils/profcheck-xfail.txt
+++ b/llvm/utils/profcheck-xfail.txt
@@ -1414,7 +1414,6 @@ Transforms/SimplifyCFG/merge-cond-stores.ll
 Transforms/SimplifyCFG/multiple-phis.ll
 Transforms/SimplifyCFG/PhiBlockMerge.ll
 Transforms/SimplifyCFG/pr48641.ll
-Transforms/SimplifyCFG/preserve-branchweights.ll
 Transforms/SimplifyCFG/preserve-store-alignment.ll
 Transforms/SimplifyCFG/rangereduce.ll
 Transforms/SimplifyCFG/RISCV/select-trunc-i64.ll
@@ -1424,7 +1423,6 @@ Transforms/SimplifyCFG/safe-abs.ll
 Transforms/SimplifyCFG/SimplifyEqualityComparisonWithOnlyPredecessor-domtree-preservation-edgecase.ll
 Transforms/SimplifyCFG/speculate-blocks.ll
 Transforms/SimplifyCFG/speculate-derefable-load.ll
-Transforms/SimplifyCFG/suppress-zero-branch-weights.ll
 Transforms/SimplifyCFG/switch_create-custom-dl.ll
 Transforms/SimplifyCFG/switch_create.ll
 Transforms/SimplifyCFG/switch-dup-bbs.ll



More information about the llvm-commits mailing list