[llvm] 228c2e1 - [SLP]Fix incorrect promotion of nodes before shuffling.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 13:39:42 PDT 2024


Author: Alexey Bataev
Date: 2024-06-25T13:39:28-07:00
New Revision: 228c2e147390ca996a22e460e0bb804431469d25

URL: https://github.com/llvm/llvm-project/commit/228c2e147390ca996a22e460e0bb804431469d25
DIFF: https://github.com/llvm/llvm-project/commit/228c2e147390ca996a22e460e0bb804431469d25.diff

LOG: [SLP]Fix incorrect promotion of nodes before shuffling.

If the base node is signed, but some values are unsigned, still the
whole node should be considered signed. Also, an extra bitwidth analysis
should be performed, when estimating the minimal bitwidth.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/RISCV/shuffled-gather-casted.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 08fcca6e9bef8..056340eccec5c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11803,13 +11803,13 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
     Value *V1 = E1.VectorizedValue;
     if (V1->getType()->isIntOrIntVectorTy())
-      V1 = castToScalarTyElem(V1, all_of(E1.Scalars, [&](Value *V) {
+      V1 = castToScalarTyElem(V1, any_of(E1.Scalars, [&](Value *V) {
                                 return !isKnownNonNegative(
                                     V, SimplifyQuery(*R.DL));
                               }));
     Value *V2 = E2.VectorizedValue;
     if (V2->getType()->isIntOrIntVectorTy())
-      V2 = castToScalarTyElem(V2, all_of(E2.Scalars, [&](Value *V) {
+      V2 = castToScalarTyElem(V2, any_of(E2.Scalars, [&](Value *V) {
                                 return !isKnownNonNegative(
                                     V, SimplifyQuery(*R.DL));
                               }));
@@ -11820,7 +11820,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
   void add(const TreeEntry &E1, ArrayRef<int> Mask) {
     Value *V1 = E1.VectorizedValue;
     if (V1->getType()->isIntOrIntVectorTy())
-      V1 = castToScalarTyElem(V1, all_of(E1.Scalars, [&](Value *V) {
+      V1 = castToScalarTyElem(V1, any_of(E1.Scalars, [&](Value *V) {
                                 return !isKnownNonNegative(
                                     V, SimplifyQuery(*R.DL));
                               }));
@@ -14900,24 +14900,30 @@ bool BoUpSLP::collectValuesToDemote(
   // If the value is not a vectorized instruction in the expression and not used
   // by the insertelement instruction and not used in multiple vector nodes, it
   // cannot be demoted.
+  bool IsSignedNode = any_of(E.Scalars, [&](Value *R) {
+    return !isKnownNonNegative(R, SimplifyQuery(*DL));
+  });
   auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
     if (MultiNodeScalars.contains(V))
       return false;
-    if (OrigBitWidth > BitWidth) {
+    // For lat shuffle of sext/zext with many uses need to check the extra bit
+    // for unsigned values, otherwise may have incorrect casting for reused
+    // scalars.
+    bool IsSignedVal = !isKnownNonNegative(V, SimplifyQuery(*DL));
+    if ((!IsSignedNode || IsSignedVal) && OrigBitWidth > BitWidth) {
       APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
       if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
         return true;
     }
-    auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+    unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
     unsigned BitWidth1 = OrigBitWidth - NumSignBits;
-    bool IsSigned = !isKnownNonNegative(V, SimplifyQuery(*DL));
-    if (IsSigned)
+    if (IsSignedNode)
       ++BitWidth1;
     if (auto *I = dyn_cast<Instruction>(V)) {
       APInt Mask = DB->getDemandedBits(I);
       unsigned BitWidth2 =
           std::max<unsigned>(1, Mask.getBitWidth() - Mask.countl_zero());
-      while (!IsSigned && BitWidth2 < OrigBitWidth) {
+      while (!IsSignedNode && BitWidth2 < OrigBitWidth) {
         APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth2 - 1);
         if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
           break;

diff  --git a/llvm/test/Transforms/SLPVectorizer/RISCV/shuffled-gather-casted.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/shuffled-gather-casted.ll
index 9d7cf368fbc3f..7b65467d7f9ef 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/shuffled-gather-casted.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/shuffled-gather-casted.ll
@@ -59,13 +59,12 @@ define i32 @test1(ptr %p) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[D_0:%.*]] = load i16, ptr [[P]], align 4
 ; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x i16> <i16 0, i16 poison, i16 0, i16 0>, i16 [[D_0]], i32 1
-; CHECK-NEXT:    [[TMP1:%.*]] = or <4 x i16> [[TMP0]], zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = and <4 x i16> [[TMP1]], <i16 -16383, i16 -1, i16 -1, i16 -1>
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext <4 x i16> [[TMP0]] to <4 x i32>
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i32> <i32 49153, i32 65535, i32 65535, i32 65535>, <4 x i32> [[TMP3]], <4 x i32> <i32 0, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = or <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = and <4 x i32> [[TMP2]], <i32 -16383, i32 -1, i32 65535, i32 -1>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i32> <i32 -16383, i32 -1, i32 65535, i32 -1>, <4 x i32> [[TMP3]], <4 x i32> <i32 0, i32 5, i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq <4 x i32> [[TMP4]], <i32 -16383, i32 65535, i32 65535, i32 65535>
-; CHECK-NEXT:    [[TMP6:%.*]] = select <4 x i1> [[TMP5]], <4 x i16> [[TMP2]], <4 x i16> <i16 3, i16 4, i16 2, i16 1>
-; CHECK-NEXT:    [[TMP7:%.*]] = zext <4 x i16> [[TMP6]] to <4 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[TMP5]], <4 x i32> [[TMP6]], <4 x i32> <i32 3, i32 4, i32 2, i32 1>
 ; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]])
 ; CHECK-NEXT:    ret i32 [[TMP8]]
 ;


        


More information about the llvm-commits mailing list