[llvm] [SLP]Support vectorization of previously vectorized scalars in split nodes (PR #134286)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 3 10:52:27 PDT 2025


https://github.com/alexey-bataev created https://github.com/llvm/llvm-project/pull/134286

Patch removes the restriction for the revectorization of the previously
vectorized scalars in split nodes, and moves the cost profitability
check to avoid regressions.


>From 37876801e79fbfdb1c47089015af0caf0a6ab779 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 3 Apr 2025 17:52:18 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 39 ++++++++-----------
 .../X86/splat-score-adjustment.ll             | 20 ++++++----
 .../X86/split-node-no-reorder-copy.ll         |  5 +--
 3 files changed, 30 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a115fec47aeec..8a12962ccf5d5 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9213,17 +9213,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         TTI.preferAlternateOpcodeVectorization() || !SplitAlternateInstructions)
       return false;
 
-    // Any value is used in split node already - just gather.
-    if (any_of(VL, [&](Value *V) {
-          return ScalarsInSplitNodes.contains(V) || isVectorized(V);
-        })) {
-      if (TryToFindDuplicates(S)) {
-        auto Invalid = ScheduleBundle::invalid();
-        newTreeEntry(VL, Invalid /*not vectorized*/, S, UserTreeIdx,
-                     ReuseShuffleIndices);
-      }
-      return true;
-    }
     SmallVector<Value *> Op1, Op2;
     OrdersType ReorderIndices(VL.size(), VL.size());
     SmallBitVector Op1Indices(VL.size());
@@ -9282,6 +9271,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     // as alternate ops.
     if (NumParts >= VL.size())
       return false;
+    constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
+    InstructionCost InsertCost = ::getShuffleCost(
+        TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
+    FixedVectorType *SubVecTy =
+        getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
+    InstructionCost NewShuffleCost =
+        ::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
+    if (LocalState.getOpcode() != Instruction::ICmp &&
+        LocalState.getOpcode() != Instruction::FCmp &&
+        LocalState.getAltOpcode() != Instruction::ICmp &&
+        LocalState.getAltOpcode() != Instruction::FCmp && NumParts <= 1 &&
+        (Mask.empty() || InsertCost >= NewShuffleCost))
+      return false;
     if ((LocalState.getMainOp()->isBinaryOp() &&
          LocalState.getAltOp()->isBinaryOp() &&
          (LocalState.isShiftOp() || LocalState.isBitwiseLogicOp() ||
@@ -9289,15 +9291,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         (LocalState.getMainOp()->isCast() && LocalState.getAltOp()->isCast()) ||
         (LocalState.getMainOp()->isUnaryOp() &&
          LocalState.getAltOp()->isUnaryOp())) {
-      constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
-      InstructionCost InsertCost = ::getShuffleCost(
-          TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
-      FixedVectorType *SubVecTy =
-          getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
-      InstructionCost NewShuffleCost =
-          ::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
-      if (NumParts <= 1 && (Mask.empty() || InsertCost >= NewShuffleCost))
-        return false;
       InstructionCost OriginalVecOpsCost =
           TTI.getArithmeticInstrCost(Opcode0, VecTy, Kind) +
           TTI.getArithmeticInstrCost(Opcode1, VecTy, Kind);
@@ -9500,9 +9493,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
           ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, VecTy, {}, Kind) +
           ::getScalarizationOverhead(*TTI, ScalarTy, VecTy, Extracted,
                                      /*Insert=*/false, /*Extract=*/true, Kind);
-      InstructionCost ScalarizeCostEstimate =
-          ::getScalarizationOverhead(*TTI, ScalarTy, VecTy, Vectorized,
-                                     /*Insert=*/true, /*Extract=*/false, Kind);
+      InstructionCost ScalarizeCostEstimate = ::getScalarizationOverhead(
+          *TTI, ScalarTy, VecTy, Vectorized,
+          /*Insert=*/true, /*Extract=*/false, Kind, /*ForPoisonSrc=*/false);
       PreferScalarize = VectorizeCostEstimate > ScalarizeCostEstimate;
     }
     if (PreferScalarize) {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/splat-score-adjustment.ll b/llvm/test/Transforms/SLPVectorizer/X86/splat-score-adjustment.ll
index 38e9ba7ce7028..1c4f51700d083 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/splat-score-adjustment.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/splat-score-adjustment.ll
@@ -7,18 +7,22 @@ define i32 @a() {
 ; CHECK-NEXT:    br label %[[BB1:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    [[TMP2:%.*]] = phi <4 x i8> [ zeroinitializer, [[TMP0:%.*]] ], [ [[TMP6:%.*]], %[[BB1]] ]
-; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP3:%.*]] = phi <2 x i8> [ zeroinitializer, [[TMP0]] ], [ [[TMP17:%.*]], %[[BB1]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <8 x i32> <i32 0, i32 0, i32 1, i32 1, i32 2, i32 2, i32 3, i32 3>
 ; CHECK-NEXT:    [[TMP6]] = load <4 x i8>, ptr null, align 4
-; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <4 x i32> <i32 2, i32 3, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP7]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <4 x i32> <i32 poison, i32 poison, i32 0, i32 1>
+; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP12]], <4 x i8> [[TMP7]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP9:%.*]] = xor <4 x i8> [[TMP6]], [[TMP8]]
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP18:%.*]] = call <8 x i8> @llvm.vector.insert.v8i8.v4i8(<8 x i8> [[TMP10]], <4 x i8> [[TMP6]], i64 4)
-; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <8 x i8> [[TMP5]], <8 x i8> [[TMP18]], <8 x i32> <i32 1, i32 2, i32 3, i32 12, i32 3, i32 12, i32 13, i32 14>
+; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <8 x i32> <i32 poison, i32 0, i32 poison, i32 1, i32 poison, i32 2, i32 poison, i32 3>
+; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> poison, <8 x i32> <i32 0, i32 poison, i32 1, i32 poison, i32 2, i32 poison, i32 3, i32 poison>
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <8 x i8> [[TMP10]], <8 x i8> [[TMP11]], <8 x i32> <i32 8, i32 1, i32 10, i32 3, i32 12, i32 5, i32 14, i32 7>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <8 x i8> [[TMP13]], <8 x i8> [[TMP18]], <8 x i32> <i32 1, i32 3, i32 2, i32 9, i32 3, i32 11, i32 9, i32 13>
 ; CHECK-NEXT:    [[TMP22:%.*]] = xor <8 x i8> [[TMP18]], [[TMP21]]
 ; CHECK-NEXT:    [[TMP23:%.*]] = xor <8 x i8> [[TMP22]], [[TMP5]]
-; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <8 x i8> [[TMP23]], <8 x i8> poison, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
-; CHECK-NEXT:    store <8 x i8> [[TMP13]], ptr null, align 4
+; CHECK-NEXT:    store <8 x i8> [[TMP23]], ptr null, align 4
+; CHECK-NEXT:    [[TMP17]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <2 x i32> <i32 2, i32 3>
 ; CHECK-NEXT:    br label %[[BB1]]
 ;
   br label %1
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/split-node-no-reorder-copy.ll b/llvm/test/Transforms/SLPVectorizer/X86/split-node-no-reorder-copy.ll
index e9884b24e1078..b7b6c10137b64 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/split-node-no-reorder-copy.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/split-node-no-reorder-copy.ll
@@ -16,9 +16,8 @@ define i1 @test(ptr %0, ptr %1, <2 x float> %2, <2 x float> %3, <2 x float> %4)
 ; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP9]], i32 7
 ; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <8 x float> [[TMP13]], <8 x float> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP17:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP16]], <8 x float> [[TMP15]], i64 8)
-; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <8 x float> [[TMP14]], <8 x float> [[TMP12]], <16 x i32> <i32 8, i32 9, i32 9, i32 9, i32 9, i32 9, i32 14, i32 14, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 poison>
-; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <16 x float> [[TMP18]], float [[TMP9]], i32 15
-; CHECK-NEXT:    [[TMP20:%.*]] = fmul <16 x float> [[TMP17]], [[TMP19]]
+; CHECK-NEXT:    [[TMP18:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP16]], <8 x float> [[TMP15]], i64 8)
+; CHECK-NEXT:    [[TMP20:%.*]] = fmul <16 x float> [[TMP18]], [[TMP17]]
 ; CHECK-NEXT:    [[TMP21:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v16f32(float 0.000000e+00, <16 x float> [[TMP20]])
 ; CHECK-NEXT:    [[TMP22:%.*]] = call float @foo(float [[TMP21]])
 ; CHECK-NEXT:    ret i1 false



More information about the llvm-commits mailing list