[llvm] [VPlan] Add initial VPScalarEvolution, use to get trip count SCEV (NFC) (PR #94464)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 17 02:34:08 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/94464

>From fcd41480cd1539e37e046c31a95062ba43633103 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 5 Jun 2024 12:59:47 +0100
Subject: [PATCH 1/4] [VPlan] Add initial VPScalarEvolution, use to get trip
 count SCEV (NFC)

Add an initial version of VPScalarEvolution, which can be used to return
a SCEV expression for a VPValue. The initial implementation only returns
SCEVs for live-in IR values (by constructing a SCEV based on the live-in
IR value) and VPExpandSCEVRecipe. This is enough to serve its first
use, getting a SCEV for a VPlan's trip count, but will be extended in
the future.
---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp   |  5 +++--
 llvm/lib/Transforms/Vectorize/VPlan.h             |  3 ---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp   | 13 +++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h     | 15 +++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp |  4 +---
 5 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c7c19ef456c7cb..14197259a9f98d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -978,8 +978,9 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
   return B.CreateElementCount(Ty, VF);
 }
 
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
-                                Loop *OrigLoop) {
+static const SCEV *createTripCountSCEV(Type *IdxTy,
+                                       PredicatedScalarEvolution &PSE,
+                                       Loop *OrigLoop) {
   const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
   assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 943edc3520869f..de37262d2c3dee 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -79,9 +79,6 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF);
 Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
                        int64_t Step);
 
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
-                                Loop *CurLoop = nullptr);
-
 /// A range of powers-of-2 vectorization factors with fixed start and
 /// adjustable end. The range includes start and excludes end, e.g.,:
 /// [1, 16) = {1, 2, 4, 8}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 90bbf2d5d99faf..e9f01c1a334e1f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "VPlanAnalysis.h"
 #include "VPlan.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Analysis/ScalarEvolution.h"
 
 using namespace llvm;
 
@@ -265,3 +266,15 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   CachedTypes[V] = ResultTy;
   return ResultTy;
 }
+
+const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
+  if (V->isLiveIn())
+    return SE.getSCEV(V->getLiveInIRValue());
+
+  // TODO: Support constructing SCEVs for more recipes as needed.
+  return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
+      .Case<VPExpandSCEVRecipe>(
+          [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
+      .Default(
+          [this](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6fe..9d549778092c99 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,8 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class SCEV;
+class ScalarEvolution;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -61,6 +63,19 @@ class VPTypeAnalysis {
   LLVMContext &getContext() { return Ctx; }
 };
 
+/// A light wrapper over ScalarEvolution to construct SCEV expressions for
+/// VPValues and recipes.
+class VPScalarEvolution {
+  ScalarEvolution &SE;
+
+public:
+  VPScalarEvolution(ScalarEvolution &SE) : SE(SE) {}
+
+  /// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+  /// SCEV expression could be constructed.
+  const SCEV *getSCEV(VPValue *V);
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ab3b5cf2b9dabe..541246266e1d51 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -674,10 +674,8 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
              m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue())))))
     return;
 
-  Type *IdxTy =
-      Plan.getCanonicalIV()->getStartValue()->getLiveInIRValue()->getType();
-  const SCEV *TripCount = createTripCountSCEV(IdxTy, PSE);
   ScalarEvolution &SE = *PSE.getSE();
+  const SCEV *TripCount = VPScalarEvolution(SE).getSCEV(Plan.getTripCount());
   ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
   const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
   if (TripCount->isZero() ||

>From d7d389e7a072799989e94dad1e440e1826bbe0ab Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 12 Jul 2024 10:33:19 +0100
Subject: [PATCH 2/4] !fixup move logic to single getSCEVExprForVPValue
 function

---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp  |  5 ++---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h    | 16 ++++++----------
 .../lib/Transforms/Vectorize/VPlanTransforms.cpp |  3 ++-
 3 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index d3f4c18d98fb49..245b26e4ce22c0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -318,7 +318,7 @@ void llvm::collectEphemeralRecipesForVPlan(
   }
 }
 
-const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
+const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
   if (V->isLiveIn())
     return SE.getSCEV(V->getLiveInIRValue());
 
@@ -326,6 +326,5 @@ const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
   return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
       .Case<VPExpandSCEVRecipe>(
           [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
-      .Default(
-          [this](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
+      .Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 51eaa0f7020184..ef2012a32e66bd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -70,18 +70,14 @@ class VPTypeAnalysis {
 void collectEphemeralRecipesForVPlan(VPlan &Plan,
                                      DenseSet<VPRecipeBase *> &EphRecipes);
 
-/// A light wrapper over ScalarEvolution to construct SCEV expressions for
-/// VPValues and recipes.
-class VPScalarEvolution {
-  ScalarEvolution &SE;
+namespace vputils {
 
-public:
-  VPScalarEvolution(ScalarEvolution &SE) : SE(SE) {}
+/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+/// SCEV expression could be constructed.
+const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
+
+} // namespace vputils
 
-  /// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
-  /// SCEV expression could be constructed.
-  const SCEV *getSCEV(VPValue *V);
-};
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 734ab96a4a68af..b55e716fbff37b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -684,7 +684,8 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
     return;
 
   ScalarEvolution &SE = *PSE.getSE();
-  const SCEV *TripCount = VPScalarEvolution(SE).getSCEV(Plan.getTripCount());
+  const SCEV *TripCount =
+      vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
   ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
   const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
   if (TripCount->isZero() ||

>From 9bc1413a4b1929d16d42a27e94139bd917fb8de9 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 29 Aug 2024 20:08:07 +0100
Subject: [PATCH 3/4] !fixup move to VPlanUtils.cpp

---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp | 11 -----------
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h   | 10 ----------
 llvm/lib/Transforms/Vectorize/VPlanUtils.cpp    | 12 ++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanUtils.h      | 10 ++++++++++
 4 files changed, 22 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 2151400c52a096..0977f764addb8e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -323,16 +323,6 @@ void llvm::collectEphemeralRecipesForVPlan(
   }
 }
 
-const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
-  if (V->isLiveIn())
-    return SE.getSCEV(V->getLiveInIRValue());
-
-  // TODO: Support constructing SCEVs for more recipes as needed.
-  return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
-      .Case<VPExpandSCEVRecipe>(
-          [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
-      .Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
-
 template void DomTreeBuilder::Calculate<DominatorTreeBase<VPBlockBase, false>>(
     DominatorTreeBase<VPBlockBase, false> &DT);
 
@@ -374,5 +364,4 @@ bool VPDominatorTree::properlyDominates(const VPRecipeBase *A,
          "No replicate regions expected at this point");
 #endif
   return Base::properlyDominates(ParentA, ParentB);
->>>>>>> origin/main
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index ef2012a32e66bd..e91d6f17f3cc6b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -26,8 +26,6 @@ struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
 class VPRecipeBase;
 class VPlan;
-class SCEV;
-class ScalarEvolution;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -70,14 +68,6 @@ class VPTypeAnalysis {
 void collectEphemeralRecipesForVPlan(VPlan &Plan,
                                      DenseSet<VPRecipeBase *> &EphRecipes);
 
-namespace vputils {
-
-/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
-/// SCEV expression could be constructed.
-const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
-
-} // namespace vputils
-
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index c18bea4f4c5926..414f8866d24f0f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -8,6 +8,7 @@
 
 #include "VPlanUtils.h"
 #include "VPlanPatternMatch.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 
 using namespace llvm;
@@ -60,3 +61,14 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) {
   return match(V, m_Binary<Instruction::ICmp>(m_VPValue(A), m_VPValue(B))) &&
          IsWideCanonicalIV(A) && B == Plan.getOrCreateBackedgeTakenCount();
 }
+
+const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
+  if (V->isLiveIn())
+    return SE.getSCEV(V->getLiveInIRValue());
+
+  // TODO: Support constructing SCEVs for more recipes as needed.
+  return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
+      .Case<VPExpandSCEVRecipe>(
+          [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
+      .Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
index fc11208a433961..db047c8a671a2d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
@@ -11,6 +11,11 @@
 
 #include "VPlan.h"
 
+namespace llvm {
+class ScalarEvolution;
+class SCEV;
+} // namespace llvm
+
 namespace llvm::vputils {
 /// Returns true if only the first lane of \p Def is used.
 bool onlyFirstLaneUsed(const VPValue *Def);
@@ -45,6 +50,11 @@ inline bool isUniformAfterVectorization(const VPValue *VPV) {
 
 /// Return true if \p V is a header mask in \p Plan.
 bool isHeaderMask(const VPValue *V, VPlan &Plan);
+
+/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+/// SCEV expression could be constructed.
+const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
+
 } // end namespace llvm::vputils
 
 #endif

>From ca98f415bb9254944ea3e77da450609685220f94 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 17 Sep 2024 10:33:20 +0100
Subject: [PATCH 4/4] !fixup address comments, thanks!

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 27 +++++++------------
 llvm/lib/Transforms/Vectorize/VPlan.cpp       | 26 +++++++++++-------
 llvm/lib/Transforms/Vectorize/VPlan.h         |  4 +--
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h |  1 -
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  2 ++
 llvm/lib/Transforms/Vectorize/VPlanUtils.h    |  9 +++----
 .../Transforms/Vectorize/VPlanTestBase.h      | 18 +++++++------
 7 files changed, 44 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 8319a41c9fcd52..9fb684427cfe9d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -905,16 +905,6 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
   return B.CreateElementCount(Ty, VF);
 }
 
-static const SCEV *createTripCountSCEV(Type *IdxTy,
-                                       PredicatedScalarEvolution &PSE,
-                                       Loop *OrigLoop) {
-  const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
-  assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
-
-  ScalarEvolution &SE = *PSE.getSE();
-  return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
-}
-
 void reportVectorizationFailure(const StringRef DebugMsg,
                                 const StringRef OREMsg, const StringRef ORETag,
                                 OptimizationRemarkEmitter *ORE, Loop *TheLoop,
@@ -4751,7 +4741,10 @@ VectorizationFactor LoopVectorizationPlanner::selectEpilogueVectorizationFactor(
     if (!MainLoopVF.isScalable() && !NextVF.Width.isScalable()) {
       // TODO: extend to support scalable VFs.
       if (!RemainingIterations) {
-        const SCEV *TC = createTripCountSCEV(TCType, PSE, OrigLoop);
+        const SCEV *TC = vputils::getSCEVExprForVPValue(
+            getPlanFor(NextVF.Width).getTripCount(), SE);
+        assert(!isa<SCEVCouldNotCompute>(TC) &&
+               "Trip count SCEV must be computable");
         RemainingIterations = SE.getURemExpr(
             TC, SE.getConstant(TCType, MainLoopVF.getKnownMinValue() * IC));
       }
@@ -8864,10 +8857,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
             return !CM.requiresScalarEpilogue(VF.isVector());
           },
           Range);
-  VPlanPtr Plan = VPlan::createInitialVPlan(
-      createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
-      *PSE.getSE(), RequiresScalarEpilogueCheck, CM.foldTailByMasking(),
-      OrigLoop);
+  VPlanPtr Plan = VPlan::createInitialVPlan(Legal->getWidestInductionType(),
+                                            PSE, RequiresScalarEpilogueCheck,
+                                            CM.foldTailByMasking(), OrigLoop);
 
   // Don't use getDecisionAndClampRange here, because we don't know the UF
   // so this function is better to be conservative, rather than to split
@@ -9082,9 +9074,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   assert(EnableVPlanNativePath && "VPlan-native path is not enabled.");
 
   // Create new empty VPlan
-  auto Plan = VPlan::createInitialVPlan(
-      createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop),
-      *PSE.getSE(), true, false, OrigLoop);
+  auto Plan = VPlan::createInitialVPlan(Legal->getWidestInductionType(), PSE,
+                                        true, false, OrigLoop);
 
   // Build hierarchical CFG
   VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan);
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index a310756793a515..4f9addb705328f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -869,14 +869,23 @@ static VPIRBasicBlock *createVPIRBasicBlockFor(BasicBlock *BB) {
   return VPIRBB;
 }
 
-VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE,
+VPlanPtr VPlan::createInitialVPlan(Type *IdxTy, PredicatedScalarEvolution &PSE,
                                    bool RequiresScalarEpilogueCheck,
                                    bool TailFolded, Loop *TheLoop) {
   VPIRBasicBlock *Entry = createVPIRBasicBlockFor(TheLoop->getLoopPreheader());
   VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph");
   auto Plan = std::make_unique<VPlan>(Entry, VecPreheader);
-  Plan->TripCount =
-      vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE);
+  {
+    const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
+    assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
+           "Invalid loop count");
+
+    ScalarEvolution &SE = *PSE.getSE();
+    const SCEV *TripCount =
+        SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, TheLoop);
+    Plan->TripCount =
+        vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE);
+  }
   // Create VPRegionBlock, with empty header and latch blocks, to be filled
   // during processing later.
   VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body");
@@ -916,12 +925,11 @@ VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE,
   // debugging. Eg. if the compare has got a line number inside the loop.
   VPBuilder Builder(MiddleVPBB);
   VPValue *Cmp =
-      TailFolded
-          ? Plan->getOrAddLiveIn(ConstantInt::getTrue(
-                IntegerType::getInt1Ty(TripCount->getType()->getContext())))
-          : Builder.createICmp(CmpInst::ICMP_EQ, Plan->getTripCount(),
-                               &Plan->getVectorTripCount(),
-                               ScalarLatchTerm->getDebugLoc(), "cmp.n");
+      TailFolded ? Plan->getOrAddLiveIn(ConstantInt::getTrue(
+                       IntegerType::getInt1Ty(IdxTy->getContext())))
+                 : Builder.createICmp(CmpInst::ICMP_EQ, Plan->getTripCount(),
+                                      &Plan->getVectorTripCount(),
+                                      ScalarLatchTerm->getDebugLoc(), "cmp.n");
   Builder.createNaryOp(VPInstruction::BranchOnCond, {Cmp},
                        ScalarLatchTerm->getDebugLoc());
   return Plan;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index cf5b15308606e3..3513d1bb861d64 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3474,8 +3474,8 @@ class VPlan {
   /// middle VPBasicBlock. If a check is needed to guard executing the scalar
   /// epilogue loop, it will be added to the middle block, together with
   /// VPBasicBlocks for the scalar preheader and exit blocks.
-  static VPlanPtr createInitialVPlan(const SCEV *TripCount,
-                                     ScalarEvolution &PSE,
+  static VPlanPtr createInitialVPlan(Type *IdxTy,
+                                     PredicatedScalarEvolution &PSE,
                                      bool RequiresScalarEpilogueCheck,
                                      bool TailFolded, Loop *TheLoop);
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7540d9570d37eb..cc21870bee2e3b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -68,7 +68,6 @@ class VPTypeAnalysis {
 // Collect a VPlan's ephemeral recipes (those used only by an assume).
 void collectEphemeralRecipesForVPlan(VPlan &Plan,
                                      DenseSet<VPRecipeBase *> &EphRecipes);
-
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index bb0c6bd5745d90..edcd7d26e60daa 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -688,6 +688,8 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
   ScalarEvolution &SE = *PSE.getSE();
   const SCEV *TripCount =
       vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
+  assert(!isa<SCEVCouldNotCompute>(TripCount) &&
+         "Trip count SCEV must be computable");
   ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
   const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
   if (TripCount->isZero() ||
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
index db047c8a671a2d..7b5d4300655f5a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h
@@ -31,6 +31,10 @@ bool onlyFirstPartUsed(const VPValue *Def);
 VPValue *getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr,
                                        ScalarEvolution &SE);
 
+/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+/// SCEV expression could be constructed.
+const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
+
 /// Returns true if \p VPV is uniform after vectorization.
 inline bool isUniformAfterVectorization(const VPValue *VPV) {
   // A value defined outside the vector region must be uniform after
@@ -50,11 +54,6 @@ inline bool isUniformAfterVectorization(const VPValue *VPV) {
 
 /// Return true if \p V is a header mask in \p Plan.
 bool isHeaderMask(const VPValue *V, VPlan &Plan);
-
-/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
-/// SCEV expression could be constructed.
-const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
-
 } // end namespace llvm::vputils
 
 #endif
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
index e7b51190489159..06e091da9054e3 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h
@@ -67,10 +67,11 @@ class VPlanTestBase : public testing::Test {
     assert(!verifyFunction(F) && "input function must be valid");
     doAnalysis(F);
 
-    auto Plan = VPlan::createInitialVPlan(
-        SE->getBackedgeTakenCount(LI->getLoopFor(LoopHeader)), *SE, true, false,
-        LI->getLoopFor(LoopHeader));
-    VPlanHCFGBuilder HCFGBuilder(LI->getLoopFor(LoopHeader), LI.get(), *Plan);
+    Loop *L = LI->getLoopFor(LoopHeader);
+    PredicatedScalarEvolution PSE(*SE, *L);
+    auto Plan = VPlan::createInitialVPlan(IntegerType::get(*Ctx, 64), PSE, true,
+                                          false, L);
+    VPlanHCFGBuilder HCFGBuilder(L, LI.get(), *Plan);
     HCFGBuilder.buildHierarchicalCFG();
     return Plan;
   }
@@ -81,10 +82,11 @@ class VPlanTestBase : public testing::Test {
     assert(!verifyFunction(F) && "input function must be valid");
     doAnalysis(F);
 
-    auto Plan = VPlan::createInitialVPlan(
-        SE->getBackedgeTakenCount(LI->getLoopFor(LoopHeader)), *SE, true, false,
-        LI->getLoopFor(LoopHeader));
-    VPlanHCFGBuilder HCFGBuilder(LI->getLoopFor(LoopHeader), LI.get(), *Plan);
+    Loop *L = LI->getLoopFor(LoopHeader);
+    PredicatedScalarEvolution PSE(*SE, *L);
+    auto Plan = VPlan::createInitialVPlan(IntegerType::get(*Ctx, 64), PSE, true,
+                                          false, L);
+    VPlanHCFGBuilder HCFGBuilder(L, LI.get(), *Plan);
     HCFGBuilder.buildPlainCFG();
     return Plan;
   }



More information about the llvm-commits mailing list