[llvm] [SLP][REVEC] When ScalarTy is FixedVectorType, the insertion index should consider the number of elements of ScalarTy. (PR #114526)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 02:48:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/114526.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+10-9) 
- (modified) llvm/test/Transforms/SLPVectorizer/revec.ll (+54) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2afd02dae3a8b8..328ccf30641a83 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13803,13 +13803,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
            unsigned VF = 0,
            function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
     IsFinalized = true;
+    unsigned ScalarTyNumElements = getNumElements(ScalarTy);
     SmallVector<int> NewExtMask(ExtMask);
-    if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+    if (ScalarTyNumElements != 1) {
       assert(SLPReVec && "FixedVectorType is not expected.");
-      transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
-                                             CommonMask);
-      transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
-                                             NewExtMask);
+      transformScalarShuffleIndiciesToVector(ScalarTyNumElements, CommonMask);
+      transformScalarShuffleIndiciesToVector(ScalarTyNumElements, NewExtMask);
       ExtMask = NewExtMask;
     }
     if (Action) {
@@ -13852,12 +13851,14 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
                                    return !isKnownNonNegative(
                                        V, SimplifyQuery(*R.DL));
                                  }));
+        unsigned InsertionIndex = Idx * ScalarTyNumElements;
         Vec = Builder.CreateInsertVector(Vec->getType(), Vec, V,
-                                         Builder.getInt64(Idx));
+                                         Builder.getInt64(InsertionIndex));
         if (!CommonMask.empty()) {
-          std::iota(std::next(CommonMask.begin(), Idx),
-                    std::next(CommonMask.begin(), Idx + E->getVectorFactor()),
-                    Idx);
+          std::iota(std::next(CommonMask.begin(), InsertionIndex),
+                    std::next(CommonMask.begin(), (Idx + E->getVectorFactor()) *
+                                                      ScalarTyNumElements),
+                    InsertionIndex);
         }
       }
       InVectors.front() = Vec;
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index f32e315142767f..aec81086105d68 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -355,3 +355,57 @@ entry:
   %10 = icmp ne <2 x i8> %8, zeroinitializer
   ret void
 }
+
+define void @test12() {
+; CHECK-LABEL: @test12(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr float, ptr null, i64 33
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr float, ptr null, i64 50
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr float, ptr null, i64 75
+; CHECK-NEXT:    [[TMP3:%.*]] = load <8 x float>, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = load <8 x float>, ptr [[TMP2]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = load <16 x float>, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> [[TMP4]], i64 0)
+; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP6]], <8 x float> [[TMP3]], i64 8)
+; CHECK-NEXT:    [[TMP8:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v16f32(<32 x float> [[TMP7]], <16 x float> [[TMP5]], i64 16)
+; CHECK-NEXT:    [[TMP9:%.*]] = fpext <32 x float> [[TMP8]] to <32 x double>
+; CHECK-NEXT:    [[TMP10:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> poison, <8 x double> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP11:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP10]], <8 x double> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP12:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP11]], <8 x double> zeroinitializer, i64 16)
+; CHECK-NEXT:    [[TMP13:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP12]], <8 x double> zeroinitializer, i64 24)
+; CHECK-NEXT:    [[TMP14:%.*]] = fadd <32 x double> [[TMP13]], [[TMP9]]
+; CHECK-NEXT:    [[TMP15:%.*]] = fptrunc <32 x double> [[TMP14]] to <32 x float>
+; CHECK-NEXT:    [[TMP16:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP17:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP16]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP18:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP17]], <8 x float> zeroinitializer, i64 16)
+; CHECK-NEXT:    [[TMP19:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP18]], <8 x float> zeroinitializer, i64 24)
+; CHECK-NEXT:    [[TMP20:%.*]] = fcmp ogt <32 x float> [[TMP19]], [[TMP15]]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = getelementptr float, ptr null, i64 33
+  %1 = getelementptr float, ptr null, i64 41
+  %2 = getelementptr float, ptr null, i64 50
+  %3 = getelementptr float, ptr null, i64 75
+  %4 = load <8 x float>, ptr %0, align 4
+  %5 = load <8 x float>, ptr %1, align 4
+  %6 = load <8 x float>, ptr %2, align 4
+  %7 = load <8 x float>, ptr %3, align 4
+  %8 = fpext <8 x float> %4 to <8 x double>
+  %9 = fpext <8 x float> %5 to <8 x double>
+  %10 = fpext <8 x float> %6 to <8 x double>
+  %11 = fpext <8 x float> %7 to <8 x double>
+  %12 = fadd <8 x double> zeroinitializer, %8
+  %13 = fadd <8 x double> zeroinitializer, %9
+  %14 = fadd <8 x double> zeroinitializer, %10
+  %15 = fadd <8 x double> zeroinitializer, %11
+  %16 = fptrunc <8 x double> %12 to <8 x float>
+  %17 = fptrunc <8 x double> %13 to <8 x float>
+  %18 = fptrunc <8 x double> %14 to <8 x float>
+  %19 = fptrunc <8 x double> %15 to <8 x float>
+  %20 = fcmp ogt <8 x float> zeroinitializer, %16
+  %21 = fcmp ogt <8 x float> zeroinitializer, %17
+  %22 = fcmp ogt <8 x float> zeroinitializer, %18
+  %23 = fcmp ogt <8 x float> zeroinitializer, %19
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/114526


More information about the llvm-commits mailing list