[llvm] 910d2de - [SLP]Fix PR88103: consider the sign of the compare for non-negative operands.
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 9 10:55:00 PDT 2024
Author: Alexey Bataev
Date: 2024-04-09T10:47:47-07:00
New Revision: 910d2de357de8a490cac3ecbd27196356fe1f2a3
URL: https://github.com/llvm/llvm-project/commit/910d2de357de8a490cac3ecbd27196356fe1f2a3
DIFF: https://github.com/llvm/llvm-project/commit/910d2de357de8a490cac3ecbd27196356fe1f2a3.diff
LOG: [SLP]Fix PR88103: consider the sign of the compare for non-negative operands.
Need to improve detection of number of bits, required for the operand,
before doing a reduction. If the instruction is incoming operand of the
signed compare, need to consider adding an extra bit for signedness.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index da2b61ea6a635e..c3dcf73b0b7626 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -14430,6 +14430,8 @@ bool BoUpSLP::collectValuesToDemote(
return FinalAnalysis();
}
+static RecurKind getRdxKind(Value *V);
+
void BoUpSLP::computeMinimumValueSizes() {
// We only attempt to truncate integer expressions.
bool IsStoreOrInsertElt =
@@ -14480,7 +14482,7 @@ void BoUpSLP::computeMinimumValueSizes() {
auto ComputeMaxBitWidth = [&](ArrayRef<Value *> TreeRoot, unsigned VF,
bool IsTopRoot, bool IsProfitableToDemoteRoot,
unsigned Opcode, unsigned Limit,
- bool IsTruncRoot) {
+ bool IsTruncRoot, bool IsSignedCmp) {
ToDemote.clear();
auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType());
if (!TreeRootIT || !Opcode)
@@ -14503,7 +14505,7 @@ void BoUpSLP::computeMinimumValueSizes() {
// True.
// Determine if the sign bit of all the roots is known to be zero. If not,
// IsKnownPositive is set to False.
- bool IsKnownPositive = all_of(TreeRoot, [&](Value *R) {
+ bool IsKnownPositive = !IsSignedCmp && all_of(TreeRoot, [&](Value *R) {
KnownBits Known = computeKnownBits(R, *DL);
return Known.isNonNegative();
});
@@ -14590,8 +14592,11 @@ void BoUpSLP::computeMinimumValueSizes() {
unsigned BitWidth1 = NumTypeBits - NumSignBits;
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
++BitWidth1;
- auto Mask = DB->getDemandedBits(cast<Instruction>(V));
- unsigned BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+ unsigned BitWidth2 = BitWidth1;
+ if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
+ auto Mask = DB->getDemandedBits(cast<Instruction>(V));
+ BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+ }
ReductionBitWidth =
std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
}
@@ -14608,6 +14613,7 @@ void BoUpSLP::computeMinimumValueSizes() {
++NodeIdx;
IsTruncRoot = true;
}
+ bool IsSignedCmp = false;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
unsigned Limit = 2;
@@ -14619,7 +14625,7 @@ void BoUpSLP::computeMinimumValueSizes() {
Limit = 3;
unsigned MaxBitWidth = ComputeMaxBitWidth(
TreeRoot, VectorizableTree[NodeIdx]->getVectorFactor(), IsTopRoot,
- IsProfitableToDemoteRoot, Opcode, Limit, IsTruncRoot);
+ IsProfitableToDemoteRoot, Opcode, Limit, IsTruncRoot, IsSignedCmp);
if (ReductionBitWidth != 0 && (IsTopRoot || !RootDemotes.empty())) {
if (MaxBitWidth != 0 && ReductionBitWidth < MaxBitWidth)
ReductionBitWidth = bit_ceil(MaxBitWidth);
@@ -14657,6 +14663,16 @@ void BoUpSLP::computeMinimumValueSizes() {
EI.UserTE->getOpcode() == Instruction::Trunc &&
!EI.UserTE->isAltShuffle();
});
+ IsSignedCmp =
+ NodeIdx < VectorizableTree.size() &&
+ any_of(VectorizableTree[NodeIdx]->UserTreeIndices,
+ [](const EdgeInfo &EI) {
+ return EI.UserTE->getOpcode() == Instruction::ICmp &&
+ any_of(EI.UserTE->Scalars, [](Value *V) {
+ auto *IC = dyn_cast<ICmpInst>(V);
+ return IC && IC->isSigned();
+ });
+ });
}
// If the maximum bit width we compute is less than the with of the roots'
@@ -16697,6 +16713,10 @@ class HorizontalReduction {
};
} // end anonymous namespace
+/// Gets recurrence kind from the specified value.
+static RecurKind getRdxKind(Value *V) {
+ return HorizontalReduction::getRdxKind(V);
+}
static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) {
if (auto *IE = dyn_cast<InsertElementInst>(InsertInst))
return cast<FixedVectorType>(IE->getType())->getNumElements();
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
index f76b6be02477b8..7f086d17ca4c08 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
@@ -10,7 +10,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
; CHECK-NEXT: [[TMP2:%.*]] = add <4 x i8> [[TMP1]], <i8 -1, i8 -2, i8 -3, i8 -4>
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i8> poison, i8 [[B]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i8> [[TMP2]], [[TMP4]]
+; CHECK-NEXT: [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
+; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
+; CHECK-NEXT: [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]
More information about the llvm-commits
mailing list