[llvm] e4b7724 - [SLP]Do extra analysis int minbitwidth if some checks return false.
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 14 16:44:51 PDT 2024
Author: Alexey Bataev
Date: 2024-03-14T16:41:04-07:00
New Revision: e4b772444c8176abe30d364e4a946ee6c8ae8de4
URL: https://github.com/llvm/llvm-project/commit/e4b772444c8176abe30d364e4a946ee6c8ae8de4
DIFF: https://github.com/llvm/llvm-project/commit/e4b772444c8176abe30d364e4a946ee6c8ae8de4.diff
LOG: [SLP]Do extra analysis int minbitwidth if some checks return false.
The instruction itself can be considered good for minbitwidth casting,
even if one of the operand checks returns false.
Reviewers: RKSimon
Reviewed By: RKSimon
Pull Request: https://github.com/llvm/llvm-project/pull/84363
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 4481306209d49f..acb738ef281e0d 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10226,9 +10226,11 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
for (const TreeEntry *TE : ForRemoval)
Set.erase(TE);
}
+ bool NeedToRemapValues = false;
for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
if (It->empty()) {
UsedTEs.erase(It);
+ NeedToRemapValues = true;
continue;
}
std::advance(It, 1);
@@ -10237,6 +10239,19 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
Entries.clear();
return std::nullopt;
}
+ // Recalculate the mapping between the values and entries sets.
+ if (NeedToRemapValues) {
+ DenseMap<Value *, int> PrevUsedValuesEntry;
+ PrevUsedValuesEntry.swap(UsedValuesEntry);
+ for (auto [Idx, Set] : enumerate(UsedTEs)) {
+ DenseSet<Value *> Values;
+ for (const TreeEntry *E : Set)
+ Values.insert(E->Scalars.begin(), E->Scalars.end());
+ for (const auto &P : PrevUsedValuesEntry)
+ if (Values.contains(P.first))
+ UsedValuesEntry.try_emplace(P.first, Idx);
+ }
+ }
}
unsigned VF = 0;
@@ -11935,7 +11950,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true);
if (VecTy != Vec->getType()) {
- assert((getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
MinBWs.contains(getOperandEntry(E, I))) &&
"Expected item in MinBWs.");
Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
@@ -12193,7 +12209,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return E->VectorizedValue;
}
if (L->getType() != R->getType()) {
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
MinBWs.contains(getOperandEntry(E, 0)) ||
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12232,7 +12249,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return E->VectorizedValue;
}
if (True->getType() != False->getType()) {
- assert((getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
getOperandEntry(E, 2)->State == TreeEntry::NeedToGather ||
MinBWs.contains(getOperandEntry(E, 1)) ||
MinBWs.contains(getOperandEntry(E, 2))) &&
@@ -12302,7 +12320,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return E->VectorizedValue;
}
if (LHS->getType() != RHS->getType()) {
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
MinBWs.contains(getOperandEntry(E, 0)) ||
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12540,7 +12559,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return E->VectorizedValue;
}
if (LHS && RHS && LHS->getType() != RHS->getType()) {
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
MinBWs.contains(getOperandEntry(E, 0)) ||
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -14002,6 +14022,33 @@ bool BoUpSLP::collectValuesToDemote(
};
unsigned Start = 0;
unsigned End = I->getNumOperands();
+
+ auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
+ if (!IsProfitableToDemote)
+ return false;
+ return (ITE && ITE->UserTreeIndices.size() > 1) ||
+ IsPotentiallyTruncated(I, BitWidth);
+ };
+ auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
+ NeedToExit = false;
+ unsigned InitLevel = MaxDepthLevel;
+ for (Value *IncValue : Operands) {
+ unsigned Level = InitLevel;
+ if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
+ ToDemote, DemotedConsts, Visited, Level,
+ IsProfitableToDemote, IsTruncRoot)) {
+ if (!IsProfitableToDemote)
+ return false;
+ NeedToExit = true;
+ if (!FinalAnalysis(ITE))
+ return false;
+ continue;
+ }
+ MaxDepthLevel = std::max(MaxDepthLevel, Level);
+ }
+ return true;
+ };
+ bool NeedToExit = false;
switch (I->getOpcode()) {
// We can always demote truncations and extensions. Since truncations can
@@ -14027,35 +14074,21 @@ bool BoUpSLP::collectValuesToDemote(
case Instruction::And:
case Instruction::Or:
case Instruction::Xor: {
- unsigned Level1, Level2;
- if ((ITE->UserTreeIndices.size() > 1 &&
- !IsPotentiallyTruncated(I, BitWidth)) ||
- !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
- BitWidth, ToDemote, DemotedConsts, Visited,
- Level1, IsProfitableToDemote, IsTruncRoot) ||
- !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
- BitWidth, ToDemote, DemotedConsts, Visited,
- Level2, IsProfitableToDemote, IsTruncRoot))
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
return false;
- MaxDepthLevel = std::max(Level1, Level2);
break;
}
// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
Start = 1;
- unsigned Level1, Level2;
- SelectInst *SI = cast<SelectInst>(I);
- if ((ITE->UserTreeIndices.size() > 1 &&
- !IsPotentiallyTruncated(I, BitWidth)) ||
- !collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot,
- BitWidth, ToDemote, DemotedConsts, Visited,
- Level1, IsProfitableToDemote, IsTruncRoot) ||
- !collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot,
- BitWidth, ToDemote, DemotedConsts, Visited,
- Level2, IsProfitableToDemote, IsTruncRoot))
+ auto *SI = cast<SelectInst>(I);
+ if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
return false;
- MaxDepthLevel = std::max(Level1, Level2);
break;
}
@@ -14066,22 +14099,20 @@ bool BoUpSLP::collectValuesToDemote(
MaxDepthLevel = 0;
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
return false;
- for (Value *IncValue : PN->incoming_values()) {
- unsigned Level;
- if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
- ToDemote, DemotedConsts, Visited, Level,
- IsProfitableToDemote, IsTruncRoot))
- return false;
- MaxDepthLevel = std::max(MaxDepthLevel, Level);
- }
+ SmallVector<Value *> Ops(PN->incoming_values().begin(),
+ PN->incoming_values().end());
+ if (!ProcessOperands(Ops, NeedToExit))
+ return false;
break;
}
// Otherwise, conservatively give up.
default:
MaxDepthLevel = 1;
- return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth);
+ return FinalAnalysis();
}
+ if (NeedToExit)
+ return true;
++MaxDepthLevel;
// Gather demoted constant operands.
@@ -14120,6 +14151,7 @@ void BoUpSLP::computeMinimumValueSizes() {
// The first value node for store/insertelement is sext/zext/trunc? Skip it,
// resize to the final type.
+ bool IsTruncRoot = false;
bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt;
if (NodeIdx != 0 &&
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
@@ -14127,8 +14159,9 @@ void BoUpSLP::computeMinimumValueSizes() {
VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt ||
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) {
assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph.");
- ++NodeIdx;
+ IsTruncRoot = VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc;
IsProfitableToDemoteRoot = true;
+ ++NodeIdx;
}
// Analyzed in reduction already and not profitable - exit.
@@ -14260,7 +14293,6 @@ void BoUpSLP::computeMinimumValueSizes() {
ReductionBitWidth = bit_ceil(ReductionBitWidth);
}
bool IsTopRoot = NodeIdx == 0;
- bool IsTruncRoot = false;
while (NodeIdx < VectorizableTree.size() &&
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) {
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
index 1986b51ec94828..02d1f9f60d0ca1 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
@@ -228,7 +228,7 @@ for.end: ; preds = %for.end.loopexit, %
; YAML-NEXT: Function: test_unrolled_select
; YAML-NEXT: Args:
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
-; YAML-NEXT: - Cost: '-36'
+; YAML-NEXT: - Cost: '-40'
; YAML-NEXT: - String: ' and with tree size '
; YAML-NEXT: - TreeSize: '10'
@@ -246,15 +246,17 @@ define i32 @test_unrolled_select(ptr noalias nocapture readonly %blk1, ptr noali
; CHECK-NEXT: [[P2_045:%.*]] = phi ptr [ [[BLK2:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR88:%.*]], [[IF_END_86]] ]
; CHECK-NEXT: [[P1_044:%.*]] = phi ptr [ [[BLK1:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR:%.*]], [[IF_END_86]] ]
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[P1_044]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i32>
+; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i8>, ptr [[P2_045]], align 1
-; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i32>
-; CHECK-NEXT: [[TMP4:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP3]]
-; CHECK-NEXT: [[TMP5:%.*]] = icmp slt <8 x i32> [[TMP4]], zeroinitializer
-; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <8 x i32> zeroinitializer, [[TMP4]]
-; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> [[TMP4]]
-; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
-; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP8]], [[S_047]]
+; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i16>
+; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i16> [[TMP1]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = trunc <8 x i16> [[TMP4]] to <8 x i1>
+; CHECK-NEXT: [[TMP6:%.*]] = icmp slt <8 x i1> [[TMP5]], zeroinitializer
+; CHECK-NEXT: [[TMP7:%.*]] = sub <8 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT: [[TMP8:%.*]] = select <8 x i1> [[TMP6]], <8 x i16> [[TMP7]], <8 x i16> [[TMP4]]
+; CHECK-NEXT: [[TMP9:%.*]] = zext <8 x i16> [[TMP8]] to <8 x i32>
+; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]])
+; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP10]], [[S_047]]
; CHECK-NEXT: [[CMP83:%.*]] = icmp slt i32 [[OP_RDX]], [[LIM:%.*]]
; CHECK-NEXT: br i1 [[CMP83]], label [[IF_END_86]], label [[FOR_END_LOOPEXIT:%.*]]
; CHECK: if.end.86:
More information about the llvm-commits
mailing list