[llvm] [LV] Simplify the chain traversal in `getScaledReductions()` (NFCI) (PR #184830)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 5 09:55:04 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

I found the logic of this function quite hard to reason about. This patch attempts to rectify this by splitting out matching an extended reduction operand and traversing reduction chain.

- `matchExtendedReductionOperand()` contains all the logic to match an extended operand.
- `getScaledReductions()` validates each operation in the chain, starting backwards from the exit value, walking up through the operand that is not extended.

---
Full diff: https://github.com/llvm/llvm-project/pull/184830.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+71-55) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index a24a483ab5e32..2ec582c18ebba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -6049,22 +6049,20 @@ static bool isValidPartialReduction(const VPPartialReductionChain &Chain,
       Range);
 }
 
-/// Examines reduction operations to see if the target can use a cheaper
-/// operation with a wider per-iteration input VF and narrower PHI VF.
-/// Recursively calls itself to identify chained scaled reductions.
-/// Returns true if this invocation added an entry to Chains, otherwise false.
-static bool
-getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
-                    SmallVectorImpl<VPPartialReductionChain> &Chains,
-                    VPCostContext &CostCtx, VFRange &Range) {
-  auto *UpdateR = dyn_cast<VPWidenRecipe>(PrevValue);
-  if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
-    return false;
+/// Holds the binary operation used to compute the extended operand and the
+/// casts that feed into it.
+struct ExtendedReductionOperand {
+  VPWidenRecipe *BinOp = nullptr;
+  std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
+};
 
-  VPValue *Op = UpdateR->getOperand(0);
-  VPValue *PhiOp = UpdateR->getOperand(1);
-  if (Op == RedPhiR)
-    std::swap(Op, PhiOp);
+/// Checks if \p Op (which is an operand of \p UpdateR) is an extended reduction
+/// operand. This is an operand where the source of the value (e.g. a load) has
+/// been extended (sext, zext, or fpext) before it is used in the reduction.
+static std::optional<ExtendedReductionOperand>
+matchExtendedReductionOperand(VPWidenRecipe *UpdateR, VPValue *Op) {
+  assert(is_contained(UpdateR->operands(), Op) &&
+         "Op should be operand of UpdateR");
 
   // If Op is an extend, then it's still a valid partial reduction if the
   // extended mul fulfills the other requirements.
@@ -6076,36 +6074,16 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
       match(Op, m_FPExt(m_FMul(m_VPValue(), m_VPValue())))) {
     auto *CastRecipe = dyn_cast<VPWidenCastRecipe>(Op);
     if (!CastRecipe)
-      return false;
+      return std::nullopt;
     auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
     OuterExtKind = TTI::getPartialReductionExtendKind(CastOp);
     Op = CastRecipe->getOperand(0);
   }
 
-  // Try and get a scaled reduction from the first non-phi operand.
-  // If one is found, we use the discovered reduction instruction in
-  // place of the accumulator for costing.
-  if (getScaledReductions(RedPhiR, Op, Chains, CostCtx, Range)) {
-    Op = UpdateR->getOperand(0);
-    PhiOp = UpdateR->getOperand(1);
-    if (Op == Chains.rbegin()->ReductionBinOp)
-      std::swap(Op, PhiOp);
-    assert(PhiOp == Chains.rbegin()->ReductionBinOp &&
-           "PhiOp must be the chain value");
-    assert(CostCtx.Types.inferScalarType(RedPhiR) ==
-               CostCtx.Types.inferScalarType(PhiOp) &&
-           "Unexpected type for chain values");
-  } else if (RedPhiR != PhiOp) {
-    // If neither operand of this instruction is the reduction PHI node or a
-    // link in the reduction chain, then this is just an operand to the chain
-    // and not a link in the chain itself.
-    return false;
-  }
-
   // If the update is a binary op, check both of its operands to see if
   // they are extends. Otherwise, see if the update comes directly from an
   // extend.
-  VPWidenCastRecipe *CastRecipes[2] = {nullptr};
+  std::array<VPWidenCastRecipe *, 2> CastRecipes = {nullptr};
 
   // Match extends and populate CastRecipes. Returns false if matching fails.
   auto MatchExtends = [OuterExtKind,
@@ -6144,7 +6122,7 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
   auto *BinOp = dyn_cast<VPWidenRecipe>(Op);
   if (BinOp && Instruction::isBinaryOp(BinOp->getOpcode())) {
     if (!BinOp->hasOneUse())
-      return false;
+      return std::nullopt;
 
     // Handle neg(binop(ext, ext)) pattern.
     VPValue *OtherOp = nullptr;
@@ -6153,33 +6131,71 @@ getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue,
 
     if (!BinOp || !Instruction::isBinaryOp(BinOp->getOpcode()) ||
         !MatchExtends(BinOp->operands()))
-      return false;
+      return std::nullopt;
   } else if (match(UpdateR, m_Add(m_VPValue(), m_VPValue())) ||
              match(UpdateR, m_FAdd(m_VPValue(), m_VPValue()))) {
-    // We already know the operands for Update are Op and PhiOp.
+    // We already know Op is an operand of UpdateR.
     if (!MatchExtends({Op}))
-      return false;
+      return std::nullopt;
     BinOp = UpdateR;
   } else {
-    return false;
+    return std::nullopt;
   }
 
+  return ExtendedReductionOperand{BinOp, CastRecipes};
+}
+
+/// Examines reduction operations to see if the target can use a cheaper
+/// operation with a wider per-iteration input VF and narrower PHI VF.
+/// This works backwards from the \p ExitValue examining each operation in
+/// in the reduction.
+static bool
+getScaledReductions(VPReductionPHIRecipe *RedPhiR, VPValue *ExitValue,
+                    SmallVectorImpl<VPPartialReductionChain> &Chains,
+                    VPCostContext &CostCtx, VFRange &Range) {
   Type *PhiType = CostCtx.Types.inferScalarType(RedPhiR);
-  TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
-  Type *ExtOpType =
-      CostCtx.Types.inferScalarType(CastRecipes[0]->getOperand(0));
-  TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
-  if (!PHISize.hasKnownScalarFactor(ASize))
-    return false;
 
-  RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
-  VPPartialReductionChain Chain(
-      {UpdateR, CastRecipes[0], CastRecipes[1], BinOp,
-       static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
-  if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
-    return false;
+  VPValue *CurrentValue = ExitValue;
+  while (CurrentValue != RedPhiR) {
+    auto *UpdateR = dyn_cast<VPWidenRecipe>(CurrentValue);
+    if (!UpdateR || !Instruction::isBinaryOp(UpdateR->getOpcode()))
+      return false;
+
+    VPValue *Op = UpdateR->getOperand(0);
+    VPValue *PrevValue = UpdateR->getOperand(1);
+
+    // Find the extended operand. The other operand (PrevValue) is the next link
+    // in the reduction chain.
+    auto ExtendedOp = matchExtendedReductionOperand(UpdateR, Op);
+    if (!ExtendedOp) {
+      ExtendedOp = matchExtendedReductionOperand(UpdateR, PrevValue);
+      if (!ExtendedOp)
+        return false;
+      std::swap(Op, PrevValue);
+    }
+
+    TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
+    Type *ExtOpType = CostCtx.Types.inferScalarType(
+        ExtendedOp->CastRecipes[0]->getOperand(0));
+    TypeSize ASize = ExtOpType->getPrimitiveSizeInBits();
+    if (!PHISize.hasKnownScalarFactor(ASize))
+      return false;
+
+    RecurKind RK = cast<VPReductionPHIRecipe>(RedPhiR)->getRecurrenceKind();
+    VPPartialReductionChain Chain(
+        {UpdateR, ExtendedOp->CastRecipes[0], ExtendedOp->CastRecipes[1],
+         ExtendedOp->BinOp,
+         static_cast<unsigned>(PHISize.getKnownScalarFactor(ASize)), RK});
+    if (!isValidPartialReduction(Chain, PhiType, CostCtx, Range))
+      return false;
+
+    Chains.push_back(Chain);
+    CurrentValue = PrevValue;
+  }
 
-  Chains.push_back(Chain);
+  // The chains were collected by traversing backwards from the exit value.
+  // Reverse the chains so they are in program order.
+  std::reverse(Chains.begin(), Chains.end());
   return true;
 }
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/184830


More information about the llvm-commits mailing list