[llvm] [LV][NFC]Introduce isScalableVectorizationAllowed() to refactor getMaxLegalScalableVF(). (PR #98916)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 16 10:28:29 PDT 2024
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/98916
>From e2c6fc26f3a0c1fb5f8ab0df53e2c5edf3f8cd80 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 15 Jul 2024 15:22:05 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
=?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.5
---
.../Transforms/Vectorize/LoopVectorize.cpp | 45 ++++++++++++++-----
1 file changed, 35 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 5520baef7152d..86957af9ce6c7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1627,6 +1627,10 @@ class LoopVectorizationCostModel {
ElementCount MaxSafeVF,
bool FoldTailByMasking);
+ /// Checks if scalable vectorization is supported and enabled. Caches the
+ /// result to avoid repeated debug dumps for repeated queries.
+ bool isScalableVectorizationAllowed();
+
/// \return the maximum legal scalable VF, based on the safe max number
/// of elements.
ElementCount getMaxLegalScalableVF(unsigned MaxSafeElements);
@@ -1691,6 +1695,9 @@ class LoopVectorizationCostModel {
std::optional<std::pair<TailFoldingStyle, TailFoldingStyle>>
ChosenTailFoldingStyle;
+ /// true if scalable vectorization is supported and enabled.
+ std::optional<bool> IsScalableVectorizationAllowed;
+
/// A map holding scalar costs for different vectorization factors. The
/// presence of a cost for an instruction in the mapping indicates that the
/// instruction will be scalarized when vectorizing with the associated
@@ -4143,15 +4150,18 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() {
return false;
}
-ElementCount
-LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
+bool LoopVectorizationCostModel::isScalableVectorizationAllowed() {
+ if (IsScalableVectorizationAllowed)
+ return *IsScalableVectorizationAllowed;
+
+ IsScalableVectorizationAllowed = false;
if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors)
- return ElementCount::getScalable(0);
+ return false;
if (Hints->isScalableVectorizationDisabled()) {
reportVectorizationInfo("Scalable vectorization is explicitly disabled",
"ScalableVectorizationDisabled", ORE, TheLoop);
- return ElementCount::getScalable(0);
+ return false;
}
LLVM_DEBUG(dbgs() << "LV: Scalable vectorization is available\n");
@@ -4171,7 +4181,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
"Scalable vectorization not supported for the reduction "
"operations found in this loop.",
"ScalableVFUnfeasible", ORE, TheLoop);
- return ElementCount::getScalable(0);
+ return false;
}
// Disable scalable vectorization if the loop contains any instructions
@@ -4183,17 +4193,32 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
reportVectorizationInfo("Scalable vectorization is not supported "
"for all element types found in this loop.",
"ScalableVFUnfeasible", ORE, TheLoop);
- return ElementCount::getScalable(0);
+ return false;
}
+ IsScalableVectorizationAllowed = true;
+ return true;
+}
+
+ElementCount
+LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
+ if (!isScalableVectorizationAllowed())
+ return ElementCount::getScalable(0);
+
+ auto MaxScalableVF = ElementCount::getScalable(
+ std::numeric_limits<ElementCount::ScalarTy>::max());
if (Legal->isSafeForAnyVectorWidth())
return MaxScalableVF;
+ std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI);
+ if (!MaxVScale) {
+ reportVectorizationInfo("The target does not provide maximum vscale value.",
+ "ScalableVFUnfeasible", ORE, TheLoop);
+ return ElementCount::getScalable(0);
+ }
+
// Limit MaxScalableVF by the maximum safe dependence distance.
- if (std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI))
- MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale);
- else
- MaxScalableVF = ElementCount::getScalable(0);
+ MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale);
if (!MaxScalableVF)
reportVectorizationInfo(
More information about the llvm-commits
mailing list