[llvm] 910d2de - [SLP]Fix PR88103: consider the sign of the compare for non-negative operands.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 10:55:00 PDT 2024


Author: Alexey Bataev
Date: 2024-04-09T10:47:47-07:00
New Revision: 910d2de357de8a490cac3ecbd27196356fe1f2a3

URL: https://github.com/llvm/llvm-project/commit/910d2de357de8a490cac3ecbd27196356fe1f2a3
DIFF: https://github.com/llvm/llvm-project/commit/910d2de357de8a490cac3ecbd27196356fe1f2a3.diff

LOG: [SLP]Fix PR88103: consider the sign of the compare for non-negative operands.

Need to improve detection of number of bits, required for the operand,
before doing a reduction. If the instruction is incoming operand of the
signed compare, need to consider adding an extra bit for signedness.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index da2b61ea6a635e..c3dcf73b0b7626 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -14430,6 +14430,8 @@ bool BoUpSLP::collectValuesToDemote(
   return FinalAnalysis();
 }
 
+static RecurKind getRdxKind(Value *V);
+
 void BoUpSLP::computeMinimumValueSizes() {
   // We only attempt to truncate integer expressions.
   bool IsStoreOrInsertElt =
@@ -14480,7 +14482,7 @@ void BoUpSLP::computeMinimumValueSizes() {
   auto ComputeMaxBitWidth = [&](ArrayRef<Value *> TreeRoot, unsigned VF,
                                 bool IsTopRoot, bool IsProfitableToDemoteRoot,
                                 unsigned Opcode, unsigned Limit,
-                                bool IsTruncRoot) {
+                                bool IsTruncRoot, bool IsSignedCmp) {
     ToDemote.clear();
     auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType());
     if (!TreeRootIT || !Opcode)
@@ -14503,7 +14505,7 @@ void BoUpSLP::computeMinimumValueSizes() {
     // True.
     // Determine if the sign bit of all the roots is known to be zero. If not,
     // IsKnownPositive is set to False.
-    bool IsKnownPositive = all_of(TreeRoot, [&](Value *R) {
+    bool IsKnownPositive = !IsSignedCmp && all_of(TreeRoot, [&](Value *R) {
       KnownBits Known = computeKnownBits(R, *DL);
       return Known.isNonNegative();
     });
@@ -14590,8 +14592,11 @@ void BoUpSLP::computeMinimumValueSizes() {
       unsigned BitWidth1 = NumTypeBits - NumSignBits;
       if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
         ++BitWidth1;
-      auto Mask = DB->getDemandedBits(cast<Instruction>(V));
-      unsigned BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+      unsigned BitWidth2 = BitWidth1;
+      if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
+        auto Mask = DB->getDemandedBits(cast<Instruction>(V));
+        BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
+      }
       ReductionBitWidth =
           std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
     }
@@ -14608,6 +14613,7 @@ void BoUpSLP::computeMinimumValueSizes() {
     ++NodeIdx;
     IsTruncRoot = true;
   }
+  bool IsSignedCmp = false;
   while (NodeIdx < VectorizableTree.size()) {
     ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
     unsigned Limit = 2;
@@ -14619,7 +14625,7 @@ void BoUpSLP::computeMinimumValueSizes() {
       Limit = 3;
     unsigned MaxBitWidth = ComputeMaxBitWidth(
         TreeRoot, VectorizableTree[NodeIdx]->getVectorFactor(), IsTopRoot,
-        IsProfitableToDemoteRoot, Opcode, Limit, IsTruncRoot);
+        IsProfitableToDemoteRoot, Opcode, Limit, IsTruncRoot, IsSignedCmp);
     if (ReductionBitWidth != 0 && (IsTopRoot || !RootDemotes.empty())) {
       if (MaxBitWidth != 0 && ReductionBitWidth < MaxBitWidth)
         ReductionBitWidth = bit_ceil(MaxBitWidth);
@@ -14657,6 +14663,16 @@ void BoUpSLP::computeMinimumValueSizes() {
                           EI.UserTE->getOpcode() == Instruction::Trunc &&
                           !EI.UserTE->isAltShuffle();
                  });
+      IsSignedCmp =
+          NodeIdx < VectorizableTree.size() &&
+          any_of(VectorizableTree[NodeIdx]->UserTreeIndices,
+                 [](const EdgeInfo &EI) {
+                   return EI.UserTE->getOpcode() == Instruction::ICmp &&
+                          any_of(EI.UserTE->Scalars, [](Value *V) {
+                            auto *IC = dyn_cast<ICmpInst>(V);
+                            return IC && IC->isSigned();
+                          });
+                 });
     }
 
     // If the maximum bit width we compute is less than the with of the roots'
@@ -16697,6 +16713,10 @@ class HorizontalReduction {
 };
 } // end anonymous namespace
 
+/// Gets recurrence kind from the specified value.
+static RecurKind getRdxKind(Value *V) {
+  return HorizontalReduction::getRdxKind(V);
+}
 static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) {
   if (auto *IE = dyn_cast<InsertElementInst>(InsertInst))
     return cast<FixedVectorType>(IE->getType())->getNumElements();

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
index f76b6be02477b8..7f086d17ca4c08 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/zext-incoming-for-neg-icmp.ll
@@ -10,7 +10,9 @@ define i32 @test(i32 %a, i8 %b, i8 %c) {
 ; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i8> [[TMP1]], <i8 -1, i8 -2, i8 -3, i8 -4>
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i8> poison, i8 [[B]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp sle <4 x i8> [[TMP2]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i16>
+; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i8> [[TMP4]] to <4 x i16>
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp sle <4 x i16> [[TMP8]], [[TMP9]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = zext <4 x i1> [[TMP5]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
 ; CHECK-NEXT:    [[OP_RDX:%.*]] = add i32 [[TMP7]], [[A]]


        


More information about the llvm-commits mailing list