[llvm] 6a8606e - [VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI (#131300)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 24 04:18:57 PDT 2025
Author: Luke Lau
Date: 2025-03-24T19:18:54+08:00
New Revision: 6a8606e99e399ccc68a89b1c22b396e00021d3fb
URL: https://github.com/llvm/llvm-project/commit/6a8606e99e399ccc68a89b1c22b396e00021d3fb
DIFF: https://github.com/llvm/llvm-project/commit/6a8606e99e399ccc68a89b1c22b396e00021d3fb.diff
LOG: [VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI (#131300)
VPReductionRecipes take a RecurrenceDescriptor, but only use the
RecurKind and FastMathFlags in it when executing. This patch makes the
recipe more lightweight by stripping it to only take the latter two.
The motiviation for this is to simplify an upcoming patch to support
in-loop AnyOf reductions. For an in-loop AnyOf reduction we want to
create an Or reduction, and by using RecurKind we can create an
arbitrary reduction without needing a full RecurrenceDescriptor.
Added:
Modified:
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/lib/Transforms/Vectorize/VPlan.h
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
llvm/test/Transforms/LoopVectorize/vplan-printing.ll
llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 92160a421e59c..b64bac329e05d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9772,8 +9772,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (CM.blockNeedsPredicationForAnyReason(BB))
CondOp = RecipeBuilder.getBlockInMask(BB);
+ // Non-FP RdxDescs will have all fast math flags set, so clear them.
+ FastMathFlags FMFs = isa<FPMathOperator>(CurrentLinkI)
+ ? RdxDesc.getFastMathFlags()
+ : FastMathFlags();
auto *RedRecipe = new VPReductionRecipe(
- RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
+ Kind, FMFs, CurrentLinkI, PreviousLink, VecOp, CondOp,
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 3059b87ae63c8..433fb247754bc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2239,22 +2239,19 @@ class VPInterleaveRecipe : public VPRecipeBase {
/// a vector operand into a scalar value, and adding the result to a chain.
/// The Operands are {ChainOp, VecOp, [Condition]}.
class VPReductionRecipe : public VPRecipeWithIRFlags {
- /// The recurrence decriptor for the reduction in question.
- const RecurrenceDescriptor &RdxDesc;
+ /// The recurrence kind for the reduction in question.
+ RecurKind RdxKind;
bool IsOrdered;
/// Whether the reduction is conditional.
bool IsConditional = false;
protected:
- VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
- Instruction *I, ArrayRef<VPValue *> Operands,
- VPValue *CondOp, bool IsOrdered, DebugLoc DL)
- : VPRecipeWithIRFlags(SC, Operands,
- isa_and_nonnull<FPMathOperator>(I)
- ? R.getFastMathFlags()
- : FastMathFlags(),
- DL),
- RdxDesc(R), IsOrdered(IsOrdered) {
+ VPReductionRecipe(const unsigned char SC, RecurKind RdxKind,
+ FastMathFlags FMFs, Instruction *I,
+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
+ bool IsOrdered, DebugLoc DL)
+ : VPRecipeWithIRFlags(SC, Operands, FMFs, DL), RdxKind(RdxKind),
+ IsOrdered(IsOrdered) {
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
@@ -2263,19 +2260,19 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
}
public:
- VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
+ VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
bool IsOrdered, DebugLoc DL = {})
- : VPReductionRecipe(VPDef::VPReductionSC, R, I,
+ : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, I,
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
IsOrdered, DL) {}
~VPReductionRecipe() override = default;
VPReductionRecipe *clone() override {
- return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
- getVecOp(), getCondOp(), IsOrdered,
- getDebugLoc());
+ return new VPReductionRecipe(RdxKind, getFastMathFlags(),
+ getUnderlyingInstr(), getChainOp(), getVecOp(),
+ getCondOp(), IsOrdered, getDebugLoc());
}
static inline bool classof(const VPRecipeBase *R) {
@@ -2301,10 +2298,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
VPSlotTracker &SlotTracker) const override;
#endif
- /// Return the recurrence decriptor for the in-loop reduction.
- const RecurrenceDescriptor &getRecurrenceDescriptor() const {
- return RdxDesc;
- }
+ /// Return the recurrence kind for the in-loop reduction.
+ RecurKind getRecurrenceKind() const { return RdxKind; }
/// Return true if the in-loop reduction is ordered.
bool isOrdered() const { return IsOrdered; };
/// Return true if the in-loop reduction is conditional.
@@ -2328,7 +2323,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp,
DebugLoc DL = {})
: VPReductionRecipe(
- VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
+ VPDef::VPReductionEVLSC, R.getRecurrenceKind(),
+ R.getFastMathFlags(),
cast_or_null<Instruction>(R.getUnderlyingValue()),
ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
R.isOrdered(), DL) {}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c7190b3187d94..cdef7972f3bdc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2300,7 +2300,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
void VPReductionRecipe::execute(VPTransformState &State) {
assert(!State.Lane && "Reduction being replicated.");
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"In-loop AnyOf reductions aren't currently supported");
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2313,8 +2313,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType());
Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType();
- Value *Start =
- getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags());
+ Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags());
if (State.VF.isVector())
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start);
@@ -2329,18 +2328,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind),
+ PrevInChain, NewVecOp);
PrevInChain = NewRed;
NextInChain = NewRed;
} else {
PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
- NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
- NewRed, PrevInChain);
+ NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain);
else
NextInChain = State.Builder.CreateBinOp(
- (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain);
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
+ PrevInChain);
}
State.set(this, NextInChain, /*IsScalar*/ true);
}
@@ -2351,10 +2351,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
Builder.setFastMathFlags(getFastMathFlags());
- RecurKind Kind = RdxDesc.getRecurrenceKind();
+ RecurKind Kind = getRecurrenceKind();
Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
Value *VecOp = State.get(getVecOp());
Value *EVL = State.get(getEVL(), VPLane(0));
@@ -2377,18 +2376,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else
- NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(),
- NewRed, Prev);
+ NewRed = Builder.CreateBinOp(
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind), NewRed,
+ Prev);
}
State.set(this, NewRed, /*IsScalar*/ true);
}
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+ RecurKind RdxKind = getRecurrenceKind();
Type *ElementTy = Ctx.Types.inferScalarType(this);
auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
- unsigned Opcode = RdxDesc.getOpcode();
+ unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind);
FastMathFlags FMFs = getFastMathFlags();
// TODO: Support any-of and in-loop reductions.
@@ -2401,9 +2401,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
"In-loop reduction not implemented in VPlan-based cost model currently.");
- assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
- "Inferred type and recurrence type mismatch.");
-
// Cost = Reduction cost + BinOp cost
InstructionCost Cost =
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
@@ -2426,28 +2423,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
printFlags(O);
- O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (isConditional()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
O << Indent << "REDUCE ";
printAsOperand(O, SlotTracker);
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
printFlags(O);
- O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
+ O << " vp.reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
+ << " (";
getVecOp()->printAsOperand(O, SlotTracker);
O << ", ";
getEVL()->printAsOperand(O, SlotTracker);
@@ -2456,9 +2455,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
- if (RdxDesc.IntermediateStore)
- O << " (with final reduction value stored in invariant address sank "
- "outside of loop)";
}
#endif
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 207cb8b4a0d30..9274915ff46a2 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -234,7 +234,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[IV]]>
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
-; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>) (with final reduction value stored in invariant address sank outside of loop)
+; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
; CHECK-NEXT: EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
; CHECK-NEXT: EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
; CHECK-NEXT: No successors
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index ca1e48290f25b..c51ab5df3ff07 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -1170,8 +1170,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
- VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
- VecOp, false);
+ VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
+ CondOp, VecOp, false);
EXPECT_FALSE(Recipe.mayHaveSideEffects());
EXPECT_FALSE(Recipe.mayReadFromMemory());
EXPECT_FALSE(Recipe.mayWriteToMemory());
@@ -1185,8 +1185,8 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
- VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
- VecOp, false);
+ VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
+ CondOp, VecOp, false);
VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
EXPECT_FALSE(EVLRecipe.mayHaveSideEffects());
@@ -1540,8 +1540,8 @@ TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
- VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
- false);
+ VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
+ CondOp, VecOp, false);
EXPECT_TRUE(isa<VPUser>(&Recipe));
VPRecipeBase *BaseR = &Recipe;
EXPECT_TRUE(isa<VPUser>(BaseR));
@@ -1555,8 +1555,8 @@ TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
- VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp,
- false);
+ VPReductionRecipe Recipe(RecurKind::Add, FastMathFlags(), Add, ChainOp,
+ CondOp, VecOp, false);
VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
EXPECT_TRUE(isa<VPUser>(&EVLRecipe));
More information about the llvm-commits
mailing list