[llvm] 37cb9bd - [VectorCombine] Add explicit CostKind to all getArithmeticInstrCost calls. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 9 05:06:25 PST 2024


Author: Simon Pilgrim
Date: 2024-12-09T13:01:14Z
New Revision: 37cb9bdecac2f291f54866bbb9660525ebe6fb16

URL: https://github.com/llvm/llvm-project/commit/37cb9bdecac2f291f54866bbb9660525ebe6fb16
DIFF: https://github.com/llvm/llvm-project/commit/37cb9bdecac2f291f54866bbb9660525ebe6fb16.diff

LOG: [VectorCombine] Add explicit CostKind to all getArithmeticInstrCost calls. NFC.

We currently hardwire CostKind to TTI::TCK_RecipThroughput which matches the default CostKind for getArithmeticInstrCost.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b0dd3d5a4fbaa2..2053cc6ab13ee4 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -421,8 +421,8 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
   // Get cost estimates for scalar and vector versions of the operation.
   bool IsBinOp = Instruction::isBinaryOp(Opcode);
   if (IsBinOp) {
-    ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
-    VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
+    ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
+    VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   } else {
     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
            "Expected a compare");
@@ -684,7 +684,7 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
 
   Type *ScalarTy = VecTy->getScalarType();
   InstructionCost OldCost =
-      TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy) +
+      TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
       TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
 
   // If the extract has one use, it will be eliminated, so count it in the
@@ -694,7 +694,7 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
     OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
 
   InstructionCost NewCost =
-      TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy) +
+      TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
       TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind);
 
   if (NewCost > OldCost)
@@ -871,8 +871,8 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
     IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
     ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
   } else {
-    ScalarOpCost =
-        TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
+    ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
+                                              VecTy->getScalarType(), CostKind);
   }
 
   // The existing splats may be kept around if other instructions use them.
@@ -995,8 +995,8 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
     VectorOpCost = TTI.getCmpSelInstrCost(
         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
   } else {
-    ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
-    VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
+    ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
+    VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
   }
 
   // Get cost estimate for the insert element. This cost will factor into
@@ -1101,7 +1101,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
       TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
                              CmpInst::makeCmpResultType(I0->getType()), Pred) *
           2 +
-      TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
+      TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
 
   // The proposed vector pattern is:
   // vcmp = cmp Pred X, VecC
@@ -1115,7 +1115,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
   ShufMask[CheapIndex] = ExpensiveIndex;
   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
                                 ShufMask, CostKind);
-  NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
+  NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
   NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
   NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
   NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
@@ -2616,8 +2616,8 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   // Get the costs of the shuffles + binops before and after with the new
   // shuffle masks.
   InstructionCost CostBefore =
-      TTI.getArithmeticInstrCost(Op0->getOpcode(), VT) +
-      TTI.getArithmeticInstrCost(Op1->getOpcode(), VT);
+      TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
+      TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
   CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
                                 InstructionCost(0), AddShuffleCost);
   CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
@@ -2630,8 +2630,8 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
   FixedVectorType *Op1SmallVT =
       FixedVectorType::get(VT->getScalarType(), V2.size());
   InstructionCost CostAfter =
-      TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT) +
-      TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT);
+      TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
+      TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
   CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
                                InstructionCost(0), AddShuffleMaskCost);
   std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});


        


More information about the llvm-commits mailing list