[llvm] 893d3a6 - Reland [FuncSpec] Add Phi nodes to the InstCostVisitor.

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 00:29:36 PDT 2023


Author: Alexandros Lamprineas
Date: 2023-07-31T08:25:48+01:00
New Revision: 893d3a61c0574a539f2c15206117e6535846aacb

URL: https://github.com/llvm/llvm-project/commit/893d3a61c0574a539f2c15206117e6535846aacb
DIFF: https://github.com/llvm/llvm-project/commit/893d3a61c0574a539f2c15206117e6535846aacb.diff

LOG: Reland [FuncSpec] Add Phi nodes to the InstCostVisitor.

This patch allows constant folding of PHIs when estimating the user
bonus. Phi nodes are a special case since some of their inputs may
remain unresolved until all the specialization arguments have been
processed by the InstCostVisitor. Therefore, we keep a list of dead
basic blocks and then lazily visit the Phi nodes once the user bonus
has been computed for all the specialization arguments.

Differential Revision: https://reviews.llvm.org/D154852

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index b20212a21b3264..5efb3cb240ff8d 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -123,6 +123,15 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   SCCPSolver &Solver;
 
   ConstMap KnownConstants;
+  // Basic blocks known to be unreachable after constant propagation.
+  DenseSet<BasicBlock *> DeadBlocks;
+  // PHI nodes we have visited before.
+  DenseSet<Instruction *> VisitedPHIs;
+  // PHI nodes we have visited once without successfully constant folding them.
+  // Once the InstCostVisitor has processed all the specialization arguments,
+  // it should be possible to determine whether those PHIs can be folded
+  // (some of their incoming values may have become constant or dead).
+  SmallVector<Instruction *> PendingPHIs;
 
   ConstMap::iterator LastVisited;
 
@@ -131,7 +140,14 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
                   TargetTransformInfo &TTI, SCCPSolver &Solver)
       : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
 
-  Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
+  bool isBlockExecutable(BasicBlock *BB) {
+    return Solver.isBlockExecutable(BB) && !DeadBlocks.contains(BB);
+  }
+
+  Cost getUserBonus(Instruction *User, Value *Use = nullptr,
+                    Constant *C = nullptr);
+
+  Cost getBonusFromPendingPHIs();
 
 private:
   friend class InstVisitor<InstCostVisitor, Constant *>;
@@ -140,6 +156,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
   Cost estimateBranchInst(BranchInst &I);
 
   Constant *visitInstruction(Instruction &I) { return nullptr; }
+  Constant *visitPHINode(PHINode &I);
   Constant *visitFreezeInst(FreezeInst &I);
   Constant *visitCallBase(CallBase &I);
   Constant *visitLoadInst(LoadInst &I);

diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index ac5dbc7cfb2a56..d917342a7d2905 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -78,6 +78,11 @@ static cl::opt<unsigned> MaxClones(
     "The maximum number of clones allowed for a single function "
     "specialization"));
 
+static cl::opt<unsigned> MaxIncomingPhiValues(
+    "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
+    "The maximum number of incoming values a PHI node can have to be "
+    "considered during the specialization bonus estimation"));
+
 static cl::opt<unsigned> MinFunctionSize(
     "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
     "Don't specialize functions that have less than this number of "
@@ -104,6 +109,7 @@ static cl::opt<bool> SpecializeLiteralConstant(
 // the combination of size and latency savings in comparison to the non
 // specialized version of the function.
 static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
+                                DenseSet<BasicBlock *> &DeadBlocks,
                                 ConstMap &KnownConstants, SCCPSolver &Solver,
                                 BlockFrequencyInfo &BFI,
                                 TargetTransformInfo &TTI) {
@@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
     if (!Weight)
       continue;
 
+    // These blocks are considered dead as far as the InstCostVisitor
+    // is concerned. They haven't been proven dead yet by the Solver,
+    // but may become if we propagate the specialization arguments.
+    if (!DeadBlocks.insert(BB).second)
+      continue;
+
     for (Instruction &I : *BB) {
       // Disregard SSA copies.
       if (auto *II = dyn_cast<IntrinsicInst>(&I))
@@ -152,9 +164,25 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
   return nullptr;
 }
 
+Cost InstCostVisitor::getBonusFromPendingPHIs() {
+  Cost Bonus = 0;
+  while (!PendingPHIs.empty()) {
+    Instruction *Phi = PendingPHIs.pop_back_val();
+    // The pending PHIs could have been proven dead by now.
+    if (isBlockExecutable(Phi->getParent()))
+      Bonus += getUserBonus(Phi);
+  }
+  return Bonus;
+}
+
 Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+  // We have already propagated a constant for this user.
+  if (KnownConstants.contains(User))
+    return 0;
+
   // Cache the iterator before visiting.
-  LastVisited = KnownConstants.insert({Use, C}).first;
+  LastVisited = Use ? KnownConstants.insert({Use, C}).first
+                    : KnownConstants.end();
 
   if (auto *I = dyn_cast<SwitchInst>(User))
     return estimateSwitchInst(*I);
@@ -181,13 +209,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
 
   for (auto *U : User->users())
     if (auto *UI = dyn_cast<Instruction>(U))
-      if (Solver.isBlockExecutable(UI->getParent()))
+      if (UI != User && isBlockExecutable(UI->getParent()))
         Bonus += getUserBonus(UI, User, C);
 
   return Bonus;
 }
 
 Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   if (I.getCondition() != LastVisited->first)
     return 0;
 
@@ -208,10 +238,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
     WorkList.push_back(BB);
   }
 
-  return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+  return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+                             TTI);
 }
 
 Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   if (I.getCondition() != LastVisited->first)
     return 0;
 
@@ -223,10 +256,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
       Succ->getUniquePredecessor() == I.getParent())
     WorkList.push_back(Succ);
 
-  return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+  return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+                             TTI);
+}
+
+Constant *InstCostVisitor::visitPHINode(PHINode &I) {
+  if (I.getNumIncomingValues() > MaxIncomingPhiValues)
+    return nullptr;
+
+  bool Inserted = VisitedPHIs.insert(&I).second;
+  Constant *Const = nullptr;
+
+  for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
+    Value *V = I.getIncomingValue(Idx);
+    if (auto *Inst = dyn_cast<Instruction>(V))
+      if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
+        continue;
+    Constant *C = findConstantFor(V, KnownConstants);
+    if (!C) {
+      if (Inserted)
+        PendingPHIs.push_back(&I);
+      return nullptr;
+    }
+    if (!Const)
+      Const = C;
+    else if (C != Const)
+      return nullptr;
+  }
+  return Const;
 }
 
 Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
     return LastVisited->second;
   return nullptr;
@@ -253,6 +315,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
 }
 
 Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   if (isa<ConstantPointerNull>(LastVisited->second))
     return nullptr;
   return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
@@ -275,6 +339,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
 }
 
 Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   if (I.getCondition() != LastVisited->first)
     return nullptr;
 
@@ -290,6 +356,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
 }
 
 Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   bool Swap = I.getOperand(1) == LastVisited->first;
   Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
   Constant *Other = findConstantFor(V, KnownConstants);
@@ -303,10 +371,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
 }
 
 Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
 }
 
 Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
+  assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
   bool Swap = I.getOperand(1) == LastVisited->first;
   Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
   Constant *Other = findConstantFor(V, KnownConstants);
@@ -713,13 +785,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
       AllSpecs[Index].CallSites.push_back(&CS);
     } else {
       // Calculate the specialisation gain.
-      Cost Score = 0 - SpecCost;
+      Cost Score = 0;
       InstCostVisitor Visitor = getInstCostVisitorFor(F);
       for (ArgInfo &A : S.Args)
         Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
+      Score += Visitor.getBonusFromPendingPHIs();
+
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
+                        << Score << "\n");
 
       // Discard unprofitable specialisations.
-      if (!ForceSpecialization && Score <= 0)
+      if (!ForceSpecialization && Score <= SpecCost)
         continue;
 
       // Create a new specialisation entry.
@@ -798,7 +874,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
   Cost TotalCost = 0;
   for (auto *U : A->users())
     if (auto *UI = dyn_cast<Instruction>(U))
-      if (Solver.isBlockExecutable(UI->getParent()))
+      if (Visitor.isBlockExecutable(UI->getParent()))
         TotalCost += Visitor.getUserBonus(UI, A, C);
 
   LLVM_DEBUG(dbgs() << "FnSpecialization:   Accumulated user bonus "

diff  --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index 81da6d8f6ed5c9..6018263cad6586 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -302,3 +302,69 @@ TEST_F(FunctionSpecializationTest, Misc) {
   Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
   EXPECT_TRUE(Bonus == 0);
 }
+
+TEST_F(FunctionSpecializationTest, PhiNode) {
+  const char *ModuleString = R"(
+    define void @foo(i32 %a, i32 %b, i32 %i) {
+    entry:
+      br label %loop
+    loop:
+      %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
+      switch i32 %i, label %default
+      [ i32 1, label %case1
+        i32 2, label %case2 ]
+    case1:
+      %1 = add i32 %0, 1
+      br label %bb
+    case2:
+      %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
+      br label %bb
+    bb:
+      %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
+      %4 = icmp eq i32 %3, 1
+      br i1 %4, label %bb, label %loop
+    default:
+      ret void
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+
+  auto FuncIter = F->begin();
+  BasicBlock &Loop = *++FuncIter;
+  BasicBlock &Case1 = *++FuncIter;
+  BasicBlock &Case2 = *++FuncIter;
+  BasicBlock &BB = *++FuncIter;
+
+  Instruction &PhiLoop = Loop.front();
+  Instruction &Add = Case1.front();
+  Instruction &PhiCase2 = Case2.front();
+  Instruction &BrBB = Case2.back();
+  Instruction &PhiBB = BB.front();
+  Instruction &Icmp = *++BB.begin();
+
+  Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+  EXPECT_EQ(Bonus, 0);
+
+  Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
+  EXPECT_EQ(Bonus, 0);
+
+  // phi + br
+  Cost Ref = getInstCost(PhiCase2) + getInstCost(BrBB);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+  EXPECT_TRUE(Bonus > 0);
+
+  // phi + phi + add + icmp
+  Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
+        getInstCost(Icmp);
+  Bonus = Visitor.getBonusFromPendingPHIs();
+  EXPECT_EQ(Bonus, Ref);
+  EXPECT_TRUE(Bonus > 0);
+}
+


        


More information about the llvm-commits mailing list