[llvm] [AArch64] Refactor and refine cost-model for partial reductions (PR #158641)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 1 07:42:49 PDT 2025


================
@@ -5632,75 +5632,93 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
     TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
-  InstructionCost Cost(TTI::TCC_Basic);
 
   if (CostKind != TTI::TCK_RecipThroughput)
     return Invalid;
 
-  // Sub opcodes currently only occur in chained cases.
-  // Independent partial reduction subtractions are still costed as an add
+  if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
+      (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
       OpAExtend == TTI::PR_None)
     return Invalid;
 
   // We only support multiply binary operations for now, and for muls we
   // require the types being extended to be the same.
-  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
-  // only if the i8mm or sve/streaming features are available.
-  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
-                OpBExtend == TTI::PR_None ||
-                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-                 !ST->isSVEorStreamingSVEAvailable())))
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
     return Invalid;
   assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
          "Unexpected values for OpBExtend or InputTypeB");
 
-  EVT InputEVT = EVT::getEVT(InputTypeA);
-  EVT AccumEVT = EVT::getEVT(AccumType);
+  bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
+  if (IsUSDot && !ST->hasMatMulInt8())
+    return Invalid;
+
+  unsigned Ratio =
+      AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+  if (VF.getKnownMinValue() <= Ratio)
+    return Invalid;
+
+  VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
+  VectorType *AccumVectorType =
+      VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
+  // We don't yet support all kinds of legalization (e.g. widening
+  // of <[vscale x] 1 x ..> accumulators)
----------------
david-arm wrote:

Can this specific example (widening a <vscale x 1 x ...> type) ever happen given the check above?

```
  if (VF.getKnownMinValue() <= Ratio)
    return Invalid;
```

Also, the VF should always be a power of 2 (and hence ratio should be a power of 2), which means we shouldn't really end up with TypeWidenVector. Perhaps a better example would be something like <2 x i128> where presumably we'd see a TypeExpandInteger action? Or maybe if we ever support FP element types we'd end up with TypePromoteFloat?

https://github.com/llvm/llvm-project/pull/158641


More information about the llvm-commits mailing list