[llvm] 4e3c005 - [TTI] getScalarizationOverhead - use explicit VectorType operand
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue May 5 09:03:53 PDT 2020
Author: Simon Pilgrim
Date: 2020-05-05T16:59:23+01:00
New Revision: 4e3c005554f9fd886e838b0cdc533f43ab819867
URL: https://github.com/llvm/llvm-project/commit/4e3c005554f9fd886e838b0cdc533f43ab819867
DIFF: https://github.com/llvm/llvm-project/commit/4e3c005554f9fd886e838b0cdc533f43ab819867.diff
LOG: [TTI] getScalarizationOverhead - use explicit VectorType operand
getScalarizationOverhead is only ever called with vectors (and we already had a load of cast<VectorType> calls immediately inside the functions).
Followup to D78357
Reviewed By: @samparker
Differential Revision: https://reviews.llvm.org/D79341
Added:
Modified:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.h
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 7a819f0aa5ad..f3e57567b6bd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -620,7 +620,7 @@ class TargetTransformInfo {
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) const;
/// Estimate the overhead of scalarizing an instructions unique
@@ -1261,7 +1261,8 @@ class TargetTransformInfo::Concept {
virtual bool shouldBuildLookupTables() = 0;
virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0;
virtual bool useColdCCForColdCall(Function &F) = 0;
- virtual unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ virtual unsigned getScalarizationOverhead(VectorType *Ty,
+ const APInt &DemandedElts,
bool Insert, bool Extract) = 0;
virtual unsigned
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
@@ -1609,7 +1610,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.useColdCCForColdCall(F);
}
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) override {
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
}
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 6171ff9fbf0d..529cdbcb20dd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -240,7 +240,7 @@ class TargetTransformInfoImplBase {
bool useColdCCForColdCall(Function &F) { return false; }
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) {
return 0;
}
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index e885b1158d07..140e39d26da7 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -552,32 +552,30 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract) {
- auto *VTy = cast<VectorType>(Ty);
- assert(DemandedElts.getBitWidth() == VTy->getNumElements() &&
+ assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
"Vector size mismatch");
unsigned Cost = 0;
- for (int i = 0, e = VTy->getNumElements(); i < e; ++i) {
+ for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
if (!DemandedElts[i])
continue;
if (Insert)
Cost += static_cast<T *>(this)->getVectorInstrCost(
- Instruction::InsertElement, VTy, i);
+ Instruction::InsertElement, Ty, i);
if (Extract)
Cost += static_cast<T *>(this)->getVectorInstrCost(
- Instruction::ExtractElement, VTy, i);
+ Instruction::ExtractElement, Ty, i);
}
return Cost;
}
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
- unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) {
- auto *VTy = cast<VectorType>(Ty);
- APInt DemandedElts = APInt::getAllOnesValue(VTy->getNumElements());
+ unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) {
+ APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements());
return static_cast<T *>(this)->getScalarizationOverhead(Ty, DemandedElts,
Insert, Extract);
}
@@ -591,11 +589,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
SmallPtrSet<const Value*, 4> UniqueOperands;
for (const Value *A : Args) {
if (!isa<Constant>(A) && UniqueOperands.insert(A).second) {
- Type *VecTy = nullptr;
- if (A->getType()->isVectorTy()) {
- VecTy = A->getType();
+ auto *VecTy = dyn_cast<VectorType>(A->getType());
+ if (VecTy) {
// If A is a vector operand, VF should be 1 or correspond to A.
- assert((VF == 1 || VF == cast<VectorType>(VecTy)->getNumElements()) &&
+ assert((VF == 1 || VF == VecTy->getNumElements()) &&
"Vector argument does not match VF");
}
else
@@ -608,17 +605,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return Cost;
}
- unsigned getScalarizationOverhead(Type *VecTy, ArrayRef<const Value *> Args) {
+ unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef<const Value *> Args) {
unsigned Cost = 0;
- auto *VecVTy = cast<VectorType>(VecTy);
- Cost += getScalarizationOverhead(VecVTy, true, false);
+ Cost += getScalarizationOverhead(Ty, true, false);
if (!Args.empty())
- Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements());
+ Cost += getOperandsScalarizationOverhead(Args, Ty->getNumElements());
else
// When no information on arguments is provided, we add the cost
// associated with one argument as a heuristic.
- Cost += getScalarizationOverhead(VecVTy, false, true);
+ Cost += getScalarizationOverhead(Ty, false, true);
return Cost;
}
@@ -742,13 +738,16 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
break;
}
+ auto *SrcVTy = dyn_cast<VectorType>(Src);
+ auto *DstVTy = dyn_cast<VectorType>(Dst);
+
// If the cast is marked as legal (or promote) then assume low cost.
if (SrcLT.first == DstLT.first &&
TLI->isOperationLegalOrPromote(ISD, DstLT.second))
return SrcLT.first;
// Handle scalar conversions.
- if (!Src->isVectorTy() && !Dst->isVectorTy()) {
+ if (!SrcVTy && !DstVTy) {
// Scalar bitcasts are usually free.
if (Opcode == Instruction::BitCast)
return 0;
@@ -763,9 +762,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
}
// Check vector-to-vector casts.
- if (Dst->isVectorTy() && Src->isVectorTy()) {
- auto *SrcVTy = cast<VectorType>(Src);
- auto *DstVTy = cast<VectorType>(Dst);
+ if (DstVTy && SrcVTy) {
// If the cast is between same-sized registers, then the check is simple.
if (SrcLT.first == DstLT.first &&
SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) {
@@ -819,19 +816,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
- return getScalarizationOverhead(Dst, true, true) + Num * Cost;
+ return getScalarizationOverhead(DstVTy, true, true) + Num * Cost;
}
// We already handled vector-to-vector and scalar-to-scalar conversions.
// This
// is where we handle bitcast between vectors and scalars. We need to assume
// that the conversion is scalarized in one way or another.
- if (Opcode == Instruction::BitCast)
+ if (Opcode == Instruction::BitCast) {
// Illegal bitcasts are done by storing and loading from a stack slot.
- return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true)
- : 0) +
- (Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false)
- : 0);
+ return (SrcVTy ? getScalarizationOverhead(SrcVTy, false, true) : 0) +
+ (DstVTy ? getScalarizationOverhead(DstVTy, true, false) : 0);
+ }
llvm_unreachable("Unhandled cast");
}
@@ -923,7 +919,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
// This is a vector load/store for some illegal type that is scalarized.
// We must account for the cost of building or decomposing the vector.
- Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store,
+ Cost += getScalarizationOverhead(cast<VectorType>(Src),
+ Opcode != Instruction::Store,
Opcode == Instruction::Store);
}
}
@@ -1118,7 +1115,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
if (RetVF > 1 || VF > 1) {
ScalarizationCost = 0;
if (!RetTy->isVoidTy())
- ScalarizationCost += getScalarizationOverhead(RetTy, true, false);
+ ScalarizationCost +=
+ getScalarizationOverhead(cast<VectorType>(RetTy), true, false);
ScalarizationCost += getOperandsScalarizationOverhead(Args, VF);
}
@@ -1224,21 +1222,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
unsigned ScalarizationCost = ScalarizationCostPassed;
unsigned ScalarCalls = 1;
Type *ScalarRetTy = RetTy;
- if (RetTy->isVectorTy()) {
+ if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
- ScalarizationCost = getScalarizationOverhead(RetTy, true, false);
- ScalarCalls =
- std::max(ScalarCalls, cast<VectorType>(RetTy)->getNumElements());
+ ScalarizationCost = getScalarizationOverhead(RetVTy, true, false);
+ ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements());
ScalarRetTy = RetTy->getScalarType();
}
SmallVector<Type *, 4> ScalarTys;
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
Type *Ty = Tys[i];
- if (Ty->isVectorTy()) {
+ if (auto *VTy = dyn_cast<VectorType>(Ty)) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
- ScalarizationCost += getScalarizationOverhead(Ty, false, true);
- ScalarCalls =
- std::max(ScalarCalls, cast<VectorType>(Ty)->getNumElements());
+ ScalarizationCost += getScalarizationOverhead(VTy, false, true);
+ ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
Ty = Ty->getScalarType();
}
ScalarTys.push_back(Ty);
@@ -1588,12 +1584,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
// Else, assume that we need to scalarize this intrinsic. For math builtins
// this will emit a costly libcall, adding call overhead and spills. Make it
// very expensive.
- if (RetTy->isVectorTy()) {
+ if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
unsigned ScalarizationCost =
((ScalarizationCostPassed != std::numeric_limits<unsigned>::max())
? ScalarizationCostPassed
- : getScalarizationOverhead(RetTy, true, false));
- unsigned ScalarCalls = cast<VectorType>(RetTy)->getNumElements();
+ : getScalarizationOverhead(RetVTy, true, false));
+ unsigned ScalarCalls = RetVTy->getNumElements();
SmallVector<Type *, 4> ScalarTys;
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
Type *Ty = Tys[i];
@@ -1604,14 +1600,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost(
IID, RetTy->getScalarType(), ScalarTys, FMF, CostKind);
for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
- if (Tys[i]->isVectorTy()) {
+ if (auto *VTy = dyn_cast<VectorType>(Tys[i])) {
if (ScalarizationCostPassed == std::numeric_limits<unsigned>::max())
- ScalarizationCost += getScalarizationOverhead(Tys[i], false, true);
- ScalarCalls =
- std::max(ScalarCalls, cast<VectorType>(Tys[i])->getNumElements());
+ ScalarizationCost += getScalarizationOverhead(VTy, false, true);
+ ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
}
}
-
return ScalarCalls * ScalarCost + ScalarizationCost;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 743160a26966..95b17aa702d0 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -370,8 +370,10 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const {
return TTIImpl->useColdCCForColdCall(F);
}
-unsigned TargetTransformInfo::getScalarizationOverhead(
- Type *Ty, const APInt &DemandedElts, bool Insert, bool Extract) const {
+unsigned
+TargetTransformInfo::getScalarizationOverhead(VectorType *Ty,
+ const APInt &DemandedElts,
+ bool Insert, bool Extract) const {
return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 1dac45a029b3..d6e082d64e7a 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -807,7 +807,7 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
CostKind);
// Return the cost of multiple scalar invocation plus the cost of
// inserting and extracting the values.
- return BaseT::getScalarizationOverhead(Ty, Args) + Num * Cost;
+ return BaseT::getScalarizationOverhead(VTy, Args) + Num * Cost;
}
return BaseCost;
@@ -899,7 +899,7 @@ unsigned ARMTTIImpl::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
// The scalarization cost should be a lot higher. We use the number of vector
// elements plus the scalarization overhead.
unsigned ScalarCost =
- NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {});
+ NumElems * LT.first + BaseT::getScalarizationOverhead(VTy, {});
if (Alignment < EltSize / 8)
return ScalarCost;
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index b8571476d66a..99845ae7ca84 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -115,7 +115,7 @@ unsigned HexagonTTIImpl::getMinimumVF(unsigned ElemWidth) const {
return (8 * ST.getVectorLength()) / ElemWidth;
}
-unsigned HexagonTTIImpl::getScalarizationOverhead(Type *Ty,
+unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) {
return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 4b0625a67ffd..b2191910a238 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -101,7 +101,7 @@ class HexagonTTIImpl : public BasicTTIImplBase<HexagonTTIImpl> {
return true;
}
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract);
unsigned getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
unsigned VF);
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 4bf03da45397..9ec7b07fc3f8 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -464,7 +464,8 @@ int SystemZTTIImpl::getArithmeticInstrCost(
return DivInstrCost;
}
else if (ST->hasVector()) {
- unsigned VF = cast<VectorType>(Ty)->getNumElements();
+ auto *VTy = cast<VectorType>(Ty);
+ unsigned VF = VTy->getNumElements();
unsigned NumVectors = getNumVectorRegs(Ty);
// These vector operations are custom handled, but are still supported
@@ -477,7 +478,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
if (DivRemConstPow2)
return (NumVectors * (SignedDivRem ? SDivPow2Cost : 1));
if (DivRemConst)
- return VF * DivMulSeqCost + getScalarizationOverhead(Ty, Args);
+ return VF * DivMulSeqCost + getScalarizationOverhead(VTy, Args);
if ((SignedDivRem || UnsignedDivRem) && VF > 4)
// Temporary hack: disable high vectorization factors with integer
// division/remainder, which will get scalarized and handled with
@@ -500,7 +501,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
// inserting and extracting the values.
unsigned ScalarCost =
getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind);
- unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(Ty, Args);
+ unsigned Cost = (VF * ScalarCost) + getScalarizationOverhead(VTy, Args);
// FIXME: VF 2 for these FP operations are currently just as
// expensive as for VF 4.
if (VF == 2)
@@ -517,7 +518,7 @@ int SystemZTTIImpl::getArithmeticInstrCost(
// There is no native support for FRem.
if (Opcode == Instruction::FRem) {
- unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(Ty, Args);
+ unsigned Cost = (VF * LIBCALL_COST) + getScalarizationOverhead(VTy, Args);
// FIXME: VF 2 for float is currently just as expensive as for VF 4.
if (VF == 2 && ScalarBits == 32)
Cost *= 2;
@@ -724,8 +725,9 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
}
}
else if (ST->hasVector()) {
- assert (Dst->isVectorTy());
- unsigned VF = cast<VectorType>(Src)->getNumElements();
+ auto *SrcVecTy = cast<VectorType>(Src);
+ auto *DstVecTy = cast<VectorType>(Dst);
+ unsigned VF = SrcVecTy->getNumElements();
unsigned NumDstVectors = getNumVectorRegs(Dst);
unsigned NumSrcVectors = getNumVectorRegs(Src);
@@ -781,8 +783,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
NeedsExtracts = false;
- TotCost += getScalarizationOverhead(Src, false, NeedsExtracts);
- TotCost += getScalarizationOverhead(Dst, NeedsInserts, false);
+ TotCost += getScalarizationOverhead(SrcVecTy, false, NeedsExtracts);
+ TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts, false);
// FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -793,7 +795,8 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
if (Opcode == Instruction::FPTrunc) {
if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements.
- return VF /*ldxbr/lexbr*/ + getScalarizationOverhead(Dst, true, false);
+ return VF /*ldxbr/lexbr*/ +
+ getScalarizationOverhead(DstVecTy, true, false);
else // double -> float
return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
}
@@ -806,7 +809,7 @@ int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
return VF * 2;
}
// -> fp128. VF * lxdb/lxeb + extraction of elements.
- return VF + getScalarizationOverhead(Src, false, true);
+ return VF + getScalarizationOverhead(SrcVecTy, false, true);
}
}
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index f2f34f5f0bd1..98f698826605 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -2888,10 +2888,9 @@ int X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) {
return BaseT::getVectorInstrCost(Opcode, Val, Index) + RegisterFileMoveCost;
}
-unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
+unsigned X86TTIImpl::getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract) {
- auto* VecTy = cast<VectorType>(Ty);
unsigned Cost = 0;
// For insertions, a ISD::BUILD_VECTOR style vector initialization can be much
@@ -2917,7 +2916,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
// 128-bit vector is free.
// NOTE: This assumes legalization widens vXf32 vectors.
if (MScalarTy == MVT::f32)
- for (unsigned i = 0, e = VecTy->getNumElements(); i < e; i += 4)
+ for (unsigned i = 0, e = Ty->getNumElements(); i < e; i += 4)
if (DemandedElts[i])
Cost--;
}
@@ -2933,7 +2932,7 @@ unsigned X86TTIImpl::getScalarizationOverhead(Type *Ty,
// vector elements, which represents the number of unpacks we'll end up
// performing.
unsigned NumElts = LT.second.getVectorNumElements();
- unsigned Pow2Elts = PowerOf2Ceil(VecTy->getNumElements());
+ unsigned Pow2Elts = PowerOf2Ceil(Ty->getNumElements());
Cost += (std::min<unsigned>(NumElts, Pow2Elts) - 1) * LT.first;
}
}
@@ -2970,7 +2969,7 @@ int X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
APInt DemandedElts = APInt::getAllOnesValue(NumElem);
int Cost = BaseT::getMemoryOpCost(Opcode, VTy->getScalarType(), Alignment,
AddressSpace, CostKind);
- int SplitCost = getScalarizationOverhead(Src, DemandedElts,
+ int SplitCost = getScalarizationOverhead(VTy, DemandedElts,
Opcode == Instruction::Load,
Opcode == Instruction::Store);
return NumElem * Cost + SplitCost;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index eabd0f132363..ee9f3a67cd3b 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -135,7 +135,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);
int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index);
- unsigned getScalarizationOverhead(Type *Ty, const APInt &DemandedElts,
+ unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract);
int getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
unsigned AddressSpace,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 612f32ec034b..b139f8520df3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5702,9 +5702,9 @@ int LoopVectorizationCostModel::computePredInstDiscount(
// Compute the scalarization overhead of needed insertelement instructions
// and phi nodes.
if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) {
- ScalarCost +=
- TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF),
- APInt::getAllOnesValue(VF), true, false);
+ ScalarCost += TTI.getScalarizationOverhead(
+ cast<VectorType>(ToVectorTy(I->getType(), VF)),
+ APInt::getAllOnesValue(VF), true, false);
ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI);
}
@@ -5720,8 +5720,8 @@ int LoopVectorizationCostModel::computePredInstDiscount(
Worklist.push_back(J);
else if (needsExtract(J, VF))
ScalarCost += TTI.getScalarizationOverhead(
- ToVectorTy(J->getType(), VF), APInt::getAllOnesValue(VF), false,
- true);
+ cast<VectorType>(ToVectorTy(J->getType(), VF)),
+ APInt::getAllOnesValue(VF), false, true);
}
// Scale the total scalar cost by block probability.
@@ -6016,8 +6016,8 @@ unsigned LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I,
Type *RetTy = ToVectorTy(I->getType(), VF);
if (!RetTy->isVoidTy() &&
(!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore()))
- Cost += TTI.getScalarizationOverhead(RetTy, APInt::getAllOnesValue(VF),
- true, false);
+ Cost += TTI.getScalarizationOverhead(
+ cast<VectorType>(RetTy), APInt::getAllOnesValue(VF), true, false);
// Some targets keep addresses scalar.
if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing())
@@ -6222,7 +6222,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
if (ScalarPredicatedBB) {
// Return cost for branches around scalarized and predicated blocks.
- Type *Vec_i1Ty =
+ VectorType *Vec_i1Ty =
VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF);
return (TTI.getScalarizationOverhead(Vec_i1Ty, APInt::getAllOnesValue(VF),
false, true) +
More information about the llvm-commits
mailing list