[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 13:29:00 PST 2024
================
@@ -2185,30 +2186,150 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
unsigned Opcode = RdxDesc.getOpcode();
- // TODO: Support any-of and in-loop reductions.
+ // TODO: Support any-of reductions.
assert(
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"Any-of reduction not implemented in VPlan-based cost model currently.");
- assert(
- (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
- ForceTargetInstructionCost.getNumOccurrences() > 0) &&
- "In-loop reduction not implemented in VPlan-based cost model currently.");
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
"Inferred type and recurrence type mismatch.");
- // Cost = Reduction cost + BinOp cost
- InstructionCost Cost =
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost BaseCost =
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
- return Cost + Ctx.TTI.getMinMaxReductionCost(
- Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ BaseCost += Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ BaseCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
}
- return Cost + Ctx.TTI.getArithmeticReductionCost(
- Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ // Default cost.
+ return BaseCost;
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ Type *ElementTy = getResultType();
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost ReductionCost =
+ Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+ ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ }
+
+ // Extended cost
+ auto *SrcTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+ TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+ // Arm TTI will use the underlying instruction to determine the cost.
+ InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+ Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+ // ExtendedReduction Cost
+ InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (ExtendedRedCost.isValid() &&
+ ExtendedRedCost < ExtendedCost + ReductionCost) {
+ return ExtendedRedCost;
+ }
+ return ExtendedCost + ReductionCost;
+}
+
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+ : Ctx.Types.inferScalarType(getVecOp0());
+ auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ unsigned Opcode = RdxDesc.getOpcode();
+
+ assert(Opcode == Instruction::Add &&
+ "Reduction opcode must be add in the VPMulAccRecipe.");
+
+ // BaseCost = Reduction cost + BinOp cost
+ InstructionCost ReductionCost =
+ Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+ ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+ // Extended cost
+ InstructionCost ExtendedCost = 0;
+ if (IsExtended) {
+ auto *SrcTy = cast<VectorType>(
+ ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+ TTI::CastContextHint CCH0 =
+ computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+ // Arm TTI will use the underlying instruction to determine the cost.
+ ExtendedCost = Ctx.TTI.getCastInstrCost(
+ ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt0Instr()));
+ TTI::CastContextHint CCH1 =
+ computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+ ExtendedCost += Ctx.TTI.getCastInstrCost(
+ ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+ dyn_cast_if_present<Instruction>(getExt1Instr()));
+ }
+
+ // Mul cost
+ InstructionCost MulCost;
+ SmallVector<const Value *, 4> Operands;
+ Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+ if (IsExtended)
+ MulCost = Ctx.TTI.getArithmeticInstrCost(
+ Instruction::Mul, VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ Operands, MulInstr, &Ctx.TLI);
+ else {
+ VPValue *RHS = getVecOp1();
+ // Certain instructions can be cheaper to vectorize if they have a constant
+ // second vector operand. One example of this are shifts on x86.
+ TargetTransformInfo::OperandValueInfo RHSInfo = {
+ TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+ if (RHS->isLiveIn())
+ RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+ if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+ RHS->isDefinedOutsideLoopRegions())
+ RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+ MulCost = Ctx.TTI.getArithmeticInstrCost(
+ Instruction::Mul, VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ RHSInfo, Operands, MulInstr, &Ctx.TLI);
+ }
+
+ // MulAccReduction Cost
+ VectorType *SrcVecTy =
+ cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+ InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+ getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+ CostKind);
+
+ // Check if folding ext into ExtendedReduction is profitable.
+ if (MulAccCost.isValid() &&
----------------
fhahn wrote:
Can we only create the MullAcc recipe if the cost is valid?
https://github.com/llvm/llvm-project/pull/113903
More information about the llvm-commits
mailing list