[llvm] [llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils (PR #86607)

Paul Kirth via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 16:03:40 PDT 2024


https://github.com/ilovepi updated https://github.com/llvm/llvm-project/pull/86607

>From 31588d29e9299dfcbf6909fd0bcd71dca73e3e8d Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Tue, 26 Mar 2024 00:49:00 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.4
---
 llvm/include/llvm/IR/ProfDataUtils.h          |  9 +++-
 llvm/lib/IR/ProfDataUtils.cpp                 | 45 +++++++++++++------
 .../Transforms/Utils/LoopRotationUtils.cpp    |  2 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     |  7 +--
 4 files changed, 41 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c7906..dc983eed13a8d3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights);
 
 /// Faster version of extractBranchWeights() that skips checks and must only
-/// be called with "branch_weights" metadata nodes.
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+/// be called with "branch_weights" metadata nodes. Supports uint32_t.
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights);
 
+/// Faster version of extractBranchWeights() that skips checks and must only
+/// be called with "branch_weights" metadata nodes. Supports uint64_t.
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights);
+
 /// Extract branch weights attatched to an Instruction
 ///
 /// \param I The Instruction to extract weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a522..b4e09e76993f99 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
   return ProfDataName->getString().equals(Name);
 }
 
+template <typename T,
+          typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+                                      SmallVectorImpl<T> &Weights) {
+  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+  unsigned NOps = ProfileData->getNumOperands();
+  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+  Weights.resize(NOps - WeightsIdx);
+
+  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+    ConstantInt *Weight =
+        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+    assert(Weight && "Malformed branch_weight in MD_prof node");
+    assert(Weight->getValue().getActiveBits() <= 32 &&
+           "Too many bits for uint32_t");
+    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+  }
+}
+
 } // namespace
 
 namespace llvm {
@@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   return nullptr;
 }
 
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights) {
-  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
-  unsigned NOps = ProfileData->getNumOperands();
-  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
-  Weights.resize(NOps - WeightsIdx);
+  extractFromBranchWeightMD(ProfileData, Weights);
+}
 
-  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    assert(Weight && "Malformed branch_weight in MD_prof node");
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
-  }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights) {
+  extractFromBranchWeightMD(ProfileData, Weights);
 }
 
+
+
+
+
+
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index bc671171137199..f4b43ce370a5da 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     return;
 
   SmallVector<uint32_t, 2> Weights;
-  extractFromBranchWeightMD(WeightMD, Weights);
+  extractFromBranchWeightMD32(WeightMD, Weights);
   if (Weights.size() != 2)
     return;
   uint32_t OrigLoopExitWeight = Weights[0];
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55bbffb18879fb..a425e26d490e4f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
 static void GetBranchWeights(Instruction *TI,
                              SmallVectorImpl<uint64_t> &Weights) {
   MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
-  assert(MD);
-  for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
-    ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
-    Weights.push_back(CI->getValue().getZExtValue());
-  }
+  assert(MD && "Invalid branch-weight metadata");
+  extractFromBranchWeightMD64(MD, Weights);
 
   // If TI is a conditional eq, the default case is the false case,
   // and the corresponding branch-weight data is at index 2. We swap the

>From b69df3fade21cbae9a81cb6161b8f171e60762fd Mon Sep 17 00:00:00 2001
From: Paul Kirth <paulkirth at google.com>
Date: Tue, 26 Mar 2024 00:55:03 +0000
Subject: [PATCH 2/2] git clang-format

Created using spr 1.3.4
---
 llvm/include/llvm/IR/ProfDataUtils.h | 4 ++--
 llvm/lib/IR/ProfDataUtils.cpp        | 9 ++-------
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index dc983eed13a8d3..457ffdff8fe37f 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -67,12 +67,12 @@ bool extractBranchWeights(const MDNode *ProfileData,
 /// Faster version of extractBranchWeights() that skips checks and must only
 /// be called with "branch_weights" metadata nodes. Supports uint32_t.
 void extractFromBranchWeightMD32(const MDNode *ProfileData,
-                               SmallVectorImpl<uint32_t> &Weights);
+                                 SmallVectorImpl<uint32_t> &Weights);
 
 /// Faster version of extractBranchWeights() that skips checks and must only
 /// be called with "branch_weights" metadata nodes. Supports uint64_t.
 void extractFromBranchWeightMD64(const MDNode *ProfileData,
-                               SmallVectorImpl<uint64_t> &Weights);
+                                 SmallVectorImpl<uint64_t> &Weights);
 
 /// Extract branch weights attatched to an Instruction
 ///
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b4e09e76993f99..36e165e641f464 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -121,20 +121,15 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
 }
 
 void extractFromBranchWeightMD32(const MDNode *ProfileData,
-                               SmallVectorImpl<uint32_t> &Weights) {
+                                 SmallVectorImpl<uint32_t> &Weights) {
   extractFromBranchWeightMD(ProfileData, Weights);
 }
 
 void extractFromBranchWeightMD64(const MDNode *ProfileData,
-                               SmallVectorImpl<uint64_t> &Weights) {
+                                 SmallVectorImpl<uint64_t> &Weights) {
   extractFromBranchWeightMD(ProfileData, Weights);
 }
 
-
-
-
-
-
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))



More information about the llvm-commits mailing list