[llvm] [VPlan] Implement VPExtendedReduction, VPMulAccumulateReductionRecipe and corresponding vplan transformations. (PR #137746)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 01:20:36 PDT 2025


================
@@ -2482,6 +2552,185 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
   }
 };
 
+/// A recipe to represent inloop extended reduction operations, performing a
+/// reduction on a extended vector operand into a scalar value, and adding the
+/// result to a chain. This recipe is abstract and needs to be lowered to
+/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
+/// [Condition]}.
+class VPExtendedReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend for VecOp.
+  Instruction::CastOps ExtOp;
+
+  /// The scalar type after extending.
+  Type *ResultTy;
+
+  /// For cloning VPExtendedReductionRecipe.
+  VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
+      : VPReductionRecipe(
+            VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
+            {ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
+            ExtRed->isOrdered(), ExtRed->getDebugLoc()),
+        ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
+    transferFlags(*ExtRed);
+  }
+
+public:
+  VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
+      : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
+                          {R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
+                          R->isOrdered(), Ext->getDebugLoc()),
+        ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
+    assert((ExtOp == Instruction::CastOps::ZExt ||
+            ExtOp == Instruction::CastOps::SExt) &&
+           "VPExtendedReductionRecipe only supports zext and sext.");
+
+    transferFlags(*Ext);
+    setUnderlyingValue(R->getUnderlyingValue());
+  }
+
+  ~VPExtendedReductionRecipe() override = default;
+
+  VPExtendedReductionRecipe *clone() override {
+    return new VPExtendedReductionRecipe(this);
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
+
+  void execute(VPTransformState &State) override {
+    llvm_unreachable("VPExtendedReductionRecipe should be transform to "
+                     "VPExtendedRecipe + VPReductionRecipe before execution.");
+  };
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+
+  /// The scalar type after extending.
+  Type *getResultType() const { return ResultTy; }
+
+  /// Is the extend ZExt?
+  bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
+
+  /// Get the opcode of the extend for VecOp.
+  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+};
+
+/// A recipe to represent inloop MulAccumulateReduction operations, performing a
+/// reduction.add on the result of vector operands (might be extended)
+/// multiplication into a scalar value, and adding the result to a chain. This
+/// recipe is abstract and needs to be lowered to concrete recipes before
+/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
+class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
+  /// Opcode of the extend for VecOp1 and VecOp2.
+  Instruction::CastOps ExtOp;
+
+  /// Non-neg flag of the extend recipe.
+  bool IsNonNeg = false;
+
+  /// The scalar type after extending.
+  Type *ResultTy;
+
+  /// For cloning VPMulAccumulateReductionRecipe.
+  VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
+      : VPReductionRecipe(
+            VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
+            {MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
+            MulAcc->getCondOp(), MulAcc->isOrdered(),
+            WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
+            MulAcc->getDebugLoc()),
+        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()) {
+    if (MulAcc->isExtended())
----------------
fhahn wrote:

Ah yes, but still should set the result type here as well, otherwise `getResultType` may read unitialized memory?

We have the result type available at construction through inference, would be good to always pass it to the constructor, even if there's no extend and always copy it when cloning? 

https://github.com/llvm/llvm-project/pull/137746


More information about the llvm-commits mailing list