[llvm] 910eb98 - [FuncSpec][NFC] Refactor internal structures.

Alexandros Lamprineas via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 3 05:09:50 PST 2022


Author: Alexandros Lamprineas
Date: 2022-03-03T13:08:13Z
New Revision: 910eb988eb442fd25a0821022b696dcfb5ea6e27

URL: https://github.com/llvm/llvm-project/commit/910eb988eb442fd25a0821022b696dcfb5ea6e27
DIFF: https://github.com/llvm/llvm-project/commit/910eb988eb442fd25a0821022b696dcfb5ea6e27.diff

LOG: [FuncSpec][NFC] Refactor internal structures.

`ArgInfo` is reduced to only contain a pair of {formal,actual} values.
The specialized function `Fn` and the `Partial` flag are redundant in
this structure. The `Gain` is moved to a new struct `SpecializationInfo`.

The value mappings created by cloneCandidateFunction() are being used
by rewriteCallSites() for matching the formal arguments of recursive
functions.

The list of specializations is passed by reference to calculateGains()
instead of being returned by value.

The `IsPartial` flag is removed from isArgumentInteresting() and
getPossibleConstants() as it's no longer used anywhere in the code.

Differential Revision: https://reviews.llvm.org/D120753

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/SCCPSolver.h
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/lib/Transforms/Utils/SCCPSolver.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 8ffb731c44b4f..fb94b1dc20b81 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -43,6 +43,14 @@ struct AnalysisResultsForFn {
   PostDominatorTree *PDT;
 };
 
+/// Helper struct shared between Function Specialization and SCCP Solver.
+struct ArgInfo {
+  Argument *Formal; // The Formal argument being analysed.
+  Constant *Actual; // A corresponding actual constant argument.
+
+  ArgInfo(Argument *F, Constant *A) : Formal(F), Actual(A){};
+};
+
 class SCCPInstVisitor;
 
 //===----------------------------------------------------------------------===//
@@ -143,11 +151,13 @@ class SCCPSolver {
   /// Return a reference to the set of argument tracked functions.
   SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions();
 
-  /// Mark argument \p A constant with value \p C in a new function
-  /// specialization. The argument's parent function is a specialization of the
-  /// original function \p F. All other arguments of the specialization inherit
-  /// the lattice state of their corresponding values in the original function.
-  void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C);
+  /// Mark the constant argument of a new function specialization. \p F points
+  /// to the cloned function and \p Arg represents the constant argument as a
+  /// pair of {formal,actual} values (the formal argument is associated with the
+  /// original function definition). All other arguments of the specialization
+  /// inherit the lattice state of their corresponding values in the original
+  /// function.
+  void markArgInFuncSpecialization(Function *F, const ArgInfo &Arg);
 
   /// Mark all of the blocks in function \p F non-executable. Clients can used
   /// this method to erase a function from the module (e.g., if it has been

diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 5ef0161a28610..e625868fe6a83 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -112,24 +112,18 @@ static cl::opt<bool> EnableSpecializationForLiteralConstant(
 namespace {
 // Bookkeeping struct to pass data from the analysis and profitability phase
 // to the actual transform helper functions.
-struct ArgInfo {
-  Function *Fn;         // The function to perform specialisation on.
-  Argument *Formal;     // The Formal argument being analysed.
-  Constant *Actual;     // A corresponding actual constant argument.
+struct SpecializationInfo {
+  ArgInfo Arg;          // Stores the {formal,actual} argument pair.
   InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
 
-  // Flag if this will be a partial specialization, in which case we will need
-  // to keep the original function around in addition to the added
-  // specializations.
-  bool Partial = false;
-
-  ArgInfo(Function *F, Argument *A, Constant *C, InstructionCost G)
-      : Fn(F), Formal(A), Actual(C), Gain(G){};
+  SpecializationInfo(Argument *A, Constant *C, InstructionCost G)
+      : Arg(A, C), Gain(G){};
 };
 } // Anonymous namespace
 
 using FuncList = SmallVectorImpl<Function *>;
-using ConstList = SmallVectorImpl<Constant *>;
+using ConstList = SmallVector<Constant *>;
+using SpecializationList = SmallVector<SpecializationInfo>;
 
 // Helper to check if \p LV is either a constant or a constant
 // range with a single element. This should cover exactly the same cases as the
@@ -316,14 +310,15 @@ class FunctionSpecializer {
       LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
                         << F->getName() << " is " << Cost << "\n");
 
-      auto ConstArgs = calculateGains(F, Cost);
-      if (ConstArgs.empty()) {
+      SpecializationList Specializations;
+      calculateGains(F, Cost, Specializations);
+      if (Specializations.empty()) {
         LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n");
         continue;
       }
 
-      for (auto &CA : ConstArgs) {
-        specializeFunction(CA, WorkList);
+      for (SpecializationInfo &S : Specializations) {
+        specializeFunction(F, S, WorkList);
         Changed = true;
       }
     }
@@ -394,9 +389,8 @@ class FunctionSpecializer {
 
   /// Clone the function \p F and remove the ssa_copy intrinsics added by
   /// the SCCPSolver in the cloned version.
-  Function *cloneCandidateFunction(Function *F) {
-    ValueToValueMapTy EmptyMap;
-    Function *Clone = CloneFunction(F, EmptyMap);
+  Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) {
+    Function *Clone = CloneFunction(F, Mappings);
     removeSSACopy(*Clone);
     return Clone;
   }
@@ -407,21 +401,16 @@ class FunctionSpecializer {
   /// profitable to specialize. Specialization is performed on the first
   /// interesting argument. Specializations based on additional arguments will
   /// be evaluated on following iterations of the main IPSCCP solve loop.
-  SmallVector<ArgInfo> calculateGains(Function *F, InstructionCost Cost) {
-    SmallVector<ArgInfo> Worklist;
+  void calculateGains(Function *F, InstructionCost Cost,
+                      SpecializationList &WorkList) {
     // Determine if we should specialize the function based on the values the
     // argument can take on. If specialization is not profitable, we continue
     // on to the next argument.
     for (Argument &FormalArg : F->args()) {
       // Determine if this argument is interesting. If we know the argument can
-      // take on any constant values, they are collected in Constants. If the
-      // argument can only ever equal a constant value in Constants, the
-      // function will be completely specialized, and the IsPartial flag will
-      // be set to false by isArgumentInteresting (that function only adds
-      // values to the Constants list that are deemed profitable).
-      bool IsPartial = true;
-      SmallVector<Constant *> ActualArgs;
-      if (!isArgumentInteresting(&FormalArg, ActualArgs, IsPartial)) {
+      // take on any constant values, they are collected in Constants.
+      ConstList ActualArgs;
+      if (!isArgumentInteresting(&FormalArg, ActualArgs)) {
         LLVM_DEBUG(dbgs() << "FnSpecialization: Argument "
                           << FormalArg.getNameOrAsOperand()
                           << " is not interesting\n");
@@ -436,47 +425,41 @@ class FunctionSpecializer {
 
         if (Gain <= 0)
           continue;
-        Worklist.push_back({F, &FormalArg, ActualArg, Gain});
+        WorkList.push_back({&FormalArg, ActualArg, Gain});
       }
 
-      if (Worklist.empty())
+      if (WorkList.empty())
         continue;
 
       // Sort the candidates in descending order.
-      llvm::stable_sort(Worklist, [](const ArgInfo &L, const ArgInfo &R) {
+      llvm::stable_sort(WorkList, [](const SpecializationInfo &L,
+                                     const SpecializationInfo &R) {
         return L.Gain > R.Gain;
       });
 
       // Truncate the worklist to 'MaxClonesThreshold' candidates if
       // necessary.
-      if (Worklist.size() > MaxClonesThreshold) {
+      if (WorkList.size() > MaxClonesThreshold) {
         LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
                           << "the maximum number of clones threshold.\n"
                           << "FnSpecialization: Truncating worklist to "
                           << MaxClonesThreshold << " candidates.\n");
-        Worklist.erase(Worklist.begin() + MaxClonesThreshold,
-                       Worklist.end());
+        WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
       }
 
-      if (IsPartial || Worklist.size() < ActualArgs.size())
-        for (auto &ActualArg : Worklist)
-          ActualArg.Partial = true;
-
-      LLVM_DEBUG(
-        dbgs() << "FnSpecialization: Specializations for function "
-               << F->getName() << "\n";
-        for (auto &C : Worklist) {
-          dbgs() << "FnSpecialization:   FormalArg = "
-                 << C.Formal->getNameOrAsOperand() << ", ActualArg = "
-                 << C.Actual->getNameOrAsOperand() << ", Gain = "
-                 << C.Gain << "\n";
-        }
-      );
+      LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
+                        << F->getName() << "\n";
+                 for (SpecializationInfo &S : WorkList) {
+                   dbgs() << "FnSpecialization:   FormalArg = "
+                          << S.Arg.Formal->getNameOrAsOperand()
+                          << ", ActualArg = "
+                          << S.Arg.Actual->getNameOrAsOperand()
+                          << ", Gain = " << S.Gain << "\n";
+                 });
 
       // FIXME: Only one argument per function.
       break;
     }
-    return Worklist;
   }
 
   bool isCandidateFunction(Function *F) {
@@ -503,17 +486,18 @@ class FunctionSpecializer {
     return true;
   }
 
-  void specializeFunction(ArgInfo &AI, FuncList &WorkList) {
-    Function *Clone = cloneCandidateFunction(AI.Fn);
-    Argument *ClonedArg = Clone->getArg(AI.Formal->getArgNo());
+  void specializeFunction(Function *F, SpecializationInfo &S,
+                          FuncList &WorkList) {
+    ValueToValueMapTy Mappings;
+    Function *Clone = cloneCandidateFunction(F, Mappings);
 
     // Rewrite calls to the function so that they call the clone instead.
-    rewriteCallSites(AI.Fn, Clone, *ClonedArg, AI.Actual);
+    rewriteCallSites(Clone, S.Arg, Mappings);
 
     // Initialize the lattice state of the arguments of the function clone,
     // marking the argument on which we specialized the function constant
     // with the given value.
-    Solver.markArgInFuncSpecialization(AI.Fn, ClonedArg, AI.Actual);
+    Solver.markArgInFuncSpecialization(Clone, S.Arg);
 
     // Mark all the specialized functions
     WorkList.push_back(Clone);
@@ -521,14 +505,13 @@ class FunctionSpecializer {
 
     // If the function has been completely specialized, the original function
     // is no longer needed. Mark it unreachable.
-    if (AI.Fn->getNumUses() == 0 ||
-        all_of(AI.Fn->users(), [&AI](User *U) {
+    if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) {
           if (auto *CS = dyn_cast<CallBase>(U))
-            return CS->getFunction() == AI.Fn;
+            return CS->getFunction() == F;
           return false;
         })) {
-      Solver.markFunctionUnreachable(AI.Fn);
-      FullySpecialized.insert(AI.Fn);
+      Solver.markFunctionUnreachable(F);
+      FullySpecialized.insert(F);
     }
   }
 
@@ -667,15 +650,11 @@ class FunctionSpecializer {
   /// specializing the function based on the incoming values of argument \p A
   /// would result in any significant optimization opportunities. If
   /// optimization opportunities exist, the constant values of \p A on which to
-  /// specialize the function are collected in \p Constants. If the values in
-  /// \p Constants represent the complete set of values that \p A can take on,
-  /// the function will be completely specialized, and the \p IsPartial flag is
-  /// set to false.
+  /// specialize the function are collected in \p Constants.
   ///
   /// \returns true if the function should be specialized on the given
   /// argument.
-  bool isArgumentInteresting(Argument *A, ConstList &Constants,
-                             bool &IsPartial) {
+  bool isArgumentInteresting(Argument *A, ConstList &Constants) {
     // For now, don't attempt to specialize functions based on the values of
     // composite types.
     if (!A->getType()->isSingleValueType() || A->user_empty())
@@ -703,7 +682,11 @@ class FunctionSpecializer {
     //
     // TODO 2: this currently does not support constants, i.e. integer ranges.
     //
-    IsPartial = !getPossibleConstants(A, Constants);
+    getPossibleConstants(A, Constants);
+
+    if (Constants.empty())
+      return false;
+
     LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
                       << A->getNameOrAsOperand() << "\n");
     return true;
@@ -711,13 +694,8 @@ class FunctionSpecializer {
 
   /// Collect in \p Constants all the constant values that argument \p A can
   /// take on.
-  ///
-  /// \returns true if all of the values the argument can take on are constant
-  /// (e.g., the argument's parent function cannot be called with an
-  /// overdefined value).
-  bool getPossibleConstants(Argument *A, ConstList &Constants) {
+  void getPossibleConstants(Argument *A, ConstList &Constants) {
     Function *F = A->getParent();
-    bool AllConstant = true;
 
     // Iterate over all the call sites of the argument's parent function.
     for (User *U : F->users()) {
@@ -726,10 +704,8 @@ class FunctionSpecializer {
       auto &CS = *cast<CallBase>(U);
       // If the call site has attribute minsize set, that callsite won't be
       // specialized.
-      if (CS.hasFnAttr(Attribute::MinSize)) {
-        AllConstant = false;
+      if (CS.hasFnAttr(Attribute::MinSize))
         continue;
-      }
 
       // If the parent of the call site will never be executed, we don't need
       // to worry about the passed value.
@@ -738,13 +714,13 @@ class FunctionSpecializer {
 
       auto *V = CS.getArgOperand(A->getArgNo());
       if (isa<PoisonValue>(V))
-        return false;
+        return;
 
       // For now, constant expressions are fine but only if they are function
       // calls.
       if (auto *CE = dyn_cast<ConstantExpr>(V))
         if (!isa<Function>(CE->getOperand(0)))
-          return false;
+          return;
 
       // TrackValueOfGlobalVariable only tracks scalar global variables.
       if (auto *GV = dyn_cast<GlobalVariable>(V)) {
@@ -752,35 +728,30 @@ class FunctionSpecializer {
         // global values.
         if (!GV->isConstant())
           if (!SpecializeOnAddresses)
-            return false;
+            return;
 
         if (!GV->getValueType()->isSingleValueType())
-          return false;
+          return;
       }
 
       if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
                                EnableSpecializationForLiteralConstant))
         Constants.push_back(cast<Constant>(V));
-      else
-        AllConstant = false;
     }
-
-    // If the argument can only take on constant values, AllConstant will be
-    // true.
-    return AllConstant;
   }
 
   /// Rewrite calls to function \p F to call function \p Clone instead.
   ///
-  /// This function modifies calls to function \p F whose argument at index \p
-  /// ArgNo is equal to constant \p C. The calls are rewritten to call function
-  /// \p Clone instead.
+  /// This function modifies calls to function \p F as long as the actual
+  /// argument matches the one in \p Arg. Note that for recursive calls we
+  /// need to compare against the cloned formal argument.
   ///
   /// Callsites that have been marked with the MinSize function attribute won't
   /// be specialized and rewritten.
-  void rewriteCallSites(Function *F, Function *Clone, Argument &Arg,
-                        Constant *C) {
-    unsigned ArgNo = Arg.getArgNo();
+  void rewriteCallSites(Function *Clone, const ArgInfo &Arg,
+                        ValueToValueMapTy &Mappings) {
+    Function *F = Arg.Formal->getParent();
+    unsigned ArgNo = Arg.Formal->getArgNo();
     SmallVector<CallBase *, 4> CallSitesToRewrite;
     for (auto *U : F->users()) {
       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
@@ -799,8 +770,11 @@ class FunctionSpecializer {
       LLVM_DEBUG(dbgs() << "FnSpecialization:   "
                         << CS->getFunction()->getName() << " ->"
                         << *CS << "\n");
-      if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) ||
-          CS->getArgOperand(ArgNo) == C) {
+      if (/* recursive call */
+          (CS->getFunction() == Clone &&
+           CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]) ||
+          /* normal call */
+          CS->getArgOperand(ArgNo) == Arg.Actual) {
         CS->setCalledFunction(Clone);
         Solver.markOverdefined(CS);
       }

diff  --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index e930aebd1df3a..88dd5e6031ecf 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -450,7 +450,7 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
     return TrackingIncomingArguments;
   }
 
-  void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C);
+  void markArgInFuncSpecialization(Function *F, const ArgInfo &Arg);
 
   void markFunctionUnreachable(Function *F) {
     for (auto &BB : *F)
@@ -524,24 +524,25 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
   return nullptr;
 }
 
-void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, Argument *A,
-                                                  Constant *C) {
-  assert(F->arg_size() == A->getParent()->arg_size() &&
+void SCCPInstVisitor::markArgInFuncSpecialization(Function *F,
+                                                  const ArgInfo &Arg) {
+  assert(F->arg_size() == Arg.Formal->getParent()->arg_size() &&
          "Functions should have the same number of arguments");
 
-  // Mark the argument constant in the new function.
-  markConstant(A, C);
-
-  // For the remaining arguments in the new function, copy the lattice state
-  // over from the old function.
-  for (Argument *OldArg = F->arg_begin(), *NewArg = A->getParent()->arg_begin(),
-                *End = F->arg_end();
-       OldArg != End; ++OldArg, ++NewArg) {
+  Argument *NewArg = F->arg_begin();
+  Argument *OldArg = Arg.Formal->getParent()->arg_begin();
+  for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) {
 
     LLVM_DEBUG(dbgs() << "SCCP: Marking argument "
                       << NewArg->getNameOrAsOperand() << "\n");
 
-    if (NewArg != A && ValueState.count(OldArg)) {
+    if (OldArg == Arg.Formal) {
+      // Mark the argument constants in the new function.
+      markConstant(NewArg, Arg.Actual);
+    } else if (ValueState.count(OldArg)) {
+      // For the remaining arguments in the new function, copy the lattice state
+      // over from the old function.
+      //
       // Note: This previously looked like this:
       // ValueState[NewArg] = ValueState[OldArg];
       // This is incorrect because the DenseMap class may resize the underlying
@@ -1716,9 +1717,8 @@ SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() {
   return Visitor->getArgumentTrackedFunctions();
 }
 
-void SCCPSolver::markArgInFuncSpecialization(Function *F, Argument *A,
-                                             Constant *C) {
-  Visitor->markArgInFuncSpecialization(F, A, C);
+void SCCPSolver::markArgInFuncSpecialization(Function *F, const ArgInfo &Arg) {
+  Visitor->markArgInFuncSpecialization(F, Arg);
 }
 
 void SCCPSolver::markFunctionUnreachable(Function *F) {


        


More information about the llvm-commits mailing list