[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 13:49:45 PDT 2025


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/144953

>From 5c8e30d69bf426569de1c43cf93559badb4b3120 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 19 Jun 2025 13:34:16 -0700
Subject: [PATCH] [TTI] Plumb CostKind through getPartialReductionCost

Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.
---
 llvm/include/llvm/Analysis/TargetTransformInfo.h      |  4 ++--
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h  | 11 +++++------
 llvm/lib/Analysis/TargetTransformInfo.cpp             |  5 +++--
 .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp |  7 +++++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h  | 11 +++++------
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp    |  9 ++++-----
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h      | 11 +++++------
 .../WebAssembly/WebAssemblyTargetTransformInfo.cpp    |  7 +++++--
 .../WebAssembly/WebAssemblyTargetTransformInfo.h      |  4 ++--
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp       |  2 +-
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp        |  6 +++---
 11 files changed, 40 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..60d4f0ddf3500 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1332,8 +1332,8 @@ class TargetTransformInfo {
   LLVM_ABI InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, PartialReductionExtendKind OpAExtend,
-      PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const;
+      PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..c807b60e2b8d8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -652,12 +652,11 @@ class TargetTransformInfoImplBase {
   virtual bool enableWritePrefetching() const { return false; }
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
-  virtual InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp = std::nullopt) const {
+  virtual InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..f1c28896777a3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
 InstructionCost TargetTransformInfo::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, PartialReductionExtendKind OpAExtend,
-    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
                                           AccumType, VF, OpAExtend, OpBExtend,
-                                          BinOp);
+                                          BinOp, CostKind);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..9d5c984fa4f16 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
 InstructionCost AArch64TTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return Invalid;
+
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
   if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..470af01be3154 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -382,12 +382,11 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool enableOrderedReductions() const override { return true; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 46e30ce4c18a9..a503a9731cf3b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
 InstructionCost RISCVTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
 
   // zve32x is broken for partial_reduce_umla, but let's make sure we
   // don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
   Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
   // Note: Asuming all vqdot* variants are equal cost
-  // TODO: Thread CostKind through this API
-  return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
-                                            TTI::TCK_RecipThroughput);
+  return LT.first *
+         getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
 }
 
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index dd7e9f7709f8e..1497b00541bfe 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -107,12 +107,11 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   TargetTransformInfo::PopcntSupportKind
   getPopcntSupport(unsigned TyWidth) const override;
 
-  InstructionCost
-  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
-                          Type *AccumType, ElementCount VF,
-                          TTI::PartialReductionExtendKind OpAExtend,
-                          TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+      ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
 
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..4f159996e4c6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
 InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-    TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   if (!VF.isFixed() || !ST->hasSIMD128())
     return Invalid;
 
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return Invalid;
+
   InstructionCost Cost(TTI::TCC_Basic);
 
   // Possible options:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..d83b8d1f45dbd 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -86,8 +86,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
   InstructionCost getPartialReductionCost(
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
-      TTI::PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const override;
+      TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+      TTI::TargetCostKind CostKind) const override;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2f4416d2782e8..414047aa2eb0c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
                 Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+                VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..22861eb1c7dfc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
-  return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+  return Ctx.TTI.getPartialReductionCost(
+      getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
+      GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {



More information about the llvm-commits mailing list