[llvm] e6b9fc4 - [FuncSpec] Global ranking of specialisations

Momchil Velikov via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 14 07:35:20 PST 2022


Author: Momchil Velikov
Date: 2022-12-14T15:30:32Z
New Revision: e6b9fc4c8be00660c5e1f1605a6a5d92475bdba7

URL: https://github.com/llvm/llvm-project/commit/e6b9fc4c8be00660c5e1f1605a6a5d92475bdba7
DIFF: https://github.com/llvm/llvm-project/commit/e6b9fc4c8be00660c5e1f1605a6a5d92475bdba7.diff

LOG: [FuncSpec] Global ranking of specialisations

The `FunctionSpecialization` pass chooses specializations among the
opportunities presented by a single function and its calls,
progressively penalizing subsequent specialization attempts by
artificially increasing the cost of a specialization, depending on how
many specialization were applied before. Thus the chosen
specializations are sensitive to the order the functions appear in the
module and may be worse than others, had those others been considered
earlier.

This patch makes the `FunctionSpecialization` pass rank the
specializations globally, i.e.  choose the "best" specializations
among the all possible specializations in the module, for all
functions.

Since this involved quite a bit of redesign of the pass data
structures, this patch also carries:

  * removal of duplicate specializations

  * optimization of call sites update, by collecting per
    specialization the list of call sites that can be directly
    rewritten, without prior expensive check if the call constants and
    their positions match those of the specialized function.

A bit of a write-up up about the FuncSpec data structures and
operation:

Each potential function specialisation is kept in a single vector
(`AllSpecs` in `FunctionSpecializer::run`).  This vector is populated
by `FunctionSpecializer::findSpecializations`.

The `findSpecializations` member function has a local `DenseMap` to
eliminate duplicates - with each call to the current function,
`findSpecializations` builds a specialisation signature (`SpecSig`)
and looks it in the duplicates map. If the signature is present, the
function records the call to rewrite into the existing specialisation
instance.  If the signature is absent, it means we have a new
specialisation instance - the function calculates the gain and creates
a new entry in `AllSpecs`. Negative gain specialisation are ignored at
this point, unless forced.

The potential specialisations for a function form a contiguous range
in the `AllSpecs` [1]. This range is recorded in `SpecMap SM`, so we
can quickly find all specialisations for a function.

Once we have all the potential specialisations with their gains we
need to choose the best ones, which fit in the module specialisation
budget. This is done by using a max-heap (`std::make_heap`,
`std::push_heap`, etc) to find the best `NSpec` specialisations with a
single traversal of the `AllSpecs` vector. The heap itself is
contained with a small vector (`BestSpecs`) of indices into
`AllSpecs`, since elements of `AllSpecs` are a bit too heavy to
shuffle around.

Next the chosen specialisation are performed, that is, functions
cloned, `SCCPSolver` primed, and known call sites updated.

Then we run the `SCCPSolver` to propagate constants in the cloned
functions, after which we walk the calls of the original functions to
update them to call the specialised functions.

---

[1] This range may contain specialisation that were discarded and is
not ordered in any way. One alternative design is to keep a vector
indices of all specialisations for this function (which would
initially be, `i`, `i+1`, `i+2`, etc) and later sort them by gain,
pushing non-applied ones to the back. This has the potential to speed
`updateCallSites` up.

Reviewed By: ChuanqiXu, labrinea

Differential Revision: https://reviews.llvm.org/D139346

Change-Id: I708851eb38f07c42066637085b833ca91b195998

Added: 
    llvm/test/Transforms/FunctionSpecialization/global-rank.ll

Modified: 
    llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
    llvm/include/llvm/Transforms/Utils/SCCPSolver.h
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll
    llvm/test/Transforms/FunctionSpecialization/specialization-order.ll
    llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
index bd278257f9314..8db246f739ab1 100644
--- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
+++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
@@ -60,18 +60,57 @@
 using namespace llvm;
 
 namespace llvm {
-// Bookkeeping struct to pass data from the analysis and profitability phase
-// to the actual transform helper functions.
-struct SpecializationInfo {
-  SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs.
-  InstructionCost Gain;         // Profitability: Gain = Bonus - Cost.
-  Function *Clone;              // The definition of the specialized function.
+// Specialization signature, used to uniquely designate a specialization within
+// a function.
+struct SpecSig {
+  // Hashing support, used to distinguish between ordinary, empty, or tombstone
+  // keys.
+  unsigned Key = 0;
+  SmallVector<ArgInfo, 4> Args;
+
+  bool operator==(const SpecSig &Other) const {
+    if (Key != Other.Key || Args.size() != Other.Args.size())
+      return false;
+    for (size_t I = 0; I < Args.size(); ++I)
+      if (Args[I] != Other.Args[I])
+        return false;
+    return true;
+  }
+
+  friend hash_code hash_value(const SpecSig &S) {
+    return hash_combine(hash_value(S.Key),
+                        hash_combine_range(S.Args.begin(), S.Args.end()));
+  }
+};
+
+// Specialization instance.
+struct Spec {
+  // Original function.
+  Function *F;
+
+  // Cloned function, a specialized version of the original one.
+  Function *Clone = nullptr;
+
+  // Specialization signature.
+  SpecSig Sig;
+
+  // Profitability of the specialization.
+  InstructionCost Gain;
+
+  // List of call sites, matching this specialization.
+  SmallVector<CallBase *> CallSites;
+
+  Spec(Function *F, const SpecSig &S, InstructionCost G)
+      : F(F), Sig(S), Gain(G) {}
+  Spec(Function *F, const SpecSig &&S, InstructionCost G)
+      : F(F), Sig(S), Gain(G) {}
 };
 
-using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>;
-// We are using MapVector because it guarantees deterministic iteration
-// order across executions.
-using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>;
+// Map of potential specializations for each function. The FunctionSpecializer
+// keeps the discovered specialisation opportunities for the module in a single
+// vector, where the specialisations of each function form a contiguous range.
+// This map's value is the beginning and the end of that range.
+using SpecMap = DenseMap<Function *, std::pair<unsigned, unsigned>>;
 
 class FunctionSpecializer {
 
@@ -137,18 +176,23 @@ class FunctionSpecializer {
   // Compute the code metrics for function \p F.
   CodeMetrics &analyzeFunction(Function *F);
 
-  /// This function decides whether it's worthwhile to specialize function
-  /// \p F based on the known constant values its arguments can take on. It
-  /// only discovers potential specialization opportunities without actually
-  /// applying them.
-  ///
-  /// \returns true if any specializations have been found.
+  /// @brief  Find potential specialization opportunities.
+  /// @param F Function to specialize
+  /// @param Cost Cost of specializing a function. Final gain is this cost
+  /// minus benefit
+  /// @param AllSpecs A vector to add potential specializations to.
+  /// @param SM  A map for a function's specialisation range
+  /// @return True, if any potential specializations were found
   bool findSpecializations(Function *F, InstructionCost Cost,
-                           SmallVectorImpl<CallSpecBinding> &WorkList);
+                           SmallVectorImpl<Spec> &AllSpecs, SpecMap &SM);
 
   bool isCandidateFunction(Function *F);
 
-  Function *createSpecialization(Function *F, CallSpecBinding &Specialization);
+  /// @brief Create a specialization of \p F and prime the SCCPSolver
+  /// @param F Function to specialize
+  /// @param S Which specialization to create
+  /// @return The new, cloned function
+  Function *createSpecialization(Function *F, const SpecSig &S);
 
   /// Compute and return the cost of specializing function \p F.
   InstructionCost getSpecializationCost(Function *F);
@@ -165,9 +209,11 @@ class FunctionSpecializer {
   /// have a constant value. Return that constant.
   Constant *getCandidateConstant(Value *V);
 
-  /// Redirects callsites of function \p F to its specialized copies.
-  void updateCallSites(Function *F,
-                       SmallVectorImpl<CallSpecBinding> &Specializations);
+  /// @brief Find and update calls to \p F, which match a specialization
+  /// @param F Orginal function
+  /// @param Begin Start of a range of possibly matching specialisations
+  /// @param End End of a range (exclusive) of possibly matching specialisations
+  void updateCallSites(Function *F, const Spec *Begin, const Spec *End);
 };
 } // namespace llvm
 

diff  --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 9e73425b07fbf..d5c99afafe30b 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -52,7 +52,17 @@ struct ArgInfo {
   Argument *Formal; // The Formal argument being analysed.
   Constant *Actual; // A corresponding actual constant argument.
 
-  ArgInfo(Argument *F, Constant *A) : Formal(F), Actual(A){};
+  ArgInfo(Argument *F, Constant *A) : Formal(F), Actual(A) {}
+
+  bool operator==(const ArgInfo &Other) const {
+    return Formal == Other.Formal && Actual == Other.Actual;
+  }
+
+  bool operator!=(const ArgInfo &Other) const { return !(*this == Other); }
+
+  friend hash_code hash_value(const ArgInfo &A) {
+    return hash_combine(hash_value(A.Formal), hash_value(A.Actual));
+  }
 };
 
 class SCCPInstVisitor;

diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 84ddeebde853d..c2f6c7eced2ea 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -234,51 +234,132 @@ void FunctionSpecializer::cleanUpSSA() {
     removeSSACopy(*F);
 }
 
+
+template <> struct llvm::DenseMapInfo<SpecSig> {
+  static inline SpecSig getEmptyKey() { return {~0U, {}}; }
+
+  static inline SpecSig getTombstoneKey() { return {~1U, {}}; }
+
+  static unsigned getHashValue(const SpecSig &S) {
+    return static_cast<unsigned>(hash_value(S));
+  }
+
+  static bool isEqual(const SpecSig &LHS, const SpecSig &RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// Attempt to specialize functions in the module to enable constant
 /// propagation across function boundaries.
 ///
 /// \returns true if at least one function is specialized.
 bool FunctionSpecializer::run() {
-  bool Changed = false;
-
+  // Find possible specializations for each function.
+  SpecMap SM;
+  SmallVector<Spec, 32> AllSpecs;
+  unsigned NumCandidates = 0;
   for (Function &F : M) {
     if (!isCandidateFunction(&F))
       continue;
 
     auto Cost = getSpecializationCost(&F);
     if (!Cost.isValid()) {
-      LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost.\n");
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost for "
+                        << F.getName() << "\n");
       continue;
     }
 
     LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
                       << F.getName() << " is " << Cost << "\n");
 
-    SmallVector<CallSpecBinding, 8> Specializations;
-    if (!findSpecializations(&F, Cost, Specializations)) {
+    if (!findSpecializations(&F, Cost, AllSpecs, SM)) {
       LLVM_DEBUG(
-          dbgs() << "FnSpecialization: No possible specializations found\n");
+          dbgs() << "FnSpecialization: No possible specializations found for "
+                 << F.getName() << "\n");
       continue;
     }
 
-    Changed = true;
+    ++NumCandidates;
+  }
+
+  if (!NumCandidates) {
+    LLVM_DEBUG(
+        dbgs()
+        << "FnSpecialization: No possible specializations found in module\n");
+    return false;
+  }
+
+  // Choose the most profitable specialisations, which fit in the module
+  // specialization budget, which is derived from maximum number of
+  // specializations per specialization candidate function.
+  auto CompareGain = [&AllSpecs](unsigned I, unsigned J) {
+    return AllSpecs[I].Gain > AllSpecs[J].Gain;
+  };
+  const unsigned NSpecs =
+      std::min(NumCandidates * MaxClonesThreshold, unsigned(AllSpecs.size()));
+  SmallVector<unsigned> BestSpecs(NSpecs + 1);
+  std::iota(BestSpecs.begin(), BestSpecs.begin() + NSpecs, 0);
+  if (AllSpecs.size() > NSpecs) {
+    LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
+                      << "the maximum number of clones threshold.\n"
+                      << "FnSpecialization: Specializing the "
+                      << NSpecs
+                      << " most profitable candidates.\n");
+    std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareGain);
+    for (unsigned I = NSpecs, N = AllSpecs.size(); I < N; ++I) {
+      BestSpecs[NSpecs] = I;
+      std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain);
+      std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain);
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "FnSpecialization: List of specializations \n";
+             for (unsigned I = 0; I < NSpecs; ++I) {
+               const Spec &S = AllSpecs[BestSpecs[I]];
+               dbgs() << "FnSpecialization: Function " << S.F->getName()
+                      << " , gain " << S.Gain << "\n";
+               for (const ArgInfo &Arg : S.Sig.Args)
+                 dbgs() << "FnSpecialization:   FormalArg = "
+                        << Arg.Formal->getNameOrAsOperand()
+                        << ", ActualArg = " << Arg.Actual->getNameOrAsOperand()
+                        << "\n";
+             });
+
+  // Create the chosen specializations.
+  SmallPtrSet<Function *, 8> OriginalFuncs;
+  SmallVector<Function *> Clones;
+  for (unsigned I = 0; I < NSpecs; ++I) {
+    Spec &S = AllSpecs[BestSpecs[I]];
+    S.Clone = createSpecialization(S.F, S.Sig);
+
+    // Update the known call sites to call the clone.
+    for (CallBase *Call : S.CallSites) {
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
+                        << " to call " << S.Clone->getName() << "\n");
+      Call->setCalledFunction(S.Clone);
+    }
+
+    Clones.push_back(S.Clone);
+    OriginalFuncs.insert(S.F);
+  }
 
-    SmallVector<Function *, 4> Clones;
-    for (CallSpecBinding &Specialization : Specializations)
-      Clones.push_back(createSpecialization(&F, Specialization));
+  Solver.solveWhileResolvedUndefsIn(Clones);
 
-    Solver.solveWhileResolvedUndefsIn(Clones);
-    updateCallSites(&F, Specializations);
+  // Update the rest of the call sites - these are the recursive calls, calls
+  // to discarded specialisations and calls that may match a specialisation
+  // after the solver runs.
+  for (Function *F : OriginalFuncs) {
+    auto [Begin, End] = SM[F];
+    updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End);
   }
 
   promoteConstantStackValues();
-
   LLVM_DEBUG(if (NbFunctionsSpecialized) dbgs()
              << "FnSpecialization: Specialized " << NbFunctionsSpecialized
              << " functions in module " << M.getName() << "\n");
 
   NumFuncSpecialized += NbFunctionsSpecialized;
-  return Changed;
+  return true;
 }
 
 void FunctionSpecializer::removeDeadFunctions() {
@@ -319,32 +400,30 @@ static Function *cloneCandidateFunction(Function *F) {
   return Clone;
 }
 
-/// This function decides whether it's worthwhile to specialize function
-/// \p F based on the known constant values its arguments can take on. It
-/// only discovers potential specialization opportunities without actually
-/// applying them.
-///
-/// \returns true if any specializations have been found.
-bool FunctionSpecializer::findSpecializations(
-    Function *F, InstructionCost Cost,
-    SmallVectorImpl<CallSpecBinding> &WorkList) {
+bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost,
+                                              SmallVectorImpl<Spec> &AllSpecs,
+                                              SpecMap &SM) {
+  // A mapping from a specialisation signature to the index of the respective
+  // entry in the all specialisation array. Used to ensure uniqueness of
+  // specialisations.
+  DenseMap<SpecSig, unsigned> UM;
+
   // Get a list of interesting arguments.
-  SmallVector<Argument *, 4> Args;
+  SmallVector<Argument *> Args;
   for (Argument &Arg : F->args())
     if (isArgumentInteresting(&Arg))
       Args.push_back(&Arg);
 
-  if (!Args.size())
+  if (Args.empty())
     return false;
 
-  // Find all the call sites for the function.
-  SpecializationMap Specializations;
+  bool Found = false;
   for (User *U : F->users()) {
     if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
       continue;
     auto &CS = *cast<CallBase>(U);
 
-    // Skip irrelevant users.
+    // The user instruction does not call our function.
     if (CS.getCalledFunction() != F)
       continue;
 
@@ -358,62 +437,58 @@ bool FunctionSpecializer::findSpecializations(
     if (!Solver.isBlockExecutable(CS.getParent()))
       continue;
 
-    // Examine arguments and create specialization candidates from call sites
-    // with constant arguments.
-    bool Added = false;
+    // Examine arguments and create a specialisation candidate from the
+    // constant operands of this call site.
+    SpecSig S;
     for (Argument *A : Args) {
       Constant *C = getCandidateConstant(CS.getArgOperand(A->getArgNo()));
       if (!C)
         continue;
-
-      if (!Added) {
-        Specializations[&CS] = {{}, 0 - Cost, nullptr};
-        Added = true;
-      }
-
-      SpecializationInfo &S = Specializations.back().second;
-      S.Gain += getSpecializationBonus(A, C, Solver.getLoopInfo(*F));
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
+                        << A->getName() << " : " << C->getNameOrAsOperand()
+                        << "\n");
       S.Args.push_back({A, C});
     }
-    Added = false;
-  }
 
-  // Remove unprofitable specializations.
-  if (!ForceFunctionSpecialization)
-    Specializations.remove_if(
-        [](const auto &Entry) { return Entry.second.Gain <= 0; });
-
-  // Clear the MapVector and return the underlying vector.
-  WorkList = Specializations.takeVector();
+    if (S.Args.empty())
+      continue;
 
-  // Sort the candidates in descending order.
-  llvm::stable_sort(WorkList, [](const auto &L, const auto &R) {
-    return L.second.Gain > R.second.Gain;
-  });
+    // Check if we have encountered the same specialisation already.
+    if (auto It = UM.find(S); It != UM.end()) {
+      // Existing specialisation. Add the call to the list to rewrite, unless
+      // it's a recursive call. A specialisation, generated because of a
+      // recursive call may end up as not the best specialisation for all
+      // the cloned instances of this call, which result from specialising
+      // functions. Hence we don't rewrite the call directly, but match it with
+      // the best specialisation once all specialisations are known.
+      if (CS.getFunction() == F)
+        continue;
+      const unsigned Index = It->second;
+      AllSpecs[Index].CallSites.push_back(&CS);
+    } else {
+      // Calculate the specialisation gain.
+      InstructionCost Gain = 0 - Cost;
+      for (ArgInfo &A : S.Args)
+        Gain +=
+            getSpecializationBonus(A.Formal, A.Actual, Solver.getLoopInfo(*F));
+
+      // Discard unprofitable specialisations.
+      if (!ForceFunctionSpecialization && Gain <= 0)
+        continue;
 
-  // Truncate the worklist to 'MaxClonesThreshold' candidates if necessary.
-  if (WorkList.size() > MaxClonesThreshold) {
-    LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
-                      << "the maximum number of clones threshold.\n"
-                      << "FnSpecialization: Truncating worklist to "
-                      << MaxClonesThreshold << " candidates.\n");
-    WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
+      // Create a new specialisation entry.
+      auto &Spec = AllSpecs.emplace_back(F, S, Gain);
+      if (CS.getFunction() != F)
+        Spec.CallSites.push_back(&CS);
+      const unsigned Index = AllSpecs.size() - 1;
+      UM[S] = Index;
+      if (auto [It, Inserted] = SM.try_emplace(F, Index, Index + 1); !Inserted)
+        It->second.second = Index + 1;
+      Found = true;
+    }
   }
 
-  LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
-                    << F->getName() << "\n";
-             for (const auto &Entry
-                  : WorkList) {
-               dbgs() << "FnSpecialization:   Gain = " << Entry.second.Gain
-                      << "\n";
-               for (const ArgInfo &Arg : Entry.second.Args)
-                 dbgs() << "FnSpecialization:   FormalArg = "
-                        << Arg.Formal->getNameOrAsOperand()
-                        << ", ActualArg = " << Arg.Actual->getNameOrAsOperand()
-                        << "\n";
-             });
-
-  return !WorkList.empty();
+  return Found;
 }
 
 bool FunctionSpecializer::isCandidateFunction(Function *F) {
@@ -449,16 +524,13 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) {
   return true;
 }
 
-Function *
-FunctionSpecializer::createSpecialization(Function *F,
-                                          CallSpecBinding &Specialization) {
+Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) {
   Function *Clone = cloneCandidateFunction(F);
-  Specialization.second.Clone = Clone;
 
   // Initialize the lattice state of the arguments of the function clone,
   // marking the argument on which we specialized the function constant
   // with the given value.
-  Solver.markArgInFuncSpecialization(Clone, Specialization.second.Args);
+  Solver.markArgInFuncSpecialization(Clone, S.Args);
 
   Solver.addArgumentTrackedFunction(Clone);
   Solver.markBlockExecutable(&Clone->front());
@@ -484,9 +556,8 @@ InstructionCost FunctionSpecializer::getSpecializationCost(Function *F) {
     return InstructionCost::getInvalid();
 
   // Otherwise, set the specialization cost to be the cost of all the
-  // instructions in the function and penalty for specializing more functions.
-  unsigned Penalty = NbFunctionsSpecialized + 1;
-  return Metrics.NumInsts * InlineConstants::getInstrCost() * Penalty;
+  // instructions in the function.
+  return Metrics.NumInsts * InlineConstants::getInstrCost();
 }
 
 static InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
@@ -611,11 +682,14 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) {
   const ValueLatticeElement &LV = Solver.getLatticeValueFor(A);
   if (LV.isUnknownOrUndef() || LV.isConstant() ||
       (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) {
-    LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument "
+    LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter "
                       << A->getNameOrAsOperand() << " is already constant\n");
     return false;
   }
 
+  LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter "
+                    << A->getNameOrAsOperand() << "\n");
+
   return true;
 }
 
@@ -651,44 +725,45 @@ Constant *FunctionSpecializer::getCandidateConstant(Value *V) {
       return nullptr;
   }
 
-  LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
-                    << V->getNameOrAsOperand() << "\n");
-
   return C;
 }
 
-/// Redirects callsites of function \p F to its specialized copies.
-void FunctionSpecializer::updateCallSites(
-    Function *F, SmallVectorImpl<CallSpecBinding> &Specializations) {
-  SmallVector<CallBase *, 8> ToUpdate;
-  for (User *U : F->users()) {
-    if (auto *CS = dyn_cast<CallBase>(U))
-      if (CS->getCalledFunction() == F &&
-          Solver.isBlockExecutable(CS->getParent()))
-        ToUpdate.push_back(CS);
-  }
+void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin,
+                                          const Spec *End) {
+  // Collect the call sites that need updating.
+  SmallVector<CallBase *> ToUpdate;
+  for (User *U : F->users())
+    if (auto *CS = dyn_cast<CallBase>(U);
+        CS && CS->getCalledFunction() == F &&
+        Solver.isBlockExecutable(CS->getParent()))
+      ToUpdate.push_back(CS);
 
   unsigned NCallsLeft = ToUpdate.size();
   for (CallBase *CS : ToUpdate) {
-    // Decrement the counter if the callsite is either recursive or updated.
     bool ShouldDecrementCount = CS->getFunction() == F;
-    for (CallSpecBinding &Specialization : Specializations) {
-      Function *Clone = Specialization.second.Clone;
-      SmallVectorImpl<ArgInfo> &Args = Specialization.second.Args;
 
-      if (any_of(Args, [CS, this](const ArgInfo &Arg) {
+    // Find the best matching specialisation.
+    const Spec *BestSpec = nullptr;
+    for (const Spec &S : make_range(Begin, End)) {
+      if (!S.Clone || (BestSpec && S.Gain <= BestSpec->Gain))
+        continue;
+
+      if (any_of(S.Sig.Args, [CS, this](const ArgInfo &Arg) {
             unsigned ArgNo = Arg.Formal->getArgNo();
             return getCandidateConstant(CS->getArgOperand(ArgNo)) != Arg.Actual;
           }))
         continue;
 
-      LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call site " << *CS
-                        << " with " << Clone->getName() << "\n");
+      BestSpec = &S;
+    }
 
-      CS->setCalledFunction(Clone);
+    if (BestSpec) {
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *CS
+                        << " to call " << BestSpec->Clone->getName() << "\n");
+      CS->setCalledFunction(BestSpec->Clone);
       ShouldDecrementCount = true;
-      break;
     }
+
     if (ShouldDecrementCount)
       --NCallsLeft;
   }

diff  --git a/llvm/test/Transforms/FunctionSpecialization/global-rank.ll b/llvm/test/Transforms/FunctionSpecialization/global-rank.ll
new file mode 100644
index 0000000000000..db799d6ce9238
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/global-rank.ll
@@ -0,0 +1,51 @@
+; RUN: opt -S --passes=ipsccp  -specialize-functions -func-specialization-max-clones=1 < %s | FileCheck %s
+define internal i32 @f(i32 noundef %x, ptr nocapture noundef readonly %p, ptr nocapture noundef readonly %q) noinline {
+entry:
+  %call = tail call i32 %p(i32 noundef %x)
+  %call1 = tail call i32 %q(i32 noundef %x)
+  %add = add nsw i32 %call1, %call
+  ret i32 %add
+}
+
+define internal i32 @g(i32 noundef %x, ptr nocapture noundef readonly %p, ptr nocapture noundef readonly %q) noinline {
+entry:
+  %call = tail call i32 %p(i32 noundef %x)
+  %call1 = tail call i32 %q(i32 noundef %x)
+  %sub = sub nsw i32 %call, %call1
+  ret i32 %sub
+}
+
+define i32 @h0(i32 noundef %x) {
+entry:
+  %call = tail call i32 @f(i32 noundef %x, ptr noundef nonnull @pp, ptr noundef nonnull @qq)
+  ret i32 %call
+}
+
+define i32 @h1(i32 noundef %x) {
+entry:
+  %call = tail call i32 @f(i32 noundef %x, ptr noundef nonnull @qq, ptr noundef nonnull @pp)
+  ret i32 %call
+}
+
+define i32 @h2(i32 noundef %x, ptr nocapture noundef readonly %p) {
+entry:
+  %call = tail call i32 @g(i32 noundef %x, ptr noundef %p, ptr noundef nonnull @pp)
+  ret i32 %call
+}
+
+define i32 @h3(i32 noundef %x, ptr nocapture noundef readonly %p) {
+entry:
+  %call = tail call i32 @g(i32 noundef %x, ptr noundef %p, ptr noundef nonnull @qq)
+  ret i32 %call
+}
+
+declare i32 @pp(i32 noundef)
+declare i32 @qq(i32 noundef)
+
+
+; Check that the global ranking causes two specialisations of
+; `f` to be chosen, whereas the old algorithm would choose
+; one specialsation of `f` and one of `g`.
+
+; CHECK-DAG: define internal i32 @f.1
+; CHECK-DAG: define internal i32 @f.2

diff  --git a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll
index 8ceb251908dc6..ff8341a584069 100644
--- a/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/identical-specializations.ll
@@ -6,14 +6,14 @@ define i64 @main(i64 %x, i64 %y, i1 %flag) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[FLAG:%.*]], label [[PLUS:%.*]], label [[MINUS:%.*]]
 ; CHECK:       plus:
-; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @compute.1(i64 [[X:%.*]], i64 [[Y:%.*]], ptr @plus, ptr @minus)
+; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @compute.2(i64 [[X:%.*]], i64 [[Y:%.*]], ptr @plus, ptr @minus)
 ; CHECK-NEXT:    br label [[MERGE:%.*]]
 ; CHECK:       minus:
-; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @compute.2(i64 [[X]], i64 [[Y]], ptr @minus, ptr @plus)
+; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @compute.3(i64 [[X]], i64 [[Y]], ptr @minus, ptr @plus)
 ; CHECK-NEXT:    br label [[MERGE]]
 ; CHECK:       merge:
 ; CHECK-NEXT:    [[PH:%.*]] = phi i64 [ [[CMP0]], [[PLUS]] ], [ [[CMP1]], [[MINUS]] ]
-; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.1(i64 [[PH]], i64 42, ptr @plus, ptr @minus)
+; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.2(i64 [[PH]], i64 42, ptr @plus, ptr @minus)
 ; CHECK-NEXT:    ret i64 [[CMP2]]
 ;
 entry:
@@ -62,18 +62,18 @@ entry:
 
 ; CHECK-LABEL: @compute.1
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]])
-; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]])
-; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute(i64 [[X]], i64 [[Y]], ptr @plus, ptr @plus)
+; CHECK-NEXT:    [[CMP0:%.*]] = call i64 %binop1(i64 [[X:%.*]], i64 [[Y:%.*]])
+; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]])
+; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr %binop1, ptr @plus)
 
 ; CHECK-LABEL: @compute.2
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @minus(i64 [[X:%.*]], i64 [[Y:%.*]])
-; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]])
-; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.2(i64 [[X]], i64 [[Y]], ptr @minus, ptr @plus)
+; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]])
+; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]])
+; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr @plus, ptr @plus)
 
 ; CHECK-LABEL: @compute.3
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @plus(i64 [[X:%.*]], i64 [[Y:%.*]])
-; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @minus(i64 [[X]], i64 [[Y]])
-; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute(i64 [[X]], i64 [[Y]], ptr @plus, ptr @plus)
+; CHECK-NEXT:    [[CMP0:%.*]] = call i64 @minus(i64 [[X:%.*]], i64 [[Y:%.*]])
+; CHECK-NEXT:    [[CMP1:%.*]] = call i64 @plus(i64 [[X]], i64 [[Y]])
+; CHECK-NEXT:    [[CMP2:%.*]] = call i64 @compute.3(i64 [[X]], i64 [[Y]], ptr @minus, ptr @plus)

diff  --git a/llvm/test/Transforms/FunctionSpecialization/specialization-order.ll b/llvm/test/Transforms/FunctionSpecialization/specialization-order.ll
index 83b1ca19c5e5c..0ce526a6efe75 100644
--- a/llvm/test/Transforms/FunctionSpecialization/specialization-order.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/specialization-order.ll
@@ -21,7 +21,7 @@ entry:
 
 define dso_local i32 @g0(i32 %x, i32 %y) {
 ; CHECK-LABEL: @g0
-; CHECK:       call i32 @f.2(i32 [[X:%.*]], i32 [[Y:%.*]])
+; CHECK:       call i32 @f.3(i32 [[X:%.*]], i32 [[Y:%.*]])
 entry:
   %call = tail call i32 @f(i32 %x, i32 %y, ptr @add, ptr @add)
   ret i32 %call
@@ -30,7 +30,7 @@ entry:
 
 define dso_local i32 @g1(i32 %x, i32 %y) {
 ; CHECK-LABEL: @g1(
-; CHECK:       call i32 @f.1(i32 [[X:%.*]], i32 [[Y:%.*]])
+; CHECK:       call i32 @f.2(i32 [[X:%.*]], i32 [[Y:%.*]])
 entry:
   %call = tail call i32 @f(i32 %x, i32 %y, ptr @sub, ptr @add)
   ret i32 %call
@@ -38,7 +38,7 @@ entry:
 
 define dso_local i32 @g2(i32 %x, i32 %y, ptr %v) {
 ; CHECK-LABEL @g2
-; CHECK       call i32 @f.3(i32 [[X:%.*]], i32 [[Y:%.*]], ptr [[V:%.*]])
+; CHECK       call i32 @f.1(i32 [[X:%.*]], i32 [[Y:%.*]], ptr [[V:%.*]])
 entry:
   %call = tail call i32 @f(i32 %x, i32 %y, ptr @sub, ptr %v)
   ret i32 %call
@@ -46,13 +46,13 @@ entry:
 
 ; CHECK-LABEL: define {{.*}} i32 @f.1
 ; CHECK:       call i32 @sub(i32 %x, i32 %y)
-; CHECK-NEXT:  call i32 @add(i32 %x, i32 %y)
+; CHECK-NEXT:  call i32 %v(i32 %x, i32 %y)
 
 ; CHECK-LABEL: define {{.*}} i32 @f.2
-; CHECK:       call i32 @add(i32 %x, i32 %y)
-; CHECK-NEXT   call i32 @add(i32 %x, i32 %y)
+; CHECK:       call i32 @sub(i32 %x, i32 %y)
+; CHECK-NEXT:  call i32 @add(i32 %x, i32 %y)
 
 ; CHECK-LABEL: define {{.*}} i32 @f.3
-; CHECK:       call i32 @sub(i32 %x, i32 %y)
-; CHECK-NEXT:  call i32 %v(i32 %x, i32 %y)
+; CHECK:       call i32 @add(i32 %x, i32 %y)
+; CHECK-NEXT   call i32 @add(i32 %x, i32 %y)
 

diff  --git a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll
index f576764281e4b..6ef938739262a 100644
--- a/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/specialize-multiple-arguments.ll
@@ -52,25 +52,25 @@ define i64 @main(i64 %x, i64 %y, i1 %flag) {
 ; TWO-NEXT:    [[TMP0:%.*]] = call i64 @compute(i64 [[X:%.*]], i64 [[Y:%.*]], ptr @power, ptr @mul)
 ; TWO-NEXT:    br label [[MERGE:%.*]]
 ; TWO:       minus:
-; TWO-NEXT:    [[TMP1:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr @plus, ptr @minus)
+; TWO-NEXT:    [[TMP1:%.*]] = call i64 @compute.2(i64 [[X]], i64 [[Y]], ptr @plus, ptr @minus)
 ; TWO-NEXT:    br label [[MERGE]]
 ; TWO:       merge:
 ; TWO-NEXT:    [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ]
-; TWO-NEXT:    [[TMP3:%.*]] = call i64 @compute.2(i64 [[TMP2]], i64 42, ptr @minus, ptr @power)
+; TWO-NEXT:    [[TMP3:%.*]] = call i64 @compute.1(i64 [[TMP2]], i64 42, ptr @minus, ptr @power)
 ; TWO-NEXT:    ret i64 [[TMP3]]
 ;
 ; THREE-LABEL: @main(
 ; THREE-NEXT:  entry:
 ; THREE-NEXT:    br i1 [[FLAG:%.*]], label [[PLUS:%.*]], label [[MINUS:%.*]]
 ; THREE:       plus:
-; THREE-NEXT:    [[TMP0:%.*]] = call i64 @compute.3(i64 [[X:%.*]], i64 [[Y:%.*]], ptr @power, ptr @mul)
+; THREE-NEXT:    [[TMP0:%.*]] = call i64 @compute.1(i64 [[X:%.*]], i64 [[Y:%.*]], ptr @power, ptr @mul)
 ; THREE-NEXT:    br label [[MERGE:%.*]]
 ; THREE:       minus:
-; THREE-NEXT:    [[TMP1:%.*]] = call i64 @compute.1(i64 [[X]], i64 [[Y]], ptr @plus, ptr @minus)
+; THREE-NEXT:    [[TMP1:%.*]] = call i64 @compute.2(i64 [[X]], i64 [[Y]], ptr @plus, ptr @minus)
 ; THREE-NEXT:    br label [[MERGE]]
 ; THREE:       merge:
 ; THREE-NEXT:    [[TMP2:%.*]] = phi i64 [ [[TMP0]], [[PLUS]] ], [ [[TMP1]], [[MINUS]] ]
-; THREE-NEXT:    [[TMP3:%.*]] = call i64 @compute.2(i64 [[TMP2]], i64 42, ptr @minus, ptr @power)
+; THREE-NEXT:    [[TMP3:%.*]] = call i64 @compute.3(i64 [[TMP2]], i64 42, ptr @minus, ptr @power)
 ; THREE-NEXT:    ret i64 [[TMP3]]
 ;
 entry:
@@ -94,8 +94,8 @@ merge:
 ;
 ; THREE-LABEL: define internal i64 @compute.1(i64 %x, i64 %y, ptr %binop1, ptr %binop2) {
 ; THREE-NEXT:  entry:
-; THREE-NEXT:    [[TMP0:%.+]] = call i64 @plus(i64 %x, i64 %y)
-; THREE-NEXT:    [[TMP1:%.+]] = call i64 @minus(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP0:%.+]] = call i64 @power(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP1:%.+]] = call i64 @mul(i64 %x, i64 %y)
 ; THREE-NEXT:    [[TMP2:%.+]] = add i64 [[TMP0]], [[TMP1]]
 ; THREE-NEXT:    [[TMP3:%.+]] = sdiv i64 [[TMP2]], %x
 ; THREE-NEXT:    [[TMP4:%.+]] = sub i64 [[TMP3]], %y
@@ -105,8 +105,8 @@ merge:
 ;
 ; THREE-LABEL: define internal i64 @compute.2(i64 %x, i64 %y, ptr %binop1, ptr %binop2) {
 ; THREE-NEXT:  entry:
-; THREE-NEXT:    [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 %y)
-; THREE-NEXT:    [[TMP1:%.+]] = call i64 @power(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP0:%.+]] = call i64 @plus(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP1:%.+]] = call i64 @minus(i64 %x, i64 %y)
 ; THREE-NEXT:    [[TMP2:%.+]] = add i64 [[TMP0]], [[TMP1]]
 ; THREE-NEXT:    [[TMP3:%.+]] = sdiv i64 [[TMP2]], %x
 ; THREE-NEXT:    [[TMP4:%.+]] = sub i64 [[TMP3]], %y
@@ -116,8 +116,8 @@ merge:
 ;
 ; THREE-LABEL: define internal i64 @compute.3(i64 %x, i64 %y, ptr %binop1, ptr %binop2) {
 ; THREE-NEXT:  entry:
-; THREE-NEXT:    [[TMP0:%.+]] = call i64 @power(i64 %x, i64 %y)
-; THREE-NEXT:    [[TMP1:%.+]] = call i64 @mul(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP0:%.+]] = call i64 @minus(i64 %x, i64 %y)
+; THREE-NEXT:    [[TMP1:%.+]] = call i64 @power(i64 %x, i64 %y)
 ; THREE-NEXT:    [[TMP2:%.+]] = add i64 [[TMP0]], [[TMP1]]
 ; THREE-NEXT:    [[TMP3:%.+]] = sdiv i64 [[TMP2]], %x
 ; THREE-NEXT:    [[TMP4:%.+]] = sub i64 [[TMP3]], %y


        


More information about the llvm-commits mailing list