[llvm] a795a18 - [SLP][REVEC] VF should be scaled when ScalarTy is FixedVectorType. (#114551)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 12:03:55 PDT 2024


Author: Han-Kuan Chen
Date: 2024-11-02T03:03:52+08:00
New Revision: a795a18bbae1800d8ee6b2eb23bc2a454a1269ef

URL: https://github.com/llvm/llvm-project/commit/a795a18bbae1800d8ee6b2eb23bc2a454a1269ef
DIFF: https://github.com/llvm/llvm-project/commit/a795a18bbae1800d8ee6b2eb23bc2a454a1269ef.diff

LOG: [SLP][REVEC] VF should be scaled when ScalarTy is FixedVectorType. (#114551)

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 248a107ded514c..427b8bd0e75ab0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5086,6 +5086,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
             VecLdCost +=
                 TTI.getInstructionCost(cast<Instruction>(VL[Idx]), CostKind);
       }
+      unsigned ScalarTyNumElements = getNumElements(ScalarTy);
       auto *SubVecTy = getWidenedType(ScalarTy, VF);
       for (auto [I, LS] : enumerate(States)) {
         auto *LI0 = cast<LoadInst>(VL[I * VF]);
@@ -5109,11 +5110,12 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
                 SubVecTy, APInt::getAllOnes(VF),
                 /*Insert=*/true, /*Extract=*/false, CostKind);
           else
-            VectorGEPCost += TTI.getScalarizationOverhead(
-                                 SubVecTy, APInt::getOneBitSet(VF, 0),
-                                 /*Insert=*/true, /*Extract=*/false, CostKind) +
-                             ::getShuffleCost(TTI, TTI::SK_Broadcast, SubVecTy,
-                                              {}, CostKind);
+            VectorGEPCost +=
+                TTI.getScalarizationOverhead(
+                    SubVecTy, APInt::getOneBitSet(ScalarTyNumElements * VF, 0),
+                    /*Insert=*/true, /*Extract=*/false, CostKind) +
+                ::getShuffleCost(TTI, TTI::SK_Broadcast, SubVecTy, {},
+                                 CostKind);
         }
         switch (LS) {
         case LoadsState::Vectorize:

diff  --git a/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll
index 65d0078080d227..0cf4da623a0fe9 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll
@@ -40,3 +40,57 @@ sw.bb509.i:                                       ; preds = %if.then458.i, %if.e
   %5 = phi <2 x i32> [ %1, %if.then458.i ], [ zeroinitializer, %if.end.i87 ], [ zeroinitializer, %if.end.i87 ]
   ret i32 0
 }
+
+define void @test2() {
+; CHECK-LABEL: @test2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr null, i64 132
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr null, i64 200
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr null, i64 300
+; CHECK-NEXT:    [[TMP3:%.*]] = load <8 x float>, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = load <8 x float>, ptr [[TMP2]], align 4
+; CHECK-NEXT:    [[TMP5:%.*]] = load <16 x float>, ptr [[TMP0]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> [[TMP4]], i64 0)
+; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP6]], <8 x float> [[TMP3]], i64 8)
+; CHECK-NEXT:    [[TMP8:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v16f32(<32 x float> [[TMP7]], <16 x float> [[TMP5]], i64 16)
+; CHECK-NEXT:    [[TMP9:%.*]] = fpext <32 x float> [[TMP8]] to <32 x double>
+; CHECK-NEXT:    [[TMP10:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> poison, <8 x double> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP11:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP10]], <8 x double> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP12:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP11]], <8 x double> zeroinitializer, i64 16)
+; CHECK-NEXT:    [[TMP13:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP12]], <8 x double> zeroinitializer, i64 24)
+; CHECK-NEXT:    [[TMP14:%.*]] = fadd <32 x double> [[TMP13]], [[TMP9]]
+; CHECK-NEXT:    [[TMP15:%.*]] = fptrunc <32 x double> [[TMP14]] to <32 x float>
+; CHECK-NEXT:    [[TMP16:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP17:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP16]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP18:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP17]], <8 x float> zeroinitializer, i64 16)
+; CHECK-NEXT:    [[TMP19:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP18]], <8 x float> zeroinitializer, i64 24)
+; CHECK-NEXT:    [[TMP20:%.*]] = fcmp ogt <32 x float> [[TMP19]], [[TMP15]]
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = getelementptr i8, ptr null, i64 132
+  %1 = getelementptr i8, ptr null, i64 164
+  %2 = getelementptr i8, ptr null, i64 200
+  %3 = getelementptr i8, ptr null, i64 300
+  %4 = load <8 x float>, ptr %0, align 4
+  %5 = load <8 x float>, ptr %1, align 4
+  %6 = load <8 x float>, ptr %2, align 4
+  %7 = load <8 x float>, ptr %3, align 4
+  %8 = fpext <8 x float> %4 to <8 x double>
+  %9 = fpext <8 x float> %5 to <8 x double>
+  %10 = fpext <8 x float> %6 to <8 x double>
+  %11 = fpext <8 x float> %7 to <8 x double>
+  %12 = fadd <8 x double> zeroinitializer, %8
+  %13 = fadd <8 x double> zeroinitializer, %9
+  %14 = fadd <8 x double> zeroinitializer, %10
+  %15 = fadd <8 x double> zeroinitializer, %11
+  %16 = fptrunc <8 x double> %12 to <8 x float>
+  %17 = fptrunc <8 x double> %13 to <8 x float>
+  %18 = fptrunc <8 x double> %14 to <8 x float>
+  %19 = fptrunc <8 x double> %15 to <8 x float>
+  %20 = fcmp ogt <8 x float> zeroinitializer, %16
+  %21 = fcmp ogt <8 x float> zeroinitializer, %17
+  %22 = fcmp ogt <8 x float> zeroinitializer, %18
+  %23 = fcmp ogt <8 x float> zeroinitializer, %19
+  ret void
+}


        


More information about the llvm-commits mailing list