[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Elvis Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 01:37:59 PST 2024


================
@@ -9387,6 +9330,170 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
   return Plan;
 }
 
+/// Try to match the extended-reduction and create VPExtendedReductionRecipe.
+///
+/// This function try to match following pattern which will generate
+/// extended-reduction instruction.
+///    reduce(ext(...)).
+static VPExtendedReductionRecipe *tryToMatchAndCreateExtendedReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  VPValue *A;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if using extended-reduction is beneficial and clamp the range.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+              CostKind);
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+          InstructionCost RedCost;
+          if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+            Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+            RedCost = Ctx.TTI.getMinMaxReductionCost(
+                Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          } else {
+            RedCost = Ctx.TTI.getArithmeticReductionCost(
+                Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          }
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
+        },
+        Range);
+  };
+
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    if (!IsExtendedRedValidAndClampRange(
+            RdxDesc.getOpcode(),
+            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+                Instruction::CastOps::ZExt,
+            Ctx.Types.inferScalarType(A)))
+      return nullptr;
+    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
+                                         CM.useOrderedReductions(RdxDesc));
+  }
+  return nullptr;
+}
+
+/// Try to match the mul-acc-reduction and create VPMulAccRecipe.
+///
+/// This function try to match following patterns which will generate mul-acc
+/// instructions.
+///    reduce.add(mul(...)),
+///    reduce.add(mul(ext(A), ext(B))),
+///    reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccRecipe *tryToMatchAndCreateMulAcc(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp,
+    LoopVectorizationCostModel &CM, VPCostContext &Ctx, VFRange &Range) {
----------------
ElvisWang123 wrote:

Migrated, thanks.

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list