[llvm] 03f1d09 - [FuncSpec] Add Phi nodes to the InstCostVisitor.
Alexandros Lamprineas via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 25 03:01:34 PDT 2023
Author: Alexandros Lamprineas
Date: 2023-07-25T11:00:20+01:00
New Revision: 03f1d09fe484f6c924434bc9c888e022b3514455
URL: https://github.com/llvm/llvm-project/commit/03f1d09fe484f6c924434bc9c888e022b3514455
DIFF: https://github.com/llvm/llvm-project/commit/03f1d09fe484f6c924434bc9c888e022b3514455.diff
LOG: [FuncSpec] Add Phi nodes to the InstCostVisitor.
This patch allows constant folding of PHIs when estimating the user
bonus. Phi nodes are a special case since some of their inputs may
remain unresolved until all the specialization arguments have been
processed by the InstCostVisitor. Therefore, we keep a list of dead
basic blocks and then lazily visit the Phi nodes once the user bonus
has been computed for all the specialization arguments.
Differential Revision: https://reviews.llvm.org/D154852
Added:
Modified:
llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index f780385f7f67da..4e78d9db024c56 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -126,6 +126,15 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
SCCPSolver &Solver;
ConstMap KnownConstants;
+ // Basic blocks known to be unreachable after constant propagation.
+ DenseSet<BasicBlock *> DeadBlocks;
+ // PHI nodes we have visited before.
+ DenseSet<Instruction *> VisitedPHIs;
+ // PHI nodes we have visited once without successfully constant folding them.
+ // Once the InstCostVisitor has processed all the specialization arguments,
+ // it should be possible to determine whether those PHIs can be folded
+ // (some of their incoming values may have become constant or dead).
+ SmallVector<Instruction *> PendingPHIs;
ConstMap::iterator LastVisited;
@@ -134,7 +143,10 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
TargetTransformInfo &TTI, SCCPSolver &Solver)
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
- Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
+ Cost getUserBonus(Instruction *User, Value *Use = nullptr,
+ Constant *C = nullptr);
+
+ Cost getBonusFromPendingPHIs();
private:
friend class InstVisitor<InstCostVisitor, Constant *>;
@@ -143,6 +155,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
Cost estimateBranchInst(BranchInst &I);
Constant *visitInstruction(Instruction &I) { return nullptr; }
+ Constant *visitPHINode(PHINode &I);
Constant *visitFreezeInst(FreezeInst &I);
Constant *visitCallBase(CallBase &I);
Constant *visitLoadInst(LoadInst &I);
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 47d1fcfabc7845..cae0fc7b733543 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -78,6 +78,11 @@ static cl::opt<unsigned> MaxClones(
"The maximum number of clones allowed for a single function "
"specialization"));
+static cl::opt<unsigned> MaxIncomingPhiValues(
+ "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
+ "The maximum number of incoming values a PHI node can have to be "
+ "considered during the specialization bonus estimation"));
+
static cl::opt<unsigned> MinFunctionSize(
"funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
"Don't specialize functions that have less than this number of "
@@ -104,6 +109,7 @@ static cl::opt<bool> SpecializeLiteralConstant(
// the combination of size and latency savings in comparison to the non
// specialized version of the function.
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
+ DenseSet<BasicBlock *> &DeadBlocks,
ConstMap &KnownConstants, SCCPSolver &Solver,
BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI) {
@@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
if (!Weight)
continue;
+ // These blocks are considered dead as far as the InstCostVisitor is
+ // concerned. They haven't been proven dead yet by the Solver, but
+ // may become if we propagate the constant specialization arguments.
+ if (!DeadBlocks.insert(BB).second)
+ continue;
+
for (Instruction &I : *BB) {
// Disregard SSA copies.
if (auto *II = dyn_cast<IntrinsicInst>(&I))
@@ -152,9 +164,19 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
return nullptr;
}
+Cost InstCostVisitor::getBonusFromPendingPHIs() {
+ Cost Bonus = 0;
+ while (!PendingPHIs.empty()) {
+ Instruction *Phi = PendingPHIs.pop_back_val();
+ Bonus += getUserBonus(Phi);
+ }
+ return Bonus;
+}
+
Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
// Cache the iterator before visiting.
- LastVisited = KnownConstants.insert({Use, C}).first;
+ LastVisited = Use ? KnownConstants.insert({Use, C}).first
+ : KnownConstants.end();
if (auto *I = dyn_cast<SwitchInst>(User))
return estimateSwitchInst(*I);
@@ -181,13 +203,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
- if (Solver.isBlockExecutable(UI->getParent()))
+ if (UI != User && Solver.isBlockExecutable(UI->getParent()))
Bonus += getUserBonus(UI, User, C);
return Bonus;
}
Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
if (I.getCondition() != LastVisited->first)
return 0;
@@ -208,10 +232,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
WorkList.push_back(BB);
}
- return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+ return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+ TTI);
}
Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
if (I.getCondition() != LastVisited->first)
return 0;
@@ -223,10 +250,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
Succ->getUniquePredecessor() == I.getParent())
WorkList.push_back(Succ);
- return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+ return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
+ TTI);
+}
+
+Constant *InstCostVisitor::visitPHINode(PHINode &I) {
+ if (I.getNumIncomingValues() > MaxIncomingPhiValues)
+ return nullptr;
+
+ bool Inserted = VisitedPHIs.insert(&I).second;
+ Constant *Const = nullptr;
+
+ for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
+ Value *V = I.getIncomingValue(Idx);
+ if (auto *Inst = dyn_cast<Instruction>(V))
+ if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
+ continue;
+ Constant *C = findConstantFor(V, KnownConstants);
+ if (!C) {
+ if (Inserted)
+ PendingPHIs.push_back(&I);
+ return nullptr;
+ }
+ if (!Const)
+ Const = C;
+ else if (C != Const)
+ return nullptr;
+ }
+ return Const;
}
Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
return LastVisited->second;
return nullptr;
@@ -253,6 +309,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
}
Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
if (isa<ConstantPointerNull>(LastVisited->second))
return nullptr;
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
@@ -275,6 +333,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
}
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
if (I.getCondition() != LastVisited->first)
return nullptr;
@@ -290,6 +350,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
}
Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
@@ -303,10 +365,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
}
Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
}
Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
+ assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
+
bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
@@ -713,13 +779,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
- Cost Score = 0 - SpecCost;
+ Cost Score = 0;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
+ Score += Visitor.getBonusFromPendingPHIs();
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
+ << Score << "\n");
// Discard unprofitable specialisations.
- if (!ForceSpecialization && Score <= 0)
+ if (!ForceSpecialization && Score <= SpecCost)
continue;
// Create a new specialisation entry.
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index 222311dc040a86..be6f8fc990b3c7 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -287,3 +287,56 @@ TEST_F(FunctionSpecializationTest, Misc) {
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
EXPECT_TRUE(Bonus == 0);
}
+
+TEST_F(FunctionSpecializationTest, PhiNode) {
+ const char *ModuleString = R"(
+ define void @foo(i32 %a, i32 %b, i32 %i) {
+ entry:
+ br label %loop
+ loop:
+ switch i32 %i, label %default
+ [ i32 1, label %case1
+ i32 2, label %case2 ]
+ case1:
+ %0 = add i32 %a, 1
+ br label %bb
+ case2:
+ %1 = sub i32 %b, 1
+ br label %bb
+ bb:
+ %2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ]
+ %3 = icmp eq i32 %2, 2
+ br i1 %3, label %bb, label %loop
+ default:
+ ret void
+ }
+ )";
+
+ Module &M = parseModule(ModuleString);
+ Function *F = M.getFunction("foo");
+ FunctionSpecializer Specializer = getSpecializerFor(F);
+ InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+ Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+
+ auto FuncIter = F->begin();
+ for (int I = 0; I < 4; ++I)
+ ++FuncIter;
+
+ BasicBlock &BB = *FuncIter;
+
+ Instruction &Phi = BB.front();
+ Instruction &Icmp = *++BB.begin();
+
+ Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) +
+ Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) +
+ Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
+ EXPECT_TRUE(Bonus > 0);
+
+ // phi + icmp
+ Cost Ref = getInstCost(Phi) + getInstCost(Icmp);
+ Bonus = Visitor.getBonusFromPendingPHIs();
+ EXPECT_EQ(Bonus, Ref);
+ EXPECT_TRUE(Bonus > 0);
+}
+
More information about the llvm-commits
mailing list