[llvm] [GlobalOpt][FMV] Fix static resolution of calls. (PR #160011)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 21 13:34:32 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Alexandros Lamprineas (labrinea)
<details>
<summary>Changes</summary>
Addresses the issues found on the review of
https://github.com/llvm/llvm-project/pull/150267/files#r2356936355
Currently when collecting the users of an IFunc symbol to determine the callers, we incorrectly mix versions of different functions together, alongside non-FMV callers all in the same bag. That is problematic because we incorrectly deduce which features are unavailable as we iterate the callers.
I have updated the unit tests to require a resolver function for the callers and regenerated the resolvers since some FMV features have been removed making the detection bitmasks different. I've replaced the deleted FMV feature ls64 with cssc. I've added a new test to cover unrelated callers.
---
Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160011.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/IPO/GlobalOpt.cpp (+115-75)
- (modified) llvm/test/Transforms/GlobalOpt/resolve-fmv-ifunc.ll (+301-54)
``````````diff
diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
index f88d51f443bcf..0707eb5eacf5d 100644
--- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
@@ -2482,20 +2482,21 @@ DeleteDeadIFuncs(Module &M,
// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
-static bool collectVersions(TargetTransformInfo &TTI, Value *V,
- SmallVectorImpl<Function *> &Versions) {
+static bool
+collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
+ function_ref<TargetTransformInfo &(Function &)> GetTTI) {
if (auto *F = dyn_cast<Function>(V)) {
- if (!TTI.isMultiversionedFunction(*F))
+ if (!GetTTI(*F).isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
- if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
+ if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
return false;
- if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
+ if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
- if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
+ if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
return false;
} else {
// Unknown instruction type. Bail.
@@ -2525,8 +2526,14 @@ static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;
- // Cache containing the mask constructed from a function's target features.
+ // Map containing the feature bits for a given function.
DenseMap<Function *, APInt> FeatureMask;
+ // Map containing all the versions corresponding to an IFunc symbol.
+ DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
+ // Map containing the IFunc symbol a function is version of.
+ DenseMap<Function *, GlobalIFunc *> VersionOf;
+ // List of all the interesting IFuncs found in the module.
+ SmallVector<GlobalIFunc *> IFuncs;
for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
@@ -2539,107 +2546,140 @@ static bool OptimizeNonTrivialIFuncs(
if (Resolver->isInterposable())
continue;
- TargetTransformInfo &TTI = GetTTI(*Resolver);
-
- // Discover the callee versions.
- SmallVector<Function *> Callees;
- if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
+ SmallVector<Function *> Versions;
+ // Discover the versioned functions.
+ if (any_of(*Resolver, [&](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
- if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
+ if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
return true;
return false;
}))
continue;
- if (Callees.empty())
+ if (Versions.empty())
continue;
- LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
- << Resolver->getName() << "\n");
-
- // Cache the feature mask for each callee.
- for (Function *Callee : Callees) {
- auto [It, Inserted] = FeatureMask.try_emplace(Callee);
+ for (Function *V : Versions) {
+ VersionOf.insert({V, &IF});
+ auto [It, Inserted] = FeatureMask.try_emplace(V);
if (Inserted)
- It->second = TTI.getFeatureMask(*Callee);
+ It->second = GetTTI(*V).getFeatureMask(*V);
}
- // Sort the callee versions in decreasing priority order.
- sort(Callees, [&](auto *LHS, auto *RHS) {
+ // Sort function versions in decreasing priority order.
+ sort(Versions, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});
- // Find the callsites and cache the feature mask for each caller.
- SmallVector<Function *> Callers;
+ IFuncs.push_back(&IF);
+ VersionedFuncs.try_emplace(&IF, std::move(Versions));
+ }
+
+ for (GlobalIFunc *CalleeIF : IFuncs) {
+ SmallVector<Function *> NonFMVCallers;
+ SmallVector<GlobalIFunc *> CallerIFuncs;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
- for (User *U : IF.users()) {
+
+ // Find the callsites.
+ for (User *U : CalleeIF->users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
- if (CB->getCalledOperand() == &IF) {
+ if (CB->getCalledOperand() == CalleeIF) {
Function *Caller = CB->getFunction();
- auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
- if (FeatInserted)
- FeatIt->second = TTI.getFeatureMask(*Caller);
- auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
- if (CallInserted)
- Callers.push_back(Caller);
- CallIt->second.push_back(CB);
+ GlobalIFunc *CallerIFunc = nullptr;
+ TargetTransformInfo &TTI = GetTTI(*Caller);
+ bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
+ // The caller is a version of a known IFunc.
+ if (auto It = VersionOf.find(Caller); It != VersionOf.end())
+ CallerIFunc = It->second;
+ else if (!CallerIsFMV && OptimizeNonFMVCallers) {
+ // The caller is non-FMV.
+ auto [It, Inserted] = FeatureMask.try_emplace(Caller);
+ if (Inserted)
+ It->second = TTI.getFeatureMask(*Caller);
+ } else
+ // The caller is none of the above, skip.
+ continue;
+ auto [It, Inserted] = CallSites.try_emplace(Caller);
+ if (Inserted) {
+ if (CallerIsFMV)
+ CallerIFuncs.push_back(CallerIFunc);
+ else
+ NonFMVCallers.push_back(Caller);
+ }
+ It->second.push_back(CB);
}
}
}
- // Sort the caller versions in decreasing priority order.
- sort(Callers, [&](auto *LHS, auto *RHS) {
- return FeatureMask[LHS].ugt(FeatureMask[RHS]);
- });
-
- auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
+ LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
+ << CalleeIF->getResolverFunction()->getName() << "\n");
- // Index to the highest priority candidate.
- unsigned I = 0;
- // Now try to redirect calls starting from higher priority callers.
- for (Function *Caller : Callers) {
- assert(I < Callees.size() && "Found callers of equal priority");
+ auto redirectCalls = [&](SmallVectorImpl<Function *> &Callers,
+ SmallVectorImpl<Function *> &Callees) {
+ // Index to the current callee candidate.
+ unsigned I = 0;
- Function *Callee = Callees[I];
- APInt CallerBits = FeatureMask[Caller];
- APInt CalleeBits = FeatureMask[Callee];
+ // Try to redirect calls starting from higher priority callers.
+ for (Function *Caller : Callers) {
+ if (I == Callees.size())
+ break;
- // In the case of FMV callers, we know that all higher priority callers
- // than the current one did not get selected at runtime, which helps
- // reason about the callees (if they have versions that mandate presence
- // of the features which we already know are unavailable on this target).
- if (TTI.isMultiversionedFunction(*Caller)) {
+ bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
+ // In the case of FMV callers, we know that all higher priority callers
+ // than the current one did not get selected at runtime, which helps
+ // reason about the callees (if they have versions that mandate presence
+ // of the features which we already know are unavailable on this
+ // target).
+ if (!CallerIsFMV)
+ // We can't reason much about non-FMV callers. Just pick the highest
+ // priority callee if it matches, otherwise bail.
+ assert(I == 0 && "Should only select the highest priority candidate");
+
+ Function *Callee = Callees[I];
+ APInt CallerBits = FeatureMask[Caller];
+ APInt CalleeBits = FeatureMask[Callee];
// If the feature set of the caller implies the feature set of the
- // highest priority candidate then it shall be picked. In case of
- // identical sets advance the candidate index one position.
- if (CallerBits == CalleeBits)
- ++I;
- else if (!implies(CallerBits, CalleeBits)) {
- // Keep advancing the candidate index as long as the caller's
- // features are a subset of the current candidate's.
- while (implies(CalleeBits, CallerBits)) {
+ // highest priority candidate then it shall be picked.
+ if (CalleeBits.isSubsetOf(CallerBits)) {
+ // If there are no records of call sites for this particular function
+ // version, then it is not actually a caller, in which case skip.
+ if (auto It = CallSites.find(Caller); It != CallSites.end()) {
+ for (CallBase *CS : It->second) {
+ LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName()
+ << " -> " << Callee->getName() << "\n");
+ CS->setCalledOperand(Callee);
+ }
+ Changed = true;
+ }
+ }
+ // Keep advancing the candidate index as long as the caller's
+ // features are a subset of the current candidate's.
+ if (CallerIsFMV) {
+ while (CallerBits.isSubsetOf(CalleeBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
- continue;
}
- } else {
- // We can't reason much about non-FMV callers. Just pick the highest
- // priority callee if it matches, otherwise bail.
- if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))
- continue;
}
- auto &Calls = CallSites[Caller];
- for (CallBase *CS : Calls) {
- LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
- << Callee->getName() << "\n");
- CS->setCalledOperand(Callee);
+ };
+
+ auto &Callees = VersionedFuncs[CalleeIF];
+
+ // Optimize non-FMV calls.
+ if (!NonFMVCallers.empty() && OptimizeNonFMVCallers)
+ redirectCalls(NonFMVCallers, Callees);
+
+ // Optimize FMV calls.
+ if (!CallerIFuncs.empty()) {
+ for (GlobalIFunc *CallerIF : CallerIFuncs) {
+ auto &Callers = VersionedFuncs[CallerIF];
+ redirectCalls(Callers, Callees);
}
- Changed = true;
}
- if (IF.use_empty() ||
- all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
+
+ if (CalleeIF->use_empty() ||
+ all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
diff --git a/llvm/test/Transforms/GlobalOpt/resolve-fmv-ifunc.ll b/llvm/test/Transforms/GlobalOpt/resolve-fmv-ifunc.ll
index 4b6a19d3f05cf..7ace67e3857ff 100644
--- a/llvm/test/Transforms/GlobalOpt/resolve-fmv-ifunc.ll
+++ b/llvm/test/Transforms/GlobalOpt/resolve-fmv-ifunc.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --filter "call i32 @(test_single_bb_resolver|test_multi_bb_resolver|test_caller_feats_not_implied|test_non_fmv_caller|test_priority|test_alternative_names)" --version 4
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --filter "call i32 @(test_single_bb_resolver|test_multi_bb_resolver|test_caller_feats_not_implied|test_non_fmv_caller|test_priority|test_alternative_names|test_unrelated_callers)" --version 4
; REQUIRES: aarch64-registered-target
@@ -13,6 +13,14 @@ $test_caller_feats_not_implied.resolver = comdat any
$test_non_fmv_caller.resolver = comdat any
$test_priority.resolver = comdat any
$test_alternative_names.resolver = comdat any
+$test_unrelated_callers.resolver = comdat any
+$caller1.resolver = comdat any
+$caller2.resolver = comdat any
+$caller3.resolver = comdat any
+$caller6.resolver = comdat any
+$caller7.resolver = comdat any
+$caller8.resolver = comdat any
+$caller9.resolver = comdat any
@__aarch64_cpu_features = external local_unnamed_addr global { i64 }
@@ -22,6 +30,14 @@ $test_alternative_names.resolver = comdat any
@test_non_fmv_caller = weak_odr ifunc i32 (), ptr @test_non_fmv_caller.resolver
@test_priority = weak_odr ifunc i32 (), ptr @test_priority.resolver
@test_alternative_names = weak_odr ifunc i32 (), ptr @test_alternative_names.resolver
+ at test_unrelated_callers = weak_odr ifunc i32 (), ptr @test_unrelated_callers.resolver
+ at caller1 = weak_odr ifunc i32 (), ptr @caller1.resolver
+ at caller2 = weak_odr ifunc i32 (), ptr @caller2.resolver
+ at caller3 = weak_odr ifunc i32 (), ptr @caller3.resolver
+ at caller6 = weak_odr ifunc i32 (), ptr @caller6.resolver
+ at caller7 = weak_odr ifunc i32 (), ptr @caller7.resolver
+ at caller8 = weak_odr ifunc i32 (), ptr @caller8.resolver
+ at caller9 = weak_odr ifunc i32 (), ptr @caller9.resolver
declare void @__init_cpu_features_resolver() local_unnamed_addr
@@ -34,18 +50,18 @@ define weak_odr ptr @test_single_bb_resolver.resolver() comdat {
resolver_entry:
tail call void @__init_cpu_features_resolver()
%0 = load i64, ptr @__aarch64_cpu_features, align 8
- %1 = and i64 %0, 68719476736
- %.not = icmp eq i64 %1, 0
- %2 = and i64 %0, 1073741824
- %.not3 = icmp eq i64 %2, 0
- %test_single_bb_resolver._Msve.test_single_bb_resolver.default = select i1 %.not3, ptr @test_single_bb_resolver.default, ptr @test_single_bb_resolver._Msve
- %common.ret.op = select i1 %.not, ptr %test_single_bb_resolver._Msve.test_single_bb_resolver.default, ptr @test_single_bb_resolver._Msve2
+ %1 = and i64 %0, 69793284352
+ %2 = icmp eq i64 %1, 69793284352
+ %3 = and i64 %0, 1073807616
+ %4 = icmp eq i64 %3, 1073807616
+ %test_single_bb_resolver._Msve.test_single_bb_resolver.default = select i1 %4, ptr @test_single_bb_resolver._Msve, ptr @test_single_bb_resolver.default
+ %common.ret.op = select i1 %2, ptr @test_single_bb_resolver._Msve2, ptr %test_single_bb_resolver._Msve.test_single_bb_resolver.default
ret ptr %common.ret.op
}
define i32 @caller1._Msve() #1 {
; CHECK-LABEL: define i32 @caller1._Msve(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR1:[0-9]+]] {
+; CHECK-SAME: ) #[[ATTR1:[0-9]+]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_single_bb_resolver._Msve()
;
entry:
@@ -55,7 +71,7 @@ entry:
define i32 @caller1._Msve2() #2 {
; CHECK-LABEL: define i32 @caller1._Msve2(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR2:[0-9]+]] {
+; CHECK-SAME: ) #[[ATTR2:[0-9]+]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_single_bb_resolver._Msve2()
;
entry:
@@ -65,7 +81,7 @@ entry:
define i32 @caller1.default() #0 {
; CHECK-LABEL: define i32 @caller1.default(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_single_bb_resolver.default()
;
entry:
@@ -73,6 +89,20 @@ entry:
ret i32 %call
}
+define weak_odr ptr @caller1.resolver() comdat {
+; CHECK-LABEL: define weak_odr ptr @caller1.resolver() comdat {
+resolver_entry:
+ tail call void @__init_cpu_features_resolver()
+ %0 = load i64, ptr @__aarch64_cpu_features, align 8
+ %1 = and i64 %0, 69793284352
+ %2 = icmp eq i64 %1, 69793284352
+ %3 = and i64 %0, 1073807616
+ %4 = icmp eq i64 %3, 1073807616
+ %caller1._Msve.caller1.default = select i1 %4, ptr @caller1._Msve, ptr @caller1.default
+ %common.ret.op = select i1 %2, ptr @caller1._Msve2, ptr %caller1._Msve.caller1.default
+ ret ptr %common.ret.op
+}
+
declare i32 @test_multi_bb_resolver._Mmops() #3
declare i32 @test_multi_bb_resolver._Msve2() #2
declare i32 @test_multi_bb_resolver._Msve() #1
@@ -92,20 +122,20 @@ common.ret: ; preds = %resolver_else2, %re
ret ptr %common.ret.op
resolver_else: ; preds = %resolver_entry
- %2 = and i64 %0, 68719476736
- %.not5 = icmp eq i64 %2, 0
- br i1 %.not5, label %resolver_else2, label %common.ret
+ %2 = and i64 %0, 69793284352
+ %3 = icmp eq i64 %2, 69793284352
+ br i1 %3, label %common.ret, label %resolver_else2
resolver_else2: ; preds = %resolver_else
- %3 = and i64 %0, 1073741824
- %.not6 = icmp eq i64 %3, 0
- %test_multi_bb_resolver._Msve.test_multi_bb_resolver.default = select i1 %.not6, ptr @test_multi_bb_resolver.default, ptr @test_multi_bb_resolver._Msve
+ %4 = and i64 %0, 1073807616
+ %5 = icmp eq i64 %4, 1073807616
+ %test_multi_bb_resolver._Msve.test_multi_bb_resolver.default = select i1 %5, ptr @test_multi_bb_resolver._Msve, ptr @test_multi_bb_resolver.default
br label %common.ret
}
define i32 @caller2._MmopsMsve2() #4 {
; CHECK-LABEL: define i32 @caller2._MmopsMsve2(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR4:[0-9]+]] {
+; CHECK-SAME: ) #[[ATTR4:[0-9]+]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_multi_bb_resolver._Mmops()
;
entry:
@@ -115,7 +145,7 @@ entry:
define i32 @caller2._Mmops() #3 {
; CHECK-LABEL: define i32 @caller2._Mmops(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR3:[0-9]+]] {
+; CHECK-SAME: ) #[[ATTR3:[0-9]+]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_multi_bb_resolver._Mmops()
;
entry:
@@ -125,7 +155,7 @@ entry:
define i32 @caller2._Msve() #1 {
; CHECK-LABEL: define i32 @caller2._Msve(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR1]] {
+; CHECK-SAME: ) #[[ATTR1]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_multi_bb_resolver()
;
entry:
@@ -135,7 +165,7 @@ entry:
define i32 @caller2.default() #0 {
; CHECK-LABEL: define i32 @caller2.default(
-; CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
+; CHECK-SAME: ) #[[ATTR0]] {
; CHECK: [[CALL:%.*]] = tail call i32 @test_multi_bb_resolver.default()
;
entry:
@@ -143,6 +173,31 @@ entry:
ret i32 %call
}
+define weak_odr ptr @caller2.resolver() comdat {
+; CHECK-LABEL: define weak_odr ptr @caller2.resolver() comdat {
+resolver_entry:
+ tail call void @__init_cpu_features_resolver()
+ %0 = load i64, ptr @__aarch64_cpu_features, align 8
+ %1 = and i64 %0, 576460822096707840
+ %2 = icmp eq i64 %1, 576460822096707840
+ br i1 %2, label %common.ret, label %resolver_else
+
+common.ret: ; preds = %resolver_else2, %resolver_else, %resolver_entry
+ %common.ret.op = phi ptr [ @caller2._MmopsMsve2, %resolver_entry ], [ @caller2._Mmops, %resolver_else ], [ %caller2._Msve.caller2.default, %resolver_else2 ]
+ ret ptr %common.ret.op
+
+resolver_else: ; preds = %resolver_entry
+ %3 = and i64 %0, 576460752303423488
+ %.not = icmp eq i64 %3, 0
+ br i1 %.not, label %resolver_else2, label %common.ret
+
+resolver_else2: ; preds = %resolver_else
+ %4 = and i64 %0, 1073807616
+ %5 = icmp eq i64 %4, 1073807616
+ %caller2._Msve.caller2.default = select i1 %5, ptr @caller2._Msve, ptr @caller2.default
+ br label %common.ret
+}
+
declare i32 @test_caller_feats_not_implied._Mmops() #3
declare i32 @test_caller_feats_not_implied._Msme() #5
declare i32 @test_caller_feats_not_implied._Msve() #1
@@ -162,20 +217,20 @@ common.ret: ; preds = %resolver_else2, %re
ret ptr %common.ret.op
resolver_else: ; preds = %resolver_entry
- %2 = and i64 %0, 4398046511104
- %.not5 = icmp eq i64 %2, 0
- br i1 %.not5, label %resolver_else2, label %common.ret
+ %2 = and i64 %0, 4398180795136
+ %3 = icmp eq i64 %2, 4398180795136
+ br i1 %3, label %common.ret, label %resolver_else2
resolver_else2: ; preds = %resolver_else
- %3 = and i64 %0, 1073741824
- %.not6 = icmp eq i64 %3, 0
- %test_caller_feats_not_implied._Msve.test_caller_feats_not_implied.default = select i1 %.not6, ptr @test_caller_feats_not_implied.default, ptr @test_caller_feats_not_implied._Msve
+ %4 = and i64 %0, 1073807616
+ %5 = icmp eq i64 %4, 1073807616
+ %test_caller_feats_n...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160011
More information about the llvm-commits
mailing list