[llvm] 2bf3fe9 - [TRE] Allow elimination when the returned value is non-constant

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed May 27 16:55:30 PDT 2020


Author: Layton Kifer
Date: 2020-05-27T16:55:03-07:00
New Revision: 2bf3fe9b6dedf727990e68244a3d637518ea8bc3

URL: https://github.com/llvm/llvm-project/commit/2bf3fe9b6dedf727990e68244a3d637518ea8bc3
DIFF: https://github.com/llvm/llvm-project/commit/2bf3fe9b6dedf727990e68244a3d637518ea8bc3.diff

LOG: [TRE] Allow elimination when the returned value is non-constant

Currently we can only eliminate call return pairs that either return the
result of the call or a dynamic constant. This patch removes that
limitation.

Differential Revision: https://reviews.llvm.org/D79660

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
    llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll
    llvm/test/Transforms/TailCallElim/basic.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 4fd63fa1838b..a752e356b727 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -460,6 +460,16 @@ class TailRecursionEliminator {
   SmallVector<PHINode *, 8> ArgumentPHIs;
   bool RemovableCallsMustBeMarkedTail = false;
 
+  // PHI node to store our return value.
+  PHINode *RetPN = nullptr;
+
+  // i1 PHI node to track if we have a valid return value stored in RetPN.
+  PHINode *RetKnownPN = nullptr;
+
+  // Vector of select instructions we insereted. These selects use RetKnownPN
+  // to either propagate RetPN or select a new return value.
+  SmallVector<SelectInst *, 8> RetSelects;
+
   TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
                           AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
                           DomTreeUpdater &DTU)
@@ -577,6 +587,21 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) {
     PN->addIncoming(&*I, NewEntry);
     ArgumentPHIs.push_back(PN);
   }
+
+  // If the function doen't return void, create the RetPN and RetKnownPN PHI
+  // nodes to track our return value. We initialize RetPN with undef and
+  // RetKnownPN with false since we can't know our return value at function
+  // entry.
+  Type *RetType = F.getReturnType();
+  if (!RetType->isVoidTy()) {
+    Type *BoolType = Type::getInt1Ty(F.getContext());
+    RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos);
+    RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos);
+
+    RetPN->addIncoming(UndefValue::get(RetType), NewEntry);
+    RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry);
+  }
+
   // The entry block was changed from HeaderBB to NewEntry.
   // The forward DominatorTree needs to be recalculated when the EntryBB is
   // changed. In this corner-case we recalculate the entire tree.
@@ -616,11 +641,7 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
   // value for the accumulator is placed in this variable.  If this value is set
   // then we actually perform accumulator recursion elimination instead of
   // simple tail recursion elimination.  If the operation is an LLVM instruction
-  // (eg: "add") then it is recorded in AccumulatorRecursionInstr.  If not, then
-  // we are handling the case when the return instruction returns a constant C
-  // which is 
diff erent to the constant returned by other return instructions
-  // (which is recorded in AccumulatorRecursionEliminationInitVal).  This is a
-  // special case of accumulator recursion, the operation being "return C".
+  // (eg: "add") then it is recorded in AccumulatorRecursionInstr.
   Value *AccumulatorRecursionEliminationInitVal = nullptr;
   Instruction *AccumulatorRecursionInstr = nullptr;
 
@@ -647,26 +668,6 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
     }
   }
 
-  // We can only transform call/return pairs that either ignore the return value
-  // of the call and return void, ignore the value of the call and return a
-  // constant, return the value returned by the tail call, or that are being
-  // accumulator recursion variable eliminated.
-  if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI &&
-      !isa<UndefValue>(Ret->getReturnValue()) &&
-      AccumulatorRecursionEliminationInitVal == nullptr &&
-      !getCommonReturnValue(nullptr, CI)) {
-    // One case remains that we are able to handle: the current return
-    // instruction returns a constant, and all other return instructions
-    // return a 
diff erent constant.
-    if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret))
-      return false; // Current return instruction does not return a constant.
-    // Check that all other return instructions return a common constant.  If
-    // so, record it in AccumulatorRecursionEliminationInitVal.
-    AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI);
-    if (!AccumulatorRecursionEliminationInitVal)
-      return false;
-  }
-
   BasicBlock *BB = Ret->getParent();
 
   using namespace ore;
@@ -698,20 +699,15 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
     PHINode *AccPN = insertAccumulator(AccumulatorRecursionEliminationInitVal);
 
     Instruction *AccRecInstr = AccumulatorRecursionInstr;
-    if (AccRecInstr) {
-      // Add an incoming argument for the current block, which is computed by
-      // our associative and commutative accumulator instruction.
-      AccPN->addIncoming(AccRecInstr, BB);
-
-      // Next, rewrite the accumulator recursion instruction so that it does not
-      // use the result of the call anymore, instead, use the PHI node we just
-      // inserted.
-      AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
-    } else {
-      // Add an incoming argument for the current block, which is just the
-      // constant returned by the current return instruction.
-      AccPN->addIncoming(Ret->getReturnValue(), BB);
-    }
+
+    // Add an incoming argument for the current block, which is computed by
+    // our associative and commutative accumulator instruction.
+    AccPN->addIncoming(AccRecInstr, BB);
+
+    // Next, rewrite the accumulator recursion instruction so that it does not
+    // use the result of the call anymore, instead, use the PHI node we just
+    // inserted.
+    AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
 
     // Finally, rewrite any return instructions in the program to return the PHI
     // node instead of the "initval" that they do currently.  This loop will
@@ -722,6 +718,25 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
     ++NumAccumAdded;
   }
 
+  // Update our return value tracking
+  if (RetPN) {
+    if (Ret->getReturnValue() == CI || AccumulatorRecursionEliminationInitVal) {
+      // Defer selecting a return value
+      RetPN->addIncoming(RetPN, BB);
+      RetKnownPN->addIncoming(RetKnownPN, BB);
+    } else {
+      // We found a return value we want to use, insert a select instruction to
+      // select it if we don't already know what our return value will be and
+      // store the result in our return value PHI node.
+      SelectInst *SI = SelectInst::Create(
+          RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret);
+      RetSelects.push_back(SI);
+
+      RetPN->addIncoming(SI, BB);
+      RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB);
+    }
+  }
+
   // Now that all of the PHI nodes are in place, remove the call and
   // ret instructions, replacing them with an unconditional branch.
   BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret);
@@ -804,6 +819,30 @@ void TailRecursionEliminator::cleanupAndFinalize() {
       PN->eraseFromParent();
     }
   }
+
+  if (RetPN) {
+    if (RetSelects.empty()) {
+      // If we didn't insert any select instructions, then we know we didn't
+      // store a return value and we can remove the PHI nodes we inserted.
+      RetPN->dropAllReferences();
+      RetPN->eraseFromParent();
+
+      RetKnownPN->dropAllReferences();
+      RetKnownPN->eraseFromParent();
+    } else {
+      // We need to insert a select instruction before any return left in the
+      // function to select our stored return value if we have one.
+      for (BasicBlock &BB : F) {
+        ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator());
+        if (!RI)
+          continue;
+
+        SelectInst *SI = SelectInst::Create(
+            RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI);
+        RI->setOperand(0, SI);
+      }
+    }
+  }
 }
 
 bool TailRecursionEliminator::eliminate(Function &F,

diff  --git a/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll b/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll
index 48110e3283cf..4e0346c14c34 100644
--- a/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll
+++ b/llvm/test/Transforms/TailCallElim/2010-06-26-MultipleReturnValues.ll
@@ -1,20 +1,112 @@
 ; RUN: opt < %s -tailcallelim -verify-dom-info -S | FileCheck %s
 ; PR7328
 ; PR7506
-define i32 @foo(i32 %x) {
-; CHECK-LABEL: define i32 @foo(
-; CHECK: %accumulator.tr = phi i32 [ 1, %entry ], [ 0, %body ]
+define i32 @test1_constants(i32 %x) {
 entry:
   %cond = icmp ugt i32 %x, 0                      ; <i1> [#uses=1]
   br i1 %cond, label %return, label %body
 
 body:                                             ; preds = %entry
   %y = add i32 %x, 1                              ; <i32> [#uses=1]
-  %tmp = call i32 @foo(i32 %y)                    ; <i32> [#uses=0]
-; CHECK-NOT: call
+  %recurse = call i32 @test1_constants(i32 %y)        ; <i32> [#uses=0]
   ret i32 0
-; CHECK: ret i32 %accumulator.tr
 
 return:                                           ; preds = %entry
   ret i32 1
 }
+
+; CHECK-LABEL: define i32 @test1_constants(
+; CHECK: tailrecurse:
+; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %body ]
+; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %body ]
+; CHECK: body:
+; CHECK-NOT: %recurse
+; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 0
+; CHECK-NOT: ret
+; CHECK: return:
+; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 1
+; CHECK: ret i32 %current.ret.tr1
+
+define i32 @test2_non_constants(i32 %x) {
+entry:
+  %cond = icmp ugt i32 %x, 0
+  br i1 %cond, label %return, label %body
+
+body:
+  %y = add i32 %x, 1
+  %helper1 = call i32 @test2_helper()
+  %recurse = call i32 @test2_non_constants(i32 %y)
+  ret i32 %helper1
+
+return:
+  %helper2 = call i32 @test2_helper()
+  ret i32 %helper2
+}
+
+declare i32 @test2_helper()
+
+; CHECK-LABEL: define i32 @test2_non_constants(
+; CHECK: tailrecurse:
+; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %body ]
+; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %body ]
+; CHECK: body:
+; CHECK-NOT: %recurse
+; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper1
+; CHECK-NOT: ret
+; CHECK: return:
+; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper2
+; CHECK: ret i32 %current.ret.tr1
+
+define i32 @test3_mixed(i32 %x) {
+entry:
+  switch i32 %x, label %default [
+    i32 0, label %case0
+    i32 1, label %case1
+    i32 2, label %case2
+  ]
+
+case0:
+  %helper1 = call i32 @test3_helper()
+  br label %return
+
+case1:
+  %y1 = add i32 %x, -1
+  %recurse1 = call i32 @test3_mixed(i32 %y1)
+  br label %return
+
+case2:
+  %y2 = add i32 %x, -1
+  %helper2 = call i32 @test3_helper()
+  %recurse2 = call i32 @test3_mixed(i32 %y2)
+  br label %return
+
+default:
+  %y3 = urem i32 %x, 3
+  %recurse3 = call i32 @test3_mixed(i32 %y3)
+  br label %return
+
+return:
+  %retval = phi i32 [ %recurse3, %default ], [ %helper2, %case2 ], [ 9, %case1 ], [ %helper1, %case0 ]
+  ret i32 %retval
+}
+
+declare i32 @test3_helper()
+
+; CHECK-LABEL: define i32 @test3_mixed(
+; CHECK: tailrecurse:
+; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %case1 ], [ %current.ret.tr1, %case2 ], [ %ret.tr, %default ]
+; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %case1 ], [ true, %case2 ], [ %ret.known.tr, %default ]
+; CHECK: case1:
+; CHECK-NOT: %recurse
+; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 9
+; CHECK: br label %tailrecurse
+; CHECK: case2:
+; CHECK-NOT: %recurse
+; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper2
+; CHECK: br label %tailrecurse
+; CHECK: default:
+; CHECK-NOT: %recurse
+; CHECK: br label %tailrecurse
+; CHECK: return:
+; CHECK: %current.ret.tr2 = select i1 %ret.known.tr, i32 %ret.tr, i32 %helper1
+; CHECK: ret i32 %current.ret.tr2

diff  --git a/llvm/test/Transforms/TailCallElim/basic.ll b/llvm/test/Transforms/TailCallElim/basic.ll
index 576f2fec1244..6116014a024b 100644
--- a/llvm/test/Transforms/TailCallElim/basic.ll
+++ b/llvm/test/Transforms/TailCallElim/basic.ll
@@ -46,8 +46,16 @@ endif.0:		; preds = %entry
 ; plunked it into the demo script, so maybe they care about it.
 define i32 @test3(i32 %c) {
 ; CHECK: i32 @test3
+; CHECK: tailrecurse:
+; CHECK: %ret.tr = phi i32 [ undef, %entry ], [ %current.ret.tr, %else ]
+; CHECK: %ret.known.tr = phi i1 [ false, %entry ], [ true, %else ]
+; CHECK: else:
 ; CHECK-NOT: call
-; CHECK: ret i32 0
+; CHECK: %current.ret.tr = select i1 %ret.known.tr, i32 %ret.tr, i32 0
+; CHECK-NOT: ret
+; CHECK: return:
+; CHECK: %current.ret.tr1 = select i1 %ret.known.tr, i32 %ret.tr, i32 0
+; CHECK: ret i32 %current.ret.tr1
 entry:
 	%tmp.1 = icmp eq i32 %c, 0		; <i1> [#uses=1]
 	br i1 %tmp.1, label %return, label %else


        


More information about the llvm-commits mailing list