[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 13:47:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-backend-webassembly

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that.

---
Full diff: https://github.com/llvm/llvm-project/pull/144953.diff


11 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+1-1) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2-1) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+3-2) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+4-1) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2-1) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+2-2) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+2-1) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp (+4-1) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h (+1-1) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+1-1) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+1-1) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..864b3bfffa2ab 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1333,7 +1333,7 @@ class TargetTransformInfo {
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, PartialReductionExtendKind OpAExtend,
       PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const;
+      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..c131307e4fdef 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -657,7 +657,8 @@ class TargetTransformInfoImplBase {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp = std::nullopt) const {
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..f1c28896777a3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
 InstructionCost TargetTransformInfo::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, PartialReductionExtendKind OpAExtend,
-    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
                                           AccumType, VF, OpAExtend, OpBExtend,
-                                          BinOp);
+                                          BinOp, CostKind);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..39d6bc7a8800d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5396,10 +5396,13 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
+  if (CostKind == TTI::TCK_RecipThroughput)
+    return Invalid;
+
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
   if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..7518e9e662967 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -387,7 +387,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const override;
 
   bool enableOrderedReductions() const override { return true; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 46e30ce4c18a9..26267648e8d4b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -298,7 +298,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
 
   // zve32x is broken for partial_reduce_umla, but let's make sure we
   // don't generate them.
@@ -313,7 +313,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
   // Note: Asuming all vqdot* variants are equal cost
   // TODO: Thread CostKind through this API
   return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
-                                            TTI::TCK_RecipThroughput);
+                                            CostKind);
 }
 
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index dd7e9f7709f8e..16b658af8850c 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -112,7 +112,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const override;
 
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..b9cc5c328066e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -199,11 +199,14 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   if (!VF.isFixed() || !ST->hasSIMD128())
     return Invalid;
 
+  if (CostKind == TTI::TCK_RecipThroughput)
+    return Invalid;
+
   InstructionCost Cost(TTI::TCC_Basic);
 
   // Possible options:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..4e183921ab873 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -87,7 +87,7 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
       TTI::PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const override;
+      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const override;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2f4416d2782e8..414047aa2eb0c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
                 Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+                VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..029a2b10131f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -338,7 +338,7 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
 
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
                                          PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144953


More information about the llvm-commits mailing list