[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 9 02:05:49 PST 2024
================
@@ -2192,30 +2193,146 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
- // TODO: Support any-of and in-loop reductions.
+ // TODO: Support any-of reductions.
assert(
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"Any-of reduction not implemented in VPlan-based cost model currently.");
- assert(
- (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
- ForceTargetInstructionCost.getNumOccurrences() > 0) &&
- "In-loop reduction not implemented in VPlan-based cost model currently.");
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
"Inferred type and recurrence type mismatch.");
- // Cost = Reduction cost + BinOp cost
- InstructionCost Cost =
- Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+ // Note that TTI should model the cost of moving result to the scalar register
+ // and the binOp cost in the getReductionCost().
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
- return Cost + Ctx.TTI.getMinMaxReductionCost(
- Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy,
+ RdxDesc.getFastMathFlags(), CostKind);
}
- return Cost + Ctx.TTI.getArithmeticReductionCost(
- Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ return Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ Type *ElementTy = getResultType();
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ auto *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ // ExtendedReduction Cost
+ InstructionCost ExtendedRedCost =
+ Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
+ RdxDesc.getFastMathFlags(), CostKind);
+
+ assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
+ "created if the cost is invalid.");
+
+ InstructionCost ReductionCost;
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+ ReductionCost = Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ ReductionCost = Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ }
+
+ // Extended cost
+ TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+ // Arm TTI will use the underlying instruction to determine the cost.
+ InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+ Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (ExtendedRedCost.isValid() &&
+ ExtendedRedCost < ExtendedCost + ReductionCost) {
+ return ExtendedRedCost;
+ }
+
+ return ExtendedCost + ReductionCost;
+}
+
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+ Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+ : Ctx.Types.inferScalarType(getVecOp0());
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ assert(Opcode == Instruction::Add &&
+ "Reduction opcode must be add in the VPMulAccRecipe.");
+ // MulAccReduction Cost
+ VectorType *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ InstructionCost MulAccCost =
+ Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
+
+ assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
+ "created if the cost is invalid.");
+
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost ReductionCost = Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+ // Extended cost
----------------
fhahn wrote:
Redundant comment
https://github.com/llvm/llvm-project/pull/113903
More information about the llvm-commits
mailing list