[llvm] a733c1f - [AArch64][NFC] Move getPartialReductionCost into cpp file (#123370)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 20 06:07:08 PST 2025


Author: David Sherwood
Date: 2025-01-20T14:07:03Z
New Revision: a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb

URL: https://github.com/llvm/llvm-project/commit/a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb
DIFF: https://github.com/llvm/llvm-project/commit/a733c1fa90f3d26dbf399f7676e11fad0e3f5eeb.diff

LOG: [AArch64][NFC] Move getPartialReductionCost into cpp file (#123370)

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 050fd71d3b1438..cd093317275ee9 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4670,6 +4670,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
   return LegalizationCost * LT.first;
 }
 
+InstructionCost AArch64TTIImpl::getPartialReductionCost(
+    unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+    ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+    TTI::PartialReductionExtendKind OpBExtend,
+    std::optional<unsigned> BinOp) const {
+  InstructionCost Invalid = InstructionCost::getInvalid();
+  InstructionCost Cost(TTI::TCC_Basic);
+
+  if (Opcode != Instruction::Add)
+    return Invalid;
+
+  if (InputTypeA != InputTypeB)
+    return Invalid;
+
+  EVT InputEVT = EVT::getEVT(InputTypeA);
+  EVT AccumEVT = EVT::getEVT(AccumType);
+
+  if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+    return Invalid;
+  if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
+  if (InputEVT == MVT::i8) {
+    switch (VF.getKnownMinValue()) {
+    default:
+      return Invalid;
+    case 8:
+      if (AccumEVT == MVT::i32)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i64)
+        return Invalid;
+      break;
+    case 16:
+      if (AccumEVT == MVT::i64)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i32)
+        return Invalid;
+      break;
+    }
+  } else if (InputEVT == MVT::i16) {
+    // FIXME: Allow i32 accumulator but increase cost, as we would extend
+    //        it to i64.
+    if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+      return Invalid;
+  } else
+    return Invalid;
+
+  // AArch64 supports lowering mixed extensions to a usdot but only if the
+  // i8mm or sve/streaming features are available.
+  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+       !ST->isSVEorStreamingSVEAvailable()))
+    return Invalid;
+
+  if (!BinOp || *BinOp != Instruction::Mul)
+    return Invalid;
+
+  return Cost;
+}
+
 InstructionCost AArch64TTIImpl::getShuffleCost(
     TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int> Mask,
     TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const {
-
-    InstructionCost Invalid = InstructionCost::getInvalid();
-    InstructionCost Cost(TTI::TCC_Basic);
-
-    if (Opcode != Instruction::Add)
-      return Invalid;
-
-    if (InputTypeA != InputTypeB)
-      return Invalid;
-
-    EVT InputEVT = EVT::getEVT(InputTypeA);
-    EVT AccumEVT = EVT::getEVT(AccumType);
-
-    if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
-      return Invalid;
-    if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
-      return Invalid;
-
-    if (InputEVT == MVT::i8) {
-      switch (VF.getKnownMinValue()) {
-      default:
-        return Invalid;
-      case 8:
-        if (AccumEVT == MVT::i32)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i64)
-          return Invalid;
-        break;
-      case 16:
-        if (AccumEVT == MVT::i64)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i32)
-          return Invalid;
-        break;
-      }
-    } else if (InputEVT == MVT::i16) {
-      // FIXME: Allow i32 accumulator but increase cost, as we would extend
-      //        it to i64.
-      if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
-        return Invalid;
-    } else
-      return Invalid;
-
-    // AArch64 supports lowering mixed extensions to a usdot but only if the
-    // i8mm or sve/streaming features are available.
-    if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-        (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-         !ST->isSVEorStreamingSVEAvailable()))
-      return Invalid;
-
-    if (!BinOp || *BinOp != Instruction::Mul)
-      return Invalid;
-
-    return Cost;
-  }
+                          std::optional<unsigned> BinOp) const;
 
   bool enableOrderedReductions() const { return true; }
 


        


More information about the llvm-commits mailing list