[llvm-commits] [llvm] r64532 - in /llvm/trunk: lib/Transforms/Scalar/IndVarSimplify.cpp test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll

Dan Gohman gohman at apple.com
Fri Feb 13 18:31:09 PST 2009


Author: djg
Date: Fri Feb 13 20:31:09 2009
New Revision: 64532

URL: http://llvm.org/viewvc/llvm-project?rev=64532&view=rev
Log:
Extend the IndVarSimplify support for promoting induction variables:
 - Test for signed and unsigned wrapping conditions, instead of just
   testing for non-negative induction ranges. 
 - Handle loops with GT comparisons, in addition to LT comparisons.
 - Support more cases of induction variables that don't start at 0.

Modified:
    llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp
    llvm/trunk/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll

Modified: llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp?rev=64532&r1=64531&r2=64532&view=diff

==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/IndVarSimplify.cpp Fri Feb 13 20:31:09 2009
@@ -458,33 +458,98 @@
   return Ty;
 }
 
-/// isOrigIVAlwaysNonNegative - Analyze the original induction variable
-/// in the loop to determine whether it would ever have a negative
-/// value.
+/// TestOrigIVForWrap - Analyze the original induction variable
+/// in the loop to determine whether it would ever undergo signed
+/// or unsigned overflow.
 ///
 /// TODO: This duplicates a fair amount of ScalarEvolution logic.
-/// Perhaps this can be merged with ScalarEvolution::getIterationCount.
+/// Perhaps this can be merged with ScalarEvolution::getIterationCount
+/// and/or ScalarEvolution::get{Sign,Zero}ExtendExpr.
 ///
-static bool isOrigIVAlwaysNonNegative(const Loop *L,
-                                      const Instruction *OrigCond) {
+static void TestOrigIVForWrap(const Loop *L,
+                              const BranchInst *BI,
+                              const Instruction *OrigCond,
+                              bool &NoSignedWrap,
+                              bool &NoUnsignedWrap) {
   // Verify that the loop is sane and find the exit condition.
   const ICmpInst *Cmp = dyn_cast<ICmpInst>(OrigCond);
-  if (!Cmp) return false;
+  if (!Cmp) return;
 
-  // For now, analyze only SLT loops for signed overflow.
-  if (Cmp->getPredicate() != ICmpInst::ICMP_SLT) return false;
+  const Value *CmpLHS = Cmp->getOperand(0);
+  const Value *CmpRHS = Cmp->getOperand(1);
+  const BasicBlock *TrueBB = BI->getSuccessor(0);
+  const BasicBlock *FalseBB = BI->getSuccessor(1);
+  ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+  // Canonicalize a constant to the RHS.
+  if (isa<ConstantInt>(CmpLHS)) {
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+    std::swap(CmpLHS, CmpRHS);
+  }
+  // Canonicalize SLE to SLT.
+  if (Pred == ICmpInst::ICMP_SLE)
+    if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
+      if (!CI->getValue().isMaxSignedValue()) {
+        CmpRHS = ConstantInt::get(CI->getValue() + 1);
+        Pred = ICmpInst::ICMP_SLT;
+      }
+  // Canonicalize SGT to SGE.
+  if (Pred == ICmpInst::ICMP_SGT)
+    if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
+      if (!CI->getValue().isMaxSignedValue()) {
+        CmpRHS = ConstantInt::get(CI->getValue() + 1);
+        Pred = ICmpInst::ICMP_SGE;
+      }
+  // Canonicalize SGE to SLT.
+  if (Pred == ICmpInst::ICMP_SGE) {
+    std::swap(TrueBB, FalseBB);
+    Pred = ICmpInst::ICMP_SLT;
+  }
+  // Canonicalize ULE to ULT.
+  if (Pred == ICmpInst::ICMP_ULE)
+    if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
+      if (!CI->getValue().isMaxValue()) {
+        CmpRHS = ConstantInt::get(CI->getValue() + 1);
+        Pred = ICmpInst::ICMP_ULT;
+      }
+  // Canonicalize UGT to UGE.
+  if (Pred == ICmpInst::ICMP_UGT)
+    if (const ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
+      if (!CI->getValue().isMaxValue()) {
+        CmpRHS = ConstantInt::get(CI->getValue() + 1);
+        Pred = ICmpInst::ICMP_UGE;
+      }
+  // Canonicalize UGE to ULT.
+  if (Pred == ICmpInst::ICMP_UGE) {
+    std::swap(TrueBB, FalseBB);
+    Pred = ICmpInst::ICMP_ULT;
+  }
+  // For now, analyze only LT loops for signed overflow.
+  if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_ULT)
+    return;
+
+  bool isSigned = Pred == ICmpInst::ICMP_SLT;
 
-  // Get the increment instruction. Look past SExtInsts if we will
+  // Get the increment instruction. Look past casts if we will
   // be able to prove that the original induction variable doesn't
-  // undergo signed overflow.
-  const Value *OrigIncrVal = Cmp->getOperand(0);
-  const Value *IncrVal = OrigIncrVal;
-  if (SExtInst *SI = dyn_cast<SExtInst>(Cmp->getOperand(0))) {
-    if (!isa<ConstantInt>(Cmp->getOperand(1)) ||
-        !cast<ConstantInt>(Cmp->getOperand(1))->getValue()
-          .isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
-      return false;
-    IncrVal = SI->getOperand(0);
+  // undergo signed or unsigned overflow, respectively.
+  const Value *IncrVal = CmpLHS;
+  if (isSigned) {
+    if (const SExtInst *SI = dyn_cast<SExtInst>(CmpLHS)) {
+      if (!isa<ConstantInt>(CmpRHS) ||
+          !cast<ConstantInt>(CmpRHS)->getValue()
+            .isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
+        return;
+      IncrVal = SI->getOperand(0);
+    }
+  } else {
+    if (const ZExtInst *ZI = dyn_cast<ZExtInst>(CmpLHS)) {
+      if (!isa<ConstantInt>(CmpRHS) ||
+          !cast<ConstantInt>(CmpRHS)->getValue()
+            .isIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
+        return;
+      IncrVal = ZI->getOperand(0);
+    }
   }
 
   // For now, only analyze induction variables that have simple increments.
@@ -493,32 +558,36 @@
       IncrOp->getOpcode() != Instruction::Add ||
       !isa<ConstantInt>(IncrOp->getOperand(1)) ||
       !cast<ConstantInt>(IncrOp->getOperand(1))->equalsInt(1))
-    return false;
+    return;
 
   // Make sure the PHI looks like a normal IV.
   const PHINode *PN = dyn_cast<PHINode>(IncrOp->getOperand(0));
   if (!PN || PN->getNumIncomingValues() != 2)
-    return false;
+    return;
   unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
   unsigned BackEdge = !IncomingEdge;
   if (!L->contains(PN->getIncomingBlock(BackEdge)) ||
       PN->getIncomingValue(BackEdge) != IncrOp)
-    return false;
+    return;
+  if (!L->contains(TrueBB))
+    return;
 
   // For now, only analyze loops with a constant start value, so that
-  // we can easily determine if the start value is non-negative and
-  // not a maximum value which would wrap on the first iteration.
+  // we can easily determine if the start value is not a maximum value
+  // which would wrap on the first iteration.
   const Value *InitialVal = PN->getIncomingValue(IncomingEdge);
-  if (!isa<ConstantInt>(InitialVal) ||
-      cast<ConstantInt>(InitialVal)->getValue().isNegative() ||
-      cast<ConstantInt>(InitialVal)->getValue().isMaxSignedValue())
-    return false;
+  if (!isa<ConstantInt>(InitialVal))
+    return;
 
-  // The original induction variable will start at some non-negative
-  // non-max value, it counts up by one, and the loop iterates only
-  // while it remans less than (signed) some value in the same type.
-  // As such, it will always be non-negative.
-  return true;
+  // The original induction variable will start at some non-max value,
+  // it counts up by one, and the loop iterates only while it remans
+  // less than some value in the same type. As such, it will never wrap.
+  if (isSigned &&
+      !cast<ConstantInt>(InitialVal)->getValue().isMaxSignedValue())
+    NoSignedWrap = true;
+  else if (!isSigned &&
+           !cast<ConstantInt>(InitialVal)->getValue().isMaxValue())
+    NoUnsignedWrap = true;
 }
 
 bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
@@ -596,13 +665,15 @@
 
   // If we have a trip count expression, rewrite the loop's exit condition
   // using it.  We can currently only handle loops with a single exit.
-  bool OrigIVAlwaysNonNegative = false;
+  bool NoSignedWrap = false;
+  bool NoUnsignedWrap = false;
   if (!isa<SCEVCouldNotCompute>(IterationCount) && ExitingBlock)
     // Can't rewrite non-branch yet.
     if (BranchInst *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator())) {
       if (Instruction *OrigCond = dyn_cast<Instruction>(BI->getCondition())) {
-        // Determine if the OrigIV will ever have a non-zero sign bit.
-        OrigIVAlwaysNonNegative = isOrigIVAlwaysNonNegative(L, OrigCond);
+        // Determine if the OrigIV will ever undergo overflow.
+        TestOrigIVForWrap(L, BI, OrigCond,
+                          NoSignedWrap, NoUnsignedWrap);
 
         // We'll be replacing the original condition, so it'll be dead.
         DeadInsts.insert(OrigCond);
@@ -642,19 +713,38 @@
     /// If the new canonical induction variable is wider than the original,
     /// and the original has uses that are casts to wider types, see if the
     /// truncate and extend can be omitted.
-    if (isa<TruncInst>(NewVal))
+    if (PN->getType() != LargestType)
       for (Value::use_iterator UI = PN->use_begin(), UE = PN->use_end();
-           UI != UE; ++UI)
-        if (isa<ZExtInst>(UI) ||
-            (isa<SExtInst>(UI) && OrigIVAlwaysNonNegative)) {
-          Value *TruncIndVar = IndVar;
-          if (TruncIndVar->getType() != UI->getType())
-            TruncIndVar = new TruncInst(IndVar, UI->getType(), "truncindvar",
-                                        InsertPt);
+           UI != UE; ++UI) {
+        if (isa<SExtInst>(UI) && NoSignedWrap) {
+          SCEVHandle ExtendedStart =
+            SE->getSignExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStart(), LargestType);
+          SCEVHandle ExtendedStep =
+            SE->getSignExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStepRecurrence(*SE), LargestType);
+          SCEVHandle ExtendedAddRec =
+            SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
+          if (LargestType != UI->getType())
+            ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
+          Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
+          UI->replaceAllUsesWith(TruncIndVar);
+          if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
+            DeadInsts.insert(DeadUse);
+        }
+        if (isa<ZExtInst>(UI) && NoUnsignedWrap) {
+          SCEVHandle ExtendedStart =
+            SE->getZeroExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStart(), LargestType);
+          SCEVHandle ExtendedStep =
+            SE->getZeroExtendExpr(cast<SCEVAddRecExpr>(IndVars.back().second)->getStepRecurrence(*SE), LargestType);
+          SCEVHandle ExtendedAddRec =
+            SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
+          if (LargestType != UI->getType())
+            ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
+          Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
           UI->replaceAllUsesWith(TruncIndVar);
           if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
             DeadInsts.insert(DeadUse);
         }
+      }
 
     // Replace the old PHI Node with the inserted computation.
     PN->replaceAllUsesWith(NewVal);

Modified: llvm/trunk/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll?rev=64532&r1=64531&r2=64532&view=diff

==============================================================================
--- llvm/trunk/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll (original)
+++ llvm/trunk/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll Fri Feb 13 20:31:09 2009
@@ -60,3 +60,41 @@
 return:		; preds = %bb1.return_crit_edge, %entry
 	ret void
 }
+
+; Test cases from PR1301:
+
+define void @kinds__srangezero([21 x i32]* nocapture %a) nounwind {
+bb.thread:
+  br label %bb
+
+bb:             ; preds = %bb, %bb.thread
+  %i.0.reg2mem.0 = phi i8 [ -10, %bb.thread ], [ %tmp7, %bb ]           ; <i8> [#uses=2]
+  %tmp12 = sext i8 %i.0.reg2mem.0 to i32                ; <i32> [#uses=1]
+  %tmp4 = add i32 %tmp12, 10            ; <i32> [#uses=1]
+  %tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4                ; <i32*> [#uses=1]
+  store i32 0, i32* %tmp5
+  %tmp7 = add i8 %i.0.reg2mem.0, 1              ; <i8> [#uses=2]
+  %0 = icmp sgt i8 %tmp7, 10            ; <i1> [#uses=1]
+  br i1 %0, label %return, label %bb
+
+return:         ; preds = %bb
+  ret void
+}
+
+define void @kinds__urangezero([21 x i32]* nocapture %a) nounwind {
+bb.thread:
+  br label %bb
+
+bb:             ; preds = %bb, %bb.thread
+  %i.0.reg2mem.0 = phi i8 [ 10, %bb.thread ], [ %tmp7, %bb ]            ; <i8> [#uses=2]
+  %tmp12 = sext i8 %i.0.reg2mem.0 to i32                ; <i32> [#uses=1]
+  %tmp4 = add i32 %tmp12, -10           ; <i32> [#uses=1]
+  %tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4                ; <i32*> [#uses=1]
+  store i32 0, i32* %tmp5
+  %tmp7 = add i8 %i.0.reg2mem.0, 1              ; <i8> [#uses=2]
+  %0 = icmp sgt i8 %tmp7, 30            ; <i1> [#uses=1]
+  br i1 %0, label %return, label %bb
+
+return:         ; preds = %bb
+  ret void
+}





More information about the llvm-commits mailing list