[llvm] [GlobalOpt][FMV] Fix static resolution of calls. (PR #160011)

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 03:57:18 PST 2025


================
@@ -2545,107 +2552,148 @@ static bool OptimizeNonTrivialIFuncs(
     if (Resolver->isInterposable())
       continue;
 
-    TargetTransformInfo &TTI = GetTTI(*Resolver);
-
-    // Discover the callee versions.
-    SmallVector<Function *> Callees;
-    if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
+    SmallVector<Function *> Versions;
+    // Discover the versioned functions.
+    if (any_of(*Resolver, [&](BasicBlock &BB) {
           if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
-            if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
+            if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
               return true;
           return false;
         }))
       continue;
 
-    if (Callees.empty())
+    if (Versions.empty())
       continue;
 
-    LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
-                      << Resolver->getName() << "\n");
-
-    // Cache the feature mask for each callee.
-    for (Function *Callee : Callees) {
-      auto [It, Inserted] = FeatureMask.try_emplace(Callee);
+    for (Function *V : Versions) {
+      VersionOf.insert({V, &IF});
+      auto [It, Inserted] = FeatureMask.try_emplace(V);
       if (Inserted)
-        It->second = TTI.getFeatureMask(*Callee);
+        It->second = GetTTI(*V).getFeatureMask(*V);
     }
 
-    // Sort the callee versions in decreasing priority order.
-    sort(Callees, [&](auto *LHS, auto *RHS) {
+    // Sort function versions in decreasing priority order.
+    sort(Versions, [&](auto *LHS, auto *RHS) {
       return FeatureMask[LHS].ugt(FeatureMask[RHS]);
     });
 
-    // Find the callsites and cache the feature mask for each caller.
-    SmallVector<Function *> Callers;
+    IFuncs.push_back(&IF);
+    VersionedFuncs.try_emplace(&IF, std::move(Versions));
+  }
+
+  for (GlobalIFunc *CalleeIF : IFuncs) {
+    SmallVector<Function *> NonFMVCallers;
+    DenseSet<GlobalIFunc *> CallerIFuncs;
     DenseMap<Function *, SmallVector<CallBase *>> CallSites;
-    for (User *U : IF.users()) {
+
+    // Find the callsites.
+    for (User *U : CalleeIF->users()) {
       if (auto *CB = dyn_cast<CallBase>(U)) {
-        if (CB->getCalledOperand() == &IF) {
+        if (CB->getCalledOperand() == CalleeIF) {
           Function *Caller = CB->getFunction();
-          auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
-          if (FeatInserted)
-            FeatIt->second = TTI.getFeatureMask(*Caller);
-          auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
-          if (CallInserted)
-            Callers.push_back(Caller);
-          CallIt->second.push_back(CB);
+          GlobalIFunc *CallerIF = nullptr;
+          TargetTransformInfo &TTI = GetTTI(*Caller);
+          bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
+          // The caller is a version of a known IFunc.
+          if (auto It = VersionOf.find(Caller); It != VersionOf.end())
+            CallerIF = It->second;
+          else if (!CallerIsFMV && OptimizeNonFMVCallers) {
+            // The caller is non-FMV.
+            auto [It, Inserted] = FeatureMask.try_emplace(Caller);
+            if (Inserted)
+              It->second = TTI.getFeatureMask(*Caller);
+          } else
+            // The caller is none of the above, skip.
+            continue;
+          auto [It, Inserted] = CallSites.try_emplace(Caller);
+          if (Inserted) {
+            if (CallerIsFMV)
+              CallerIFuncs.insert(CallerIF);
+            else
+              NonFMVCallers.push_back(Caller);
+          }
+          It->second.push_back(CB);
         }
       }
     }
 
-    // Sort the caller versions in decreasing priority order.
-    sort(Callers, [&](auto *LHS, auto *RHS) {
-      return FeatureMask[LHS].ugt(FeatureMask[RHS]);
-    });
-
-    auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
+    if (CallSites.empty())
+      continue;
 
-    // Index to the highest priority candidate.
-    unsigned I = 0;
-    // Now try to redirect calls starting from higher priority callers.
-    for (Function *Caller : Callers) {
-      assert(I < Callees.size() && "Found callers of equal priority");
+    LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
+                      << CalleeIF->getResolverFunction()->getName() << "\n");
+
+    // The complexity of this algorithm is linear: O(NumCallers + NumCallees).
+    // TODO
+    // A limitation it has is that we are not using information about the
+    // current caller to deduce why an earlier caller of higher priority was
+    // skipped. For example let's say the current caller is aes+sve2 and a
+    // previous caller was mops+sve2. Knowing that sve2 is available we could
+    // infer that mops is unavailable. This would allow us to skip callee
+    // versions which depend on mops. I tried implementing this but the
+    // complexity was cubic :/
+    auto redirectCalls = [&](ArrayRef<Function *> Callers,
+                             ArrayRef<Function *> Callees) {
+      // Index to the highest callee candidate.
+      unsigned I = 0;
+
+      for (Function *const &Caller : Callers) {
+        bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);
+
+        LLVM_DEBUG(dbgs() << "   Examining "
+                          << (CallerIsFMV ? "FMV" : "regular") << " caller "
+                          << Caller->getName() << "\n");
+
+        if (I == Callees.size())
----------------
momchil-velikov wrote:

Suggestion to move this `if` at the very beginning of the loop body.

https://github.com/llvm/llvm-project/pull/160011


More information about the llvm-commits mailing list