[llvm] [VPlan] Add new recipes for extended-reduction and mul-accumulate-reduction. NFC (PR #137745)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 19:49:40 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Elvis Wang (ElvisWang123)

<details>
<summary>Changes</summary>

This patch add two new recipes for extended-reduction and the mul-accumulate-reductions.

* VPExtendedReductionRecipe.
  - Contains widen-cast + reduction.
* VPMulAccumulateReductionRecipe.
  - Contains widen-mul + widen-cast + reduction.

The transformation and the cost model of these recipes will in following patches.

Split from #<!-- -->113904.

---
Full diff: https://github.com/llvm/llvm-project/pull/137745.diff


4 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+232-5) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+68) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+2) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index afad73bcd3501..591a3f26d4d13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
     case VPRecipeBase::VPInstructionSC:
     case VPRecipeBase::VPReductionEVLSC:
     case VPRecipeBase::VPReductionSC:
+    case VPRecipeBase::VPMulAccumulateReductionSC:
+    case VPRecipeBase::VPExtendedReductionSC:
     case VPRecipeBase::VPReplicateSC:
     case VPRecipeBase::VPScalarIVStepsSC:
     case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
   };
 
+  struct NonNegFlagsTy {
+    char NonNeg : 1;
+    NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
+  };
+
 private:
   struct ExactFlagsTy {
     char IsExact : 1;
   };
-  struct NonNegFlagsTy {
-    char NonNeg : 1;
-  };
   struct FastMathFlagsTy {
     char AllowReassoc : 1;
     char NoNaNs : 1;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
       : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
         DisjointFlags(DisjointFlags) {}
 
+  template <typename IterT>
+  VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
+                      NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
+      : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
+        NonNegFlags(NonNegFlags) {}
+
 protected:
   template <typename IterT>
   VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
            R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
            R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
-           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
+           R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
 
   FastMathFlags getFastMathFlags() const;
 
+  /// Returns true if the recipe has non-negative flag.
+  bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
+
+  bool isNonNeg() const {
+    assert(OpType == OperationType::NonNegOp &&
+           "recipe doesn't have a NNEG flag");
+    return NonNegFlags.NonNeg;
+  }
+
   bool hasNoUnsignedWrap() const {
     assert(OpType == OperationType::OverflowingBinOp &&
            "recipe doesn't have a NUW flag");
@@ -2373,6 +2394,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
     setUnderlyingValue(I);
   }
 
+  /// For VPExtendedReductionRecipe.
+  /// Note that the debug location is from the extend.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
+  /// For VPMulAccumulateReductionRecipe.
+  /// Note that the NUW/NSW flags and the debug location are from the Mul.
+  VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
+                    ArrayRef<VPValue *> Operands, VPValue *CondOp,
+                    bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
+      : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
+        IsOrdered(IsOrdered), IsConditional(CondOp) {
+    if (CondOp)
+      addOperand(CondOp);
+  }
+
 public:
   VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2381,6 +2424,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
                           IsOrdered, DL) {}
 
+  VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
+                    VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                    bool IsOrdered, DebugLoc DL = {})
+      : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
+                          ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
+                          IsOrdered, DL) {}
+
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
@@ -2391,7 +2441,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
 
   static inline bool classof(const VPRecipeBase *R) {
     return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
-           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+           R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
   }
 
   static inline bool classof(const VPUser *U) {
@@ -2471,6 +2523,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a extended vector operand into a scalar value, and adding the
+/// result to a chain. This recipe is abstract and needs to be lowered to
+/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
+/// [Condition]}.
+class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe will be lowered to.
+  Instruction::CastOps ExtOp;
+
+  Type *ResultTy;
+
+  /// For cloning VPExtendedReductionRecipe.
+  VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
+      : VPReductionRecipe(
+            VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
+            {ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
+            ExtRed->isOrdered(), ExtRed->getDebugLoc()),
+        ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
+    transferFlags(*ExtRed);
+  }
+
+public:
+  VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
+      : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
+                          {R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
+                          R->isOrdered(), Ext->getDebugLoc()),
+        ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
+    // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
+    // the original recipe to prevent setting wrong flags.
+    transferFlags(*Ext);
+  }
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    auto *Copy = new VPExtendedReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The scalar type after extending.
+  Type *getResultType() const { return ResultTy; }
+
+  /// Is the extend ZExt?
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
+  /// The opcode of extend recipe.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+};
+
+/// A recipe to represent inloop MulAccumulateReduction operations, performing a
+/// reduction.add on the result of vector operands (might be extended)
+/// multiplication into a scalar value, and adding the result to a chain. This
+/// recipe is abstract and needs to be lowered to concrete recipes before
+/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend recipe.
+  Instruction::CastOps ExtOp;
+
+  /// Non-neg flag of the extend recipe.
+  bool IsNonNeg = false;
+
+  Type *ResultTy;
+
+  /// For cloning VPMulAccumulateReductionRecipe.
+  VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
+            {MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
+            MulAcc->getCondOp(), MulAcc->isOrdered(),
+            WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
+            MulAcc->getDebugLoc()),
+        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        ResultTy(MulAcc->getResultType()) {}
+
+public:
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
+                                 VPWidenCastRecipe *Ext0,
+                                 VPWidenCastRecipe *Ext1, Type *ResultTy)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateteReductionRecipe must "
+           "be Add");
+    // Only set the non-negative flag if the original recipe contains.
+    if (Ext0->hasNonNegFlag())
+      IsNonNeg = Ext0->isNonNeg();
+  }
+
+  VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
+            {R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
+            R->getCondOp(), R->isOrdered(),
+            WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
+            R->getDebugLoc()),
+        ExtOp(Instruction::CastOps::CastOpsEnd) {
+    assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
+               Instruction::Add &&
+           "The reduction instruction in MulAccumulateReductionRecipe must be "
+           "Add");
+  }
+
+  ~VPMulAccumulateReductionRecipe() override = default;
+
+  VPMulAccumulateReductionRecipe *clone() override {
+    auto *Copy = new VPMulAccumulateReductionRecipe(this);
+    Copy->transferFlags(*this);
+    return Copy;
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
+                     "VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
+
+  /// Return the cost of VPMulAccumulateReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  Type *getResultType() const {
+    assert(isExtended() && "Only support getResultType when this recipe "
+                           "contains implicit extend.");
+    return ResultTy;
+  }
+
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
+
+  /// Return if this MulAcc recipe contains extended operands.
+  bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+
+  /// Return the opcode of the extends for the operands.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+
+  /// Return if the operands are zero extended.
+  bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+
+  /// Return the non negative flag of the ext recipe.
+  bool isNonNeg() const { return IsNonNeg; }
+};
+
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
 /// copies of the original scalar type, one per lane, instead of producing a
 /// single copy of widened type for all lanes. If the instruction is known to be
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index c86815c84d8d9..7dcbd72c25191 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
             // TODO: Use info from interleave group.
             return V->getUnderlyingValue()->getType();
           })
+          .Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
+              [](const auto *R) { return R->getResultType(); })
           .Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
             return R->getSCEV()->getType();
           })
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 75d056026025a..8978a4d5e93cf 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -71,6 +71,8 @@ bool VPRecipeBase::mayWriteToMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -118,6 +120,8 @@ bool VPRecipeBase::mayReadFromMemory() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
   case VPWidenCastSC:
@@ -155,6 +159,8 @@ bool VPRecipeBase::mayHaveSideEffects() const {
   case VPBlendSC:
   case VPReductionEVLSC:
   case VPReductionSC:
+  case VPExtendedReductionSC:
+  case VPMulAccumulateReductionSC:
   case VPScalarIVStepsSC:
   case VPVectorPointerSC:
   case VPWidenCanonicalIVSC:
@@ -2513,6 +2519,18 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
                                                    Ctx.CostKind);
 }
 
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  return 0;
+}
+
+InstructionCost
+VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  return 0;
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
                               VPSlotTracker &SlotTracker) const {
@@ -2555,6 +2573,56 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   }
   O << ")";
 }
+
+void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                      VPSlotTracker &SlotTracker) const {
+  O << Indent << "EXTENDED-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " +";
+  O << " reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
+  getVecOp()->printAsOperand(O, SlotTracker);
+  O << " extended to " << *getResultType();
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+}
+
+void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
+                                           VPSlotTracker &SlotTracker) const {
+  O << Indent << "MULACC-REDUCE ";
+  printAsOperand(O, SlotTracker);
+  O << " = ";
+  getChainOp()->printAsOperand(O, SlotTracker);
+  O << " + ";
+  O << "reduce."
+    << Instruction::getOpcodeName(
+           RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+    << " (";
+  O << "mul";
+  printFlags(O);
+  if (isExtended())
+    O << "(";
+  getVecOp0()->printAsOperand(O, SlotTracker);
+  if (isExtended())
+    O << " extended to " << *getResultType() << "), (";
+  else
+    O << ", ";
+  getVecOp1()->printAsOperand(O, SlotTracker);
+  if (isExtended())
+    O << " extended to " << *getResultType() << ")";
+  if (isConditional()) {
+    O << ", ";
+    getCondOp()->printAsOperand(O, SlotTracker);
+  }
+  O << ")";
+}
 #endif
 
 bool VPReplicateRecipe::shouldPack() const {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 638156eab7a84..64065edd315f9 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -339,6 +339,8 @@ class VPDef {
     VPInterleaveSC,
     VPReductionEVLSC,
     VPReductionSC,
+    VPMulAccumulateReductionSC,
+    VPExtendedReductionSC,
     VPPartialReductionSC,
     VPReplicateSC,
     VPScalarIVStepsSC,

``````````

</details>


https://github.com/llvm/llvm-project/pull/137745


More information about the llvm-commits mailing list