[llvm] 23cbea9 - [TRE][NFC] Refactor shared state into member variables.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Fri May 8 14:36:26 PDT 2020


Author: Layton Kifer
Date: 2020-05-08T14:36:02-07:00
New Revision: 23cbea9a04e023d5b79dfee5964fae769340c993

URL: https://github.com/llvm/llvm-project/commit/23cbea9a04e023d5b79dfee5964fae769340c993
DIFF: https://github.com/llvm/llvm-project/commit/23cbea9a04e023d5b79dfee5964fae769340c993.diff

LOG: [TRE][NFC] Refactor shared state into member variables.

Separate functions that require shared state into a class to avoid
needing to pass them though multiple functions just to be available
where needed.

The main motivation for this is that we would like to remove the
limitation that accumulator values be dynamic constant, which would
require additional shared state between call eliminations in the same
function, compounding this issue.

Differential Revision: https://reviews.llvm.org/D79299

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 54187bce1a0d..f10198d26542 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -445,11 +445,53 @@ static Instruction *firstNonDbg(BasicBlock::iterator I) {
   return &*I;
 }
 
-static CallInst *findTRECandidate(Instruction *TI,
-                                  bool CannotTailCallElimCallsMarkedTail,
-                                  const TargetTransformInfo *TTI) {
+namespace {
+class TailRecursionEliminator {
+  Function &F;
+  const TargetTransformInfo *TTI;
+  AliasAnalysis *AA;
+  OptimizationRemarkEmitter *ORE;
+  DomTreeUpdater &DTU;
+
+  // The below are shared state we want to have available when eliminating any
+  // calls in the function. There values should be populated by
+  // createTailRecurseLoopHeader the first time we find a call we can eliminate.
+  BasicBlock *HeaderBB = nullptr;
+  SmallVector<PHINode *, 8> ArgumentPHIs;
+  bool RemovableCallsMustBeMarkedTail = false;
+
+  TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
+                          AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
+                          DomTreeUpdater &DTU)
+      : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+
+  CallInst *findTRECandidate(Instruction *TI,
+                             bool CannotTailCallElimCallsMarkedTail);
+
+  void createTailRecurseLoopHeader(CallInst *CI);
+
+  PHINode *insertAccumulator(Value *AccumulatorRecursionEliminationInitVal);
+
+  bool eliminateCall(CallInst *CI);
+
+  bool foldReturnAndProcessPred(ReturnInst *Ret,
+                                bool CannotTailCallElimCallsMarkedTail);
+
+  bool processReturningBlock(ReturnInst *Ret,
+                             bool CannotTailCallElimCallsMarkedTail);
+
+  void cleanupAndFinalize();
+
+public:
+  static bool eliminate(Function &F, const TargetTransformInfo *TTI,
+                        AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
+                        DomTreeUpdater &DTU);
+};
+} // namespace
+
+CallInst *TailRecursionEliminator::findTRECandidate(
+    Instruction *TI, bool CannotTailCallElimCallsMarkedTail) {
   BasicBlock *BB = TI->getParent();
-  Function *F = BB->getParent();
 
   if (&BB->front() == TI) // Make sure there is something before the terminator.
     return nullptr;
@@ -460,7 +502,7 @@ static CallInst *findTRECandidate(Instruction *TI,
   BasicBlock::iterator BBI(TI);
   while (true) {
     CI = dyn_cast<CallInst>(BBI);
-    if (CI && CI->getCalledFunction() == F)
+    if (CI && CI->getCalledFunction() == &F)
       break;
 
     if (BBI == BB->begin())
@@ -477,15 +519,14 @@ static CallInst *findTRECandidate(Instruction *TI,
   //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
   // and disable this xform in this case, because the code generator will
   // lower the call to fabs into inline code.
-  if (BB == &F->getEntryBlock() &&
+  if (BB == &F.getEntryBlock() &&
       firstNonDbg(BB->front().getIterator()) == CI &&
       firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() &&
       !TTI->isLoweredToCall(CI->getCalledFunction())) {
     // A single-block function with just a call and a return. Check that
     // the arguments match.
     auto I = CI->arg_begin(), E = CI->arg_end();
-    Function::arg_iterator FI = F->arg_begin(),
-                           FE = F->arg_end();
+    Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end();
     for (; I != E && FI != FE; ++I, ++FI)
       if (*I != &*FI) break;
     if (I == E && FI == FE)
@@ -495,10 +536,81 @@ static CallInst *findTRECandidate(Instruction *TI,
   return CI;
 }
 
-static bool eliminateRecursiveTailCall(
-    CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry,
-    bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs,
-    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
+void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
+  HeaderBB = &F.getEntryBlock();
+  BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB);
+  NewEntry->takeName(HeaderBB);
+  HeaderBB->setName("tailrecurse");
+  BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry);
+  BI->setDebugLoc(CI->getDebugLoc());
+
+  // If this function has self recursive calls in the tail position where some
+  // are marked tail and some are not, only transform one flavor or another.
+  // We have to choose whether we move allocas in the entry block to the new
+  // entry block or not, so we can't make a good choice for both. We make this
+  // decision here based on whether the first call we found to remove is
+  // marked tail.
+  // NOTE: We could do slightly better here in the case that the function has
+  // no entry block allocas.
+  RemovableCallsMustBeMarkedTail = CI->isTailCall();
+
+  // If this tail call is marked 'tail' and if there are any allocas in the
+  // entry block, move them up to the new entry block.
+  if (RemovableCallsMustBeMarkedTail)
+    // Move all fixed sized allocas from HeaderBB to NewEntry.
+    for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(),
+                              NEBI = NewEntry->begin();
+         OEBI != E;)
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
+        if (isa<ConstantInt>(AI->getArraySize()))
+          AI->moveBefore(&*NEBI);
+
+  // Now that we have created a new block, which jumps to the entry
+  // block, insert a PHI node for each argument of the function.
+  // For now, we initialize each PHI to only have the real arguments
+  // which are passed in.
+  Instruction *InsertPos = &HeaderBB->front();
+  for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) {
+    PHINode *PN =
+        PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos);
+    I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
+    PN->addIncoming(&*I, NewEntry);
+    ArgumentPHIs.push_back(PN);
+  }
+  // The entry block was changed from HeaderBB to NewEntry.
+  // The forward DominatorTree needs to be recalculated when the EntryBB is
+  // changed. In this corner-case we recalculate the entire tree.
+  DTU.recalculate(*NewEntry->getParent());
+}
+
+PHINode *TailRecursionEliminator::insertAccumulator(
+    Value *AccumulatorRecursionEliminationInitVal) {
+  // Start by inserting a new PHI node for the accumulator.
+  pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB);
+  PHINode *AccPN = PHINode::Create(
+      AccumulatorRecursionEliminationInitVal->getType(),
+      std::distance(PB, PE) + 1, "accumulator.tr", &HeaderBB->front());
+
+  // Loop over all of the predecessors of the tail recursion block.  For the
+  // real entry into the function we seed the PHI with the initial value,
+  // computed earlier.  For any other existing branches to this block (due to
+  // other tail recursions eliminated) the accumulator is not modified.
+  // Because we haven't added the branch in the current block to HeaderBB yet,
+  // it will not show up as a predecessor.
+  for (pred_iterator PI = PB; PI != PE; ++PI) {
+    BasicBlock *P = *PI;
+    if (P == &F.getEntryBlock())
+      AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
+    else
+      AccPN->addIncoming(AccPN, P);
+  }
+
+  return AccPN;
+}
+
+bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
+  ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator());
+
   // If we are introducing accumulator recursion to eliminate operations after
   // the call instruction that are both associative and commutative, the initial
   // value for the accumulator is placed in this variable.  If this value is set
@@ -556,7 +668,6 @@ static bool eliminateRecursiveTailCall(
   }
 
   BasicBlock *BB = Ret->getParent();
-  Function *F = BB->getParent();
 
   using namespace ore;
   ORE->emit([&]() {
@@ -566,51 +677,10 @@ static bool eliminateRecursiveTailCall(
 
   // OK! We can transform this tail call.  If this is the first one found,
   // create the new entry block, allowing us to branch back to the old entry.
-  if (!OldEntry) {
-    OldEntry = &F->getEntryBlock();
-    BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry);
-    NewEntry->takeName(OldEntry);
-    OldEntry->setName("tailrecurse");
-    BranchInst *BI = BranchInst::Create(OldEntry, NewEntry);
-    BI->setDebugLoc(CI->getDebugLoc());
-
-    // If this tail call is marked 'tail' and if there are any allocas in the
-    // entry block, move them up to the new entry block.
-    TailCallsAreMarkedTail = CI->isTailCall();
-    if (TailCallsAreMarkedTail)
-      // Move all fixed sized allocas from OldEntry to NewEntry.
-      for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(),
-             NEBI = NewEntry->begin(); OEBI != E; )
-        if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
-          if (isa<ConstantInt>(AI->getArraySize()))
-            AI->moveBefore(&*NEBI);
-
-    // Now that we have created a new block, which jumps to the entry
-    // block, insert a PHI node for each argument of the function.
-    // For now, we initialize each PHI to only have the real arguments
-    // which are passed in.
-    Instruction *InsertPos = &OldEntry->front();
-    for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
-         I != E; ++I) {
-      PHINode *PN = PHINode::Create(I->getType(), 2,
-                                    I->getName() + ".tr", InsertPos);
-      I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
-      PN->addIncoming(&*I, NewEntry);
-      ArgumentPHIs.push_back(PN);
-    }
-    // The entry block was changed from OldEntry to NewEntry.
-    // The forward DominatorTree needs to be recalculated when the EntryBB is
-    // changed. In this corner-case we recalculate the entire tree.
-    DTU.recalculate(*NewEntry->getParent());
-  }
+  if (!HeaderBB)
+    createTailRecurseLoopHeader(CI);
 
-  // If this function has self recursive calls in the tail position where some
-  // are marked tail and some are not, only transform one flavor or another.  We
-  // have to choose whether we move allocas in the entry block to the new entry
-  // block or not, so we can't make a good choice for both.  NOTE: We could do
-  // slightly better here in the case that the function has no entry block
-  // allocas.
-  if (TailCallsAreMarkedTail && !CI->isTailCall())
+  if (RemovableCallsMustBeMarkedTail && !CI->isTailCall())
     return false;
 
   // Ok, now that we know we have a pseudo-entry block WITH all of the
@@ -625,27 +695,9 @@ static bool eliminateRecursiveTailCall(
   // accumulator recursion predicate is set up.
   //
   if (AccumulatorRecursionEliminationInitVal) {
-    Instruction *AccRecInstr = AccumulatorRecursionInstr;
-    // Start by inserting a new PHI node for the accumulator.
-    pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry);
-    PHINode *AccPN = PHINode::Create(
-        AccumulatorRecursionEliminationInitVal->getType(),
-        std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front());
-
-    // Loop over all of the predecessors of the tail recursion block.  For the
-    // real entry into the function we seed the PHI with the initial value,
-    // computed earlier.  For any other existing branches to this block (due to
-    // other tail recursions eliminated) the accumulator is not modified.
-    // Because we haven't added the branch in the current block to OldEntry yet,
-    // it will not show up as a predecessor.
-    for (pred_iterator PI = PB; PI != PE; ++PI) {
-      BasicBlock *P = *PI;
-      if (P == &F->getEntryBlock())
-        AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
-      else
-        AccPN->addIncoming(AccPN, P);
-    }
+    PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal);
 
+    Instruction *AccRecInstr = AccumulatorRecursionInstr;
     if (AccRecInstr) {
       // Add an incoming argument for the current block, which is computed by
       // our associative and commutative accumulator instruction.
@@ -664,7 +716,7 @@ static bool eliminateRecursiveTailCall(
     // Finally, rewrite any return instructions in the program to return the PHI
     // node instead of the "initval" that they do currently.  This loop will
     // actually rewrite the return value we are destroying, but that's ok.
-    for (BasicBlock &BBI : *F)
+    for (BasicBlock &BBI : F)
       if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()))
         RI->setOperand(0, AccPN);
     ++NumAccumAdded;
@@ -672,21 +724,20 @@ static bool eliminateRecursiveTailCall(
 
   // Now that all of the PHI nodes are in place, remove the call and
   // ret instructions, replacing them with an unconditional branch.
-  BranchInst *NewBI = BranchInst::Create(OldEntry, Ret);
+  BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret);
   NewBI->setDebugLoc(CI->getDebugLoc());
 
   BB->getInstList().erase(Ret);  // Remove return.
   BB->getInstList().erase(CI);   // Remove call.
-  DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}});
+  DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
   ++NumEliminated;
   return true;
 }
 
-static bool foldReturnAndProcessPred(
-    BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry,
-    bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs,
-    bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI,
-    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
+bool TailRecursionEliminator::foldReturnAndProcessPred(
+    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
+  BasicBlock *BB = Ret->getParent();
+
   bool Change = false;
 
   // Make sure this block is a trivial return block.
@@ -709,10 +760,11 @@ static bool foldReturnAndProcessPred(
   while (!UncondBranchPreds.empty()) {
     BranchInst *BI = UncondBranchPreds.pop_back_val();
     BasicBlock *Pred = BI->getParent();
-    if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){
+    if (CallInst *CI =
+            findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) {
       LLVM_DEBUG(dbgs() << "FOLDING: " << *BB
                         << "INTO UNCOND BRANCH PRED: " << *Pred);
-      ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);
+      FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);
 
       // Cleanup: if all predecessors of BB have been eliminated by
       // FoldReturnIntoUncondBranch, delete it.  It is important to empty it,
@@ -721,8 +773,7 @@ static bool foldReturnAndProcessPred(
       if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB))
         DTU.deleteBB(BB);
 
-      eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail,
-                                 ArgumentPHIs, AA, ORE, DTU);
+      eliminateCall(CI);
       ++NumRetDuped;
       Change = true;
     }
@@ -731,23 +782,35 @@ static bool foldReturnAndProcessPred(
   return Change;
 }
 
-static bool processReturningBlock(
-    ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail,
-    SmallVectorImpl<PHINode *> &ArgumentPHIs,
-    bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI,
-    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) {
-  CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI);
+bool TailRecursionEliminator::processReturningBlock(
+    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) {
+  CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
   if (!CI)
     return false;
 
-  return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
-                                    ArgumentPHIs, AA, ORE, DTU);
+  return eliminateCall(CI);
+}
+
+void TailRecursionEliminator::cleanupAndFinalize() {
+  // If we eliminated any tail recursions, it's possible that we inserted some
+  // silly PHI nodes which just merge an initial value (the incoming operand)
+  // with themselves.  Check to see if we did and clean up our mess if so.  This
+  // occurs when a function passes an argument straight through to its tail
+  // call.
+  for (PHINode *PN : ArgumentPHIs) {
+    // If the PHI Node is a dynamic constant, replace it with the value it is.
+    if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) {
+      PN->replaceAllUsesWith(PNV);
+      PN->eraseFromParent();
+    }
+  }
 }
 
-static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
-                                   AliasAnalysis *AA,
-                                   OptimizationRemarkEmitter *ORE,
-                                   DomTreeUpdater &DTU) {
+bool TailRecursionEliminator::eliminate(Function &F,
+                                        const TargetTransformInfo *TTI,
+                                        AliasAnalysis *AA,
+                                        OptimizationRemarkEmitter *ORE,
+                                        DomTreeUpdater &DTU) {
   if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true")
     return false;
 
@@ -762,15 +825,13 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
   if (F.getFunctionType()->isVarArg())
     return false;
 
-  BasicBlock *OldEntry = nullptr;
-  bool TailCallsAreMarkedTail = false;
-  SmallVector<PHINode*, 8> ArgumentPHIs;
-
   // If false, we cannot perform TRE on tail calls marked with the 'tail'
   // attribute, because doing so would cause the stack size to increase (real
   // TRE would deallocate variable sized allocas, TRE doesn't).
   bool CanTRETailMarkedCall = canTRE(F);
 
+  TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+
   // Change any tail recursive calls to loops.
   //
   // FIXME: The code generator produces really bad code when an 'escaping
@@ -780,29 +841,14 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
   for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) {
     BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB.
     if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
-      bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
-                                          ArgumentPHIs, !CanTRETailMarkedCall,
-                                          TTI, AA, ORE, DTU);
+      bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall);
       if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
-        Change = foldReturnAndProcessPred(
-            BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs,
-            !CanTRETailMarkedCall, TTI, AA, ORE, DTU);
+        Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall);
       MadeChange |= Change;
     }
   }
 
-  // If we eliminated any tail recursions, it's possible that we inserted some
-  // silly PHI nodes which just merge an initial value (the incoming operand)
-  // with themselves.  Check to see if we did and clean up our mess if so.  This
-  // occurs when a function passes an argument straight through to its tail
-  // call.
-  for (PHINode *PN : ArgumentPHIs) {
-    // If the PHI Node is a dynamic constant, replace it with the value it is.
-    if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) {
-      PN->replaceAllUsesWith(PNV);
-      PN->eraseFromParent();
-    }
-  }
+  TRE.cleanupAndFinalize();
 
   return MadeChange;
 }
@@ -836,7 +882,7 @@ struct TailCallElim : public FunctionPass {
     // UpdateStrategy to Lazy if we find it profitable later.
     DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
 
-    return eliminateTailRecursion(
+    return TailRecursionEliminator::eliminate(
         F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
         &getAnalysis<AAResultsWrapperPass>().getAAResults(),
         &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
@@ -869,7 +915,7 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
   // UpdateStrategy based on some test results. It is feasible to switch the
   // UpdateStrategy to Lazy if we find it profitable later.
   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
-  bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU);
+  bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
 
   if (!Changed)
     return PreservedAnalyses::all();


        


More information about the llvm-commits mailing list