[llvm] ce8ec31 - [SLP][REVEC] Support more mask pattern usage in shufflevector. (#106212)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 3 06:30:49 PDT 2024
Author: Han-Kuan Chen
Date: 2024-09-03T21:30:40+08:00
New Revision: ce8ec31298d5fbd81712af0f6bc34dae87f7f30c
URL: https://github.com/llvm/llvm-project/commit/ce8ec31298d5fbd81712af0f6bc34dae87f7f30c
DIFF: https://github.com/llvm/llvm-project/commit/ce8ec31298d5fbd81712af0f6bc34dae87f7f30c.diff
LOG: [SLP][REVEC] Support more mask pattern usage in shufflevector. (#106212)
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 93e7bfcdd87c44..e6a0e9b458966b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -293,8 +293,7 @@ static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
/// A group has the following features
/// 1. All of value in a group are shufflevector.
/// 2. The mask of all shufflevector is isExtractSubvectorMask.
-/// 3. The mask of all shufflevector uses all of the elements of the source (and
-/// the elements are used in order).
+/// 3. The mask of all shufflevector uses all of the elements of the source.
/// e.g., it is 1 group (%0)
/// %1 = shufflevector <16 x i8> %0, <16 x i8> poison,
/// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@@ -322,7 +321,8 @@ static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
auto *SV = cast<ShuffleVectorInst>(VL.front());
unsigned SVNumElements =
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
- unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
+ unsigned ShuffleMaskSize = SV->getShuffleMask().size();
+ unsigned GroupSize = SVNumElements / ShuffleMaskSize;
if (GroupSize == 0 || (VL.size() % GroupSize) != 0)
return 0;
unsigned NumGroup = 0;
@@ -330,7 +330,7 @@ static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
auto *SV = cast<ShuffleVectorInst>(VL[I]);
Value *Src = SV->getOperand(0);
ArrayRef<Value *> Group = VL.slice(I, GroupSize);
- SmallVector<int> ExtractionIndex(SVNumElements);
+ SmallBitVector ExpectedIndex(GroupSize);
if (!all_of(Group, [&](Value *V) {
auto *SV = cast<ShuffleVectorInst>(V);
// From the same source.
@@ -339,12 +339,11 @@ static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
int Index;
if (!SV->isExtractSubvectorMask(Index))
return false;
- for (int I : seq<int>(Index, Index + SV->getShuffleMask().size()))
- ExtractionIndex.push_back(I);
+ ExpectedIndex.set(Index / ShuffleMaskSize);
return true;
}))
return 0;
- if (!is_sorted(ExtractionIndex))
+ if (!ExpectedIndex.all())
return 0;
++NumGroup;
}
@@ -10289,12 +10288,40 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return VecCost;
};
if (SLPReVec && !E->isAltShuffle())
- return GetCostDiff(GetScalarCost, [](InstructionCost) {
- // shufflevector will be eliminated by instcombine because the
- // shufflevector masks are used in order (guaranteed by
- // getShufflevectorNumGroups). The vector cost is 0.
- return TTI::TCC_Free;
- });
+ return GetCostDiff(
+ GetScalarCost, [&](InstructionCost) -> InstructionCost {
+ // If a group uses mask in order, the shufflevector can be
+ // eliminated by instcombine. Then the cost is 0.
+ assert(isa<ShuffleVectorInst>(VL.front()) &&
+ "Not supported shufflevector usage.");
+ auto *SV = cast<ShuffleVectorInst>(VL.front());
+ unsigned SVNumElements =
+ cast<FixedVectorType>(SV->getOperand(0)->getType())
+ ->getNumElements();
+ unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
+ for (size_t I = 0, End = VL.size(); I != End; I += GroupSize) {
+ ArrayRef<Value *> Group = VL.slice(I, GroupSize);
+ int NextIndex = 0;
+ if (!all_of(Group, [&](Value *V) {
+ assert(isa<ShuffleVectorInst>(V) &&
+ "Not supported shufflevector usage.");
+ auto *SV = cast<ShuffleVectorInst>(V);
+ int Index;
+ bool isExtractSubvectorMask =
+ SV->isExtractSubvectorMask(Index);
+ assert(isExtractSubvectorMask &&
+ "Not supported shufflevector usage.");
+ if (NextIndex != Index)
+ return false;
+ NextIndex += SV->getShuffleMask().size();
+ return true;
+ }))
+ return ::getShuffleCost(
+ *TTI, TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
+ calculateShufflevectorMask(E->Scalars));
+ }
+ return TTI::TCC_Free;
+ });
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case Instruction::Freeze:
@@ -14072,9 +14099,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- // The current shufflevector usage always duplicate the source.
- V = Builder.CreateShuffleVector(Src,
- calculateShufflevectorMask(E->Scalars));
+ assert(isa<ShuffleVectorInst>(Src) &&
+ "Not supported shufflevector usage.");
+ auto *SVSrc = cast<ShuffleVectorInst>(Src);
+ assert(isa<PoisonValue>(SVSrc->getOperand(1)) &&
+ "Not supported shufflevector usage.");
+ SmallVector<int> ThisMask(calculateShufflevectorMask(E->Scalars));
+ SmallVector<int> NewMask(ThisMask.size());
+ transform(ThisMask, NewMask.begin(),
+ [&SVSrc](int Mask) { return SVSrc->getShuffleMask()[Mask]; });
+ V = Builder.CreateShuffleVector(SVSrc->getOperand(0), NewMask);
propagateIRFlags(V, E->Scalars, VL0);
} else {
assert(E->isAltShuffle() &&
diff --git a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
index 6028a8b918941c..1fc0b0306d1194 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
@@ -34,17 +34,9 @@ define void @test2(ptr %in, ptr %out) {
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
-; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
-; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP7]], align 8
-; CHECK-NEXT: store <4 x i64> [[TMP4]], ptr [[TMP8]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i32> [[TMP0]] to <8 x i64>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i64> [[TMP1]], <8 x i64> poison, <8 x i32> <i32 2, i32 3, i32 0, i32 1, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: store <8 x i64> [[TMP2]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: ret void
;
entry:
@@ -67,3 +59,26 @@ entry:
store <2 x i64> %8, ptr %12, align 8
ret void
}
+
+define void @test3(<16 x i32> %0, ptr %out) {
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[TMP0:%.*]], <16 x i32> poison, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: store <16 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %1 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 12, i32 13, i32 14, i32 15>
+ %2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 8, i32 9, i32 10, i32 11>
+ %3 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ %4 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %5 = getelementptr inbounds i32, ptr %out, i64 0
+ %6 = getelementptr inbounds i32, ptr %out, i64 4
+ %7 = getelementptr inbounds i32, ptr %out, i64 8
+ %8 = getelementptr inbounds i32, ptr %out, i64 12
+ store <4 x i32> %1, ptr %5, align 4
+ store <4 x i32> %2, ptr %6, align 4
+ store <4 x i32> %3, ptr %7, align 4
+ store <4 x i32> %4, ptr %8, align 4
+ ret void
+}
More information about the llvm-commits
mailing list