[llvm] 7ea597e - [FuncSpec] Consider constant struct arguments when specializing.
Alexandros Lamprineas via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 17 10:40:10 PDT 2023
Author: Alexandros Lamprineas
Date: 2023-04-17T18:39:21+01:00
New Revision: 7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f
URL: https://github.com/llvm/llvm-project/commit/7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f
DIFF: https://github.com/llvm/llvm-project/commit/7ea597ead9fc9773ee5a5f5595fa4a61a68cc90f.diff
LOG: [FuncSpec] Consider constant struct arguments when specializing.
Optionally enabled just like integer and floating point arguments.
Differential Revision: https://reviews.llvm.org/D145374
Added:
llvm/test/Transforms/FunctionSpecialization/constant-struct.ll
Modified:
llvm/include/llvm/Transforms/Utils/SCCPSolver.h
llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
llvm/lib/Transforms/Utils/SCCPSolver.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 633052f1c15e7..3512f0a867b91 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -168,17 +168,18 @@ class SCCPSolver {
/// range with a single element.
Constant *getConstant(const ValueLatticeElement &LV) const;
+ /// Return either a Constant or nullptr for a given Value.
+ Constant *getConstantOrNull(Value *V) const;
+
/// Return a reference to the set of argument tracked functions.
SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions();
- /// Mark the constant arguments of a new function specialization. \p F points
- /// to the cloned function and \p Args contains a list of constant arguments
- /// represented as pairs of {formal,actual} values (the formal argument is
- /// associated with the original function definition). All other arguments of
- /// the specialization inherit the lattice state of their corresponding values
- /// in the original function.
- void markArgInFuncSpecialization(Function *F,
- const SmallVectorImpl<ArgInfo> &Args);
+ /// Set the Lattice Value for the arguments of a specialization \p F.
+ /// If an argument is Constant then its lattice value is marked with the
+ /// corresponding actual argument in \p Args. Otherwise, its lattice value
+ /// is inherited (copied) from the corresponding formal argument in \p Args.
+ void setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args);
/// Mark all of the blocks in function \p F non-executable. Clients can used
/// this method to erase a function from the module (e.g., if it has been
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 8f5878feab60f..4fc609e4405e4 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -536,7 +536,7 @@ Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &
// Initialize the lattice state of the arguments of the function clone,
// marking the argument on which we specialized the function constant
// with the given value.
- Solver.markArgInFuncSpecialization(Clone, S.Args);
+ Solver.setLatticeValueForSpecializationArguments(Clone, S.Args);
Solver.addArgumentTrackedFunction(Clone);
Solver.markBlockExecutable(&Clone->front());
@@ -666,16 +666,9 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
if (A->user_empty())
return false;
- // For now, don't attempt to specialize functions based on the values of
- // composite types.
- Type *ArgTy = A->getType();
- if (!ArgTy->isSingleValueType())
- return false;
-
- // Specialization of integer and floating point types needs to be explicitly
- // enabled.
- if (!SpecializeLiteralConstant &&
- (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy()))
+ Type *Ty = A->getType();
+ if (!Ty->isPointerTy() && (!SpecializeLiteralConstant ||
+ (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isStructTy())))
return false;
// SCCP solver does not record an argument that will be constructed on
@@ -686,21 +679,22 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
// Check the lattice value and decide if we should attemt to specialize,
// based on this argument. No point in specialization, if the lattice value
// is already a constant.
- const ValueLatticeElement &LV = Solver.getLatticeValueFor(A);
- if (LV.isUnknownOrUndef() || LV.isConstant() ||
- (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) {
- LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter "
- << A->getNameOrAsOperand() << " is already constant\n");
- return false;
- }
-
- LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter "
- << A->getNameOrAsOperand() << "\n");
+ bool IsOverdefined = Ty->isStructTy()
+ ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined)
+ : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A));
- return true;
+ LLVM_DEBUG(
+ if (IsOverdefined)
+ dbgs() << "FnSpecialization: Found interesting parameter "
+ << A->getNameOrAsOperand() << "\n";
+ else
+ dbgs() << "FnSpecialization: Nothing to do, parameter "
+ << A->getNameOrAsOperand() << " is already constant\n";
+ );
+ return IsOverdefined;
}
-/// Check if the valuy \p V (an actual argument) is a constant or can only
+/// Check if the value \p V (an actual argument) is a constant or can only
/// have a constant value. Return that constant.
Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
if (isa<PoisonValue>(V))
@@ -720,18 +714,8 @@ Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
// Select for possible specialisation values that are constants or
// are deduced to be constants or constant ranges with a single element.
Constant *C = dyn_cast<Constant>(V);
- if (!C) {
- const ValueLatticeElement &LV = Solver.getLatticeValueFor(V);
- if (LV.isConstant())
- C = LV.getConstant();
- else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) {
- assert(V->getType()->isIntegerTy() && "Non-integral constant range");
- C = Constant::getIntegerValue(V->getType(),
- *LV.getConstantRange().getSingleElement());
- } else
- return nullptr;
- }
-
+ if (!C)
+ C = Solver.getConstantOrNull(V);
return C;
}
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 3dd155ba3c9b2..cc86ab26471a6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -73,30 +73,9 @@ static bool canRemoveInstruction(Instruction *I) {
}
bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
- Constant *Const = nullptr;
- if (V->getType()->isStructTy()) {
- std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V);
- if (llvm::any_of(IVs, isOverdefined))
- return false;
- std::vector<Constant *> ConstVals;
- auto *ST = cast<StructType>(V->getType());
- for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
- ValueLatticeElement V = IVs[i];
- ConstVals.push_back(SCCPSolver::isConstant(V)
- ? getConstant(V)
- : UndefValue::get(ST->getElementType(i)));
- }
- Const = ConstantStruct::get(ST, ConstVals);
- } else {
- const ValueLatticeElement &IV = getLatticeValueFor(V);
- if (isOverdefined(IV))
- return false;
-
- Const = SCCPSolver::isConstant(IV) ? getConstant(IV)
- : UndefValue::get(V->getType());
- }
- assert(Const && "Constant is nullptr here!");
-
+ Constant *Const = getConstantOrNull(V);
+ if (!Const)
+ return false;
// Replacing `musttail` instructions with constant breaks `musttail` invariant
// unless the call itself can be removed.
// Calls with "clang.arc.attachedcall" implicitly use the return value and
@@ -734,12 +713,14 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
Constant *getConstant(const ValueLatticeElement &LV) const;
+ Constant *getConstantOrNull(Value *V) const;
+
SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() {
return TrackingIncomingArguments;
}
- void markArgInFuncSpecialization(Function *F,
- const SmallVectorImpl<ArgInfo> &Args);
+ void setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args);
void markFunctionUnreachable(Function *F) {
for (auto &BB : *F)
@@ -833,36 +814,68 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
return nullptr;
}
-void SCCPInstVisitor::markArgInFuncSpecialization(
- Function *F, const SmallVectorImpl<ArgInfo> &Args) {
+Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
+ Constant *Const = nullptr;
+ if (V->getType()->isStructTy()) {
+ std::vector<ValueLatticeElement> LVs = getStructLatticeValueFor(V);
+ if (any_of(LVs, SCCPSolver::isOverdefined))
+ return nullptr;
+ std::vector<Constant *> ConstVals;
+ auto *ST = cast<StructType>(V->getType());
+ for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) {
+ ValueLatticeElement LV = LVs[I];
+ ConstVals.push_back(SCCPSolver::isConstant(LV)
+ ? getConstant(LV)
+ : UndefValue::get(ST->getElementType(I)));
+ }
+ Const = ConstantStruct::get(ST, ConstVals);
+ } else {
+ const ValueLatticeElement &LV = getLatticeValueFor(V);
+ if (SCCPSolver::isOverdefined(LV))
+ return nullptr;
+ Const = SCCPSolver::isConstant(LV) ? getConstant(LV)
+ : UndefValue::get(V->getType());
+ }
+ assert(Const && "Constant is nullptr here!");
+ return Const;
+}
+
+void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args) {
assert(!Args.empty() && "Specialization without arguments");
assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() &&
"Functions should have the same number of arguments");
auto Iter = Args.begin();
- Argument *NewArg = F->arg_begin();
- Argument *OldArg = Args[0].Formal->getParent()->arg_begin();
+ Function::arg_iterator NewArg = F->arg_begin();
+ Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin();
for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) {
LLVM_DEBUG(dbgs() << "SCCP: Marking argument "
<< NewArg->getNameOrAsOperand() << "\n");
- if (Iter != Args.end() && OldArg == Iter->Formal) {
- // Mark the argument constants in the new function.
- markConstant(NewArg, Iter->Actual);
+ // Mark the argument constants in the new function
+ // or copy the lattice state over from the old function.
+ if (Iter != Args.end() && Iter->Formal == &*OldArg) {
+ if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+ NewValue.markConstant(Iter->Actual->getAggregateElement(I));
+ }
+ } else {
+ ValueState[&*NewArg].markConstant(Iter->Actual);
+ }
++Iter;
- } else if (ValueState.count(OldArg)) {
- // For the remaining arguments in the new function, copy the lattice state
- // over from the old function.
- //
- // Note: This previously looked like this:
- // ValueState[NewArg] = ValueState[OldArg];
- // This is incorrect because the DenseMap class may resize the underlying
- // memory when inserting `NewArg`, which will invalidate the reference to
- // `OldArg`. Instead, we make sure `NewArg` exists before setting it.
- auto &NewValue = ValueState[NewArg];
- NewValue = ValueState[OldArg];
- pushToWorkList(NewValue, NewArg);
+ } else {
+ if (auto *STy = dyn_cast<StructType>(NewArg->getType())) {
+ for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) {
+ ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}];
+ NewValue = StructValueState[{&*OldArg, I}];
+ }
+ } else {
+ ValueLatticeElement &NewValue = ValueState[&*NewArg];
+ NewValue = ValueState[&*OldArg];
+ }
}
}
}
@@ -1945,13 +1958,17 @@ Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const {
return Visitor->getConstant(LV);
}
+Constant *SCCPSolver::getConstantOrNull(Value *V) const {
+ return Visitor->getConstantOrNull(V);
+}
+
SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() {
return Visitor->getArgumentTrackedFunctions();
}
-void SCCPSolver::markArgInFuncSpecialization(
- Function *F, const SmallVectorImpl<ArgInfo> &Args) {
- Visitor->markArgInFuncSpecialization(F, Args);
+void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F,
+ const SmallVectorImpl<ArgInfo> &Args) {
+ Visitor->setLatticeValueForSpecializationArguments(F, Args);
}
void SCCPSolver::markFunctionUnreachable(Function *F) {
diff --git a/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll
new file mode 100644
index 0000000000000..6c3bfaef49b0a
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll
@@ -0,0 +1,46 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+
+; RUN: opt -passes="ipsccp<func-spec>" -force-specialization \
+; RUN: -funcspec-for-literal-constant -S < %s | FileCheck %s
+
+define i32 @foo(i32 %y0, i32 %y1) {
+; CHECK-LABEL: @foo(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[Y:%.*]] = insertvalue { i32, i32 } undef, i32 [[Y0:%.*]], 0
+; CHECK-NEXT: [[YY:%.*]] = insertvalue { i32, i32 } [[Y]], i32 [[Y1:%.*]], 1
+; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.1({ i32, i32 } { i32 2, i32 3 }, { i32, i32 } [[YY]])
+; CHECK-NEXT: ret i32 [[CALL]]
+;
+entry:
+ %y = insertvalue { i32, i32 } undef, i32 %y0, 0
+ %yy = insertvalue { i32, i32 } %y, i32 %y1, 1
+ %call = tail call i32 @add({i32, i32} {i32 2, i32 3}, {i32, i32} %yy)
+ ret i32 %call
+}
+
+define i32 @bar(i32 %x0, i32 %x1) {
+; CHECK-LABEL: @bar(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[X:%.*]] = insertvalue { i32, i32 } undef, i32 [[X0:%.*]], 0
+; CHECK-NEXT: [[XX:%.*]] = insertvalue { i32, i32 } [[X]], i32 [[X1:%.*]], 1
+; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.2({ i32, i32 } [[XX]], { i32, i32 } { i32 3, i32 2 })
+; CHECK-NEXT: ret i32 [[CALL]]
+;
+entry:
+ %x = insertvalue { i32, i32 } undef, i32 %x0, 0
+ %xx = insertvalue { i32, i32 } %x, i32 %x1, 1
+ %call = tail call i32 @add({i32, i32} %xx, {i32, i32} {i32 3, i32 2})
+ ret i32 %call
+}
+
+define internal i32 @add({i32, i32} %x, {i32, i32} %y) {
+entry:
+ %x0 = extractvalue {i32, i32} %x, 0
+ %y0 = extractvalue {i32, i32} %y, 0
+ %add0 = add nsw i32 %x0, %y0
+ %x1 = extractvalue {i32, i32} %x, 1
+ %y1 = extractvalue {i32, i32} %y, 1
+ %add1 = add nsw i32 %x1, %y1
+ %mul = mul i32 %add0, %add1
+ ret i32 %mul
+}
More information about the llvm-commits
mailing list