[llvm] [LV] Simplify the chain traversal in `getScaledReductions()` (NFCI) (PR #184830)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 9 02:51:43 PDT 2026
================
@@ -6159,33 +6143,72 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
if (!BinOp || !Instruction::isBinaryOp(BinOp->getOpcode()) ||
!MatchExtends(BinOp->operands()))
- return false;
+ return std::nullopt;
} else if (match(UpdateR, m_Add(m_VPValue(), m_VPValue())) ||
match(UpdateR, m_FAdd(m_VPValue(), m_VPValue()))) {
- // We already know the operands for Update are Op and PhiOp.
+ // We already know Op is an operand of UpdateR.
if (!MatchExtends({Op}))
- return false;
+ return std::nullopt;
BinOp = UpdateR;
} else {
- return false;
+ return std::nullopt;
}
+ return ExtendedReductionOperand{BinOp, CastRecipes};
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// This works backwards from the \p ExitValue examining each operation in
+/// in the reduction.
+static bool
+getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
+ SmallVectorImpl<VPPartialReductionChain> &Chains,
+ VPCostContext &CostCtx, VFRange &Range) {
+ RecurKind RK = RedPhiR->getRecurrenceKind();
Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
- Type *ExtOpType =
- CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0));
- TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
- if (!PHISize.hasKnownScalarFactor(ASize))
- return false;
- RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
- VPPartialReductionChain Chain(
- {UpdateR, CastRecipes[0], CastRecipes[1], BinOp,
- static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
- if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
- return false;
+ VPValue *CurrentValue = ExitValue;
+ while (CurrentValue != RedPhiR) {
+ auto *UpdateR = dyn_cast<VPWidenRecipe>(CurrentValue);
+ if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
+ return false;
+
+ VPValue *Op = UpdateR->getOperand(0);
+ VPValue *PrevValue = UpdateR->getOperand(1);
+
+ // Find the extended operand. The other operand (PrevValue) is the next link
+ // in the reduction chain.
+ std::optional<ExtendedReductionOperand> ExtendedOp =
+ matchExtendedReductionOperand(UpdateR, Op);
+ if (!ExtendedOp) {
+ ExtendedOp = matchExtendedReductionOperand(UpdateR, PrevValue);
+ if (!ExtendedOp)
+ return false;
----------------
fhahn wrote:
Would be good to also add the test case, to guard against regressions
https://github.com/llvm/llvm-project/pull/184830
More information about the llvm-commits
mailing list