[llvm] fa8bff0 - [CostModel] Unify getArithmeticInstrCost

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 10 01:09:11 PDT 2020


Author: Sam Parker
Date: 2020-06-10T09:08:45+01:00
New Revision: fa8bff0cd1ad28b78a8910ebb1be077ef010f91f

URL: https://github.com/llvm/llvm-project/commit/fa8bff0cd1ad28b78a8910ebb1be077ef010f91f
DIFF: https://github.com/llvm/llvm-project/commit/fa8bff0cd1ad28b78a8910ebb1be077ef010f91f.diff

LOG: [CostModel] Unify getArithmeticInstrCost

Add the remaining arithmetic opcodes into the generic implementation
of getUserCost and then call this from getInstructionThroughput. Most
of the backends have been modified to return the base implementation
for cost kinds other RecipThroughput. The outlier here is AMDGPU
which already uses getArithmeticInstrCost for all the cost kinds.
This change means that most of the opcodes can be removed from that
backends implementation of getUserCost.

Differential Revision: https://reviews.llvm.org/D80992

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
    llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9408485e95cb..96153be094a6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -376,6 +376,20 @@ class TargetTransformInfoImplBase {
                                   TTI::OperandValueProperties Opd2PropInfo,
                                   ArrayRef<const Value *> Args,
                                   const Instruction *CxtI = nullptr) {
+    // FIXME: A number of transformation tests seem to require these values
+    // which seems a little odd for how arbitary there are.
+    switch (Opcode) {
+    default:
+      break;
+    case Instruction::FDiv:
+    case Instruction::FRem:
+    case Instruction::SDiv:
+    case Instruction::SRem:
+    case Instruction::UDiv:
+    case Instruction::URem:
+      // FIXME: Unlikely to be true for CodeSize.
+      return TTI::TCC_Expensive;
+    }
     return 1;
   }
 
@@ -830,14 +844,33 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
                                    GEP->getPointerOperand(),
                                    Operands.drop_front());
     }
-    case Instruction::FDiv:
-    case Instruction::FRem:
-    case Instruction::SDiv:
-    case Instruction::SRem:
+    case Instruction::Add:
+    case Instruction::FAdd:
+    case Instruction::Sub:
+    case Instruction::FSub:
+    case Instruction::Mul:
+    case Instruction::FMul:
     case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::FDiv:
     case Instruction::URem:
-      // FIXME: Unlikely to be true for CodeSize.
-      return TTI::TCC_Expensive;
+    case Instruction::SRem:
+    case Instruction::FRem:
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor: {
+      TargetTransformInfo::OperandValueKind Op1VK, Op2VK;
+      TargetTransformInfo::OperandValueProperties Op1VP, Op2VP;
+      Op1VK = TTI::getOperandInfo(U->getOperand(0), Op1VP);
+      Op2VK = TTI::getOperandInfo(U->getOperand(1), Op2VP);
+      SmallVector<const Value *, 2> Operands(U->operand_values());
+      return TargetTTI->getArithmeticInstrCost(Opcode, Ty, CostKind,
+                                               Op1VK, Op2VK,
+                                               Op1VP, Op2VP, Operands, I);
+    }
     case Instruction::IntToPtr:
     case Instruction::PtrToInt:
     case Instruction::SIToFP:

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 361d9813a60e..3880bf6bbf37 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -608,6 +608,13 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     int ISD = TLI->InstructionOpcodeToISD(Opcode);
     assert(ISD && "Invalid opcode");
 
+    // TODO: Handle more cost kinds.
+    if (CostKind != TTI::TCK_RecipThroughput)
+      return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind,
+                                           Opd1Info, Opd2Info,
+                                           Opd1PropInfo, Opd2PropInfo,
+                                           Args, CxtI);
+
     std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
 
     bool IsFloat = Ty->isFPOrFPVectorTy();

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index ed5db3373127..7325597f68b8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1275,16 +1275,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   case Instruction::AShr:
   case Instruction::And:
   case Instruction::Or:
-  case Instruction::Xor: {
-    TargetTransformInfo::OperandValueKind Op1VK, Op2VK;
-    TargetTransformInfo::OperandValueProperties Op1VP, Op2VP;
-    Op1VK = getOperandInfo(I->getOperand(0), Op1VP);
-    Op2VK = getOperandInfo(I->getOperand(1), Op2VP);
-    SmallVector<const Value *, 2> Operands(I->operand_values());
-    return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind,
-                                  Op1VK, Op2VK,
-                                  Op1VP, Op2VP, Operands, I);
-  }
+  case Instruction::Xor:
+    return getUserCost(I, CostKind);
   case Instruction::FNeg: {
     TargetTransformInfo::OperandValueKind Op1VK, Op2VK;
     TargetTransformInfo::OperandValueProperties Op1VP, Op2VP;

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cc5157fbf9c8..1ec88b37e849 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -500,6 +500,12 @@ int AArch64TTIImpl::getArithmeticInstrCost(
     TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
+                                         Opd2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
+
   // Legalize the type.
   std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
 

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
index 085ba4702b5c..80864107fb18 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
@@ -437,8 +437,11 @@ int GCNTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
                                        const Instruction *CxtI) {
   EVT OrigTy = TLI->getValueType(DL, Ty);
   if (!OrigTy.isSimple()) {
-    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
-                                         Opd2Info,
+    // FIXME: We're having to query the throughput cost so that the basic
+    // implementation tries to generate legalize and scalarization costs. Maybe
+    // we could hoist the scalarization code here?
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, TTI::TCK_RecipThroughput,
+                                         Opd1Info, Opd2Info,
                                          Opd1PropInfo, Opd2PropInfo);
   }
 
@@ -1036,29 +1039,10 @@ GCNTTIImpl::getUserCost(const User *U, ArrayRef<const Value *> Operands,
 
     return getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, 0, nullptr);
   }
-  case Instruction::Add:
-  case Instruction::FAdd:
-  case Instruction::Sub:
-  case Instruction::FSub:
-  case Instruction::Mul:
-  case Instruction::FMul:
-  case Instruction::UDiv:
-  case Instruction::SDiv:
-  case Instruction::FDiv:
-  case Instruction::URem:
-  case Instruction::SRem:
-  case Instruction::FRem:
-  case Instruction::Shl:
-  case Instruction::LShr:
-  case Instruction::AShr:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor:
-  case Instruction::FNeg: {
+  case Instruction::FNeg:
     return getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind,
                                   TTI::OK_AnyValue, TTI::OK_AnyValue,
                                   TTI::OP_None, TTI::OP_None, Operands, I);
-  }
   default:
     break;
   }

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 6b4899f624d8..fb2c4c619f78 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -759,6 +759,12 @@ int ARMTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
                                        TTI::OperandValueProperties Opd2PropInfo,
                                        ArrayRef<const Value *> Args,
                                        const Instruction *CxtI) {
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
+                                         Op2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
+
   int ISDOpcode = TLI->InstructionOpcodeToISD(Opcode);
   std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
 

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index b70386f1565c..cc1aaf919434 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -250,6 +250,12 @@ unsigned HexagonTTIImpl::getArithmeticInstrCost(
     TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
+                                         Opd2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
+
   if (Ty->isVectorTy()) {
     std::pair<int, MVT> LT = TLI.getTypeLegalizationCost(DL, Ty);
     if (LT.second.isFloatingPoint())

diff  --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
index bc8897664b14..5ea5b2e12aeb 100644
--- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
+++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
@@ -740,6 +740,11 @@ int PPCTTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
                                        ArrayRef<const Value *> Args,
                                        const Instruction *CxtI) {
   assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode");
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
+                                         Op2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
 
   // Fallback to the default implementation.
   int Cost = BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,

diff  --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index f7a4c76fbbd5..cebc3fdc52c1 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -374,6 +374,12 @@ int SystemZTTIImpl::getArithmeticInstrCost(
     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
     const Instruction *CxtI) {
 
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
+                                         Op2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
+
   // TODO: return a good value for BB-VECTORIZER that includes the
   // immediate loads, which we do not want to count for the loop
   // vectorizer, since they are hopefully hoisted out of the loop. This

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 964ac06c863a..c9ebc5e1ffbc 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -177,6 +177,11 @@ int X86TTIImpl::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
                                        TTI::OperandValueProperties Opd2PropInfo,
                                        ArrayRef<const Value *> Args,
                                        const Instruction *CxtI) {
+  // TODO: Handle more cost kinds.
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
+                                         Op2Info, Opd1PropInfo,
+                                         Opd2PropInfo, Args, CxtI);
   // Legalize the type.
   std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
 


        


More information about the llvm-commits mailing list