[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 02:18:58 PST 2024


================
@@ -2653,6 +2657,210 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe and VPWidenCastRecipe before execution. The Operands are
+/// {ChainOp, VecOp, [Condition]}.
+class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Type after extend.
+  Type *ResultTy;
+  CastInst *ExtInstr;
+
+protected:
+  VPExtendedReductionRecipe(const unsigned char SC,
+                            const RecurrenceDescriptor &R, Instruction *RedI,
+                            Instruction::CastOps ExtOp, CastInst *ExtInstr,
+                            VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
+                            bool IsOrdered, Type *ResultTy)
+      : VPReductionRecipe(SC, R, RedI, ArrayRef<VPValue *>({ChainOp, VecOp}),
+                          CondOp, IsOrdered),
+        ResultTy(ResultTy), ExtInstr(ExtInstr) {}
+
+public:
+  VPExtendedReductionRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                            VPValue *ChainOp, VPWidenCastRecipe *Ext,
+                            VPValue *CondOp, bool IsOrdered)
+      : VPExtendedReductionRecipe(
+            VPDef::VPExtendedReductionSC, R, RedI, Ext->getOpcode(),
+            cast<CastInst>(Ext->getUnderlyingInstr()), ChainOp,
+            Ext->getOperand(0), CondOp, IsOrdered, Ext->getResultType()) {}
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    llvm_unreachable("Not implement yet");
+  }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPDef::VPExtendedReductionSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+  /// Return the cost of VPExtendedReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The Type after extended.
+  Type *getResultType() const { return ResultTy; }
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+  /// The Opcode of extend instruction.
+  Instruction::CastOps getExtOpcode() const { return ExtInstr->getOpcode(); }
+  /// The CastInst of the extend instruction.
+  CastInst *getExtInstr() const { return ExtInstr; }
+};
+
+/// A recipe to represent inloop MulAccreduction operations, performing a
+/// reduction on a vector operand into a scalar value, and adding the result to
+/// a chain. This recipe is high level abstract which will generate
+/// VPReductionRecipe VPWidenRecipe(mul) and VPWidenCastRecipes before
+/// execution. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccRecipe : public VPReductionRecipe {
+  /// Type after extend.
+  Type *ResultType;
+
+  /// reduce.add(ext(mul(ext0(), ext1())))
+  Instruction *MulInstr;
+  CastInst *ExtInstr = nullptr;
+  CastInst *Ext0Instr = nullptr;
+  CastInst *Ext1Instr = nullptr;
+
+protected:
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *ExtInstr,
+                 Instruction *MulInstr, Instruction *Ext0Instr,
+                 Instruction *Ext1Instr, VPValue *ChainOp, VPValue *VecOp0,
+                 VPValue *VecOp1, VPValue *CondOp, bool IsOrdered,
+                 Type *ResultType)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
+        ResultType(ResultType), MulInstr(MulInstr),
+        ExtInstr(cast_if_present<CastInst>(ExtInstr)),
+        Ext0Instr(cast<CastInst>(Ext0Instr)),
+        Ext1Instr(cast<CastInst>(Ext1Instr)) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
+  }
+
+  VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
+                 Instruction *RedI, Instruction *MulInstr, VPValue *ChainOp,
+                 VPValue *VecOp0, VPValue *VecOp1, VPValue *CondOp,
+                 bool IsOrdered)
+      : VPReductionRecipe(SC, R, RedI,
+                          ArrayRef<VPValue *>({ChainOp, VecOp0, VecOp1}),
+                          CondOp, IsOrdered),
+        MulInstr(MulInstr) {
+    assert(MulInstr->getOpcode() == Instruction::Mul);
+  }
+
+public:
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+                 VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, nullptr,
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenRecipe *Mul)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
+                       ChainOp, Mul->getOperand(0), Mul->getOperand(1), CondOp,
+                       IsOrdered) {}
+
+  VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
+                 VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
+                 VPWidenCastRecipe *Ext, VPWidenRecipe *Mul,
+                 VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1)
+      : VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Ext->getUnderlyingInstr(),
+                       Mul->getUnderlyingInstr(), Ext0->getUnderlyingInstr(),
+                       Ext1->getUnderlyingInstr(), ChainOp, Ext0->getOperand(0),
+                       Ext1->getOperand(0), CondOp, IsOrdered,
+                       Ext0->getResultType()) {}
+
+  ~VPMulAccRecipe() override = default;
+
+  VPMulAccRecipe *clone() override { llvm_unreachable("Not implement yet"); }
+
+  static inline bool classof(const VPRecipeBase *R) {
+    return R->getVPDefID() == VPRecipeBase::VPMulAccSC;
+  }
+
+  static inline bool classof(const VPUser *U) {
+    auto *R = dyn_cast<VPRecipeBase>(U);
+    return R && classof(R);
+  }
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
+                     "VPWidenRecipe + VPReductionRecipe before execution");
+  }
+
+  /// Return the cost of VPMulAccRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The VPValue of the vector value to be extended and reduced.
+  VPValue *getVecOp0() const { return getOperand(1); }
+  VPValue *getVecOp1() const { return getOperand(2); }
+
+  /// Return the type after inner extended, which must equal to the type of mul
+  /// instruction. If the ResultType != recurrenceType, than it must have a
+  /// extend recipe after mul recipe.
+  Type *getResultType() const { return ResultType; }
+
+  /// The underlying instruction for VPWidenRecipe.
+  Instruction *getMulInstr() const { return MulInstr; }
+
+  /// The underlying Instruction for outer VPWidenCastRecipe.
+  CastInst *getExtInstr() const { return ExtInstr; }
----------------
davemgreen wrote:

Is it worth representing a `reduce(ext(mul(ext, ext)))` reduction as `reduce(mul(ext, ext))`, as (so long as the extends are large enough) they should be equivalent?

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list