[llvm] [nfc][PGO]Factor out profile scaling into a standalone helper function (PR #83780)

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 00:06:36 PST 2024


https://github.com/minglotus-6 created https://github.com/llvm/llvm-project/pull/83780

- Put the helper function in `ProfDataUtil.h/cpp`, which is already a dependency of `Instructions.cpp`
- The helper function could be re-used to update profiles of `InvokeInst` (in a follow-up pull request)

>From e7c6220a4f2c42c94fa33fd8c61da569ef67d4db Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 4 Mar 2024 00:04:11 -0800
Subject: [PATCH] [nfc][PGO]Factor out profile scaling into a standalone
 function

---
 llvm/include/llvm/IR/ProfDataUtils.h |  3 ++
 llvm/lib/IR/Instructions.cpp         | 46 +-------------------------
 llvm/lib/IR/ProfDataUtils.cpp        | 48 ++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+), 45 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c7906..c0897408986fb3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -108,5 +108,8 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 /// a `prof` metadata reference to instruction `I`.
 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
 
+/// Scaling the profile data attached to 'I' using the ratio of S/T.
+void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 42cdcad78228f6..9ae71acd523c36 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -825,15 +825,6 @@ CallInst *CallInst::Create(CallInst *CI, ArrayRef<OperandBundleDef> OpB,
 // of S/T. The meaning of "branch_weights" meta data for call instruction is
 // transfered to represent call count.
 void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (ProfileData == nullptr)
-    return;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || (!ProfDataName->getString().equals("branch_weights") &&
-                        !ProfDataName->getString().equals("VP")))
-    return;
-
   if (T == 0) {
     LLVM_DEBUG(dbgs() << "Attempting to update profile weights will result in "
                          "div by 0. Ignoring. Likely the function "
@@ -842,42 +833,7 @@ void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
                          "with non-zero prof info.");
     return;
   }
-
-  MDBuilder MDB(getContext());
-  SmallVector<Metadata *, 3> Vals;
-  Vals.push_back(ProfileData->getOperand(0));
-  APInt APS(128, S), APT(128, T);
-  if (ProfDataName->getString().equals("branch_weights") &&
-      ProfileData->getNumOperands() > 0) {
-    // Using APInt::div may be expensive, but most cases should fit 64 bits.
-    APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
-                       ->getValue()
-                       .getZExtValue());
-    Val *= APS;
-    Vals.push_back(MDB.createConstant(
-        ConstantInt::get(Type::getInt32Ty(getContext()),
-                         Val.udiv(APT).getLimitedValue(UINT32_MAX))));
-  } else if (ProfDataName->getString().equals("VP"))
-    for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
-      // The first value is the key of the value profile, which will not change.
-      Vals.push_back(ProfileData->getOperand(i));
-      uint64_t Count =
-          mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
-              ->getValue()
-              .getZExtValue();
-      // Don't scale the magic number.
-      if (Count == NOMORE_ICP_MAGICNUM) {
-        Vals.push_back(ProfileData->getOperand(i + 1));
-        continue;
-      }
-      // Using APInt::div may be expensive, but most cases should fit 64 bits.
-      APInt Val(128, Count);
-      Val *= APS;
-      Vals.push_back(MDB.createConstant(
-          ConstantInt::get(Type::getInt64Ty(getContext()),
-                           Val.udiv(APT).getLimitedValue())));
-    }
-  setMetadata(LLVMContext::MD_prof, MDNode::get(getContext(), Vals));
+  scaleProfData(*this, S, T);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a522..dc86f4204b1a1d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -190,4 +190,52 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
 }
 
+void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
+  assert(T != 0 && "Caller should guarantee");
+  auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+  if (ProfileData == nullptr)
+    return;
+
+  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+  if (!ProfDataName || (!ProfDataName->getString().equals("branch_weights") &&
+                        !ProfDataName->getString().equals("VP")))
+    return;
+
+  LLVMContext &C = I.getContext();
+
+  MDBuilder MDB(C);
+  SmallVector<Metadata *, 3> Vals;
+  Vals.push_back(ProfileData->getOperand(0));
+  APInt APS(128, S), APT(128, T);
+  if (ProfDataName->getString().equals("branch_weights") &&
+      ProfileData->getNumOperands() > 0) {
+    // Using APInt::div may be expensive, but most cases should fit 64 bits.
+    APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
+                       ->getValue()
+                       .getZExtValue());
+    Val *= APS;
+    Vals.push_back(MDB.createConstant(ConstantInt::get(
+        Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
+  } else if (ProfDataName->getString().equals("VP"))
+    for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
+      // The first value is the key of the value profile, which will not change.
+      Vals.push_back(ProfileData->getOperand(i));
+      uint64_t Count =
+          mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
+              ->getValue()
+              .getZExtValue();
+      // Don't scale the magic number.
+      if (Count == NOMORE_ICP_MAGICNUM) {
+        Vals.push_back(ProfileData->getOperand(i + 1));
+        continue;
+      }
+      // Using APInt::div may be expensive, but most cases should fit 64 bits.
+      APInt Val(128, Count);
+      Val *= APS;
+      Vals.push_back(MDB.createConstant(ConstantInt::get(
+          Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
+    }
+  I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
+}
+
 } // namespace llvm



More information about the llvm-commits mailing list