[llvm-branch-commits] [llvm][ProfDataUtils] provide getNumBranchWeights API (PR #90146)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Apr 25 16:11:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: Paul Kirth (ilovepi)
<details>
<summary>Changes</summary>
As suggested in https://github.com/llvm/llvm-project/pull/86609/files#r1556689262
an API for getting the number of branch weights directly from the MD node would
be useful in a variety of checks, and keeps the logic within ProfDataUtils.
---
Full diff: https://github.com/llvm/llvm-project/pull/90146.diff
4 Files Affected:
- (modified) llvm/include/llvm/IR/ProfDataUtils.h (+2)
- (modified) llvm/lib/IR/Instructions.cpp (+1-5)
- (modified) llvm/lib/IR/ProfDataUtils.cpp (+5-3)
- (modified) llvm/lib/IR/Verifier.cpp (+5-5)
``````````diff
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 3c761bdc1bf3e9..7008d3240feded 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -66,6 +66,8 @@ bool hasBranchWeightProvenance(const MDNode *ProfileData);
/// Return the offset to the first branch weight data
unsigned getBranchWeightOffset(const MDNode *ProfileData);
+unsigned getNumBranchWeights(const MDNode &ProfileData);
+
/// Extract branch weights from MD_prof metadata
///
/// \param ProfileData A pointer to an MDNode.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 650d32ac17fc2b..a14d6758cad1d8 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -5165,11 +5165,7 @@ void SwitchInstProfUpdateWrapper::init() {
if (!ProfileData)
return;
- // FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
- // getValidBranchWeightMDNode(), but the need to use llvm_unreachable
- // makes them slightly different.
- if (ProfileData->getNumOperands() !=
- SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
+ if (getNumBranchWeights(*ProfileData) != SI.getNumSuccessors()) {
llvm_unreachable("number of prof branch_weights metadata operands does "
"not correspond to number of succesors");
}
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index cd219c22e3dfe6..9544ea85b93d96 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -123,6 +123,10 @@ unsigned getBranchWeightOffset(const MDNode *ProfileData) {
return hasBranchWeightProvenance(ProfileData) ? 2 : 1;
}
+unsigned getNumBranchWeights(const MDNode &ProfileData) {
+ return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
+}
+
MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
@@ -132,9 +136,7 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
- auto Offset = getBranchWeightOffset(ProfileData);
- if (ProfileData &&
- ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
+ if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
return ProfileData;
return nullptr;
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 4a142be71eec41..ecccb1790ff8ff 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -4787,10 +4787,9 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
// Check consistency of !prof branch_weights metadata.
if (ProfName.equals("branch_weights")) {
- unsigned int Offset = getBranchWeightOffset(I);
+ unsigned NumBranchWeights = getNumBranchWeights(*MD);
if (isa<InvokeInst>(&I)) {
- Check(MD->getNumOperands() == (1 + Offset) ||
- MD->getNumOperands() == (2 + Offset),
+ Check(NumBranchWeights == 1 || NumBranchWeights == 2,
"Wrong number of InvokeInst branch_weights operands", MD);
} else {
unsigned ExpectedNumOperands = 0;
@@ -4810,10 +4809,11 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
CheckFailed("!prof branch_weights are not allowed for this instruction",
MD);
- Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
+ Check(NumBranchWeights == ExpectedNumOperands,
"Wrong number of operands", MD);
}
- for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
+ for (unsigned i = getBranchWeightOffset(MD); i < MD->getNumOperands();
+ ++i) {
auto &MDO = MD->getOperand(i);
Check(MDO, "second operand should not be null", MD);
Check(mdconst::dyn_extract<ConstantInt>(MDO),
``````````
</details>
https://github.com/llvm/llvm-project/pull/90146
More information about the llvm-branch-commits
mailing list