[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 19 13:47:06 PDT 2025
https://github.com/preames created https://github.com/llvm/llvm-project/pull/144953
Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that.
>From 4ea64c11865d73655e5a4755006e2322b79b10a5 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 19 Jun 2025 13:34:16 -0700
Subject: [PATCH] [TTI] Plumb CostKind through getPartialReductionCost
Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.
---
llvm/include/llvm/Analysis/TargetTransformInfo.h | 2 +-
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 3 ++-
llvm/lib/Analysis/TargetTransformInfo.cpp | 5 +++--
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 5 ++++-
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 3 ++-
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 4 ++--
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 3 ++-
.../Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp | 5 ++++-
llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h | 2 +-
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 +-
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 2 +-
11 files changed, 23 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..864b3bfffa2ab 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1333,7 +1333,7 @@ class TargetTransformInfo {
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const;
+ std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const;
/// \return The maximum interleave factor that any transform should try to
/// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..c131307e4fdef 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -657,7 +657,8 @@ class TargetTransformInfoImplBase {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const {
+ std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
return InstructionCost::getInvalid();
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..f1c28896777a3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
InstructionCost TargetTransformInfo::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
- PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+ PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
AccumType, VF, OpAExtend, OpBExtend,
- BinOp);
+ BinOp, CostKind);
}
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..39d6bc7a8800d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5396,10 +5396,13 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);
+ if (CostKind == TTI::TCK_RecipThroughput)
+ return Invalid;
+
// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..7518e9e662967 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -387,7 +387,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const override;
+ std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const override;
bool enableOrderedReductions() const override { return true; }
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 46e30ce4c18a9..26267648e8d4b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -298,7 +298,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
// zve32x is broken for partial_reduce_umla, but let's make sure we
// don't generate them.
@@ -313,7 +313,7 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
// Note: Asuming all vqdot* variants are equal cost
// TODO: Thread CostKind through this API
return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
- TTI::TCK_RecipThroughput);
+ CostKind);
}
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index dd7e9f7709f8e..16b658af8850c 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -112,7 +112,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const override;
+ std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const override;
bool shouldExpandReduction(const IntrinsicInst *II) const override;
bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..b9cc5c328066e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -199,11 +199,14 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
if (!VF.isFixed() || !ST->hasSIMD128())
return Invalid;
+ if (CostKind == TTI::TCK_RecipThroughput)
+ return Invalid;
+
InstructionCost Cost(TTI::TCC_Basic);
// Possible options:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..4e183921ab873 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -87,7 +87,7 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const override;
+ std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const override;
TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2f4416d2782e8..414047aa2eb0c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
- VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+ VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
return Cost.isValid();
},
Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..029a2b10131f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -338,7 +338,7 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
PhiType, VF, GetExtendKind(ExtAR),
- GetExtendKind(ExtBR), Opcode);
+ GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
}
void VPPartialReductionRecipe::execute(VPTransformState &State) {
More information about the llvm-commits
mailing list