[llvm] [GlobalOpt][FMV] Fix static resolution of calls. (PR #160011)
Momchil Velikov via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 6 03:57:18 PST 2025
================
@@ -2545,107 +2552,148 @@ static bool OptimizeNonTrivialIFuncs(
if (Resolver->isInterposable())
continue;
- TargetTransformInfo &TTI = GetTTI(*Resolver);
-
- // Discover the callee versions.
- SmallVector<Function *> Callees;
- if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
+ SmallVector<Function *> Versions;
+ // Discover the versioned functions.
+ if (any_of(*Resolver, [&](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
- if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
+ if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
return true;
return false;
}))
continue;
- if (Callees.empty())
+ if (Versions.empty())
continue;
- LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
- << Resolver->getName() << "\n");
-
- // Cache the feature mask for each callee.
- for (Function *Callee : Callees) {
- auto [It, Inserted] = FeatureMask.try_emplace(Callee);
+ for (Function *V : Versions) {
+ VersionOf.insert({V, &IF});
+ auto [It, Inserted] = FeatureMask.try_emplace(V);
if (Inserted)
- It->second = TTI.getFeatureMask(*Callee);
+ It->second = GetTTI(*V).getFeatureMask(*V);
}
- // Sort the callee versions in decreasing priority order.
- sort(Callees, [&](auto *LHS, auto *RHS) {
+ // Sort function versions in decreasing priority order.
+ sort(Versions, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});
- // Find the callsites and cache the feature mask for each caller.
- SmallVector<Function *> Callers;
+ IFuncs.push_back(&IF);
+ VersionedFuncs.try_emplace(&IF, std::move(Versions));
+ }
+
+ for (GlobalIFunc *CalleeIF : IFuncs) {
+ SmallVector<Function *> NonFMVCallers;
+ DenseSet<GlobalIFunc *> CallerIFuncs;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
- for (User *U : IF.users()) {
+
+ // Find the callsites.
+ for (User *U : CalleeIF->users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
- if (CB->getCalledOperand() == &IF) {
+ if (CB->getCalledOperand() == CalleeIF) {
Function *Caller = CB->getFunction();
- auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
- if (FeatInserted)
- FeatIt->second = TTI.getFeatureMask(*Caller);
- auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
- if (CallInserted)
- Callers.push_back(Caller);
- CallIt->second.push_back(CB);
+ GlobalIFunc *CallerIF = nullptr;
+ TargetTransformInfo &TTI = GetTTI(*Caller);
+ bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
+ // The caller is a version of a known IFunc.
+ if (auto It = VersionOf.find(Caller); It != VersionOf.end())
+ CallerIF = It->second;
+ else if (!CallerIsFMV && OptimizeNonFMVCallers) {
+ // The caller is non-FMV.
+ auto [It, Inserted] = FeatureMask.try_emplace(Caller);
+ if (Inserted)
+ It->second = TTI.getFeatureMask(*Caller);
+ } else
+ // The caller is none of the above, skip.
+ continue;
+ auto [It, Inserted] = CallSites.try_emplace(Caller);
+ if (Inserted) {
+ if (CallerIsFMV)
+ CallerIFuncs.insert(CallerIF);
+ else
+ NonFMVCallers.push_back(Caller);
+ }
+ It->second.push_back(CB);
}
}
}
- // Sort the caller versions in decreasing priority order.
- sort(Callers, [&](auto *LHS, auto *RHS) {
- return FeatureMask[LHS].ugt(FeatureMask[RHS]);
- });
-
- auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
+ if (CallSites.empty())
+ continue;
- // Index to the highest priority candidate.
- unsigned I = 0;
- // Now try to redirect calls starting from higher priority callers.
- for (Function *Caller : Callers) {
- assert(I < Callees.size() && "Found callers of equal priority");
+ LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
+ << CalleeIF->getResolverFunction()->getName() << "\n");
+
+ // The complexity of this algorithm is linear: O(NumCallers + NumCallees).
+ // TODO
+ // A limitation it has is that we are not using information about the
+ // current caller to deduce why an earlier caller of higher priority was
+ // skipped. For example let's say the current caller is aes+sve2 and a
+ // previous caller was mops+sve2. Knowing that sve2 is available we could
+ // infer that mops is unavailable. This would allow us to skip callee
+ // versions which depend on mops. I tried implementing this but the
+ // complexity was cubic :/
+ auto redirectCalls = [&](ArrayRef<Function *> Callers,
+ ArrayRef<Function *> Callees) {
+ // Index to the highest callee candidate.
+ unsigned I = 0;
+
+ for (Function *const &Caller : Callers) {
+ bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
----------------
momchil-velikov wrote:
Whether a caller is an FMV-caller or not is statically know at the call site, no?
The you can just pass `false` or `true` as a parameter `CallerIsFMV`.
https://github.com/llvm/llvm-project/pull/160011
More information about the llvm-commits
mailing list