[llvm] 97743b8 - [SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder support vector instructions. (#99499)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 7 08:48:02 PDT 2024
Author: Han-Kuan Chen
Date: 2024-08-07T23:47:57+08:00
New Revision: 97743b8be86ab96afb26ba93e1876406c1f4d541
URL: https://github.com/llvm/llvm-project/commit/97743b8be86ab96afb26ba93e1876406c1f4d541
DIFF: https://github.com/llvm/llvm-project/commit/97743b8be86ab96afb26ba93e1876406c1f4d541.diff
LOG: [SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder support vector instructions. (#99499)
1. When REVEC is enabled, we need to expand vector types into scalar
types.
2. When REVEC is enabled, CreateInsertVector (and CreateExtractVector)
is used because the scalar type may be a FixedVectorType.
3. Since the mask indices which are used by processBuildVector expect
the source is scalar type, we need to transform the mask indices into a
form which can be used when REVEC is enabled. The transform is only
called when the mask is really used.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/revec.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 1fee06d145b9f..4186b17e644b0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
VF * getNumElements(ScalarTy));
}
+static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
+ SmallVectorImpl<int> &Mask) {
+ // The ShuffleBuilder implementation use shufflevector to splat an "element".
+ // But the element have
diff erent meaning for SLP (scalar) and REVEC
+ // (vector). We need to expand Mask into masks which shufflevector can use
+ // directly.
+ SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
+ for (unsigned I : seq<unsigned>(Mask.size()))
+ for (auto [J, MaskV] : enumerate(MutableArrayRef(NewMask).slice(
+ I * VecTyNumElements, VecTyNumElements)))
+ MaskV = Mask[I] == PoisonMaskElem ? PoisonMaskElem
+ : Mask[I] * VecTyNumElements + J;
+ Mask.swap(NewMask);
+}
+
/// \returns True if the value is a constant (but not globals/constant
/// expressions).
static bool isConstant(Value *V) {
@@ -7772,6 +7787,31 @@ namespace {
/// The base class for shuffle instruction emission and shuffle cost estimation.
class BaseShuffleAnalysis {
protected:
+ Type *ScalarTy = nullptr;
+
+ BaseShuffleAnalysis(Type *ScalarTy) : ScalarTy(ScalarTy) {}
+
+ /// V is expected to be a vectorized value.
+ /// When REVEC is disabled, there is no
diff erence between VF and
+ /// VNumElements.
+ /// When REVEC is enabled, VF is VNumElements / ScalarTyNumElements.
+ /// e.g., if ScalarTy is <4 x Ty> and V1 is <8 x Ty>, 2 is returned instead
+ /// of 8.
+ unsigned getVF(Value *V) const {
+ assert(V && "V cannot be nullptr");
+ assert(isa<FixedVectorType>(V->getType()) &&
+ "V does not have FixedVectorType");
+ assert(ScalarTy && "ScalarTy cannot be nullptr");
+ unsigned ScalarTyNumElements = getNumElements(ScalarTy);
+ unsigned VNumElements =
+ cast<FixedVectorType>(V->getType())->getNumElements();
+ assert(VNumElements > ScalarTyNumElements &&
+ "the number of elements of V is not large enough");
+ assert(VNumElements % ScalarTyNumElements == 0 &&
+ "the number of elements of V is not a vectorized value");
+ return VNumElements / ScalarTyNumElements;
+ }
+
/// Checks if the mask is an identity mask.
/// \param IsStrict if is true the function returns false if mask size does
/// not match vector size.
@@ -8265,7 +8305,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
bool IsFinalized = false;
SmallVector<int> CommonMask;
SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
- Type *ScalarTy = nullptr;
const TargetTransformInfo &TTI;
InstructionCost Cost = 0;
SmallDenseSet<Value *> VectorizedVals;
@@ -8847,14 +8886,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
} else if (V1 && P2.isNull()) {
// Shuffle single vector.
ExtraCost += GetValueMinBWAffectedCost(V1);
- CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ CommonVF = getVF(V1);
assert(
all_of(Mask,
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
"All elements in mask must be less than CommonVF.");
} else if (V1 && !V2) {
// Shuffle vector and tree node.
- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ unsigned VF = getVF(V1);
const TreeEntry *E2 = P2.get<const TreeEntry *>();
CommonVF = std::max(VF, E2->getVectorFactor());
assert(all_of(Mask,
@@ -8880,7 +8919,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
} else if (!V1 && V2) {
// Shuffle vector and tree node.
- unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+ unsigned VF = getVF(V2);
const TreeEntry *E1 = P1.get<const TreeEntry *>();
CommonVF = std::max(VF, E1->getVectorFactor());
assert(all_of(Mask,
@@ -8908,9 +8947,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
} else {
assert(V1 && V2 && "Expected both vectors.");
- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
- CommonVF =
- std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
+ unsigned VF = getVF(V1);
+ CommonVF = std::max(VF, getVF(V2));
assert(all_of(Mask,
[=](int Idx) {
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8928,6 +8966,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
}
}
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ CommonMask);
+ }
InVectors.front() =
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
if (InVectors.size() == 2)
@@ -8940,7 +8983,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
ShuffleCostEstimator(Type *ScalarTy, TargetTransformInfo &TTI,
ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
SmallPtrSetImpl<Value *> &CheckedExtracts)
- : ScalarTy(ScalarTy), TTI(TTI),
+ : BaseShuffleAnalysis(ScalarTy), TTI(TTI),
VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), R(R),
CheckedExtracts(CheckedExtracts) {}
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -9145,7 +9188,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
assert(!InVectors.empty() && !CommonMask.empty() &&
"Expected only tree entries from extracts/reused buildvectors.");
- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ unsigned VF = getVF(V1);
if (InVectors.size() == 2) {
Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9179,12 +9222,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
Vals.push_back(Constant::getNullValue(V->getType()));
}
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Vals.front()->getType())) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ // When REVEC is enabled, we need to expand vector types into scalar
+ // types.
+ unsigned VecTyNumElements = VecTy->getNumElements();
+ SmallVector<Constant *> NewVals(VF * VecTyNumElements, nullptr);
+ for (auto [I, V] : enumerate(Vals)) {
+ Type *ScalarTy = V->getType()->getScalarType();
+ Constant *NewVal;
+ if (isa<PoisonValue>(V))
+ NewVal = PoisonValue::get(ScalarTy);
+ else if (isa<UndefValue>(V))
+ NewVal = UndefValue::get(ScalarTy);
+ else
+ NewVal = Constant::getNullValue(ScalarTy);
+ std::fill_n(NewVals.begin() + I * VecTyNumElements, VecTyNumElements,
+ NewVal);
+ }
+ Vals.swap(NewVals);
+ }
return ConstantVector::get(Vals);
}
return ConstantVector::getSplat(
ElementCount::getFixed(
cast<FixedVectorType>(Root->getType())->getNumElements()),
- getAllOnesValue(*R.DL, ScalarTy));
+ getAllOnesValue(*R.DL, ScalarTy->getScalarType()));
}
InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
/// Finalize emission of the shuffles.
@@ -11685,8 +11748,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
Type *Ty) {
Value *Scalar = V;
if (Scalar->getType() != Ty) {
- assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() &&
- "Expected integer types only.");
+ assert(Scalar->getType()->isIntOrIntVectorTy() &&
+ Ty->isIntOrIntVectorTy() && "Expected integer types only.");
Value *V = Scalar;
if (auto *CI = dyn_cast<CastInst>(Scalar);
isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
@@ -11699,10 +11762,21 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
}
- Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
- auto *InsElt = dyn_cast<InsertElementInst>(Vec);
- if (!InsElt)
- return Vec;
+ Instruction *InsElt;
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ Vec = InsElt = Builder.CreateInsertVector(
+ Vec->getType(), Vec, V,
+ Builder.getInt64(Pos * VecTy->getNumElements()));
+ auto *II = dyn_cast<IntrinsicInst>(InsElt);
+ if (!II || II->getIntrinsicID() != Intrinsic::vector_insert)
+ return Vec;
+ } else {
+ Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
+ InsElt = dyn_cast<InsertElementInst>(Vec);
+ if (!InsElt)
+ return Vec;
+ }
GatherShuffleExtractSeq.insert(InsElt);
CSEBlocks.insert(InsElt->getParent());
// Add to our 'need-to-extract' list.
@@ -11803,7 +11877,6 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
/// resulting shuffle and the second operand sets to be the newly added
/// operand. The \p CommonMask is transformed in the proper way after that.
SmallVector<Value *, 2> InVectors;
- Type *ScalarTy = nullptr;
IRBuilderBase &Builder;
BoUpSLP &R;
@@ -11929,7 +12002,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
public:
ShuffleInstructionBuilder(Type *ScalarTy, IRBuilderBase &Builder, BoUpSLP &R)
- : ScalarTy(ScalarTy), Builder(Builder), R(R) {}
+ : BaseShuffleAnalysis(ScalarTy), Builder(Builder), R(R) {}
/// Adjusts extractelements after reusing them.
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -12186,7 +12259,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
break;
}
}
- int VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ int VF = getVF(V1);
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12209,6 +12282,15 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
IsFinalized = true;
+ SmallVector<int> NewExtMask(ExtMask);
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ CommonMask);
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ NewExtMask);
+ ExtMask = NewExtMask;
+ }
if (Action) {
Value *Vec = InVectors.front();
if (InVectors.size() == 2) {
@@ -13992,6 +14074,17 @@ Value *BoUpSLP::vectorizeTree(
if (GEP->hasName())
CloneGEP->takeName(GEP);
Ex = CloneGEP;
+ } else if (auto *VecTy =
+ dyn_cast<FixedVectorType>(Scalar->getType())) {
+ assert(SLPReVec && "FixedVectorType is not expected.");
+ unsigned VecTyNumElements = VecTy->getNumElements();
+ // When REVEC is enabled, we need to extract a vector.
+ // Note: The element size of Scalar may be
diff erent from the
+ // element size of Vec.
+ Ex = Builder.CreateExtractVector(
+ FixedVectorType::get(Vec->getType()->getScalarType(),
+ VecTyNumElements),
+ Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
} else {
Ex = Builder.CreateExtractElement(Vec, Lane);
}
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index a6e1061189980..d6dd4128de9c7 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -88,3 +88,39 @@ entry:
store <4 x i32> %9, ptr %10, align 4
ret void
}
+
+define void @test4(ptr %in, ptr %out) {
+; CHECK-LABEL: @test4(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = load <8 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 8)
+; CHECK-NEXT: [[TMP2:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP1]], <8 x float> [[TMP0]], i64 0)
+; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP4:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT: [[TMP5:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP4]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP6:%.*]] = fmul <16 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 0)
+; CHECK-NEXT: [[TMP8:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP7]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> [[TMP8]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT: [[TMP10:%.*]] = fadd <16 x float> [[TMP9]], [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = fcmp ogt <16 x float> [[TMP10]], [[TMP5]]
+; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i1, ptr [[OUT:%.*]], i64 8
+; CHECK-NEXT: [[TMP13:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 8)
+; CHECK-NEXT: store <8 x i1> [[TMP13]], ptr [[OUT]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 0)
+; CHECK-NEXT: store <8 x i1> [[TMP14]], ptr [[TMP12]], align 1
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = load <8 x float>, ptr %in, align 4
+ %1 = fmul <8 x float> %0, zeroinitializer
+ %2 = fmul <8 x float> %0, zeroinitializer
+ %3 = fadd <8 x float> zeroinitializer, %1
+ %4 = fadd <8 x float> %0, %2
+ %5 = fcmp ogt <8 x float> %3, zeroinitializer
+ %6 = fcmp ogt <8 x float> %4, zeroinitializer
+ %7 = getelementptr i1, ptr %out, i64 8
+ store <8 x i1> %5, ptr %out, align 1
+ store <8 x i1> %6, ptr %7, align 1
+ ret void
+}
More information about the llvm-commits
mailing list