[llvm] [VPlan] Implement transformation of widen-cast + (widen-mul) + reduction to abstract recipe. (PR #137746)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 19:53:02 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Elvis Wang (ElvisWang123)

<details>
<summary>Changes</summary>

This patch implements the transformation that match the following
patterns in the vplan and converts to abstract recipes for better cost
estimation.

* VPExtendedReductionRecipe
  - cast + reduction.

* VPMulAccumulateReductionRecipe
  - (cast) + mul + reduction.

The converted abstract recipes will be lower to the concrete recipes
(widen-cast + widen-mul + reduction) just before vector codegen.

This should be a cost-model based decision which will be implemented in the
following patch. In current status, still rely on the legacy cost model
to calculate the cost.

Split from #<!-- -->113903. Stacked on #<!-- -->137745.

---

Patch is 68.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137746.diff


12 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+15-4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+251-7) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+89-12) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+247) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.h (+7) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+2) 
- (modified) llvm/test/Transforms/LoopVectorize/ARM/mve-reduction-types.ll (+2-2) 
- (modified) llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll (+73-47) 
- (modified) llvm/test/Transforms/LoopVectorize/reduction-inloop-pred.ll (+1-1) 
- (modified) llvm/test/Transforms/LoopVectorize/reduction-inloop.ll (+3-5) 
- (modified) llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll (+145) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 4684378687ef6..f5e3d1664b407 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9631,10 +9631,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
          "entry block must be set to a VPRegionBlock having a non-empty entry "
          "VPBasicBlock");
 
-  for (ElementCount VF : Range)
-    Plan->addVF(VF);
-  Plan->setName("Initial VPlan");
-
   // Update wide induction increments to use the same step as the corresponding
   // wide induction. This enables detecting induction increments directly in
   // VPlan and removes redundant splats.
@@ -9670,6 +9666,21 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   // Adjust the recipes for any inloop reductions.
   adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
 
+  // Transform recipes to abstract recipes if it is legal and beneficial and
+  // clamp the range for better cost estimation.
+  // TODO: Enable following transform when the EVL-version of extended-reduction
+  // and mulacc-reduction are implemented.
+  if (!CM.foldTailWithEVL()) {
+    VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(), CM,
+                          CM.CostKind);
+    VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
+                             CostCtx, Range);
+  }
+
+  for (ElementCount VF : Range)
+    Plan->addVF(VF);
+  Plan->setName("Initial VPlan");
+
   // Interleave memory: for each Interleave Group we marked earlier as relevant
   // for this VPlan, replace the Recipes widening its memory instructions with a
   // single VPInterleaveRecipe at its insertion point.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index afad73bcd3501..587ba29965646 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccumulateReductionSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
   };
 
+  struct NonNegFlagsTy {
+    char NonNeg : 1;
+    NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
+  };
+
 private:
   struct ExactFlagsTy {
     char IsExact : 1;
   };
-  struct NonNegFlagsTy {
-    char NonNeg : 1;
-  };
   struct FastMathFlagsTy {
     char AllowReassoc : 1;
     char NoNaNs : 1;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
         DisjointFlags(DisjointFlags) {}
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
+        NonNegFlags(NonNegFlags) {}
+
 protected:
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
            R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
            R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 
   FastMathFlags getFastMathFlags() const;
 
+  /// Returns true if the recipe has non-negative flag.
+  bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
+
+  bool isNonNeg() const {
+    assert(OpType == OperationType::NonNegOp &&
+           "recipe doesn't have a NNEG flag");
+    return NonNegFlags.NonNeg;
+  }
+
   bool hasNoUnsignedWrap() const {
     assert(OpType == OperationType::OverflowingBinOp &&
            "recipe doesn't have a NUW flag");
@@ -1231,11 +1252,22 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
       : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
         Opcode(I.getOpcode()) {}
 
+  template <typename IterT>
+    VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
+                  iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
+    : VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
+    Opcode(Opcode) {}
+
 public:
   template <typename IterT>
   VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands)
       : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
 
+  template <typename IterT>
+  VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
+                bool NSW, DebugLoc DL)
+      : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -1280,10 +1312,16 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
            "opcode of underlying cast doesn't match");
   }
 
-  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
-      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
         Opcode(Opcode), ResultTy(ResultTy) {}
 
+  VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
+                    bool IsNonNeg, DebugLoc DL = {})
+    : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
+                          DL),
+    Opcode(Opcode), ResultTy(ResultTy) {}
+
   ~VPWidenCastRecipe() override = default;
 
   VPWidenCastRecipe *clone() override {
@@ -2373,6 +2411,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
     setUnderlyingValue(I);
   }
 
+  /// For VPExtendedReductionRecipe.
+  /// Note that the debug location is from the extend.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
+  /// For VPMulAccumulateReductionRecipe.
+  /// Note that the NUW/NSW flags and the debug location are from the Mul.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
 public:
   VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2441,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
 
+  VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
+                    VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL = {})
+      : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
+                          ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                          IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
@@ -2391,7 +2458,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
 
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2471,6 +2540,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a extended vector operand into a scalar value, and adding the
+/// result to a chain. This recipe is abstract and needs to be lowered to
+/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
+/// [Condition]}.
+class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe will be lowered to.
+  Instruction::CastOps ExtOp;
+
+  Type *ResultTy;
+
+  /// For cloning VPExtendedReductionRecipe.
+  VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
+      : VPReductionRecipe(
+            VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
+            {ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
+            ExtRed->isOrdered(), ExtRed->getDebugLoc()),
+        ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
+    transferFlags(*ExtRed);
+  }
+
+public:
+  VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
+      : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
+                          {R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
+                          R->isOrdered(), Ext->getDebugLoc()),
+        ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
+    // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
+    // the original recipe to prevent setting wrong flags.
+    transferFlags(*Ext);
+  }
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    auto *Copy = new VPExtendedReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The scalar type after extending.
+  Type *getResultType() const { return ResultTy; }
+
+  /// Is the extend ZExt?
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
+  /// The opcode of extend recipe.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+};
+
+/// A recipe to represent inloop MulAccumulateReduction operations, performing a
+/// reduction.add on the result of vector operands (might be extended)
+/// multiplication into a scalar value, and adding the result to a chain. This
+/// recipe is abstract and needs to be lowered to concrete recipes before
+/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe.
+  Instruction::CastOps ExtOp;
+
+  /// Non-neg flag of the extend recipe.
+  bool IsNonNeg = false;
+
+  Type *ResultTy;
+
+  /// For cloning VPMulAccumulateReductionRecipe.
+  VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
+            {MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
+            MulAcc->getCondOp(), MulAcc->isOrdered(),
+            WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
+            MulAcc->getDebugLoc()),
+        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        ResultTy(MulAcc->getResultType()) {}
+
+public:
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
+                                 VPWidenCastRecipe *Ext0,
+                                 VPWidenCastRecipe *Ext1, Type *ResultTy)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateteReductionRecipe must "
+           "be Add");
+    // Only set the non-negative flag if the original recipe contains.
+    if (Ext0->hasNonNegFlag())
+      IsNonNeg = Ext0->isNonNeg();
+  }
+
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Instruction::CastOps::CastOpsEnd) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateReductionRecipe must be "
+           "Add");
+  }
+
+  ~VPMulAccumulateReductionRecipe() override = default;
+
+  VPMulAccumulateReductionRecipe *clone() override {
+    auto *Copy = new VPMulAccumulateReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
+                     "VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
+
+  /// Return the cost of VPMulAccumulateReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  Type *getResultType() const {
+    assert(isExtended() && "Only support getResultType when this recipe "
+                           "contains implicit extend.");
+    return ResultTy;
+  }
+
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
+
+  /// Return if this MulAcc recipe contains extended operands.
+  bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+
+  /// Return the opcode of the extends for the operands.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+
+  /// Return if the operands are zero extended.
+  bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+
+  /// Return the non negative flag of the ext recipe.
+  bool isNonNeg() const { return IsNonNeg; }
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index c86815c84d8d9..7dcbd72c25191 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             // TODO: Use info from interleave group.
             return V->getUnderlyingValue()->getType();
           })
+          .Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
+              [](const auto *R) { return R->getResultType(); })
           .Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
             return R->getSCEV()->getType();
           })
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75d056026025a..32e35b0fb78d3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -71,6 +71,8 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -118,6 +120,8 @@ bool VPRecipeBase::mayReadFromMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -155,6 +159,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPScalarIVStepsSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
@@ -2489,28 +2495,49 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
   unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind);
   FastMathFlags FMFs = getFastMathFlags();
+  std::optional<FastMathFlags> OptionalFMF =
+      ElementTy->isFloatingPointTy() ? std::make_optional(FMFs) : std::nullopt;
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost +
-           Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
+    return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
-                                                   Ctx.CostKind);
+  return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, OptionalFMF,
+                                            Ctx.CostKind);
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  unsigned Opcode = RecurrenceDescriptor::getOpcode(getRecurrenceKind());
+  Type *RedTy = Ctx.Types.inferScalarType(this);
+  auto *SrcVecTy =
+      cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  assert(RedTy->isIntegerTy() &&
+         "Exte...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/137746


More information about the llvm-commits mailing list