[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 19 13:47:06 PDT 2025


https://github.com/preames created https://github.com/llvm/llvm-project/pull/144953

Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that.

>From 4ea64c11865d73655e5a4755006e2322b79b10a5 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 19 Jun 2025 13:34:16 -0700
Subject: [PATCH] [TTI] Plumb CostKind through getPartialReductionCost

Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.
---
 llvm/include/llvm/Analysis/TargetTransformInfo.h             | 2 +-
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h         | 3 ++-
 llvm/lib/Analysis/TargetTransformInfo.cpp                    | 5 +++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp       | 5 ++++-
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h         | 3 ++-
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp           | 4 ++--
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h             | 3 ++-
 .../Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp    | 5 ++++-
 llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h | 2 +-
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp              | 2 +-
 llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp               | 2 +-
 11 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..864b3bfffa2ab 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1333,7 +1333,7 @@ class TargetTransformInfo {
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, PartialReductionExtendKind OpAExtend,
       PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const;
+      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const;
 
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..c131307e4fdef 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -657,7 +657,8 @@ class TargetTransformInfoImplBase {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp = std::nullopt) const {
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const {
     return InstructionCost::getInvalid();
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..f1c28896777a3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
 InstructionCost TargetTransformInfo::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, PartialReductionExtendKind OpAExtend,
-    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+    PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+    TTI::TargetCostKind CostKind) const {
   return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
                                           AccumType, VF, OpAExtend, OpBExtend,
-                                          BinOp);
+                                          BinOp, CostKind);
 }
 
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..39d6bc7a8800d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5396,10 +5396,13 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   InstructionCost Cost(TTI::TCC_Basic);
 
+  if (CostKind == TTI::TCK_RecipThroughput)
+    return Invalid;
+
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
   if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..7518e9e662967 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -387,7 +387,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const override;
 
   bool enableOrderedReductions() const override { return true; }
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 46e30ce4c18a9..26267648e8d4b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -298,7 +298,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
 
   // zve32x is broken for partial_reduce_umla, but let's make sure we
   // don't generate them.
@@ -313,7 +313,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
   // Note: Asuming all vqdot* variants are equal cost
   // TODO: Thread CostKind through this API
   return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
-                                            TTI::TCK_RecipThroughput);
+                                            CostKind);
 }
 
 bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index dd7e9f7709f8e..16b658af8850c 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -112,7 +112,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const override;
+                          std::optional<unsigned> BinOp,
+                          TTI::TargetCostKind CostKind) const override;
 
   bool shouldExpandReduction(const IntrinsicInst *II) const override;
   bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..b9cc5c328066e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -199,11 +199,14 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
     TTI::PartialReductionExtendKind OpBExtend,
-    std::optional<unsigned> BinOp) const {
+    std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
   if (!VF.isFixed() || !ST->hasSIMD128())
     return Invalid;
 
+  if (CostKind == TTI::TCK_RecipThroughput)
+    return Invalid;
+
   InstructionCost Cost(TTI::TCC_Basic);
 
   // Possible options:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..4e183921ab873 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -87,7 +87,7 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
       TTI::PartialReductionExtendKind OpBExtend,
-      std::optional<unsigned> BinOp = std::nullopt) const override;
+      std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const override;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2f4416d2782e8..414047aa2eb0c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
                 Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+                VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..029a2b10131f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -338,7 +338,7 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
 
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
                                          PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {



More information about the llvm-commits mailing list