[llvm] 54e5fb7 - [FuncSpec] Track the return values of specializations.
Alexandros Lamprineas via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 24 05:24:21 PDT 2023
Author: Alexandros Lamprineas
Date: 2023-04-24T13:18:49+01:00
New Revision: 54e5fb789cfafc2eee4aa6df1650b89ee7b85ba4
URL: https://github.com/llvm/llvm-project/commit/54e5fb789cfafc2eee4aa6df1650b89ee7b85ba4
DIFF: https://github.com/llvm/llvm-project/commit/54e5fb789cfafc2eee4aa6df1650b89ee7b85ba4.diff
LOG: [FuncSpec] Track the return values of specializations.
To track the return values of specializations, we need to invalidate all
the lattice values across the use-def chain which originates from the
callsites, recompute and propagate.
Differential Revision: https://reviews.llvm.org/D146158
Added:
llvm/test/Transforms/FunctionSpecialization/track-return.ll
Modified:
llvm/include/llvm/Transforms/Utils/SCCPSolver.h
llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
llvm/lib/Transforms/Utils/SCCPSolver.cpp
llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll
llvm/test/Transforms/FunctionSpecialization/non-argument-tracked.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 3512f0a867b91..cf3c3b7eee49f 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -132,6 +132,8 @@ class SCCPSolver {
void solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList);
+ void solveWhileResolvedUndefs();
+
bool isBlockExecutable(BasicBlock *BB) const;
// isEdgeFeasible - Return true if the control flow edge from the 'From' basic
@@ -142,6 +144,10 @@ class SCCPSolver {
void removeLatticeValueFor(Value *V);
+ /// Invalidate the Lattice Value of \p Call and its users after specializing
+ /// the call. Then recompute it.
+ void resetLatticeValueFor(CallBase *Call);
+
const ValueLatticeElement &getLatticeValueFor(Value *V) const;
/// getTrackedRetVals - Get the inferred return value map.
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 207737f1185c3..379211316b95a 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -364,6 +364,33 @@ bool FunctionSpecializer::run() {
updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End);
}
+ for (Function *F : Clones) {
+ if (F->getReturnType()->isVoidTy())
+ continue;
+ if (F->getReturnType()->isStructTy()) {
+ auto *STy = cast<StructType>(F->getReturnType());
+ if (!Solver.isStructLatticeConstant(F, STy))
+ continue;
+ } else {
+ auto It = Solver.getTrackedRetVals().find(F);
+ assert(It != Solver.getTrackedRetVals().end() &&
+ "Return value ought to be tracked");
+ if (SCCPSolver::isOverdefined(It->second))
+ continue;
+ }
+ for (User *U : F->users()) {
+ if (auto *CS = dyn_cast<CallBase>(U)) {
+ //The user instruction does not call our function.
+ if (CS->getCalledFunction() != F)
+ continue;
+ Solver.resetLatticeValueFor(CS);
+ }
+ }
+ }
+
+ // Rerun the solver to notify the users of the modified callsites.
+ Solver.solveWhileResolvedUndefs();
+
promoteConstantStackValues();
return true;
}
@@ -538,9 +565,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &
// marking the argument on which we specialized the function constant
// with the given value.
Solver.setLatticeValueForSpecializationArguments(Clone, S.Args);
-
- Solver.addArgumentTrackedFunction(Clone);
Solver.markBlockExecutable(&Clone->front());
+ Solver.addArgumentTrackedFunction(Clone);
+ Solver.addTrackedFunction(Clone);
// Mark all the specialized functions
Specializations.insert(Clone);
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index cc86ab26471a6..881c3cc7b56f6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -352,6 +352,10 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
MapVector<std::pair<Function *, unsigned>, ValueLatticeElement>
TrackedMultipleRetVals;
+ /// The set of values whose lattice has been invalidated.
+ /// Populated by resetLatticeValueFor(), cleared after resolving undefs.
+ DenseSet<Value *> Invalidated;
+
/// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is
/// represented here for efficient lookup.
SmallPtrSet<Function *, 16> MRVFunctionsTracked;
@@ -477,6 +481,64 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
return LV;
}
+ /// Traverse the use-def chain of \p Call, marking itself and its users as
+ /// "unknown" on the way.
+ void invalidate(CallBase *Call) {
+ SmallVector<Instruction *, 64> ToInvalidate;
+ ToInvalidate.push_back(Call);
+
+ while (!ToInvalidate.empty()) {
+ Instruction *Inst = ToInvalidate.pop_back_val();
+
+ if (!Invalidated.insert(Inst).second)
+ continue;
+
+ if (!BBExecutable.count(Inst->getParent()))
+ continue;
+
+ Value *V = nullptr;
+ // For return instructions we need to invalidate the tracked returns map.
+ // Anything else has its lattice in the value map.
+ if (auto *RetInst = dyn_cast<ReturnInst>(Inst)) {
+ Function *F = RetInst->getParent()->getParent();
+ if (auto It = TrackedRetVals.find(F); It != TrackedRetVals.end()) {
+ It->second = ValueLatticeElement();
+ V = F;
+ } else if (MRVFunctionsTracked.count(F)) {
+ auto *STy = cast<StructType>(F->getReturnType());
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I)
+ TrackedMultipleRetVals[{F, I}] = ValueLatticeElement();
+ V = F;
+ }
+ } else if (auto *STy = dyn_cast<StructType>(Inst->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ if (auto It = StructValueState.find({Inst, I});
+ It != StructValueState.end()) {
+ It->second = ValueLatticeElement();
+ V = Inst;
+ }
+ }
+ } else if (auto It = ValueState.find(Inst); It != ValueState.end()) {
+ It->second = ValueLatticeElement();
+ V = Inst;
+ }
+
+ if (V) {
+ LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n");
+
+ for (User *U : V->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ ToInvalidate.push_back(UI);
+
+ auto It = AdditionalUsers.find(V);
+ if (It != AdditionalUsers.end())
+ for (User *U : It->second)
+ if (auto *UI = dyn_cast<Instruction>(U))
+ ToInvalidate.push_back(UI);
+ }
+ }
+ }
+
/// markEdgeExecutable - Mark a basic block as executable, adding it to the BB
/// work list if it is not already executable.
bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest);
@@ -657,6 +719,8 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
void solve();
+ bool resolvedUndef(Instruction &I);
+
bool resolvedUndefsIn(Function &F);
bool isBlockExecutable(BasicBlock *BB) const {
@@ -679,6 +743,19 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
void removeLatticeValueFor(Value *V) { ValueState.erase(V); }
+ /// Invalidate the Lattice Value of \p Call and its users after specializing
+ /// the call. Then recompute it.
+ void resetLatticeValueFor(CallBase *Call) {
+ // Calls to void returning functions do not need invalidation.
+ Function *F = Call->getCalledFunction();
+ (void)F;
+ assert(!F->getReturnType()->isVoidTy() &&
+ (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) &&
+ "All non void specializations should be tracked");
+ invalidate(Call);
+ handleCallResult(*Call);
+ }
+
const ValueLatticeElement &getLatticeValueFor(Value *V) const {
assert(!V->getType()->isStructTy() &&
"Should use getStructLatticeValueFor");
@@ -746,6 +823,18 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
ResolvedUndefs |= resolvedUndefsIn(*F);
}
}
+
+ void solveWhileResolvedUndefs() {
+ bool ResolvedUndefs = true;
+ while (ResolvedUndefs) {
+ solve();
+ ResolvedUndefs = false;
+ for (Value *V : Invalidated)
+ if (auto *I = dyn_cast<Instruction>(V))
+ ResolvedUndefs |= resolvedUndef(*I);
+ }
+ Invalidated.clear();
+ }
};
} // namespace llvm
@@ -1720,6 +1809,7 @@ void SCCPInstVisitor::solve() {
// things to overdefined more quickly.
while (!OverdefinedInstWorkList.empty()) {
Value *I = OverdefinedInstWorkList.pop_back_val();
+ Invalidated.erase(I);
LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n');
@@ -1736,6 +1826,7 @@ void SCCPInstVisitor::solve() {
// Process the instruction work list.
while (!InstWorkList.empty()) {
Value *I = InstWorkList.pop_back_val();
+ Invalidated.erase(I);
LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n');
@@ -1763,6 +1854,61 @@ void SCCPInstVisitor::solve() {
}
}
+bool SCCPInstVisitor::resolvedUndef(Instruction &I) {
+ // Look for instructions which produce undef values.
+ if (I.getType()->isVoidTy())
+ return false;
+
+ if (auto *STy = dyn_cast<StructType>(I.getType())) {
+ // Only a few things that can be structs matter for undef.
+
+ // Tracked calls must never be marked overdefined in resolvedUndefsIn.
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
+ if (MRVFunctionsTracked.count(F))
+ return false;
+
+ // extractvalue and insertvalue don't need to be marked; they are
+ // tracked as precisely as their operands.
+ if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
+ return false;
+ // Send the results of everything else to overdefined. We could be
+ // more precise than this but it isn't worth bothering.
+ for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
+ ValueLatticeElement &LV = getStructValueState(&I, i);
+ if (LV.isUnknown()) {
+ markOverdefined(LV, &I);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ ValueLatticeElement &LV = getValueState(&I);
+ if (!LV.isUnknown())
+ return false;
+
+ // There are two reasons a call can have an undef result
+ // 1. It could be tracked.
+ // 2. It could be constant-foldable.
+ // Because of the way we solve return values, tracked calls must
+ // never be marked overdefined in resolvedUndefsIn.
+ if (auto *CB = dyn_cast<CallBase>(&I))
+ if (Function *F = CB->getCalledFunction())
+ if (TrackedRetVals.count(F))
+ return false;
+
+ if (isa<LoadInst>(I)) {
+ // A load here means one of two things: a load of undef from a global,
+ // a load from an unknown pointer. Either way, having it return undef
+ // is okay.
+ return false;
+ }
+
+ markOverdefined(&I);
+ return true;
+}
+
/// While solving the dataflow for a function, we don't compute a result for
/// operations with an undef operand, to allow undef to be lowered to a
/// constant later. For example, constant folding of "zext i8 undef to i16"
@@ -1782,60 +1928,8 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) {
if (!BBExecutable.count(&BB))
continue;
- for (Instruction &I : BB) {
- // Look for instructions which produce undef values.
- if (I.getType()->isVoidTy())
- continue;
-
- if (auto *STy = dyn_cast<StructType>(I.getType())) {
- // Only a few things that can be structs matter for undef.
-
- // Tracked calls must never be marked overdefined in resolvedUndefsIn.
- if (auto *CB = dyn_cast<CallBase>(&I))
- if (Function *F = CB->getCalledFunction())
- if (MRVFunctionsTracked.count(F))
- continue;
-
- // extractvalue and insertvalue don't need to be marked; they are
- // tracked as precisely as their operands.
- if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))
- continue;
- // Send the results of everything else to overdefined. We could be
- // more precise than this but it isn't worth bothering.
- for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
- ValueLatticeElement &LV = getStructValueState(&I, i);
- if (LV.isUnknown()) {
- markOverdefined(LV, &I);
- MadeChange = true;
- }
- }
- continue;
- }
-
- ValueLatticeElement &LV = getValueState(&I);
- if (!LV.isUnknown())
- continue;
-
- // There are two reasons a call can have an undef result
- // 1. It could be tracked.
- // 2. It could be constant-foldable.
- // Because of the way we solve return values, tracked calls must
- // never be marked overdefined in resolvedUndefsIn.
- if (auto *CB = dyn_cast<CallBase>(&I))
- if (Function *F = CB->getCalledFunction())
- if (TrackedRetVals.count(F))
- continue;
-
- if (isa<LoadInst>(I)) {
- // A load here means one of two things: a load of undef from a global,
- // a load from an unknown pointer. Either way, having it return undef
- // is okay.
- continue;
- }
-
- markOverdefined(&I);
- MadeChange = true;
- }
+ for (Instruction &I : BB)
+ MadeChange |= resolvedUndef(I);
}
LLVM_DEBUG(if (MadeChange) dbgs()
@@ -1913,6 +2007,10 @@ SCCPSolver::solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) {
Visitor->solveWhileResolvedUndefsIn(WorkList);
}
+void SCCPSolver::solveWhileResolvedUndefs() {
+ Visitor->solveWhileResolvedUndefs();
+}
+
bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const {
return Visitor->isBlockExecutable(BB);
}
@@ -1930,6 +2028,10 @@ void SCCPSolver::removeLatticeValueFor(Value *V) {
return Visitor->removeLatticeValueFor(V);
}
+void SCCPSolver::resetLatticeValueFor(CallBase *Call) {
+ Visitor->resetLatticeValueFor(Call);
+}
+
const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const {
return Visitor->getLatticeValueFor(V);
}
diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll
index a9cb4184bb8a7..003f80fa260ff 100644
--- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression.ll
@@ -36,7 +36,7 @@ define internal i64 @zoo(i1 %flag) {
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @func2.1(ptr getelementptr inbounds ([[STRUCT]], ptr @Global, i32 0, i32 4))
; CHECK-NEXT: br label [[MERGE]]
; CHECK: merge:
-; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ]
+; CHECK-NEXT: [[TMP2:%.*]] = phi i64 [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 3) to i64), [[PLUS]] ], [ ptrtoint (ptr getelementptr inbounds ([[STRUCT:%.*]], ptr @Global, i32 0, i32 4) to i64), [[MINUS]] ]
; CHECK-NEXT: ret i64 [[TMP2]]
;
entry:
@@ -70,3 +70,4 @@ define i64 @main() {
%3 = add i64 %1, %2
ret i64 %3
}
+
diff --git a/llvm/test/Transforms/FunctionSpecialization/non-argument-tracked.ll b/llvm/test/Transforms/FunctionSpecialization/non-argument-tracked.ll
index 2fb7205f7e33d..14a6fd746d09e 100644
--- a/llvm/test/Transforms/FunctionSpecialization/non-argument-tracked.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/non-argument-tracked.ll
@@ -29,9 +29,10 @@ define internal i32 @f2(i32 %i) {
;; All calls are to specilisation instances.
; CHECK-LABEL: define i32 @g0
-; CHECK: [[U0:%.*]] = call i32 @f0.[[#A:]]()
-; CHECK-NEXT: [[U1:%.*]] = call i32 @f1.[[#B:]]()
-; CHECK-NEXT: [[U2:%.*]] = call i32 @f2.[[#C:]]()
+; CHECK: call void @f0.[[#A:]]()
+; CHECK-NEXT: call void @f1.[[#B:]]()
+; CHECK-NEXT: call void @f2.[[#C:]]()
+; CHECK-NEXT: ret i32 9
define i32 @g0(i32 %i) {
%u0 = call i32 @f0(i32 1)
%u1 = call i32 @f1(i32 2)
@@ -42,9 +43,10 @@ define i32 @g0(i32 %i) {
}
; CHECK-LABEL: define i32 @g1
-; CHECK: [[U0:%.*]] = call i32 @f0.[[#D:]]()
-; CHECK-NEXT: [[U1:%.*]] = call i32 @f1.[[#E:]]()
-; CHECK-NEXT: [[U2:%.*]] = call i32 @f2.[[#F:]]()
+; CHECK: call void @f0.[[#D:]]()
+; CHECK-NEXT: call void @f1.[[#E:]]()
+; CHECK-NEXT: call void @f2.[[#F:]]()
+; CHECK-NEXT: ret i32 12
define i32 @g1(i32 %i) {
%u0 = call i32 @f0(i32 2)
%u1 = call i32 @f1(i32 3)
@@ -56,9 +58,9 @@ define i32 @g1(i32 %i) {
; All of the function are specialized and all clones are with internal linkage.
-; CHECK-DAG: define internal i32 @f0.[[#A]]() {
-; CHECK-DAG: define internal i32 @f1.[[#B]]() {
-; CHECK-DAG: define internal i32 @f2.[[#C]]() {
-; CHECK-DAG: define internal i32 @f0.[[#D]]() {
-; CHECK-DAG: define internal i32 @f1.[[#E]]() {
-; CHECK-DAG: define internal i32 @f2.[[#F]]() {
+; CHECK-DAG: define internal void @f0.[[#A]]() {
+; CHECK-DAG: define internal void @f1.[[#B]]() {
+; CHECK-DAG: define internal void @f2.[[#C]]() {
+; CHECK-DAG: define internal void @f0.[[#D]]() {
+; CHECK-DAG: define internal void @f1.[[#E]]() {
+; CHECK-DAG: define internal void @f2.[[#F]]() {
diff --git a/llvm/test/Transforms/FunctionSpecialization/track-return.ll b/llvm/test/Transforms/FunctionSpecialization/track-return.ll
new file mode 100644
index 0000000000000..58a1c5f2a5904
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/track-return.ll
@@ -0,0 +1,106 @@
+; RUN: opt -passes="ipsccp<func-spec>" -force-specialization -funcspec-for-literal-constant -funcspec-max-iters=3 -S < %s | FileCheck %s
+
+define i64 @main() {
+; CHECK: define i64 @main
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[C1:%.*]] = call i64 @foo.1(i1 true, i64 3, i64 1)
+; CHECK-NEXT: [[C2:%.*]] = call i64 @foo.2(i1 false, i64 4, i64 -1)
+; CHECK-NEXT: ret i64 8
+;
+entry:
+ %c1 = call i64 @foo(i1 true, i64 3, i64 1)
+ %c2 = call i64 @foo(i1 false, i64 4, i64 -1)
+ %add = add i64 %c1, %c2
+ ret i64 %add
+}
+
+define internal i64 @foo(i1 %flag, i64 %m, i64 %n) {
+;
+; CHECK: define internal i64 @foo.1
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %plus
+; CHECK: plus:
+; CHECK-NEXT: [[N0:%.*]] = call i64 @binop.4(i64 3, i64 1)
+; CHECK-NEXT: [[RES0:%.*]] = call i64 @bar.6(i64 4)
+; CHECK-NEXT: br label %merge
+; CHECK: merge:
+; CHECK-NEXT: ret i64 undef
+;
+; CHECK: define internal i64 @foo.2
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %minus
+; CHECK: minus:
+; CHECK-NEXT: [[N1:%.*]] = call i64 @binop.3(i64 4, i64 -1)
+; CHECK-NEXT: [[RES1:%.*]] = call i64 @bar.5(i64 3)
+; CHECK-NEXT: br label %merge
+; CHECK: merge:
+; CHECK-NEXT: ret i64 undef
+;
+entry:
+ br i1 %flag, label %plus, label %minus
+
+plus:
+ %n0 = call i64 @binop(i64 %m, i64 %n)
+ %res0 = call i64 @bar(i64 %n0)
+ br label %merge
+
+minus:
+ %n1 = call i64 @binop(i64 %m, i64 %n)
+ %res1 = call i64 @bar(i64 %n1)
+ br label %merge
+
+merge:
+ %res = phi i64 [ %res0, %plus ], [ %res1, %minus]
+ ret i64 %res
+}
+
+define internal i64 @binop(i64 %x, i64 %y) {
+;
+; CHECK: define internal i64 @binop.3
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret i64 undef
+;
+; CHECK: define internal i64 @binop.4
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret i64 undef
+;
+entry:
+ %z = add i64 %x, %y
+ ret i64 %z
+}
+
+define internal i64 @bar(i64 %n) {
+;
+; CHECK: define internal i64 @bar.5
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %if.else
+; CHECK: if.else:
+; CHECK-NEXT: br label %if.end
+; CHECK: if.end:
+; CHECK-NEXT: ret i64 undef
+;
+; CHECK: define internal i64 @bar.6
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %if.then
+; CHECK: if.then:
+; CHECK-NEXT: br label %if.end
+; CHECK: if.end:
+; CHECK-NEXT: ret i64 undef
+;
+entry:
+ %cmp = icmp sgt i64 %n, 3
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+ %res0 = sdiv i64 %n, 2
+ br label %if.end
+
+if.else:
+ %res1 = mul i64 %n, 2
+ br label %if.end
+
+if.end:
+ %res = phi i64 [ %res0, %if.then ], [ %res1, %if.else]
+ ret i64 %res
+}
+
More information about the llvm-commits
mailing list