[llvm] [SLP][REVEC] Make SLP vectorize shufflevector. (PR #102489)
Han-Kuan Chen via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 8 08:22:39 PDT 2024
https://github.com/HanKuanChen created https://github.com/llvm/llvm-project/pull/102489
Add getShufflevectorNumGroups to vectorize shufflevector.
Current getShufflevectorNumGroups can only vectorize limited pattern
(e.g., the masks of shufflevector use the elements of the source in
order).
In addition, ReuseShuffleIndices and ReorderIndices are not supported.
>From 268a88a7bfa98a74057365be7066604da9b681d0 Mon Sep 17 00:00:00 2001
From: Han-Kuan Chen <hankuan.chen at sifive.com>
Date: Wed, 10 Jul 2024 02:06:09 -0700
Subject: [PATCH 1/2] [SLP][REVEC] Pre-commit test.
---
.../SLPVectorizer/revec-shufflevector.ll | 86 +++++++++++++++++++
1 file changed, 86 insertions(+)
create mode 100644 llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
diff --git a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
new file mode 100644
index 00000000000000..0c80d5c0ccb53a
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=slp-vectorizer,instcombine -S -slp-revec -slp-max-reg-size=1024 -slp-threshold=-100 %s | FileCheck %s
+
+define void @test1(ptr %in, ptr %out) {
+; CHECK-LABEL: @test1(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
+; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 48
+; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP9]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP7]], ptr [[TMP10]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP8]], ptr [[TMP11]], align 8
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = load <8 x i32>, ptr %in, align 1
+ %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ %3 = zext <4 x i32> %1 to <4 x i64>
+ %4 = zext <4 x i32> %2 to <4 x i64>
+ %5 = shufflevector <4 x i64> %3, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+ %6 = shufflevector <4 x i64> %3, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+ %7 = shufflevector <4 x i64> %4, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+ %8 = shufflevector <4 x i64> %4, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+ %9 = getelementptr inbounds i64, ptr %out, i64 0
+ %10 = getelementptr inbounds i64, ptr %out, i64 2
+ %11 = getelementptr inbounds i64, ptr %out, i64 4
+ %12 = getelementptr inbounds i64, ptr %out, i64 6
+ store <2 x i64> %5, ptr %9, align 8
+ store <2 x i64> %6, ptr %10, align 8
+ store <2 x i64> %7, ptr %11, align 8
+ store <2 x i64> %8, ptr %12, align 8
+ ret void
+}
+
+define void @test2(ptr %in, ptr %out) {
+; CHECK-LABEL: @test2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
+; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
+; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
+; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 48
+; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP9]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP7]], ptr [[TMP10]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP8]], ptr [[TMP11]], align 8
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = load <8 x i32>, ptr %in, align 1
+ %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ %3 = zext <4 x i32> %1 to <4 x i64>
+ %4 = zext <4 x i32> %2 to <4 x i64>
+ %5 = shufflevector <4 x i64> %3, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+ %6 = shufflevector <4 x i64> %3, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+ %7 = shufflevector <4 x i64> %4, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
+ %8 = shufflevector <4 x i64> %4, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
+ %9 = getelementptr inbounds i64, ptr %out, i64 0
+ %10 = getelementptr inbounds i64, ptr %out, i64 2
+ %11 = getelementptr inbounds i64, ptr %out, i64 4
+ %12 = getelementptr inbounds i64, ptr %out, i64 6
+ store <2 x i64> %5, ptr %9, align 8
+ store <2 x i64> %6, ptr %10, align 8
+ store <2 x i64> %7, ptr %11, align 8
+ store <2 x i64> %8, ptr %12, align 8
+ ret void
+}
>From 0da34834ba5d6c10fbc38b9c0a371d41474d76d0 Mon Sep 17 00:00:00 2001
From: Han-Kuan Chen <hankuan.chen at sifive.com>
Date: Wed, 7 Aug 2024 20:06:01 -0700
Subject: [PATCH 2/2] [SLP][REVEC] Make SLP vectorize shufflevector.
Add getShufflevectorNumGroups to vectorize shufflevector.
Current getShufflevectorNumGroups can only vectorize limited pattern
(e.g., the masks of shufflevector use the elements of the source in
order).
In addition, ReuseShuffleIndices and ReorderIndices are not supported.
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 364 ++++++++++++------
.../SLPVectorizer/revec-shufflevector.ll | 29 +-
2 files changed, 251 insertions(+), 142 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 4186b17e644b0b..8e8531a14652ae 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -268,6 +268,98 @@ static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
Mask.swap(NewMask);
}
+/// \returns the number of groups of shufflevector
+/// A group has the following features
+/// 1. All of value in a group are shufflevector.
+/// 2. The mask of all shufflevector is isExtractSubvectorMask.
+/// 3. The mask of all shufflevector uses all of the elements of the source (and
+/// the elements are used in order).
+/// e.g., it is 1 group (%0)
+/// %1 = shufflevector <16 x i8> %0, <16 x i8> poison,
+/// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+/// %2 = shufflevector <16 x i8> %0, <16 x i8> poison,
+/// <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+/// it is 2 groups (%3 and %4)
+/// %5 = shufflevector <8 x i16> %3, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+/// %6 = shufflevector <8 x i16> %3, <8 x i16> poison,
+/// <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+/// %7 = shufflevector <8 x i16> %4, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+/// %8 = shufflevector <8 x i16> %4, <8 x i16> poison,
+/// <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+/// it is 0 group
+/// %12 = shufflevector <8 x i16> %10, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+/// %13 = shufflevector <8 x i16> %11, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+static unsigned getShufflevectorNumGroups(ArrayRef<Value *> VL) {
+ if (VL.empty())
+ return 0;
+ if (!all_of(VL, IsaPred<ShuffleVectorInst>))
+ return 0;
+ auto *SV = cast<ShuffleVectorInst>(VL.front());
+ unsigned SVNumElements =
+ cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+ unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
+ if (GroupSize == 0 || (VL.size() % GroupSize) != 0)
+ return 0;
+ unsigned NumGroup = 0;
+ for (size_t I = 0, E = VL.size(); I != E; I += GroupSize) {
+ auto *SV = cast<ShuffleVectorInst>(VL[I]);
+ Value *Src = SV->getOperand(0);
+ ArrayRef<Value *> Group = VL.slice(I, GroupSize);
+ SmallVector<int> ExtractionIndex(SVNumElements);
+ if (!all_of(Group, [&](Value *V) {
+ auto *SV = cast<ShuffleVectorInst>(V);
+ // From the same source.
+ if (SV->getOperand(0) != Src)
+ return false;
+ int Index;
+ if (!SV->isExtractSubvectorMask(Index))
+ return false;
+ for (int I : seq<int>(Index, Index + SV->getShuffleMask().size()))
+ ExtractionIndex.push_back(I);
+ return true;
+ }))
+ return 0;
+ if (!std::is_sorted(ExtractionIndex.begin(), ExtractionIndex.end()))
+ return 0;
+ ++NumGroup;
+ }
+ assert(NumGroup == (VL.size() / GroupSize) && "Unexpected number of groups");
+ return NumGroup;
+}
+
+/// \returns a shufflevector mask which is used to vectorize shufflevectors
+/// e.g.,
+/// %5 = shufflevector <8 x i16> %3, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+/// %6 = shufflevector <8 x i16> %3, <8 x i16> poison,
+/// <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+/// %7 = shufflevector <8 x i16> %4, <8 x i16> poison,
+/// <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+/// %8 = shufflevector <8 x i16> %4, <8 x i16> poison,
+/// <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+/// the result is
+/// <0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31>
+static SmallVector<int> calculateShufflevectorMask(ArrayRef<Value *> VL) {
+ assert(getShufflevectorNumGroups(VL) && "Not supported shufflevector usage.");
+ auto *SV = cast<ShuffleVectorInst>(VL.front());
+ unsigned SVNumElements =
+ cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+ SmallVector<int> Mask;
+ unsigned AccumulateLength = 0;
+ for (Value *V : VL) {
+ ShuffleVectorInst *SV = cast<ShuffleVectorInst>(V);
+ for (int M : SV->getShuffleMask())
+ Mask.push_back(M == PoisonMaskElem ? PoisonMaskElem
+ : AccumulateLength + M);
+ AccumulateLength += SVNumElements;
+ }
+ return Mask;
+}
+
/// \returns True if the value is a constant (but not globals/constant
/// expressions).
static bool isConstant(Value *V) {
@@ -6643,9 +6735,12 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
return TreeEntry::Vectorize;
}
case Instruction::ShuffleVector: {
- // If this is not an alternate sequence of opcode like add-sub
- // then do not vectorize this instruction.
if (!S.isAltShuffle()) {
+ // REVEC can support non alternate shuffle.
+ if (SLPReVec && getShufflevectorNumGroups(VL))
+ return TreeEntry::Vectorize;
+ // If this is not an alternate sequence of opcode like add-sub
+ // then do not vectorize this instruction.
LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n");
return TreeEntry::NeedToGather;
}
@@ -10003,13 +10098,14 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case Instruction::ShuffleVector: {
- assert(E->isAltShuffle() &&
- ((Instruction::isBinaryOp(E->getOpcode()) &&
- Instruction::isBinaryOp(E->getAltOpcode())) ||
- (Instruction::isCast(E->getOpcode()) &&
- Instruction::isCast(E->getAltOpcode())) ||
- (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
- "Invalid Shuffle Vector Operand");
+ if (!SLPReVec || E->isAltShuffle())
+ assert(E->isAltShuffle() &&
+ ((Instruction::isBinaryOp(E->getOpcode()) &&
+ Instruction::isBinaryOp(E->getAltOpcode())) ||
+ (Instruction::isCast(E->getOpcode()) &&
+ Instruction::isCast(E->getAltOpcode())) ||
+ (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
+ "Invalid Shuffle Vector Operand");
// Try to find the previous shuffle node with the same operands and same
// main/alternate ops.
auto TryFindNodeWithEqualOperands = [=]() {
@@ -10116,6 +10212,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// TODO: Check the reverse order too.
return VecCost;
};
+ if (SLPReVec && !E->isAltShuffle())
+ return GetCostDiff(GetScalarCost, [](InstructionCost) {
+ // shufflevector will be eliminated by instcombine because the
+ // shufflevector masks are used in order (guaranteed by
+ // getShufflevectorNumGroups). The vector cost is 0.
+ return InstructionCost();
+ });
return GetCostDiff(GetScalarCost, GetVectorCost);
}
default:
@@ -13699,128 +13802,151 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return V;
}
case Instruction::ShuffleVector: {
- assert(E->isAltShuffle() &&
- ((Instruction::isBinaryOp(E->getOpcode()) &&
- Instruction::isBinaryOp(E->getAltOpcode())) ||
- (Instruction::isCast(E->getOpcode()) &&
- Instruction::isCast(E->getAltOpcode())) ||
- (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
- "Invalid Shuffle Vector Operand");
-
- Value *LHS = nullptr, *RHS = nullptr;
- if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) {
+ Value *V;
+ if (SLPReVec && !E->isAltShuffle()) {
+ assert(E->ReuseShuffleIndices.empty() &&
+ "Not support ReuseShuffleIndices yet.");
+ assert(E->ReorderIndices.empty() && "Not support ReorderIndices yet.");
setInsertPointAfterBundle(E);
- LHS = vectorizeOperand(E, 0, PostponedPHIs);
+ Value *Src = vectorizeOperand(E, 0, PostponedPHIs);
if (E->VectorizedValue) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- RHS = vectorizeOperand(E, 1, PostponedPHIs);
- } else {
- setInsertPointAfterBundle(E);
- LHS = vectorizeOperand(E, 0, PostponedPHIs);
- }
- if (E->VectorizedValue) {
- LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
- return E->VectorizedValue;
- }
- if (LHS && RHS &&
- ((Instruction::isBinaryOp(E->getOpcode()) &&
- (LHS->getType() != VecTy || RHS->getType() != VecTy)) ||
- (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()))) {
- assert((It != MinBWs.end() || getOperandEntry(E, 0)->isGather() ||
- getOperandEntry(E, 1)->isGather() ||
- MinBWs.contains(getOperandEntry(E, 0)) ||
- MinBWs.contains(getOperandEntry(E, 1))) &&
- "Expected item in MinBWs.");
- Type *CastTy = VecTy;
- if (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()) {
- if (cast<VectorType>(LHS->getType())
- ->getElementType()
- ->getIntegerBitWidth() < cast<VectorType>(RHS->getType())
- ->getElementType()
- ->getIntegerBitWidth())
- CastTy = RHS->getType();
- else
- CastTy = LHS->getType();
- }
- if (LHS->getType() != CastTy)
- LHS = Builder.CreateIntCast(LHS, CastTy, GetOperandSignedness(0));
- if (RHS->getType() != CastTy)
- RHS = Builder.CreateIntCast(RHS, CastTy, GetOperandSignedness(1));
- }
-
- Value *V0, *V1;
- if (Instruction::isBinaryOp(E->getOpcode())) {
- V0 = Builder.CreateBinOp(
- static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS);
- V1 = Builder.CreateBinOp(
- static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS);
- } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
- V0 = Builder.CreateCmp(CI0->getPredicate(), LHS, RHS);
- auto *AltCI = cast<CmpInst>(E->getAltOp());
- CmpInst::Predicate AltPred = AltCI->getPredicate();
- V1 = Builder.CreateCmp(AltPred, LHS, RHS);
+ // The current shufflevector usage always duplicate the source.
+ V = Builder.CreateShuffleVector(Src,
+ calculateShufflevectorMask(E->Scalars));
+ propagateIRFlags(V, E->Scalars, VL0);
} else {
- if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) {
- unsigned SrcBWSz = DL->getTypeSizeInBits(
- cast<VectorType>(LHS->getType())->getElementType());
- unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
- if (BWSz <= SrcBWSz) {
- if (BWSz < SrcBWSz)
- LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
- assert(LHS->getType() == VecTy && "Expected same type as operand.");
- if (auto *I = dyn_cast<Instruction>(LHS))
- LHS = propagateMetadata(I, E->Scalars);
- E->VectorizedValue = LHS;
- ++NumVectorInstructions;
- return LHS;
+ assert(E->isAltShuffle() &&
+ ((Instruction::isBinaryOp(E->getOpcode()) &&
+ Instruction::isBinaryOp(E->getAltOpcode())) ||
+ (Instruction::isCast(E->getOpcode()) &&
+ Instruction::isCast(E->getAltOpcode())) ||
+ (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) &&
+ "Invalid Shuffle Vector Operand");
+
+ Value *LHS = nullptr, *RHS = nullptr;
+ if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) {
+ setInsertPointAfterBundle(E);
+ LHS = vectorizeOperand(E, 0, PostponedPHIs);
+ if (E->VectorizedValue) {
+ LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
+ return E->VectorizedValue;
}
+ RHS = vectorizeOperand(E, 1, PostponedPHIs);
+ } else {
+ setInsertPointAfterBundle(E);
+ LHS = vectorizeOperand(E, 0, PostponedPHIs);
}
- V0 = Builder.CreateCast(
- static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
- V1 = Builder.CreateCast(
- static_cast<Instruction::CastOps>(E->getAltOpcode()), LHS, VecTy);
- }
- // Add V0 and V1 to later analysis to try to find and remove matching
- // instruction, if any.
- for (Value *V : {V0, V1}) {
- if (auto *I = dyn_cast<Instruction>(V)) {
- GatherShuffleExtractSeq.insert(I);
- CSEBlocks.insert(I->getParent());
+ if (E->VectorizedValue) {
+ LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
+ return E->VectorizedValue;
+ }
+ if (LHS && RHS &&
+ ((Instruction::isBinaryOp(E->getOpcode()) &&
+ (LHS->getType() != VecTy || RHS->getType() != VecTy)) ||
+ (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()))) {
+ assert((It != MinBWs.end() ||
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
+ getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
+ MinBWs.contains(getOperandEntry(E, 0)) ||
+ MinBWs.contains(getOperandEntry(E, 1))) &&
+ "Expected item in MinBWs.");
+ Type *CastTy = VecTy;
+ if (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()) {
+ if (cast<VectorType>(LHS->getType())
+ ->getElementType()
+ ->getIntegerBitWidth() < cast<VectorType>(RHS->getType())
+ ->getElementType()
+ ->getIntegerBitWidth())
+ CastTy = RHS->getType();
+ else
+ CastTy = LHS->getType();
+ }
+ if (LHS->getType() != CastTy)
+ LHS = Builder.CreateIntCast(LHS, CastTy, GetOperandSignedness(0));
+ if (RHS->getType() != CastTy)
+ RHS = Builder.CreateIntCast(RHS, CastTy, GetOperandSignedness(1));
}
- }
- // Create shuffle to take alternate operations from the vector.
- // Also, gather up main and alt scalar ops to propagate IR flags to
- // each vector operation.
- ValueList OpScalars, AltScalars;
- SmallVector<int> Mask;
- E->buildAltOpShuffleMask(
- [E, this](Instruction *I) {
- assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
- return isAlternateInstruction(I, E->getMainOp(), E->getAltOp(),
- *TLI);
- },
- Mask, &OpScalars, &AltScalars);
+ Value *V0, *V1;
+ if (Instruction::isBinaryOp(E->getOpcode())) {
+ V0 = Builder.CreateBinOp(
+ static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS);
+ V1 = Builder.CreateBinOp(
+ static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS);
+ } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) {
+ V0 = Builder.CreateCmp(CI0->getPredicate(), LHS, RHS);
+ auto *AltCI = cast<CmpInst>(E->getAltOp());
+ CmpInst::Predicate AltPred = AltCI->getPredicate();
+ V1 = Builder.CreateCmp(AltPred, LHS, RHS);
+ } else {
+ if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) {
+ unsigned SrcBWSz = DL->getTypeSizeInBits(
+ cast<VectorType>(LHS->getType())->getElementType());
+ unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+ if (BWSz <= SrcBWSz) {
+ if (BWSz < SrcBWSz)
+ LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
+ assert(LHS->getType() == VecTy &&
+ "Expected same type as operand.");
+ if (auto *I = dyn_cast<Instruction>(LHS))
+ LHS = propagateMetadata(I, E->Scalars);
+ E->VectorizedValue = LHS;
+ ++NumVectorInstructions;
+ return LHS;
+ }
+ }
+ V0 = Builder.CreateCast(
+ static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
+ V1 = Builder.CreateCast(
+ static_cast<Instruction::CastOps>(E->getAltOpcode()), LHS, VecTy);
+ }
+ // Add V0 and V1 to later analysis to try to find and remove matching
+ // instruction, if any.
+ for (Value *V : {V0, V1}) {
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ GatherShuffleExtractSeq.insert(I);
+ CSEBlocks.insert(I->getParent());
+ }
+ }
- propagateIRFlags(V0, OpScalars, E->getMainOp(), It == MinBWs.end());
- propagateIRFlags(V1, AltScalars, E->getAltOp(), It == MinBWs.end());
- auto DropNuwFlag = [&](Value *Vec, unsigned Opcode) {
- // Drop nuw flags for abs(sub(commutative), true).
- if (auto *I = dyn_cast<Instruction>(Vec);
- I && Opcode == Instruction::Sub && !MinBWs.contains(E) &&
- any_of(E->Scalars, [](Value *V) {
- auto *IV = cast<Instruction>(V);
- return IV->getOpcode() == Instruction::Sub &&
- isCommutative(cast<Instruction>(IV));
- }))
- I->setHasNoUnsignedWrap(/*b=*/false);
- };
- DropNuwFlag(V0, E->getOpcode());
- DropNuwFlag(V1, E->getAltOpcode());
+ // Create shuffle to take alternate operations from the vector.
+ // Also, gather up main and alt scalar ops to propagate IR flags to
+ // each vector operation.
+ ValueList OpScalars, AltScalars;
+ SmallVector<int> Mask;
+ E->buildAltOpShuffleMask(
+ [E, this](Instruction *I) {
+ assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
+ return isAlternateInstruction(I, E->getMainOp(), E->getAltOp(),
+ *TLI);
+ },
+ Mask, &OpScalars, &AltScalars);
+
+ propagateIRFlags(V0, OpScalars, E->getMainOp(), It == MinBWs.end());
+ propagateIRFlags(V1, AltScalars, E->getAltOp(), It == MinBWs.end());
+ auto DropNuwFlag = [&](Value *Vec, unsigned Opcode) {
+ // Drop nuw flags for abs(sub(commutative), true).
+ if (auto *I = dyn_cast<Instruction>(Vec);
+ I && Opcode == Instruction::Sub && !MinBWs.contains(E) &&
+ any_of(E->Scalars, [](Value *V) {
+ auto *IV = cast<Instruction>(V);
+ return IV->getOpcode() == Instruction::Sub &&
+ isCommutative(cast<Instruction>(IV));
+ }))
+ I->setHasNoUnsignedWrap(/*b=*/false);
+ };
+ DropNuwFlag(V0, E->getOpcode());
+ DropNuwFlag(V1, E->getAltOpcode());
- Value *V = Builder.CreateShuffleVector(V0, V1, Mask);
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(), Mask);
+ }
+ V = Builder.CreateShuffleVector(V0, V1, Mask);
+ }
if (auto *I = dyn_cast<Instruction>(V)) {
V = propagateMetadata(I, E->Scalars);
GatherShuffleExtractSeq.insert(I);
diff --git a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
index 0c80d5c0ccb53a..6028a8b918941c 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
@@ -5,21 +5,8 @@ define void @test1(ptr %in, ptr %out) {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
-; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 48
-; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP9]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP7]], ptr [[TMP10]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP8]], ptr [[TMP11]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i32> [[TMP0]] to <8 x i64>
+; CHECK-NEXT: store <8 x i64> [[TMP1]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: ret void
;
entry:
@@ -53,15 +40,11 @@ define void @test2(ptr %in, ptr %out) {
; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i64> [[TMP4]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
-; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 48
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP9]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP7]], ptr [[TMP10]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP8]], ptr [[TMP11]], align 8
+; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP7]], align 8
+; CHECK-NEXT: store <4 x i64> [[TMP4]], ptr [[TMP8]], align 8
; CHECK-NEXT: ret void
;
entry:
More information about the llvm-commits
mailing list