[llvm] b963701 - [TTI] Plumb CostKind through getPartialReductionCost (#144953)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 15:29:59 PDT 2025


Author: Philip Reames
Date: 2025-06-19T15:29:56-07:00
New Revision: b96370131d1572feb9c51442ac8ba1ccb16d7071

URL: https://github.com/llvm/llvm-project/commit/b96370131d1572feb9c51442ac8ba1ccb16d7071
DIFF: https://github.com/llvm/llvm-project/commit/b96370131d1572feb9c51442ac8ba1ccb16d7071.diff

LOG: [TTI] Plumb CostKind through getPartialReductionCost (#144953)

Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9dc4eca82492e..ba47cef274bec 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1332,8 +1332,8 @@ class TargetTransformInfo {
   LLVM_ABI InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, PartialReductionExtendKind OpAExtend,
-      PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const;
+      PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d933752183949..640766cf8cd10 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -652,12 +652,11 @@ class TargetTransformInfoImplBase {
   virtual bool enableWritePrefetching() const { return false; }
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  virtual InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp = std::nullopt) const {
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index d9cb11de9c097..8cc7f8a9d2ab2 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
 InstructionCost TargetTransformInfo::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, PartialReductionExtendKind OpAExtend,
-    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
                                           AccumType, VF, OpAExtend, OpBExtend,
-                                          BinOp);
+                                          BinOp, CostKind);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..9d5c984fa4f16 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
 InstructionCost AArch64TTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return Invalid;
+
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
   if (Opcode != Instruction::Add && Opcode != Instruction::Sub)

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..470af01be3154 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -382,12 +382,11 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool enableOrderedReductions() const override { return true; }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 63c5f17a84872..1b80b0fcaf10a 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
 InstructionCost RISCVTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
 
   // zve32x is broken for partial_reduce_umla, but let's make sure we
   // don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
   Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
   // Note: Asuming all vqdot* variants are equal cost
-  // TODO: Thread CostKind through this API
-  return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
-                                            TTI::TCK_RecipThroughput);
+  return LT.first *
+         getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
 }
 
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 75d377abb0e7d..83ac71ed9da69 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -100,12 +100,11 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   TargetTransformInfo::PopcntSupportKind
   getPopcntSupport(unsigned TyWidth) const override;
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..4f159996e4c6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
 InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   if (!VF.isFixed() || !ST->hasSIMD128())
     return Invalid;
 
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return Invalid;
+
   InstructionCost Cost(TTI::TCC_Basic);
 
   // Possible options:

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..d83b8d1f45dbd 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -86,8 +86,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
   InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-      TTI::PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const override;
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
 

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e14f985efd96a..9a2cd94eda58f 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
                 Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+                VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
             return Cost.isValid();
           },
           Range)) {

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..22861eb1c7dfc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
-  return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+  return Ctx.TTI.getPartialReductionCost(
+      getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
+      GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {


        


More information about the llvm-commits mailing list