[llvm] a733c1f - [AArch64][NFC] Move getPartialReductionCost into cpp file (#123370)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 20 06:07:08 PST 2025
Author: David Sherwood
Date: 2025-01-20T14:07:03Z
New Revision: a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb
URL: https://github.com/llvm/llvm-project/commit/a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb
DIFF: https://github.com/llvm/llvm-project/commit/a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb.diff
LOG: [AArch64][NFC] Move getPartialReductionCost into cpp file (#123370)
The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 050fd71d3b1438..cd093317275ee9 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4670,6 +4670,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
return LegalizationCost * LT.first;
}
+InstructionCost AArch64TTIImpl::getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+ return Invalid;
+ if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
+ if (InputEVT == MVT::i8) {
+ switch (VF.getKnownMinValue()) {
+ default:
+ return Invalid;
+ case 8:
+ if (AccumEVT == MVT::i32)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i64)
+ return Invalid;
+ break;
+ case 16:
+ if (AccumEVT == MVT::i64)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i32)
+ return Invalid;
+ break;
+ }
+ } else if (InputEVT == MVT::i16) {
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
+ // it to i64.
+ if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+ return Invalid;
+ } else
+ return Invalid;
+
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
+ // i8mm or sve/streaming features are available.
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable()))
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ return Cost;
+}
+
InstructionCost AArch64TTIImpl::getShuffleCost(
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int> Mask,
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
-
- InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
-
- if (Opcode != Instruction::Add)
- return Invalid;
-
- if (InputTypeA != InputTypeB)
- return Invalid;
-
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
-
- if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
- return Invalid;
- if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
- if (InputEVT == MVT::i8) {
- switch (VF.getKnownMinValue()) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
-
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
- return Cost;
- }
+ std::optional<unsigned> BinOp) const;
bool enableOrderedReductions() const { return true; }
More information about the llvm-commits
mailing list