[llvm] [VPlan] Add initial VPScalarEvolution, use to get trip count SCEV (NFC) (PR #94464)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 02:34:06 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/94464

>From fcd41480cd1539e37e046c31a95062ba43633103 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 5 Jun 2024 12:59:47 +0100
Subject: [PATCH 1/2] [VPlan] Add initial VPScalarEvolution, use to get trip
 count SCEV (NFC)

Add an initial version of VPScalarEvolution, which can be used to return
a SCEV expression for a VPValue. The initial implementation only returns
SCEVs for live-in IR values (by constructing a SCEV based on the live-in
IR value) and VPExpandSCEVRecipe. This is enough to serve its first
use, getting a SCEV for a VPlan's trip count, but will be extended in
the future.
---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp   |  5 +++--
 llvm/lib/Transforms/Vectorize/VPlan.h             |  3 ---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp   | 13 +++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h     | 15 +++++++++++++++
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp |  4 +---
 5 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c7c19ef456c7c..14197259a9f98 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -978,8 +978,9 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
   return B.CreateElementCount(Ty, VF);
 }
 
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
-                                Loop *OrigLoop) {
+static const SCEV *createTripCountSCEV(Type *IdxTy,
+                                       PredicatedScalarEvolution &PSE,
+                                       Loop *OrigLoop) {
   const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
   assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 943edc3520869..de37262d2c3de 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -79,9 +79,6 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF);
 Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
                        int64_t Step);
 
-const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
-                                Loop *CurLoop = nullptr);
-
 /// A range of powers-of-2 vectorization factors with fixed start and
 /// adjustable end. The range includes start and excludes end, e.g.,:
 /// [1, 16) = {1, 2, 4, 8}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 90bbf2d5d99fa..e9f01c1a334e1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "VPlanAnalysis.h"
 #include "VPlan.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Analysis/ScalarEvolution.h"
 
 using namespace llvm;
 
@@ -265,3 +266,15 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
   CachedTypes[V] = ResultTy;
   return ResultTy;
 }
+
+const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
+  if (V->isLiveIn())
+    return SE.getSCEV(V->getLiveInIRValue());
+
+  // TODO: Support constructing SCEVs for more recipes as needed.
+  return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
+      .Case<VPExpandSCEVRecipe>(
+          [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
+      .Default(
+          [this](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 7d310b1b31b6f..9d549778092c9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -23,6 +23,8 @@ class VPWidenIntOrFpInductionRecipe;
 class VPWidenMemoryRecipe;
 struct VPWidenSelectRecipe;
 class VPReplicateRecipe;
+class SCEV;
+class ScalarEvolution;
 class Type;
 
 /// An analysis for type-inference for VPValues.
@@ -61,6 +63,19 @@ class VPTypeAnalysis {
   LLVMContext &getContext() { return Ctx; }
 };
 
+/// A light wrapper over ScalarEvolution to construct SCEV expressions for
+/// VPValues and recipes.
+class VPScalarEvolution {
+  ScalarEvolution &SE;
+
+public:
+  VPScalarEvolution(ScalarEvolution &SE) : SE(SE) {}
+
+  /// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+  /// SCEV expression could be constructed.
+  const SCEV *getSCEV(VPValue *V);
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ab3b5cf2b9dab..541246266e1d5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -674,10 +674,8 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
              m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue())))))
     return;
 
-  Type *IdxTy =
-      Plan.getCanonicalIV()->getStartValue()->getLiveInIRValue()->getType();
-  const SCEV *TripCount = createTripCountSCEV(IdxTy, PSE);
   ScalarEvolution &SE = *PSE.getSE();
+  const SCEV *TripCount = VPScalarEvolution(SE).getSCEV(Plan.getTripCount());
   ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
   const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
   if (TripCount->isZero() ||

>From d7d389e7a072799989e94dad1e440e1826bbe0ab Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 12 Jul 2024 10:33:19 +0100
Subject: [PATCH 2/2] !fixup move logic to single getSCEVExprForVPValue
 function

---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp  |  5 ++---
 llvm/lib/Transforms/Vectorize/VPlanAnalysis.h    | 16 ++++++----------
 .../lib/Transforms/Vectorize/VPlanTransforms.cpp |  3 ++-
 3 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index d3f4c18d98fb4..245b26e4ce22c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -318,7 +318,7 @@ void llvm::collectEphemeralRecipesForVPlan(
   }
 }
 
-const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
+const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
   if (V->isLiveIn())
     return SE.getSCEV(V->getLiveInIRValue());
 
@@ -326,6 +326,5 @@ const SCEV *VPScalarEvolution::getSCEV(VPValue *V) {
   return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
       .Case<VPExpandSCEVRecipe>(
           [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
-      .Default(
-          [this](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
+      .Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 51eaa0f702018..ef2012a32e66b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -70,18 +70,14 @@ class VPTypeAnalysis {
 void collectEphemeralRecipesForVPlan(VPlan &Plan,
                                      DenseSet<VPRecipeBase *> &EphRecipes);
 
-/// A light wrapper over ScalarEvolution to construct SCEV expressions for
-/// VPValues and recipes.
-class VPScalarEvolution {
-  ScalarEvolution &SE;
+namespace vputils {
 
-public:
-  VPScalarEvolution(ScalarEvolution &SE) : SE(SE) {}
+/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
+/// SCEV expression could be constructed.
+const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);
+
+} // namespace vputils
 
-  /// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
-  /// SCEV expression could be constructed.
-  const SCEV *getSCEV(VPValue *V);
-};
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 734ab96a4a68a..b55e716fbff37 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -684,7 +684,8 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
     return;
 
   ScalarEvolution &SE = *PSE.getSE();
-  const SCEV *TripCount = VPScalarEvolution(SE).getSCEV(Plan.getTripCount());
+  const SCEV *TripCount =
+      vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
   ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
   const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
   if (TripCount->isZero() ||



More information about the llvm-commits mailing list