[llvm] [SLP]Fix perfect diamond match with extractelements in scalars (PR #132466)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 21 13:18:28 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead add a cost of the insertion of the first vector element into
non-poison vector value and then remaining elements.


---
Full diff: https://github.com/llvm/llvm-project/pull/132466.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+45-14) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll (+9-12) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll (+8-36) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0201955b8b559..7050549d61d74 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5310,12 +5310,11 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind,
 /// This is similar to TargetTransformInfo::getScalarizationOverhead, but if
 /// ScalarTy is a FixedVectorType, a vector will be inserted or extracted
 /// instead of a scalar.
-static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
-                                                Type *ScalarTy, VectorType *Ty,
-                                                const APInt &DemandedElts,
-                                                bool Insert, bool Extract,
-                                                TTI::TargetCostKind CostKind,
-                                                ArrayRef<Value *> VL = {}) {
+static InstructionCost
+getScalarizationOverhead(const TargetTransformInfo &TTI, Type *ScalarTy,
+                         VectorType *Ty, const APInt &DemandedElts, bool Insert,
+                         bool Extract, TTI::TargetCostKind CostKind,
+                         bool ForPoisonSrc = true, ArrayRef<Value *> VL = {}) {
   assert(!isa<ScalableVectorType>(Ty) &&
          "ScalableVectorType is not supported.");
   assert(getNumElements(ScalarTy) * DemandedElts.getBitWidth() ==
@@ -5339,8 +5338,19 @@ static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
     }
     return Cost;
   }
-  return TTI.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                      CostKind, VL);
+  APInt NewDemandedElts = DemandedElts;
+  InstructionCost Cost = 0;
+  if (!ForPoisonSrc && Insert) {
+    // Handle insert into non-poison vector.
+    unsigned LeftMostBit = NewDemandedElts.countr_zero();
+    NewDemandedElts.clearBit(LeftMostBit);
+    Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, CostKind,
+                                   LeftMostBit, Constant::getNullValue(Ty));
+  }
+  return Cost + (NewDemandedElts.isZero()
+                     ? 0
+                     : TTI.getScalarizationOverhead(Ty, NewDemandedElts, Insert,
+                                                    Extract, CostKind, VL));
 }
 
 /// Correctly creates insert_subvector, checking that the index is multiple of
@@ -11684,6 +11694,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     // No need to delay the cost estimation during analysis.
     return std::nullopt;
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+    Cost = 0;
+    VectorizedVals.clear();
+    SameNodesEstimated = true;
+  }
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
     if (&E1 == &E2) {
       assert(all_of(Mask,
@@ -14890,15 +14909,18 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
     ShuffledElements.setBit(I);
     ShuffleMask[I] = Res.first->second;
   }
-  if (!DemandedElements.isZero())
-    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
-                                     /*Insert=*/true,
-                                     /*Extract=*/false, CostKind, VL);
-  if (ForPoisonSrc)
+  if (ForPoisonSrc) {
     Cost = getScalarizationOverhead(*TTI, ScalarTy, VecTy,
                                     /*DemandedElts*/ ~ShuffledElements,
                                     /*Insert*/ true,
-                                    /*Extract*/ false, CostKind, VL);
+                                    /*Extract*/ false, CostKind,
+                                    /*ForPoisonSrc=*/true, VL);
+  } else if (!DemandedElements.isZero()) {
+    Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
+                                     /*Insert=*/true,
+                                     /*Extract=*/false, CostKind,
+                                     /*ForPoisonSrc=*/false, VL);
+  }
   if (DuplicateNonConst)
     Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
                              VecTy, ShuffleMask);
@@ -15556,6 +15578,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
         PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())),
         MaybeAlign());
   }
+  /// Reset the builder to handle perfect diamond match.
+  void resetForSameNode() {
+    IsFinalized = false;
+    CommonMask.clear();
+    InVectors.clear();
+  }
   /// Adds 2 input vectors (in form of tree entries) and the mask for their
   /// shuffling.
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -16111,6 +16139,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             Mask[I] = FrontTE->findLaneForValue(V);
           }
         }
+        // Reset the builder(s) to correctly handle perfect diamond matched
+        // nodes.
+        ShuffleBuilder.resetForSameNode();
         ShuffleBuilder.add(*FrontTE, Mask);
         // Full matched entry found, no need to insert subvectors.
         Res = ShuffleBuilder.finalize(E->getCommonMask(), {}, {});
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
index 75a413ffc1fb1..579239bc659bd 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/buildvector-with-reuses.ll
@@ -10,18 +10,15 @@ define <4 x double> @test(ptr %ia, ptr %ib, ptr %ic, ptr %id, ptr %ie, ptr %x) {
 ; CHECK-NEXT:    [[I4275:%.*]] = load double, ptr [[ID]], align 8
 ; CHECK-NEXT:    [[I4277:%.*]] = load double, ptr [[IE]], align 8
 ; CHECK-NEXT:    [[I4326:%.*]] = load <4 x double>, ptr [[X]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> poison, double [[I4238]], i32 0
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[I4252]], i32 1
-; CHECK-NEXT:    [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[I4264]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> [[TMP6]], double [[I4277]], i32 1
-; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <2 x double> [[TMP5]], [[TMP7]]
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[I44281:%.*]] = shufflevector <4 x double> [[TMP9]], <4 x double> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    ret <4 x double> [[I44281]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> <i32 0, i32 poison>
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x double> poison, double [[I4238]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP4]], double [[I4252]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x double> [[TMP5]], double [[I4264]], i32 2
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP6]], double [[I4277]], i32 3
+; CHECK-NEXT:    [[TMP8:%.*]] = fmul fast <4 x double> [[TMP3]], [[TMP7]]
+; CHECK-NEXT:    ret <4 x double> [[TMP8]]
 ;
   %i4238 = load double, ptr %ia, align 8
   %i4252 = load double, ptr %ib, align 8
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
index cb4783010965e..32dccd353da17 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction-transpose.ll
@@ -49,24 +49,10 @@ define i32 @reduce_and4(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i32> %v3, <
 ;
 ; AVX512-LABEL: @reduce_and4(
 ; AVX512-NEXT:  entry:
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT2:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT4:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT10:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT12:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP1:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[VECEXT8]], i32 8
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT7]], i32 9
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT10]], i32 10
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT12]], i32 11
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 12
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT]], i32 13
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT2]], i32 14
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT4]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP8]])
+; AVX512-NEXT:    [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP0]], [[TMP1]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;
@@ -144,24 +130,10 @@ define i32 @reduce_and4_transpose(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i
 ; AVX2-NEXT:    ret i32 [[OP_RDX]]
 ;
 ; AVX512-LABEL: @reduce_and4_transpose(
-; AVX512-NEXT:    [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT1:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
-; AVX512-NEXT:    [[VECEXT7:%.*]] = extractelement <4 x i32> [[V1]], i64 1
-; AVX512-NEXT:    [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
-; AVX512-NEXT:    [[VECEXT15:%.*]] = extractelement <4 x i32> [[V1]], i64 2
-; AVX512-NEXT:    [[VECEXT16:%.*]] = extractelement <4 x i32> [[V2]], i64 2
-; AVX512-NEXT:    [[VECEXT23:%.*]] = extractelement <4 x i32> [[V1]], i64 3
-; AVX512-NEXT:    [[VECEXT24:%.*]] = extractelement <4 x i32> [[V2]], i64 3
-; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; AVX512-NEXT:    [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT24]], i32 8
-; AVX512-NEXT:    [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT16]], i32 9
-; AVX512-NEXT:    [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT8]], i32 10
-; AVX512-NEXT:    [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 11
-; AVX512-NEXT:    [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT23]], i32 12
-; AVX512-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT15]], i32 13
-; AVX512-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT7]], i32 14
-; AVX512-NEXT:    [[TMP9:%.*]] = insertelement <16 x i32> [[TMP8]], i32 [[VECEXT]], i32 15
-; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP9]])
+; AVX512-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
+; AVX512-NEXT:    [[RDX_OP:%.*]] = and <8 x i32> [[TMP1]], [[TMP2]]
+; AVX512-NEXT:    [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
 ; AVX512-NEXT:    [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
 ; AVX512-NEXT:    ret i32 [[OP_RDX1]]
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/132466


More information about the llvm-commits mailing list