[llvm] [AArch64][NFC] Move getPartialReductionCost into cpp file (PR #123370)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 20 05:11:55 PST 2025
https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/123370
>From 681994c19dfefaa4dd4ac5f0ba2960d3aba9426d Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 17 Jan 2025 16:45:57 +0000
Subject: [PATCH 1/2] [AArch64][NFC] Move getPartialReductionCost into cpp file
The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 61 +++++++++++++++++++
.../AArch64/AArch64TargetTransformInfo.h | 57 +----------------
2 files changed, 62 insertions(+), 56 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7f10bfed739b41..ba26af129f2757 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
return false;
}
+
+InstructionCost
+AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+ Type *AccumType, ElementCount VF,
+ TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+ return Invalid;
+ if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
+ if (InputEVT == MVT::i8) {
+ switch (VF.getKnownMinValue()) {
+ default:
+ return Invalid;
+ case 8:
+ if (AccumEVT == MVT::i32)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i64)
+ return Invalid;
+ break;
+ case 16:
+ if (AccumEVT == MVT::i64)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i32)
+ return Invalid;
+ break;
+ }
+ } else if (InputEVT == MVT::i16) {
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
+ // it to i64.
+ if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+ return Invalid;
+ } else
+ return Invalid;
+
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
+ // i8mm or sve/streaming features are available.
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable()))
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ return Cost;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
-
- InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
-
- if (Opcode != Instruction::Add)
- return Invalid;
-
- if (InputTypeA != InputTypeB)
- return Invalid;
-
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
-
- if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
- return Invalid;
- if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
- if (InputEVT == MVT::i8) {
- switch (VF.getKnownMinValue()) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
-
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
- return Cost;
- }
+ std::optional<unsigned> BinOp) const;
bool enableOrderedReductions() const { return true; }
>From 4da541fa77d459bc8f39b37384aee1008f97e4c1 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Mon, 20 Jan 2025 13:10:51 +0000
Subject: [PATCH 2/2] Address review comment
---
.../AArch64/AArch64TargetTransformInfo.cpp | 121 +++++++++---------
1 file changed, 60 insertions(+), 61 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ba26af129f2757..f32e4d91c833e1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4664,6 +4664,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
return LegalizationCost * LT.first;
}
+InstructionCost AArch64TTIImpl::getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+ return Invalid;
+ if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
+ if (InputEVT == MVT::i8) {
+ switch (VF.getKnownMinValue()) {
+ default:
+ return Invalid;
+ case 8:
+ if (AccumEVT == MVT::i32)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i64)
+ return Invalid;
+ break;
+ case 16:
+ if (AccumEVT == MVT::i64)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i32)
+ return Invalid;
+ break;
+ }
+ } else if (InputEVT == MVT::i16) {
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
+ // it to i64.
+ if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+ return Invalid;
+ } else
+ return Invalid;
+
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
+ // i8mm or sve/streaming features are available.
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable()))
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ return Cost;
+}
+
InstructionCost AArch64TTIImpl::getShuffleCost(
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int> Mask,
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
@@ -5573,64 +5633,3 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
return false;
}
-
-InstructionCost
-AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
- Type *AccumType, ElementCount VF,
- TTI::PartialReductionExtendKind OpAExtend,
- TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
- InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
-
- if (Opcode != Instruction::Add)
- return Invalid;
-
- if (InputTypeA != InputTypeB)
- return Invalid;
-
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
-
- if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
- return Invalid;
- if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
- if (InputEVT == MVT::i8) {
- switch (VF.getKnownMinValue()) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
-
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
- return Cost;
-}
More information about the llvm-commits
mailing list