[llvm] [CostModel] Move default expand cost for partial reductions to BasicTTIImpl (PR #189905)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 1 01:03:22 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Sander de Smalen (sdesmalen-arm)
<details>
<summary>Changes</summary>
This is a follow-up of the suggestion left here:
https://github.com/llvm/llvm-project/pull/181707#discussion_r2995733831
The override functions in AMDGPU/ARM/SystemZ/X86 are required to avoid enabling partial reductions where they were previously disabled (I've added this for all targets that implement getArithmeticReductionCost).
---
Full diff: https://github.com/llvm/llvm-project/pull/189905.diff
6 Files Affected:
- (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+41)
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+8-28)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h (+9)
- (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+9)
- (modified) llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h (+10)
- (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+9)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 7812a301efbd7..02f054581529c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3435,6 +3435,47 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return RedCost + MulCost + 2 * ExtCost;
}
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind,
+ std::optional<FastMathFlags> FMF) const override {
+ unsigned Ratio =
+ AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+ if (VF.getKnownMinValue() <= Ratio)
+ return InstructionCost::getInvalid();
+
+ Type *InputVectorType = VectorType::get(InputTypeA, VF);
+ Type *ExtInputVectorType = VectorType::get(AccumType, VF);
+ Type *AccumVectorType =
+ VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
+
+ auto ExtendCostA = InstructionCost(0);
+ if (OpAExtend != TTI::PartialReductionExtendKind::PR_None)
+ ExtendCostA = getCastInstrCost(
+ TTI::getOpcodeForPartialReductionExtendKind(OpAExtend),
+ ExtInputVectorType, InputVectorType, TTI::CastContextHint::None,
+ CostKind);
+
+ // TODO: add cost of extracting subvectors from the source vector that
+ // is to be partially reduced.
+ auto ReductionOpCost =
+ Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
+
+ if (!BinOp)
+ return ExtendCostA + ReductionOpCost;
+
+ auto ExtendCostB = InstructionCost(0);
+ if (OpBExtend != TTI::PartialReductionExtendKind::PR_None)
+ ExtendCostB = getCastInstrCost(
+ TTI::getOpcodeForPartialReductionExtendKind(OpBExtend),
+ ExtInputVectorType, InputVectorType, TTI::CastContextHint::None,
+ CostKind);
+ return ExtendCostA + ExtendCostB + ReductionOpCost +
+ getArithmeticInstrCost(*BinOp, ExtInputVectorType, CostKind);
+ }
+
InstructionCost getVectorSplitCost() const { return 1; }
/// @}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index de578ea29cbe9..e6eab7a3bdb87 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -6056,34 +6056,14 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
return Cost * 2;
}
- // Returns cost of expanding the partial reduction in ISel.
- auto GetExpandCost = [&]() -> InstructionCost {
- Type *ExtVectorType =
- VectorType::get(AccumVectorType->getElementType(), VF);
- auto ExtendCostA = getCastInstrCost(
- TTI::getOpcodeForPartialReductionExtendKind(OpAExtend), ExtVectorType,
- InputVectorType, TTI::CastContextHint::None, CostKind);
- auto RedOpCost =
- Ratio * getArithmeticInstrCost(Opcode, AccumVectorType, CostKind);
- if (!BinOp)
- return ExtendCostA + RedOpCost;
-
- auto ExtendCostB = getCastInstrCost(
- TTI::getOpcodeForPartialReductionExtendKind(OpBExtend), ExtVectorType,
- InputVectorType, TTI::CastContextHint::None, CostKind);
- return ExtendCostA + ExtendCostB + RedOpCost +
- getArithmeticInstrCost(*BinOp, ExtVectorType, CostKind);
- };
-
- if (IsSub) {
- // Slightly lower the cost of a sub reduction so that it can be considered
- // as candidate for 'cdot' operations. This is a somewhat arbitrary number,
- // because we don't yet model these operations directly.
- return (8 * GetExpandCost()) / 10;
- }
-
- // By default, assume the operation is expanded.
- return GetExpandCost();
+ InstructionCost ExpandCost = BaseT::getPartialReductionCost(
+ Opcode, InputTypeA, InputTypeB, AccumType, VF, OpAExtend, OpBExtend,
+ BinOp, CostKind, FMF);
+
+ // Slightly lower the cost of a sub reduction so that it can be considered
+ // as candidate for 'cdot' operations. This is a somewhat arbitrary number,
+ // because we don't yet model these operations directly.
+ return IsSub ? ((8 * ExpandCost) / 10) : ExpandCost;
}
InstructionCost
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
index ea2bf72836199..555c711a3b810 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
@@ -269,6 +269,15 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind) const override;
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind,
+ std::optional<FastMathFlags> FMF) const override {
+ return InstructionCost::getInvalid();
+ }
+
InstructionCost
getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
TTI::TargetCostKind CostKind) const override;
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index f766deb884e0b..0d6d5d202bddf 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -413,6 +413,15 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
TTI::TargetCostKind CostKind) const override;
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind,
+ std::optional<FastMathFlags> FMF) const override {
+ return InstructionCost::getInvalid();
+ }
+
/// getScalingFactorCost - Return the cost of the scaling used in
/// addressing mode represented by AM.
/// If the AM is supported, the return value must be >= 0.
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index d96036067c786..456604ef9f627 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -101,6 +101,16 @@ class SystemZTTIImpl final : public BasicTTIImplBase<SystemZTTIImpl> {
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = {},
const Instruction *CxtI = nullptr) const override;
+
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind,
+ std::optional<FastMathFlags> FMF) const override {
+ return InstructionCost::getInvalid();
+ }
+
InstructionCost
getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy,
ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index b3dde1555d0a0..b5124c3276896 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -224,6 +224,15 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind) const override;
+ InstructionCost getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
+ TTI::TargetCostKind CostKind,
+ std::optional<FastMathFlags> FMF) const override {
+ return InstructionCost::getInvalid();
+ }
+
InstructionCost getMinMaxCost(Intrinsic::ID IID, Type *Ty,
TTI::TargetCostKind CostKind,
FastMathFlags FMF) const;
``````````
</details>
https://github.com/llvm/llvm-project/pull/189905
More information about the llvm-commits
mailing list