[llvm] r308299 - PSCEV] Create AddRec for Phis in cases of possible integer overflow,

Dorit Nuzman via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 18 04:57:08 PDT 2017


Author: dorit
Date: Tue Jul 18 04:57:08 2017
New Revision: 308299

URL: http://llvm.org/viewvc/llvm-project?rev=308299&view=rev
Log:
PSCEV] Create AddRec for Phis in cases of possible integer overflow,
using runtime checks

Extend the SCEVPredicateRewriter to work a bit harder when it encounters an
UnknownSCEV for a Phi node; Try to build an AddRecurrence also for Phi nodes
whose update chain involves casts that can be ignored under the proper runtime
overflow test. This is one step towards addressing PR30654.

Differential revision: http://reviews.llvm.org/D30041


Added:
    llvm/trunk/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll
Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=308299&r1=308298&r2=308299&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Tue Jul 18 04:57:08 2017
@@ -237,17 +237,15 @@ struct FoldingSetTrait<SCEVPredicate> :
 };
 
 /// This class represents an assumption that two SCEV expressions are equal,
-/// and this can be checked at run-time. We assume that the left hand side is
-/// a SCEVUnknown and the right hand side a constant.
+/// and this can be checked at run-time.
 class SCEVEqualPredicate final : public SCEVPredicate {
-  /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
-  /// constant.
-  const SCEVUnknown *LHS;
-  const SCEVConstant *RHS;
+  /// We assume that LHS == RHS.
+  const SCEV *LHS;
+  const SCEV *RHS;
 
 public:
-  SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
-                     const SCEVConstant *RHS);
+  SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEV *LHS,
+                     const SCEV *RHS);
 
   /// Implementation of the SCEVPredicate interface
   bool implies(const SCEVPredicate *N) const override;
@@ -256,10 +254,10 @@ public:
   const SCEV *getExpr() const override;
 
   /// Returns the left hand side of the equality.
-  const SCEVUnknown *getLHS() const { return LHS; }
+  const SCEV *getLHS() const { return LHS; }
 
   /// Returns the right hand side of the equality.
-  const SCEVConstant *getRHS() const { return RHS; }
+  const SCEV *getRHS() const { return RHS; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const SCEVPredicate *P) {
@@ -1241,6 +1239,14 @@ public:
     SmallVector<const SCEV *, 4> NewOp(Operands.begin(), Operands.end());
     return getAddRecExpr(NewOp, L, Flags);
   }
+
+  /// Checks if \p SymbolicPHI can be rewritten as an AddRecExpr under some
+  /// Predicates. If successful return these <AddRecExpr, Predicates>;
+  /// The function is intended to be called from PSCEV (the caller will decide
+  /// whether to actually add the predicates and carry out the rewrites).
+  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+  createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI);
+  
   /// Returns an expression for a GEP
   ///
   /// \p GEP The GEP. The indices contained in the GEP itself are ignored,
@@ -1675,8 +1681,7 @@ public:
     return F.getParent()->getDataLayout();
   }
 
-  const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
-                                         const SCEVConstant *RHS);
+  const SCEVPredicate *getEqualPredicate(const SCEV *LHS, const SCEV *RHS);
 
   const SCEVPredicate *
   getWrapPredicate(const SCEVAddRecExpr *AR,
@@ -1692,6 +1697,19 @@ public:
       SmallPtrSetImpl<const SCEVPredicate *> &Preds);
 
 private:
+  /// Similar to createAddRecFromPHI, but with the additional flexibility of 
+  /// suggesting runtime overflow checks in case casts are encountered.
+  /// If successful, the analysis records that for this loop, \p SymbolicPHI,
+  /// which is the UnknownSCEV currently representing the PHI, can be rewritten
+  /// into an AddRec, assuming some predicates; The function then returns the
+  /// AddRec and the predicates as a pair, and caches this pair in
+  /// PredicatedSCEVRewrites.
+  /// If the analysis is not successful, a mapping from the \p SymbolicPHI to 
+  /// itself (with no predicates) is recorded, and a nullptr with an empty
+  /// predicates vector is returned as a pair. 
+  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+  createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI);
+
   /// Compute the backedge taken count knowing the interval difference, the
   /// stride and presence of the equality in the comparison.
   const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride,
@@ -1722,6 +1740,12 @@ private:
   FoldingSet<SCEVPredicate> UniquePreds;
   BumpPtrAllocator SCEVAllocator;
 
+  /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression
+  /// they can be rewritten into under certain predicates.
+  DenseMap<std::pair<const SCEVUnknown *, const Loop *>,
+           std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+      PredicatedSCEVRewrites;
+   
   /// The head of a linked list of all SCEVUnknown values that have been
   /// allocated. This is used by releaseMemory to locate them all and call
   /// their destructors.

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=308299&r1=308298&r2=308299&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Tue Jul 18 04:57:08 2017
@@ -4173,6 +4173,319 @@ static Optional<BinaryOp> MatchBinaryOp(
   return None;
 }
 
+/// Helper function to createAddRecFromPHIWithCasts. We have a phi 
+/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
+/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the 
+/// way. This function checks if \p Op, an operand of this SCEVAddExpr, 
+/// follows one of the following patterns:
+/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
+/// If the SCEV expression of \p Op conforms with one of the expected patterns
+/// we return the type of the truncation operation, and indicate whether the
+/// truncated type should be treated as signed/unsigned by setting 
+/// \p Signed to true/false, respectively.
+static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
+                               bool &Signed, ScalarEvolution &SE) {
+
+  // The case where Op == SymbolicPHI (that is, with no type conversions on 
+  // the way) is handled by the regular add recurrence creating logic and 
+  // would have already been triggered in createAddRecForPHI. Reaching it here
+  // means that createAddRecFromPHI had failed for this PHI before (e.g., 
+  // because one of the other operands of the SCEVAddExpr updating this PHI is
+  // not invariant). 
+  //
+  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in 
+  // this case predicates that allow us to prove that Op == SymbolicPHI will
+  // be added.
+  if (Op == SymbolicPHI)
+    return nullptr;
+
+  unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
+  unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
+  if (SourceBits != NewBits)
+    return nullptr;
+
+  const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
+  const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
+  if (!SExt && !ZExt)
+    return nullptr;
+  const SCEVTruncateExpr *Trunc =
+      SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
+           : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
+  if (!Trunc)
+    return nullptr;
+  const SCEV *X = Trunc->getOperand();
+  if (X != SymbolicPHI)
+    return nullptr;
+  Signed = SExt ? true : false; 
+  return Trunc->getType();
+}
+
+static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
+  if (!PN->getType()->isIntegerTy())
+    return nullptr;
+  const Loop *L = LI.getLoopFor(PN->getParent());
+  if (!L || L->getHeader() != PN->getParent())
+    return nullptr;
+  return L;
+}
+
+// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
+// computation that updates the phi follows the following pattern:
+//   (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
+// which correspond to a phi->trunc->sext/zext->add->phi update chain.
+// If so, try to see if it can be rewritten as an AddRecExpr under some
+// Predicates. If successful, return them as a pair. Also cache the results
+// of the analysis.
+//
+// Example usage scenario:
+//    Say the Rewriter is called for the following SCEV:
+//         8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+//    where:
+//         %X = phi i64 (%Start, %BEValue)
+//    It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
+//    and call this function with %SymbolicPHI = %X.
+//
+//    The analysis will find that the value coming around the backedge has 
+//    the following SCEV:
+//         BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
+//    Upon concluding that this matches the desired pattern, the function
+//    will return the pair {NewAddRec, SmallPredsVec} where:
+//         NewAddRec = {%Start,+,%Step}
+//         SmallPredsVec = {P1, P2, P3} as follows:
+//           P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
+//           P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
+//           P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
+//    The returned pair means that SymbolicPHI can be rewritten into NewAddRec
+//    under the predicates {P1,P2,P3}.
+//    This predicated rewrite will be cached in PredicatedSCEVRewrites:
+//         PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} 
+//
+// TODO's:
+//
+// 1) Extend the Induction descriptor to also support inductions that involve
+//    casts: When needed (namely, when we are called in the context of the 
+//    vectorizer induction analysis), a Set of cast instructions will be 
+//    populated by this method, and provided back to isInductionPHI. This is
+//    needed to allow the vectorizer to properly record them to be ignored by
+//    the cost model and to avoid vectorizing them (otherwise these casts,
+//    which are redundant under the runtime overflow checks, will be 
+//    vectorized, which can be costly).  
+//
+// 2) Support additional induction/PHISCEV patterns: We also want to support
+//    inductions where the sext-trunc / zext-trunc operations (partly) occur 
+//    after the induction update operation (the induction increment):
+//
+//      (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
+//    which correspond to a phi->add->trunc->sext/zext->phi update chain.
+//
+//      (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
+//    which correspond to a phi->trunc->add->sext/zext->phi update chain.
+//
+// 3) Outline common code with createAddRecFromPHI to avoid duplication.
+//
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
+  SmallVector<const SCEVPredicate *, 3> Predicates;
+
+  // *** Part1: Analyze if we have a phi-with-cast pattern for which we can 
+  // return an AddRec expression under some predicate.
+ 
+  auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+  const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+  assert (L && "Expecting an integer loop header phi");
+
+  // The loop may have multiple entrances or multiple exits; we can analyze
+  // this phi as an addrec if it has a unique entry value and a unique
+  // backedge value.
+  Value *BEValueV = nullptr, *StartValueV = nullptr;
+  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+    Value *V = PN->getIncomingValue(i);
+    if (L->contains(PN->getIncomingBlock(i))) {
+      if (!BEValueV) {
+        BEValueV = V;
+      } else if (BEValueV != V) {
+        BEValueV = nullptr;
+        break;
+      }
+    } else if (!StartValueV) {
+      StartValueV = V;
+    } else if (StartValueV != V) {
+      StartValueV = nullptr;
+      break;
+    }
+  }
+  if (!BEValueV || !StartValueV)
+    return None;
+
+  const SCEV *BEValue = getSCEV(BEValueV);
+
+  // If the value coming around the backedge is an add with the symbolic
+  // value we just inserted, possibly with casts that we can ignore under
+  // an appropriate runtime guard, then we found a simple induction variable!
+  const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
+  if (!Add)
+    return None;
+
+  // If there is a single occurrence of the symbolic value, possibly
+  // casted, replace it with a recurrence. 
+  unsigned FoundIndex = Add->getNumOperands();
+  Type *TruncTy = nullptr;
+  bool Signed;
+  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+    if ((TruncTy = 
+             isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
+      if (FoundIndex == e) {
+        FoundIndex = i;
+        break;
+      }
+
+  if (FoundIndex == Add->getNumOperands())
+    return None;
+
+  // Create an add with everything but the specified operand.
+  SmallVector<const SCEV *, 8> Ops;
+  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+    if (i != FoundIndex)
+      Ops.push_back(Add->getOperand(i));
+  const SCEV *Accum = getAddExpr(Ops);
+
+  // The runtime checks will not be valid if the step amount is
+  // varying inside the loop.
+  if (!isLoopInvariant(Accum, L))
+    return None;
+
+  
+  // *** Part2: Create the predicates 
+
+  // Analysis was successful: we have a phi-with-cast pattern for which we
+  // can return an AddRec expression under the following predicates:
+  //
+  // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
+  //     fits within the truncated type (does not overflow) for i = 0 to n-1.
+  // P2: An Equal predicate that guarantees that 
+  //     Start = (Ext ix (Trunc iy (Start) to ix) to iy)
+  // P3: An Equal predicate that guarantees that 
+  //     Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
+  //
+  // As we next prove, the above predicates guarantee that: 
+  //     Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
+  //
+  //
+  // More formally, we want to prove that:
+  //     Expr(i+1) = Start + (i+1) * Accum 
+  //               = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum 
+  //
+  // Given that:
+  // 1) Expr(0) = Start 
+  // 2) Expr(1) = Start + Accum 
+  //            = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
+  // 3) Induction hypothesis (step i):
+  //    Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum 
+  //
+  // Proof:
+  //  Expr(i+1) =
+  //   = Start + (i+1)*Accum
+  //   = (Start + i*Accum) + Accum
+  //   = Expr(i) + Accum  
+  //   = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum 
+  //                                                             :: from step i
+  //
+  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum 
+  //
+  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
+  //     + (Ext ix (Trunc iy (Accum) to ix) to iy)
+  //     + Accum                                                     :: from P3
+  //
+  //   = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) 
+  //     + Accum                            :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
+  //
+  //   = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
+  //   = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum 
+  //
+  // By induction, the same applies to all iterations 1<=i<n:
+  //
+  
+  // Create a truncated addrec for which we will add a no overflow check (P1).
+  const SCEV *StartVal = getSCEV(StartValueV);
+  const SCEV *PHISCEV = 
+      getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
+                    getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); 
+  const auto *AR = cast<SCEVAddRecExpr>(PHISCEV);
+
+  SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
+      Signed ? SCEVWrapPredicate::IncrementNSSW
+             : SCEVWrapPredicate::IncrementNUSW;
+  const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
+  Predicates.push_back(AddRecPred);
+
+  // Create the Equal Predicates P2,P3:
+  auto AppendPredicate = [&](const SCEV *Expr) -> void {
+    assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
+    const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
+    const SCEV *ExtendedExpr =
+        Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType())
+               : getZeroExtendExpr(TruncatedExpr, Expr->getType());
+    if (Expr != ExtendedExpr &&
+        !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
+      const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
+      DEBUG (dbgs() << "Added Predicate: " << *Pred);
+      Predicates.push_back(Pred);
+    }
+  };
+  
+  AppendPredicate(StartVal);
+  AppendPredicate(Accum);
+  
+  // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
+  // which the casts had been folded away. The caller can rewrite SymbolicPHI
+  // into NewAR if it will also add the runtime overflow checks specified in
+  // Predicates.  
+  auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
+
+  std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
+      std::make_pair(NewAR, Predicates);
+  // Remember the result of the analysis for this SCEV at this locayyytion.
+  PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
+  return PredRewrite;
+}
+
+Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
+
+  auto *PN = cast<PHINode>(SymbolicPHI->getValue());
+  const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
+  if (!L)
+    return None;
+
+  // Check to see if we already analyzed this PHI.
+  auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
+  if (I != PredicatedSCEVRewrites.end()) {
+    std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
+        I->second;
+    // Analysis was done before and failed to create an AddRec:
+    if (Rewrite.first == SymbolicPHI) 
+      return None;
+    // Analysis was done before and succeeded to create an AddRec under
+    // a predicate:
+    assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
+    assert(!(Rewrite.second).empty() && "Expected to find Predicates");
+    return Rewrite;
+  }
+
+  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+    Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
+
+  // Record in the cache that the analysis failed
+  if (!Rewrite) {
+    SmallVector<const SCEVPredicate *, 3> Predicates;
+    PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
+    return None;
+  }
+
+  return Rewrite;
+}
+
 /// A helper function for createAddRecFromPHI to handle simple cases.
 ///
 /// This function tries to find an AddRec expression for the simplest (yet most
@@ -5904,6 +6217,16 @@ void ScalarEvolution::forgetLoop(const L
   RemoveLoopFromBackedgeMap(BackedgeTakenCounts);
   RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts);
 
+  // Drop information about predicated SCEV rewrites for this loop.
+  for (auto I = PredicatedSCEVRewrites.begin();
+       I != PredicatedSCEVRewrites.end();) {
+    std::pair<const SCEV *, const Loop *> Entry = I->first;
+    if (Entry.second == L)
+      PredicatedSCEVRewrites.erase(I++);
+    else
+      ++I;
+  }
+
   // Drop information about expressions based on loop-header PHIs.
   SmallVector<Instruction *, 16> Worklist;
   PushLoopPHIs(L, Worklist);
@@ -10062,6 +10385,7 @@ ScalarEvolution::ScalarEvolution(ScalarE
       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
       UniquePreds(std::move(Arg.UniquePreds)),
       SCEVAllocator(std::move(Arg.SCEVAllocator)),
+      PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
       FirstUnknown(Arg.FirstUnknown) {
   Arg.FirstUnknown = nullptr;
 }
@@ -10462,6 +10786,15 @@ void ScalarEvolution::forgetMemoizedResu
   HasRecMap.erase(S);
   MinTrailingZerosCache.erase(S);
 
+  for (auto I = PredicatedSCEVRewrites.begin(); 
+       I != PredicatedSCEVRewrites.end();) {
+    std::pair<const SCEV *, const Loop *> Entry = I->first;
+    if (Entry.first == S)
+      PredicatedSCEVRewrites.erase(I++);
+    else
+      ++I;
+  }
+
   auto RemoveSCEVFromBackedgeMap =
       [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
         for (auto I = Map.begin(), E = Map.end(); I != E;) {
@@ -10621,10 +10954,11 @@ void ScalarEvolutionWrapperPass::getAnal
   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
 }
 
-const SCEVPredicate *
-ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
-                                   const SCEVConstant *RHS) {
+const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
+                                                        const SCEV *RHS) {
   FoldingSetNodeID ID;
+  assert(LHS->getType() == RHS->getType() &&
+         "Type mismatch between LHS and RHS");
   // Unique this node based on the arguments
   ID.AddInteger(SCEVPredicate::P_Equal);
   ID.AddPointer(LHS);
@@ -10687,8 +11021,7 @@ public:
           if (IPred->getLHS() == Expr)
             return IPred->getRHS();
     }
-
-    return Expr;
+    return convertToAddRecWithPreds(Expr);
   }
 
   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
@@ -10724,17 +11057,41 @@ public:
   }
 
 private:
-  bool addOverflowAssumption(const SCEVAddRecExpr *AR,
-                             SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
-    auto *A = SE.getWrapPredicate(AR, AddedFlags);
+  bool addOverflowAssumption(const SCEVPredicate *P) {
     if (!NewPreds) {
       // Check if we've already made this assumption.
-      return Pred && Pred->implies(A);
+      return Pred && Pred->implies(P);
     }
-    NewPreds->insert(A);
+    NewPreds->insert(P);
     return true;
   }
 
+  bool addOverflowAssumption(const SCEVAddRecExpr *AR,
+                             SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
+    auto *A = SE.getWrapPredicate(AR, AddedFlags);
+    return addOverflowAssumption(A);
+  }
+
+  // If \p Expr represents a PHINode, we try to see if it can be represented
+  // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible 
+  // to add this predicate as a runtime overflow check, we return the AddRec.
+  // If \p Expr does not meet these conditions (is not a PHI node, or we 
+  // couldn't create an AddRec for it, or couldn't add the predicate), we just 
+  // return \p Expr.
+  const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
+    if (!isa<PHINode>(Expr->getValue()))
+      return Expr;
+    Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
+    PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
+    if (!PredicatedRewrite)
+      return Expr;
+    for (auto *P : PredicatedRewrite->second){
+      if (!addOverflowAssumption(P))
+        return Expr;
+    }
+    return PredicatedRewrite->first;
+  }
+  
   SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
   SCEVUnionPredicate *Pred;
   const Loop *L;
@@ -10771,9 +11128,11 @@ SCEVPredicate::SCEVPredicate(const Foldi
     : FastID(ID), Kind(Kind) {}
 
 SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
-                                       const SCEVUnknown *LHS,
-                                       const SCEVConstant *RHS)
-    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+                                       const SCEV *LHS, const SCEV *RHS)
+    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {
+  assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
+  assert(LHS != RHS && "LHS and RHS are the same SCEV");
+}
 
 bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
   const auto *Op = dyn_cast<SCEVEqualPredicate>(N);

Added: llvm/trunk/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll?rev=308299&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll (added)
+++ llvm/trunk/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll Tue Jul 18 04:57:08 2017
@@ -0,0 +1,240 @@
+; RUN: opt -S -loop-vectorize -force-vector-width=4 -force-vector-interleave=1 < %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+
+; Check that the vectorizer identifies the %p.09 phi,
+; as an induction variable, despite the potential overflow
+; due to the truncation from 32bit to 8bit. 
+; SCEV will detect the pattern "sext(trunc(%p.09)) + %step"
+; and generate the required runtime checks under which
+; we can assume no overflow. We check here that we generate
+; exactly two runtime checks:
+; 1) an overflow check:
+;    {0,+,(trunc i32 %step to i8)}<%for.body> Added Flags: <nssw>
+; 2) an equality check verifying that the step of the induction 
+;    is equal to sext(trunc(step)): 
+;    Equal predicate: %step == (sext i8 (trunc i32 %step to i8) to i32)
+; 
+; See also pr30654.
+;
+; int a[N];
+; void doit1(int n, int step) {
+;   int i;
+;   char p = 0;
+;   for (i = 0; i < n; i++) {
+;      a[i] = p;
+;      p = p + step;
+;   }
+; }
+; 
+
+; CHECK-LABEL: @doit1
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %[[TEST:[0-9]+]] = or i1 {{.*}}, %mul.overflow
+; CHECK: %[[NTEST:[0-9]+]] = or i1 false, %[[TEST]]
+; CHECK: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK: %{{.*}} = or i1 %[[NTEST]], %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+ at a = common local_unnamed_addr global [250 x i32] zeroinitializer, align 16
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit1(i32 %n, i32 %step) local_unnamed_addr {
+entry:
+  %cmp7 = icmp sgt i32 %n, 0
+  br i1 %cmp7, label %for.body.preheader, label %for.end
+
+for.body.preheader:                    
+  %wide.trip.count = zext i32 %n to i64
+  br label %for.body
+
+for.body:                  
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+  %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+  %sext = shl i32 %p.09, 24
+  %conv = ashr exact i32 %sext, 24
+  %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+  store i32 %conv, i32* %arrayidx, align 4
+  %add = add nsw i32 %conv, %step
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:                        
+  br label %for.end
+
+for.end:                                
+  ret void
+}
+
+; Same as above, but for checking the SCEV "zext(trunc(%p.09)) + %step".
+; Here we expect the following two predicates to be added for runtime checking:
+; 1) {0,+,(trunc i32 %step to i8)}<%for.body> Added Flags: <nusw>
+; 2) Equal predicate: %step == (zext i8 (trunc i32 %step to i8) to i32)
+;
+; int a[N];
+; void doit2(int n, int step) {
+;   int i;
+;   unsigned char p = 0;
+;   for (i = 0; i < n; i++) {
+;      a[i] = p;
+;      p = p + step;
+;   }
+; }
+; 
+
+; CHECK-LABEL: @doit2
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %[[TEST:[0-9]+]] = or i1 {{.*}}, %mul.overflow
+; CHECK: %[[NTEST:[0-9]+]] = or i1 false, %[[TEST]]
+; CHECK: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK: %{{.*}} = or i1 %[[NTEST]], %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit2(i32 %n, i32 %step) local_unnamed_addr  {
+entry:
+  %cmp7 = icmp sgt i32 %n, 0
+  br i1 %cmp7, label %for.body.preheader, label %for.end
+
+for.body.preheader:                             
+  %wide.trip.count = zext i32 %n to i64
+  br label %for.body
+
+for.body:                                      
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+  %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+  %conv = and i32 %p.09, 255
+  %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+  store i32 %conv, i32* %arrayidx, align 4
+  %add = add nsw i32 %conv, %step
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:                        
+  br label %for.end
+
+for.end:                                
+  ret void
+}
+
+; Here we check that the same phi scev analysis would fail 
+; to create the runtime checks because the step is not invariant.
+; As a result vectorization will fail.
+;
+; int a[N];
+; void doit3(int n, int step) {
+;   int i;
+;   char p = 0;
+;   for (i = 0; i < n; i++) {
+;      a[i] = p;
+;      p = p + step;
+;      step += 2;
+;   }
+; }
+;
+
+; CHECK-LABEL: @doit3
+; CHECK-NOT: vector.scevcheck
+; CHECK-NOT: vector.body:
+; CHECK-LABEL: for.body:
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit3(i32 %n, i32 %step) local_unnamed_addr {
+entry:
+  %cmp9 = icmp sgt i32 %n, 0
+  br i1 %cmp9, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+  %wide.trip.count = zext i32 %n to i64
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+  %p.012 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+  %step.addr.010 = phi i32 [ %add3, %for.body ], [ %step, %for.body.preheader ]
+  %sext = shl i32 %p.012, 24
+  %conv = ashr exact i32 %sext, 24
+  %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+  store i32 %conv, i32* %arrayidx, align 4
+  %add = add nsw i32 %conv, %step.addr.010
+  %add3 = add nsw i32 %step.addr.010, 2
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}
+
+
+; Lastly, we also check the case where we can tell at compile time that
+; the step of the induction is equal to sext(trunc(step)), in which case
+; we don't have to check this equality at runtime (we only need the
+; runtime overflow check). Therefore only the following overflow predicate
+; will be added for runtime checking:
+; {0,+,%cstep}<%for.body> Added Flags: <nssw>
+;
+; a[N];
+; void doit4(int n, char cstep) {
+;   int i;
+;   char p = 0;
+;   int istep = cstep;
+;  for (i = 0; i < n; i++) {
+;      a[i] = p;
+;      p = p + istep;
+;   }
+; }
+
+; CHECK-LABEL: @doit4
+; CHECK: vector.scevcheck
+; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: %{{.*}} = or i1 {{.*}}, %mul.overflow
+; CHECK-NOT: %ident.check = icmp ne i32 {{.*}}, %{{.*}}
+; CHECK-NOT: %{{.*}} = or i1 %{{.*}}, %ident.check
+; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}})
+; CHECK: vector.body:
+; CHECK: <4 x i32>
+
+; Function Attrs: norecurse nounwind uwtable
+define void @doit4(i32 %n, i8 signext %cstep) local_unnamed_addr {
+entry:
+  %conv = sext i8 %cstep to i32
+  %cmp10 = icmp sgt i32 %n, 0
+  br i1 %cmp10, label %for.body.preheader, label %for.end
+
+for.body.preheader:
+  %wide.trip.count = zext i32 %n to i64
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ]
+  %p.011 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ]
+  %sext = shl i32 %p.011, 24
+  %conv2 = ashr exact i32 %sext, 24
+  %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv
+  store i32 %conv2, i32* %arrayidx, align 4
+  %add = add nsw i32 %conv2, %conv
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}




More information about the llvm-commits mailing list