[llvm] [AArch64][NFC] Move getPartialReductionCost into cpp file (PR #123370)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 17 08:49:09 PST 2025
https://github.com/david-arm created https://github.com/llvm/llvm-project/pull/123370
The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
>From f92cbd2a4804398efda507c0f894ec0821f58268 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 17 Jan 2025 16:45:57 +0000
Subject: [PATCH] [AArch64][NFC] Move getPartialReductionCost into cpp file
The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 61 +++++++++++++++++++
.../AArch64/AArch64TargetTransformInfo.h | 57 +----------------
2 files changed, 62 insertions(+), 56 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7f10bfed739b41..ba26af129f2757 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
return false;
}
+
+InstructionCost
+AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+ Type *AccumType, ElementCount VF,
+ TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+ return Invalid;
+ if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
+ if (InputEVT == MVT::i8) {
+ switch (VF.getKnownMinValue()) {
+ default:
+ return Invalid;
+ case 8:
+ if (AccumEVT == MVT::i32)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i64)
+ return Invalid;
+ break;
+ case 16:
+ if (AccumEVT == MVT::i64)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i32)
+ return Invalid;
+ break;
+ }
+ } else if (InputEVT == MVT::i16) {
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
+ // it to i64.
+ if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+ return Invalid;
+ } else
+ return Invalid;
+
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
+ // i8mm or sve/streaming features are available.
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable()))
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ return Cost;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
-
- InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
-
- if (Opcode != Instruction::Add)
- return Invalid;
-
- if (InputTypeA != InputTypeB)
- return Invalid;
-
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
-
- if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
- return Invalid;
- if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
- if (InputEVT == MVT::i8) {
- switch (VF.getKnownMinValue()) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
-
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
- return Cost;
- }
+ std::optional<unsigned> BinOp) const;
bool enableOrderedReductions() const { return true; }
More information about the llvm-commits
mailing list