[llvm] d24a0e8 - [SCEV] Use constant range of RHS to prove NUW on narrow IV in trip count logic

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 5 15:37:01 PDT 2021


Author: Philip Reames
Date: 2021-11-05T15:36:47-07:00
New Revision: d24a0e88576dca1c475a7f48d4361136a46f9b72

URL: https://github.com/llvm/llvm-project/commit/d24a0e88576dca1c475a7f48d4361136a46f9b72
DIFF: https://github.com/llvm/llvm-project/commit/d24a0e88576dca1c475a7f48d4361136a46f9b72.diff

LOG: [SCEV] Use constant range of RHS to prove NUW on narrow IV in trip count logic

The basic idea here is that given a zero extended narrow IV, we can prove the inner IV to be NUW if we can prove there's a value the inner IV must take before overflow which must exit the loop.

Differential Revision: https://reviews.llvm.org/D109457

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/trip-count-implied-addrec.ll
    llvm/test/Transforms/IndVarSimplify/finite-exit-comparisons.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e97f793427f73..764c637ce3b09 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11792,9 +11792,34 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
           SmallVector<const SCEV*> Operands{AR->operands()};
           Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
-
-          setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
         }
+
+        auto canProveNUW = [&]() {
+          if (!isLoopInvariant(RHS, L))
+            return false;
+
+          if (!isKnownNonZero(AR->getStepRecurrence(*this)))
+            // We need the sequence defined by AR to strictly increase in the
+            // unsigned integer domain for the logic below to hold.
+            return false;
+
+          const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
+          const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
+          // If RHS <=u Limit, then there must exist a value V in the sequence
+          // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
+          // V <=u UINT_MAX.  Thus, we must exit the loop before unsigned
+          // overflow occurs.  This limit also implies that a signed comparison
+          // (in the wide bitwidth) is equivalent to an unsigned comparison as
+          // the high bits on both sides must be zero.
+          APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
+          APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
+          Limit = Limit.zext(OuterBitWidth);
+          return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
+        };
+        if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
+          Flags = setFlags(Flags, SCEV::FlagNUW);
+
+        setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
         if (AR->hasNoUnsignedWrap()) {
           // Emulate what getZeroExtendExpr would have done during construction
           // if we'd been able to infer the fact just above at that time.

diff  --git a/llvm/test/Analysis/ScalarEvolution/trip-count-implied-addrec.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-implied-addrec.ll
index f7e978e1faf4d..1d9babddedc3a 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-implied-addrec.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-implied-addrec.ll
@@ -279,11 +279,11 @@ for.end:                                          ; preds = %for.body, %entry
 define void @rhs_narrow_range(i16 %n.raw) {
 ; CHECK-LABEL: 'rhs_narrow_range'
 ; CHECK-NEXT:  Determining loop execution counts for: @rhs_narrow_range
-; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (1 umax (2 * (zext i7 (trunc i16 (%n.raw /u 2) to i7) to i16))<nuw><nsw>))<nsw>
+; CHECK-NEXT:  Loop %for.body: max backedge-taken count is 253
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (1 umax (2 * (zext i7 (trunc i16 (%n.raw /u 2) to i7) to i16))<nuw><nsw>))<nsw>
 ; CHECK-NEXT:   Predicates:
-; CHECK-NEXT:    {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %n = and i16 %n.raw, 254
@@ -301,6 +301,150 @@ for.end:                                          ; preds = %for.body, %entry
   ret void
 }
 
+define void @ugt_constant_rhs(i16 %n.raw, i8 %start) mustprogress {
+;
+; CHECK-LABEL: 'ugt_constant_rhs'
+; CHECK-NEXT:  Determining loop execution counts for: @ugt_constant_rhs
+; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Unpredictable max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Unpredictable predicated backedge-taken count.
+;
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ %start, %entry ]
+  %iv.next = add i8 %iv, 1
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ugt i16 %zext, 254
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+define void @ult_constant_rhs(i16 %n.raw, i8 %start) {
+;
+; CHECK-LABEL: 'ult_constant_rhs'
+; CHECK-NEXT:  Determining loop execution counts for: @ult_constant_rhs
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (255 + (-1 * (zext i8 (1 + %start) to i16))<nsw>)<nsw>
+; CHECK-NEXT:  Loop %for.body: max backedge-taken count is 255
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (255 + (-1 * (zext i8 (1 + %start) to i16))<nsw>)<nsw>
+; CHECK-NEXT:   Predicates:
+; CHECK:       Loop %for.body: Trip multiple is 1
+;
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ %start, %entry ]
+  %iv.next = add i8 %iv, 1
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ult i16 %zext, 255
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+define void @ult_constant_rhs_stride2(i16 %n.raw, i8 %start) {
+;
+; CHECK-LABEL: 'ult_constant_rhs_stride2'
+; CHECK-NEXT:  Determining loop execution counts for: @ult_constant_rhs_stride2
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is ((1 + (-1 * (zext i8 (2 + %start) to i16))<nsw> + (254 umax (zext i8 (2 + %start) to i16))) /u 2)
+; CHECK-NEXT:  Loop %for.body: max backedge-taken count is 127
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is ((1 + (-1 * (zext i8 (2 + %start) to i16))<nsw> + (254 umax (zext i8 (2 + %start) to i16))) /u 2)
+; CHECK-NEXT:   Predicates:
+; CHECK:       Loop %for.body: Trip multiple is 1
+;
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ %start, %entry ]
+  %iv.next = add i8 %iv, 2
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ult i16 %zext, 254
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+define void @ult_constant_rhs_stride2_neg(i16 %n.raw, i8 %start) {
+;
+; CHECK-LABEL: 'ult_constant_rhs_stride2_neg'
+; CHECK-NEXT:  Determining loop execution counts for: @ult_constant_rhs_stride2_neg
+; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Unpredictable max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is ((256 + (-1 * (zext i8 (2 + %start) to i16))<nsw>)<nsw> /u 2)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:    {(2 + %start),+,2}<%for.body> Added Flags: <nusw>
+;
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ %start, %entry ]
+  %iv.next = add i8 %iv, 2
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ult i16 %zext, 255
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+
+define void @ult_restricted_rhs(i16 %n.raw) {
+; CHECK-LABEL: 'ult_restricted_rhs'
+; CHECK-NEXT:  Determining loop execution counts for: @ult_restricted_rhs
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (1 umax (zext i8 (trunc i16 %n.raw to i8) to i16)))<nsw>
+; CHECK-NEXT:  Loop %for.body: max backedge-taken count is 254
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (1 umax (zext i8 (trunc i16 %n.raw to i8) to i16)))<nsw>
+; CHECK-NEXT:   Predicates:
+; CHECK:       Loop %for.body: Trip multiple is 1
+;
+entry:
+  %n = and i16 %n.raw, 255
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ 0, %entry ]
+  %iv.next = add i8 %iv, 1
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ult i16 %zext, %n
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+define void @ult_guarded_rhs(i16 %n) {;
+; CHECK-LABEL: 'ult_guarded_rhs'
+; CHECK-NEXT:  Determining loop execution counts for: @ult_guarded_rhs
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (1 umax %n))
+; CHECK-NEXT:  Loop %for.body: max backedge-taken count is -2
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + (1 umax %n))
+; CHECK-NEXT:   Predicates:
+; CHECK:       Loop %for.body: Trip multiple is 1
+;
+entry:
+  %in_range = icmp ult i16 %n, 256
+  br i1 %in_range, label %for.body, label %for.end
+
+for.body:                                         ; preds = %entry, %for.body
+  %iv = phi i8 [ %iv.next, %for.body ], [ 0, %entry ]
+  %iv.next = add i8 %iv, 1
+  %zext = zext i8 %iv.next to i16
+  %cmp = icmp ult i16 %zext, %n
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
+
 
 declare void @llvm.assume(i1)
 

diff  --git a/llvm/test/Transforms/IndVarSimplify/finite-exit-comparisons.ll b/llvm/test/Transforms/IndVarSimplify/finite-exit-comparisons.ll
index 8ae677e1f6df5..c5e7a76c64387 100644
--- a/llvm/test/Transforms/IndVarSimplify/finite-exit-comparisons.ll
+++ b/llvm/test/Transforms/IndVarSimplify/finite-exit-comparisons.ll
@@ -129,13 +129,13 @@ define void @slt_non_constant_rhs_no_mustprogress(i16 %n.raw) {
 ; CHECK-LABEL: @slt_non_constant_rhs_no_mustprogress(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[N:%.*]] = and i16 [[N_RAW:%.*]], 255
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc i16 [[N]] to i8
+; CHECK-NEXT:    [[SMAX:%.*]] = call i16 @llvm.smax.i16(i16 [[N]], i16 1)
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[IV:%.*]] = phi i8 [ [[IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[IV_NEXT]] = add i8 [[IV]], 1
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[IV_NEXT]], [[TMP0]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i16 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i16 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i16 [[INDVARS_IV_NEXT]], [[SMAX]]
+; CHECK-NEXT:    br i1 [[EXITCOND]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
 ; CHECK-NEXT:    ret void
 ;
@@ -932,17 +932,18 @@ for.end:                                          ; preds = %for.body, %entry
 define i16 @ult_multiuse_profit(i16 %n.raw, i8 %start) mustprogress {
 ; CHECK-LABEL: @ult_multiuse_profit(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc i16 254 to i8
+; CHECK-NEXT:    [[TMP0:%.*]] = add i8 [[START:%.*]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[TMP0]] to i16
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i16 254 to i8
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[IV:%.*]] = phi i8 [ [[IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[START:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i8 [ [[IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[START]], [[ENTRY:%.*]] ]
 ; CHECK-NEXT:    [[IV_NEXT]] = add i8 [[IV]], 1
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[IV_NEXT]] to i16
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[IV_NEXT]], [[TMP0]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[IV_NEXT]], [[TMP2]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
-; CHECK-NEXT:    [[ZEXT_LCSSA:%.*]] = phi i16 [ [[ZEXT]], [[FOR_BODY]] ]
-; CHECK-NEXT:    ret i16 [[ZEXT_LCSSA]]
+; CHECK-NEXT:    [[UMAX:%.*]] = call i16 @llvm.umax.i16(i16 [[TMP1]], i16 254)
+; CHECK-NEXT:    ret i16 [[UMAX]]
 ;
 entry:
   br label %for.body
@@ -993,13 +994,13 @@ define void @slt_restricted_rhs(i16 %n.raw) mustprogress {
 ; CHECK-LABEL: @slt_restricted_rhs(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[N:%.*]] = and i16 [[N_RAW:%.*]], 255
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc i16 [[N]] to i8
+; CHECK-NEXT:    [[SMAX:%.*]] = call i16 @llvm.smax.i16(i16 [[N]], i16 1)
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    [[IV:%.*]] = phi i8 [ [[IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
-; CHECK-NEXT:    [[IV_NEXT]] = add i8 [[IV]], 1
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[IV_NEXT]], [[TMP0]]
-; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i16 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i16 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i16 [[INDVARS_IV_NEXT]], [[SMAX]]
+; CHECK-NEXT:    br i1 [[EXITCOND]], label [[FOR_BODY]], label [[FOR_END:%.*]]
 ; CHECK:       for.end:
 ; CHECK-NEXT:    ret void
 ;


        


More information about the llvm-commits mailing list