[llvm] [LV] Simplify the chain traversal in `getScaledReductions()` (NFCI) (PR #184830)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 5 11:47:10 PST 2026


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/184830

>From d49b92a7def08427535fe038ee2cc05c30a9add9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 17:13:32 +0000
Subject: [PATCH 1/3] [LV] Simplify the chain traversal in
 `getScaledReductions()` (NFCI)

I found the logic of this function quite hard to reason about. This
patch attempts to rectify this by splitting out matching an extended
reduction operand and traversing reduction chain.

- `matchExtendedReductionOperand()` contains all the logic to match an
  extended operand.
- `getScaledReductions()` validates each operation in the chain,
  starting backwards from the exit value, walking up through the operand
  that is not extended.
---
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 126 ++++++++++--------
 1 file changed, 71 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a24a483ab5e32..a09dc1ddf8115 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6049,22 +6049,20 @@ static bool isValidPartialReduction(const VPPartialReductionChain &Chain,
       Range);
 }
 
-/// Examines reduction operations to see if the target can use a cheaper
-/// operation with a wider per-iteration input VF and narrower PHI VF.
-/// Recursively calls itself to identify chained scaled reductions.
-/// Returns true if this invocation added an entry to Chains, otherwise false.
-static bool
-getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
-                    SmallVectorImpl<VPPartialReductionChain> &Chains,
-                    VPCostContext &CostCtx, VFRange &Range) {
-  auto *UpdateR = dyn_cast<VPWidenRecipe>(PrevValue);
-  if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
-    return false;
+/// Holds the binary operation used to compute the extended operand and the
+/// casts that feed into it.
+struct ExtendedReductionOperand {
+  VPWidenRecipe *BinOp = nullptr;
+  std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
+};
 
-  VPValue *Op = UpdateR->getOperand(0);
-  VPValue *PhiOp = UpdateR->getOperand(1);
-  if (Op == RedPhiR)
-    std::swap(Op, PhiOp);
+/// Checks if \p Op (which is an operand of \p UpdateR) is an extended reduction
+/// operand. This is an operand where the source of the value (e.g. a load) has
+/// been extended (sext, zext, or fpext) before it is used in the reduction.
+static std::optional<ExtendedReductionOperand>
+matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
+  assert(is_contained(UpdateR->operands(), Op) &&
+         "Op should be operand of UpdateR");
 
   // If Op is an extend, then it's still a valid partial reduction if the
   // extended mul fulfills the other requirements.
@@ -6076,36 +6074,16 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
       match(Op, m_FPExt(m_FMul(m_VPValue(), m_VPValue())))) {
     auto *CastRecipe = dyn_cast<VPWidenCastRecipe>(Op);
     if (!CastRecipe)
-      return false;
+      return std::nullopt;
     auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
     OuterExtKind = TTI::getPartialReductionExtendKind(CastOp);
     Op = CastRecipe->getOperand(0);
   }
 
-  // Try and get a scaled reduction from the first non-phi operand.
-  // If one is found, we use the discovered reduction instruction in
-  // place of the accumulator for costing.
-  if (getScaledReductions(RedPhiR, Op, Chains, CostCtx, Range)) {
-    Op = UpdateR->getOperand(0);
-    PhiOp = UpdateR->getOperand(1);
-    if (Op == Chains.rbegin()->ReductionBinOp)
-      std::swap(Op, PhiOp);
-    assert(PhiOp == Chains.rbegin()->ReductionBinOp &&
-           "PhiOp must be the chain value");
-    assert(CostCtx.Types.inferScalarType(RedPhiR) ==
-               CostCtx.Types.inferScalarType(PhiOp) &&
-           "Unexpected type for chain values");
-  } else if (RedPhiR != PhiOp) {
-    // If neither operand of this instruction is the reduction PHI node or a
-    // link in the reduction chain, then this is just an operand to the chain
-    // and not a link in the chain itself.
-    return false;
-  }
-
   // If the update is a binary op, check both of its operands to see if
   // they are extends. Otherwise, see if the update comes directly from an
   // extend.
-  VPWidenCastRecipe *CastRecipes[2] = {nullptr};
+  std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
 
   // Match extends and populate CastRecipes. Returns false if matching fails.
   auto MatchExtends = [OuterExtKind,
@@ -6144,7 +6122,7 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
   auto *BinOp = dyn_cast<VPWidenRecipe>(Op);
   if (BinOp && Instruction::isBinaryOp(BinOp->getOpcode())) {
     if (!BinOp->hasOneUse())
-      return false;
+      return std::nullopt;
 
     // Handle neg(binop(ext, ext)) pattern.
     VPValue *OtherOp = nullptr;
@@ -6153,33 +6131,71 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
 
     if (!BinOp || !Instruction::isBinaryOp(BinOp->getOpcode()) ||
         !MatchExtends(BinOp->operands()))
-      return false;
+      return std::nullopt;
   } else if (match(UpdateR, m_Add(m_VPValue(), m_VPValue())) ||
              match(UpdateR, m_FAdd(m_VPValue(), m_VPValue()))) {
-    // We already know the operands for Update are Op and PhiOp.
+    // We already know Op is an operand of UpdateR.
     if (!MatchExtends({Op}))
-      return false;
+      return std::nullopt;
     BinOp = UpdateR;
   } else {
-    return false;
+    return std::nullopt;
   }
 
+  return ExtendedReductionOperand{BinOp, CastRecipes};
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// This works backwards from the \p ExitValue examining each operation in
+/// in the reduction.
+static bool
+getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
+                    SmallVectorImpl<VPPartialReductionChain> &Chains,
+                    VPCostContext &CostCtx, VFRange &Range) {
   Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
-  TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
-  Type *ExtOpType =
-      CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0));
-  TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
-  if (!PHISize.hasKnownScalarFactor(ASize))
-    return false;
 
-  RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
-  VPPartialReductionChain Chain(
-      {UpdateR, CastRecipes[0], CastRecipes[1], BinOp,
-       static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
-  if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
-    return false;
+  VPValue *CurrentValue = ExitValue;
+  while (CurrentValue != RedPhiR) {
+    auto *UpdateR = dyn_cast<VPWidenRecipe>(CurrentValue);
+    if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
+      return false;
+
+    VPValue *Op = UpdateR->getOperand(0);
+    VPValue *PrevValue = UpdateR->getOperand(1);
+
+    // Find the extended operand. The other operand (PrevValue) is the next link
+    // in the reduction chain.
+    auto ExtendedOp = matchExtendedReductionOperand(UpdateR, Op);
+    if (!ExtendedOp) {
+      ExtendedOp = matchExtendedReductionOperand(UpdateR, PrevValue);
+      if (!ExtendedOp)
+        return false;
+      std::swap(Op, PrevValue);
+    }
+
+    TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
+    Type *ExtOpType = CostCtx.Types.inferScalarType(
+        ExtendedOp->CastRecipes[0]->getOperand(0));
+    TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
+    if (!PHISize.hasKnownScalarFactor(ASize))
+      return false;
+
+    RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
+    VPPartialReductionChain Chain(
+        {UpdateR, ExtendedOp->CastRecipes[0], ExtendedOp->CastRecipes[1],
+         ExtendedOp->BinOp,
+         static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
+    if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
+      return false;
+
+    Chains.push_back(Chain);
+    CurrentValue = PrevValue;
+  }
 
-  Chains.push_back(Chain);
+  // The chains were collected by traversing the chain backwards from the exit
+  // value. Reverse them to they are in program order.
+  std::reverse(Chains.begin(), Chains.end());
   return true;
 }
 } // namespace

>From 22e021c3d9273f451d59865fa580a55f53435ec2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 17:35:11 +0000
Subject: [PATCH 2/3] Fix typo

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a09dc1ddf8115..2ec582c18ebba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6193,8 +6193,8 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
     CurrentValue = PrevValue;
   }
 
-  // The chains were collected by traversing the chain backwards from the exit
-  // value. Reverse them to they are in program order.
+  // The chains were collected by traversing backwards from the exit value.
+  // Reverse the chains so they are in program order.
   std::reverse(Chains.begin(), Chains.end());
   return true;
 }

>From 89ca4658823f76464c65a680512f5be38594afd9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Mar 2026 19:46:18 +0000
Subject: [PATCH 3/3] Fixups

---
 llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 2ec582c18ebba..d09486bce8079 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6153,7 +6153,9 @@ static bool
 getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
                     SmallVectorImpl<VPPartialReductionChain> &Chains,
                     VPCostContext &CostCtx, VFRange &Range) {
+  RecurKind RK = RedPhiR->getRecurrenceKind();
   Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
+  TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
 
   VPValue *CurrentValue = ExitValue;
   while (CurrentValue != RedPhiR) {
@@ -6174,18 +6176,16 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
       std::swap(Op, PrevValue);
     }
 
-    TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
-    Type *ExtOpType = CostCtx.Types.inferScalarType(
+    Type *ExtSrcType = CostCtx.Types.inferScalarType(
         ExtendedOp->CastRecipes[0]->getOperand(0));
-    TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
-    if (!PHISize.hasKnownScalarFactor(ASize))
+    TypeSize ExtSrcSize = ExtSrcType->getPrimitiveSizeInBits();
+    if (!PHISize.hasKnownScalarFactor(ExtSrcSize))
       return false;
 
-    RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
     VPPartialReductionChain Chain(
         {UpdateR, ExtendedOp->CastRecipes[0], ExtendedOp->CastRecipes[1],
          ExtendedOp->BinOp,
-         static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
+         static_cast<unsigned>(PHISize.getKnownScalarFactor(ExtSrcSize)), RK});
     if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
       return false;
 



More information about the llvm-commits mailing list