[llvm] r319659 - [Loop Predication] Teach LP about reverse loops

Anna Thomas via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 4 07:11:48 PST 2017


Author: annat
Date: Mon Dec  4 07:11:48 2017
New Revision: 319659

URL: http://llvm.org/viewvc/llvm-project?rev=319659&view=rev
Log:
[Loop Predication] Teach LP about reverse loops

Summary:
Currently, we only support predication for forward loops with step
of 1.  This patch enables loop predication for reverse or
countdownLoops, which satisfy the following conditions:
   1. The step of the IV is -1.
   2. The loop has a singe latch as B(X) = X <pred>
latchLimit with pred as s> or u>
   3. The IV of the guard is the decrement
IV of the latch condition (Guard is: G(X) = X-1 u< guardLimit).

This patch was downstream for a while and is the last series of patches
that's from our LP implementation downstream.

Reviewers: apilipenko, mkazantsev, sanjoy

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40353

Added:
    llvm/trunk/test/Transforms/LoopPredication/reverse.ll
Modified:
    llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp

Modified: llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp?rev=319659&r1=319658&r2=319659&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/LoopPredication.cpp Mon Dec  4 07:11:48 2017
@@ -98,60 +98,79 @@
 // Note that we can use anything stronger than M, i.e. any condition which
 // implies M.
 //
-// For now the transformation is limited to the following case:
+// When S = 1 (i.e. forward iterating loop), the transformation is supported
+// when:
 //   * The loop has a single latch with the condition of the form:
 //     B(X) = latchStart + X <pred> latchLimit,
 //     where <pred> is u<, u<=, s<, or s<=.
-//   * The step of the IV used in the latch condition is 1.
 //   * The guard condition is of the form
 //     G(X) = guardStart + X u< guardLimit
 //
-// For the ult latch comparison case M is:
-//   forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
-//      guardStart + X + 1 u< guardLimit
+//   For the ult latch comparison case M is:
+//     forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
+//        guardStart + X + 1 u< guardLimit
 //
-// The only way the antecedent can be true and the consequent can be false is
-// if
-//   X == guardLimit - 1 - guardStart
-// (and guardLimit is non-zero, but we won't use this latter fact).
-// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
-//   latchStart + guardLimit - 1 - guardStart u< latchLimit
-// and its negation is
-//   latchStart + guardLimit - 1 - guardStart u>= latchLimit
+//   The only way the antecedent can be true and the consequent can be false is
+//   if
+//     X == guardLimit - 1 - guardStart
+//   (and guardLimit is non-zero, but we won't use this latter fact).
+//   If X == guardLimit - 1 - guardStart then the second half of the antecedent is
+//     latchStart + guardLimit - 1 - guardStart u< latchLimit
+//   and its negation is
+//     latchStart + guardLimit - 1 - guardStart u>= latchLimit
 //
-// In other words, if
-//   latchLimit u<= latchStart + guardLimit - 1 - guardStart
-// then:
-// (the ranges below are written in ConstantRange notation, where [A, B) is the
-// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
+//   In other words, if
+//     latchLimit u<= latchStart + guardLimit - 1 - guardStart
+//   then:
+//   (the ranges below are written in ConstantRange notation, where [A, B) is the
+//   set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
 //
-//    forall X . guardStart + X u< guardLimit &&
-//               latchStart + X u< latchLimit =>
-//      guardStart + X + 1 u< guardLimit
-// == forall X . guardStart + X u< guardLimit &&
-//               latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
-//      guardStart + X + 1 u< guardLimit
-// == forall X . (guardStart + X) in [0, guardLimit) &&
-//               (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
-//      (guardStart + X + 1) in [0, guardLimit)
-// == forall X . X in [-guardStart, guardLimit - guardStart) &&
-//               X in [-latchStart, guardLimit - 1 - guardStart) =>
-//       X in [-guardStart - 1, guardLimit - guardStart - 1)
-// == true
+//      forall X . guardStart + X u< guardLimit &&
+//                 latchStart + X u< latchLimit =>
+//        guardStart + X + 1 u< guardLimit
+//   == forall X . guardStart + X u< guardLimit &&
+//                 latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
+//        guardStart + X + 1 u< guardLimit
+//   == forall X . (guardStart + X) in [0, guardLimit) &&
+//                 (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
+//        (guardStart + X + 1) in [0, guardLimit)
+//   == forall X . X in [-guardStart, guardLimit - guardStart) &&
+//                 X in [-latchStart, guardLimit - 1 - guardStart) =>
+//         X in [-guardStart - 1, guardLimit - guardStart - 1)
+//   == true
 //
-// So the widened condition is:
-//   guardStart u< guardLimit &&
-//   latchStart + guardLimit - 1 - guardStart u>= latchLimit
-// Similarly for ule condition the widened condition is:
-//   guardStart u< guardLimit &&
-//   latchStart + guardLimit - 1 - guardStart u> latchLimit
-// For slt condition the widened condition is:
-//   guardStart u< guardLimit &&
-//   latchStart + guardLimit - 1 - guardStart s>= latchLimit
-// For sle condition the widened condition is:
-//   guardStart u< guardLimit &&
-//   latchStart + guardLimit - 1 - guardStart s> latchLimit
+//   So the widened condition is:
+//     guardStart u< guardLimit &&
+//     latchStart + guardLimit - 1 - guardStart u>= latchLimit
+//   Similarly for ule condition the widened condition is:
+//     guardStart u< guardLimit &&
+//     latchStart + guardLimit - 1 - guardStart u> latchLimit
+//   For slt condition the widened condition is:
+//     guardStart u< guardLimit &&
+//     latchStart + guardLimit - 1 - guardStart s>= latchLimit
+//   For sle condition the widened condition is:
+//     guardStart u< guardLimit &&
+//     latchStart + guardLimit - 1 - guardStart s> latchLimit
 //
+// When S = -1 (i.e. reverse iterating loop), the transformation is supported
+// when:
+//   * The loop has a single latch with the condition of the form:
+//     B(X) = X <pred> latchLimit, where <pred> is u> or s>.
+//   * The guard condition is of the form
+//     G(X) = X - 1 u< guardLimit
+//
+//   For the ugt latch comparison case M is:
+//     forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
+//
+//   The only way the antecedent can be true and the consequent can be false is if
+//     X == 1.
+//   If X == 1 then the second half of the antecedent is
+//     1 u> latchLimit, and its negation is latchLimit u>= 1.
+//
+//   So the widened condition is:
+//     guardStart u< guardLimit && latchLimit u>= 1.
+//   Similarly for sgt condition the widened condition is:
+//     guardStart u< guardLimit && latchLimit s>= 1.
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Scalar/LoopPredication.h"
@@ -177,6 +196,8 @@ using namespace llvm;
 static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
                                         cl::Hidden, cl::init(true));
 
+static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
+                                        cl::Hidden, cl::init(true));
 namespace {
 class LoopPredication {
   /// Represents an induction variable check:
@@ -223,7 +244,10 @@ class LoopPredication {
                                                         LoopICmp RangeCheck,
                                                         SCEVExpander &Expander,
                                                         IRBuilder<> &Builder);
-
+  Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
+                                                        LoopICmp RangeCheck,
+                                                        SCEVExpander &Expander,
+                                                        IRBuilder<> &Builder);
   bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
 
   // When the IV type is wider than the range operand type, we can still do loop
@@ -360,7 +384,7 @@ LoopPredication::generateLoopLatchCheck(
 }
 
 bool LoopPredication::isSupportedStep(const SCEV* Step) {
-  return Step->isOne();
+  return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
 }
 
 bool LoopPredication::CanExpand(const SCEV* S) {
@@ -420,6 +444,44 @@ Optional<Value *> LoopPredication::widen
                                           GuardStart, GuardLimit, InsertAt);
   return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
 }
+
+Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
+    LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
+    SCEVExpander &Expander, IRBuilder<> &Builder) {
+  auto *Ty = RangeCheck.IV->getType();
+  const SCEV *GuardStart = RangeCheck.IV->getStart();
+  const SCEV *GuardLimit = RangeCheck.Limit;
+  const SCEV *LatchLimit = LatchCheck.Limit;
+  if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
+      !CanExpand(LatchLimit)) {
+    DEBUG(dbgs() << "Can't expand limit check!\n");
+    return None;
+  }
+  // The decrement of the latch check IV should be the same as the
+  // rangeCheckIV.
+  auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
+  if (RangeCheck.IV != PostDecLatchCheckIV) {
+    DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
+                 << *PostDecLatchCheckIV
+                 << "  and RangeCheckIV: " << *RangeCheck.IV << "\n");
+    return None;
+  }
+
+  // Generate the widened condition for CountDownLoop:
+  // guardStart u< guardLimit &&
+  // latchLimit <pred> 1.
+  // See the header comment for reasoning of the checks.
+  Instruction *InsertAt = Preheader->getTerminator();
+  auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred)
+                            ? ICmpInst::ICMP_SGE
+                            : ICmpInst::ICMP_UGE;
+  auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
+                                          GuardStart, GuardLimit, InsertAt);
+  auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
+                                 SE->getOne(Ty), InsertAt);
+  return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
+}
+
 /// If ICI can be widened to a loop invariant condition emits the loop
 /// invariant condition in the loop preheader and return it, otherwise
 /// returns None.
@@ -467,13 +529,24 @@ Optional<Value *> LoopPredication::widen
   }
 
   LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
-  // At this point the range check step and latch step should have the same
-  // value and type.
-  assert(Step == CurrLatchCheck.IV->getStepRecurrence(*SE) &&
-         "Range and latch should have same step recurrence!");
+  // At this point, the range and latch step should have the same type, but need
+  // not have the same value (we support both 1 and -1 steps).
+  assert(Step->getType() ==
+             CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
+         "Range and latch steps should be of same type!");
+  if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
+    DEBUG(dbgs() << "Range and latch have different step values!\n");
+    return None;
+  }
 
-  return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
-                                             Expander, Builder);
+  if (Step->isOne())
+    return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
+                                               Expander, Builder);
+  else {
+    assert(Step->isAllOnesValue() && "Step should be -1!");
+    return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
+                                               Expander, Builder);
+  }
 }
 
 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
@@ -580,9 +653,13 @@ Optional<LoopPredication::LoopICmp> Loop
   }
 
   auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
-    assert(Step->isOne() && "expected Step to be one!");
-    return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
-           Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
+    if (Step->isOne()) {
+      return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
+             Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
+    } else {
+      assert(Step->isAllOnesValue() && "Step should be -1!");
+      return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT;
+    }
   };
 
   if (IsUnsupportedPredicate(Step, Result->Pred)) {

Added: llvm/trunk/test/Transforms/LoopPredication/reverse.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopPredication/reverse.ll?rev=319659&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/LoopPredication/reverse.ll (added)
+++ llvm/trunk/test/Transforms/LoopPredication/reverse.ll Mon Dec  4 07:11:48 2017
@@ -0,0 +1,140 @@
+; RUN: opt -S -loop-predication -loop-predication-enable-count-down-loop=true < %s 2>&1 | FileCheck %s
+; RUN: opt -S -passes='require<scalar-evolution>,loop(loop-predication)' -loop-predication-enable-count-down-loop=true < %s 2>&1 | FileCheck %s
+
+declare void @llvm.experimental.guard(i1, ...)
+
+define i32 @signed_reverse_loop_n_to_lower_limit(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
+; CHECK-LABEL: @signed_reverse_loop_n_to_lower_limit(
+entry:
+  %tmp5 = icmp eq i32 %n, 0
+  br i1 %tmp5, label %exit, label %loop.preheader
+
+; CHECK:       loop.preheader:
+; CHECK-NEXT:    [[range_start:%.*]] = add i32 %n, -1
+; CHECK-NEXT:    [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
+; CHECK-NEXT:    [[no_wrap_check:%.*]] = icmp sge i32 %lowerlimit, 1
+; CHECK-NEXT:    [[wide_cond:%.*]] = and i1 [[first_iteration_check]], [[no_wrap_check]]
+loop.preheader:
+  br label %loop
+
+; CHECK: loop:
+; CHECK:    call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
+loop:
+  %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
+  %i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
+  %i.next = add nsw i32 %i, -1
+  %within.bounds = icmp ult i32 %i.next, %length
+  call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+  %i.i64 = zext i32 %i.next to i64
+  %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
+  %array.i = load i32, i32* %array.i.ptr, align 4
+  %loop.acc.next = add i32 %loop.acc, %array.i
+  %continue = icmp sgt i32 %i, %lowerlimit
+  br i1 %continue, label %loop, label %exit
+
+exit:
+  %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
+  ret i32 %result
+}
+
+define i32 @unsigned_reverse_loop_n_to_lower_limit(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
+; CHECK-LABEL: @unsigned_reverse_loop_n_to_lower_limit(
+entry:
+  %tmp5 = icmp eq i32 %n, 0
+  br i1 %tmp5, label %exit, label %loop.preheader
+
+; CHECK:       loop.preheader:
+; CHECK-NEXT:    [[range_start:%.*]] = add i32 %n, -1
+; CHECK-NEXT:    [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
+; CHECK-NEXT:    [[no_wrap_check:%.*]] = icmp uge i32 %lowerlimit, 1
+; CHECK-NEXT:    [[wide_cond:%.*]] = and i1 [[first_iteration_check]], [[no_wrap_check]]
+loop.preheader:
+  br label %loop
+
+; CHECK: loop:
+; CHECK:    call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
+loop:
+  %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
+  %i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
+  %i.next = add nsw i32 %i, -1
+  %within.bounds = icmp ult i32 %i.next, %length
+  call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+  %i.i64 = zext i32 %i.next to i64
+  %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
+  %array.i = load i32, i32* %array.i.ptr, align 4
+  %loop.acc.next = add i32 %loop.acc, %array.i
+  %continue = icmp ugt i32 %i, %lowerlimit
+  br i1 %continue, label %loop, label %exit
+
+exit:
+  %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
+  ret i32 %result
+}
+
+
+; if we predicated the loop, the guard will definitely fail and we will
+; deoptimize early on.
+define i32 @unsigned_reverse_loop_n_to_0(i32* %array, i32 %length, i32 %n, i32 %lowerlimit) {
+; CHECK-LABEL: @unsigned_reverse_loop_n_to_0(
+entry:
+  %tmp5 = icmp eq i32 %n, 0
+  br i1 %tmp5, label %exit, label %loop.preheader
+
+; CHECK:       loop.preheader:
+; CHECK-NEXT:    [[range_start:%.*]] = add i32 %n, -1
+; CHECK-NEXT:    [[first_iteration_check:%.*]] = icmp ult i32 [[range_start]], %length
+; CHECK-NEXT:    [[wide_cond:%.*]] = and i1 [[first_iteration_check]], false
+loop.preheader:
+  br label %loop
+
+; CHECK: loop:
+; CHECK:    call void (i1, ...) @llvm.experimental.guard(i1 [[wide_cond]], i32 9) [ "deopt"() ]
+loop:
+  %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
+  %i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
+  %i.next = add nsw i32 %i, -1
+  %within.bounds = icmp ult i32 %i.next, %length
+  call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+  %i.i64 = zext i32 %i.next to i64
+  %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
+  %array.i = load i32, i32* %array.i.ptr, align 4
+  %loop.acc.next = add i32 %loop.acc, %array.i
+  %continue = icmp ugt i32 %i, 0
+  br i1 %continue, label %loop, label %exit
+
+exit:
+  %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
+  ret i32 %result
+}
+
+; do not loop predicate when the range has step -1 and latch has step 1.
+define i32 @reverse_loop_range_step_increment(i32 %n, i32* %array, i32 %length) {
+; CHECK-LABEL: @reverse_loop_range_step_increment(
+entry:
+  %tmp5 = icmp eq i32 %n, 0
+  br i1 %tmp5, label %exit, label %loop.preheader
+
+loop.preheader:
+  br label %loop
+
+; CHECK: loop:
+; CHECK: llvm.experimental.guard(i1 %within.bounds, i32 9)
+loop:
+  %loop.acc = phi i32 [ %loop.acc.next, %loop ], [ 0, %loop.preheader ]
+  %i = phi i32 [ %i.next, %loop ], [ %n, %loop.preheader ]
+  %irc = phi i32 [ %i.inc, %loop ], [ 1, %loop.preheader ]
+  %i.inc = add nuw nsw i32 %irc, 1
+  %within.bounds = icmp ult i32 %irc, %length
+  call void (i1, ...) @llvm.experimental.guard(i1 %within.bounds, i32 9) [ "deopt"() ]
+  %i.i64 = zext i32 %irc to i64
+  %array.i.ptr = getelementptr inbounds i32, i32* %array, i64 %i.i64
+  %array.i = load i32, i32* %array.i.ptr, align 4
+  %i.next = add nsw i32 %i, -1
+  %loop.acc.next = add i32 %loop.acc, %array.i
+  %continue = icmp ugt i32 %i, 65534
+  br i1 %continue, label %loop, label %exit
+
+exit:
+  %result = phi i32 [ 0, %entry ], [ %loop.acc.next, %loop ]
+  ret i32 %result
+}




More information about the llvm-commits mailing list