[llvm] [AArch64] Refactor and refine cost-model for partial reductions (PR #158641)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 30 07:06:57 PDT 2025


================
@@ -5632,75 +5632,97 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
     TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
-  InstructionCost Cost(TTI::TCC_Basic);
 
   if (CostKind != TTI::TCK_RecipThroughput)
     return Invalid;
 
-  // Sub opcodes currently only occur in chained cases.
-  // Independent partial reduction subtractions are still costed as an add
+  if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+    return Invalid;
+
+  if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
+      (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
       OpAExtend == TTI::PR_None)
     return Invalid;
 
   // We only support multiply binary operations for now, and for muls we
   // require the types being extended to be the same.
-  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
-  // only if the i8mm or sve/streaming features are available.
-  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
-                OpBExtend == TTI::PR_None ||
-                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-                 !ST->isSVEorStreamingSVEAvailable())))
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
     return Invalid;
   assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
          "Unexpected values for OpBExtend or InputTypeB");
 
-  EVT InputEVT = EVT::getEVT(InputTypeA);
-  EVT AccumEVT = EVT::getEVT(AccumType);
+  bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
+  if (IsUSDot && !ST->hasMatMulInt8())
+    return Invalid;
 
-  unsigned VFMinValue = VF.getKnownMinValue();
+  unsigned Ratio =
+      AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+  if (VF.getKnownMinValue() < Ratio)
----------------
SamTebbs33 wrote:

Don't we still want to be catching the cases where the VF is equal to the ratio? Or is that covered by the type legalisation costs below?

https://github.com/llvm/llvm-project/pull/158641


More information about the llvm-commits mailing list