[llvm] b3e77c6 - [SVE] Remove invalid calls to VectorType::getNumElements from BasicTTIImpl

Christopher Tetreault via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 16 14:16:30 PDT 2020


Author: Christopher Tetreault
Date: 2020-06-16T14:16:15-07:00
New Revision: b3e77c6d55853eea5f5c32ec8a3510c0b0e438e1

URL: https://github.com/llvm/llvm-project/commit/b3e77c6d55853eea5f5c32ec8a3510c0b0e438e1
DIFF: https://github.com/llvm/llvm-project/commit/b3e77c6d55853eea5f5c32ec8a3510c0b0e438e1.diff

LOG: [SVE] Remove invalid calls to VectorType::getNumElements from BasicTTIImpl

Summary:
Most of these operations are reasonable for scalable vectors. Due to
this, we have decided not to change the interface to specifically take
FixedVectorType despite the fact that the current implementations make
fixed width assumptions. Instead, we cast to FixedVectorType and assert
in the body. If a developer makes some change in the future that causes
one of these asserts to fire, they should either change their code or
make the function they are trying to call handle scalable vectors.

Reviewers: efriedma, samparker, RKSimon, craig.topper, sdesmalen, c-rhodes

Reviewed By: efriedma

Subscribers: tschuett, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D81495

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/BasicTTIImpl.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index d31332b27923..1e36aae36ee0 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -80,7 +80,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   /// Estimate a cost of Broadcast as an extract and sequence of insert
   /// operations.
-  unsigned getBroadcastShuffleOverhead(VectorType *VTy) {
+  unsigned getBroadcastShuffleOverhead(FixedVectorType *VTy) {
     unsigned Cost = 0;
     // Broadcast cost is equal to the cost of extracting the zero'th element
     // plus the cost of inserting it into every element of the result vector.
@@ -96,7 +96,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   /// Estimate a cost of shuffle as a sequence of extract and insert
   /// operations.
-  unsigned getPermuteShuffleOverhead(VectorType *VTy) {
+  unsigned getPermuteShuffleOverhead(FixedVectorType *VTy) {
     unsigned Cost = 0;
     // Shuffle cost is equal to the cost of extracting element from its argument
     // plus the cost of inserting them onto the result vector.
@@ -116,8 +116,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   /// Estimate a cost of subvector extraction as a sequence of extract and
   /// insert operations.
-  unsigned getExtractSubvectorOverhead(VectorType *VTy, int Index,
-                                       VectorType *SubVTy) {
+  unsigned getExtractSubvectorOverhead(FixedVectorType *VTy, int Index,
+                                       FixedVectorType *SubVTy) {
     assert(VTy && SubVTy &&
            "Can only extract subvectors from vectors");
     int NumSubElts = SubVTy->getNumElements();
@@ -139,8 +139,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   /// Estimate a cost of subvector insertion as a sequence of extract and
   /// insert operations.
-  unsigned getInsertSubvectorOverhead(VectorType *VTy, int Index,
-                                      VectorType *SubVTy) {
+  unsigned getInsertSubvectorOverhead(FixedVectorType *VTy, int Index,
+                                      FixedVectorType *SubVTy) {
     assert(VTy && SubVTy &&
            "Can only insert subvectors into vectors");
     int NumSubElts = SubVTy->getNumElements();
@@ -525,8 +525,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
-  unsigned getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
+  unsigned getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts,
                                     bool Insert, bool Extract) {
+    /// FIXME: a bitfield is not a reasonable abstraction for talking about
+    /// which elements are needed from a scalable vector
+    auto *Ty = cast<FixedVectorType>(InTy);
+
     assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
            "Vector size mismatch");
 
@@ -547,7 +551,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
   }
 
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
-  unsigned getScalarizationOverhead(VectorType *Ty, bool Insert, bool Extract) {
+  unsigned getScalarizationOverhead(VectorType *InTy, bool Insert,
+                                    bool Extract) {
+    auto *Ty = cast<FixedVectorType>(InTy);
+
     APInt DemandedElts = APInt::getAllOnesValue(Ty->getNumElements());
     return static_cast<T *>(this)->getScalarizationOverhead(Ty, DemandedElts,
                                                             Insert, Extract);
@@ -565,11 +572,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         auto *VecTy = dyn_cast<VectorType>(A->getType());
         if (VecTy) {
           // If A is a vector operand, VF should be 1 or correspond to A.
-          assert((VF == 1 || VF == VecTy->getNumElements()) &&
+          assert((VF == 1 ||
+                  VF == cast<FixedVectorType>(VecTy)->getNumElements()) &&
                  "Vector argument does not match VF");
         }
         else
-          VecTy = VectorType::get(A->getType(), VF);
+          VecTy = FixedVectorType::get(A->getType(), VF);
 
         Cost += getScalarizationOverhead(VecTy, false, true);
       }
@@ -578,7 +586,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return Cost;
   }
 
-  unsigned getScalarizationOverhead(VectorType *Ty, ArrayRef<const Value *> Args) {
+  unsigned getScalarizationOverhead(VectorType *InTy,
+                                    ArrayRef<const Value *> Args) {
+    auto *Ty = cast<FixedVectorType>(InTy);
+
     unsigned Cost = 0;
 
     Cost += getScalarizationOverhead(Ty, true, false);
@@ -638,7 +649,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // TODO: If one of the types get legalized by splitting, handle this
     // similarly to what getCastInstrCost() does.
     if (auto *VTy = dyn_cast<VectorType>(Ty)) {
-      unsigned Num = VTy->getNumElements();
+      unsigned Num = cast<FixedVectorType>(VTy)->getNumElements();
       unsigned Cost = static_cast<T *>(this)->getArithmeticInstrCost(
           Opcode, VTy->getScalarType(), CostKind);
       // Return the cost of multiple scalar invocation plus the cost of
@@ -652,19 +663,22 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, int Index,
                           VectorType *SubTp) {
+
     switch (Kind) {
     case TTI::SK_Broadcast:
-      return getBroadcastShuffleOverhead(Tp);
+      return getBroadcastShuffleOverhead(cast<FixedVectorType>(Tp));
     case TTI::SK_Select:
     case TTI::SK_Reverse:
     case TTI::SK_Transpose:
     case TTI::SK_PermuteSingleSrc:
     case TTI::SK_PermuteTwoSrc:
-      return getPermuteShuffleOverhead(Tp);
+      return getPermuteShuffleOverhead(cast<FixedVectorType>(Tp));
     case TTI::SK_ExtractSubvector:
-      return getExtractSubvectorOverhead(Tp, Index, SubTp);
+      return getExtractSubvectorOverhead(cast<FixedVectorType>(Tp), Index,
+                                         cast<FixedVectorType>(SubTp));
     case TTI::SK_InsertSubvector:
-      return getInsertSubvectorOverhead(Tp, Index, SubTp);
+      return getInsertSubvectorOverhead(cast<FixedVectorType>(Tp), Index,
+                                        cast<FixedVectorType>(SubTp));
     }
     llvm_unreachable("Unknown TTI::ShuffleKind");
   }
@@ -784,12 +798,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       bool SplitDst =
           TLI->getTypeAction(Dst->getContext(), TLI->getValueType(DL, Dst)) ==
           TargetLowering::TypeSplitVector;
-      if ((SplitSrc || SplitDst) && SrcVTy->getNumElements() > 1 &&
-          DstVTy->getNumElements() > 1) {
-        Type *SplitDstTy = VectorType::get(DstVTy->getElementType(),
-                                           DstVTy->getNumElements() / 2);
-        Type *SplitSrcTy = VectorType::get(SrcVTy->getElementType(),
-                                           SrcVTy->getNumElements() / 2);
+      if ((SplitSrc || SplitDst) &&
+          cast<FixedVectorType>(SrcVTy)->getNumElements() > 1 &&
+          cast<FixedVectorType>(DstVTy)->getNumElements() > 1) {
+        Type *SplitDstTy = VectorType::getHalfElementsVectorType(DstVTy);
+        Type *SplitSrcTy = VectorType::getHalfElementsVectorType(SrcVTy);
         T *TTI = static_cast<T *>(this);
         // If both types need to be split then the split is free.
         unsigned SplitCost =
@@ -801,7 +814,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
       // In other cases where the source or destination are illegal, assume
       // the operation will get scalarized.
-      unsigned Num = DstVTy->getNumElements();
+      unsigned Num = cast<FixedVectorType>(DstVTy)->getNumElements();
       unsigned Cost = static_cast<T *>(this)->getCastInstrCost(
           Opcode, Dst->getScalarType(), Src->getScalarType(),
           CostKind, I);
@@ -867,7 +880,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     // TODO: If one of the types get legalized by splitting, handle this
     // similarly to what getCastInstrCost() does.
     if (auto *ValVTy = dyn_cast<VectorType>(ValTy)) {
-      unsigned Num = ValVTy->getNumElements();
+      unsigned Num = cast<FixedVectorType>(ValVTy)->getNumElements();
       if (CondTy)
         CondTy = CondTy->getScalarType();
       unsigned Cost = static_cast<T *>(this)->getCmpSelInstrCost(
@@ -935,13 +948,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                       TTI::TargetCostKind CostKind,
                                       bool UseMaskForCond = false,
                                       bool UseMaskForGaps = false) {
-    auto *VT = cast<VectorType>(VecTy);
+    auto *VT = cast<FixedVectorType>(VecTy);
 
     unsigned NumElts = VT->getNumElements();
     assert(Factor > 1 && NumElts % Factor == 0 && "Invalid interleave factor");
 
     unsigned NumSubElts = NumElts / Factor;
-    VectorType *SubVT = VectorType::get(VT->getElementType(), NumSubElts);
+    auto *SubVT = FixedVectorType::get(VT->getElementType(), NumSubElts);
 
     // Firstly, the cost of load/store operation.
     unsigned Cost;
@@ -1050,8 +1063,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       return Cost;
 
     Type *I8Type = Type::getInt8Ty(VT->getContext());
-    VectorType *MaskVT = VectorType::get(I8Type, NumElts);
-    SubVT = VectorType::get(I8Type, NumSubElts);
+    auto *MaskVT = FixedVectorType::get(I8Type, NumElts);
+    SubVT = FixedVectorType::get(I8Type, NumSubElts);
 
     // The Mask shuffling cost is extract all the elements of the Mask
     // and insert each of them Factor times into the wide vector:
@@ -1119,7 +1132,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     Type *RetTy = ICA.getReturnType();
     unsigned VF = ICA.getVectorFactor();
     unsigned RetVF =
-        (RetTy->isVectorTy() ? cast<VectorType>(RetTy)->getNumElements() : 1);
+        (RetTy->isVectorTy() ? cast<FixedVectorType>(RetTy)->getNumElements()
+                             : 1);
     assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type");
     const IntrinsicInst *I = ICA.getInst();
     const SmallVectorImpl<Value *> &Args = ICA.getArgs();
@@ -1132,11 +1146,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       for (Value *Op : Args) {
         Type *OpTy = Op->getType();
         assert(VF == 1 || !OpTy->isVectorTy());
-        Types.push_back(VF == 1 ? OpTy : VectorType::get(OpTy, VF));
+        Types.push_back(VF == 1 ? OpTy : FixedVectorType::get(OpTy, VF));
       }
 
       if (VF > 1 && !RetTy->isVoidTy())
-        RetTy = VectorType::get(RetTy, VF);
+        RetTy = FixedVectorType::get(RetTy, VF);
 
       // Compute the scalarization overhead based on Args for a vector
       // intrinsic. A vectorizer will pass a scalar RetTy and VF > 1, while
@@ -1262,7 +1276,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       if (auto *RetVTy = dyn_cast<VectorType>(RetTy)) {
         if (!SkipScalarizationCost)
           ScalarizationCost = getScalarizationOverhead(RetVTy, true, false);
-        ScalarCalls = std::max(ScalarCalls, RetVTy->getNumElements());
+        ScalarCalls = std::max(ScalarCalls,
+                               cast<FixedVectorType>(RetVTy)->getNumElements());
         ScalarRetTy = RetTy->getScalarType();
       }
       SmallVector<Type *, 4> ScalarTys;
@@ -1271,7 +1286,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         if (auto *VTy = dyn_cast<VectorType>(Ty)) {
           if (!SkipScalarizationCost)
             ScalarizationCost += getScalarizationOverhead(VTy, false, true);
-          ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
+          ScalarCalls = std::max(ScalarCalls,
+                                 cast<FixedVectorType>(VTy)->getNumElements());
           Ty = Ty->getScalarType();
         }
         ScalarTys.push_back(Ty);
@@ -1629,7 +1645,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       unsigned ScalarizationCost = SkipScalarizationCost ?
         ScalarizationCostPassed : getScalarizationOverhead(RetVTy, true, false);
 
-      unsigned ScalarCalls = RetVTy->getNumElements();
+      unsigned ScalarCalls = cast<FixedVectorType>(RetVTy)->getNumElements();
       SmallVector<Type *, 4> ScalarTys;
       for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) {
         Type *Ty = Tys[i];
@@ -1643,7 +1659,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         if (auto *VTy = dyn_cast<VectorType>(Tys[i])) {
           if (!ICA.skipScalarizationCost())
             ScalarizationCost += getScalarizationOverhead(VTy, false, true);
-          ScalarCalls = std::max(ScalarCalls, VTy->getNumElements());
+          ScalarCalls = std::max(ScalarCalls,
+                                 cast<FixedVectorType>(VTy)->getNumElements());
         }
       }
       return ScalarCalls * ScalarCost + ScalarizationCost;
@@ -1718,7 +1735,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                       bool IsPairwise,
                                       TTI::TargetCostKind CostKind) {
     Type *ScalarTy = Ty->getElementType();
-    unsigned NumVecElts = Ty->getNumElements();
+    unsigned NumVecElts = cast<FixedVectorType>(Ty)->getNumElements();
     unsigned NumReduxLevels = Log2_32(NumVecElts);
     unsigned ArithCost = 0;
     unsigned ShuffleCost = 0;
@@ -1730,7 +1747,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         LT.second.isVector() ? LT.second.getVectorNumElements() : 1;
     while (NumVecElts > MVTLen) {
       NumVecElts /= 2;
-      VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts);
+      VectorType *SubTy = FixedVectorType::get(ScalarTy, NumVecElts);
       // Assume the pairwise shuffles add a cost.
       ShuffleCost += (IsPairwise + 1) *
                      ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty,
@@ -1769,7 +1786,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
                                   TTI::TargetCostKind CostKind) {
     Type *ScalarTy = Ty->getElementType();
     Type *ScalarCondTy = CondTy->getElementType();
-    unsigned NumVecElts = Ty->getNumElements();
+    unsigned NumVecElts = cast<FixedVectorType>(Ty)->getNumElements();
     unsigned NumReduxLevels = Log2_32(NumVecElts);
     unsigned CmpOpcode;
     if (Ty->isFPOrFPVectorTy()) {
@@ -1789,8 +1806,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         LT.second.isVector() ? LT.second.getVectorNumElements() : 1;
     while (NumVecElts > MVTLen) {
       NumVecElts /= 2;
-      VectorType *SubTy = VectorType::get(ScalarTy, NumVecElts);
-      CondTy = VectorType::get(ScalarCondTy, NumVecElts);
+      auto *SubTy = FixedVectorType::get(ScalarTy, NumVecElts);
+      CondTy = FixedVectorType::get(ScalarCondTy, NumVecElts);
 
       // Assume the pairwise shuffles add a cost.
       ShuffleCost += (IsPairwise + 1) *


        


More information about the llvm-commits mailing list