[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (PR #113903)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 24 03:31:37 PDT 2025
================
@@ -2376,30 +2382,55 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
unsigned Opcode = RdxDesc.getOpcode();
FastMathFlags FMFs = getFastMathFlags();
- // TODO: Support any-of and in-loop reductions.
+ // TODO: Support any-of reductions.
assert(
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"Any-of reduction not implemented in VPlan-based cost model currently.");
- assert(
- (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
- ForceTargetInstructionCost.getNumOccurrences() > 0) &&
- "In-loop reduction not implemented in VPlan-based cost model currently.");
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
"Inferred type and recurrence type mismatch.");
- // Cost = Reduction cost + BinOp cost
- InstructionCost Cost =
- Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
+ // Note that TTI should model the cost of moving result to the scalar register
+ // and the binOp cost in the getReductionCost().
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
- return Cost +
- Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
+ return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
}
- return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
- Ctx.CostKind);
+ if (ElementTy->isFloatingPointTy())
+ return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
+ Ctx.CostKind);
+ // Cannot get correct cost when quering TTI with FMFs not contains `reassoc`
+ // for non-FP reductions.
+ return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, std::nullopt,
+ Ctx.CostKind);
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ unsigned Opcode = RdxDesc.getOpcode();
+ Type *RedTy = Ctx.Types.inferScalarType(this);
+ auto *SrcVecTy =
+ cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+
+ assert(RedTy->isIntegerTy() &&
+ "ExtendedReduction only support integer type currently.");
+ return Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), RedTy, SrcVecTy,
+ std::nullopt, Ctx.CostKind);
+}
+
+InstructionCost
+VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ Type *RedTy = Ctx.Types.inferScalarType(this);
+ auto *SrcVecTy =
+ cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+
----------------
fhahn wrote:
```suggestion
```
https://github.com/llvm/llvm-project/pull/113903
More information about the llvm-commits
mailing list