[llvm] 8aaabad - [CostModel] Unify getCastInstrCost

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Tue May 26 03:31:58 PDT 2020


Author: Sam Parker
Date: 2020-05-26T11:29:57+01:00
New Revision: 8aaabadeced32a1cd959a5b1524b9c927e82bcc0

URL: https://github.com/llvm/llvm-project/commit/8aaabadeced32a1cd959a5b1524b9c927e82bcc0
DIFF: https://github.com/llvm/llvm-project/commit/8aaabadeced32a1cd959a5b1524b9c927e82bcc0.diff

LOG: [CostModel] Unify getCastInstrCost

Add the remaining cast instruction opcodes to the base implementation
of getUserCost and directly return the result. This allows
getInstructionThroughput to return getUserCost for the casts. This
has required changes to PPC and SystemZ because they implement
getUserCost and/or getCastInstrCost with adjustments for vector
operations. Adjusts have also been made in the remaining backends
that implement the method so that they still produce a cost of zero
or one for cost kinds other than throughput.

Differential Revision: https://reviews.llvm.org/D79848

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 60de70dcb16a..bd8d29cb22a1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -826,18 +826,18 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
       return TTI::TCC_Expensive;
     case Instruction::IntToPtr:
     case Instruction::PtrToInt:
+    case Instruction::SIToFP:
+    case Instruction::UIToFP:
+    case Instruction::FPToUI:
+    case Instruction::FPToSI:
     case Instruction::Trunc:
+    case Instruction::FPTrunc:
     case Instruction::BitCast:
-      if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) ==
-          TTI::TCC_Free)
-        return TTI::TCC_Free;
-      break;
     case Instruction::FPExt:
     case Instruction::SExt:
     case Instruction::ZExt:
-      if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == TTI::TCC_Free)
-        return TTI::TCC_Free;
-      break;
+    case Instruction::AddrSpaceCast:
+      return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I);
     }
     // By default, just classify everything as 'basic'.
     return TTI::TCC_Basic;

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 86952a5ad659..a14199515faf 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1325,10 +1325,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   case Instruction::Trunc:
   case Instruction::FPTrunc:
   case Instruction::BitCast:
-  case Instruction::AddrSpaceCast: {
-    Type *SrcTy = I->getOperand(0)->getType();
-    return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, CostKind, I);
-  }
+  case Instruction::AddrSpaceCast:
+    return getUserCost(I, CostKind);
   case Instruction::ExtractElement: {
     const ExtractElementInst *EEI = cast<ExtractElementInst>(I);
     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 1324945c4d4e..f0961646c31f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -295,11 +295,18 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     }
   }
 
+  // TODO: Allow non-throughput costs that aren't binary.
+  auto AdjustCost = [&CostKind](int Cost) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return Cost == 0 ? 0 : 1;
+    return Cost;
+  };
+
   EVT SrcTy = TLI->getValueType(DL, Src);
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
 
   static const TypeConversionCostTblEntry
   ConversionTbl[] = {
@@ -401,9 +408,9 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
                                                  DstTy.getSimpleVT(),
                                                  SrcTy.getSimpleVT()))
-    return Entry->Cost;
+    return AdjustCost(Entry->Cost);
 
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
 }
 
 int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 1ca74bfc3df0..c1af19727ba2 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -173,6 +173,13 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  // TODO: Allow non-throughput costs that aren't binary.
+  auto AdjustCost = [&CostKind](int Cost) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return Cost == 0 ? 0 : 1;
+    return Cost;
+  };
+
   // Single to/from double precision conversions.
   static const CostTblEntry NEONFltDblTbl[] = {
     // Vector fptrunc/fpext conversions.
@@ -185,14 +192,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                           ISD == ISD::FP_EXTEND)) {
     std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Src);
     if (const auto *Entry = CostTableLookup(NEONFltDblTbl, ISD, LT.second))
-      return LT.first * Entry->Cost;
+      return AdjustCost(LT.first * Entry->Cost);
   }
 
   EVT SrcTy = TLI->getValueType(DL, Src);
   EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
 
   // The extend of a load is free
   if (I && isa<LoadInst>(I->getOperand(0))) {
@@ -212,7 +219,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     };
     if (const auto *Entry = ConvertCostTableLookup(
             LoadConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
 
     static const TypeConversionCostTblEntry MVELoadConversionTbl[] = {
         {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0},
@@ -226,7 +233,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
       if (const auto *Entry =
               ConvertCostTableLookup(MVELoadConversionTbl, ISD,
                                      DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
-        return Entry->Cost;
+        return AdjustCost(Entry->Cost);
     }
   }
 
@@ -253,7 +260,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (auto *Entry = ConvertCostTableLookup(NEONDoubleWidthTbl, UserISD,
                                              DstTy.getSimpleVT(),
                                              SrcTy.getSimpleVT())) {
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
     }
   }
 
@@ -347,7 +354,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (const auto *Entry = ConvertCostTableLookup(NEONVectorConversionTbl, ISD,
                                                    DstTy.getSimpleVT(),
                                                    SrcTy.getSimpleVT()))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   // Scalar float to integer conversions.
@@ -377,7 +384,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (const auto *Entry = ConvertCostTableLookup(NEONFloatConversionTbl, ISD,
                                                    DstTy.getSimpleVT(),
                                                    SrcTy.getSimpleVT()))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   // Scalar integer to float conversions.
@@ -408,7 +415,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (const auto *Entry = ConvertCostTableLookup(NEONIntegerConversionTbl,
                                                    ISD, DstTy.getSimpleVT(),
                                                    SrcTy.getSimpleVT()))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   // MVE extend costs, taken from codegen tests. i8->i16 or i16->i32 is one
@@ -433,7 +440,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (const auto *Entry = ConvertCostTableLookup(MVEVectorConversionTbl,
                                                    ISD, DstTy.getSimpleVT(),
                                                    SrcTy.getSimpleVT()))
-      return Entry->Cost * ST->getMVEVectorCostFactor();
+      return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor());
   }
 
   // Scalar integer conversion costs.
@@ -452,13 +459,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (const auto *Entry = ConvertCostTableLookup(ARMIntegerConversionTbl, ISD,
                                                    DstTy.getSimpleVT(),
                                                    SrcTy.getSimpleVT()))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy()
                      ? ST->getMVEVectorCostFactor()
                      : 1;
-  return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  return AdjustCost(
+    BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
 }
 
 int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy,

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 92e32ca99090..381941df2fb4 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -263,7 +263,11 @@ unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy,
 
     std::pair<int, MVT> SrcLT = TLI.getTypeLegalizationCost(DL, SrcTy);
     std::pair<int, MVT> DstLT = TLI.getTypeLegalizationCost(DL, DstTy);
-    return std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
+    unsigned Cost = std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN);
+    // TODO: Allow non-throughput costs that aren't binary.
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return Cost == 0 ? 0 : 1;
+    return Cost;
   }
   return 1;
 }

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index 002905febbc8..a41c6b41a991 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -212,7 +212,8 @@ int PPCTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
 unsigned
 PPCTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
                         TTI::TargetCostKind CostKind) {
-  if (U->getType()->isVectorTy()) {
+  // We already implement getCastInstrCost and perform the vector adjustment there.
+  if (!isa<CastInst>(U) && U->getType()->isVectorTy()) {
     // Instructions that need to be split should cost more.
     std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, U->getType());
     return LT.first * BaseT::getUserCost(U, Operands, CostKind);
@@ -760,7 +761,11 @@ int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
 
   int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
-  return vectorCostAdjustment(Cost, Opcode, Dst, Src);
+  Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src);
+  // TODO: Allow non-throughput costs that aren't binary.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return Cost == 0 ? 0 : 1;
+  return Cost;
 }
 
 int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index d9efb40f0ab6..bce02cc793bf 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -691,6 +691,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst,
 int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                      TTI::TargetCostKind CostKind,
                                      const Instruction *I) {
+  // FIXME: Can the logic below also be used for these cost kinds?
+  if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) {
+    int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+    return BaseCost == 0 ? BaseCost : 1;
+  }
+
   unsigned DstScalarBits = Dst->getScalarSizeInBits();
   unsigned SrcScalarBits = Src->getScalarSizeInBits();
 

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 4170b102f2b3..6bfcadeaf8b6 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1368,6 +1368,13 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  // TODO: Allow non-throughput costs that aren't binary.
+  auto AdjustCost = [&CostKind](int Cost) {
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return Cost == 0 ? 0 : 1;
+    return Cost;
+  };
+
   // FIXME: Need a better design of the cost table to handle non-simple types of
   // potential massive combinations (elem_num x src_type x dst_type).
 
@@ -1969,7 +1976,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
   if (ST->hasSSE2() && !ST->hasAVX()) {
     if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
                                                    LTDest.second, LTSrc.second))
-      return LTSrc.first * Entry->Cost;
+      return AdjustCost(LTSrc.first * Entry->Cost);
   }
 
   EVT SrcTy = TLI->getValueType(DL, Src);
@@ -1977,7 +1984,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
 
   // The function getSimpleVT only handles simple value types.
   if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind);
+    return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind));
 
   MVT SimpleSrcTy = SrcTy.getSimpleVT();
   MVT SimpleDstTy = DstTy.getSimpleVT();
@@ -1986,59 +1993,59 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
     if (ST->hasBWI())
       if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD,
                                                      SimpleDstTy, SimpleSrcTy))
-        return Entry->Cost;
+        return AdjustCost(Entry->Cost);
 
     if (ST->hasDQI())
       if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD,
                                                      SimpleDstTy, SimpleSrcTy))
-        return Entry->Cost;
+        return AdjustCost(Entry->Cost);
 
     if (ST->hasAVX512())
       if (const auto *Entry = ConvertCostTableLookup(AVX512FConversionTbl, ISD,
                                                      SimpleDstTy, SimpleSrcTy))
-        return Entry->Cost;
+        return AdjustCost(Entry->Cost);
   }
 
   if (ST->hasBWI())
     if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
 
   if (ST->hasDQI())
     if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
 
   if (ST->hasAVX512())
     if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
 
   if (ST->hasAVX2()) {
     if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   if (ST->hasAVX()) {
     if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   if (ST->hasSSE41()) {
     if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
   if (ST->hasSSE2()) {
     if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD,
                                                    SimpleDstTy, SimpleSrcTy))
-      return Entry->Cost;
+      return AdjustCost(Entry->Cost);
   }
 
-  return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I);
+  return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I));
 }
 
 int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,


        


More information about the llvm-commits mailing list