[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 24 10:58:02 PST 2024


================
@@ -1867,9 +1957,173 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
       PhiR->replaceAllUsesWith(ScalarR);
       PhiR->eraseFromParent();
     }
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      if (!isa<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(&R))
+        continue;
+      if (auto *ExtRed = dyn_cast<VPExtendedReductionRecipe>(&R)) {
+        expandVPExtendedReduction(ExtRed);
+      }
+      if (auto *MulAcc = dyn_cast<VPMulAccumulateReductionRecipe>(&R)) {
+        expandVPMulAccumulateReduction(MulAcc);
+      }
+    }
   }
 }
 
+VPExtendedReductionRecipe *
+VPlanTransforms::tryToMatchAndCreateExtendedReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp, bool IsOrderedRed,
+    VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  VPValue *A;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if using extended-reduction is beneficial and clamp the range.
+  auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
+                                             Type *SrcTy) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+              Opcode, isZExt, RedTy, SrcVecTy, RdxDesc.getFastMathFlags(),
+              CostKind);
+          InstructionCost ExtCost =
+              cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
+          RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+          InstructionCost RedCost;
+          if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+            Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+            RedCost = Ctx.TTI.getMinMaxReductionCost(
+                Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          } else {
+            RedCost = Ctx.TTI.getArithmeticReductionCost(
+                Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          }
+          return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
+        },
+        Range);
+  };
+
+  // Matched reduce(ext)).
+  if (match(VecOp, m_ZExtOrSExt(m_VPValue(A)))) {
+    if (!IsExtendedRedValidAndClampRange(
+            RdxDesc.getOpcode(),
+            cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
+                Instruction::CastOps::ZExt,
+            Ctx.Types.inferScalarType(A)))
+      return nullptr;
+    return new VPExtendedReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink,
+                                         cast<VPWidenCastRecipe>(VecOp), CondOp,
+                                         IsOrderedRed);
+  }
+  return nullptr;
+}
+
+VPMulAccumulateReductionRecipe *
+VPlanTransforms::tryToMatchAndCreateMulAccumulateReduction(
+    const RecurrenceDescriptor &RdxDesc, Instruction *CurrentLinkI,
+    VPValue *PreviousLink, VPValue *VecOp, VPValue *CondOp, bool IsOrderedRed,
+    VPCostContext &Ctx, VFRange &Range) {
+  using namespace VPlanPatternMatch;
+
+  VPValue *A, *B;
+  Type *RedTy = RdxDesc.getRecurrenceType();
+
+  // Test if using mulutiply-accumulate-reduction is beneficial and clamp the
+  // range.
+  auto IsMulAccValidAndClampRange =
+      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+    return LoopVectorizationPlanner::getDecisionAndClampRange(
+        [&](ElementCount VF) {
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          Type *SrcTy =
+              Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
+          auto *SrcVecTy = cast<VectorType>(ToVectorTy(SrcTy, VF));
+          auto *VectorTy = cast<VectorType>(ToVectorTy(RedTy, VF));
+          InstructionCost MulAccCost =
+              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulCost = Mul->computeCost(VF, Ctx);
+          InstructionCost RedCost = Ctx.TTI.getArithmeticReductionCost(
+              Instruction::Add, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+          InstructionCost ExtCost = 0;
+          if (Ext0)
+            ExtCost += Ext0->computeCost(VF, Ctx);
+          if (Ext1)
+            ExtCost += Ext1->computeCost(VF, Ctx);
+          if (OuterExt)
+            ExtCost += OuterExt->computeCost(VF, Ctx);
+
+          return MulAccCost.isValid() &&
+                 MulAccCost < ExtCost + MulCost + RedCost;
+        },
+        Range);
+  };
+
+  if (RdxDesc.getOpcode() != Instruction::Add)
+    return nullptr;
+
+  // Try to match reduce.add(mul(...))
+  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+    auto *RecipeA =
+        dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    auto *RecipeB =
+        dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+
+    // Matched reduce.add(mul(ext, ext))
+    if (RecipeA && RecipeB &&
+        (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
+        match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+        match(RecipeB, m_ZExtOrSExt(m_VPValue()))) {
+
+      // Only create MulAccRecipe if the cost is valid.
+      if (!IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
+                                          Instruction::CastOps::ZExt,
+                                      Mul, RecipeA, RecipeB, nullptr))
+        return nullptr;
+
+      return new VPMulAccumulateReductionRecipe(
+          RdxDesc, CurrentLinkI, PreviousLink, CondOp, IsOrderedRed, Mul,
+          RecipeA, RecipeB);
+    } else {
----------------
alexey-bataev wrote:

```suggestion
    }
```


https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list