[llvm] 6c1d445 - [SLP]Improve minbitwidth analysis for shifts.
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 20 06:08:08 PDT 2024
Author: Alexey Bataev
Date: 2024-03-20T09:07:26-04:00
New Revision: 6c1d4454ad00414cf20f9a69e04856de99f6bf1d
URL: https://github.com/llvm/llvm-project/commit/6c1d4454ad00414cf20f9a69e04856de99f6bf1d
DIFF: https://github.com/llvm/llvm-project/commit/6c1d4454ad00414cf20f9a69e04856de99f6bf1d.diff
LOG: [SLP]Improve minbitwidth analysis for shifts.
Adds improved bitwidth analysis for shl/ashr/lshr instructions. The
analysis is based on similar version in InstCombiner.
Reviewers: RKSimon
Reviewed By: RKSimon
Pull Request: https://github.com/llvm/llvm-project/pull/84356
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b0cb3cd47dd793..5d59f35f30810e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13995,9 +13995,11 @@ bool BoUpSLP::collectValuesToDemote(
if (MultiNodeScalars.contains(V))
return false;
uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
- APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
- if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
- return true;
+ if (OrigBitWidth < BitWidth) {
+ APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+ if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
+ return true;
+ }
auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
unsigned BitWidth1 = OrigBitWidth - NumSignBits;
if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
@@ -14042,6 +14044,30 @@ bool BoUpSLP::collectValuesToDemote(
}
return true;
};
+ auto AttemptCheckBitwidth =
+ [&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) {
+ // Try all bitwidth < OrigBitWidth.
+ NeedToExit = false;
+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(I->getType());
+ unsigned BestFailBitwidth = 0;
+ for (; BitWidth < OrigBitWidth; BitWidth *= 2) {
+ if (Checker(BitWidth, OrigBitWidth))
+ return true;
+ if (BestFailBitwidth == 0 && FinalAnalysis())
+ BestFailBitwidth = BitWidth;
+ }
+ if (BitWidth >= OrigBitWidth) {
+ if (BestFailBitwidth == 0) {
+ BitWidth = OrigBitWidth;
+ return false;
+ }
+ MaxDepthLevel = 1;
+ BitWidth = BestFailBitwidth;
+ NeedToExit = true;
+ return true;
+ }
+ return false;
+ };
bool NeedToExit = false;
switch (I->getOpcode()) {
@@ -14074,6 +14100,71 @@ bool BoUpSLP::collectValuesToDemote(
return false;
break;
}
+ case Instruction::Shl: {
+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ // If we are truncating the result of this SHL, and if it's a shift of an
+ // inrange amount, we can always perform a SHL in a smaller type.
+ if (!AttemptCheckBitwidth(
+ [&](unsigned BitWidth, unsigned) {
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ return AmtKnownBits.getMaxValue().ult(BitWidth);
+ },
+ NeedToExit))
+ return false;
+ if (NeedToExit)
+ return true;
+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+ return false;
+ break;
+ }
+ case Instruction::LShr: {
+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ // If this is a truncate of a logical shr, we can truncate it to a smaller
+ // lshr iff we know that the bits we would otherwise be shifting in are
+ // already zeros.
+ if (!AttemptCheckBitwidth(
+ [&](unsigned BitWidth, unsigned OrigBitWidth) {
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+ return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+ MaskedValueIsZero(I->getOperand(0), ShiftedBits,
+ SimplifyQuery(*DL));
+ },
+ NeedToExit))
+ return false;
+ if (NeedToExit)
+ return true;
+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+ return false;
+ break;
+ }
+ case Instruction::AShr: {
+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ // If this is a truncate of an arithmetic shr, we can truncate it to a
+ // smaller ashr iff we know that all the bits from the sign bit of the
+ // original type and the sign bit of the truncate type are similar.
+ if (!AttemptCheckBitwidth(
+ [&](unsigned BitWidth, unsigned OrigBitWidth) {
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ unsigned ShiftedBits = OrigBitWidth - BitWidth;
+ return AmtKnownBits.getMaxValue().ult(BitWidth) &&
+ ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0,
+ AC, nullptr, DT);
+ },
+ NeedToExit))
+ return false;
+ if (NeedToExit)
+ return true;
+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
+ return false;
+ break;
+ }
// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index f6db831ed97b19..4a23abf182e888 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; CHECK-NEXT: [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT: store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT: [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
; CHECK-NEXT: ret void
;
entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..dce85b4b2a195e 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT: [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT: [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT: [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT: [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT: [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT: [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT: [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT: [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP15:%.*]] = sext <4 x i16> [[TMP14]] to <4 x i32>
; CHECK-NEXT: store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
; CHECK-NEXT: ret void
;
More information about the llvm-commits
mailing list