[llvm] [TTI] Plumb CostKind through getPartialReductionCost (PR #144953)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 19 13:49:45 PDT 2025
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/144953
>From 5c8e30d69bf426569de1c43cf93559badb4b3120 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 19 Jun 2025 13:34:16 -0700
Subject: [PATCH] [TTI] Plumb CostKind through getPartialReductionCost
Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.
---
llvm/include/llvm/Analysis/TargetTransformInfo.h | 4 ++--
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 11 +++++------
llvm/lib/Analysis/TargetTransformInfo.cpp | 5 +++--
.../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 7 +++++--
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 11 +++++------
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 9 ++++-----
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 11 +++++------
.../WebAssembly/WebAssemblyTargetTransformInfo.cpp | 7 +++++--
.../WebAssembly/WebAssemblyTargetTransformInfo.h | 4 ++--
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 +-
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 6 +++---
11 files changed, 40 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 8f4ce80ada5ed..60d4f0ddf3500 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1332,8 +1332,8 @@ class TargetTransformInfo {
LLVM_ABI InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
- PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const;
+ PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const;
/// \return The maximum interleave factor that any transform should try to
/// perform for this target. This number depends on the level of parallelism
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a80b4c5179bad..c807b60e2b8d8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -652,12 +652,11 @@ class TargetTransformInfoImplBase {
virtual bool enableWritePrefetching() const { return false; }
virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
- virtual InstructionCost
- getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
- Type *AccumType, ElementCount VF,
- TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const {
+ virtual InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
return InstructionCost::getInvalid();
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 2d053e55bdfa9..f1c28896777a3 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
InstructionCost TargetTransformInfo::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
- PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
+ PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
AccumType, VF, OpAExtend, OpBExtend,
- BinOp);
+ BinOp, CostKind);
}
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ed051f295752e..9d5c984fa4f16 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
InstructionCost AArch64TTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return Invalid;
+
// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 0184e748b3d86..470af01be3154 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -382,12 +382,11 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
return BaseT::isLegalNTLoad(DataType, Alignment);
}
- InstructionCost
- getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
- Type *AccumType, ElementCount VF,
- TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const override;
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const override;
bool enableOrderedReductions() const override { return true; }
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 46e30ce4c18a9..a503a9731cf3b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
InstructionCost RISCVTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
// zve32x is broken for partial_reduce_umla, but let's make sure we
// don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
// Note: Asuming all vqdot* variants are equal cost
- // TODO: Thread CostKind through this API
- return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
- TTI::TCK_RecipThroughput);
+ return LT.first *
+ getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
}
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index dd7e9f7709f8e..1497b00541bfe 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -107,12 +107,11 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth) const override;
- InstructionCost
- getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
- Type *AccumType, ElementCount VF,
- TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const override;
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const override;
bool shouldExpandReduction(const IntrinsicInst *II) const override;
bool supportsScalableVectors() const override {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 978e08bb89551..4f159996e4c6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
if (!VF.isFixed() || !ST->hasSIMD128())
return Invalid;
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return Invalid;
+
InstructionCost Cost(TTI::TCC_Basic);
// Possible options:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 6b6d060076a80..d83b8d1f45dbd 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -86,8 +86,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp = std::nullopt) const override;
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind) const override;
TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 2f4416d2782e8..414047aa2eb0c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
- VF, OpAExtend, OpBExtend, BinOp->getOpcode());
+ VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
return Cost.isValid();
},
Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index f3b5c8cfa9885..22861eb1c7dfc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return TargetTransformInfo::PR_None;
};
- return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
- PhiType, VF, GetExtendKind(ExtAR),
- GetExtendKind(ExtBR), Opcode);
+ return Ctx.TTI.getPartialReductionCost(
+ getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
+ GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
}
void VPPartialReductionRecipe::execute(VPTransformState &State) {
More information about the llvm-commits
mailing list