[llvm] [SLP]Improve minbitwidth analysis for shifts. (PR #84356)
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 14 16:49:28 PDT 2024
https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/84356
>From 107d47c8dddaa4b106bac5267ddc31f841d66208 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 7 Mar 2024 18:25:36 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
=?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Created using spr 1.3.5
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 112 +++++++++++++++---
.../X86/reorder-possible-strided-node.ll | 6 +-
.../X86/reorder_diamond_match.ll | 25 ++--
3 files changed, 112 insertions(+), 31 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 1889bc09e85028..3364f34d0148cc 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10094,16 +10094,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
BitWidth = UserIt->second.second;
}
}
- auto CheckBitwidth = [&](const TreeEntry &TE) {
- Type *ScalarTy = TE.Scalars.front()->getType();
- if (!ScalarTy->isIntegerTy())
- return true;
- unsigned TEBitWidth = DL->getTypeStoreSize(ScalarTy);
- auto UserIt = MinBWs.find(TEUseEI.UserTE);
- if (UserIt != MinBWs.end())
- TEBitWidth = UserIt->second.second;
- return BitWidth == TEBitWidth;
- };
SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
DenseMap<Value *, int> UsedValuesEntry;
for (Value *V : VL) {
@@ -10138,8 +10128,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
continue;
}
- if (!CheckBitwidth(*TEPtr))
- continue;
// Check if the user node of the TE comes after user node of TEPtr,
// otherwise TEPtr depends on TE.
if ((TEInsertBlock != InsertPt->getParent() ||
@@ -10157,7 +10145,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
VTE = *It->getSecond().begin();
// Iterate through all vectorized nodes.
auto *MIt = find_if(It->getSecond(), [&](const TreeEntry *MTE) {
- return MTE->State == TreeEntry::Vectorize && CheckBitwidth(*MTE);
+ return MTE->State == TreeEntry::Vectorize;
});
if (MIt == It->getSecond().end())
continue;
@@ -10167,8 +10155,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
continue;
- if (!CheckBitwidth(*VTE))
- continue;
VToTEs.insert(VTE);
}
if (VToTEs.empty())
@@ -10216,6 +10202,45 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
return std::nullopt;
}
+ if (BitWidth > 0) {
+ // Check if the used TEs supposed to be resized and choose the best
+ // candidates.
+ unsigned NodesBitWidth = 0;
+ auto CheckBitwidth = [&](const TreeEntry &TE) {
+ unsigned TEBitWidth = BitWidth;
+ auto UserIt = MinBWs.find(TEUseEI.UserTE);
+ if (UserIt != MinBWs.end())
+ TEBitWidth = UserIt->second.second;
+ if (BitWidth <= TEBitWidth) {
+ if (NodesBitWidth == 0)
+ NodesBitWidth = TEBitWidth;
+ return NodesBitWidth == TEBitWidth;
+ }
+ return false;
+ };
+ for (auto [Idx, Set] : enumerate(UsedTEs)) {
+ DenseSet<const TreeEntry *> ForRemoval;
+ for (const TreeEntry *TE : Set) {
+ if (!CheckBitwidth(*TE))
+ ForRemoval.insert(TE);
+ }
+ // All elements must be removed - remove the whole container.
+ if (ForRemoval.size() == Set.size()) {
+ Set.clear();
+ continue;
+ }
+ for (const TreeEntry *TE : ForRemoval)
+ Set.erase(TE);
+ }
+ for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
+ if (It->empty()) {
+ UsedTEs.erase(It);
+ continue;
+ }
+ std::advance(It, 1);
+ }
+ }
+
unsigned VF = 0;
if (UsedTEs.size() == 1) {
// Keep the order to avoid non-determinism.
@@ -13946,6 +13971,63 @@ bool BoUpSLP::collectValuesToDemote(
MaxDepthLevel = std::max(Level1, Level2);
break;
}
+ case Instruction::Shl: {
+ // If we are truncating the result of this SHL, and if it's a shift of an
+ // inrange amount, we can always perform a SHL in a smaller type.
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
+ case Instruction::LShr: {
+ // If this is a truncate of a logical shr, we can truncate it to a smaller
+ // lshr iff we know that the bits we would otherwise be shifting in are
+ // already zeros.
+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ !MaskedValueIsZero(I->getOperand(0), ShiftedBits, SimplifyQuery(*DL)) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
+ case Instruction::AShr: {
+ // If this is a truncate of an arithmetic shr, we can truncate it to a
+ // smaller ashr iff we know that all the bits from the sign bit of the
+ // original type and the sign bit of the truncate type are similar.
+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+ unsigned Level1, Level2;
+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+ unsigned ShiftedBits = OrigBitWidth - BitWidth;
+ if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+ ShiftedBits >=
+ ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT) ||
+ !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level1, IsProfitableToDemote) ||
+ !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+ BitWidth, ToDemote, DemotedConsts, Visited,
+ Level2, IsProfitableToDemote))
+ return false;
+ MaxDepthLevel = std::max(Level1, Level2);
+ break;
+ }
// We can demote selects if we can demote their true and false values.
case Instruction::Select: {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index 6f5d3d3785e0c8..6378f696b470d4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
; CHECK-NEXT: [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT: [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT: store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT: [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
; CHECK-NEXT: ret void
;
entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..91ee4dba07009f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT: [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT: [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT: [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT: [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT: [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT: [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT: [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT: [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT: [[TMP15:%.*]] = zext <4 x i16> [[TMP14]] to <4 x i32>
; CHECK-NEXT: store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
; CHECK-NEXT: ret void
;
More information about the llvm-commits
mailing list