[llvm] [FuncSpec] Only compute Latency bonus when necessary (PR #113159)

Hari Limaye via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 22 04:09:39 PDT 2024


https://github.com/hazzlim updated https://github.com/llvm/llvm-project/pull/113159

>From d3823156e7d19413f76c3b1943e3d0d4b31bf09c Mon Sep 17 00:00:00 2001
From: Hari Limaye <hari.limaye at arm.com>
Date: Wed, 16 Oct 2024 23:25:54 +0000
Subject: [PATCH 1/3] [FuncSpec] Only compute Latency bonus when necessary

Only compute the Latency component of a specialisation's Bonus when
necessary, to avoid unnecessarily computing the Block Frequency
Information for a Function.
---
 .../Transforms/IPO/FunctionSpecialization.h   |  23 ++-
 .../Transforms/IPO/FunctionSpecialization.cpp |  85 +++++----
 .../Transforms/SCCP/ipsccp-preserve-pdt.ll    |  18 +-
 .../IPO/FunctionSpecializationTest.cpp        | 177 +++++++++++-------
 4 files changed, 182 insertions(+), 121 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index b001771951e0fe..a50eafe48293ad 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -173,8 +173,9 @@ struct Bonus {
 };
 
 class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
+  std::function<BlockFrequencyInfo &(Function &)> GetBFI;
+  Function *F;
   const DataLayout &DL;
-  BlockFrequencyInfo &BFI;
   TargetTransformInfo &TTI;
   SCCPSolver &Solver;
 
@@ -192,17 +193,20 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   ConstMap::iterator LastVisited;
 
 public:
-  InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
-                  TargetTransformInfo &TTI, SCCPSolver &Solver)
-      : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
+  InstCostVisitor(std::function<BlockFrequencyInfo &(Function &)> GetBFI,
+                  Function *F, const DataLayout &DL, TargetTransformInfo &TTI,
+                  SCCPSolver &Solver)
+      : GetBFI(GetBFI), F(F), DL(DL), TTI(TTI), Solver(Solver) {}
 
   bool isBlockExecutable(BasicBlock *BB) {
     return Solver.isBlockExecutable(BB) && !DeadBlocks.contains(BB);
   }
 
-  Bonus getSpecializationBonus(Argument *A, Constant *C);
+  Cost getCodeSizeBonus(Argument *A, Constant *C);
+
+  Cost getCodeSizeBonusFromPendingPHIs();
 
-  Bonus getBonusFromPendingPHIs();
+  Cost getLatencyBonus();
 
 private:
   friend class InstVisitor<InstCostVisitor, Constant *>;
@@ -210,8 +214,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   static bool canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ,
                                     DenseSet<BasicBlock *> &DeadBlocks);
 
-  Bonus getUserBonus(Instruction *User, Value *Use = nullptr,
-                     Constant *C = nullptr);
+  Cost getUserCodeSizeBonus(Instruction *User, Value *Use = nullptr,
+                            Constant *C = nullptr);
 
   Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList);
   Cost estimateSwitchInst(SwitchInst &I);
@@ -283,9 +287,8 @@ class FunctionSpecializer {
   bool run();
 
   InstCostVisitor getInstCostVisitorFor(Function *F) {
-    auto &BFI = GetBFI(*F);
     auto &TTI = GetTTI(*F);
-    return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver);
+    return InstCostVisitor(GetBFI, F, M.getDataLayout(), TTI, Solver);
   }
 
 private:
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 7feebbe420ae53..38c10c29cf4e3a 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -112,7 +112,7 @@ bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ,
 Cost InstCostVisitor::estimateBasicBlocks(
                           SmallVectorImpl<BasicBlock *> &WorkList) {
   Cost CodeSize = 0;
-  // Accumulate the instruction cost of each basic block weighted by frequency.
+  // Accumulate the codesize savings of each basic block.
   while (!WorkList.empty()) {
     BasicBlock *BB = WorkList.pop_back_val();
 
@@ -154,37 +154,55 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
   return KnownConstants.lookup(V);
 }
 
-Bonus InstCostVisitor::getBonusFromPendingPHIs() {
-  Bonus B;
+Cost InstCostVisitor::getCodeSizeBonusFromPendingPHIs() {
+  Cost CodeSize;
   while (!PendingPHIs.empty()) {
     Instruction *Phi = PendingPHIs.pop_back_val();
     // The pending PHIs could have been proven dead by now.
     if (isBlockExecutable(Phi->getParent()))
-      B += getUserBonus(Phi);
+      CodeSize += getUserCodeSizeBonus(Phi);
   }
-  return B;
+  return CodeSize;
 }
 
-/// Compute a bonus for replacing argument \p A with constant \p C.
-Bonus InstCostVisitor::getSpecializationBonus(Argument *A, Constant *C) {
+/// Compute the codesize savings for replacing argument \p A with constant \p C.
+Cost InstCostVisitor::getCodeSizeBonus(Argument *A, Constant *C) {
   LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
                     << C->getNameOrAsOperand() << "\n");
-  Bonus B;
+  Cost CodeSize;
   for (auto *U : A->users())
     if (auto *UI = dyn_cast<Instruction>(U))
       if (isBlockExecutable(UI->getParent()))
-        B += getUserBonus(UI, A, C);
+        CodeSize += getUserCodeSizeBonus(UI, A, C);
 
   LLVM_DEBUG(dbgs() << "FnSpecialization:   Accumulated bonus {CodeSize = "
-                    << B.CodeSize << ", Latency = " << B.Latency
-                    << "} for argument " << *A << "\n");
-  return B;
+                    << CodeSize << "} for argument " << *A << "\n");
+  return CodeSize;
+}
+
+Cost InstCostVisitor::getLatencyBonus() {
+  auto &BFI = GetBFI(*F);
+  Cost Latency = 0;
+
+  for (auto Pair : KnownConstants) {
+    Instruction *I = dyn_cast<Instruction>(Pair.first);
+    if (!I)
+      continue;
+
+    uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() /
+                      BFI.getEntryFreq().getFrequency();
+    Latency +=
+        Weight * TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
+  }
+
+  return Latency;
 }
 
-Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+Cost InstCostVisitor::getUserCodeSizeBonus(Instruction *User, Value *Use,
+                                           Constant *C) {
   // We have already propagated a constant for this user.
   if (KnownConstants.contains(User))
-    return {0, 0};
+    return 0;
 
   // Cache the iterator before visiting.
   LastVisited = Use ? KnownConstants.insert({Use, C}).first
@@ -198,7 +216,7 @@ Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C)
   } else {
     C = visit(*User);
     if (!C)
-      return {0, 0};
+      return 0;
   }
 
   // Even though it doesn't make sense to bind switch and branch instructions
@@ -208,23 +226,15 @@ Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C)
 
   CodeSize += TTI.getInstructionCost(User, TargetTransformInfo::TCK_CodeSize);
 
-  uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
-                    BFI.getEntryFreq().getFrequency();
-
-  Cost Latency = Weight *
-      TTI.getInstructionCost(User, TargetTransformInfo::TCK_Latency);
-
   LLVM_DEBUG(dbgs() << "FnSpecialization:     {CodeSize = " << CodeSize
-                    << ", Latency = " << Latency << "} for user "
-                    << *User << "\n");
+                    << "} for user " << *User << "\n");
 
-  Bonus B(CodeSize, Latency);
   for (auto *U : User->users())
     if (auto *UI = dyn_cast<Instruction>(U))
       if (UI != User && isBlockExecutable(UI->getParent()))
-        B += getUserBonus(UI, User, C);
+        CodeSize += getUserCodeSizeBonus(UI, User, C);
 
-  return B;
+  return CodeSize;
 }
 
 Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
@@ -875,24 +885,23 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       AllSpecs[Index].CallSites.push_back(&CS);
     } else {
       // Calculate the specialisation gain.
-      Bonus B;
+      Cost CodeSize;
       unsigned Score = 0;
       InstCostVisitor Visitor = getInstCostVisitorFor(F);
       for (ArgInfo &A : S.Args) {
-        B += Visitor.getSpecializationBonus(A.Formal, A.Actual);
+        CodeSize += Visitor.getCodeSizeBonus(A.Formal, A.Actual);
         Score += getInliningBonus(A.Formal, A.Actual);
       }
-      B += Visitor.getBonusFromPendingPHIs();
-
+      CodeSize += Visitor.getCodeSizeBonusFromPendingPHIs();
 
       LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = "
-                        << B.CodeSize << ", Latency = " << B.Latency
-                        << ", Inlining = " << Score << "}\n");
+                        << CodeSize << ", Inlining = " << Score << "}\n");
 
+      Bonus B = {CodeSize, 0};
       FunctionGrowth[F] += FuncSize - B.CodeSize;
 
       auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize,
-                             unsigned FuncGrowth) -> bool {
+                             unsigned FuncGrowth, InstCostVisitor &V) -> bool {
         // No check required.
         if (ForceSpecialization)
           return true;
@@ -902,6 +911,14 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
         // Minimum codesize savings.
         if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100)
           return false;
+
+        // Lazily compute the Latency, to avoid unnecessarily computing BFI.
+        B += {0, V.getLatencyBonus()};
+
+        LLVM_DEBUG(
+            dbgs() << "FnSpecialization: Specialization bonus {Latency = "
+                   << B.Latency << "}\n");
+
         // Minimum latency savings.
         if (B.Latency < MinLatencySavings * FuncSize / 100)
           return false;
@@ -912,7 +929,7 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       };
 
       // Discard unprofitable specialisations.
-      if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F]))
+      if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F], Visitor))
         continue;
 
       // Create a new specialisation entry.
diff --git a/llvm/test/Transforms/SCCP/ipsccp-preserve-pdt.ll b/llvm/test/Transforms/SCCP/ipsccp-preserve-pdt.ll
index ff57569d127884..f8c8e33dfc2335 100644
--- a/llvm/test/Transforms/SCCP/ipsccp-preserve-pdt.ll
+++ b/llvm/test/Transforms/SCCP/ipsccp-preserve-pdt.ll
@@ -4,25 +4,25 @@
 
 ; This test case is trying to validate that the postdomtree is preserved
 ; correctly by the ipsccp pass. A tricky bug was introduced in commit
-; 1b1232047e83b69561 when PDT would be feched using getCachedAnalysis in order
+; 1b1232047e83b69561 when PDT would be fetched using getCachedAnalysis in order
 ; to setup a DomTreeUpdater (to update the PDT during transformation in order
 ; to preserve the analysis). But given that commit the PDT could end up being
 ; required and calculated via BlockFrequency analysis. So the problem was that
 ; when setting up the DomTreeUpdater we used a nullptr in case PDT wasn't
-; cached at the begininng of IPSCCP, to indicate that no updates where needed
+; cached at the beginning of IPSCCP, to indicate that no updates were needed
 ; for PDT. But then the PDT was calculated, given the input IR, and preserved
 ; using the non-updated state (as the DTU wasn't configured for updating the
 ; PDT).
 
 ; CHECK-NOT: <badref>
 ; CHECK: Inorder PostDominator Tree: DFSNumbers invalid: 0 slow queries.
-; CHECK-NEXT:   [1]  <<exit node>> {4294967295,4294967295} [0]
-; CHECK-NEXT:     [2] %for.cond34 {4294967295,4294967295} [1]
-; CHECK-NEXT:       [3] %for.cond16 {4294967295,4294967295} [2]
-; CHECK-NEXT:     [2] %for.body {4294967295,4294967295} [1]
-; CHECK-NEXT:     [2] %if.end4 {4294967295,4294967295} [1]
-; CHECK-NEXT:       [3] %entry {4294967295,4294967295} [2]
-; CHECK-NEXT: Roots: %for.cond34 %for.body
+; CHECK-NEXT:  [1]  <<exit node>> {4294967295,4294967295} [0]
+; CHECK-NEXT:    [2] %for.body {4294967295,4294967295} [1]
+; CHECK-NEXT:    [2] %if.end4 {4294967295,4294967295} [1]
+; CHECK-NEXT:      [3] %entry {4294967295,4294967295} [2]
+; CHECK-NEXT:    [2] %for.cond34 {4294967295,4294967295} [1]
+; CHECK-NEXT:      [3] %for.cond16 {4294967295,4294967295} [2]
+; CHECK-NEXT: Roots: %for.body %for.cond34
 ; CHECK-NEXT: PostDominatorTree for function: bar
 ; CHECK-NOT: <badref>
 
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index b0ff55489e1762..9f2053948cb33b 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -43,6 +43,7 @@ class FunctionSpecializationTest : public testing::Test {
   FunctionAnalysisManager FAM;
   std::unique_ptr<Module> M;
   std::unique_ptr<SCCPSolver> Solver;
+  SmallVector<Instruction *, 8> Instructions;
 
   FunctionSpecializationTest() {
     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
@@ -97,21 +98,29 @@ class FunctionSpecializationTest : public testing::Test {
                                GetAC);
   }
 
-  Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
+  Cost getInstCodeSize(Instruction &I, bool SizeOnly = false) {
     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
-    auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
 
     Cost CodeSize =
         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
 
-    Cost Latency =
-        SizeOnly
-            ? 0
-            : BFI.getBlockFreq(I.getParent()).getFrequency() /
-                  BFI.getEntryFreq().getFrequency() *
-                  TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
+    if (!SizeOnly)
+      Instructions.push_back(&I);
 
-    return {CodeSize, Latency};
+    return CodeSize;
+  }
+
+  Cost getLatency(Function *F) {
+    auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
+    auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F);
+
+    Cost Latency = 0;
+    for (const Instruction *I : Instructions)
+      Latency += BFI.getBlockFreq(I->getParent()).getFrequency() /
+                 BFI.getEntryFreq().getFrequency() *
+                 TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
+
+    return Latency;
   }
 };
 
@@ -171,25 +180,30 @@ TEST_F(FunctionSpecializationTest, SwitchInst) {
   Instruction &BrLoop = BB2.back();
 
   // mul
-  Bonus Ref = getInstCost(Mul);
-  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
+  Cost Ref = getInstCodeSize(Mul);
+  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // and + or + add
-  Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
-  Test = Visitor.getSpecializationBonus(F->getArg(1), One);
+  Ref = getInstCodeSize(And) + getInstCodeSize(Or) + getInstCodeSize(Add);
+  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // switch + sdiv + br + br
-  Ref = getInstCost(Switch) +
-        getInstCost(Sdiv, /*SizeOnly =*/ true) +
-        getInstCost(BrBB2, /*SizeOnly =*/ true) +
-        getInstCost(BrLoop, /*SizeOnly =*/ true);
-  Test = Visitor.getSpecializationBonus(F->getArg(2), One);
+  Ref = getInstCodeSize(Switch) + getInstCodeSize(Sdiv, /*SizeOnly =*/true) +
+        getInstCodeSize(BrBB2, /*SizeOnly =*/true) +
+        getInstCodeSize(BrLoop, /*SizeOnly =*/true);
+  Test = Visitor.getCodeSizeBonus(F->getArg(2), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
+
+  // Latency.
+  Ref = getLatency(F);
+  Test = Visitor.getLatencyBonus();
+  EXPECT_EQ(Test, Ref);
+  EXPECT_TRUE(Test > 0);
 }
 
 TEST_F(FunctionSpecializationTest, BranchInst) {
@@ -238,27 +252,31 @@ TEST_F(FunctionSpecializationTest, BranchInst) {
   Instruction &BrLoop = BB2.front();
 
   // mul
-  Bonus Ref = getInstCost(Mul);
-  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
+  Cost Ref = getInstCodeSize(Mul);
+  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // add
-  Ref = getInstCost(Add);
-  Test = Visitor.getSpecializationBonus(F->getArg(1), One);
+  Ref = getInstCodeSize(Add);
+  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // branch + sub + br + sdiv + br
-  Ref = getInstCost(Branch) +
-        getInstCost(Sub, /*SizeOnly =*/ true) +
-        getInstCost(BrBB1BB2) +
-        getInstCost(Sdiv, /*SizeOnly =*/ true) +
-        getInstCost(BrBB2, /*SizeOnly =*/ true) +
-        getInstCost(BrLoop, /*SizeOnly =*/ true);
-  Test = Visitor.getSpecializationBonus(F->getArg(2), False);
+  Ref = getInstCodeSize(Branch) + getInstCodeSize(Sub, /*SizeOnly =*/true) +
+        getInstCodeSize(BrBB1BB2) + getInstCodeSize(Sdiv, /*SizeOnly =*/true) +
+        getInstCodeSize(BrBB2, /*SizeOnly =*/true) +
+        getInstCodeSize(BrLoop, /*SizeOnly =*/true);
+  Test = Visitor.getCodeSizeBonus(F->getArg(2), False);
+  EXPECT_EQ(Test, Ref);
+  EXPECT_TRUE(Test > 0);
+
+  // Latency.
+  Ref = getLatency(F);
+  Test = Visitor.getLatencyBonus();
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 }
 
 TEST_F(FunctionSpecializationTest, SelectInst) {
@@ -279,14 +297,22 @@ TEST_F(FunctionSpecializationTest, SelectInst) {
   Constant *False = ConstantInt::getFalse(M.getContext());
   Instruction &Select = *F->front().begin();
 
-  Bonus Ref = getInstCost(Select);
-  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), False);
-  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
-  Test = Visitor.getSpecializationBonus(F->getArg(1), One);
-  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
-  Test = Visitor.getSpecializationBonus(F->getArg(2), Zero);
-  EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  Cost RefCodeSize = getInstCodeSize(Select);
+  Cost RefLatency = getLatency(F);
+
+  Cost TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(0), False);
+  EXPECT_TRUE(TestCodeSize == 0);
+  TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  EXPECT_TRUE(TestCodeSize == 0);
+  Cost TestLatency = Visitor.getLatencyBonus();
+  EXPECT_TRUE(TestLatency == 0);
+
+  TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(2), Zero);
+  EXPECT_EQ(TestCodeSize, RefCodeSize);
+  EXPECT_TRUE(TestCodeSize > 0);
+  TestLatency = Visitor.getLatencyBonus();
+  EXPECT_EQ(TestLatency, RefLatency);
+  EXPECT_TRUE(TestLatency > 0);
 }
 
 TEST_F(FunctionSpecializationTest, Misc) {
@@ -332,26 +358,32 @@ TEST_F(FunctionSpecializationTest, Misc) {
   Instruction &Smax = *BlockIter++;
 
   // icmp + zext
-  Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
-  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
+  Cost Ref = getInstCodeSize(Icmp) + getInstCodeSize(Zext);
+  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // select
-  Ref = getInstCost(Select);
-  Test = Visitor.getSpecializationBonus(F->getArg(1), True);
+  Ref = getInstCodeSize(Select);
+  Test = Visitor.getCodeSizeBonus(F->getArg(1), True);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
   // gep + load + freeze + smax
-  Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
-        getInstCost(Smax);
-  Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
+  Ref = getInstCodeSize(Gep) + getInstCodeSize(Load) + getInstCodeSize(Freeze) +
+        getInstCodeSize(Smax);
+  Test = Visitor.getCodeSizeBonus(F->getArg(2), GV);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 
-  Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
-  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+  Test = Visitor.getCodeSizeBonus(F->getArg(3), Undef);
+  EXPECT_TRUE(Test == 0);
+
+  // Latency.
+  Ref = getLatency(F);
+  Test = Visitor.getLatencyBonus();
+  EXPECT_EQ(Test, Ref);
+  EXPECT_TRUE(Test > 0);
 }
 
 TEST_F(FunctionSpecializationTest, PhiNode) {
@@ -401,25 +433,34 @@ TEST_F(FunctionSpecializationTest, PhiNode) {
   Instruction &Icmp = *++BB.begin();
   Instruction &Branch = BB.back();
 
-  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
-  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
+  EXPECT_TRUE(Test == 0);
 
-  Test = Visitor.getSpecializationBonus(F->getArg(1), One);
-  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  EXPECT_TRUE(Test == 0);
+
+  Test = Visitor.getLatencyBonus();
+  EXPECT_TRUE(Test == 0);
 
   // switch + phi + br
-  Bonus Ref = getInstCost(Switch) +
-              getInstCost(PhiCase2, /*SizeOnly =*/ true) +
-              getInstCost(BrBB, /*SizeOnly =*/ true);
-  Test = Visitor.getSpecializationBonus(F->getArg(2), One);
+  Cost Ref = getInstCodeSize(Switch) +
+             getInstCodeSize(PhiCase2, /*SizeOnly =*/true) +
+             getInstCodeSize(BrBB, /*SizeOnly =*/true);
+  Test = Visitor.getCodeSizeBonus(F->getArg(2), One);
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0 && Test > 0);
 
   // phi + phi + add + icmp + branch
-  Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
-        getInstCost(Icmp) + getInstCost(Branch);
-  Test = Visitor.getBonusFromPendingPHIs();
+  Ref = getInstCodeSize(PhiBB) + getInstCodeSize(PhiLoop) +
+        getInstCodeSize(Add) + getInstCodeSize(Icmp) + getInstCodeSize(Branch);
+  Test = Visitor.getCodeSizeBonusFromPendingPHIs();
+  EXPECT_EQ(Test, Ref);
+  EXPECT_TRUE(Test > 0);
+
+  // Latency.
+  Ref = getLatency(F);
+  Test = Visitor.getLatencyBonus();
   EXPECT_EQ(Test, Ref);
-  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+  EXPECT_TRUE(Test > 0);
 }
 

>From fc3aaa44dbcc0e5d78e55816b51dc7437aae85d0 Mon Sep 17 00:00:00 2001
From: Hari Limaye <hari.limaye at arm.com>
Date: Mon, 21 Oct 2024 15:16:41 +0000
Subject: [PATCH 2/3] Address review comments

---
 .../Transforms/IPO/FunctionSpecialization.h   |  42 +------
 .../Transforms/IPO/FunctionSpecialization.cpp |  62 +++++++---
 .../IPO/FunctionSpecializationTest.cpp        | 117 +++++++++---------
 3 files changed, 109 insertions(+), 112 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index a50eafe48293ad..5920dde9d77dfd 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -140,38 +140,6 @@ struct Spec {
       : F(F), Sig(S), Score(Score) {}
 };
 
-struct Bonus {
-  unsigned CodeSize = 0;
-  unsigned Latency = 0;
-
-  Bonus() = default;
-
-  Bonus(Cost CodeSize, Cost Latency) {
-    int64_t Sz = *CodeSize.getValue();
-    int64_t Ltc = *Latency.getValue();
-
-    assert(Sz >= 0 && Ltc >= 0 && "CodeSize and Latency cannot be negative");
-    // It is safe to down cast since we know the arguments
-    // cannot be negative and Cost is of type int64_t.
-    this->CodeSize = static_cast<unsigned>(Sz);
-    this->Latency = static_cast<unsigned>(Ltc);
-  }
-
-  Bonus &operator+=(const Bonus RHS) {
-    CodeSize += RHS.CodeSize;
-    Latency += RHS.Latency;
-    return *this;
-  }
-
-  Bonus operator+(const Bonus RHS) const {
-    return Bonus(CodeSize + RHS.CodeSize, Latency + RHS.Latency);
-  }
-
-  bool operator==(const Bonus RHS) const {
-    return CodeSize == RHS.CodeSize && Latency == RHS.Latency;
-  }
-};
-
 class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   std::function<BlockFrequencyInfo &(Function &)> GetBFI;
   Function *F;
@@ -202,11 +170,11 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
     return Solver.isBlockExecutable(BB) && !DeadBlocks.contains(BB);
   }
 
-  Cost getCodeSizeBonus(Argument *A, Constant *C);
+  Cost getCodeSizeSavingsForArg(Argument *A, Constant *C);
 
-  Cost getCodeSizeBonusFromPendingPHIs();
+  Cost getCodeSizeSavingsFromPendingPHIs();
 
-  Cost getLatencyBonus();
+  Cost getLatencySavingsForKnownConstants();
 
 private:
   friend class InstVisitor<InstCostVisitor, Constant *>;
@@ -214,8 +182,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   static bool canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ,
                                     DenseSet<BasicBlock *> &DeadBlocks);
 
-  Cost getUserCodeSizeBonus(Instruction *User, Value *Use = nullptr,
-                            Constant *C = nullptr);
+  Cost getCodeSizeSavingsForUser(Instruction *User, Value *Use = nullptr,
+                                 Constant *C = nullptr);
 
   Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList);
   Cost estimateSwitchInst(SwitchInst &I);
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 38c10c29cf4e3a..04a62837ff962b 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -154,33 +154,45 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
   return KnownConstants.lookup(V);
 }
 
-Cost InstCostVisitor::getCodeSizeBonusFromPendingPHIs() {
+Cost InstCostVisitor::getCodeSizeSavingsFromPendingPHIs() {
   Cost CodeSize;
   while (!PendingPHIs.empty()) {
     Instruction *Phi = PendingPHIs.pop_back_val();
     // The pending PHIs could have been proven dead by now.
     if (isBlockExecutable(Phi->getParent()))
-      CodeSize += getUserCodeSizeBonus(Phi);
+      CodeSize += getCodeSizeSavingsForUser(Phi);
   }
   return CodeSize;
 }
 
 /// Compute the codesize savings for replacing argument \p A with constant \p C.
-Cost InstCostVisitor::getCodeSizeBonus(Argument *A, Constant *C) {
+Cost InstCostVisitor::getCodeSizeSavingsForArg(Argument *A, Constant *C) {
   LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
                     << C->getNameOrAsOperand() << "\n");
   Cost CodeSize;
   for (auto *U : A->users())
     if (auto *UI = dyn_cast<Instruction>(U))
       if (isBlockExecutable(UI->getParent()))
-        CodeSize += getUserCodeSizeBonus(UI, A, C);
+        CodeSize += getCodeSizeSavingsForUser(UI, A, C);
 
   LLVM_DEBUG(dbgs() << "FnSpecialization:   Accumulated bonus {CodeSize = "
                     << CodeSize << "} for argument " << *A << "\n");
   return CodeSize;
 }
 
-Cost InstCostVisitor::getLatencyBonus() {
+/// Compute the latency savings from replacing all arguments with constants for
+/// a specialization candidate. As this function computes the latency savings
+/// for all Instructions in KnownConstants at once, it should be called only
+/// after every instruction has been visited, i.e. after:
+///
+/// * getCodeSizeBonus has been run for every constant argument of a
+///   specialization candidate
+///
+/// * getCodeSizeBonusFromPendingPHIs has been run
+///
+/// to ensure that the latency savings are calculated for all Instructions we
+/// have visited and found to be constant.
+Cost InstCostVisitor::getLatencySavingsForKnownConstants() {
   auto &BFI = GetBFI(*F);
   Cost Latency = 0;
 
@@ -198,8 +210,8 @@ Cost InstCostVisitor::getLatencyBonus() {
   return Latency;
 }
 
-Cost InstCostVisitor::getUserCodeSizeBonus(Instruction *User, Value *Use,
-                                           Constant *C) {
+Cost InstCostVisitor::getCodeSizeSavingsForUser(Instruction *User, Value *Use,
+                                                Constant *C) {
   // We have already propagated a constant for this user.
   if (KnownConstants.contains(User))
     return 0;
@@ -232,7 +244,7 @@ Cost InstCostVisitor::getUserCodeSizeBonus(Instruction *User, Value *Use,
   for (auto *U : User->users())
     if (auto *UI = dyn_cast<Instruction>(U))
       if (UI != User && isBlockExecutable(UI->getParent()))
-        CodeSize += getUserCodeSizeBonus(UI, User, C);
+        CodeSize += getCodeSizeSavingsForUser(UI, User, C);
 
   return CodeSize;
 }
@@ -819,6 +831,15 @@ static Function *cloneCandidateFunction(Function *F, unsigned NSpecs) {
   return Clone;
 }
 
+static unsigned getCostValue(const Cost &C) {
+  int64_t Value = *C.getValue();
+
+  assert(Value >= 0 && "CodeSize and Latency cannot be negative");
+  // It is safe to down cast since we know the arguments cannot be negative and
+  // Cost is of type int64_t.
+  return static_cast<unsigned>(Value);
+}
+
 bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
                                               SmallVectorImpl<Spec> &AllSpecs,
                                               SpecMap &SM) {
@@ -889,18 +910,20 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       unsigned Score = 0;
       InstCostVisitor Visitor = getInstCostVisitorFor(F);
       for (ArgInfo &A : S.Args) {
-        CodeSize += Visitor.getCodeSizeBonus(A.Formal, A.Actual);
+        CodeSize += Visitor.getCodeSizeSavingsForArg(A.Formal, A.Actual);
         Score += getInliningBonus(A.Formal, A.Actual);
       }
-      CodeSize += Visitor.getCodeSizeBonusFromPendingPHIs();
+      CodeSize += Visitor.getCodeSizeSavingsFromPendingPHIs();
 
       LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = "
                         << CodeSize << ", Inlining = " << Score << "}\n");
 
-      Bonus B = {CodeSize, 0};
-      FunctionGrowth[F] += FuncSize - B.CodeSize;
+      unsigned LatencySavings = 0;
+      unsigned CodeSizeSavings = getCostValue(CodeSize);
+      FunctionGrowth[F] += FuncSize - CodeSizeSavings;
 
-      auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize,
+      auto IsProfitable = [](unsigned CodeSizeSavings, unsigned &LatencySavings,
+                             unsigned Score, unsigned FuncSize,
                              unsigned FuncGrowth, InstCostVisitor &V) -> bool {
         // No check required.
         if (ForceSpecialization)
@@ -909,18 +932,18 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
         if (Score > MinInliningBonus * FuncSize / 100)
           return true;
         // Minimum codesize savings.
-        if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100)
+        if (CodeSizeSavings < MinCodeSizeSavings * FuncSize / 100)
           return false;
 
         // Lazily compute the Latency, to avoid unnecessarily computing BFI.
-        B += {0, V.getLatencyBonus()};
+        LatencySavings = getCostValue(V.getLatencySavingsForKnownConstants());
 
         LLVM_DEBUG(
             dbgs() << "FnSpecialization: Specialization bonus {Latency = "
-                   << B.Latency << "}\n");
+                   << LatencySavings << "}\n");
 
         // Minimum latency savings.
-        if (B.Latency < MinLatencySavings * FuncSize / 100)
+        if (LatencySavings < MinLatencySavings * FuncSize / 100)
           return false;
         // Maximum codesize growth.
         if (FuncGrowth / FuncSize > MaxCodeSizeGrowth)
@@ -929,11 +952,12 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       };
 
       // Discard unprofitable specialisations.
-      if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F], Visitor))
+      if (!IsProfitable(CodeSizeSavings, LatencySavings, Score, FuncSize,
+                        FunctionGrowth[F], Visitor))
         continue;
 
       // Create a new specialisation entry.
-      Score += std::max(B.CodeSize, B.Latency);
+      Score += std::max(CodeSizeSavings, LatencySavings);
       auto &Spec = AllSpecs.emplace_back(F, S, Score);
       if (CS.getFunction() != F)
         Spec.CallSites.push_back(&CS);
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index 9f2053948cb33b..c8fd366bfac65f 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -43,7 +43,7 @@ class FunctionSpecializationTest : public testing::Test {
   FunctionAnalysisManager FAM;
   std::unique_ptr<Module> M;
   std::unique_ptr<SCCPSolver> Solver;
-  SmallVector<Instruction *, 8> Instructions;
+  SmallVector<Instruction *, 8> KnownConstants;
 
   FunctionSpecializationTest() {
     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
@@ -98,24 +98,24 @@ class FunctionSpecializationTest : public testing::Test {
                                GetAC);
   }
 
-  Cost getInstCodeSize(Instruction &I, bool SizeOnly = false) {
+  Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) {
     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
 
     Cost CodeSize =
         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
 
-    if (!SizeOnly)
-      Instructions.push_back(&I);
+    if (HasLatencySavings)
+      KnownConstants.push_back(&I);
 
     return CodeSize;
   }
 
-  Cost getLatency(Function *F) {
+  Cost getLatencySavings(Function *F) {
     auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F);
 
     Cost Latency = 0;
-    for (const Instruction *I : Instructions)
+    for (const Instruction *I : KnownConstants)
       Latency += BFI.getBlockFreq(I->getParent()).getFrequency() /
                  BFI.getEntryFreq().getFrequency() *
                  TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
@@ -180,28 +180,30 @@ TEST_F(FunctionSpecializationTest, SwitchInst) {
   Instruction &BrLoop = BB2.back();
 
   // mul
-  Cost Ref = getInstCodeSize(Mul);
-  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
+  Cost Ref = getCodeSizeSavings(Mul);
+  Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // and + or + add
-  Ref = getInstCodeSize(And) + getInstCodeSize(Or) + getInstCodeSize(Add);
-  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) +
+        getCodeSizeSavings(Add);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // switch + sdiv + br + br
-  Ref = getInstCodeSize(Switch) + getInstCodeSize(Sdiv, /*SizeOnly =*/true) +
-        getInstCodeSize(BrBB2, /*SizeOnly =*/true) +
-        getInstCodeSize(BrLoop, /*SizeOnly =*/true);
-  Test = Visitor.getCodeSizeBonus(F->getArg(2), One);
+  Ref = getCodeSizeSavings(Switch) +
+        getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
+        getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
+        getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // Latency.
-  Ref = getLatency(F);
-  Test = Visitor.getLatencyBonus();
+  Ref = getLatencySavings(F);
+  Test = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 }
@@ -252,29 +254,31 @@ TEST_F(FunctionSpecializationTest, BranchInst) {
   Instruction &BrLoop = BB2.front();
 
   // mul
-  Cost Ref = getInstCodeSize(Mul);
-  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
+  Cost Ref = getCodeSizeSavings(Mul);
+  Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // add
-  Ref = getInstCodeSize(Add);
-  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  Ref = getCodeSizeSavings(Add);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // branch + sub + br + sdiv + br
-  Ref = getInstCodeSize(Branch) + getInstCodeSize(Sub, /*SizeOnly =*/true) +
-        getInstCodeSize(BrBB1BB2) + getInstCodeSize(Sdiv, /*SizeOnly =*/true) +
-        getInstCodeSize(BrBB2, /*SizeOnly =*/true) +
-        getInstCodeSize(BrLoop, /*SizeOnly =*/true);
-  Test = Visitor.getCodeSizeBonus(F->getArg(2), False);
+  Ref = getCodeSizeSavings(Branch) +
+        getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) +
+        getCodeSizeSavings(BrBB1BB2) +
+        getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
+        getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
+        getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // Latency.
-  Ref = getLatency(F);
-  Test = Visitor.getLatencyBonus();
+  Ref = getLatencySavings(F);
+  Test = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 }
@@ -297,20 +301,20 @@ TEST_F(FunctionSpecializationTest, SelectInst) {
   Constant *False = ConstantInt::getFalse(M.getContext());
   Instruction &Select = *F->front().begin();
 
-  Cost RefCodeSize = getInstCodeSize(Select);
-  Cost RefLatency = getLatency(F);
+  Cost RefCodeSize = getCodeSizeSavings(Select);
+  Cost RefLatency = getLatencySavings(F);
 
-  Cost TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(0), False);
+  Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
   EXPECT_TRUE(TestCodeSize == 0);
-  TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
   EXPECT_TRUE(TestCodeSize == 0);
-  Cost TestLatency = Visitor.getLatencyBonus();
+  Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_TRUE(TestLatency == 0);
 
-  TestCodeSize = Visitor.getCodeSizeBonus(F->getArg(2), Zero);
+  TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero);
   EXPECT_EQ(TestCodeSize, RefCodeSize);
   EXPECT_TRUE(TestCodeSize > 0);
-  TestLatency = Visitor.getLatencyBonus();
+  TestLatency = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_EQ(TestLatency, RefLatency);
   EXPECT_TRUE(TestLatency > 0);
 }
@@ -358,30 +362,30 @@ TEST_F(FunctionSpecializationTest, Misc) {
   Instruction &Smax = *BlockIter++;
 
   // icmp + zext
-  Cost Ref = getInstCodeSize(Icmp) + getInstCodeSize(Zext);
-  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
+  Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext);
+  Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // select
-  Ref = getInstCodeSize(Select);
-  Test = Visitor.getCodeSizeBonus(F->getArg(1), True);
+  Ref = getCodeSizeSavings(Select);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // gep + load + freeze + smax
-  Ref = getInstCodeSize(Gep) + getInstCodeSize(Load) + getInstCodeSize(Freeze) +
-        getInstCodeSize(Smax);
-  Test = Visitor.getCodeSizeBonus(F->getArg(2), GV);
+  Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) +
+        getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
-  Test = Visitor.getCodeSizeBonus(F->getArg(3), Undef);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef);
   EXPECT_TRUE(Test == 0);
 
   // Latency.
-  Ref = getLatency(F);
-  Test = Visitor.getLatencyBonus();
+  Ref = getLatencySavings(F);
+  Test = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 }
@@ -433,33 +437,34 @@ TEST_F(FunctionSpecializationTest, PhiNode) {
   Instruction &Icmp = *++BB.begin();
   Instruction &Branch = BB.back();
 
-  Cost Test = Visitor.getCodeSizeBonus(F->getArg(0), One);
+  Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
   EXPECT_TRUE(Test == 0);
 
-  Test = Visitor.getCodeSizeBonus(F->getArg(1), One);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
   EXPECT_TRUE(Test == 0);
 
-  Test = Visitor.getLatencyBonus();
+  Test = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_TRUE(Test == 0);
 
   // switch + phi + br
-  Cost Ref = getInstCodeSize(Switch) +
-             getInstCodeSize(PhiCase2, /*SizeOnly =*/true) +
-             getInstCodeSize(BrBB, /*SizeOnly =*/true);
-  Test = Visitor.getCodeSizeBonus(F->getArg(2), One);
+  Cost Ref = getCodeSizeSavings(Switch) +
+             getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) +
+             getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false);
+  Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0 && Test > 0);
 
   // phi + phi + add + icmp + branch
-  Ref = getInstCodeSize(PhiBB) + getInstCodeSize(PhiLoop) +
-        getInstCodeSize(Add) + getInstCodeSize(Icmp) + getInstCodeSize(Branch);
-  Test = Visitor.getCodeSizeBonusFromPendingPHIs();
+  Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) +
+        getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) +
+        getCodeSizeSavings(Branch);
+  Test = Visitor.getCodeSizeSavingsFromPendingPHIs();
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 
   // Latency.
-  Ref = getLatency(F);
-  Test = Visitor.getLatencyBonus();
+  Ref = getLatencySavings(F);
+  Test = Visitor.getLatencySavingsForKnownConstants();
   EXPECT_EQ(Test, Ref);
   EXPECT_TRUE(Test > 0);
 }

>From b9336f26bc021d91a1fc1d5a2495b11124ab67f5 Mon Sep 17 00:00:00 2001
From: Hari Limaye <hari.limaye at arm.com>
Date: Mon, 21 Oct 2024 21:49:53 +0000
Subject: [PATCH 3/3] Address review comments 2

- Fix outdated function names in comments.
- Sink debug messages about profitability into the IsProfitable lambda,
  and also add percentage of Function size while we're here.
- Move other profitability related statements into IsProfitable.
- Add a TODO to improve the accumulation of codesize increases
  (FunctionGrowth)
---
 .../Transforms/IPO/FunctionSpecialization.cpp | 48 ++++++++++++-------
 1 file changed, 30 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 04a62837ff962b..35a75515cca223 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -185,10 +185,10 @@ Cost InstCostVisitor::getCodeSizeSavingsForArg(Argument *A, Constant *C) {
 /// for all Instructions in KnownConstants at once, it should be called only
 /// after every instruction has been visited, i.e. after:
 ///
-/// * getCodeSizeBonus has been run for every constant argument of a
+/// * getCodeSizeSavingsForArg has been run for every constant argument of a
 ///   specialization candidate
 ///
-/// * getCodeSizeBonusFromPendingPHIs has been run
+/// * getCodeSizeSavingsFromPendingPHIs has been run
 ///
 /// to ensure that the latency savings are calculated for all Instructions we
 /// have visited and found to be constant.
@@ -831,6 +831,9 @@ static Function *cloneCandidateFunction(Function *F, unsigned NSpecs) {
   return Clone;
 }
 
+/// Get the unsigned Value of given Cost object. Assumes the Cost is always
+/// non-negative, which is true for both TCK_CodeSize and TCK_Latency, and
+/// always Valid.
 static unsigned getCostValue(const Cost &C) {
   int64_t Value = *C.getValue();
 
@@ -915,49 +918,58 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
       }
       CodeSize += Visitor.getCodeSizeSavingsFromPendingPHIs();
 
-      LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = "
-                        << CodeSize << ", Inlining = " << Score << "}\n");
-
-      unsigned LatencySavings = 0;
-      unsigned CodeSizeSavings = getCostValue(CodeSize);
-      FunctionGrowth[F] += FuncSize - CodeSizeSavings;
-
-      auto IsProfitable = [](unsigned CodeSizeSavings, unsigned &LatencySavings,
-                             unsigned Score, unsigned FuncSize,
-                             unsigned FuncGrowth, InstCostVisitor &V) -> bool {
+      auto IsProfitable = [&]() -> bool {
         // No check required.
         if (ForceSpecialization)
           return true;
+
+        unsigned CodeSizeSavings = getCostValue(CodeSize);
+        // TODO: We should only accumulate codesize increase of specializations
+        // that are actually created.
+        FunctionGrowth[F] += FuncSize - CodeSizeSavings;
+
+        LLVM_DEBUG(
+            dbgs() << "FnSpecialization: Specialization bonus {Inlining = "
+                   << Score << " (" << (Score * 100 / FuncSize) << "%)}\n");
+
         // Minimum inlining bonus.
         if (Score > MinInliningBonus * FuncSize / 100)
           return true;
+
+        LLVM_DEBUG(
+            dbgs() << "FnSpecialization: Specialization bonus {CodeSize = "
+                   << CodeSizeSavings << " ("
+                   << (CodeSizeSavings * 100 / FuncSize) << "%)}\n");
+
         // Minimum codesize savings.
         if (CodeSizeSavings < MinCodeSizeSavings * FuncSize / 100)
           return false;
 
         // Lazily compute the Latency, to avoid unnecessarily computing BFI.
-        LatencySavings = getCostValue(V.getLatencySavingsForKnownConstants());
+        unsigned LatencySavings =
+            getCostValue(Visitor.getLatencySavingsForKnownConstants());
 
         LLVM_DEBUG(
             dbgs() << "FnSpecialization: Specialization bonus {Latency = "
-                   << LatencySavings << "}\n");
+                   << LatencySavings << " ("
+                   << (LatencySavings * 100 / FuncSize) << "%)}\n");
 
         // Minimum latency savings.
         if (LatencySavings < MinLatencySavings * FuncSize / 100)
           return false;
         // Maximum codesize growth.
-        if (FuncGrowth / FuncSize > MaxCodeSizeGrowth)
+        if (FunctionGrowth[F] / FuncSize > MaxCodeSizeGrowth)
           return false;
+
+        Score += std::max(CodeSizeSavings, LatencySavings);
         return true;
       };
 
       // Discard unprofitable specialisations.
-      if (!IsProfitable(CodeSizeSavings, LatencySavings, Score, FuncSize,
-                        FunctionGrowth[F], Visitor))
+      if (!IsProfitable())
         continue;
 
       // Create a new specialisation entry.
-      Score += std::max(CodeSizeSavings, LatencySavings);
       auto &Spec = AllSpecs.emplace_back(F, S, Score);
       if (CS.getFunction() != F)
         Spec.CallSites.push_back(&CS);



More information about the llvm-commits mailing list