[llvm] [AArch64][NFC] Move getPartialReductionCost into cpp file (PR #123370)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 08:49:09 PST 2025


https://github.com/david-arm created https://github.com/llvm/llvm-project/pull/123370

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.

>From f92cbd2a4804398efda507c0f894ec0821f58268 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 17 Jan 2025 16:45:57 +0000
Subject: [PATCH] [AArch64][NFC] Move getPartialReductionCost into cpp file

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
---
 .../AArch64/AArch64TargetTransformInfo.cpp    | 61 +++++++++++++++++++
 .../AArch64/AArch64TargetTransformInfo.h      | 57 +----------------
 2 files changed, 62 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7f10bfed739b41..ba26af129f2757 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
   }
   return false;
 }
+
+InstructionCost
+AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+                        Type *AccumType, ElementCount VF,
+                        TTI::PartialReductionExtendKind OpAExtend,
+                        TTI::PartialReductionExtendKind OpBExtend,
+                        std::optional<unsigned> BinOp) const {
+  InstructionCost Invalid = InstructionCost::getInvalid();
+  InstructionCost Cost(TTI::TCC_Basic);
+
+  if (Opcode != Instruction::Add)
+    return Invalid;
+
+  if (InputTypeA != InputTypeB)
+    return Invalid;
+
+  EVT InputEVT = EVT::getEVT(InputTypeA);
+  EVT AccumEVT = EVT::getEVT(AccumType);
+
+  if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+    return Invalid;
+  if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
+  if (InputEVT == MVT::i8) {
+    switch (VF.getKnownMinValue()) {
+    default:
+      return Invalid;
+    case 8:
+      if (AccumEVT == MVT::i32)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i64)
+        return Invalid;
+      break;
+    case 16:
+      if (AccumEVT == MVT::i64)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i32)
+        return Invalid;
+      break;
+    }
+  } else if (InputEVT == MVT::i16) {
+    // FIXME: Allow i32 accumulator but increase cost, as we would extend
+    //        it to i64.
+    if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+      return Invalid;
+  } else
+    return Invalid;
+
+  // AArch64 supports lowering mixed extensions to a usdot but only if the
+  // i8mm or sve/streaming features are available.
+  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+       !ST->isSVEorStreamingSVEAvailable()))
+    return Invalid;
+
+  if (!BinOp || *BinOp != Instruction::Mul)
+    return Invalid;
+
+  return Cost;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const {
-
-    InstructionCost Invalid = InstructionCost::getInvalid();
-    InstructionCost Cost(TTI::TCC_Basic);
-
-    if (Opcode != Instruction::Add)
-      return Invalid;
-
-    if (InputTypeA != InputTypeB)
-      return Invalid;
-
-    EVT InputEVT = EVT::getEVT(InputTypeA);
-    EVT AccumEVT = EVT::getEVT(AccumType);
-
-    if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
-      return Invalid;
-    if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
-      return Invalid;
-
-    if (InputEVT == MVT::i8) {
-      switch (VF.getKnownMinValue()) {
-      default:
-        return Invalid;
-      case 8:
-        if (AccumEVT == MVT::i32)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i64)
-          return Invalid;
-        break;
-      case 16:
-        if (AccumEVT == MVT::i64)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i32)
-          return Invalid;
-        break;
-      }
-    } else if (InputEVT == MVT::i16) {
-      // FIXME: Allow i32 accumulator but increase cost, as we would extend
-      //        it to i64.
-      if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
-        return Invalid;
-    } else
-      return Invalid;
-
-    // AArch64 supports lowering mixed extensions to a usdot but only if the
-    // i8mm or sve/streaming features are available.
-    if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-        (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-         !ST->isSVEorStreamingSVEAvailable()))
-      return Invalid;
-
-    if (!BinOp || *BinOp != Instruction::Mul)
-      return Invalid;
-
-    return Cost;
-  }
+                          std::optional<unsigned> BinOp) const;
 
   bool enableOrderedReductions() const { return true; }
 



More information about the llvm-commits mailing list