[llvm] [LV] Decide early between partial reduce and a regular reduction based on cost-model (PR #169898)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 19 04:40:17 PST 2025


================
@@ -8282,9 +8290,49 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
   return tryToWiden(VPI);
 }
 
-VPRecipeBase *
-VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction,
-                                             unsigned ScaleFactor) {
+static VPExpressionRecipe *
+tryToCreateExtendedReduction(VPValue *BinOp, VPValue *Acc, VPCostContext &Ctx,
+                             VFRange &Range, Instruction *ReductionI,
+                             unsigned ScaleFactor, VPValue *Cond) {
+  Type *RedTy = Ctx.Types.inferScalarType(Acc);
+  unsigned ReductionOpcode = ReductionI->getOpcode();
+  using namespace llvm::VPlanPatternMatch;
+  VPValue *A;
+  match(BinOp, m_ZExtOrSExt(m_VPValue(A)));
+  auto IsExtendedRedValidAndClampRange =
+      [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost;
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(BinOp)->computeCost(VF, Ctx);
+          InstructionCost RedCost =
+              Ctx.TTI.getArithmeticInstrCost(Opcode, SrcVecTy, Ctx.CostKind);
+          TargetTransformInfo::PartialReductionExtendKind ExtKind =
+              TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
+          ExtRedCost = Ctx.TTI.getPartialReductionCost(
+              Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
+              llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
----------------
david-arm wrote:

In theory 'ExtCost' and 'RedCost' could be invalid costs too, so you might want to do something like:

```
  InstructionCost RedPlusExtCost = ExtCost + RedCost;
  return PartialRedCost.isValid() && RedPlusExtCost.isValid() && PartialRedCost < RedPlusExtCost;
```

What do you think?

https://github.com/llvm/llvm-project/pull/169898


More information about the llvm-commits mailing list