[llvm] 7538df9 - [llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils (#86607)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 30 13:53:09 PDT 2024
Author: Paul Kirth
Date: 2024-04-30T13:53:04-07:00
New Revision: 7538df90aee11603bce5146a3f34c06e9ef3f793
URL: https://github.com/llvm/llvm-project/commit/7538df90aee11603bce5146a3f34c06e9ef3f793
DIFF: https://github.com/llvm/llvm-project/commit/7538df90aee11603bce5146a3f34c06e9ef3f793.diff
LOG: [llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils (#86607)
Since some places, like SimplifyCFG, work with 64-bit weights, we supply
an API in ProfDataUtils to extract the weights accordingly.
We change the API slightly to disambiguate the 64-bit version from the
32-bit version.
Added:
Modified:
llvm/include/llvm/IR/ProfDataUtils.h
llvm/lib/IR/ProfDataUtils.cpp
llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index c0897408986fb3..88fbad4d6b9d82 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -65,9 +65,14 @@ bool extractBranchWeights(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);
/// Faster version of extractBranchWeights() that skips checks and must only
-/// be called with "branch_weights" metadata nodes.
-void extractFromBranchWeightMD(const MDNode *ProfileData,
- SmallVectorImpl<uint32_t> &Weights);
+/// be called with "branch_weights" metadata nodes. Supports uint32_t.
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights);
+
+/// Faster version of extractBranchWeights() that skips checks and must only
+/// be called with "branch_weights" metadata nodes. Supports uint64_t.
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights);
/// Extract branch weights attatched to an Instruction
///
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 186f9bf10d0cd7..3d72418593a7a6 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
return ProfDataName->getString() == Name;
}
+template <typename T,
+ typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+ SmallVectorImpl<T> &Weights) {
+ assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+ unsigned NOps = ProfileData->getNumOperands();
+ assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+ Weights.resize(NOps - WeightsIdx);
+
+ for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+ ConstantInt *Weight =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+ assert(Weight && "Malformed branch_weight in MD_prof node");
+ assert(Weight->getValue().getActiveBits() <= 32 &&
+ "Too many bits for uint32_t");
+ Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+ }
+}
+
} // namespace
namespace llvm {
@@ -100,22 +120,14 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
return nullptr;
}
-void extractFromBranchWeightMD(const MDNode *ProfileData,
- SmallVectorImpl<uint32_t> &Weights) {
- assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
- unsigned NOps = ProfileData->getNumOperands();
- assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
- Weights.resize(NOps - WeightsIdx);
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
+ extractFromBranchWeightMD(ProfileData, Weights);
+}
- for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
- ConstantInt *Weight =
- mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
- assert(Weight && "Malformed branch_weight in MD_prof node");
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
- Weights[Idx - WeightsIdx] = Weight->getZExtValue();
- }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights) {
+ extractFromBranchWeightMD(ProfileData, Weights);
}
bool extractBranchWeights(const MDNode *ProfileData,
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 0f55af3b6eddf8..5cd96412a322d1 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
return;
SmallVector<uint32_t, 2> Weights;
- extractFromBranchWeightMD(WeightMD, Weights);
+ extractFromBranchWeightMD32(WeightMD, Weights);
if (Weights.size() != 2)
return;
uint32_t OrigLoopExitWeight = Weights[0];
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 4db72461c95e47..5a44a11ecfd2c5 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1066,11 +1066,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
static void GetBranchWeights(Instruction *TI,
SmallVectorImpl<uint64_t> &Weights) {
MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
- assert(MD);
- for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
- ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
- Weights.push_back(CI->getValue().getZExtValue());
- }
+ assert(MD && "Invalid branch-weight metadata");
+ extractFromBranchWeightMD64(MD, Weights);
// If TI is a conditional eq, the default case is the false case,
// and the corresponding branch-weight data is at index 2. We swap the
More information about the llvm-commits
mailing list