[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)
Elvis Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 11 22:54:14 PST 2024
================
@@ -9387,6 +9330,171 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
return Plan;
}
+/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
+///
+/// This function try to match following pattern which will generate
+/// extended-reduction instruction.
+/// reduce(ext(...)).
+static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+ const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+ VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+ LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
+ using namespace VPlanPatternMatch;
+
+ VPValue *A;
+ Type *RedTy = RdxDesc.getRecurrenceType();
+
+ // Test if using extended-reduction is beneficial and clamp the range.
+ auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+ Type *SrcTy) -> bool {
+ return LoopVectorizationPlanner::getDecisionAndClampRange(
+ [&](ElementCount VF) {
+ VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+ VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+ CostKind);
+ InstructionCost ExtCost =
+ cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+ RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ InstructionCost RedCost;
+ if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+ Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+ RedCost = Ctx.TTI.getMinMaxReductionCost(
+ Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ } else {
+ RedCost = Ctx.TTI.getArithmeticReductionCost(
+ Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ }
+ return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
+ },
+ Range);
+ };
+
+ // Matched reduce(ext)).
+ if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+ if (!IsExtendedRedValidAndClampRange(
+ RdxDesc.getOpcode(),
+ cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+ Instruction::CastOps::ZExt,
+ Ctx.Types.inferScalarType(A)))
+ return nullptr;
+ return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+ cast<VPWidenCastRecipe>(VecOp), CondOp,
+ CM.useOrderedReductions(RdxDesc));
+ }
+ return nullptr;
+}
+
+/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
+///
+/// This function try to match following patterns which will generate mul-acc
+/// instructions.
+/// reduce.add(mul(...)),
+/// reduce.add(mul(ext(A), ext(B))),
+/// reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
+ const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+ VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+ LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
+ using namespace VPlanPatternMatch;
+
+ VPValue *A, *B;
+ Type *RedTy = RdxDesc.getRecurrenceType();
+
+ // Test if using mul-acc-reduction is beneficial and clamp the range.
+ auto IsMulAccValidAndClampRange =
+ [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+ VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ return LoopVectorizationPlanner::getDecisionAndClampRange(
+ [&](ElementCount VF) {
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ Type *SrcTy =
+ Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
+ VectorType *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+ VectorType *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+ InstructionCost MulAccCost =
+ Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulCost = Mul->computeCost(VF, Ctx);
+ InstructionCost RedCost = Ctx.TTI.getArithmeticReductionCost(
+ Instruction::Add, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+ InstructionCost ExtCost = 0;
+ if (Ext0)
+ ExtCost += Ext0->computeCost(VF, Ctx);
+ if (Ext1)
+ ExtCost += Ext1->computeCost(VF, Ctx);
+ if (OuterExt)
+ ExtCost += OuterExt->computeCost(VF, Ctx);
+
+ return MulAccCost.isValid() &&
+ MulAccCost < ExtCost + MulCost + RedCost;
+ },
+ Range);
+ };
+
+ if (RdxDesc.getOpcode() != Instruction::Add)
+ return nullptr;
+
+ // Try to match reduce.add(mul(...))
+ if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ VPWidenCastRecipe *RecipeA =
+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+ VPWidenCastRecipe *RecipeB =
+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+ VPWidenRecipe *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+
+ // Matched reduce.add(mul(ext, ext))
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+ (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B)) {
+
+ // Only create MulAccRecipe if the cost is valid.
+ if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+ Instruction::CastOps::ZExt,
+ Mul, RecipeA, RecipeB, nullptr))
+ return nullptr;
+
+ return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc), Mul, RecipeA,
+ RecipeB);
+ } else {
+ // Matched reduce.add(mul)
+ if (!IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+ return nullptr;
+
+ return new VPMulAccRecipe(RdxDesc, CurrentLinkI, PreviousLink, CondOp,
+ CM.useOrderedReductions(RdxDesc), Mul);
+ }
+ // Matched reduce.add(ext(mul(ext(A), ext(B))))
+ // All extend instructions must have same opcode or A == B
+ // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+ } else if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+ m_ZExtOrSExt(m_VPValue()))))) {
+ VPWidenCastRecipe *Ext =
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+ VPWidenRecipe *Mul =
+ cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext0 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+ VPWidenCastRecipe *Ext1 =
+ cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
----------------
ElvisWang123 wrote:
Fixed, thanks.
https://github.com/llvm/llvm-project/pull/113903
More information about the llvm-commits
mailing list