[llvm] 46abd1f - [LoopFlatten] Fix assertion failure in checkOverflow

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 13 02:15:41 PDT 2021


Author: Rosie Sumpter
Date: 2021-08-13T10:07:49+01:00
New Revision: 46abd1fbe88fe1f4b0e6cb2b87f3e7d148bbadf7

URL: https://github.com/llvm/llvm-project/commit/46abd1fbe88fe1f4b0e6cb2b87f3e7d148bbadf7
DIFF: https://github.com/llvm/llvm-project/commit/46abd1fbe88fe1f4b0e6cb2b87f3e7d148bbadf7.diff

LOG: [LoopFlatten] Fix assertion failure in checkOverflow

There is an assertion failure in computeOverflowForUnsignedMul
(used in checkOverflow) due to the inner and outer trip counts
having different types. This occurs when the IV has been widened,
but the loop components are not successfully rediscovered.
This is fixed by some refactoring of the code in findLoopComponents
which identifies the trip count of the loop.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp
    llvm/test/Transforms/LoopFlatten/widen-iv.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 61f6d214bd6c..3343bdd0b573 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -93,6 +93,17 @@ struct FlattenInfo {
   FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
 };
 
+static bool
+setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
+                  SmallPtrSetImpl<Instruction *> &IterationInstructions) {
+  TripCount = TC;
+  IterationInstructions.insert(Increment);
+  LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());
+  LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
+  LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
+  return true;
+}
+
 // Finds the induction variable, increment and trip count for a simple loop that
 // we can flatten.
 static bool findLoopComponents(
@@ -164,49 +175,63 @@ static bool findLoopComponents(
     return false;
   }
   // The trip count is the RHS of the compare. If this doesn't match the trip
-  // count computed by SCEV then this is either because the trip count variable
-  // has been widened (then leave the trip count as it is), or because it is a
-  // constant and another transformation has changed the compare, e.g.
-  // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1.
-  TripCount = Compare->getOperand(1);
+  // count computed by SCEV then this is because the trip count variable
+  // has been widened so the types don't match, or because it is a constant and
+  // another transformation has changed the compare (e.g. icmp ult %inc,
+  // tripcount -> icmp ult %j, tripcount-1), or both.
+  Value *RHS = Compare->getOperand(1);
   const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
   if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
     LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
     return false;
   }
   const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
-  if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) {
-    ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount);
-    if (!RHS) {
-      LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
-      return false;
-    }
-    // The L->isCanonical check above ensures we only get here if the loop
-    // increments by 1 on each iteration, so the RHS of the Compare is
-    // tripcount-1 (i.e equivalent to the backedge taken count).
-    assert(SE->getSCEV(RHS) == BackedgeTakenCount &&
-           "Expected RHS of compare to be equal to the backedge taken count");
-    ConstantInt *One = ConstantInt::get(RHS->getType(), 1);
-    TripCount = ConstantInt::get(TripCount->getContext(),
-                                 RHS->getValue() + One->getValue());
-  } else if (SE->getSCEV(TripCount) != SCEVTripCount) {
-    auto *TripCountInst = dyn_cast<Instruction>(TripCount);
-    if (!TripCountInst) {
-      LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
-      return false;
+  const SCEV *SCEVRHS = SE->getSCEV(RHS);
+  if (SCEVRHS == SCEVTripCount)
+    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
+  ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
+  if (ConstantRHS) {
+    const SCEV *BackedgeTCExt = nullptr;
+    if (IsWidened) {
+      const SCEV *SCEVTripCountExt;
+      // Find the extended backedge taken count and extended trip count using
+      // SCEV. One of these should now match the RHS of the compare.
+      BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
+      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt);
+      if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
+        LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+        return false;
+      }
     }
-    if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
-        SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
-      LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
-      return false;
+    // If the RHS of the compare is equal to the backedge taken count we need
+    // to add one to get the trip count.
+    if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
+      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
+      Value *NewRHS = ConstantInt::get(
+          ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
+      return setLoopComponents(NewRHS, TripCount, Increment,
+                               IterationInstructions);
     }
+    return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
   }
-  IterationInstructions.insert(Increment);
-  LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
-  LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
-
-  LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
-  return true;
+  // If the RHS isn't a constant then check that the reason it doesn't match
+  // the SCEV trip count is because the RHS is a ZExt or SExt instruction
+  // (and take the trip count to be the RHS).
+  if (!IsWidened) {
+    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+    return false;
+  }
+  auto *TripCountInst = dyn_cast<Instruction>(RHS);
+  if (!TripCountInst) {
+    LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+    return false;
+  }
+  if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
+      SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
+    LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+    return false;
+  }
+  return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
 }
 
 static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {

diff  --git a/llvm/test/Transforms/LoopFlatten/widen-iv.ll b/llvm/test/Transforms/LoopFlatten/widen-iv.ll
index a6b13e43c64a..abd70138e4c1 100644
--- a/llvm/test/Transforms/LoopFlatten/widen-iv.ll
+++ b/llvm/test/Transforms/LoopFlatten/widen-iv.ll
@@ -525,6 +525,52 @@ for.cond.cleanup:
   ret void
 }
 
+; Identify trip count when it is constant and the IV has been widened.
+define i32 @constTripCount() {
+; CHECK-LABEL: @constTripCount(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 20, 20
+; CHECK-NEXT:    br label [[I_LOOP:%.*]]
+; CHECK:       i.loop:
+; CHECK-NEXT:    [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[J_LOOPDONE:%.*]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    br label [[J_LOOP:%.*]]
+; CHECK:       j.loop:
+; CHECK-NEXT:    [[INDVAR:%.*]] = phi i64 [ 0, [[I_LOOP]] ]
+; CHECK-NEXT:    call void @payload()
+; CHECK-NEXT:    [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1
+; CHECK-NEXT:    [[J_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 20
+; CHECK-NEXT:    br label [[J_LOOPDONE]]
+; CHECK:       j.loopdone:
+; CHECK-NEXT:    [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1
+; CHECK-NEXT:    [[I_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]]
+; CHECK-NEXT:    br i1 [[I_ATEND]], label [[I_LOOPDONE:%.*]], label [[I_LOOP]]
+; CHECK:       i.loopdone:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  br label %i.loop
+
+i.loop:
+  %i = phi i8 [ 0, %entry ], [ %i.inc, %j.loopdone ]
+  br label %j.loop
+
+j.loop:
+  %j = phi i8 [ 0, %i.loop ], [ %j.inc, %j.loop ]
+  call void @payload()
+  %j.inc = add i8 %j, 1
+  %j.atend = icmp eq i8 %j.inc, 20
+  br i1 %j.atend, label %j.loopdone, label %j.loop
+
+j.loopdone:
+  %i.inc = add i8 %i, 1
+  %i.atend = icmp eq i8 %i.inc, 20
+  br i1 %i.atend, label %i.loopdone, label %i.loop
+
+i.loopdone:
+  ret i32 0
+}
+
+declare void @payload()
 declare dso_local i32 @use_32(i32)
 declare dso_local i32 @use_16(i16)
 declare dso_local i32 @use_64(i64)


        


More information about the llvm-commits mailing list