[llvm] r251800 - [SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning

Adam Nemet via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 4 21:39:38 PST 2015


> On Nov 2, 2015, at 6:41 AM, Silviu Baranga via llvm-commits <llvm-commits at lists.llvm.org> wrote:
> 
> Author: sbaranga
> Date: Mon Nov  2 08:41:02 2015
> New Revision: 251800
> 
> URL: http://llvm.org/viewvc/llvm-project?rev=251800&view=rev
> Log:
> [SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
> 
> Summary:
> SCEV Predicates represent conditions that typically cannot be derived from
> static analysis, but can be used to reduce SCEV expressions to forms which are
> usable for different optimizers.
> 
> ScalarEvolution now has the rewriteUsingPredicate method which can simplify a
> SCEV expression using a SCEVPredicateSet. The normal workflow of a pass using
> SCEVPredicates would be to hold a SCEVPredicateSet and every time assumptions
> need to be made a new SCEV Predicate would be created and added to the set.
> Each time after calling getSCEV, the user will call the rewriteUsingPredicate
> method.
> 
> We add two types of predicates
> SCEVPredicateSet - implements a set of predicates
> SCEVEqualPredicate - tests for equality between two SCEV expressions
> 
> We use the SCEVEqualPredicate to re-implement stride versioning. Every time we
> version a stride, we will add a SCEVEqualPredicate to the context.
> Instead of adding specific stride checks, LoopVectorize now adds a more
> generic SCEV check.
> 
> We only need to add support for this in the LoopVectorizer since this is the
> only pass that will do stride versioning.
> 
> Reviewers: mzolotukhin, anemet, hfinkel, sanjoy
> 
> Subscribers: sanjoy, hfinkel, rengolin, jmolloy, llvm-commits
> 
> Differential Revision: http://reviews.llvm.org/D13595
> 
> Modified:
>    llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h
>    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
>    llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h
>    llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp
>    llvm/trunk/lib/Analysis/ScalarEvolution.cpp
>    llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp
>    llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
> 
> Modified: llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h (original)
> +++ llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h Mon Nov  2 08:41:02 2015
> @@ -32,6 +32,7 @@ class DataLayout;
> class ScalarEvolution;
> class Loop;
> class SCEV;
> +class SCEVUnionPredicate;
> 
> /// Optimization analysis message produced during vectorization. Messages inform
> /// the user why vectorization did not occur.
> @@ -176,10 +177,11 @@ public:
>                const SmallVectorImpl<Instruction *> &Instrs) const;
>   };
> 
> -  MemoryDepChecker(ScalarEvolution *Se, const Loop *L)
> +  MemoryDepChecker(ScalarEvolution *Se, const Loop *L,
> +                   SCEVUnionPredicate &Preds)
>       : SE(Se), InnermostLoop(L), AccessIdx(0),
>         ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
> -        RecordInterestingDependences(true) {}
> +        RecordInterestingDependences(true), Preds(Preds) {}
> 
>   /// \brief Register the location (instructions are given increasing numbers)
>   /// of a write access.
> @@ -289,6 +291,15 @@ private:
>   /// \brief Check whether the data dependence could prevent store-load
>   /// forwarding.
>   bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
> +
> +  /// The SCEV predicate containing all the SCEV-related assumptions.
> +  /// The dependence checker needs this in order to convert SCEVs of pointers
> +  /// to more accurate expressions in the context of existing assumptions.
> +  /// We also need this in case assumptions about SCEV expressions need to
> +  /// be made in order to avoid unknown dependences. For example we might
> +  /// assume a unit stride for a pointer in order to prove that a memory access
> +  /// is strided and doesn't wrap.
> +  SCEVUnionPredicate &Preds;
> };
> 
> /// \brief Holds information about the memory runtime legality checks to verify
> @@ -330,8 +341,13 @@ public:
>   }
> 
>   /// Insert a pointer and calculate the start and end SCEVs.
> +  /// \p We need Preds in order to compute the SCEV expression of the pointer
> +  /// according to the assumptions that we've made during the analysis.
> +  /// The method might also version the pointer stride according to \p Strides,
> +  /// and change \p Preds.
>   void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
> -              unsigned ASId, const ValueToValueMap &Strides);
> +              unsigned ASId, const ValueToValueMap &Strides,
> +              SCEVUnionPredicate &Preds);
> 
>   /// \brief No run-time memory checking is necessary.
>   bool empty() const { return Pointers.empty(); }
> @@ -537,6 +553,15 @@ public:
>     return StoreToLoopInvariantAddress;
>   }
> 
> +  /// The SCEV predicate contains all the SCEV-related assumptions.
> +  /// The is used to keep track of the minimal set of assumptions on SCEV
> +  /// expressions that the analysis needs to make in order to return a
> +  /// meaningful result. All SCEV expressions during the analysis should be
> +  /// re-written (and therefore simplified) according to Preds.
> +  /// A user of LoopAccessAnalysis will need to emit the runtime checks
> +  /// associated with this predicate.
> +  SCEVUnionPredicate Preds;

Can you please also add a comment probably right before LoopAccessInfo that both sets of informations it provides (run-time alias checks and dependence information) are only correct if the predicates contained here return are true (at run time).

Thanks,
Adam

> +
> private:
>   /// \brief Analyze the loop.  Substitute symbolic strides using Strides.
>   void analyzeLoop(const ValueToValueMap &Strides);
> @@ -583,19 +608,26 @@ private:
> Value *stripIntegerCast(Value *V);
> 
> ///\brief Return the SCEV corresponding to a pointer with the symbolic stride
> -///replaced with constant one.
> +/// replaced with constant one, assuming \p Preds is true.
> +///
> +/// If necessary this method will version the stride of the pointer according
> +/// to \p PtrToStride and therefore add a new predicate to \p Preds.
> ///
> /// If \p OrigPtr is not null, use it to look up the stride value instead of \p
> /// Ptr.  \p PtrToStride provides the mapping between the pointer value and its
> /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
> const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
>                                       const ValueToValueMap &PtrToStride,
> -                                      Value *Ptr, Value *OrigPtr = nullptr);
> +                                      SCEVUnionPredicate &Preds, Value *Ptr,
> +                                      Value *OrigPtr = nullptr);
> 
> /// \brief Check the stride of the pointer and ensure that it does not wrap in
> -/// the address space.
> +/// the address space, assuming \p Preds is true.
> +///
> +/// If necessary this method will version the stride of the pointer according
> +/// to \p PtrToStride and therefore add a new predicate to \p Preds.
> int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
> -                 const ValueToValueMap &StridesMap);
> +                 const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds);
> 
> /// \brief This analysis provides dependence information for the memory accesses
> /// of a loop.
> 
> Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
> +++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Mon Nov  2 08:41:02 2015
> @@ -48,10 +48,15 @@ namespace llvm {
>   class Loop;
>   class LoopInfo;
>   class Operator;
> -  class SCEVUnknown;
> -  class SCEVAddRecExpr;
>   class SCEV;
> -  template<> struct FoldingSetTrait<SCEV>;
> +  class SCEVAddRecExpr;
> +  class SCEVConstant;
> +  class SCEVExpander;
> +  class SCEVPredicate;
> +  class SCEVUnknown;
> +
> +  template <> struct FoldingSetTrait<SCEV>;
> +  template <> struct FoldingSetTrait<SCEVPredicate>;
> 
>   /// This class represents an analyzed expression in the program.  These are
>   /// opaque objects that the client is not allowed to do much with directly.
> @@ -164,6 +169,148 @@ namespace llvm {
>     static bool classof(const SCEV *S);
>   };
> 
> +  /// SCEVPredicate - This class represents an assumption made using SCEV
> +  /// expressions which can be checked at run-time.
> +  class SCEVPredicate : public FoldingSetNode {
> +    friend struct FoldingSetTrait<SCEVPredicate>;
> +
> +    /// A reference to an Interned FoldingSetNodeID for this node.  The
> +    /// ScalarEvolution's BumpPtrAllocator holds the data.
> +    FoldingSetNodeIDRef FastID;
> +
> +  public:
> +    enum SCEVPredicateKind { P_Union, P_Equal };
> +
> +  protected:
> +    SCEVPredicateKind Kind;
> +
> +  public:
> +    SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
> +
> +    virtual ~SCEVPredicate() {}
> +
> +    SCEVPredicateKind getKind() const { return Kind; }
> +
> +    /// \brief Returns the estimated complexity of this predicate.
> +    /// This is roughly measured in the number of run-time checks required.
> +    virtual unsigned getComplexity() { return 1; }
> +
> +    /// \brief Returns true if the predicate is always true. This means that no
> +    /// assumptions were made and nothing needs to be checked at run-time.
> +    virtual bool isAlwaysTrue() const = 0;
> +
> +    /// \brief Returns true if this predicate implies \p N.
> +    virtual bool implies(const SCEVPredicate *N) const = 0;
> +
> +    /// \brief Prints a textual representation of this predicate with an
> +    /// indentation of \p Depth.
> +    virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
> +
> +    /// \brief Returns the SCEV to which this predicate applies, or nullptr
> +    /// if this is a SCEVUnionPredicate.
> +    virtual const SCEV *getExpr() const = 0;
> +  };
> +
> +  inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
> +    P.print(OS);
> +    return OS;
> +  }
> +
> +  // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
> +  // temporary FoldingSetNodeID values.
> +  template <>
> +  struct FoldingSetTrait<SCEVPredicate>
> +      : DefaultFoldingSetTrait<SCEVPredicate> {
> +
> +    static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
> +      ID = X.FastID;
> +    }
> +
> +    static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
> +                       unsigned IDHash, FoldingSetNodeID &TempID) {
> +      return ID == X.FastID;
> +    }
> +    static unsigned ComputeHash(const SCEVPredicate &X,
> +                                FoldingSetNodeID &TempID) {
> +      return X.FastID.ComputeHash();
> +    }
> +  };
> +
> +  /// SCEVEqualPredicate - This class represents an assumption that two SCEV
> +  /// expressions are equal, and this can be checked at run-time. We assume
> +  /// that the left hand side is a SCEVUnknown and the right hand side a
> +  /// constant.
> +  class SCEVEqualPredicate : public SCEVPredicate {
> +    /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
> +    /// constant.
> +    const SCEVUnknown *LHS;
> +    const SCEVConstant *RHS;
> +
> +  public:
> +    SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
> +                       const SCEVConstant *RHS);
> +
> +    /// Implementation of the SCEVPredicate interface
> +    bool implies(const SCEVPredicate *N) const override;
> +    void print(raw_ostream &OS, unsigned Depth = 0) const override;
> +    bool isAlwaysTrue() const override;
> +    const SCEV *getExpr() const;
> +
> +    /// \brief Returns the left hand side of the equality.
> +    const SCEVUnknown *getLHS() const { return LHS; }
> +
> +    /// \brief Returns the right hand side of the equality.
> +    const SCEVConstant *getRHS() const { return RHS; }
> +
> +    /// Methods for support type inquiry through isa, cast, and dyn_cast:
> +    static inline bool classof(const SCEVPredicate *P) {
> +      return P->getKind() == P_Equal;
> +    }
> +  };
> +
> +  /// SCEVUnionPredicate - This class represents a composition of other
> +  /// SCEV predicates, and is the class that most clients will interact with.
> +  /// This is equivalent to a logical "AND" of all the predicates in the union.
> +  class SCEVUnionPredicate : public SCEVPredicate {
> +  private:
> +    typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
> +        PredicateMap;
> +
> +    /// Vector with references to all predicates in this union.
> +    SmallVector<const SCEVPredicate *, 16> Preds;
> +    /// Maps SCEVs to predicates for quick look-ups.
> +    PredicateMap SCEVToPreds;
> +
> +  public:
> +    SCEVUnionPredicate();
> +
> +    const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
> +      return Preds;
> +    }
> +
> +    /// \brief Adds a predicate to this union.
> +    void add(const SCEVPredicate *N);
> +
> +    /// \brief Returns a reference to a vector containing all predicates
> +    /// which apply to \p Expr.
> +    ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
> +
> +    /// Implementation of the SCEVPredicate interface
> +    bool isAlwaysTrue() const override;
> +    bool implies(const SCEVPredicate *N) const override;
> +    void print(raw_ostream &OS, unsigned Depth) const;
> +    const SCEV *getExpr() const override;
> +
> +    /// \brief We estimate the complexity of a union predicate as the size
> +    /// number of predicates in the union.
> +    unsigned getComplexity() override { return Preds.size(); }
> +
> +    /// Methods for support type inquiry through isa, cast, and dyn_cast:
> +    static inline bool classof(const SCEVPredicate *P) {
> +      return P->getKind() == P_Union;
> +    }
> +  };
> +
>   /// The main scalar evolution driver. Because client code (intentionally)
>   /// can't do much with the SCEV objects directly, they must ask this class
>   /// for services.
> @@ -1108,6 +1255,12 @@ namespace llvm {
>       return F.getParent()->getDataLayout();
>     }
> 
> +    const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
> +                                           const SCEVConstant *RHS);
> +
> +    /// Re-writes the SCEV according to the Predicates in \p Preds.
> +    const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
> +
>   private:
>     /// Compute the backedge taken count knowing the interval difference, the
>     /// stride and presence of the equality in the comparison.
> @@ -1128,6 +1281,7 @@ namespace llvm {
> 
>   private:
>     FoldingSet<SCEV> UniqueSCEVs;
> +    FoldingSet<SCEVPredicate> UniquePreds;
>     BumpPtrAllocator SCEVAllocator;
> 
>     /// The head of a linked list of all SCEVUnknown values that have been
> 
> Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h (original)
> +++ llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h Mon Nov  2 08:41:02 2015
> @@ -151,6 +151,22 @@ namespace llvm {
>     /// block.
>     Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
> 
> +    /// \brief Generates a code sequence that evaluates this predicate.
> +    /// The inserted instructions will be at position \p Loc.
> +    /// The result will be of type i1 and will have a value of 0 when the
> +    /// predicate is false and 1 otherwise.
> +    Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc);
> +
> +    /// \brief A specialized variant of expandCodeForPredicate, handling the
> +    /// case when we are expanding code for a SCEVEqualPredicate.
> +    Value *expandEqualPredicate(const SCEVEqualPredicate *Pred,
> +                                Instruction *Loc);
> +
> +    /// \brief A specialized variant of expandCodeForPredicate, handling the
> +    /// case when we are expanding code for a SCEVUnionPredicate.
> +    Value *expandUnionPredicate(const SCEVUnionPredicate *Pred,
> +                                Instruction *Loc);
> +
>     /// \brief Set the current IV increment loop and position.
>     void setIVIncInsertPos(const Loop *L, Instruction *Pos) {
>       assert(!CanonicalMode &&
> 
> Modified: llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp (original)
> +++ llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp Mon Nov  2 08:41:02 2015
> @@ -89,8 +89,8 @@ Value *llvm::stripIntegerCast(Value *V)
> 
> const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
>                                             const ValueToValueMap &PtrToStride,
> +                                            SCEVUnionPredicate &Preds,
>                                             Value *Ptr, Value *OrigPtr) {
> -
>   const SCEV *OrigSCEV = SE->getSCEV(Ptr);
> 
>   // If there is an entry in the map return the SCEV of the pointer with the
> @@ -108,22 +108,28 @@ const SCEV *llvm::replaceSymbolicStrideS
>     ValueToValueMap RewriteMap;
>     RewriteMap[StrideVal] = One;
> 
> -    const SCEV *ByOne =
> -        SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true);
> +    const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
> +    const auto *CT =
> +        static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
> +
> +    Preds.add(SE->getEqualPredicate(U, CT));
> +
> +    const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds);
>     DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
>                  << "\n");
>     return ByOne;
>   }
> 
>   // Otherwise, just return the SCEV of the original pointer.
> -  return SE->getSCEV(Ptr);
> +  return OrigSCEV;
> }
> 
> void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
>                                     unsigned DepSetId, unsigned ASId,
> -                                    const ValueToValueMap &Strides) {
> +                                    const ValueToValueMap &Strides,
> +                                    SCEVUnionPredicate &Preds) {
>   // Get the stride replaced scev.
> -  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> +  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
>   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
>   assert(AR && "Invalid addrec expression");
>   const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
> @@ -417,9 +423,9 @@ public:
>   typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
> 
>   AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
> -                 MemoryDepChecker::DepCandidates &DA)
> -      : DL(Dl), AST(*AA), LI(LI), DepCands(DA),
> -        IsRTCheckAnalysisNeeded(false) {}
> +                 MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds)
> +      : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
> +        Preds(Preds) {}
> 
>   /// \brief Register a load  and whether it is only read from.
>   void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
> @@ -504,14 +510,18 @@ private:
>   /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared
>   /// while this remains set if we have potentially dependent accesses.
>   bool IsRTCheckAnalysisNeeded;
> +
> +  /// The SCEV predicate containing all the SCEV-related assumptions.
> +  SCEVUnionPredicate &Preds;
> };
> 
> } // end anonymous namespace
> 
> /// \brief Check whether a pointer can participate in a runtime bounds check.
> static bool hasComputableBounds(ScalarEvolution *SE,
> -                                const ValueToValueMap &Strides, Value *Ptr) {
> -  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> +                                const ValueToValueMap &Strides, Value *Ptr,
> +                                Loop *L, SCEVUnionPredicate &Preds) {
> +  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
>   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
>   if (!AR)
>     return false;
> @@ -554,11 +564,11 @@ bool AccessAnalysis::canCheckPtrAtRT(Run
>       else
>         ++NumReadPtrChecks;
> 
> -      if (hasComputableBounds(SE, StridesMap, Ptr) &&
> +      if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) &&
>           // When we run after a failing dependency check we have to make sure
>           // we don't have wrapping pointers.
>           (!ShouldCheckStride ||
> -           isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) {
> +           isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) {
>         // The id of the dependence set.
>         unsigned DepId;
> 
> @@ -572,7 +582,7 @@ bool AccessAnalysis::canCheckPtrAtRT(Run
>           // Each access has its own dependence set.
>           DepId = RunningDepId++;
> 
> -        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
> +        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds);
> 
>         DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
>       } else {
> @@ -803,7 +813,8 @@ static bool isNoWrapAddRec(Value *Ptr, c
> 
> /// \brief Check whether the access through \p Ptr has a constant stride.
> int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
> -                       const ValueToValueMap &StridesMap) {
> +                       const ValueToValueMap &StridesMap,
> +                       SCEVUnionPredicate &Preds) {
>   Type *Ty = Ptr->getType();
>   assert(Ty->isPointerTy() && "Unexpected non-ptr");
> 
> @@ -815,7 +826,7 @@ int llvm::isStridedPtr(ScalarEvolution *
>     return 0;
>   }
> 
> -  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr);
> +  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr);
> 
>   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
>   if (!AR) {
> @@ -1026,11 +1037,11 @@ MemoryDepChecker::isDependent(const MemA
>       BPtr->getType()->getPointerAddressSpace())
>     return Dependence::Unknown;
> 
> -  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr);
> -  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr);
> +  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr);
> +  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr);
> 
> -  int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides);
> -  int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides);
> +  int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds);
> +  int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds);
> 
>   const SCEV *Src = AScev;
>   const SCEV *Sink = BScev;
> @@ -1429,7 +1440,7 @@ void LoopAccessInfo::analyzeLoop(const V
> 
>   MemoryDepChecker::DepCandidates DependentAccesses;
>   AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
> -                          AA, LI, DependentAccesses);
> +                          AA, LI, DependentAccesses, Preds);
> 
>   // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
>   // multiple times on the same object. If the ptr is accessed twice, once
> @@ -1480,7 +1491,8 @@ void LoopAccessInfo::analyzeLoop(const V
>     // read a few words, modify, and write a few words, and some of the
>     // words may be written to the same address.
>     bool IsReadOnlyPtr = false;
> -    if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) {
> +    if (Seen.insert(Ptr).second ||
> +        !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) {
>       ++NumReads;
>       IsReadOnlyPtr = true;
>     }
> @@ -1730,7 +1742,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L,
>                                const TargetLibraryInfo *TLI, AliasAnalysis *AA,
>                                DominatorTree *DT, LoopInfo *LI,
>                                const ValueToValueMap &Strides)
> -    : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL),
> +    : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL),
>       TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
>       MaxSafeDepDistBytes(-1U), CanVecMem(false),
>       StoreToLoopInvariantAddress(false) {
> @@ -1765,6 +1777,9 @@ void LoopAccessInfo::print(raw_ostream &
>   OS.indent(Depth) << "Store to invariant address was "
>                    << (StoreToLoopInvariantAddress ? "" : "not ")
>                    << "found in loop.\n";
> +
> +  OS.indent(Depth) << "SCEV assumptions:\n";
> +  Preds.print(OS, Depth);
> }
> 
> const LoopAccessInfo &
> @@ -1778,8 +1793,8 @@ LoopAccessAnalysis::getInfo(Loop *L, con
> 
>   if (!LAI) {
>     const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
> -    LAI = llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI,
> -                                            Strides);
> +    LAI =
> +        llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI, Strides);
> #ifndef NDEBUG
>     LAI->NumSymbolicStrides = Strides.size();
> #endif
> 
> Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
> +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Mon Nov  2 08:41:02 2015
> @@ -9093,6 +9093,7 @@ ScalarEvolution::ScalarEvolution(ScalarE
>       UnsignedRanges(std::move(Arg.UnsignedRanges)),
>       SignedRanges(std::move(Arg.SignedRanges)),
>       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
> +      UniquePreds(std::move(Arg.UniquePreds)),
>       SCEVAllocator(std::move(Arg.SCEVAllocator)),
>       FirstUnknown(Arg.FirstUnknown) {
>   Arg.FirstUnknown = nullptr;
> @@ -9596,3 +9597,134 @@ void ScalarEvolutionWrapperPass::getAnal
>   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
>   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
> }
> +
> +const SCEVPredicate *
> +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
> +                                   const SCEVConstant *RHS) {
> +  FoldingSetNodeID ID;
> +  // Unique this node based on the arguments
> +  ID.AddInteger(SCEVPredicate::P_Equal);
> +  ID.AddPointer(LHS);
> +  ID.AddPointer(RHS);
> +  void *IP = nullptr;
> +  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
> +    return S;
> +  SCEVEqualPredicate *Eq = new (SCEVAllocator)
> +      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
> +  UniquePreds.InsertNode(Eq, IP);
> +  return Eq;
> +}
> +
> +class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
> +public:
> +  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
> +                             SCEVUnionPredicate &A) {
> +    SCEVPredicateRewriter Rewriter(SE, A);
> +    return Rewriter.visit(Scev);
> +  }
> +
> +  SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
> +      : SCEVRewriteVisitor(SE), P(P) {}
> +
> +  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
> +    auto ExprPreds = P.getPredicatesForExpr(Expr);
> +    for (auto *Pred : ExprPreds)
> +      if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
> +        if (IPred->getLHS() == Expr)
> +          return IPred->getRHS();
> +
> +    return Expr;
> +  }
> +
> +private:
> +  SCEVUnionPredicate &P;
> +};
> +
> +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
> +                                                   SCEVUnionPredicate &Preds) {
> +  return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
> +}
> +
> +/// SCEV predicates
> +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
> +                             SCEVPredicateKind Kind)
> +    : FastID(ID), Kind(Kind) {}
> +
> +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
> +                                       const SCEVUnknown *LHS,
> +                                       const SCEVConstant *RHS)
> +    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
> +
> +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
> +  const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
> +
> +  if (!Op)
> +    return false;
> +
> +  return Op->LHS == LHS && Op->RHS == RHS;
> +}
> +
> +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
> +
> +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
> +
> +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
> +  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
> +}
> +
> +/// Union predicates don't get cached so create a dummy set ID for it.
> +SCEVUnionPredicate::SCEVUnionPredicate()
> +    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
> +
> +bool SCEVUnionPredicate::isAlwaysTrue() const {
> +  return std::all_of(Preds.begin(), Preds.end(),
> +                     [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
> +}
> +
> +ArrayRef<const SCEVPredicate *>
> +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
> +  auto I = SCEVToPreds.find(Expr);
> +  if (I == SCEVToPreds.end())
> +    return ArrayRef<const SCEVPredicate *>();
> +  return I->second;
> +}
> +
> +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
> +  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
> +    return std::all_of(
> +        Set->Preds.begin(), Set->Preds.end(),
> +        [this](const SCEVPredicate *I) { return this->implies(I); });
> +
> +  auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
> +  if (ScevPredsIt == SCEVToPreds.end())
> +    return false;
> +  auto &SCEVPreds = ScevPredsIt->second;
> +
> +  return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
> +                     [N](const SCEVPredicate *I) { return I->implies(N); });
> +}
> +
> +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
> +
> +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
> +  for (auto Pred : Preds)
> +    Pred->print(OS, Depth);
> +}
> +
> +void SCEVUnionPredicate::add(const SCEVPredicate *N) {
> +  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
> +    for (auto Pred : Set->Preds)
> +      add(Pred);
> +    return;
> +  }
> +
> +  if (implies(N))
> +    return;
> +
> +  const SCEV *Key = N->getExpr();
> +  assert(Key && "Only SCEVUnionPredicate doesn't have an "
> +                " associated expression!");
> +
> +  SCEVToPreds[Key].push_back(N);
> +  Preds.push_back(N);
> +}
> 
> Modified: llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp (original)
> +++ llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp Mon Nov  2 08:41:02 2015
> @@ -1944,6 +1944,43 @@ bool SCEVExpander::isHighCostExpansionHe
>   return false;
> }
> 
> +Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
> +                                            Instruction *IP) {
> +  assert(IP);
> +  switch (Pred->getKind()) {
> +  case SCEVPredicate::P_Union:
> +    return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
> +  case SCEVPredicate::P_Equal:
> +    return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
> +  }
> +  llvm_unreachable("Unknown SCEV predicate type");
> +}
> +
> +Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
> +                                          Instruction *IP) {
> +  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
> +  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
> +
> +  Builder.SetInsertPoint(IP);
> +  auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
> +  return I;
> +}
> +
> +Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
> +                                          Instruction *IP) {
> +  auto *BoolType = IntegerType::get(IP->getContext(), 1);
> +  Value *Check = ConstantInt::getNullValue(BoolType);
> +
> +  // Loop over all checks in this set.
> +  for (auto Pred : Union->getPredicates()) {
> +    auto *NextCheck = expandCodeForPredicate(Pred, IP);
> +    Builder.SetInsertPoint(IP);
> +    Check = Builder.CreateOr(Check, NextCheck);
> +  }
> +
> +  return Check;
> +}
> +
> namespace {
> // Search for a SCEV subexpression that is not safe to expand.  Any expression
> // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
> 
> Modified: llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp (original)
> +++ llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp Mon Nov  2 08:41:02 2015
> @@ -222,6 +222,15 @@ static cl::opt<unsigned> PragmaVectorize
>     cl::desc("The maximum allowed number of runtime memory checks with a "
>              "vectorize(enable) pragma."));
> 
> +static cl::opt<unsigned> VectorizeSCEVCheckThreshold(
> +    "vectorize-scev-check-threshold", cl::init(16), cl::Hidden,
> +    cl::desc("The maximum number of SCEV checks allowed."));
> +
> +static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold(
> +    "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden,
> +    cl::desc("The maximum number of SCEV checks allowed with a "
> +             "vectorize(enable) pragma"));
> +
> namespace {
> 
> // Forward declarations.
> @@ -273,12 +282,12 @@ public:
>   InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
>                       DominatorTree *DT, const TargetLibraryInfo *TLI,
>                       const TargetTransformInfo *TTI, unsigned VecWidth,
> -                      unsigned UnrollFactor)
> +                      unsigned UnrollFactor, SCEVUnionPredicate &Preds)
>       : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
>         VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
>         Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
>         TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
> -        AddedSafetyChecks(false) {}
> +        AddedSafetyChecks(false), Preds(Preds) {}
> 
>   // Perform the actual loop widening (vectorization).
>   // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
> @@ -315,12 +324,6 @@ protected:
>   typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
>                    VectorParts> EdgeMaskCache;
> 
> -  /// \brief Add checks for strides that were assumed to be 1.
> -  ///
> -  /// Returns the last check instruction and the first check instruction in the
> -  /// pair as (first, last).
> -  std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
> -
>   /// Create an empty loop, based on the loop ranges of the old loop.
>   void createEmptyLoop();
>   /// Create a new induction variable inside L.
> @@ -404,11 +407,12 @@ protected:
>   void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass);
>   /// Emit a bypass check to see if the vector trip count is nonzero.
>   void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass);
> -  /// Emit bypass checks to check if strides we've assumed to be one really are.
> -  void emitStrideChecks(Loop *L, BasicBlock *Bypass);
> +  /// Emit a bypass check to see if all of the SCEV assumptions we've
> +  /// had to make are correct.
> +  void emitSCEVChecks(Loop *L, BasicBlock *Bypass);
>   /// Emit bypass checks to check any memory assumptions we may have made.
>   void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass);
> -  
> +
>   /// This is a helper class that holds the vectorizer state. It maps scalar
>   /// instructions to vector instructions. When the code is 'unrolled' then
>   /// then a single scalar value is mapped to multiple vector parts. The parts
> @@ -516,14 +520,23 @@ protected:
> 
>   // Record whether runtime check is added.
>   bool AddedSafetyChecks;
> +
> +  /// The SCEV predicate containing all the SCEV-related assumptions.
> +  /// The predicate is used to simplify existing expressions in the
> +  /// context of existing SCEV assumptions. Since legality checking is
> +  /// not done here, we don't need to use this predicate to record
> +  /// further assumptions.
> +  SCEVUnionPredicate &Preds;
> };
> 
> class InnerLoopUnroller : public InnerLoopVectorizer {
> public:
>   InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
>                     DominatorTree *DT, const TargetLibraryInfo *TLI,
> -                    const TargetTransformInfo *TTI, unsigned UnrollFactor)
> -      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
> +                    const TargetTransformInfo *TTI, unsigned UnrollFactor,
> +                    SCEVUnionPredicate &Preds)
> +      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
> +                            Preds) {}
> 
> private:
>   void scalarizeInstruction(Instruction *Instr,
> @@ -744,8 +757,9 @@ private:
> /// between the member and the group in a map.
> class InterleavedAccessInfo {
> public:
> -  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT)
> -      : SE(SE), TheLoop(L), DT(DT) {}
> +  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
> +                        SCEVUnionPredicate &Preds)
> +      : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
> 
>   ~InterleavedAccessInfo() {
>     SmallSet<InterleaveGroup *, 4> DelSet;
> @@ -779,6 +793,13 @@ private:
>   Loop *TheLoop;
>   DominatorTree *DT;
> 
> +  /// The SCEV predicate containing all the SCEV-related assumptions.
> +  /// The predicate is used to simplify SCEV expressions in the
> +  /// context of existing SCEV assumptions. The interleaved access
> +  /// analysis can also add new predicates (for example by versioning
> +  /// strides of pointers).
> +  SCEVUnionPredicate &Preds;
> +
>   /// Holds the relationships between the members and the interleave group.
>   DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
> 
> @@ -1141,11 +1162,13 @@ public:
>                             Function *F, const TargetTransformInfo *TTI,
>                             LoopAccessAnalysis *LAA,
>                             LoopVectorizationRequirements *R,
> -                            const LoopVectorizeHints *H)
> +                            const LoopVectorizeHints *H,
> +                            SCEVUnionPredicate &Preds)
>       : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
> -        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT),
> -        Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
> -        Requirements(R), Hints(H) {}
> +        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
> +        InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
> +        WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
> +        Preds(Preds) {}
> 
>   /// ReductionList contains the reduction descriptors for all
>   /// of the reductions that were found in the loop.
> @@ -1344,7 +1367,14 @@ private:
> 
>   /// While vectorizing these instructions we have to generate a
>   /// call to the appropriate masked intrinsic
> -  SmallPtrSet<const Instruction*, 8> MaskedOp;
> +  SmallPtrSet<const Instruction *, 8> MaskedOp;
> +
> +  /// The SCEV predicate containing all the SCEV-related assumptions.
> +  /// The predicate is used to simplify SCEV expressions in the
> +  /// context of existing SCEV assumptions. The analysis will also
> +  /// add a minimal set of new predicates if this is required to
> +  /// enable vectorization/unrolling.
> +  SCEVUnionPredicate &Preds;
> };
> 
> /// LoopVectorizationCostModel - estimates the expected speedups due to
> @@ -1360,9 +1390,10 @@ public:
>                              LoopVectorizationLegality *Legal,
>                              const TargetTransformInfo &TTI,
>                              const TargetLibraryInfo *TLI, DemandedBits *DB,
> -                             AssumptionCache *AC,
> -                             const Function *F, const LoopVectorizeHints *Hints,
> -                             SmallPtrSetImpl<const Value *> &ValuesToIgnore)
> +                             AssumptionCache *AC, const Function *F,
> +                             const LoopVectorizeHints *Hints,
> +                             SmallPtrSetImpl<const Value *> &ValuesToIgnore,
> +                             SCEVUnionPredicate &Preds)
>       : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
>         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
> 
> @@ -1690,10 +1721,12 @@ struct LoopVectorize : public FunctionPa
>       }
>     }
> 
> +    SCEVUnionPredicate Preds;
> +
>     // Check if it is legal to vectorize the loop.
>     LoopVectorizationRequirements Requirements;
>     LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
> -                                  &Requirements, &Hints);
> +                                  &Requirements, &Hints, Preds);
>     if (!LVL.canVectorize()) {
>       DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
>       emitMissedWarning(F, L, Hints);
> @@ -1712,7 +1745,7 @@ struct LoopVectorize : public FunctionPa
> 
>     // Use the cost model.
>     LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
> -                                  ValuesToIgnore);
> +                                  ValuesToIgnore, Preds);
> 
>     // Check the function attributes to find out if this function should be
>     // optimized for size.
> @@ -1823,7 +1856,7 @@ struct LoopVectorize : public FunctionPa
>       assert(IC > 1 && "interleave count should not be 1 or 0");
>       // If we decided that it is not legal to vectorize the loop then
>       // interleave it.
> -      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
> +      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
>       Unroller.vectorize(&LVL, CM.MinBWs);
> 
>       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
> @@ -1831,7 +1864,7 @@ struct LoopVectorize : public FunctionPa
>                                  Twine(IC) + ")");
>     } else {
>       // If we decided that it is *legal* to vectorize the loop then do it.
> -      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
> +      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
>       LB.vectorize(&LVL, CM.MinBWs);
>       ++LoopsVectorized;
> 
> @@ -1992,7 +2025,7 @@ int LoopVectorizationLegality::isConsecu
>     //  %idxprom = zext i32 %mul to i64  << Safe cast.
>     //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
>     //
> -    Last = replaceSymbolicStrideSCEV(SE, Strides,
> +    Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
>                                      Gep->getOperand(InductionOperand), Gep);
>     if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
>       Last =
> @@ -2551,56 +2584,8 @@ void InnerLoopVectorizer::scalarizeInstr
>   }
> }
> 
> -static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
> -                                 Instruction *Loc) {
> -  if (FirstInst)
> -    return FirstInst;
> -  if (Instruction *I = dyn_cast<Instruction>(V))
> -    return I->getParent() == Loc->getParent() ? I : nullptr;
> -  return nullptr;
> -}
> -
> -std::pair<Instruction *, Instruction *>
> -InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
> -  Instruction *tnullptr = nullptr;
> -  if (!Legal->mustCheckStrides())
> -    return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
> -
> -  IRBuilder<> ChkBuilder(Loc);
> -
> -  // Emit checks.
> -  Value *Check = nullptr;
> -  Instruction *FirstInst = nullptr;
> -  for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
> -                                         SE = Legal->strides_end();
> -       SI != SE; ++SI) {
> -    Value *Ptr = stripIntegerCast(*SI);
> -    Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
> -                                       "stride.chk");
> -    // Store the first instruction we create.
> -    FirstInst = getFirstInst(FirstInst, C, Loc);
> -    if (Check)
> -      Check = ChkBuilder.CreateOr(Check, C);
> -    else
> -      Check = C;
> -  }
> -
> -  // We have to do this trickery because the IRBuilder might fold the check to a
> -  // constant expression in which case there is no Instruction anchored in a
> -  // the block.
> -  LLVMContext &Ctx = Loc->getContext();
> -  Instruction *TheCheck =
> -      BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
> -  ChkBuilder.Insert(TheCheck, "stride.not.one");
> -  FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
> -
> -  return std::make_pair(FirstInst, TheCheck);
> -}
> -
> -PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L,
> -                                                      Value *Start,
> -                                                      Value *End,
> -                                                      Value *Step,
> +PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start,
> +                                                      Value *End, Value *Step,
>                                                       Instruction *DL) {
>   BasicBlock *Header = L->getHeader();
>   BasicBlock *Latch = L->getLoopLatch();
> @@ -2735,26 +2720,26 @@ void InnerLoopVectorizer::emitVectorLoop
>   LoopBypassBlocks.push_back(BB);
> }
> 
> -void InnerLoopVectorizer::emitStrideChecks(Loop *L,
> -                                           BasicBlock *Bypass) {
> +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
>   BasicBlock *BB = L->getLoopPreheader();
> -  
> -  // Generate the code to check that the strides we assumed to be one are really
> -  // one. We want the new basic block to start at the first instruction in a
> +
> +  // Generate the code to check that the SCEV assumptions that we made.
> +  // We want the new basic block to start at the first instruction in a
>   // sequence of instructions that form a check.
> -  Instruction *StrideCheck;
> -  Instruction *FirstCheckInst;
> -  std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator());
> -  if (!StrideCheck)
> -    return;
> +  SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
> +  Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
> +
> +  if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
> +    if (C->isZero())
> +      return;
> 
>   // Create a new block containing the stride check.
> -  BB->setName("vector.stridecheck");
> +  BB->setName("vector.scevcheck");
>   auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
>   if (L->getParentLoop())
>     L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI);
>   ReplaceInstWithInst(BB->getTerminator(),
> -                      BranchInst::Create(Bypass, NewBB, StrideCheck));
> +                      BranchInst::Create(Bypass, NewBB, SCEVCheck));
>   LoopBypassBlocks.push_back(BB);
>   AddedSafetyChecks = true;
> }
> @@ -2874,10 +2859,10 @@ void InnerLoopVectorizer::createEmptyLoo
>   // Now, compare the new count to zero. If it is zero skip the vector loop and
>   // jump to the scalar loop.
>   emitVectorLoopEnteredCheck(Lp, ScalarPH);
> -  // Generate the code to check that the strides we assumed to be one are really
> -  // one. We want the new basic block to start at the first instruction in a
> -  // sequence of instructions that form a check.
> -  emitStrideChecks(Lp, ScalarPH);
> +  // Generate the code to check any assumptions that we've made for SCEV
> +  // expressions.
> +  emitSCEVChecks(Lp, ScalarPH);
> +
>   // Generate the code that checks in runtime if arrays overlap. We put the
>   // checks into a separate block to make the more common case of few elements
>   // faster.
> @@ -4130,7 +4115,19 @@ bool LoopVectorizationLegality::canVecto
> 
>   // Analyze interleaved memory accesses.
>   if (UseInterleaved)
> -     InterleaveInfo.analyzeInterleaving(Strides);
> +    InterleaveInfo.analyzeInterleaving(Strides);
> +
> +  unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
> +  if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
> +    SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
> +
> +  if (Preds.getComplexity() > SCEVThreshold) {
> +    emitAnalysis(VectorizationReport()
> +                 << "Too many SCEV assumptions need to be made and checked "
> +                 << "at runtime");
> +    DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n");
> +    return false;
> +  }
> 
>   // Okay! We can vectorize. At this point we don't have any other mem analysis
>   // which may limit our maximum vectorization factor, so just return true with
> @@ -4436,6 +4433,7 @@ bool LoopVectorizationLegality::canVecto
>   }
> 
>   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
> +  Preds.add(&LAI->Preds);
> 
>   return true;
> }
> @@ -4550,7 +4548,7 @@ void InterleavedAccessInfo::collectConst
>     StoreInst *SI = dyn_cast<StoreInst>(I);
> 
>     Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
> -    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides);
> +    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
> 
>     // The factor of the corresponding interleave group.
>     unsigned Factor = std::abs(Stride);
> @@ -4559,7 +4557,7 @@ void InterleavedAccessInfo::collectConst
>     if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
>       continue;
> 
> -    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> +    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
>     PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
>     unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());
> 
> 
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits



More information about the llvm-commits mailing list