[llvm] a3a44bf - [llvm][ProfDataUtils] Provide getNumBranchWeights API (#90146)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 15:17:43 PDT 2024


Author: Paul Kirth
Date: 2024-06-24T15:17:40-07:00
New Revision: a3a44bfbdfefe0928124f9e40d242507f75b87f4

URL: https://github.com/llvm/llvm-project/commit/a3a44bfbdfefe0928124f9e40d242507f75b87f4
DIFF: https://github.com/llvm/llvm-project/commit/a3a44bfbdfefe0928124f9e40d242507f75b87f4.diff

LOG: [llvm][ProfDataUtils] Provide getNumBranchWeights API (#90146)

As suggested in
https://github.com/llvm/llvm-project/pull/86609/files#r1556689262
an API for getting the number of branch weights directly from the MD
node would be useful in a variety of checks, and keeps the logic within
ProfDataUtils.

Added: 
    

Modified: 
    llvm/include/llvm/IR/ProfDataUtils.h
    llvm/lib/IR/Instructions.cpp
    llvm/lib/IR/ProfDataUtils.cpp
    llvm/lib/IR/Verifier.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 1d7c97d9be953..0bea517df832e 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -66,6 +66,8 @@ bool hasBranchWeightOrigin(const MDNode *ProfileData);
 /// Return the offset to the first branch weight data
 unsigned getBranchWeightOffset(const MDNode *ProfileData);
 
+unsigned getNumBranchWeights(const MDNode &ProfileData);
+
 /// Extract branch weights from MD_prof metadata
 ///
 /// \param ProfileData A pointer to an MDNode.

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 445323f2a085b..52e25a8cc09be 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4002,11 +4002,7 @@ void SwitchInstProfUpdateWrapper::init() {
   if (!ProfileData)
     return;
 
-  // FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
-  // getValidBranchWeightMDNode(), but the need to use llvm_unreachable
-  // makes them slightly 
diff erent.
-  if (ProfileData->getNumOperands() !=
-      SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
+  if (getNumBranchWeights(*ProfileData) != SI.getNumSuccessors()) {
     llvm_unreachable("number of prof branch_weights metadata operands does "
                      "not correspond to number of succesors");
   }

diff  --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 2f8e61c41b1da..992ce34e00034 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -142,6 +142,10 @@ unsigned getBranchWeightOffset(const MDNode *ProfileData) {
   return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
 }
 
+unsigned getNumBranchWeights(const MDNode &ProfileData) {
+  return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
+}
+
 MDNode *getBranchWeightMDNode(const Instruction &I) {
   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
   if (!isBranchWeightMD(ProfileData))
@@ -151,9 +155,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
 
 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   auto *ProfileData = getBranchWeightMDNode(I);
-  auto Offset = getBranchWeightOffset(ProfileData);
-  if (ProfileData &&
-      ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
+  if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
     return ProfileData;
   return nullptr;
 }

diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 0abd5f76d449c..c98f61d555140 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -4818,10 +4818,9 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
 
   // Check consistency of !prof branch_weights metadata.
   if (ProfName == "branch_weights") {
-    unsigned int Offset = getBranchWeightOffset(MD);
+    unsigned NumBranchWeights = getNumBranchWeights(*MD);
     if (isa<InvokeInst>(&I)) {
-      Check(MD->getNumOperands() == (1 + Offset) ||
-                MD->getNumOperands() == (2 + Offset),
+      Check(NumBranchWeights == 1 || NumBranchWeights == 2,
             "Wrong number of InvokeInst branch_weights operands", MD);
     } else {
       unsigned ExpectedNumOperands = 0;
@@ -4841,10 +4840,11 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
         CheckFailed("!prof branch_weights are not allowed for this instruction",
                     MD);
 
-      Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
-            "Wrong number of operands", MD);
+      Check(NumBranchWeights == ExpectedNumOperands, "Wrong number of operands",
+            MD);
     }
-    for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
+    for (unsigned i = getBranchWeightOffset(MD); i < MD->getNumOperands();
+         ++i) {
       auto &MDO = MD->getOperand(i);
       Check(MDO, "second operand should not be null", MD);
       Check(mdconst::dyn_extract<ConstantInt>(MDO),


        


More information about the llvm-commits mailing list