[llvm] 0ad6be1 - [SLPVectorizer, TargetTransformInfo, SystemZ] Improve SLP getGatherCost(). (#112491)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 29 12:19:48 PST 2024


Author: Jonas Paulsson
Date: 2024-11-29T21:19:45+01:00
New Revision: 0ad6be1927f89cef09aa5d0fb244873f687997c9

URL: https://github.com/llvm/llvm-project/commit/0ad6be1927f89cef09aa5d0fb244873f687997c9
DIFF: https://github.com/llvm/llvm-project/commit/0ad6be1927f89cef09aa5d0fb244873f687997c9.diff

LOG: [SLPVectorizer, TargetTransformInfo, SystemZ]  Improve SLP getGatherCost(). (#112491)

As vector element loads are free on SystemZ, this patch improves the cost
computation in getGatherCost() to reflect this.

getScalarizationOverhead() gets an optional parameter which can hold the actual
Values so that they in turn can be passed (by BasicTTIImpl) to
getVectorInstrCost().

SystemZTTIImpl::getVectorInstrCost() will now recognize a LoadInst and
typically return a 0 cost for it, with some exceptions.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/SystemZ/vec-elt-insertion.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 985ca1532e0149..89231e23e388a7 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -909,11 +909,13 @@ class TargetTransformInfo {
 
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
-  /// extracted from vectors.
+  /// extracted from vectors.  The involved values may be passed in VL if
+  /// Insert is true.
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind) const;
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) const;
 
   /// Estimate the overhead of scalarizing an instructions unique
   /// non-constant operands. The (potentially vector) types to use for each of
@@ -2001,10 +2003,10 @@ class TargetTransformInfo::Concept {
                                                   unsigned ScalarOpdIdx) = 0;
   virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
                                                       int ScalarOpdIdx) = 0;
-  virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
-                                                   const APInt &DemandedElts,
-                                                   bool Insert, bool Extract,
-                                                   TargetCostKind CostKind) = 0;
+  virtual InstructionCost
+  getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+                           bool Insert, bool Extract, TargetCostKind CostKind,
+                           ArrayRef<Value *> VL = {}) = 0;
   virtual InstructionCost
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
                                    ArrayRef<Type *> Tys,
@@ -2585,9 +2587,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TargetCostKind CostKind) override {
+                                           TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) override {
     return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                         CostKind);
+                                         CostKind, VL);
   }
   InstructionCost
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 38aba183f6a173..48ebffff8cbfc2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -404,7 +404,8 @@ class TargetTransformInfoImplBase {
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind) const {
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) const {
     return 0;
   }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 98cbb4886642bf..f46f07122329e7 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -780,7 +780,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   InstructionCost getScalarizationOverhead(VectorType *InTy,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind) {
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {}) {
     /// FIXME: a bitfield is not a reasonable abstraction for talking about
     /// which elements are needed from a scalable vector
     if (isa<ScalableVectorType>(InTy))
@@ -788,6 +789,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     auto *Ty = cast<FixedVectorType>(InTy);
 
     assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
+           (VL.empty() || VL.size() == Ty->getNumElements()) &&
            "Vector size mismatch");
 
     InstructionCost Cost = 0;
@@ -795,9 +797,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
       if (!DemandedElts[i])
         continue;
-      if (Insert)
+      if (Insert) {
+        Value *InsertedVal = VL.empty() ? nullptr : VL[i];
         Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
-                                            CostKind, i, nullptr, nullptr);
+                                            CostKind, i, nullptr, InsertedVal);
+      }
       if (Extract)
         Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty,
                                             CostKind, i, nullptr, nullptr);

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1fb2b9836de0cc..d4b6c08c5a32b2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -622,9 +622,9 @@ bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
 
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind) const {
+    TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) const {
   return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
-                                           CostKind);
+                                           CostKind, VL);
 }
 
 InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d1536a276a9040..919226eb54fa59 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3363,7 +3363,7 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
 
 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind) {
+    TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
   if (isa<ScalableVectorType>(Ty))
     return InstructionCost::getInvalid();
   if (Ty->getElementType()->isFloatingPointTy())

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 201bc831b816b3..83b86e31565e49 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -423,7 +423,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind);
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {});
 
   /// Return the cost of the scaling factor used in the addressing
   /// mode represented by AM for this target, for a load/store

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index bbded57bb92ab0..57f635ca6f42a8 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -669,7 +669,7 @@ static unsigned isM1OrSmaller(MVT VT) {
 
 InstructionCost RISCVTTIImpl::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
-    TTI::TargetCostKind CostKind) {
+    TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
   if (isa<ScalableVectorType>(Ty))
     return InstructionCost::getInvalid();
 

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 6fd36e90a02ddd..bd90bfed6e2c95 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -149,7 +149,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind);
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {});
 
   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                         TTI::TargetCostKind CostKind);

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index a586eedd58b667..83b42f6d1794d5 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -468,6 +468,42 @@ bool SystemZTTIImpl::hasDivRemOp(Type *DataType, bool IsSigned) {
   return (VT.isScalarInteger() && TLI->isTypeLegal(VT));
 }
 
+static bool isFreeEltLoad(Value *Op) {
+  if (isa<LoadInst>(Op) && Op->hasOneUse()) {
+    const Instruction *UserI = cast<Instruction>(*Op->user_begin());
+    return !isa<StoreInst>(UserI); // Prefer MVC
+  }
+  return false;
+}
+
+InstructionCost SystemZTTIImpl::getScalarizationOverhead(
+    VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
+    TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
+  unsigned NumElts = cast<FixedVectorType>(Ty)->getNumElements();
+  InstructionCost Cost = 0;
+
+  if (Insert && Ty->isIntOrIntVectorTy(64)) {
+    // VLVGP will insert two GPRs with one instruction, while VLE will load
+    // an element directly with no extra cost
+    assert((VL.empty() || VL.size() == NumElts) &&
+           "Type does not match the number of values.");
+    InstructionCost CurrVectorCost = 0;
+    for (unsigned Idx = 0; Idx < NumElts; ++Idx) {
+      if (DemandedElts[Idx] && !(VL.size() && isFreeEltLoad(VL[Idx])))
+        ++CurrVectorCost;
+      if (Idx % 2 == 1) {
+        Cost += std::min(InstructionCost(1), CurrVectorCost);
+        CurrVectorCost = 0;
+      }
+    }
+    Insert = false;
+  }
+
+  Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
+                                          CostKind, VL);
+  return Cost;
+}
+
 // Return the bit size for the scalar type or vector element
 // type. getScalarSizeInBits() returns 0 for a pointer type.
 static unsigned getScalarSizeInBits(Type *Ty) {
@@ -609,7 +645,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
     if (DivRemConst) {
       SmallVector<Type *> Tys(Args.size(), Ty);
       return VF * DivMulSeqCost +
-             getScalarizationOverhead(VTy, Args, Tys, CostKind);
+             BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
     }
     if ((SignedDivRem || UnsignedDivRem) && VF > 4)
       // Temporary hack: disable high vectorization factors with integer
@@ -636,7 +672,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
         SmallVector<Type *> Tys(Args.size(), Ty);
         InstructionCost Cost =
             (VF * ScalarCost) +
-            getScalarizationOverhead(VTy, Args, Tys, CostKind);
+            BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
         // FIXME: VF 2 for these FP operations are currently just as
         // expensive as for VF 4.
         if (VF == 2)
@@ -654,8 +690,9 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
     // There is no native support for FRem.
     if (Opcode == Instruction::FRem) {
       SmallVector<Type *> Tys(Args.size(), Ty);
-      InstructionCost Cost = (VF * LIBCALL_COST) +
-                             getScalarizationOverhead(VTy, Args, Tys, CostKind);
+      InstructionCost Cost =
+          (VF * LIBCALL_COST) +
+          BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
       // FIXME: VF 2 for float is currently just as expensive as for VF 4.
       if (VF == 2 && ScalarBits == 32)
         Cost *= 2;
@@ -975,10 +1012,10 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
           (Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
         NeedsExtracts = false;
 
-      TotCost += getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
-                                          NeedsExtracts, CostKind);
-      TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts,
-                                          /*Extract*/ false, CostKind);
+      TotCost += BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
+                                                 NeedsExtracts, CostKind);
+      TotCost += BaseT::getScalarizationOverhead(DstVecTy, NeedsInserts,
+                                                 /*Extract*/ false, CostKind);
 
       // FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
       if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -990,8 +1027,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     if (Opcode == Instruction::FPTrunc) {
       if (SrcScalarBits == 128)  // fp128 -> double/float + inserts of elements.
         return VF /*ldxbr/lexbr*/ +
-               getScalarizationOverhead(DstVecTy, /*Insert*/ true,
-                                        /*Extract*/ false, CostKind);
+               BaseT::getScalarizationOverhead(DstVecTy, /*Insert*/ true,
+                                               /*Extract*/ false, CostKind);
       else // double -> float
         return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
     }
@@ -1004,8 +1041,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
         return VF * 2;
       }
       // -> fp128.  VF * lxdb/lxeb + extraction of elements.
-      return VF + getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
-                                           /*Extract*/ true, CostKind);
+      return VF + BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
+                                                  /*Extract*/ true, CostKind);
     }
   }
 
@@ -1114,10 +1151,17 @@ InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
                                                    TTI::TargetCostKind CostKind,
                                                    unsigned Index, Value *Op0,
                                                    Value *Op1) {
-  // vlvgp will insert two grs into a vector register, so only count half the
-  // number of instructions.
-  if (Opcode == Instruction::InsertElement && Val->isIntOrIntVectorTy(64))
-    return ((Index % 2 == 0) ? 1 : 0);
+  if (Opcode == Instruction::InsertElement) {
+    // Vector Element Load.
+    if (Op1 != nullptr && isFreeEltLoad(Op1))
+      return 0;
+
+    // vlvgp will insert two grs into a vector register, so count half the
+    // number of instructions as an estimate when we don't have the full
+    // picture (as in getScalarizationOverhead()).
+    if (Val->isIntOrIntVectorTy(64))
+      return ((Index % 2 == 0) ? 1 : 0);
+  }
 
   if (Opcode == Instruction::ExtractElement) {
     int Cost = ((getScalarSizeInBits(Val) == 1) ? 2 /*+test-under-mask*/ : 1);

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index 8cc71a6c528f82..6795da59bf5b16 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -81,6 +81,11 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
   bool hasDivRemOp(Type *DataType, bool IsSigned);
   bool prefersVectorizedAddressing() { return false; }
   bool LSRWithInstrQueries() { return true; }
+  InstructionCost getScalarizationOverhead(VectorType *Ty,
+                                           const APInt &DemandedElts,
+                                           bool Insert, bool Extract,
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {});
   bool supportsEfficientVectorElementLoadStore() { return true; }
   bool enableInterleavedAccessVectorization() { return true; }
 

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 179e29e40614e7..abe70268108963 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -4854,10 +4854,9 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
          RegisterFileMoveCost;
 }
 
-InstructionCost
-X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
-                                     bool Insert, bool Extract,
-                                     TTI::TargetCostKind CostKind) {
+InstructionCost X86TTIImpl::getScalarizationOverhead(
+    VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
+    TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
   assert(DemandedElts.getBitWidth() ==
              cast<FixedVectorType>(Ty)->getNumElements() &&
          "Vector size mismatch");

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 36d00cee0d18b5..7786616f89aa6e 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -169,7 +169,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
-                                           TTI::TargetCostKind CostKind);
+                                           TTI::TargetCostKind CostKind,
+                                           ArrayRef<Value *> VL = {});
   InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
                                             int VF,
                                             const APInt &DemandedDstElts,

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7723442bc0fb6e..04755102643364 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3110,9 +3110,8 @@ class BoUpSLP {
       SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries,
       unsigned NumParts, bool ForOrder = false);
 
-  /// \returns the scalarization cost for this list of values. Assuming that
-  /// this subtree gets vectorized, we may need to extract the values from the
-  /// roots. This method calculates the cost of extracting the values.
+  /// \returns the cost of gathering (inserting) the values in \p VL into a
+  /// vector.
   /// \param ForPoisonSrc true if initial vector is poison, false otherwise.
   InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
                                 Type *ScalarTy) const;
@@ -13498,9 +13497,10 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
               TTI::SK_InsertSubvector, VecTy, std::nullopt, CostKind,
               I * ScalarTyNumElements, cast<FixedVectorType>(ScalarTy));
     } else {
-      Cost = TTI->getScalarizationOverhead(VecTy, ~ShuffledElements,
+      Cost = TTI->getScalarizationOverhead(VecTy,
+                                           /*DemandedElts*/ ~ShuffledElements,
                                            /*Insert*/ true,
-                                           /*Extract*/ false, CostKind);
+                                           /*Extract*/ false, CostKind, VL);
     }
   }
   if (DuplicateNonConst)

diff  --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/vec-elt-insertion.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/vec-elt-insertion.ll
index eb8dd72e0304d9..85b8157c949f1f 100644
--- a/llvm/test/Transforms/SLPVectorizer/SystemZ/vec-elt-insertion.ll
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/vec-elt-insertion.ll
@@ -1,29 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt < %s -mtriple=s390x-unknown-linux -mcpu=z16 -S -passes=slp-vectorizer \
 ; RUN:   -pass-remarks-output=%t | FileCheck %s
 ; RUN: cat %t | FileCheck -check-prefix=REMARK %s
 ;
-; NB! This is a pre-commit version (for #112491) with current codegen and remarks.
-;
 ; Test functions that (at least currently) only gets vectorized if the
 ; insertion cost for an element load is counted as free.
 
+declare double @llvm.fmuladd.f64(double, double, double)
+
 ; This function needs the free element load to be recognized in SLP
 ; getGatherCost().
-define void @fun0(ptr nocapture %0, double %1) {
+define void @fun0(ptr %0, double %1) {
 ; CHECK-LABEL: define void @fun0(
-; CHECK:         fmul double
-; CHECK:         call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.sqrt.f64(
-; CHECK:         fmul double
-; CHECK:         call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.sqrt.f64(
+; CHECK-SAME: ptr [[TMP0:%.*]], double [[TMP1:%.*]]) #[[ATTR1:[0-9]+]] {
+; CHECK-NEXT:    [[TMP3:%.*]] = load double, ptr [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <2 x double> poison, double [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP4]], double [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = fmul <2 x double> [[TMP5]], splat (double 2.000000e+00)
+; CHECK-NEXT:    [[TMP7:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[TMP6]], <2 x double> [[TMP6]], <2 x double> zeroinitializer)
+; CHECK-NEXT:    [[TMP8:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[TMP6]], <2 x double> [[TMP6]], <2 x double> [[TMP7]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call <2 x double> @llvm.sqrt.v2f64(<2 x double> [[TMP8]])
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <2 x double> [[TMP9]], i32 1
+; CHECK-NEXT:    [[TMP12:%.*]] = fadd double [[TMP10]], [[TMP11]]
+; CHECK-NEXT:    store double [[TMP12]], ptr [[TMP0]], align 8
+; CHECK-NEXT:    ret void
 ;
 ; REMARK-LABEL: Function: fun0
 ; REMARK: Args:
-; REMARK-NEXT: - String:          'List vectorization was possible but not beneficial with cost '
-; REMARK-NEXT: - Cost:            '0'
+; REMARK-NEXT: - String:          'SLP vectorized with cost '
+; REMARK-NEXT: - Cost:            '-1'
 
   %3 = fmul double %1, 2.000000e+00
   %4 = tail call double @llvm.fmuladd.f64(double %3, double %3, double 0.000000e+00)
@@ -43,32 +49,31 @@ define void @fun0(ptr nocapture %0, double %1) {
 ; getVectorInstrCost().
 define void @fun1(double %0) {
 ; CHECK-LABEL: define void @fun1(
-; CHECK:         phi double
-; CHECK-NEXT:    phi double
-; CHECK-NEXT:    phi double
-; CHECK-NEXT:    phi double
-; CHECK-NEXT:    phi double
-; CHECK-NEXT:    phi double
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    fmul double
-; CHECK-NEXT:    fmul double
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    fsub double
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK-NEXT:    call double @llvm.fmuladd.f64(
-; CHECK:         fcmp olt double
-; CHECK-NEXT:    fcmp olt double
-; CHECK-NEXT:    or i1
+; CHECK-SAME: double [[TMP0:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x double> <double 0.000000e+00, double poison>, double [[TMP0]], i32 1
+; CHECK-NEXT:    br label %[[BB3:.*]]
+; CHECK:       [[BB3]]:
+; CHECK-NEXT:    [[TMP4:%.*]] = phi <2 x double> [ <double poison, double undef>, [[TMP1:%.*]] ], [ poison, %[[BB3]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = phi <2 x double> [ zeroinitializer, [[TMP1]] ], [ poison, %[[BB3]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = phi <2 x double> [ zeroinitializer, [[TMP1]] ], [ [[TMP18:%.*]], %[[BB3]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = fsub <2 x double> zeroinitializer, [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = fsub <2 x double> zeroinitializer, [[TMP5]]
+; CHECK-NEXT:    [[TMP9:%.*]] = fsub <2 x double> zeroinitializer, [[TMP4]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load double, ptr null, align 8
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul <2 x double> [[TMP7]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[TMP8]], <2 x double> [[TMP8]], <2 x double> [[TMP11]])
+; CHECK-NEXT:    [[TMP13:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[TMP9]], <2 x double> [[TMP9]], <2 x double> [[TMP12]])
+; CHECK-NEXT:    [[TMP14:%.*]] = fcmp olt <2 x double> [[TMP13]], [[TMP2]]
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x i1> [[TMP14]], i32 0
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x i1> [[TMP14]], i32 1
+; CHECK-NEXT:    [[TMP17:%.*]] = or i1 [[TMP15]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18]] = insertelement <2 x double> poison, double [[TMP10]], i32 1
+; CHECK-NEXT:    br label %[[BB3]]
 ;
 ; REMARK-LABEL: Function: fun1
 ; REMARK: Args:
-; REMARK:      - String:          'List vectorization was possible but not beneficial with cost '
-; REMARK-NEXT: - Cost:            '0'
+; REMARK:      - String:          'SLP vectorized with cost '
+; REMARK-NEXT: - Cost:            '-1'
 
   br label %2
 
@@ -98,19 +103,24 @@ define void @fun1(double %0) {
   br label %2
 }
 
-declare double @llvm.fmuladd.f64(double, double, double)
-
 ; This should *not* be vectorized as the insertion into the vector isn't free,
 ; which is recognized in SystemZTTImpl::getScalarizationOverhead().
 define void @fun2(ptr %0, ptr %Dst) {
 ; CHECK-LABEL: define void @fun2(
-; CHECK: insertelement
-; CHECK: store <2 x i64>
+; CHECK-SAME: ptr [[TMP0:%.*]], ptr [[DST:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load i64, ptr [[TMP0]], align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[TMP2]], 0
+; CHECK-NEXT:    br i1 [[TMP3]], label %[[BB4:.*]], label %[[BB5:.*]]
+; CHECK:       [[BB4]]:
+; CHECK-NEXT:    ret void
+; CHECK:       [[BB5]]:
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[DST]], i64 24
+; CHECK-NEXT:    store i64 [[TMP2]], ptr [[TMP6]], align 8
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[DST]], i64 16
+; CHECK-NEXT:    store i64 0, ptr [[TMP7]], align 8
+; CHECK-NEXT:    br label %[[BB4]]
 ;
-; REMARK-LABEL: Function: fun2
-; REMARK: Args:
-; REMARK-NEXT: - String:          'Stores SLP vectorized with cost '
-; REMARK-NEXT: - Cost:            '-1'
+; REMARK-NOT: Function: fun2
 
   %3 = load i64, ptr %0, align 8
   %4 = icmp eq i64 %3, 0
@@ -126,3 +136,55 @@ define void @fun2(ptr %0, ptr %Dst) {
   store i64 0, ptr %8, align 8
   br label %5
 }
+
+; This should *not* be vectorized as the load is immediately stored, in which
+; case MVC is preferred.
+define void @fun3(ptr %0)  {
+; CHECK-LABEL: define void @fun3(
+; CHECK-SAME: ptr [[TMP0:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP2:%.*]] = load ptr, ptr inttoptr (i64 568 to ptr), align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i64 40
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i64 48
+; CHECK-NEXT:    br label %[[BB5:.*]]
+; CHECK:       [[BB5]]:
+; CHECK-NEXT:    store ptr null, ptr [[TMP3]], align 8, !tbaa [[TBAA0:![0-9]+]]
+; CHECK-NEXT:    [[TMP6:%.*]] = load ptr, ptr inttoptr (i64 64 to ptr), align 8, !tbaa [[TBAA8:![0-9]+]]
+; CHECK-NEXT:    store ptr [[TMP6]], ptr [[TMP4]], align 8
+; CHECK-NEXT:    [[TMP7:%.*]] = tail call i64 [[TMP0]](ptr noundef poison, i64 noundef poison)
+; CHECK-NEXT:    br label %[[BB5]]
+;
+  %2 = load ptr, ptr inttoptr (i64 568 to ptr), align 8
+  %3 = getelementptr inbounds nuw i8, ptr %2, i64 40
+  %4 = getelementptr inbounds nuw i8, ptr %2, i64 48
+  br label %5
+
+5:
+  store ptr null, ptr %3, align 8, !tbaa !1
+  %6 = load ptr, ptr inttoptr (i64 64 to ptr), align 8, !tbaa !9
+  store ptr %6, ptr %4, align 8
+  %7 = tail call i64 %0(ptr noundef poison, i64 noundef poison)
+  br label %5
+}
+
+!1 = !{!2, !7, i64 40}
+!2 = !{!"arc", !3, i64 0, !6, i64 8, !7, i64 16, !7, i64 24, !8, i64 32, !7, i64 40, !7, i64 48, !6, i64 56, !6, i64 64}
+!3 = !{!"int", !4, i64 0}
+!4 = !{!"omnipotent char", !5, i64 0}
+!5 = !{!"Simple C/C++ TBAA"}
+!6 = !{!"long", !4, i64 0}
+!7 = !{!"any pointer", !4, i64 0}
+!8 = !{!"short", !4, i64 0}
+!9 = !{!10, !7, i64 64}
+!10 = !{!"node", !6, i64 0, !3, i64 8, !7, i64 16, !7, i64 24, !7, i64 32, !7, i64 40, !7, i64 48, !7, i64 56, !7, i64 64, !7, i64 72, !6, i64 80, !6, i64 88, !3, i64 96, !3, i64 100}
+;.
+; CHECK: [[TBAA0]] = !{[[META1:![0-9]+]], [[META6:![0-9]+]], i64 40}
+; CHECK: [[META1]] = !{!"arc", [[META2:![0-9]+]], i64 0, [[META5:![0-9]+]], i64 8, [[META6]], i64 16, [[META6]], i64 24, [[META7:![0-9]+]], i64 32, [[META6]], i64 40, [[META6]], i64 48, [[META5]], i64 56, [[META5]], i64 64}
+; CHECK: [[META2]] = !{!"int", [[META3:![0-9]+]], i64 0}
+; CHECK: [[META3]] = !{!"omnipotent char", [[META4:![0-9]+]], i64 0}
+; CHECK: [[META4]] = !{!"Simple C/C++ TBAA"}
+; CHECK: [[META5]] = !{!"long", [[META3]], i64 0}
+; CHECK: [[META6]] = !{!"any pointer", [[META3]], i64 0}
+; CHECK: [[META7]] = !{!"short", [[META3]], i64 0}
+; CHECK: [[TBAA8]] = !{[[META9:![0-9]+]], [[META6]], i64 64}
+; CHECK: [[META9]] = !{!"node", [[META5]], i64 0, [[META2]], i64 8, [[META6]], i64 16, [[META6]], i64 24, [[META6]], i64 32, [[META6]], i64 40, [[META6]], i64 48, [[META6]], i64 56, [[META6]], i64 64, [[META6]], i64 72, [[META5]], i64 80, [[META5]], i64 88, [[META2]], i64 96, [[META2]], i64 100}
+;.


        


More information about the llvm-commits mailing list