[llvm] 4e3c005 - [TTI] getScalarizationOverhead - use explicit VectorType operand

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue May 5 09:03:53 PDT 2020


Author: Simon Pilgrim
Date: 2020-05-05T16:59:23+01:00
New Revision: 4e3c005554f9fd886e838b0cdc533f43ab819867

URL: https://github.com/llvm/llvm-project/commit/4e3c005554f9fd886e838b0cdc533f43ab819867
DIFF: https://github.com/llvm/llvm-project/commit/4e3c005554f9fd886e838b0cdc533f43ab819867.diff

LOG: [TTI] getScalarizationOverhead - use explicit VectorType operand

getScalarizationOverhead is only ever called with vectors (and we already had a load of cast<VectorType> calls immediately inside the functions).

Followup to D78357

Reviewed By: @samparker

Differential Revision: https://reviews.llvm.org/D79341

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7a819f0aa5ad..f3e57567b6bd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -620,7 +620,7 @@ class TargetTransformInfo {
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract) const;
 
   /// Estimate the overhead of scalarizing an instructions unique
@@ -1261,7 +1261,8 @@ class TargetTransformInfo::Concept {
   virtual bool shouldBuildLookupTables() = 0;
   virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
   virtual bool useColdCCForColdCall(Function &F) = 0;
-  virtual unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  virtual unsigned getScalarizationOverhead(VectorType *Ty,
+                                            const APInt &DemandedElts,
                                             bool Insert, bool Extract) = 0;
   virtual unsigned
   getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
@@ -1609,7 +1610,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.useColdCCForColdCall(F);
   }
 
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract) override {
     return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
   }

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 6171ff9fbf0d..529cdbcb20dd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -240,7 +240,7 @@ class TargetTransformInfoImplBase {
 
   bool useColdCCForColdCall(Function &F) { return false; }
 
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract) {
     return 0;
   }

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index e885b1158d07..140e39d26da7 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -552,32 +552,30 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract) {
-    auto *VTy = cast<VectorType>(Ty);
-    assert(DemandedElts.getBitWidth() == VTy->getNumElements() &&
+    assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
            "Vector size mismatch");
 
     unsigned Cost = 0;
 
-    for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
+    for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
       if (!DemandedElts[i])
         continue;
       if (Insert)
         Cost += static_cast<T *>(this)->getVectorInstrCost(
-            Instruction::InsertElement, VTy, i);
+            Instruction::InsertElement, Ty, i);
       if (Extract)
         Cost += static_cast<T *>(this)->getVectorInstrCost(
-            Instruction::ExtractElement, VTy, i);
+            Instruction::ExtractElement, Ty, i);
     }
 
     return Cost;
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
-  unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) {
-    auto *VTy = cast<VectorType>(Ty);
-    APInt DemandedElts = APInt::getAllOnesValue(VTy->getNumElements());
+  unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) {
+    APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements());
     return static_cast<T *>(this)->getScalarizationOverhead(Ty, DemandedElts,
                                                             Insert, Extract);
   }
@@ -591,11 +589,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     SmallPtrSet<const Value*, 4> UniqueOperands;
     for (const Value *A : Args) {
       if (!isa<Constant>(A) && UniqueOperands.insert(A).second) {
-        Type *VecTy = nullptr;
-        if (A->getType()->isVectorTy()) {
-          VecTy = A->getType();
+        auto *VecTy = dyn_cast<VectorType>(A->getType());
+        if (VecTy) {
           // If A is a vector operand, VF should be 1 or correspond to A.
-          assert((VF == 1 || VF == cast<VectorType>(VecTy)->getNumElements()) &&
+          assert((VF == 1 || VF == VecTy->getNumElements()) &&
                  "Vector argument does not match VF");
         }
         else
@@ -608,17 +605,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return Cost;
   }
 
-  unsigned getScalarizationOverhead(Type *VecTy, ArrayRef<const Value *> Args) {
+  unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef<const Value *> Args) {
     unsigned Cost = 0;
-    auto *VecVTy = cast<VectorType>(VecTy);
 
-    Cost += getScalarizationOverhead(VecVTy, true, false);
+    Cost += getScalarizationOverhead(Ty, true, false);
     if (!Args.empty())
-      Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements());
+      Cost += getOperandsScalarizationOverhead(Args, Ty->getNumElements());
     else
       // When no information on arguments is provided, we add the cost
       // associated with one argument as a heuristic.
-      Cost += getScalarizationOverhead(VecVTy, false, true);
+      Cost += getScalarizationOverhead(Ty, false, true);
 
     return Cost;
   }
@@ -742,13 +738,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       break;
     }
 
+    auto *SrcVTy = dyn_cast<VectorType>(Src);
+    auto *DstVTy = dyn_cast<VectorType>(Dst);
+
     // If the cast is marked as legal (or promote) then assume low cost.
     if (SrcLT.first == DstLT.first &&
         TLI->isOperationLegalOrPromote(ISD, DstLT.second))
       return SrcLT.first;
 
     // Handle scalar conversions.
-    if (!Src->isVectorTy() && !Dst->isVectorTy()) {
+    if (!SrcVTy && !DstVTy) {
       // Scalar bitcasts are usually free.
       if (Opcode == Instruction::BitCast)
         return 0;
@@ -763,9 +762,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     }
 
     // Check vector-to-vector casts.
-    if (Dst->isVectorTy() && Src->isVectorTy()) {
-      auto *SrcVTy = cast<VectorType>(Src);
-      auto *DstVTy = cast<VectorType>(Dst);
+    if (DstVTy && SrcVTy) {
       // If the cast is between same-sized registers, then the check is simple.
       if (SrcLT.first == DstLT.first &&
           SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) {
@@ -819,19 +816,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       // Return the cost of multiple scalar invocation plus the cost of
       // inserting and extracting the values.
-      return getScalarizationOverhead(Dst, true, true) + Num * Cost;
+      return getScalarizationOverhead(DstVTy, true, true) + Num * Cost;
     }
 
     // We already handled vector-to-vector and scalar-to-scalar conversions.
     // This
     // is where we handle bitcast between vectors and scalars. We need to assume
     //  that the conversion is scalarized in one way or another.
-    if (Opcode == Instruction::BitCast)
+    if (Opcode == Instruction::BitCast) {
       // Illegal bitcasts are done by storing and loading from a stack slot.
-      return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true)
-                                : 0) +
-             (Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false)
-                                : 0);
+      return (SrcVTy ? getScalarizationOverhead(SrcVTy, false, true) : 0) +
+             (DstVTy ? getScalarizationOverhead(DstVTy, true, false) : 0);
+    }
 
     llvm_unreachable("Unhandled cast");
   }
@@ -923,7 +919,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
         // This is a vector load/store for some illegal type that is scalarized.
         // We must account for the cost of building or decomposing the vector.
-        Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store,
+        Cost += getScalarizationOverhead(cast<VectorType>(Src),
+                                         Opcode != Instruction::Store,
                                          Opcode == Instruction::Store);
       }
     }
@@ -1118,7 +1115,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       if (RetVF > 1 || VF > 1) {
         ScalarizationCost = 0;
         if (!RetTy->isVoidTy())
-          ScalarizationCost += getScalarizationOverhead(RetTy, true, false);
+          ScalarizationCost +=
+              getScalarizationOverhead(cast<VectorType>(RetTy), true, false);
         ScalarizationCost += getOperandsScalarizationOverhead(Args, VF);
       }
 
@@ -1224,21 +1222,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       unsigned ScalarizationCost = ScalarizationCostPassed;
       unsigned ScalarCalls = 1;
       Type *ScalarRetTy = RetTy;
-      if (RetTy->isVectorTy()) {
+      if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
         if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
-          ScalarizationCost = getScalarizationOverhead(RetTy, true, false);
-        ScalarCalls =
-            std::max(ScalarCalls, cast<VectorType>(RetTy)->getNumElements());
+          ScalarizationCost = getScalarizationOverhead(RetVTy, true, false);
+        ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements());
         ScalarRetTy = RetTy->getScalarType();
       }
       SmallVector<Type *, 4> ScalarTys;
       for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
         Type *Ty = Tys[i];
-        if (Ty->isVectorTy()) {
+        if (auto *VTy = dyn_cast<VectorType>(Ty)) {
           if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
-            ScalarizationCost += getScalarizationOverhead(Ty, false, true);
-          ScalarCalls =
-              std::max(ScalarCalls, cast<VectorType>(Ty)->getNumElements());
+            ScalarizationCost += getScalarizationOverhead(VTy, false, true);
+          ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
           Ty = Ty->getScalarType();
         }
         ScalarTys.push_back(Ty);
@@ -1588,12 +1584,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // Else, assume that we need to scalarize this intrinsic. For math builtins
     // this will emit a costly libcall, adding call overhead and spills. Make it
     // very expensive.
-    if (RetTy->isVectorTy()) {
+    if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
       unsigned ScalarizationCost =
           ((ScalarizationCostPassed != std::numeric_limits<unsigned>::max())
                ? ScalarizationCostPassed
-               : getScalarizationOverhead(RetTy, true, false));
-      unsigned ScalarCalls = cast<VectorType>(RetTy)->getNumElements();
+               : getScalarizationOverhead(RetVTy, true, false));
+      unsigned ScalarCalls = RetVTy->getNumElements();
       SmallVector<Type *, 4> ScalarTys;
       for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
         Type *Ty = Tys[i];
@@ -1604,14 +1600,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost(
           IID, RetTy->getScalarType(), ScalarTys, FMF, CostKind);
       for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
-        if (Tys[i]->isVectorTy()) {
+        if (auto *VTy = dyn_cast<VectorType>(Tys[i])) {
           if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
-            ScalarizationCost += getScalarizationOverhead(Tys[i], false, true);
-          ScalarCalls =
-              std::max(ScalarCalls, cast<VectorType>(Tys[i])->getNumElements());
+            ScalarizationCost += getScalarizationOverhead(VTy, false, true);
+          ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
         }
       }
-
       return ScalarCalls * ScalarCost + ScalarizationCost;
     }
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 743160a26966..95b17aa702d0 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -370,8 +370,10 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
   return TTIImpl->useColdCCForColdCall(F);
 }
 
-unsigned TargetTransformInfo::getScalarizationOverhead(
-    Type *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const {
+unsigned
+TargetTransformInfo::getScalarizationOverhead(VectorType *Ty,
+                                              const APInt &DemandedElts,
+                                              bool Insert, bool Extract) const {
   return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
 }
 

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 1dac45a029b3..d6e082d64e7a 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -807,7 +807,7 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
                                            CostKind);
     // Return the cost of multiple scalar invocation plus the cost of
     // inserting and extracting the values.
-    return BaseT::getScalarizationOverhead(Ty, Args) + Num * Cost;
+    return BaseT::getScalarizationOverhead(VTy, Args) + Num * Cost;
   }
 
   return BaseCost;
@@ -899,7 +899,7 @@ unsigned ARMTTIImpl::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
   // The scalarization cost should be a lot higher. We use the number of vector
   // elements plus the scalarization overhead.
   unsigned ScalarCost =
-      NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {});
+      NumElems * LT.first + BaseT::getScalarizationOverhead(VTy, {});
 
   if (Alignment < EltSize / 8)
     return ScalarCost;

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index b8571476d66a..99845ae7ca84 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -115,7 +115,7 @@ unsigned HexagonTTIImpl::getMinimumVF(unsigned ElemWidth) const {
   return (8 * ST.getVectorLength()) / ElemWidth;
 }
 
-unsigned HexagonTTIImpl::getScalarizationOverhead(Type *Ty,
+unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty,
                                                   const APInt &DemandedElts,
                                                   bool Insert, bool Extract) {
   return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 4b0625a67ffd..b2191910a238 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -101,7 +101,7 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
     return true;
   }
 
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract);
   unsigned getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
                                             unsigned VF);

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 4bf03da45397..9ec7b07fc3f8 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -464,7 +464,8 @@ int SystemZTTIImpl::getArithmeticInstrCost(
       return DivInstrCost;
   }
   else if (ST->hasVector()) {
-    unsigned VF = cast<VectorType>(Ty)->getNumElements();
+    auto *VTy = cast<VectorType>(Ty);
+    unsigned VF = VTy->getNumElements();
     unsigned NumVectors = getNumVectorRegs(Ty);
 
     // These vector operations are custom handled, but are still supported
@@ -477,7 +478,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
     if (DivRemConstPow2)
       return (NumVectors * (SignedDivRem ? SDivPow2Cost : 1));
     if (DivRemConst)
-      return VF * DivMulSeqCost + getScalarizationOverhead(Ty, Args);
+      return VF * DivMulSeqCost + getScalarizationOverhead(VTy, Args);
     if ((SignedDivRem || UnsignedDivRem) && VF > 4)
       // Temporary hack: disable high vectorization factors with integer
       // division/remainder, which will get scalarized and handled with
@@ -500,7 +501,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
         // inserting and extracting the values.
         unsigned ScalarCost =
             getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind);
-        unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(Ty, Args);
+        unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(VTy, Args);
         // FIXME: VF 2 for these FP operations are currently just as
         // expensive as for VF 4.
         if (VF == 2)
@@ -517,7 +518,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
 
     // There is no native support for FRem.
     if (Opcode == Instruction::FRem) {
-      unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(Ty, Args);
+      unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(VTy, Args);
       // FIXME: VF 2 for float is currently just as expensive as for VF 4.
       if (VF == 2 && ScalarBits == 32)
         Cost *= 2;
@@ -724,8 +725,9 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     }
   }
   else if (ST->hasVector()) {
-    assert (Dst->isVectorTy());
-    unsigned VF = cast<VectorType>(Src)->getNumElements();
+    auto *SrcVecTy = cast<VectorType>(Src);
+    auto *DstVecTy = cast<VectorType>(Dst);
+    unsigned VF = SrcVecTy->getNumElements();
     unsigned NumDstVectors = getNumVectorRegs(Dst);
     unsigned NumSrcVectors = getNumVectorRegs(Src);
 
@@ -781,8 +783,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
           (Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
         NeedsExtracts = false;
 
-      TotCost += getScalarizationOverhead(Src, false, NeedsExtracts);
-      TotCost += getScalarizationOverhead(Dst, NeedsInserts, false);
+      TotCost += getScalarizationOverhead(SrcVecTy, false, NeedsExtracts);
+      TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts, false);
 
       // FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
       if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -793,7 +795,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
 
     if (Opcode == Instruction::FPTrunc) {
       if (SrcScalarBits == 128)  // fp128 -> double/float + inserts of elements.
-        return VF /*ldxbr/lexbr*/ + getScalarizationOverhead(Dst, true, false);
+        return VF /*ldxbr/lexbr*/ +
+               getScalarizationOverhead(DstVecTy, true, false);
       else // double -> float
         return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
     }
@@ -806,7 +809,7 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
         return VF * 2;
       }
       // -> fp128.  VF * lxdb/lxeb + extraction of elements.
-      return VF + getScalarizationOverhead(Src, false, true);
+      return VF + getScalarizationOverhead(SrcVecTy, false, true);
     }
   }
 

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index f2f34f5f0bd1..98f698826605 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -2888,10 +2888,9 @@ int X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) {
   return BaseT::getVectorInstrCost(Opcode, Val, Index) + RegisterFileMoveCost;
 }
 
-unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
+unsigned X86TTIImpl::getScalarizationOverhead(VectorType *Ty,
                                               const APInt &DemandedElts,
                                               bool Insert, bool Extract) {
-  auto* VecTy = cast<VectorType>(Ty);
   unsigned Cost = 0;
 
   // For insertions, a ISD::BUILD_VECTOR style vector initialization can be much
@@ -2917,7 +2916,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
         // 128-bit vector is free.
         // NOTE: This assumes legalization widens vXf32 vectors.
         if (MScalarTy == MVT::f32)
-          for (unsigned i = 0, e = VecTy->getNumElements(); i < e; i += 4)
+          for (unsigned i = 0, e = Ty->getNumElements(); i < e; i += 4)
             if (DemandedElts[i])
               Cost--;
       }
@@ -2933,7 +2932,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
       // vector elements, which represents the number of unpacks we'll end up
       // performing.
       unsigned NumElts = LT.second.getVectorNumElements();
-      unsigned Pow2Elts = PowerOf2Ceil(VecTy->getNumElements());
+      unsigned Pow2Elts = PowerOf2Ceil(Ty->getNumElements());
       Cost += (std::min<unsigned>(NumElts, Pow2Elts) - 1) * LT.first;
     }
   }
@@ -2970,7 +2969,7 @@ int X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
       APInt DemandedElts = APInt::getAllOnesValue(NumElem);
       int Cost = BaseT::getMemoryOpCost(Opcode, VTy->getScalarType(), Alignment,
                                         AddressSpace, CostKind);
-      int SplitCost = getScalarizationOverhead(Src, DemandedElts,
+      int SplitCost = getScalarizationOverhead(VTy, DemandedElts,
                                                Opcode == Instruction::Load,
                                                Opcode == Instruction::Store);
       return NumElem * Cost + SplitCost;

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index eabd0f132363..ee9f3a67cd3b 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -135,7 +135,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
                          TTI::TargetCostKind CostKind,
                          const Instruction *I = nullptr);
   int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
-  unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
                                     bool Insert, bool Extract);
   int getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
                       unsigned AddressSpace,

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 612f32ec034b..b139f8520df3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5702,9 +5702,9 @@ int LoopVectorizationCostModel::computePredInstDiscount(
     // Compute the scalarization overhead of needed insertelement instructions
     // and phi nodes.
     if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) {
-      ScalarCost +=
-          TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF),
-                                       APInt::getAllOnesValue(VF), true, false);
+      ScalarCost += TTI.getScalarizationOverhead(
+          cast<VectorType>(ToVectorTy(I->getType(), VF)),
+          APInt::getAllOnesValue(VF), true, false);
       ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI);
     }
 
@@ -5720,8 +5720,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(
           Worklist.push_back(J);
         else if (needsExtract(J, VF))
           ScalarCost += TTI.getScalarizationOverhead(
-              ToVectorTy(J->getType(), VF), APInt::getAllOnesValue(VF), false,
-              true);
+              cast<VectorType>(ToVectorTy(J->getType(), VF)),
+              APInt::getAllOnesValue(VF), false, true);
       }
 
     // Scale the total scalar cost by block probability.
@@ -6016,8 +6016,8 @@ unsigned LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I,
   Type *RetTy = ToVectorTy(I->getType(), VF);
   if (!RetTy->isVoidTy() &&
       (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore()))
-    Cost += TTI.getScalarizationOverhead(RetTy, APInt::getAllOnesValue(VF),
-                                         true, false);
+    Cost += TTI.getScalarizationOverhead(
+        cast<VectorType>(RetTy), APInt::getAllOnesValue(VF), true, false);
 
   // Some targets keep addresses scalar.
   if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing())
@@ -6222,7 +6222,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
 
     if (ScalarPredicatedBB) {
       // Return cost for branches around scalarized and predicated blocks.
-      Type *Vec_i1Ty =
+      VectorType *Vec_i1Ty =
           VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF);
       return (TTI.getScalarizationOverhead(Vec_i1Ty, APInt::getAllOnesValue(VF),
                                            false, true) +


        


More information about the llvm-commits mailing list