[llvm] 2df8bf9 - [LoopFlatten] Fix missed LoopFlatten opportunity

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 29 01:53:28 PDT 2021


Author: Rosie Sumpter
Date: 2021-07-29T09:47:41+01:00
New Revision: 2df8bf9339e43de63d8d28e07182e1d6d7ffb843

URL: https://github.com/llvm/llvm-project/commit/2df8bf9339e43de63d8d28e07182e1d6d7ffb843
DIFF: https://github.com/llvm/llvm-project/commit/2df8bf9339e43de63d8d28e07182e1d6d7ffb843.diff

LOG: [LoopFlatten] Fix missed LoopFlatten opportunity

When the trip count of the inner loop is a constant, the InstCombine
pass now causes the transformation e.g. imcp ult i32 %inc, tripcount ->
icmp ult %j, tripcount-step (where %j is the inner loop induction
variable and %inc is add %j, step), which is now accounted for when
identifying the trip count of the loop. This is also an acceptable use
of %j (provided the step is 1) so is ignored as long as the compare
that it's used in is also the condition of the inner branch.

Differential Revision: https://reviews.llvm.org/D105802

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp
    llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
    llvm/test/Transforms/LoopFlatten/loop-flatten.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 46f12aafb552d..9b94cf9dde9f5 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -167,8 +167,7 @@ static bool findLoopComponents(
   // count computed by SCEV then this is either because the trip count variable
   // has been widened (then leave the trip count as it is), or because it is a
   // constant and another transformation has changed the compare, e.g.
-  // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten
-  // the loop (yet).
+  // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1.
   TripCount = Compare->getOperand(1);
   const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
   if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
@@ -176,12 +175,22 @@ static bool findLoopComponents(
     return false;
   }
   const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
-  if (SE->getSCEV(TripCount) != SCEVTripCount) {
-    if (!IsWidened) {
-      LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
-      return false;
-    }
-    auto TripCountInst = dyn_cast<Instruction>(TripCount);
+  if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) {
+    ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount);
+    // If the IV hasn't been widened, the only way the RHS of the Compare can be
+    // 
diff erent from the SCEV trip count is if it is a constant which has been
+    // changed by another transformation.
+    assert(RHS && "Expected RHS of compare to be constant");
+    // The L->isCanonical check above ensures we only get here if the loop
+    // increments by 1 on each iteration, so the RHS of the Compare is
+    // tripcount-1 (i.e equivalent to the backedge taken count).
+    assert(SE->getSCEV(RHS) == BackedgeTakenCount &&
+           "Expected RHS of compare to be equal to the backedge taken count");
+    ConstantInt *One = ConstantInt::get(RHS->getType(), 1);
+    TripCount = ConstantInt::get(TripCount->getContext(),
+                                 RHS->getValue() + One->getValue());
+  } else if (SE->getSCEV(TripCount) != SCEVTripCount) {
+    auto *TripCountInst = dyn_cast<Instruction>(TripCount);
     if (!TripCountInst) {
       LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
       return false;
@@ -368,6 +377,13 @@ static bool checkIVUsers(FlattenInfo &FI) {
       U = *U->user_begin();
     }
 
+    // If the use is in the compare (which is also the condition of the inner
+    // branch) then the compare has been altered by another transformation e.g
+    // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is
+    // a constant. Ignore this use as the compare gets removed later anyway.
+    if (U == FI.InnerBranch->getCondition())
+      continue;
+
     LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump());
 
     Value *MatchedMul;

diff  --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
index 3806414e1e503..1e264ff427524 100644
--- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
@@ -341,38 +341,111 @@ for.end8:                                         ; preds = %for.inc6
   ret i32 10
 }
 
-; When the loop trip count is a constant (e.g. 20) and the step size is
-; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 ->
-; icmp ult i32 %j, 19. In this case a valid trip count is not found so
-; the loop is not flattened. 
-define i32 @test9(i32* nocapture %A) {
+; test_10, test_11 and test_12 are for the case when the
+; inner trip count is a constant, then the InstCombine pass
+; makes the transformation icmp ult i32 %inc, tripcount ->
+; icmp ult i32 %j, tripcount-step.
+
+; test_10: The step is not 1.
+define i32 @test_10(i32* nocapture %A) {
 entry:
   br label %for.cond1.preheader
 
 for.cond1.preheader:
-  %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
+  %i.017 = phi i32 [ 0, %entry ], [ %inc, %for.cond.cleanup3 ]
   %mul = mul i32 %i.017, 20
   br label %for.body4
 
+for.body4:
+  %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %add5, %for.body4 ]
+  %add = add i32 %j.016, %mul
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
+  store i32 30, i32* %arrayidx, align 4
+  %add5 = add nuw nsw i32 %j.016, 2
+  %cmp2 = icmp ult i32 %j.016, 18
+  br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
+
 for.cond.cleanup3:
-  %inc6 = add i32 %i.017, 1
-  %cmp = icmp ult i32 %inc6, 11
+  %inc = add i32 %i.017, 1
+  %cmp = icmp ult i32 %inc, 11
   br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
 
+for.cond.cleanup:
+  %0 = load i32, i32* %A, align 4
+  ret i32 %0
+}
+
+; test_11: The inner inducation variable is used in a compare which
+; isn't the condition of the inner branch.
+define i32 @test_11(i32* nocapture %A) {
+entry:
+  br label %for.cond1.preheader
+
+for.cond1.preheader:
+  %i.020 = phi i32 [ 0, %entry ], [ %inc7, %for.cond.cleanup3 ]
+  %mul = mul i32 %i.020, 20
+  br label %for.body4
+
 for.body4:
-  %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
-  %add = add i32 %j.016, %mul
+  %j.019 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
+  %cmp5 = icmp ult i32 %j.019, 5
+  %cond = select i1 %cmp5, i32 30, i32 15
+  %add = add i32 %j.019, %mul
   %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
-  store i32 30, i32* %arrayidx, align 4
-  %inc = add nuw nsw i32 %j.016, 1
-  %cmp2 = icmp ult i32 %j.016, 19
+  store i32 %cond, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %j.019, 1
+  %cmp2 = icmp ult i32 %j.019, 19
   br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
 
+for.cond.cleanup3:
+  %inc7 = add i32 %i.020, 1
+  %cmp = icmp ult i32 %inc7, 11
+  br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
+
 for.cond.cleanup:
   %0 = load i32, i32* %A, align 4
   ret i32 %0
 }
 
+; test_12: Incoming phi node value for preheader is a variable
+define i32 @test_12(i32* %A) {
+entry:
+  br label %while.cond1.preheader
+
+while.cond1.preheader:
+  %j.017 = phi i32 [ 0, %entry ], [ %j.1, %while.end ]
+  %i.016 = phi i32 [ 0, %entry ], [ %inc4, %while.end ]
+  %mul = mul i32 %i.016, 20
+  %cmp214 = icmp ult i32 %j.017, 20
+  br i1 %cmp214, label %while.body3.preheader, label %while.end
+
+while.body3.preheader:
+  br label %while.body3
+
+while.body3:
+  %j.115 = phi i32 [ %inc, %while.body3 ], [ %j.017, %while.body3.preheader ]
+  %add = add i32 %j.115, %mul
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
+  store i32 30, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %j.115, 1
+  %cmp2 = icmp ult i32 %j.115, 19
+  br i1 %cmp2, label %while.body3, label %while.end.loopexit
+
+while.end.loopexit:
+  %inc.lcssa = phi i32 [ %inc, %while.body3 ]
+  br label %while.end
+  
+while.end:
+  %j.1 = phi i32 [ %j.017, %while.cond1.preheader], [ %inc.lcssa, %while.end.loopexit ]
+  %inc4 = add i32 %i.016, 1
+  %cmp = icmp ult i32 %inc4, 11
+  br i1 %cmp, label %while.cond1.preheader, label %while.end5
+
+while.end5:
+  %0 = load i32, i32* %A, align 4
+  ret i32 %0
+}
+
 ; Outer loop conditional phi
 define i32 @e() {
 entry:

diff  --git a/llvm/test/Transforms/LoopFlatten/loop-flatten.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten.ll
index 2d7f897472cea..73f50c0c843c5 100644
--- a/llvm/test/Transforms/LoopFlatten/loop-flatten.ll
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten.ll
@@ -586,6 +586,59 @@ for.end8:                                         ; preds = %for.inc6
   ret i32 10
 }
 
+; When the inner loop trip count is a constant and the step
+; is 1, the InstCombine pass causes the transformation e.g.
+; icmp ult i32 %inc, 20 -> icmp ult i32 %j, 19. This doesn't
+; match the pattern (OuterPHI * InnerTripCount) + InnerPHI but
+; we should still flatten the loop as the compare is removed
+; later anyway.
+define i32 @test9(i32* nocapture %A) {
+entry:
+  br label %for.cond1.preheader
+; CHECK-LABEL: test9
+; CHECK: entry:
+; CHECK: %flatten.tripcount = mul i32 20, 11
+; CHECK: br label %for.cond1.preheader
+
+for.cond1.preheader:
+  %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
+  %mul = mul i32 %i.017, 20
+  br label %for.body4
+; CHECK: for.cond1.preheader:
+; CHECK:   %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ]
+; CHECK:   %mul = mul i32 %i.017, 20
+; CHECK:   br label %for.body4
+
+for.cond.cleanup3:
+  %inc6 = add i32 %i.017, 1
+  %cmp = icmp ult i32 %inc6, 11
+  br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
+; CHECK: for.cond.cleanup3:
+; CHECK:   %inc6 = add i32 %i.017, 1
+; CHECK:   %cmp = icmp ult i32 %inc6, %flatten.tripcount
+; CHECK:   br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.body4:
+  %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ]
+  %add = add i32 %j.016, %mul
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add
+  store i32 30, i32* %arrayidx, align 4
+  %inc = add nuw nsw i32 %j.016, 1
+  %cmp2 = icmp ult i32 %j.016, 19
+  br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
+; CHECK: for.body4
+; CHECK:   %j.016 = phi i32 [ 0, %for.cond1.preheader ]
+; CHECK:   %add = add i32 %j.016, %mul
+; CHECK:   %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.017
+; CHECK:   store i32 30, i32* %arrayidx, align 4
+; CHECK:   %inc = add nuw nsw i32 %j.016, 1
+; CHECK:   %cmp2 = icmp ult i32 %j.016, 19
+; CHECK:   br label %for.cond.cleanup3
+
+for.cond.cleanup:
+  %0 = load i32, i32* %A, align 4
+  ret i32 %0
+}
 
 declare i32 @func(i32)
 


        


More information about the llvm-commits mailing list