[llvm] [AArch64] Refactor and refine cost-model for partial reductions (PR #158641)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 2 02:09:54 PDT 2025


================
@@ -5632,75 +5632,93 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
     TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
-  InstructionCost Cost(TTI::TCC_Basic);
 
   if (CostKind != TTI::TCK_RecipThroughput)
     return Invalid;
 
-  // Sub opcodes currently only occur in chained cases.
-  // Independent partial reduction subtractions are still costed as an add
+  if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
+      (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
       OpAExtend == TTI::PR_None)
     return Invalid;
 
   // We only support multiply binary operations for now, and for muls we
   // require the types being extended to be the same.
-  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
-  // only if the i8mm or sve/streaming features are available.
-  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
-                OpBExtend == TTI::PR_None ||
-                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-                 !ST->isSVEorStreamingSVEAvailable())))
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
     return Invalid;
   assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
          "Unexpected values for OpBExtend or InputTypeB");
 
-  EVT InputEVT = EVT::getEVT(InputTypeA);
-  EVT AccumEVT = EVT::getEVT(AccumType);
+  bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
+  if (IsUSDot && !ST->hasMatMulInt8())
+    return Invalid;
+
+  unsigned Ratio =
+      AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+  if (VF.getKnownMinValue() <= Ratio)
+    return Invalid;
+
+  VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
+  VectorType *AccumVectorType =
+      VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
+  // We don't yet support all kinds of legalization (e.g. widening
+  // of <[vscale x] 1 x ..> accumulators)
+  auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
+                               EVT::getEVT(AccumVectorType));
+  switch (TA) {
+  default:
+    return Invalid;
+  case TargetLowering::TypeLegal:
+  case TargetLowering::TypePromoteInteger:
+  case TargetLowering::TypeSplitVector:
+    break;
+  }
+
+  // Check what kind of type-legalisation happens.
+  std::pair<InstructionCost, MVT> AccumLT =
+      getTypeLegalizationCost(AccumVectorType);
+  std::pair<InstructionCost, MVT> InputLT =
+      getTypeLegalizationCost(InputVectorType);
 
-  unsigned VFMinValue = VF.getKnownMinValue();
+  InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
 
-  if (VF.isScalable()) {
-    if (!ST->isSVEorStreamingSVEAvailable())
-      return Invalid;
+  // Prefer using full types by costing half-full input types as more expensive.
+  if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
----------------
sdesmalen-arm wrote:

Yes, that was actually the point. We only want to do this for fully packed scalable/fixed vectors, so a 64-bit fixed-length vector would be one we'd prefer not to favour.

https://github.com/llvm/llvm-project/pull/158641


More information about the llvm-commits mailing list