[llvm] [LV] Simplify the chain traversal in `getScaledReductions()` (NFCI) (PR #184830)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 6 03:18:39 PST 2026
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/184830
>From d49b92a7def08427535fe038ee2cc05c30a9add9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 17:13:32 +0000
Subject: [PATCH 1/4] [LV] Simplify the chain traversal in
`getScaledReductions()` (NFCI)
I found the logic of this function quite hard to reason about. This
patch attempts to rectify this by splitting out matching an extended
reduction operand and traversing reduction chain.
- `matchExtendedReductionOperand()` contains all the logic to match an
extended operand.
- `getScaledReductions()` validates each operation in the chain,
starting backwards from the exit value, walking up through the operand
that is not extended.
---
.../Transforms/Vectorize/VPlanTransforms.cpp | 126 ++++++++++--------
1 file changed, 71 insertions(+), 55 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a24a483ab5e32..a09dc1ddf8115 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6049,22 +6049,20 @@ static bool isValidPartialReduction(const VPPartialReductionChain &Chain,
Range);
}
-/// Examines reduction operations to see if the target can use a cheaper
-/// operation with a wider per-iteration input VF and narrower PHI VF.
-/// Recursively calls itself to identify chained scaled reductions.
-/// Returns true if this invocation added an entry to Chains, otherwise false.
-static bool
-getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
- SmallVectorImpl<VPPartialReductionChain> &Chains,
- VPCostContext &CostCtx, VFRange &Range) {
- auto *UpdateR = dyn_cast<VPWidenRecipe>(PrevValue);
- if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
- return false;
+/// Holds the binary operation used to compute the extended operand and the
+/// casts that feed into it.
+struct ExtendedReductionOperand {
+ VPWidenRecipe *BinOp = nullptr;
+ std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
+};
- VPValue *Op = UpdateR->getOperand(0);
- VPValue *PhiOp = UpdateR->getOperand(1);
- if (Op == RedPhiR)
- std::swap(Op, PhiOp);
+/// Checks if \p Op (which is an operand of \p UpdateR) is an extended reduction
+/// operand. This is an operand where the source of the value (e.g. a load) has
+/// been extended (sext, zext, or fpext) before it is used in the reduction.
+static std::optional<ExtendedReductionOperand>
+matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
+ assert(is_contained(UpdateR->operands(), Op) &&
+ "Op should be operand of UpdateR");
// If Op is an extend, then it's still a valid partial reduction if the
// extended mul fulfills the other requirements.
@@ -6076,36 +6074,16 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
match(Op, m_FPExt(m_FMul(m_VPValue(), m_VPValue())))) {
auto *CastRecipe = dyn_cast<VPWidenCastRecipe>(Op);
if (!CastRecipe)
- return false;
+ return std::nullopt;
auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
OuterExtKind = TTI::getPartialReductionExtendKind(CastOp);
Op = CastRecipe->getOperand(0);
}
- // Try and get a scaled reduction from the first non-phi operand.
- // If one is found, we use the discovered reduction instruction in
- // place of the accumulator for costing.
- if (getScaledReductions(RedPhiR, Op, Chains, CostCtx, Range)) {
- Op = UpdateR->getOperand(0);
- PhiOp = UpdateR->getOperand(1);
- if (Op == Chains.rbegin()->ReductionBinOp)
- std::swap(Op, PhiOp);
- assert(PhiOp == Chains.rbegin()->ReductionBinOp &&
- "PhiOp must be the chain value");
- assert(CostCtx.Types.inferScalarType(RedPhiR) ==
- CostCtx.Types.inferScalarType(PhiOp) &&
- "Unexpected type for chain values");
- } else if (RedPhiR != PhiOp) {
- // If neither operand of this instruction is the reduction PHI node or a
- // link in the reduction chain, then this is just an operand to the chain
- // and not a link in the chain itself.
- return false;
- }
-
// If the update is a binary op, check both of its operands to see if
// they are extends. Otherwise, see if the update comes directly from an
// extend.
- VPWidenCastRecipe *CastRecipes[2] = {nullptr};
+ std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
// Match extends and populate CastRecipes. Returns false if matching fails.
auto MatchExtends = [OuterExtKind,
@@ -6144,7 +6122,7 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
auto *BinOp = dyn_cast<VPWidenRecipe>(Op);
if (BinOp && Instruction::isBinaryOp(BinOp->getOpcode())) {
if (!BinOp->hasOneUse())
- return false;
+ return std::nullopt;
// Handle neg(binop(ext, ext)) pattern.
VPValue *OtherOp = nullptr;
@@ -6153,33 +6131,71 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
if (!BinOp || !Instruction::isBinaryOp(BinOp->getOpcode()) ||
!MatchExtends(BinOp->operands()))
- return false;
+ return std::nullopt;
} else if (match(UpdateR, m_Add(m_VPValue(), m_VPValue())) ||
match(UpdateR, m_FAdd(m_VPValue(), m_VPValue()))) {
- // We already know the operands for Update are Op and PhiOp.
+ // We already know Op is an operand of UpdateR.
if (!MatchExtends({Op}))
- return false;
+ return std::nullopt;
BinOp = UpdateR;
} else {
- return false;
+ return std::nullopt;
}
+ return ExtendedReductionOperand{BinOp, CastRecipes};
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// This works backwards from the \p ExitValue examining each operation in
+/// in the reduction.
+static bool
+getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
+ SmallVectorImpl<VPPartialReductionChain> &Chains,
+ VPCostContext &CostCtx, VFRange &Range) {
Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
- TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
- Type *ExtOpType =
- CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0));
- TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
- if (!PHISize.hasKnownScalarFactor(ASize))
- return false;
- RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
- VPPartialReductionChain Chain(
- {UpdateR, CastRecipes[0], CastRecipes[1], BinOp,
- static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
- if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
- return false;
+ VPValue *CurrentValue = ExitValue;
+ while (CurrentValue != RedPhiR) {
+ auto *UpdateR = dyn_cast<VPWidenRecipe>(CurrentValue);
+ if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
+ return false;
+
+ VPValue *Op = UpdateR->getOperand(0);
+ VPValue *PrevValue = UpdateR->getOperand(1);
+
+ // Find the extended operand. The other operand (PrevValue) is the next link
+ // in the reduction chain.
+ auto ExtendedOp = matchExtendedReductionOperand(UpdateR, Op);
+ if (!ExtendedOp) {
+ ExtendedOp = matchExtendedReductionOperand(UpdateR, PrevValue);
+ if (!ExtendedOp)
+ return false;
+ std::swap(Op, PrevValue);
+ }
+
+ TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
+ Type *ExtOpType = CostCtx.Types.inferScalarType(
+ ExtendedOp->CastRecipes[0]->getOperand(0));
+ TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
+ if (!PHISize.hasKnownScalarFactor(ASize))
+ return false;
+
+ RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
+ VPPartialReductionChain Chain(
+ {UpdateR, ExtendedOp->CastRecipes[0], ExtendedOp->CastRecipes[1],
+ ExtendedOp->BinOp,
+ static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
+ if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
+ return false;
+
+ Chains.push_back(Chain);
+ CurrentValue = PrevValue;
+ }
- Chains.push_back(Chain);
+ // The chains were collected by traversing the chain backwards from the exit
+ // value. Reverse them to they are in program order.
+ std::reverse(Chains.begin(), Chains.end());
return true;
}
} // namespace
>From 22e021c3d9273f451d59865fa580a55f53435ec2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 17:35:11 +0000
Subject: [PATCH 2/4] Fix typo
---
llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a09dc1ddf8115..2ec582c18ebba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6193,8 +6193,8 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
CurrentValue = PrevValue;
}
- // The chains were collected by traversing the chain backwards from the exit
- // value. Reverse them to they are in program order.
+ // The chains were collected by traversing backwards from the exit value.
+ // Reverse the chains so they are in program order.
std::reverse(Chains.begin(), Chains.end());
return true;
}
>From 89ca4658823f76464c65a680512f5be38594afd9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 19:46:18 +0000
Subject: [PATCH 3/4] Fixups
---
llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 2ec582c18ebba..d09486bce8079 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6153,7 +6153,9 @@ static bool
getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
SmallVectorImpl<VPPartialReductionChain> &Chains,
VPCostContext &CostCtx, VFRange &Range) {
+ RecurKind RK = RedPhiR->getRecurrenceKind();
Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
+ TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
VPValue *CurrentValue = ExitValue;
while (CurrentValue != RedPhiR) {
@@ -6174,18 +6176,16 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
std::swap(Op, PrevValue);
}
- TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
- Type *ExtOpType = CostCtx.Types.inferScalarType(
+ Type *ExtSrcType = CostCtx.Types.inferScalarType(
ExtendedOp->CastRecipes[0]->getOperand(0));
- TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
- if (!PHISize.hasKnownScalarFactor(ASize))
+ TypeSize ExtSrcSize = ExtSrcType->getPrimitiveSizeInBits();
+ if (!PHISize.hasKnownScalarFactor(ExtSrcSize))
return false;
- RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
VPPartialReductionChain Chain(
{UpdateR, ExtendedOp->CastRecipes[0], ExtendedOp->CastRecipes[1],
ExtendedOp->BinOp,
- static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
+ static_cast<unsigned>(PHISize.getKnownScalarFactor(ExtSrcSize)), RK});
if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
return false;
>From f08f2b612ea871b88fefdbb01763ef05695c7637 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Mar 2026 11:17:47 +0000
Subject: [PATCH 4/4] Fixups
---
.../Transforms/Vectorize/VPlanTransforms.cpp | 21 +++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index d09486bce8079..83c38a43f3694 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6053,12 +6053,24 @@ static bool isValidPartialReduction(const VPPartialReductionChain &Chain,
/// casts that feed into it.
struct ExtendedReductionOperand {
VPWidenRecipe *BinOp = nullptr;
- std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
+ // Note: The second cast recipe may be null.
+ std::array<VPWidenCastRecipe *, 2> CastRecipes = {};
};
/// Checks if \p Op (which is an operand of \p UpdateR) is an extended reduction
/// operand. This is an operand where the source of the value (e.g. a load) has
/// been extended (sext, zext, or fpext) before it is used in the reduction.
+///
+/// Possible forms matched by this function:
+/// - UpdateR(PrevValue, ext(...))
+/// - UpdateR(PrevValue, BinOp(ext(...), ext(...)))
+/// - UpdateR(PrevValue, BinOp(ext(...), Constant))
+/// - UpdateR(PrevValue, neg(BinOp(ext(...), ext(...))))
+/// - UpdateR(PrevValue, neg(BinOp(ext(...), Constant)))
+/// - UpdateR(PrevValue, ext(mul(ext(...), ext(...))))
+/// - UpdateR(PrevValue, ext(mul(ext(...), Constant)))
+///
+/// Note: The second operand of UpdateR corresponds to \p Op in the examples.
static std::optional<ExtendedReductionOperand>
matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
assert(is_contained(UpdateR->operands(), Op) &&
@@ -6072,7 +6084,7 @@ matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
std::optional<TTI::PartialReductionExtendKind> OuterExtKind;
if (match(Op, m_ZExtOrSExt(m_Mul(m_VPValue(), m_VPValue()))) ||
match(Op, m_FPExt(m_FMul(m_VPValue(), m_VPValue())))) {
- auto *CastRecipe = dyn_cast<VPWidenCastRecipe>(Op);
+ auto *CastRecipe = cast<VPWidenCastRecipe>(Op);
if (!CastRecipe)
return std::nullopt;
auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
@@ -6083,7 +6095,7 @@ matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
// If the update is a binary op, check both of its operands to see if
// they are extends. Otherwise, see if the update comes directly from an
// extend.
- std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
+ std::array<VPWidenCastRecipe *, 2> CastRecipes = {};
// Match extends and populate CastRecipes. Returns false if matching fails.
auto MatchExtends = [OuterExtKind,
@@ -6168,7 +6180,8 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
// Find the extended operand. The other operand (PrevValue) is the next link
// in the reduction chain.
- auto ExtendedOp = matchExtendedReductionOperand(UpdateR, Op);
+ std::optional<ExtendedReductionOperand> ExtendedOp =
+ matchExtendedReductionOperand(UpdateR, Op);
if (!ExtendedOp) {
ExtendedOp = matchExtendedReductionOperand(UpdateR, PrevValue);
if (!ExtendedOp)
More information about the llvm-commits
mailing list