[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (PR #113903)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 24 03:31:37 PDT 2025


================
@@ -2376,30 +2382,55 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   unsigned Opcode = RdxDesc.getOpcode();
   FastMathFlags FMFs = getFastMathFlags();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+  // Note that TTI should model the cost of moving result to the scalar register
+  // and the binOp cost in the getReductionCost().
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost +
-           Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
+    return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
-                                                   Ctx.CostKind);
+  if (ElementTy->isFloatingPointTy())
+    return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
+                                              Ctx.CostKind);
+  // Cannot get correct cost when quering TTI with FMFs not contains `reassoc`
+  // for non-FP reductions.
+  return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, std::nullopt,
+                                            Ctx.CostKind);
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  unsigned Opcode = RdxDesc.getOpcode();
+  Type *RedTy = Ctx.Types.inferScalarType(this);
+  auto *SrcVecTy =
+      cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+
+  assert(RedTy->isIntegerTy() &&
+         "ExtendedReduction only support integer type currently.");
+  return Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), RedTy, SrcVecTy,
+                                          std::nullopt, Ctx.CostKind);
+}
+
+InstructionCost
+VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *RedTy = Ctx.Types.inferScalarType(this);
+  auto *SrcVecTy =
+      cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+
----------------
fhahn wrote:

```suggestion
```

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list