[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 5 07:57:46 PST 2025


================
@@ -1928,3 +2024,174 @@ void VPlanTransforms::handleUncountableEarlyExit(
   Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken);
   LatchExitingBranch->eraseFromParent();
 }
+
+/// This function try to match following pattern to create
+/// VPExtendedReductionRecipe and clamp the \p Range if it is beneficial and
+/// valid. The created VPExtendedReductionRecipe will lower to concrete before
+/// executeion.
+///   reduce(ext(...)).
+static VPExtendedReductionRecipe *
+tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
+                                     VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  Type *RedTy = Ctx.Types.inferScalarType(Red);
+  VPValue *VecOp = Red->getVecOp();
+  const RecurrenceDescriptor &RdxDesc = Red->getRecurrenceDescriptor();
+
+  // Test if using extended-reduction is beneficial and clamp the range.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+              CostKind);
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          InstructionCost RedCost = Red->computeCost(VF, Ctx);
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
+        },
+        Range);
+  };
+
+  VPValue *A;
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
+      IsExtendedRedValidAndClampRange(
+          RdxDesc.getOpcode(),
+          cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+              Instruction::CastOps::ZExt,
+          Ctx.Types.inferScalarType(A)))
+    return new VPExtendedReductionRecipe(Red, cast<VPWidenCastRecipe>(VecOp));
+
+  return nullptr;
+}
+
+/// This function try to match following pattern to create
+/// VPMulAccumulateReductionRecipe and clamp the \p Range if it is beneficial
+/// and valid. The created VPMulAccumulateReduction will lower to concrete
+/// before executeion.
+///   reduce.add(mul(...)),
+///   reduce.add(mul(ext(A), ext(B))),
+///   reduce.add(ext(mul(ext(A), ext(B)))).
+static VPMulAccumulateReductionRecipe *
+tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
+                                          VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  Type *RedTy = Ctx.Types.inferScalarType(Red);
+
+  // Test if using mulutiply-accumulate-reduction is beneficial and clamp the
+  // range.
+  auto IsMulAccValidAndClampRange =
+      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          Type *SrcTy =
+              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          InstructionCost MulAccCost =
+              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulCost = Mul->computeCost(VF, Ctx);
+          InstructionCost RedCost = Red->computeCost(VF, Ctx);
+          InstructionCost ExtCost = 0;
+          if (Ext0)
+            ExtCost += Ext0->computeCost(VF, Ctx);
+          if (Ext1)
+            ExtCost += Ext1->computeCost(VF, Ctx);
+          if (OuterExt)
+            ExtCost += OuterExt->computeCost(VF, Ctx);
+
+          return MulAccCost.isValid() &&
+                 MulAccCost < ExtCost + MulCost + RedCost;
+        },
+        Range);
+  };
+
+  const RecurrenceDescriptor &RdxDesc = Red->getRecurrenceDescriptor();
+  if (RdxDesc.getOpcode() != Instruction::Add)
+    return nullptr;
+
+  VPValue *VecOp = Red->getVecOp();
+  VPValue *A, *B;
+  // Try to match reduce.add(mul(...))
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+    auto *RecipeA =
+        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    auto *RecipeB =
+        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+
+    // Matched reduce.add(mul(ext, ext))
+    if (RecipeA && RecipeB &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
+        match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
+        IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+                                       Instruction::CastOps::ZExt,
+                                   Mul, RecipeA, RecipeB, nullptr))
+      return new VPMulAccumulateReductionRecipe(Red, Mul, RecipeA, RecipeB);
+    // Matched reduce.add(mul)
+    if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+      return new VPMulAccumulateReductionRecipe(Red, Mul);
+  }
+  // Matched reduce.add(ext(mul(ext(A), ext(B))))
+  // All extend recipes must have same opcode or A == B
+  // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
+  if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
+                                      m_ZExtOrSExt(m_VPValue()))))) {
+    auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
+    auto *Ext0 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
+    auto *Ext1 =
+        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
+    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+        Ext0->getOpcode() == Ext1->getOpcode() &&
+        IsMulAccValidAndClampRange(Ext0->getOpcode() ==
+                                       Instruction::CastOps::ZExt,
+                                   Mul, Ext0, Ext1, Ext))
+      return new VPMulAccumulateReductionRecipe(Red, Mul, Ext0, Ext1);
+  }
+  return nullptr;
+}
+
+/// This function try to create abstract recipes from reduction recipe for
+/// following optimizations and cost estimation.
+static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
+                                               VPCostContext &Ctx,
+                                               VFRange &Range) {
+  // TODO: Remove EVL check when we support EVL version of
+  // VPExtendedReductionRecipe and VPMulAccumulateReductionRecipe.
+  if (Ctx.foldTailWithEVL())
----------------
fhahn wrote:

Better to not call the transform instead of checking here?

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list