[llvm] 491ac28 - [LoopFlatten] Use SCEV and Loop APIs to identify increment and trip count

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 01:02:34 PDT 2021


Author: Rosie Sumpter
Date: 2021-07-27T08:42:59+01:00
New Revision: 491ac2802805f65c0960ae6685f9599048517a97

URL: https://github.com/llvm/llvm-project/commit/491ac2802805f65c0960ae6685f9599048517a97
DIFF: https://github.com/llvm/llvm-project/commit/491ac2802805f65c0960ae6685f9599048517a97.diff

LOG: [LoopFlatten] Use SCEV and Loop APIs to identify increment and trip count

Replace pattern-matching with existing SCEV and Loop APIs as a more
robust way of identifying the loop increment and trip count. Also
rename 'Limit' as 'TripCount' to be consistent with terminology.

Differential Revision: https://reviews.llvm.org/D106580

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp
    llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 38bf0bca6cdd1..f54289f85ef53 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -63,7 +63,7 @@ static cl::opt<bool>
     AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
                      cl::init(false),
                      cl::desc("Assume that the product of the two iteration "
-                              "limits will never overflow"));
+                              "trip counts will never overflow"));
 
 static cl::opt<bool>
     WidenIV("loop-flatten-widen-iv", cl::Hidden,
@@ -74,10 +74,12 @@ static cl::opt<bool>
 struct FlattenInfo {
   Loop *OuterLoop = nullptr;
   Loop *InnerLoop = nullptr;
+  // These PHINodes correspond to loop induction variables, which are expected
+  // to start at zero and increment by one on each loop.
   PHINode *InnerInductionPHI = nullptr;
   PHINode *OuterInductionPHI = nullptr;
-  Value *InnerLimit = nullptr;
-  Value *OuterLimit = nullptr;
+  Value *InnerTripCount = nullptr;
+  Value *OuterTripCount = nullptr;
   BinaryOperator *InnerIncrement = nullptr;
   BinaryOperator *OuterIncrement = nullptr;
   BranchInst *InnerBranch = nullptr;
@@ -91,12 +93,12 @@ struct FlattenInfo {
   FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
 };
 
-// Finds the induction variable, increment and limit for a simple loop that we
-// can flatten.
+// Finds the induction variable, increment and trip count for a simple loop that
+// we can flatten.
 static bool findLoopComponents(
     Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
-    PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment,
-    BranchInst *&BackBranch, ScalarEvolution *SE) {
+    PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
+    BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
   LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
 
   if (!L->isLoopSimplifyForm()) {
@@ -104,6 +106,13 @@ static bool findLoopComponents(
     return false;
   }
 
+  // Currently, to simplify the implementation, the Loop induction variable must
+  // start at zero and increment with a step size of one.
+  if (!L->isCanonical(*SE)) {
+    LLVM_DEBUG(dbgs() << "Loop is not canonical\n");
+    return false;
+  }
+
   // There must be exactly one exiting block, and it must be the same at the
   // latch.
   BasicBlock *Latch = L->getLoopLatch();
@@ -144,40 +153,44 @@ static bool findLoopComponents(
   IterationInstructions.insert(Compare);
   LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
 
-  // Find increment and limit from the compare
-  Increment = nullptr;
-  if (match(Compare->getOperand(0),
-            m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
-    Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0));
-    Limit = Compare->getOperand(1);
-  } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE &&
-             match(Compare->getOperand(1),
-                   m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
-    Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1));
-    Limit = Compare->getOperand(0);
-  }
-  if (!Increment || Increment->hasNUsesOrMore(3)) {
-    LLVM_DEBUG(dbgs() << "Cound not find valid increment\n");
+  // Find increment and trip count.
+  // There are exactly 2 incoming values to the induction phi; one from the
+  // pre-header and one from the latch. The incoming latch value is the
+  // increment variable.
+  Increment =
+      dyn_cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));
+  if (Increment->hasNUsesOrMore(3)) {
+    LLVM_DEBUG(dbgs() << "Could not find valid increment\n");
     return false;
   }
+  // The trip count is the RHS of the compare. If this doesn't match the trip
+  // count computed by SCEV then this is either because the trip count variable
+  // has been widened (then leave the trip count as it is), or because it is a
+  // constant and another transformation has changed the compare, e.g.
+  // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten
+  // the loop (yet).
+  TripCount = Compare->getOperand(1);
+  const SCEV *SCEVTripCount =
+      SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L));
+  if (SE->getSCEV(TripCount) != SCEVTripCount) {
+    if (!IsWidened) {
+      LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+      return false;
+    }
+    auto TripCountInst = dyn_cast<Instruction>(TripCount);
+    if (!TripCountInst) {
+      LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+      return false;
+    }
+    if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
+        SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
+      LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+      return false;
+    }
+  }
   IterationInstructions.insert(Increment);
   LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
-  LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump());
-
-  assert(InductionPHI->getNumIncomingValues() == 2);
-
-  if (InductionPHI->getIncomingValueForBlock(Latch) != Increment) {
-    LLVM_DEBUG(
-        dbgs() << "Incoming value from latch is not the increment inst\n");
-    return false;
-  }
-
-  auto *CI = dyn_cast<ConstantInt>(
-      InductionPHI->getIncomingValueForBlock(L->getLoopPreheader()));
-  if (!CI || !CI->isZero()) {
-    LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump());
-    return false;
-  }
+  LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
 
   LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
   return true;
@@ -300,7 +313,7 @@ checkOuterLoopInsts(FlattenInfo &FI,
       // Multiplies of the outer iteration variable and inner iteration
       // count will be optimised out.
       if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
-                            m_Specific(FI.InnerLimit))))
+                            m_Specific(FI.InnerTripCount))))
         continue;
       InstructionCost Cost =
           TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
@@ -325,16 +338,16 @@ checkOuterLoopInsts(FlattenInfo &FI,
 static bool checkIVUsers(FlattenInfo &FI) {
   // We require all uses of both induction variables to match this pattern:
   //
-  //   (OuterPHI * InnerLimit) + InnerPHI
+  //   (OuterPHI * InnerTripCount) + InnerPHI
   //
   // Any uses of the induction variables not matching that pattern would
   // require a div/mod to reconstruct in the flattened loop, so the
   // transformation wouldn't be profitable.
 
-  Value *InnerLimit = FI.InnerLimit;
+  Value *InnerTripCount = FI.InnerTripCount;
   if (FI.Widened &&
-      (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
-    InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0);
+      (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
+    InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
 
   // Check that all uses of the inner loop's induction variable match the
   // expected pattern, recording the uses of the outer IV.
@@ -368,7 +381,7 @@ static bool checkIVUsers(FlattenInfo &FI) {
                             m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
                             m_Value(MatchedItCount)));
 
-    if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) {
+    if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
       ValidOuterPHIUses.insert(MatchedMul);
       FI.LinearIVUses.insert(U);
@@ -417,7 +430,7 @@ static bool checkIVUsers(FlattenInfo &FI) {
 }
 
 // Return an OverflowResult dependant on if overflow of the multiplication of
-// InnerLimit and OuterLimit can be assumed not to happen.
+// InnerTripCount and OuterTripCount can be assumed not to happen.
 static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
                                     AssumptionCache *AC) {
   Function *F = FI.OuterLoop->getHeader()->getParent();
@@ -430,7 +443,7 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
   // Check if the multiply could not overflow due to known ranges of the
   // input values.
   OverflowResult OR = computeOverflowForUnsignedMul(
-      FI.InnerLimit, FI.OuterLimit, DL, AC,
+      FI.InnerTripCount, FI.OuterTripCount, DL, AC,
       FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
   if (OR != OverflowResult::MayOverflow)
     return OR;
@@ -461,21 +474,23 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
                                ScalarEvolution *SE, AssumptionCache *AC,
                                const TargetTransformInfo *TTI) {
   SmallPtrSet<Instruction *, 8> IterationInstructions;
-  if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
-                          FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
+  if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
+                          FI.InnerInductionPHI, FI.InnerTripCount,
+                          FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
     return false;
-  if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI,
-                          FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE))
+  if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
+                          FI.OuterInductionPHI, FI.OuterTripCount,
+                          FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
     return false;
 
-  // Both of the loop limit values must be invariant in the outer loop
+  // Both of the loop trip count values must be invariant in the outer loop
   // (non-instructions are all inherently invariant).
-  if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) {
-    LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n");
+  if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
+    LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");
     return false;
   }
-  if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) {
-    LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n");
+  if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
+    LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");
     return false;
   }
 
@@ -515,9 +530,9 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
     ORE.emit(Remark);
   }
 
-  Value *NewTripCount =
-      BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
-                                FI.OuterLoop->getLoopPreheader()->getTerminator());
+  Value *NewTripCount = BinaryOperator::CreateMul(
+      FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
+      FI.OuterLoop->getLoopPreheader()->getTerminator());
   LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
              NewTripCount->dump());
 
@@ -581,7 +596,7 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
 
   // If both induction types are less than the maximum legal integer width,
   // promote both to the widest type available so we know calculating
-  // (OuterLimit * InnerLimit) as the new trip count is safe.
+  // (OuterTripCount * InnerTripCount) as the new trip count is safe.
   if (InnerType != OuterType ||
       InnerType->getScalarSizeInBits() >= MaxLegalSize ||
       MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {

diff  --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
index c563078e25da1..e7c8697b14c68 100644
--- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
@@ -341,6 +341,37 @@ for.end8:                                         ; preds = %for.inc6
   ret i32 10
 }
 
+; When the loop trip count is a constant (e.g. 20) and the step size is
+; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 ->
+; icmp ult i32 %j, 19. In this case a valid trip count is not found so
+; the loop is not flattened. 
+define i32 @test9(i32* nocapture %A) {
+entry:
+  br label %for.cond1.preheader
+
+for.cond1.preheader:
+  %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
+  %mul = mul i32 %i.017, 20
+  br label %for.body4
+
+for.cond.cleanup3:
+  %inc6 = add i32 %i.017, 1
+  %cmp = icmp ult i32 %inc6, 11
+  br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.body4:
+  %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
+  %add = add i32 %j.016, %mul
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
+  store i32 30, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %j.016, 1
+  %cmp2 = icmp ult i32 %j.016, 19
+  br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
+
+for.cond.cleanup:
+  %0 = load i32, i32* %A, align 4
+  ret i32 %0
+}
 
 ; Outer loop conditional phi
 define i32 @e() {


        


More information about the llvm-commits mailing list