[llvm] [llvm][IR] Extend BranchWeightMetadata to track provenance of weights (PR #86609)

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 14:13:37 PDT 2024


https://github.com/ilovepi updated https://github.com/llvm/llvm-project/pull/86609

>From 1e9dbac8b5578a3d10228466177f2d7c605978ba Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Tue, 26 Mar 2024 00:49:16 +0000
Subject: [PATCH 1/4] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.4

[skip ci]
---
 llvm/include/llvm/IR/ProfDataUtils.h          |  9 +++-
 llvm/lib/IR/ProfDataUtils.cpp                 | 45 +++++++++++++------
 .../Transforms/Utils/LoopRotationUtils.cpp    |  2 +-
 llvm/lib/Transforms/Utils/MisExpect.cpp       |  3 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     |  7 +--
 5 files changed, 43 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c790..dc983eed13a8d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights);
 
 /// Faster version of extractBranchWeights() that skips checks and must only
-/// be called with "branch_weights" metadata nodes.
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+/// be called with "branch_weights" metadata nodes. Supports uint32_t.
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights);
 
+/// Faster version of extractBranchWeights() that skips checks and must only
+/// be called with "branch_weights" metadata nodes. Supports uint64_t.
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights);
+
 /// Extract branch weights attatched to an Instruction
 ///
 /// \param I The Instruction to extract weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a52..b4e09e76993f9 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
   return ProfDataName->getString().equals(Name);
 }
 
+template <typename T,
+          typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+                                      SmallVectorImpl<T> &Weights) {
+  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+  unsigned NOps = ProfileData->getNumOperands();
+  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+  Weights.resize(NOps - WeightsIdx);
+
+  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+    ConstantInt *Weight =
+        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+    assert(Weight && "Malformed branch_weight in MD_prof node");
+    assert(Weight->getValue().getActiveBits() <= 32 &&
+           "Too many bits for uint32_t");
+    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+  }
+}
+
 } // namespace
 
 namespace llvm {
@@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   return nullptr;
 }
 
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights) {
-  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
-  unsigned NOps = ProfileData->getNumOperands();
-  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
-  Weights.resize(NOps - WeightsIdx);
+  extractFromBranchWeightMD(ProfileData, Weights);
+}
 
-  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    assert(Weight && "Malformed branch_weight in MD_prof node");
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
-  }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights) {
+  extractFromBranchWeightMD(ProfileData, Weights);
 }
 
+
+
+
+
+
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index bc67117113719..f4b43ce370a5d 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     return;
 
   SmallVector<uint32_t, 2> Weights;
-  extractFromBranchWeightMD(WeightMD, Weights);
+  extractFromBranchWeightMD32(WeightMD, Weights);
   if (Weights.size() != 2)
     return;
   uint32_t OrigLoopExitWeight = Weights[0];
diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp
index 6f5a25a26821b..759289384ee06 100644
--- a/llvm/lib/Transforms/Utils/MisExpect.cpp
+++ b/llvm/lib/Transforms/Utils/MisExpect.cpp
@@ -59,9 +59,10 @@ static cl::opt<bool> PGOWarnMisExpect(
     cl::desc("Use this option to turn on/off "
              "warnings about incorrect usage of llvm.expect intrinsics."));
 
+// Command line option for setting the diagnostic tolerance threshold
 static cl::opt<uint32_t> MisExpectTolerance(
     "misexpect-tolerance", cl::init(0),
-    cl::desc("Prevents emiting diagnostics when profile counts are "
+    cl::desc("Prevents emitting diagnostics when profile counts are "
              "within N% of the threshold.."));
 
 } // namespace llvm
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55bbffb18879f..a425e26d490e4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
 static void GetBranchWeights(Instruction *TI,
                              SmallVectorImpl<uint64_t> &Weights) {
   MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
-  assert(MD);
-  for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
-    ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
-    Weights.push_back(CI->getValue().getZExtValue());
-  }
+  assert(MD && "Invalid branch-weight metadata");
+  extractFromBranchWeightMD64(MD, Weights);
 
   // If TI is a conditional eq, the default case is the false case,
   // and the corresponding branch-weight data is at index 2. We swap the

>From ee503bf633c47f9b8abb76848327bcbd2b769be3 Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Tue, 26 Mar 2024 00:56:29 +0000
Subject: [PATCH 2/4] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?=
 =?UTF-8?q?anges=20introduced=20through=20rebase?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.4

[skip ci]
---
 llvm/include/llvm/IR/ProfDataUtils.h | 4 ++--
 llvm/lib/IR/ProfDataUtils.cpp        | 9 ++-------
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index dc983eed13a8d..457ffdff8fe37 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -67,12 +67,12 @@ bool extractBranchWeights(const MDNode *ProfileData,
 /// Faster version of extractBranchWeights() that skips checks and must only
 /// be called with "branch_weights" metadata nodes. Supports uint32_t.
 void extractFromBranchWeightMD32(const MDNode *ProfileData,
-                               SmallVectorImpl<uint32_t> &Weights);
+                                 SmallVectorImpl<uint32_t> &Weights);
 
 /// Faster version of extractBranchWeights() that skips checks and must only
 /// be called with "branch_weights" metadata nodes. Supports uint64_t.
 void extractFromBranchWeightMD64(const MDNode *ProfileData,
-                               SmallVectorImpl<uint64_t> &Weights);
+                                 SmallVectorImpl<uint64_t> &Weights);
 
 /// Extract branch weights attatched to an Instruction
 ///
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b4e09e76993f9..36e165e641f46 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -121,20 +121,15 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
 }
 
 void extractFromBranchWeightMD32(const MDNode *ProfileData,
-                               SmallVectorImpl<uint32_t> &Weights) {
+                                 SmallVectorImpl<uint32_t> &Weights) {
   extractFromBranchWeightMD(ProfileData, Weights);
 }
 
 void extractFromBranchWeightMD64(const MDNode *ProfileData,
-                               SmallVectorImpl<uint64_t> &Weights) {
+                                 SmallVectorImpl<uint64_t> &Weights) {
   extractFromBranchWeightMD(ProfileData, Weights);
 }
 
-
-
-
-
-
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))

>From 7760282ed8dba340d6873d06ff4c18c6efc25b56 Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Thu, 6 Jun 2024 19:04:19 +0000
Subject: [PATCH 3/4] Add assert for metadata string value

Created using spr 1.3.4
---
 llvm/lib/IR/ProfDataUtils.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index f738d76937c24..af536d2110eac 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -133,6 +133,7 @@ bool hasBranchWeightProvenance(const MDNode *ProfileData) {
   // NOTE: if we ever have more types of branch weight provenance,
   // we need to check the string value is "expected". For now, we
   // supply a more generic API, and avoid the spurious comparisons.
+  assert(ProfDataName->getString() == "expected");
   return ProfDataName;
 }
 

>From 947f9e16732197418c9f49ed02a01b187a50f936 Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Thu, 6 Jun 2024 21:13:24 +0000
Subject: [PATCH 4/4] Rename hasBranchWeightProvenance to hasBranchWeightOrigin

Created using spr 1.3.4
---
 llvm/include/llvm/IR/ProfDataUtils.h         |  4 ++--
 llvm/lib/CodeGen/CodeGenPrepare.cpp          |  9 ++++-----
 llvm/lib/IR/Metadata.cpp                     |  8 ++++----
 llvm/lib/IR/ProfDataUtils.cpp                | 14 +++++++-------
 llvm/lib/IR/Verifier.cpp                     |  2 +-
 llvm/lib/Transforms/Scalar/JumpThreading.cpp |  4 ++--
 llvm/lib/Transforms/Utils/Local.cpp          |  2 +-
 7 files changed, 21 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 3c761bdc1bf3e..1d7c97d9be953 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -57,11 +57,11 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I);
 
 /// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
 /// intrinsic
-bool hasBranchWeightProvenance(const Instruction &I);
+bool hasBranchWeightOrigin(const Instruction &I);
 
 /// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
 /// intrinsic
-bool hasBranchWeightProvenance(const MDNode *ProfileData);
+bool hasBranchWeightOrigin(const MDNode *ProfileData);
 
 /// Return the offset to the first branch weight data
 unsigned getBranchWeightOffset(const MDNode *ProfileData);
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index ce11caca38988..0e01080bd75cc 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8864,11 +8864,10 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) {
         uint64_t NewTrueWeight = TrueWeight;
         uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
-        Br1->setMetadata(
-            LLVMContext::MD_prof,
-            MDBuilder(Br1->getContext())
-                .createBranchWeights(TrueWeight, FalseWeight,
-                                     hasBranchWeightProvenance(*Br1)));
+        Br1->setMetadata(LLVMContext::MD_prof,
+                         MDBuilder(Br1->getContext())
+                             .createBranchWeights(TrueWeight, FalseWeight,
+                                                  hasBranchWeightOrigin(*Br1)));
 
         NewTrueWeight = TrueWeight;
         NewFalseWeight = 2 * FalseWeight;
diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp
index b6c932495a145..5f42ce22f72fe 100644
--- a/llvm/lib/IR/Metadata.cpp
+++ b/llvm/lib/IR/Metadata.cpp
@@ -1196,10 +1196,10 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
   StringRef AProfName = AMDS->getString();
   StringRef BProfName = BMDS->getString();
   if (AProfName == "branch_weights" && BProfName == "branch_weights") {
-    ConstantInt *AInstrWeight =
-        mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
-    ConstantInt *BInstrWeight =
-        mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
+    ConstantInt *AInstrWeight = mdconst::dyn_extract<ConstantInt>(
+        A->getOperand(getBranchWeightOffset(A)));
+    ConstantInt *BInstrWeight = mdconst::dyn_extract<ConstantInt>(
+        B->getOperand(getBranchWeightOffset(B)));
     assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
     return MDNode::get(Ctx,
                        {MDHelper.createString("branch_weights"),
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index af536d2110eac..c4b1ed55de8a2 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -121,24 +121,24 @@ bool hasValidBranchWeightMD(const Instruction &I) {
   return getValidBranchWeightMDNode(I);
 }
 
-bool hasBranchWeightProvenance(const Instruction &I) {
+bool hasBranchWeightOrigin(const Instruction &I) {
   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
-  return hasBranchWeightProvenance(ProfileData);
+  return hasBranchWeightOrigin(ProfileData);
 }
 
-bool hasBranchWeightProvenance(const MDNode *ProfileData) {
+bool hasBranchWeightOrigin(const MDNode *ProfileData) {
   if (!isBranchWeightMD(ProfileData))
     return false;
   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
   // NOTE: if we ever have more types of branch weight provenance,
   // we need to check the string value is "expected". For now, we
   // supply a more generic API, and avoid the spurious comparisons.
-  assert(ProfDataName->getString() == "expected");
-  return ProfDataName;
+  assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
+  return ProfDataName != nullptr;
 }
 
 unsigned getBranchWeightOffset(const MDNode *ProfileData) {
-  return hasBranchWeightProvenance(ProfileData) ? 2 : 1;
+  return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
 }
 
 MDNode *getBranchWeightMDNode(const Instruction &I) {
@@ -210,7 +210,7 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
   if (!ProfDataName)
     return false;
 
-  if (ProfDataName->getString().equals("branch_weights")) {
+  if (ProfDataName->getString() == "branch_weights") {
     unsigned Offset = getBranchWeightOffset(ProfileData);
     for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
       auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 39185905a1516..e0fde2b7d90dc 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -4808,7 +4808,7 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
   StringRef ProfName = MDS->getString();
 
   // Check consistency of !prof branch_weights metadata.
-  if (ProfName.equals("branch_weights")) {
+  if (ProfName == "branch_weights") {
     unsigned int Offset = getBranchWeightOffset(MD);
     if (isa<InvokeInst>(&I)) {
       Check(MD->getNumOperands() == (1 + Offset) ||
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 88307b8b074ed..b9583836aea06 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -231,7 +231,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
       Weights[0] = BP.getCompl().getNumerator();
       Weights[1] = BP.getNumerator();
     }
-    setBranchWeights(*PredBr, Weights, hasBranchWeightProvenance(*PredBr));
+    setBranchWeights(*PredBr, Weights, hasBranchWeightOrigin(*PredBr));
   }
 }
 
@@ -2618,7 +2618,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    setBranchWeights(*TI, Weights, hasBranchWeightProvenance(*TI));
+    setBranchWeights(*TI, Weights, hasBranchWeightOrigin(*TI));
   }
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 8f116c42d3d78..12229123675e7 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -231,7 +231,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          setBranchWeights(*SI, Weights, hasBranchWeightProvenance(MD));
+          setBranchWeights(*SI, Weights, hasBranchWeightOrigin(MD));
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();



More information about the llvm-commits mailing list