[llvm] [FMV][GlobalOpt] Statically resolve calls to versioned functions. (PR #87939)

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 10 13:35:18 PST 2025


================
@@ -2641,6 +2641,165 @@ DeleteDeadIFuncs(Module &M,
   return Changed;
 }
 
+// Follows the use-def chain of \p V backwards until it finds a Function,
+// in which case it collects in \p Versions. Return true on successful
+// use-def chain traversal, false otherwise.
+static bool collectVersions(TargetTransformInfo &TTI, Value *V,
+                            SmallVectorImpl<Function *> &Versions) {
+  if (auto *F = dyn_cast<Function>(V)) {
+    if (!TTI.isMultiversionedFunction(*F))
+      return false;
+    Versions.push_back(F);
+  } else if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
+      return false;
+    if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
+      return false;
+  } else if (auto *Phi = dyn_cast<PHINode>(V)) {
+    for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
+      if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
+        return false;
+  } else {
+    // Unknown instruction type. Bail.
+    return false;
+  }
+  return true;
+}
+
+// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
+// deduce whether the optimization is legal we need to compare the target
+// features between caller and callee versions. The criteria for bypassing
+// the resolver are the following:
+//
+// * If the callee's feature set is a subset of the caller's feature set,
+//   then the callee is a candidate for direct call.
+//
+// * Among such candidates the one of highest priority is the best match
+//   and it shall be picked, unless there is a version of the callee with
+//   higher priority than the best match which cannot be picked from a
+//   higher priority caller (directly or through the resolver).
+//
+// * For every higher priority callee version than the best match, there
+//   is a higher priority caller version whose feature set availability
+//   is implied by the callee's feature set.
+//
+static bool OptimizeNonTrivialIFuncs(
+    Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
+  bool Changed = false;
+
+  // Cache containing the mask constructed from a function's target features.
+  DenseMap<Function *, uint64_t> FeatureMask;
+
+  for (GlobalIFunc &IF : M.ifuncs()) {
+    if (IF.isInterposable())
+      continue;
+
+    Function *Resolver = IF.getResolverFunction();
+    if (!Resolver)
+      continue;
+
+    if (Resolver->isInterposable())
+      continue;
+
+    TargetTransformInfo &TTI = GetTTI(*Resolver);
+
+    // Discover the callee versions.
+    SmallVector<Function *> Callees;
+    if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
+          if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
+            if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
+              return true;
+          return false;
+        }))
+      continue;
+
+    assert(!Callees.empty() && "Expecting successful collection of versions");
+
+    // Cache the feature mask for each callee.
+    for (Function *Callee : Callees) {
+      auto [It, Inserted] = FeatureMask.try_emplace(Callee);
+      if (Inserted)
+        It->second = TTI.getFeatureMask(*Callee);
+    }
+
+    // Sort the callee versions in decreasing priority order.
+    sort(Callees, [&](auto *LHS, auto *RHS) {
+      return FeatureMask[LHS] > FeatureMask[RHS];
+    });
+
+    // Find the callsites and cache the feature mask for each caller.
+    SmallVector<Function *> Callers;
+    DenseMap<Function *, SmallVector<CallBase *>> CallSites;
+    for (User *U : IF.users()) {
+      if (auto *CB = dyn_cast<CallBase>(U)) {
+        if (CB->getCalledOperand() == &IF) {
+          Function *Caller = CB->getFunction();
+          auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
+          if (FeatInserted)
+            FeatIt->second = TTI.getFeatureMask(*Caller);
+          auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
+          if (CallInserted)
+            Callers.push_back(Caller);
+          CallIt->second.push_back(CB);
+        }
+      }
+    }
+
+    // Sort the caller versions in decreasing priority order.
+    sort(Callers, [&](auto *LHS, auto *RHS) {
+      return FeatureMask[LHS] > FeatureMask[RHS];
+    });
+
+    auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };
+
+    // Index to the highest priority candidate.
+    unsigned I = 0;
+    // Now try to redirect calls starting from higher priority callers.
+    for (Function *Caller : Callers) {
+      assert(I < Callees.size() && "Found callers of equal priority");
+
+      Function *Callee = Callees[I];
+      uint64_t CallerBits = FeatureMask[Caller];
+      uint64_t CalleeBits = FeatureMask[Callee];
+
+      // In the case of FMV callers, we know that all higher priority callers
+      // than the current one did not get selected at runtime, which helps
+      // reason about the callees (if they have versions that mandate presence
+      // of the features which we already know are unavailable on this target).
+      if (TTI.isMultiversionedFunction(*Caller)) {
+        // If the feature set of the caller implies the feature set of the
+        // highest priority candidate then it shall be picked. In case of
+        // identical sets advance the candidate index one position.
+        if (CallerBits == CalleeBits)
+          ++I;
+        else if (!implies(CallerBits, CalleeBits)) {
+          // Keep advancing the candidate index as long as the caller's
+          // features are a subset of the current candidate's.
+          while (implies(CalleeBits, CallerBits)) {
+            if (++I == Callees.size())
+              break;
+            CalleeBits = FeatureMask[Callees[I]];
+          }
+          continue;
+        }
+      } else {
+        // We can't reason much about non-FMV callers. Just pick the highest
+        // priority callee if it matches, otherwise bail.
+        if (I > 0 || !implies(CallerBits, CalleeBits))
+          continue;
+      }
+      auto &Calls = CallSites[Caller];
+      for (CallBase *CS : Calls)
+        CS->setCalledOperand(Callee);
+      Changed = true;
+    }
+    if (IF.use_empty() ||
+        all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
----------------
labrinea wrote:

That is probably a leftover from the time we had ifunc aliases. Subject to removal.

https://github.com/llvm/llvm-project/pull/87939


More information about the llvm-commits mailing list