[llvm] 6c1d445 - [SLP]Improve minbitwidth analysis for shifts.

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 06:08:08 PDT 2024


Author: Alexey Bataev
Date: 2024-03-20T09:07:26-04:00
New Revision: 6c1d4454ad00414cf20f9a69e04856de99f6bf1d

URL: https://github.com/llvm/llvm-project/commit/6c1d4454ad00414cf20f9a69e04856de99f6bf1d
DIFF: https://github.com/llvm/llvm-project/commit/6c1d4454ad00414cf20f9a69e04856de99f6bf1d.diff

LOG: [SLP]Improve minbitwidth analysis for shifts.

Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/84356

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
    llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b0cb3cd47dd793..5d59f35f30810e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13995,9 +13995,11 @@ bool BoUpSLP::collectValuesToDemote(
     if (MultiNodeScalars.contains(V))
       return false;
     uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
-    APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-    if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
-      return true;
+    if (OrigBitWidth < BitWidth) {
+      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+      if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
+        return true;
+    }
     auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
     unsigned BitWidth1 = OrigBitWidth - NumSignBits;
     if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
@@ -14042,6 +14044,30 @@ bool BoUpSLP::collectValuesToDemote(
     }
     return true;
   };
+  auto AttemptCheckBitwidth =
+      [&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) {
+        // Try all bitwidth < OrigBitWidth.
+        NeedToExit = false;
+        uint32_t OrigBitWidth = DL->getTypeSizeInBits(I->getType());
+        unsigned BestFailBitwidth = 0;
+        for (; BitWidth < OrigBitWidth; BitWidth *= 2) {
+          if (Checker(BitWidth, OrigBitWidth))
+            return true;
+          if (BestFailBitwidth == 0 && FinalAnalysis())
+            BestFailBitwidth = BitWidth;
+        }
+        if (BitWidth >= OrigBitWidth) {
+          if (BestFailBitwidth == 0) {
+            BitWidth = OrigBitWidth;
+            return false;
+          }
+          MaxDepthLevel = 1;
+          BitWidth = BestFailBitwidth;
+          NeedToExit = true;
+          return true;
+        }
+        return false;
+      };
   bool NeedToExit = false;
   switch (I->getOpcode()) {
 
@@ -14074,6 +14100,71 @@ bool BoUpSLP::collectValuesToDemote(
       return false;
     break;
   }
+  case Instruction::Shl: {
+    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    // If we are truncating the result of this SHL, and if it's a shift of an
+    // inrange amount, we can always perform a SHL in a smaller type.
+    if (!AttemptCheckBitwidth(
+            [&](unsigned BitWidth, unsigned) {
+              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+              return AmtKnownBits.getMaxValue().ult(BitWidth);
+            },
+            NeedToExit))
+      return false;
+    if (NeedToExit)
+      return true;
+    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+      return false;
+    break;
+  }
+  case Instruction::LShr: {
+    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    // If this is a truncate of a logical shr, we can truncate it to a smaller
+    // lshr iff we know that the bits we would otherwise be shifting in are
+    // already zeros.
+    if (!AttemptCheckBitwidth(
+            [&](unsigned BitWidth, unsigned OrigBitWidth) {
+              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+              APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+              return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+                     MaskedValueIsZero(I->getOperand(0), ShiftedBits,
+                                       SimplifyQuery(*DL));
+            },
+            NeedToExit))
+      return false;
+    if (NeedToExit)
+      return true;
+    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+      return false;
+    break;
+  }
+  case Instruction::AShr: {
+    // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    // If this is a truncate of an arithmetic shr, we can truncate it to a
+    // smaller ashr iff we know that all the bits from the sign bit of the
+    // original type and the sign bit of the truncate type are similar.
+    if (!AttemptCheckBitwidth(
+            [&](unsigned BitWidth, unsigned OrigBitWidth) {
+              KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+              unsigned ShiftedBits = OrigBitWidth - BitWidth;
+              return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+                     ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0,
+                                                      AC, nullptr, DT);
+            },
+            NeedToExit))
+      return false;
+    if (NeedToExit)
+      return true;
+    if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+      return false;
+    break;
+  }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index f6db831ed97b19..4a23abf182e888 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT:    [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT:    [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..dce85b4b2a195e 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT:    [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT:    [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT:    [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT:    [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT:    [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP15:%.*]] = sext <4 x i16> [[TMP14]] to <4 x i32>
 ; CHECK-NEXT:    store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
 ; CHECK-NEXT:    ret void
 ;


        


More information about the llvm-commits mailing list