[llvm] 5ba1150 - [PSE] Remove assumption that top level predicate is union from public interface [NFC*]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 10 16:15:00 PST 2022


Author: Philip Reames
Date: 2022-02-10T16:14:52-08:00
New Revision: 5ba115031dd773780cafc8ade0140883473b8cee

URL: https://github.com/llvm/llvm-project/commit/5ba115031dd773780cafc8ade0140883473b8cee
DIFF: https://github.com/llvm/llvm-project/commit/5ba115031dd773780cafc8ade0140883473b8cee.diff

LOG: [PSE] Remove assumption that top level predicate is union from public interface [NFC*]

Note that this doesn't actually cause the top level predicate to become a non-union just yet.

The * above comes from a case in the LoopVectorizer where a predicate which is later proven no longer blocks vectorization due to a change from checking if predicates exists to whether the predicate is possibly false.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/include/llvm/Transforms/Utils/LoopVersioning.h
    llvm/lib/Analysis/LoopAccessAnalysis.cpp
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Scalar/LoopDistribute.cpp
    llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
    llvm/lib/Transforms/Utils/LoopVersioning.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 30e62a640363e..4cafeff400a73 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2199,7 +2199,7 @@ class PredicatedScalarEvolution {
 public:
   PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L);
 
-  const SCEVUnionPredicate &getUnionPredicate() const;
+  const SCEVPredicate &getPredicate() const;
 
   /// Returns the SCEV expression of V, in the context of the current SCEV
   /// predicate.  The order of transformations applied on the expression of V

diff  --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
index 918e9f3302a59..3f3bf3353cac4 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
@@ -123,7 +123,7 @@ class LoopVersioning {
   SmallVector<RuntimePointerCheck, 4> AliasChecks;
 
   /// The set of SCEV checks that we are versioning for.
-  const SCEVUnionPredicate &Preds;
+  const SCEVPredicate &Preds;
 
   /// Maps a pointer to the pointer checking group that the pointer
   /// belongs to.

diff  --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 828877926ea8b..37f867e1c2e61 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2342,7 +2342,7 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
                    << "found in loop.\n";
 
   OS.indent(Depth) << "SCEV assumptions:\n";
-  PSE->getUnionPredicate().print(OS, Depth);
+  PSE->getPredicate().print(OS, Depth);
 
   OS << "\n";
 

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a9e102a1056a7..465e52bf03757 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13975,7 +13975,7 @@ void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
   updateGeneration();
 }
 
-const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
+const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
   return *Preds;
 }
 

diff  --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
index 0f4c767c1e4cf..26e4837749cf3 100644
--- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp
@@ -770,7 +770,7 @@ class LoopDistributeForLoop {
 
     // Don't distribute the loop if we need too many SCEV run-time checks, or
     // any if it's illegal.
-    const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate();
+    const SCEVPredicate &Pred = LAI->getPSE().getPredicate();
     if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) {
       return fail("RuntimeCheckWithConvergent",
                   "may not insert runtime check with convergent operation");

diff  --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
index 307d0b17a041d..15c698fb8178e 100644
--- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
@@ -529,7 +529,7 @@ class LoadEliminationForLoop {
       return false;
     }
 
-    if (LAI.getPSE().getUnionPredicate().getComplexity() >
+    if (LAI.getPSE().getPredicate().getComplexity() >
         LoadElimSCEVCheckThreshold) {
       LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
       return false;
@@ -540,7 +540,7 @@ class LoadEliminationForLoop {
       return false;
     }
 
-    if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) {
+    if (!Checks.empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {
       if (LAI.hasConvergentOp()) {
         LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with "
                              "convergent calls\n");

diff  --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index 049f49df9434c..97f29527bb95c 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -42,7 +42,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
                                LoopInfo *LI, DominatorTree *DT,
                                ScalarEvolution *SE)
     : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()),
-      Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT),
+      Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT),
       SE(SE) {
 }
 
@@ -276,7 +276,7 @@ bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA,
     const LoopAccessInfo &LAI = GetLAA(*L);
     if (!LAI.hasConvergentOp() &&
         (LAI.getNumRuntimePointerChecks() ||
-         !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
+         !LAI.getPSE().getPredicate().isAlwaysTrue())) {
       LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L,
                           LI, DT, SE);
       LVer.versionLoop();

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 81e5aa223c070..2dbc34e7165ce 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -572,7 +572,7 @@ void LoopVectorizationLegality::addInductionPhi(
   // on predicates that only hold within the loop, since allowing the exit
   // currently means re-using this SCEV outside the loop (see PR33706 for more
   // details).
-  if (PSE.getUnionPredicate().isAlwaysTrue()) {
+  if (PSE.getPredicate().isAlwaysTrue()) {
     AllowedExit.insert(Phi);
     AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch()));
   }
@@ -849,7 +849,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
         // used outside the loop only if the SCEV predicates within the loop is
         // same as outside the loop. Allowing the exit means reusing the SCEV
         // outside the loop.
-        if (PSE.getUnionPredicate().isAlwaysTrue()) {
+        if (PSE.getPredicate().isAlwaysTrue()) {
           AllowedExit.insert(&I);
           continue;
         }
@@ -919,7 +919,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   }
 
   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
-  PSE.addPredicate(LAI->getPSE().getUnionPredicate());
+  PSE.addPredicate(LAI->getPSE().getPredicate());
   return true;
 }
 
@@ -1266,7 +1266,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) {
   if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
     SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
 
-  if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) {
+  if (PSE.getPredicate().getComplexity() > SCEVThreshold) {
     reportVectorizationFailure("Too many SCEV checks needed",
         "Too many SCEV assumptions need to be made and checked at runtime",
         "TooManySCEVRunTimeChecks", ORE, TheLoop);

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e13149864a130..5f0129ec2b4e6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1998,7 +1998,7 @@ class GeneratedRTChecks {
   /// there is no vector code generation, the check blocks are removed
   /// completely.
   void Create(Loop *L, const LoopAccessInfo &LAI,
-              const SCEVUnionPredicate &UnionPred) {
+              const SCEVPredicate &Pred) {
 
     BasicBlock *LoopHeader = L->getHeader();
     BasicBlock *Preheader = L->getLoopPreheader();
@@ -2007,12 +2007,12 @@ class GeneratedRTChecks {
     // ensure the blocks are properly added to LoopInfo & DominatorTree. Those
     // may be used by SCEVExpander. The blocks will be un-linked from their
     // predecessors and removed from LI & DT at the end of the function.
-    if (!UnionPred.isAlwaysTrue()) {
+    if (!Pred.isAlwaysTrue()) {
       SCEVCheckBlock = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
                                   nullptr, "vector.scevcheck");
 
       SCEVCheckCond = SCEVExp.expandCodeForPredicate(
-          &UnionPred, SCEVCheckBlock->getTerminator());
+          &Pred, SCEVCheckBlock->getTerminator());
     }
 
     const auto &RtPtrChecking = *LAI.getRuntimePointerChecking();
@@ -5161,7 +5161,7 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() {
     return true;
   }
 
-  if (!PSE.getUnionPredicate().getPredicates().empty()) {
+  if (!PSE.getPredicate().isAlwaysTrue()) {
     reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz",
         "runtime SCEV checks needed. Enable vectorization of this "
         "loop with '#pragma clang loop vectorize(enable)' when "
@@ -10557,7 +10557,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
     GeneratedRTChecks Checks(*PSE.getSE(), DT, LI,
                              F->getParent()->getDataLayout());
     if (!VF.Width.isScalar() || IC > 1)
-      Checks.Create(L, *LVL.getLAI(), PSE.getUnionPredicate());
+      Checks.Create(L, *LVL.getLAI(), PSE.getPredicate());
 
     using namespace ore;
     if (!VectorizeLoop) {


        


More information about the llvm-commits mailing list