[llvm] [SLP][REVEC] When ScalarTy is FixedVectorType, the insertion index should consider the number of elements of ScalarTy. (PR #114526)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 02:48:26 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Han-Kuan Chen (HanKuanChen)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/114526.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+10-9)
- (modified) llvm/test/Transforms/SLPVectorizer/revec.ll (+54)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2afd02dae3a8b8..328ccf30641a83 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13803,13 +13803,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
unsigned VF = 0,
function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
IsFinalized = true;
+ unsigned ScalarTyNumElements = getNumElements(ScalarTy);
SmallVector<int> NewExtMask(ExtMask);
- if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+ if (ScalarTyNumElements != 1) {
assert(SLPReVec && "FixedVectorType is not expected.");
- transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
- CommonMask);
- transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
- NewExtMask);
+ transformScalarShuffleIndiciesToVector(ScalarTyNumElements, CommonMask);
+ transformScalarShuffleIndiciesToVector(ScalarTyNumElements, NewExtMask);
ExtMask = NewExtMask;
}
if (Action) {
@@ -13852,12 +13851,14 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
return !isKnownNonNegative(
V, SimplifyQuery(*R.DL));
}));
+ unsigned InsertionIndex = Idx * ScalarTyNumElements;
Vec = Builder.CreateInsertVector(Vec->getType(), Vec, V,
- Builder.getInt64(Idx));
+ Builder.getInt64(InsertionIndex));
if (!CommonMask.empty()) {
- std::iota(std::next(CommonMask.begin(), Idx),
- std::next(CommonMask.begin(), Idx + E->getVectorFactor()),
- Idx);
+ std::iota(std::next(CommonMask.begin(), InsertionIndex),
+ std::next(CommonMask.begin(), (Idx + E->getVectorFactor()) *
+ ScalarTyNumElements),
+ InsertionIndex);
}
}
InVectors.front() = Vec;
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index f32e315142767f..aec81086105d68 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -355,3 +355,57 @@ entry:
%10 = icmp ne <2 x i8> %8, zeroinitializer
ret void
}
+
+define void @test12() {
+; CHECK-LABEL: @test12(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr null, i64 33
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr null, i64 50
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr null, i64 75
+; CHECK-NEXT: [[TMP3:%.*]] = load <8 x float>, ptr [[TMP1]], align 4
+; CHECK-NEXT: [[TMP4:%.*]] = load <8 x float>, ptr [[TMP2]], align 4
+; CHECK-NEXT: [[TMP5:%.*]] = load <16 x float>, ptr [[TMP0]], align 4
+; CHECK-NEXT: [[TMP6:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> [[TMP4]], i64 0)
+; CHECK-NEXT: [[TMP7:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP6]], <8 x float> [[TMP3]], i64 8)
+; CHECK-NEXT: [[TMP8:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v16f32(<32 x float> [[TMP7]], <16 x float> [[TMP5]], i64 16)
+; CHECK-NEXT: [[TMP9:%.*]] = fpext <32 x float> [[TMP8]] to <32 x double>
+; CHECK-NEXT: [[TMP10:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> poison, <8 x double> zeroinitializer, i64 0)
+; CHECK-NEXT: [[TMP11:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP10]], <8 x double> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP12:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP11]], <8 x double> zeroinitializer, i64 16)
+; CHECK-NEXT: [[TMP13:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP12]], <8 x double> zeroinitializer, i64 24)
+; CHECK-NEXT: [[TMP14:%.*]] = fadd <32 x double> [[TMP13]], [[TMP9]]
+; CHECK-NEXT: [[TMP15:%.*]] = fptrunc <32 x double> [[TMP14]] to <32 x float>
+; CHECK-NEXT: [[TMP16:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT: [[TMP17:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP16]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP18:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP17]], <8 x float> zeroinitializer, i64 16)
+; CHECK-NEXT: [[TMP19:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP18]], <8 x float> zeroinitializer, i64 24)
+; CHECK-NEXT: [[TMP20:%.*]] = fcmp ogt <32 x float> [[TMP19]], [[TMP15]]
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = getelementptr float, ptr null, i64 33
+ %1 = getelementptr float, ptr null, i64 41
+ %2 = getelementptr float, ptr null, i64 50
+ %3 = getelementptr float, ptr null, i64 75
+ %4 = load <8 x float>, ptr %0, align 4
+ %5 = load <8 x float>, ptr %1, align 4
+ %6 = load <8 x float>, ptr %2, align 4
+ %7 = load <8 x float>, ptr %3, align 4
+ %8 = fpext <8 x float> %4 to <8 x double>
+ %9 = fpext <8 x float> %5 to <8 x double>
+ %10 = fpext <8 x float> %6 to <8 x double>
+ %11 = fpext <8 x float> %7 to <8 x double>
+ %12 = fadd <8 x double> zeroinitializer, %8
+ %13 = fadd <8 x double> zeroinitializer, %9
+ %14 = fadd <8 x double> zeroinitializer, %10
+ %15 = fadd <8 x double> zeroinitializer, %11
+ %16 = fptrunc <8 x double> %12 to <8 x float>
+ %17 = fptrunc <8 x double> %13 to <8 x float>
+ %18 = fptrunc <8 x double> %14 to <8 x float>
+ %19 = fptrunc <8 x double> %15 to <8 x float>
+ %20 = fcmp ogt <8 x float> zeroinitializer, %16
+ %21 = fcmp ogt <8 x float> zeroinitializer, %17
+ %22 = fcmp ogt <8 x float> zeroinitializer, %18
+ %23 = fcmp ogt <8 x float> zeroinitializer, %19
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/114526
More information about the llvm-commits
mailing list