[llvm] r251800 - [SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
Adam Nemet via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 4 21:39:38 PST 2015
> On Nov 2, 2015, at 6:41 AM, Silviu Baranga via llvm-commits <llvm-commits at lists.llvm.org> wrote:
>
> Author: sbaranga
> Date: Mon Nov 2 08:41:02 2015
> New Revision: 251800
>
> URL: http://llvm.org/viewvc/llvm-project?rev=251800&view=rev
> Log:
> [SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
>
> Summary:
> SCEV Predicates represent conditions that typically cannot be derived from
> static analysis, but can be used to reduce SCEV expressions to forms which are
> usable for different optimizers.
>
> ScalarEvolution now has the rewriteUsingPredicate method which can simplify a
> SCEV expression using a SCEVPredicateSet. The normal workflow of a pass using
> SCEVPredicates would be to hold a SCEVPredicateSet and every time assumptions
> need to be made a new SCEV Predicate would be created and added to the set.
> Each time after calling getSCEV, the user will call the rewriteUsingPredicate
> method.
>
> We add two types of predicates
> SCEVPredicateSet - implements a set of predicates
> SCEVEqualPredicate - tests for equality between two SCEV expressions
>
> We use the SCEVEqualPredicate to re-implement stride versioning. Every time we
> version a stride, we will add a SCEVEqualPredicate to the context.
> Instead of adding specific stride checks, LoopVectorize now adds a more
> generic SCEV check.
>
> We only need to add support for this in the LoopVectorizer since this is the
> only pass that will do stride versioning.
>
> Reviewers: mzolotukhin, anemet, hfinkel, sanjoy
>
> Subscribers: sanjoy, hfinkel, rengolin, jmolloy, llvm-commits
>
> Differential Revision: http://reviews.llvm.org/D13595
>
> Modified:
> llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h
> llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
> llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h
> llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp
> llvm/trunk/lib/Analysis/ScalarEvolution.cpp
> llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp
> llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
>
> Modified: llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h (original)
> +++ llvm/trunk/include/llvm/Analysis/LoopAccessAnalysis.h Mon Nov 2 08:41:02 2015
> @@ -32,6 +32,7 @@ class DataLayout;
> class ScalarEvolution;
> class Loop;
> class SCEV;
> +class SCEVUnionPredicate;
>
> /// Optimization analysis message produced during vectorization. Messages inform
> /// the user why vectorization did not occur.
> @@ -176,10 +177,11 @@ public:
> const SmallVectorImpl<Instruction *> &Instrs) const;
> };
>
> - MemoryDepChecker(ScalarEvolution *Se, const Loop *L)
> + MemoryDepChecker(ScalarEvolution *Se, const Loop *L,
> + SCEVUnionPredicate &Preds)
> : SE(Se), InnermostLoop(L), AccessIdx(0),
> ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
> - RecordInterestingDependences(true) {}
> + RecordInterestingDependences(true), Preds(Preds) {}
>
> /// \brief Register the location (instructions are given increasing numbers)
> /// of a write access.
> @@ -289,6 +291,15 @@ private:
> /// \brief Check whether the data dependence could prevent store-load
> /// forwarding.
> bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
> +
> + /// The SCEV predicate containing all the SCEV-related assumptions.
> + /// The dependence checker needs this in order to convert SCEVs of pointers
> + /// to more accurate expressions in the context of existing assumptions.
> + /// We also need this in case assumptions about SCEV expressions need to
> + /// be made in order to avoid unknown dependences. For example we might
> + /// assume a unit stride for a pointer in order to prove that a memory access
> + /// is strided and doesn't wrap.
> + SCEVUnionPredicate &Preds;
> };
>
> /// \brief Holds information about the memory runtime legality checks to verify
> @@ -330,8 +341,13 @@ public:
> }
>
> /// Insert a pointer and calculate the start and end SCEVs.
> + /// \p We need Preds in order to compute the SCEV expression of the pointer
> + /// according to the assumptions that we've made during the analysis.
> + /// The method might also version the pointer stride according to \p Strides,
> + /// and change \p Preds.
> void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
> - unsigned ASId, const ValueToValueMap &Strides);
> + unsigned ASId, const ValueToValueMap &Strides,
> + SCEVUnionPredicate &Preds);
>
> /// \brief No run-time memory checking is necessary.
> bool empty() const { return Pointers.empty(); }
> @@ -537,6 +553,15 @@ public:
> return StoreToLoopInvariantAddress;
> }
>
> + /// The SCEV predicate contains all the SCEV-related assumptions.
> + /// The is used to keep track of the minimal set of assumptions on SCEV
> + /// expressions that the analysis needs to make in order to return a
> + /// meaningful result. All SCEV expressions during the analysis should be
> + /// re-written (and therefore simplified) according to Preds.
> + /// A user of LoopAccessAnalysis will need to emit the runtime checks
> + /// associated with this predicate.
> + SCEVUnionPredicate Preds;
Can you please also add a comment probably right before LoopAccessInfo that both sets of informations it provides (run-time alias checks and dependence information) are only correct if the predicates contained here return are true (at run time).
Thanks,
Adam
> +
> private:
> /// \brief Analyze the loop. Substitute symbolic strides using Strides.
> void analyzeLoop(const ValueToValueMap &Strides);
> @@ -583,19 +608,26 @@ private:
> Value *stripIntegerCast(Value *V);
>
> ///\brief Return the SCEV corresponding to a pointer with the symbolic stride
> -///replaced with constant one.
> +/// replaced with constant one, assuming \p Preds is true.
> +///
> +/// If necessary this method will version the stride of the pointer according
> +/// to \p PtrToStride and therefore add a new predicate to \p Preds.
> ///
> /// If \p OrigPtr is not null, use it to look up the stride value instead of \p
> /// Ptr. \p PtrToStride provides the mapping between the pointer value and its
> /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
> const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
> const ValueToValueMap &PtrToStride,
> - Value *Ptr, Value *OrigPtr = nullptr);
> + SCEVUnionPredicate &Preds, Value *Ptr,
> + Value *OrigPtr = nullptr);
>
> /// \brief Check the stride of the pointer and ensure that it does not wrap in
> -/// the address space.
> +/// the address space, assuming \p Preds is true.
> +///
> +/// If necessary this method will version the stride of the pointer according
> +/// to \p PtrToStride and therefore add a new predicate to \p Preds.
> int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
> - const ValueToValueMap &StridesMap);
> + const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds);
>
> /// \brief This analysis provides dependence information for the memory accesses
> /// of a loop.
>
> Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
> +++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Mon Nov 2 08:41:02 2015
> @@ -48,10 +48,15 @@ namespace llvm {
> class Loop;
> class LoopInfo;
> class Operator;
> - class SCEVUnknown;
> - class SCEVAddRecExpr;
> class SCEV;
> - template<> struct FoldingSetTrait<SCEV>;
> + class SCEVAddRecExpr;
> + class SCEVConstant;
> + class SCEVExpander;
> + class SCEVPredicate;
> + class SCEVUnknown;
> +
> + template <> struct FoldingSetTrait<SCEV>;
> + template <> struct FoldingSetTrait<SCEVPredicate>;
>
> /// This class represents an analyzed expression in the program. These are
> /// opaque objects that the client is not allowed to do much with directly.
> @@ -164,6 +169,148 @@ namespace llvm {
> static bool classof(const SCEV *S);
> };
>
> + /// SCEVPredicate - This class represents an assumption made using SCEV
> + /// expressions which can be checked at run-time.
> + class SCEVPredicate : public FoldingSetNode {
> + friend struct FoldingSetTrait<SCEVPredicate>;
> +
> + /// A reference to an Interned FoldingSetNodeID for this node. The
> + /// ScalarEvolution's BumpPtrAllocator holds the data.
> + FoldingSetNodeIDRef FastID;
> +
> + public:
> + enum SCEVPredicateKind { P_Union, P_Equal };
> +
> + protected:
> + SCEVPredicateKind Kind;
> +
> + public:
> + SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
> +
> + virtual ~SCEVPredicate() {}
> +
> + SCEVPredicateKind getKind() const { return Kind; }
> +
> + /// \brief Returns the estimated complexity of this predicate.
> + /// This is roughly measured in the number of run-time checks required.
> + virtual unsigned getComplexity() { return 1; }
> +
> + /// \brief Returns true if the predicate is always true. This means that no
> + /// assumptions were made and nothing needs to be checked at run-time.
> + virtual bool isAlwaysTrue() const = 0;
> +
> + /// \brief Returns true if this predicate implies \p N.
> + virtual bool implies(const SCEVPredicate *N) const = 0;
> +
> + /// \brief Prints a textual representation of this predicate with an
> + /// indentation of \p Depth.
> + virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
> +
> + /// \brief Returns the SCEV to which this predicate applies, or nullptr
> + /// if this is a SCEVUnionPredicate.
> + virtual const SCEV *getExpr() const = 0;
> + };
> +
> + inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
> + P.print(OS);
> + return OS;
> + }
> +
> + // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
> + // temporary FoldingSetNodeID values.
> + template <>
> + struct FoldingSetTrait<SCEVPredicate>
> + : DefaultFoldingSetTrait<SCEVPredicate> {
> +
> + static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
> + ID = X.FastID;
> + }
> +
> + static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
> + unsigned IDHash, FoldingSetNodeID &TempID) {
> + return ID == X.FastID;
> + }
> + static unsigned ComputeHash(const SCEVPredicate &X,
> + FoldingSetNodeID &TempID) {
> + return X.FastID.ComputeHash();
> + }
> + };
> +
> + /// SCEVEqualPredicate - This class represents an assumption that two SCEV
> + /// expressions are equal, and this can be checked at run-time. We assume
> + /// that the left hand side is a SCEVUnknown and the right hand side a
> + /// constant.
> + class SCEVEqualPredicate : public SCEVPredicate {
> + /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
> + /// constant.
> + const SCEVUnknown *LHS;
> + const SCEVConstant *RHS;
> +
> + public:
> + SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
> + const SCEVConstant *RHS);
> +
> + /// Implementation of the SCEVPredicate interface
> + bool implies(const SCEVPredicate *N) const override;
> + void print(raw_ostream &OS, unsigned Depth = 0) const override;
> + bool isAlwaysTrue() const override;
> + const SCEV *getExpr() const;
> +
> + /// \brief Returns the left hand side of the equality.
> + const SCEVUnknown *getLHS() const { return LHS; }
> +
> + /// \brief Returns the right hand side of the equality.
> + const SCEVConstant *getRHS() const { return RHS; }
> +
> + /// Methods for support type inquiry through isa, cast, and dyn_cast:
> + static inline bool classof(const SCEVPredicate *P) {
> + return P->getKind() == P_Equal;
> + }
> + };
> +
> + /// SCEVUnionPredicate - This class represents a composition of other
> + /// SCEV predicates, and is the class that most clients will interact with.
> + /// This is equivalent to a logical "AND" of all the predicates in the union.
> + class SCEVUnionPredicate : public SCEVPredicate {
> + private:
> + typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
> + PredicateMap;
> +
> + /// Vector with references to all predicates in this union.
> + SmallVector<const SCEVPredicate *, 16> Preds;
> + /// Maps SCEVs to predicates for quick look-ups.
> + PredicateMap SCEVToPreds;
> +
> + public:
> + SCEVUnionPredicate();
> +
> + const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
> + return Preds;
> + }
> +
> + /// \brief Adds a predicate to this union.
> + void add(const SCEVPredicate *N);
> +
> + /// \brief Returns a reference to a vector containing all predicates
> + /// which apply to \p Expr.
> + ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
> +
> + /// Implementation of the SCEVPredicate interface
> + bool isAlwaysTrue() const override;
> + bool implies(const SCEVPredicate *N) const override;
> + void print(raw_ostream &OS, unsigned Depth) const;
> + const SCEV *getExpr() const override;
> +
> + /// \brief We estimate the complexity of a union predicate as the size
> + /// number of predicates in the union.
> + unsigned getComplexity() override { return Preds.size(); }
> +
> + /// Methods for support type inquiry through isa, cast, and dyn_cast:
> + static inline bool classof(const SCEVPredicate *P) {
> + return P->getKind() == P_Union;
> + }
> + };
> +
> /// The main scalar evolution driver. Because client code (intentionally)
> /// can't do much with the SCEV objects directly, they must ask this class
> /// for services.
> @@ -1108,6 +1255,12 @@ namespace llvm {
> return F.getParent()->getDataLayout();
> }
>
> + const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
> + const SCEVConstant *RHS);
> +
> + /// Re-writes the SCEV according to the Predicates in \p Preds.
> + const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
> +
> private:
> /// Compute the backedge taken count knowing the interval difference, the
> /// stride and presence of the equality in the comparison.
> @@ -1128,6 +1281,7 @@ namespace llvm {
>
> private:
> FoldingSet<SCEV> UniqueSCEVs;
> + FoldingSet<SCEVPredicate> UniquePreds;
> BumpPtrAllocator SCEVAllocator;
>
> /// The head of a linked list of all SCEVUnknown values that have been
>
> Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h (original)
> +++ llvm/trunk/include/llvm/Analysis/ScalarEvolutionExpander.h Mon Nov 2 08:41:02 2015
> @@ -151,6 +151,22 @@ namespace llvm {
> /// block.
> Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
>
> + /// \brief Generates a code sequence that evaluates this predicate.
> + /// The inserted instructions will be at position \p Loc.
> + /// The result will be of type i1 and will have a value of 0 when the
> + /// predicate is false and 1 otherwise.
> + Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc);
> +
> + /// \brief A specialized variant of expandCodeForPredicate, handling the
> + /// case when we are expanding code for a SCEVEqualPredicate.
> + Value *expandEqualPredicate(const SCEVEqualPredicate *Pred,
> + Instruction *Loc);
> +
> + /// \brief A specialized variant of expandCodeForPredicate, handling the
> + /// case when we are expanding code for a SCEVUnionPredicate.
> + Value *expandUnionPredicate(const SCEVUnionPredicate *Pred,
> + Instruction *Loc);
> +
> /// \brief Set the current IV increment loop and position.
> void setIVIncInsertPos(const Loop *L, Instruction *Pos) {
> assert(!CanonicalMode &&
>
> Modified: llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp (original)
> +++ llvm/trunk/lib/Analysis/LoopAccessAnalysis.cpp Mon Nov 2 08:41:02 2015
> @@ -89,8 +89,8 @@ Value *llvm::stripIntegerCast(Value *V)
>
> const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
> const ValueToValueMap &PtrToStride,
> + SCEVUnionPredicate &Preds,
> Value *Ptr, Value *OrigPtr) {
> -
> const SCEV *OrigSCEV = SE->getSCEV(Ptr);
>
> // If there is an entry in the map return the SCEV of the pointer with the
> @@ -108,22 +108,28 @@ const SCEV *llvm::replaceSymbolicStrideS
> ValueToValueMap RewriteMap;
> RewriteMap[StrideVal] = One;
>
> - const SCEV *ByOne =
> - SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true);
> + const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
> + const auto *CT =
> + static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
> +
> + Preds.add(SE->getEqualPredicate(U, CT));
> +
> + const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds);
> DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
> << "\n");
> return ByOne;
> }
>
> // Otherwise, just return the SCEV of the original pointer.
> - return SE->getSCEV(Ptr);
> + return OrigSCEV;
> }
>
> void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
> unsigned DepSetId, unsigned ASId,
> - const ValueToValueMap &Strides) {
> + const ValueToValueMap &Strides,
> + SCEVUnionPredicate &Preds) {
> // Get the stride replaced scev.
> - const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
> const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
> assert(AR && "Invalid addrec expression");
> const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
> @@ -417,9 +423,9 @@ public:
> typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
>
> AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
> - MemoryDepChecker::DepCandidates &DA)
> - : DL(Dl), AST(*AA), LI(LI), DepCands(DA),
> - IsRTCheckAnalysisNeeded(false) {}
> + MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds)
> + : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
> + Preds(Preds) {}
>
> /// \brief Register a load and whether it is only read from.
> void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
> @@ -504,14 +510,18 @@ private:
> /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared
> /// while this remains set if we have potentially dependent accesses.
> bool IsRTCheckAnalysisNeeded;
> +
> + /// The SCEV predicate containing all the SCEV-related assumptions.
> + SCEVUnionPredicate &Preds;
> };
>
> } // end anonymous namespace
>
> /// \brief Check whether a pointer can participate in a runtime bounds check.
> static bool hasComputableBounds(ScalarEvolution *SE,
> - const ValueToValueMap &Strides, Value *Ptr) {
> - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> + const ValueToValueMap &Strides, Value *Ptr,
> + Loop *L, SCEVUnionPredicate &Preds) {
> + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
> const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
> if (!AR)
> return false;
> @@ -554,11 +564,11 @@ bool AccessAnalysis::canCheckPtrAtRT(Run
> else
> ++NumReadPtrChecks;
>
> - if (hasComputableBounds(SE, StridesMap, Ptr) &&
> + if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) &&
> // When we run after a failing dependency check we have to make sure
> // we don't have wrapping pointers.
> (!ShouldCheckStride ||
> - isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) {
> + isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) {
> // The id of the dependence set.
> unsigned DepId;
>
> @@ -572,7 +582,7 @@ bool AccessAnalysis::canCheckPtrAtRT(Run
> // Each access has its own dependence set.
> DepId = RunningDepId++;
>
> - RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
> + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds);
>
> DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
> } else {
> @@ -803,7 +813,8 @@ static bool isNoWrapAddRec(Value *Ptr, c
>
> /// \brief Check whether the access through \p Ptr has a constant stride.
> int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
> - const ValueToValueMap &StridesMap) {
> + const ValueToValueMap &StridesMap,
> + SCEVUnionPredicate &Preds) {
> Type *Ty = Ptr->getType();
> assert(Ty->isPointerTy() && "Unexpected non-ptr");
>
> @@ -815,7 +826,7 @@ int llvm::isStridedPtr(ScalarEvolution *
> return 0;
> }
>
> - const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr);
> + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr);
>
> const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
> if (!AR) {
> @@ -1026,11 +1037,11 @@ MemoryDepChecker::isDependent(const MemA
> BPtr->getType()->getPointerAddressSpace())
> return Dependence::Unknown;
>
> - const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr);
> - const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr);
> + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr);
> + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr);
>
> - int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides);
> - int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides);
> + int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds);
> + int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds);
>
> const SCEV *Src = AScev;
> const SCEV *Sink = BScev;
> @@ -1429,7 +1440,7 @@ void LoopAccessInfo::analyzeLoop(const V
>
> MemoryDepChecker::DepCandidates DependentAccesses;
> AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
> - AA, LI, DependentAccesses);
> + AA, LI, DependentAccesses, Preds);
>
> // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
> // multiple times on the same object. If the ptr is accessed twice, once
> @@ -1480,7 +1491,8 @@ void LoopAccessInfo::analyzeLoop(const V
> // read a few words, modify, and write a few words, and some of the
> // words may be written to the same address.
> bool IsReadOnlyPtr = false;
> - if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) {
> + if (Seen.insert(Ptr).second ||
> + !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) {
> ++NumReads;
> IsReadOnlyPtr = true;
> }
> @@ -1730,7 +1742,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L,
> const TargetLibraryInfo *TLI, AliasAnalysis *AA,
> DominatorTree *DT, LoopInfo *LI,
> const ValueToValueMap &Strides)
> - : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL),
> + : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL),
> TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
> MaxSafeDepDistBytes(-1U), CanVecMem(false),
> StoreToLoopInvariantAddress(false) {
> @@ -1765,6 +1777,9 @@ void LoopAccessInfo::print(raw_ostream &
> OS.indent(Depth) << "Store to invariant address was "
> << (StoreToLoopInvariantAddress ? "" : "not ")
> << "found in loop.\n";
> +
> + OS.indent(Depth) << "SCEV assumptions:\n";
> + Preds.print(OS, Depth);
> }
>
> const LoopAccessInfo &
> @@ -1778,8 +1793,8 @@ LoopAccessAnalysis::getInfo(Loop *L, con
>
> if (!LAI) {
> const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
> - LAI = llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI,
> - Strides);
> + LAI =
> + llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI, Strides);
> #ifndef NDEBUG
> LAI->NumSymbolicStrides = Strides.size();
> #endif
>
> Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
> +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Mon Nov 2 08:41:02 2015
> @@ -9093,6 +9093,7 @@ ScalarEvolution::ScalarEvolution(ScalarE
> UnsignedRanges(std::move(Arg.UnsignedRanges)),
> SignedRanges(std::move(Arg.SignedRanges)),
> UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
> + UniquePreds(std::move(Arg.UniquePreds)),
> SCEVAllocator(std::move(Arg.SCEVAllocator)),
> FirstUnknown(Arg.FirstUnknown) {
> Arg.FirstUnknown = nullptr;
> @@ -9596,3 +9597,134 @@ void ScalarEvolutionWrapperPass::getAnal
> AU.addRequiredTransitive<DominatorTreeWrapperPass>();
> AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
> }
> +
> +const SCEVPredicate *
> +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
> + const SCEVConstant *RHS) {
> + FoldingSetNodeID ID;
> + // Unique this node based on the arguments
> + ID.AddInteger(SCEVPredicate::P_Equal);
> + ID.AddPointer(LHS);
> + ID.AddPointer(RHS);
> + void *IP = nullptr;
> + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
> + return S;
> + SCEVEqualPredicate *Eq = new (SCEVAllocator)
> + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
> + UniquePreds.InsertNode(Eq, IP);
> + return Eq;
> +}
> +
> +class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
> +public:
> + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
> + SCEVUnionPredicate &A) {
> + SCEVPredicateRewriter Rewriter(SE, A);
> + return Rewriter.visit(Scev);
> + }
> +
> + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
> + : SCEVRewriteVisitor(SE), P(P) {}
> +
> + const SCEV *visitUnknown(const SCEVUnknown *Expr) {
> + auto ExprPreds = P.getPredicatesForExpr(Expr);
> + for (auto *Pred : ExprPreds)
> + if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
> + if (IPred->getLHS() == Expr)
> + return IPred->getRHS();
> +
> + return Expr;
> + }
> +
> +private:
> + SCEVUnionPredicate &P;
> +};
> +
> +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
> + SCEVUnionPredicate &Preds) {
> + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
> +}
> +
> +/// SCEV predicates
> +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
> + SCEVPredicateKind Kind)
> + : FastID(ID), Kind(Kind) {}
> +
> +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
> + const SCEVUnknown *LHS,
> + const SCEVConstant *RHS)
> + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
> +
> +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
> + const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
> +
> + if (!Op)
> + return false;
> +
> + return Op->LHS == LHS && Op->RHS == RHS;
> +}
> +
> +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
> +
> +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
> +
> +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
> + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
> +}
> +
> +/// Union predicates don't get cached so create a dummy set ID for it.
> +SCEVUnionPredicate::SCEVUnionPredicate()
> + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
> +
> +bool SCEVUnionPredicate::isAlwaysTrue() const {
> + return std::all_of(Preds.begin(), Preds.end(),
> + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
> +}
> +
> +ArrayRef<const SCEVPredicate *>
> +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
> + auto I = SCEVToPreds.find(Expr);
> + if (I == SCEVToPreds.end())
> + return ArrayRef<const SCEVPredicate *>();
> + return I->second;
> +}
> +
> +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
> + if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
> + return std::all_of(
> + Set->Preds.begin(), Set->Preds.end(),
> + [this](const SCEVPredicate *I) { return this->implies(I); });
> +
> + auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
> + if (ScevPredsIt == SCEVToPreds.end())
> + return false;
> + auto &SCEVPreds = ScevPredsIt->second;
> +
> + return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
> + [N](const SCEVPredicate *I) { return I->implies(N); });
> +}
> +
> +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
> +
> +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
> + for (auto Pred : Preds)
> + Pred->print(OS, Depth);
> +}
> +
> +void SCEVUnionPredicate::add(const SCEVPredicate *N) {
> + if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
> + for (auto Pred : Set->Preds)
> + add(Pred);
> + return;
> + }
> +
> + if (implies(N))
> + return;
> +
> + const SCEV *Key = N->getExpr();
> + assert(Key && "Only SCEVUnionPredicate doesn't have an "
> + " associated expression!");
> +
> + SCEVToPreds[Key].push_back(N);
> + Preds.push_back(N);
> +}
>
> Modified: llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp (original)
> +++ llvm/trunk/lib/Analysis/ScalarEvolutionExpander.cpp Mon Nov 2 08:41:02 2015
> @@ -1944,6 +1944,43 @@ bool SCEVExpander::isHighCostExpansionHe
> return false;
> }
>
> +Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
> + Instruction *IP) {
> + assert(IP);
> + switch (Pred->getKind()) {
> + case SCEVPredicate::P_Union:
> + return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
> + case SCEVPredicate::P_Equal:
> + return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
> + }
> + llvm_unreachable("Unknown SCEV predicate type");
> +}
> +
> +Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
> + Instruction *IP) {
> + Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
> + Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
> +
> + Builder.SetInsertPoint(IP);
> + auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
> + return I;
> +}
> +
> +Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
> + Instruction *IP) {
> + auto *BoolType = IntegerType::get(IP->getContext(), 1);
> + Value *Check = ConstantInt::getNullValue(BoolType);
> +
> + // Loop over all checks in this set.
> + for (auto Pred : Union->getPredicates()) {
> + auto *NextCheck = expandCodeForPredicate(Pred, IP);
> + Builder.SetInsertPoint(IP);
> + Check = Builder.CreateOr(Check, NextCheck);
> + }
> +
> + return Check;
> +}
> +
> namespace {
> // Search for a SCEV subexpression that is not safe to expand. Any expression
> // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
>
> Modified: llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp
> URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp?rev=251800&r1=251799&r2=251800&view=diff
> ==============================================================================
> --- llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp (original)
> +++ llvm/trunk/lib/Transforms/Vectorize/LoopVectorize.cpp Mon Nov 2 08:41:02 2015
> @@ -222,6 +222,15 @@ static cl::opt<unsigned> PragmaVectorize
> cl::desc("The maximum allowed number of runtime memory checks with a "
> "vectorize(enable) pragma."));
>
> +static cl::opt<unsigned> VectorizeSCEVCheckThreshold(
> + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden,
> + cl::desc("The maximum number of SCEV checks allowed."));
> +
> +static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold(
> + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden,
> + cl::desc("The maximum number of SCEV checks allowed with a "
> + "vectorize(enable) pragma"));
> +
> namespace {
>
> // Forward declarations.
> @@ -273,12 +282,12 @@ public:
> InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
> DominatorTree *DT, const TargetLibraryInfo *TLI,
> const TargetTransformInfo *TTI, unsigned VecWidth,
> - unsigned UnrollFactor)
> + unsigned UnrollFactor, SCEVUnionPredicate &Preds)
> : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
> VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
> Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
> TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
> - AddedSafetyChecks(false) {}
> + AddedSafetyChecks(false), Preds(Preds) {}
>
> // Perform the actual loop widening (vectorization).
> // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
> @@ -315,12 +324,6 @@ protected:
> typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
> VectorParts> EdgeMaskCache;
>
> - /// \brief Add checks for strides that were assumed to be 1.
> - ///
> - /// Returns the last check instruction and the first check instruction in the
> - /// pair as (first, last).
> - std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
> -
> /// Create an empty loop, based on the loop ranges of the old loop.
> void createEmptyLoop();
> /// Create a new induction variable inside L.
> @@ -404,11 +407,12 @@ protected:
> void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass);
> /// Emit a bypass check to see if the vector trip count is nonzero.
> void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass);
> - /// Emit bypass checks to check if strides we've assumed to be one really are.
> - void emitStrideChecks(Loop *L, BasicBlock *Bypass);
> + /// Emit a bypass check to see if all of the SCEV assumptions we've
> + /// had to make are correct.
> + void emitSCEVChecks(Loop *L, BasicBlock *Bypass);
> /// Emit bypass checks to check any memory assumptions we may have made.
> void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass);
> -
> +
> /// This is a helper class that holds the vectorizer state. It maps scalar
> /// instructions to vector instructions. When the code is 'unrolled' then
> /// then a single scalar value is mapped to multiple vector parts. The parts
> @@ -516,14 +520,23 @@ protected:
>
> // Record whether runtime check is added.
> bool AddedSafetyChecks;
> +
> + /// The SCEV predicate containing all the SCEV-related assumptions.
> + /// The predicate is used to simplify existing expressions in the
> + /// context of existing SCEV assumptions. Since legality checking is
> + /// not done here, we don't need to use this predicate to record
> + /// further assumptions.
> + SCEVUnionPredicate &Preds;
> };
>
> class InnerLoopUnroller : public InnerLoopVectorizer {
> public:
> InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
> DominatorTree *DT, const TargetLibraryInfo *TLI,
> - const TargetTransformInfo *TTI, unsigned UnrollFactor)
> - : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
> + const TargetTransformInfo *TTI, unsigned UnrollFactor,
> + SCEVUnionPredicate &Preds)
> + : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
> + Preds) {}
>
> private:
> void scalarizeInstruction(Instruction *Instr,
> @@ -744,8 +757,9 @@ private:
> /// between the member and the group in a map.
> class InterleavedAccessInfo {
> public:
> - InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT)
> - : SE(SE), TheLoop(L), DT(DT) {}
> + InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
> + SCEVUnionPredicate &Preds)
> + : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
>
> ~InterleavedAccessInfo() {
> SmallSet<InterleaveGroup *, 4> DelSet;
> @@ -779,6 +793,13 @@ private:
> Loop *TheLoop;
> DominatorTree *DT;
>
> + /// The SCEV predicate containing all the SCEV-related assumptions.
> + /// The predicate is used to simplify SCEV expressions in the
> + /// context of existing SCEV assumptions. The interleaved access
> + /// analysis can also add new predicates (for example by versioning
> + /// strides of pointers).
> + SCEVUnionPredicate &Preds;
> +
> /// Holds the relationships between the members and the interleave group.
> DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
>
> @@ -1141,11 +1162,13 @@ public:
> Function *F, const TargetTransformInfo *TTI,
> LoopAccessAnalysis *LAA,
> LoopVectorizationRequirements *R,
> - const LoopVectorizeHints *H)
> + const LoopVectorizeHints *H,
> + SCEVUnionPredicate &Preds)
> : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
> - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT),
> - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
> - Requirements(R), Hints(H) {}
> + TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
> + InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
> + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
> + Preds(Preds) {}
>
> /// ReductionList contains the reduction descriptors for all
> /// of the reductions that were found in the loop.
> @@ -1344,7 +1367,14 @@ private:
>
> /// While vectorizing these instructions we have to generate a
> /// call to the appropriate masked intrinsic
> - SmallPtrSet<const Instruction*, 8> MaskedOp;
> + SmallPtrSet<const Instruction *, 8> MaskedOp;
> +
> + /// The SCEV predicate containing all the SCEV-related assumptions.
> + /// The predicate is used to simplify SCEV expressions in the
> + /// context of existing SCEV assumptions. The analysis will also
> + /// add a minimal set of new predicates if this is required to
> + /// enable vectorization/unrolling.
> + SCEVUnionPredicate &Preds;
> };
>
> /// LoopVectorizationCostModel - estimates the expected speedups due to
> @@ -1360,9 +1390,10 @@ public:
> LoopVectorizationLegality *Legal,
> const TargetTransformInfo &TTI,
> const TargetLibraryInfo *TLI, DemandedBits *DB,
> - AssumptionCache *AC,
> - const Function *F, const LoopVectorizeHints *Hints,
> - SmallPtrSetImpl<const Value *> &ValuesToIgnore)
> + AssumptionCache *AC, const Function *F,
> + const LoopVectorizeHints *Hints,
> + SmallPtrSetImpl<const Value *> &ValuesToIgnore,
> + SCEVUnionPredicate &Preds)
> : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
> TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
>
> @@ -1690,10 +1721,12 @@ struct LoopVectorize : public FunctionPa
> }
> }
>
> + SCEVUnionPredicate Preds;
> +
> // Check if it is legal to vectorize the loop.
> LoopVectorizationRequirements Requirements;
> LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
> - &Requirements, &Hints);
> + &Requirements, &Hints, Preds);
> if (!LVL.canVectorize()) {
> DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
> emitMissedWarning(F, L, Hints);
> @@ -1712,7 +1745,7 @@ struct LoopVectorize : public FunctionPa
>
> // Use the cost model.
> LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
> - ValuesToIgnore);
> + ValuesToIgnore, Preds);
>
> // Check the function attributes to find out if this function should be
> // optimized for size.
> @@ -1823,7 +1856,7 @@ struct LoopVectorize : public FunctionPa
> assert(IC > 1 && "interleave count should not be 1 or 0");
> // If we decided that it is not legal to vectorize the loop then
> // interleave it.
> - InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
> + InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
> Unroller.vectorize(&LVL, CM.MinBWs);
>
> emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
> @@ -1831,7 +1864,7 @@ struct LoopVectorize : public FunctionPa
> Twine(IC) + ")");
> } else {
> // If we decided that it is *legal* to vectorize the loop then do it.
> - InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
> + InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
> LB.vectorize(&LVL, CM.MinBWs);
> ++LoopsVectorized;
>
> @@ -1992,7 +2025,7 @@ int LoopVectorizationLegality::isConsecu
> // %idxprom = zext i32 %mul to i64 << Safe cast.
> // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
> //
> - Last = replaceSymbolicStrideSCEV(SE, Strides,
> + Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
> Gep->getOperand(InductionOperand), Gep);
> if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
> Last =
> @@ -2551,56 +2584,8 @@ void InnerLoopVectorizer::scalarizeInstr
> }
> }
>
> -static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
> - Instruction *Loc) {
> - if (FirstInst)
> - return FirstInst;
> - if (Instruction *I = dyn_cast<Instruction>(V))
> - return I->getParent() == Loc->getParent() ? I : nullptr;
> - return nullptr;
> -}
> -
> -std::pair<Instruction *, Instruction *>
> -InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
> - Instruction *tnullptr = nullptr;
> - if (!Legal->mustCheckStrides())
> - return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
> -
> - IRBuilder<> ChkBuilder(Loc);
> -
> - // Emit checks.
> - Value *Check = nullptr;
> - Instruction *FirstInst = nullptr;
> - for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
> - SE = Legal->strides_end();
> - SI != SE; ++SI) {
> - Value *Ptr = stripIntegerCast(*SI);
> - Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
> - "stride.chk");
> - // Store the first instruction we create.
> - FirstInst = getFirstInst(FirstInst, C, Loc);
> - if (Check)
> - Check = ChkBuilder.CreateOr(Check, C);
> - else
> - Check = C;
> - }
> -
> - // We have to do this trickery because the IRBuilder might fold the check to a
> - // constant expression in which case there is no Instruction anchored in a
> - // the block.
> - LLVMContext &Ctx = Loc->getContext();
> - Instruction *TheCheck =
> - BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
> - ChkBuilder.Insert(TheCheck, "stride.not.one");
> - FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
> -
> - return std::make_pair(FirstInst, TheCheck);
> -}
> -
> -PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L,
> - Value *Start,
> - Value *End,
> - Value *Step,
> +PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start,
> + Value *End, Value *Step,
> Instruction *DL) {
> BasicBlock *Header = L->getHeader();
> BasicBlock *Latch = L->getLoopLatch();
> @@ -2735,26 +2720,26 @@ void InnerLoopVectorizer::emitVectorLoop
> LoopBypassBlocks.push_back(BB);
> }
>
> -void InnerLoopVectorizer::emitStrideChecks(Loop *L,
> - BasicBlock *Bypass) {
> +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
> BasicBlock *BB = L->getLoopPreheader();
> -
> - // Generate the code to check that the strides we assumed to be one are really
> - // one. We want the new basic block to start at the first instruction in a
> +
> + // Generate the code to check that the SCEV assumptions that we made.
> + // We want the new basic block to start at the first instruction in a
> // sequence of instructions that form a check.
> - Instruction *StrideCheck;
> - Instruction *FirstCheckInst;
> - std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator());
> - if (!StrideCheck)
> - return;
> + SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
> + Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
> +
> + if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
> + if (C->isZero())
> + return;
>
> // Create a new block containing the stride check.
> - BB->setName("vector.stridecheck");
> + BB->setName("vector.scevcheck");
> auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
> if (L->getParentLoop())
> L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI);
> ReplaceInstWithInst(BB->getTerminator(),
> - BranchInst::Create(Bypass, NewBB, StrideCheck));
> + BranchInst::Create(Bypass, NewBB, SCEVCheck));
> LoopBypassBlocks.push_back(BB);
> AddedSafetyChecks = true;
> }
> @@ -2874,10 +2859,10 @@ void InnerLoopVectorizer::createEmptyLoo
> // Now, compare the new count to zero. If it is zero skip the vector loop and
> // jump to the scalar loop.
> emitVectorLoopEnteredCheck(Lp, ScalarPH);
> - // Generate the code to check that the strides we assumed to be one are really
> - // one. We want the new basic block to start at the first instruction in a
> - // sequence of instructions that form a check.
> - emitStrideChecks(Lp, ScalarPH);
> + // Generate the code to check any assumptions that we've made for SCEV
> + // expressions.
> + emitSCEVChecks(Lp, ScalarPH);
> +
> // Generate the code that checks in runtime if arrays overlap. We put the
> // checks into a separate block to make the more common case of few elements
> // faster.
> @@ -4130,7 +4115,19 @@ bool LoopVectorizationLegality::canVecto
>
> // Analyze interleaved memory accesses.
> if (UseInterleaved)
> - InterleaveInfo.analyzeInterleaving(Strides);
> + InterleaveInfo.analyzeInterleaving(Strides);
> +
> + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
> + if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
> + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
> +
> + if (Preds.getComplexity() > SCEVThreshold) {
> + emitAnalysis(VectorizationReport()
> + << "Too many SCEV assumptions need to be made and checked "
> + << "at runtime");
> + DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n");
> + return false;
> + }
>
> // Okay! We can vectorize. At this point we don't have any other mem analysis
> // which may limit our maximum vectorization factor, so just return true with
> @@ -4436,6 +4433,7 @@ bool LoopVectorizationLegality::canVecto
> }
>
> Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
> + Preds.add(&LAI->Preds);
>
> return true;
> }
> @@ -4550,7 +4548,7 @@ void InterleavedAccessInfo::collectConst
> StoreInst *SI = dyn_cast<StoreInst>(I);
>
> Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
> - int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides);
> + int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
>
> // The factor of the corresponding interleave group.
> unsigned Factor = std::abs(Stride);
> @@ -4559,7 +4557,7 @@ void InterleavedAccessInfo::collectConst
> if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
> continue;
>
> - const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
> + const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
> PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
> unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
More information about the llvm-commits
mailing list