[llvm] [VPlan] Detect and create partial reductions in VPlan. (NFCI) (PR #167851)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 19 08:55:24 PST 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/167851

>From 28840375524bbebc5b79525eeda6485e3bbd7731 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 3 Nov 2025 15:57:31 +0000
Subject: [PATCH 1/3] VPlan] Detect and create partial reductions in VPlan.
 (NFCI)

As a first step, move the existing partial reduction detection logic to
VPlan, trying to preserve the existing code structure & behavior as
closely as possible.

With this, partial reductions are detected and created together in a
single step.

This allows forming partial reductions and bundling them up if
profitable together in a follow-up.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 304 +++--------------
 .../Transforms/Vectorize/VPRecipeBuilder.h    |  72 +---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  42 +--
 .../Vectorize/VPlanConstruction.cpp           | 320 +++++++++++++++++-
 .../Transforms/Vectorize/VPlanTransforms.h    |  10 +
 .../partial-reduce-incomplete-chains.ll       |  40 +--
 6 files changed, 420 insertions(+), 368 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f4e3d899749ed..c87634da87ecf 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8011,212 +8011,51 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
   return Recipe;
 }
 
-/// Find all possible partial reductions in the loop and track all of those that
-/// are valid so recipes can be formed later.
-void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
-  // Find all possible partial reductions, grouping chains by their PHI. This
-  // grouping allows invalidating the whole chain, if any link is not a valid
-  // partial reduction.
-  MapVector<Instruction *,
-            SmallVector<std::pair<PartialReductionChain, unsigned>>>
-      ChainsByPhi;
-  for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
-    if (Instruction *RdxExitInstr = RdxDesc.getLoopExitInstr())
-      getScaledReductions(Phi, RdxExitInstr, Range, ChainsByPhi[Phi]);
-  }
-
-  // A partial reduction is invalid if any of its extends are used by
-  // something that isn't another partial reduction. This is because the
-  // extends are intended to be lowered along with the reduction itself.
-
-  // Build up a set of partial reduction ops for efficient use checking.
-  SmallPtrSet<User *, 4> PartialReductionOps;
-  for (const auto &[_, Chains] : ChainsByPhi)
-    for (const auto &[PartialRdx, _] : Chains)
-      PartialReductionOps.insert(PartialRdx.ExtendUser);
-
-  auto ExtendIsOnlyUsedByPartialReductions =
-      [&PartialReductionOps](Instruction *Extend) {
-        return all_of(Extend->users(), [&](const User *U) {
-          return PartialReductionOps.contains(U);
-        });
-      };
-
-  // Check if each use of a chain's two extends is a partial reduction
-  // and only add those that don't have non-partial reduction users.
-  for (const auto &[_, Chains] : ChainsByPhi) {
-    for (const auto &[Chain, Scale] : Chains) {
-      if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-          (!Chain.ExtendB ||
-           ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
-        ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
-    }
-  }
-
-  // Check that all partial reductions in a chain are only used by other
-  // partial reductions with the same scale factor. Otherwise we end up creating
-  // users of scaled reductions where the types of the other operands don't
-  // match.
-  for (const auto &[Phi, Chains] : ChainsByPhi) {
-    for (const auto &[Chain, Scale] : Chains) {
-      auto AllUsersPartialRdx = [ScaleVal = Scale, RdxPhi = Phi,
-                                 this](const User *U) {
-        auto *UI = cast<Instruction>(U);
-        if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader())
-          return UI == RdxPhi;
-        return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
-               !OrigLoop->contains(UI->getParent());
-      };
-
-      // If any partial reduction entry for the phi is invalid, invalidate the
-      // whole chain.
-      if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
-        for (const auto &[Chain, _] : Chains)
-          ScaledReductionMap.erase(Chain.Reduction);
-        break;
-      }
-    }
-  }
-}
-
-bool VPRecipeBuilder::getScaledReductions(
-    Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
-    SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-  if (!CM.TheLoop->contains(RdxExitInstr))
-    return false;
-
-  auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
-  if (!Update)
-    return false;
+VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
+                                                      VFRange &Range) {
+  // First, check for specific widening recipes that deal with inductions, Phi
+  // nodes, calls and memory operations.
+  VPRecipeBase *Recipe;
+  if (auto *PhiR = dyn_cast<VPPhi>(R)) {
+    VPBasicBlock *Parent = PhiR->getParent();
+    [[maybe_unused]] VPRegionBlock *LoopRegionOf =
+        Parent->getEnclosingLoopRegion();
+    assert(LoopRegionOf && LoopRegionOf->getEntry() == Parent &&
+           "Non-header phis should have been handled during predication");
+    auto *Phi = cast<PHINode>(R->getUnderlyingInstr());
+    assert(R->getNumOperands() == 2 && "Must have 2 operands for header phis");
+
+    VPHeaderPHIRecipe *PhiRecipe = nullptr;
+    assert((Legal->isReductionVariable(Phi) ||
+            Legal->isFixedOrderRecurrence(Phi)) &&
+           "can only widen reductions and fixed-order recurrences here");
+    VPValue *StartV = R->getOperand(0);
+    if (Legal->isReductionVariable(Phi)) {
+      const RecurrenceDescriptor &RdxDesc = Legal->getRecurrenceDescriptor(Phi);
+      assert(RdxDesc.getRecurrenceStartValue() ==
+             Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
 
-  Value *Op = Update->getOperand(0);
-  Value *PhiOp = Update->getOperand(1);
-  if (Op == PHI)
-    std::swap(Op, PhiOp);
+      // If the PHI is used by a partial reduction, set the scale factor.
+      bool UseInLoopReduction = CM.isInLoopReduction(Phi);
+      bool UseOrderedReductions = CM.useOrderedReductions(RdxDesc);
 
-  using namespace llvm::PatternMatch;
-  // If Op is an extend, then it's still a valid partial reduction if the
-  // extended mul fulfills the other requirements.
-  // For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
-  // reduction since the inner extends will be widened. We already have oneUse
-  // checks on the inner extends so widening them is safe.
-  std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
-  if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value())))) {
-    auto *Cast = cast<CastInst>(Op);
-    OuterExtKind = TTI::getPartialReductionExtendKind(Cast->getOpcode());
-    Op = Cast->getOperand(0);
-  }
-
-  // Try and get a scaled reduction from the first non-phi operand.
-  // If one is found, we use the discovered reduction instruction in
-  // place of the accumulator for costing.
-  if (auto *OpInst = dyn_cast<Instruction>(Op)) {
-    if (getScaledReductions(PHI, OpInst, Range, Chains)) {
-      PHI = Chains.rbegin()->first.Reduction;
-
-      Op = Update->getOperand(0);
-      PhiOp = Update->getOperand(1);
-      if (Op == PHI)
-        std::swap(Op, PhiOp);
+      PhiRecipe = new VPReductionPHIRecipe(
+          Phi, RdxDesc.getRecurrenceKind(), *StartV,
+          getReductionStyle(UseInLoopReduction, UseOrderedReductions, 1),
+          RdxDesc.hasUsesOutsideReductionChain());
+    } else {
+      // TODO: Currently fixed-order recurrences are modeled as chains of
+      // first-order recurrences. If there are no users of the intermediate
+      // recurrences in the chain, the fixed order recurrence should be modeled
+      // directly, enabling more efficient codegen.
+      PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV);
     }
+    // Add backedge value.
+    PhiRecipe->addOperand(R->getOperand(1));
+    return PhiRecipe;
   }
-  if (PhiOp != PHI)
-    return false;
-
-  // If the update is a binary operator, check both of its operands to see if
-  // they are extends. Otherwise, see if the update comes directly from an
-  // extend.
-  Instruction *Exts[2] = {nullptr};
-  BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
-  std::optional<unsigned> BinOpc;
-  Type *ExtOpTypes[2] = {nullptr};
-  TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};
-
-  auto CollectExtInfo = [this, OuterExtKind, &Exts, &ExtOpTypes,
-                         &ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
-    for (const auto &[I, OpI] : enumerate(Ops)) {
-      const APInt *C;
-      if (I > 0 && match(OpI, m_APInt(C)) &&
-          canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) {
-        ExtOpTypes[I] = ExtOpTypes[0];
-        ExtKinds[I] = ExtKinds[0];
-        continue;
-      }
-      Value *ExtOp;
-      if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
-        return false;
-      Exts[I] = cast<Instruction>(OpI);
+  assert(!R->isPhi() && "only VPPhi nodes expected at this point");
 
-      // TODO: We should be able to support live-ins.
-      if (!CM.TheLoop->contains(Exts[I]))
-        return false;
-
-      ExtOpTypes[I] = ExtOp->getType();
-      ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]);
-      // The outer extend kind must be the same as the inner extends, so that
-      // they can be folded together.
-      if (OuterExtKind.has_value() && OuterExtKind.value() != ExtKinds[I])
-        return false;
-    }
-    return true;
-  };
-
-  if (ExtendUser) {
-    if (!ExtendUser->hasOneUse())
-      return false;
-
-    // Use the side-effect of match to replace BinOp only if the pattern is
-    // matched, we don't care at this point whether it actually matched.
-    match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
-
-    SmallVector<Value *> Ops(ExtendUser->operands());
-    if (!CollectExtInfo(Ops))
-      return false;
-
-    BinOpc = std::make_optional(ExtendUser->getOpcode());
-  } else if (match(Update, m_Add(m_Value(), m_Value()))) {
-    // We already know the operands for Update are Op and PhiOp.
-    SmallVector<Value *> Ops({Op});
-    if (!CollectExtInfo(Ops))
-      return false;
-
-    ExtendUser = Update;
-    BinOpc = std::nullopt;
-  } else
-    return false;
-
-  PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
-
-  TypeSize PHISize = PHI->getType()->getPrimitiveSizeInBits();
-  TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits();
-  if (!PHISize.hasKnownScalarFactor(ASize))
-    return false;
-  unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize);
-
-  if (LoopVectorizationPlanner::getDecisionAndClampRange(
-          [&](ElementCount VF) {
-            InstructionCost Cost = TTI->getPartialReductionCost(
-                Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
-                PHI->getType(), VF, ExtKinds[0], ExtKinds[1], BinOpc,
-                CM.CostKind);
-            return Cost.isValid();
-          },
-          Range)) {
-    Chains.emplace_back(Chain, TargetScaleFactor);
-    return true;
-  }
-
-  return false;
-}
-
-VPRecipeBase *
-VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
-                                              VFRange &Range) {
-  assert(!R->isPhi() && "phis must be handled earlier");
-  // First, check for specific widening recipes that deal with optimizing
-  // truncates, calls and memory operations.
-
-  VPRecipeBase *Recipe;
   auto *VPI = cast<VPInstruction>(R);
   if (VPI->getOpcode() == Instruction::Trunc &&
       (Recipe = tryToOptimizeInductionTruncate(VPI, Range)))
@@ -8239,9 +8078,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
       VPI->getOpcode() == Instruction::Store)
     return tryToWidenMemory(VPI, Range);
 
-  if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
-    return tryToCreatePartialReduction(VPI, ScaleFactor.value());
-
   if (!shouldWiden(Instr, Range))
     return nullptr;
 
@@ -8264,46 +8100,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
   return tryToWiden(VPI);
 }
 
-VPRecipeBase *
-VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction,
-                                             unsigned ScaleFactor) {
-  assert(Reduction->getNumOperands() == 2 &&
-         "Unexpected number of operands for partial reduction");
-
-  VPValue *BinOp = Reduction->getOperand(0);
-  VPValue *Accumulator = Reduction->getOperand(1);
-  VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
-  if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
-      (isa<VPReductionRecipe>(BinOpRecipe) &&
-       cast<VPReductionRecipe>(BinOpRecipe)->isPartialReduction()))
-    std::swap(BinOp, Accumulator);
-
-  if (auto *RedPhiR = dyn_cast<VPReductionPHIRecipe>(Accumulator))
-    RedPhiR->setVFScaleFactor(ScaleFactor);
-
-  assert(ScaleFactor ==
-             vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) &&
-         "all accumulators in chain must have same scale factor");
-
-  auto *ReductionI = Reduction->getUnderlyingInstr();
-  if (Reduction->getOpcode() == Instruction::Sub) {
-    SmallVector<VPValue *, 2> Ops;
-    Ops.push_back(Plan.getConstantInt(ReductionI->getType(), 0));
-    Ops.push_back(BinOp);
-    BinOp = new VPWidenRecipe(*ReductionI, Ops, VPIRFlags(*ReductionI),
-                              VPIRMetadata(), ReductionI->getDebugLoc());
-    Builder.insert(BinOp->getDefiningRecipe());
-  }
-
-  VPValue *Cond = nullptr;
-  if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent()))
-    Cond = getBlockInMask(Builder.getInsertBlock());
-
-  return new VPReductionRecipe(
-      RecurKind::Add, FastMathFlags(), ReductionI, Accumulator, BinOp, Cond,
-      RdxUnordered{/*VFScaleFactor=*/ScaleFactor}, ReductionI->getDebugLoc());
-}
-
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   if (ElementCount::isKnownGT(MinVF, MaxVF))
@@ -8438,11 +8234,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
   // Construct wide recipes and apply predication for original scalar
   // VPInstructions in the loop.
   // ---------------------------------------------------------------------------
-  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, Builder,
+  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder,
                                 BlockMaskCache);
-  // TODO: Handle partial reductions with EVL tail folding.
-  if (!CM.foldTailWithEVL())
-    RecipeBuilder.collectScaledReductions(Range);
 
   // Scan the body of the loop in a topological order to visit each basic block
   // after having visited its predecessor basic blocks.
@@ -8490,7 +8283,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
       }
 
       VPRecipeBase *Recipe =
-          RecipeBuilder.tryToCreateWidenNonPhiRecipe(VPI, Range);
+          RecipeBuilder.tryToCreateWidenRecipe(VPI, Range);
       if (!Recipe)
         Recipe =
             RecipeBuilder.handleReplication(cast<VPInstruction>(VPI), Range);
@@ -8552,11 +8345,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
                                 *Plan))
     return nullptr;
 
-  // Transform recipes to abstract recipes if it is legal and beneficial and
-  // clamp the range for better cost estimation.
-  // TODO: Enable following transform when the EVL-version of extended-reduction
-  // and mulacc-reduction are implemented.
   if (!CM.foldTailWithEVL()) {
+    // Create partial reduction recipes for scaled reductions.
+    VPlanTransforms::createPartialReductions(*Plan, Range, &TTI, CM.CostKind);
+
     VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
                           *CM.PSE.getSE(), OrigLoop);
     VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
@@ -8992,11 +8784,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       VPBuilder PHBuilder(Plan->getVectorPreheader());
       VPValue *Iden = Plan->getOrAddLiveIn(
           getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
-      // If the PHI is used by a partial reduction, set the scale factor.
-      unsigned ScaleFactor =
-          RecipeBuilder.getScalingForReduction(RdxDesc.getLoopExitInstr())
-              .value_or(1);
-      auto *ScaleFactorVPV = Plan->getConstantInt(32, ScaleFactor);
+      auto *ScaleFactorVPV = Plan->getConstantInt(32, 1);
       VPValue *StartV = PHBuilder.createNaryOp(
           VPInstruction::ReductionStartVector,
           {PhiR->getStartValue(), Iden, ScaleFactorVPV},
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 534f51e86a9b7..d537f62c95a13 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -19,30 +19,9 @@ namespace llvm {
 class LoopVectorizationLegality;
 class LoopVectorizationCostModel;
 class TargetLibraryInfo;
-class TargetTransformInfo;
 struct HistogramInfo;
 struct VFRange;
 
-/// A chain of instructions that form a partial reduction.
-/// Designed to match either:
-///   reduction_bin_op (extend (A), accumulator), or
-///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
-struct PartialReductionChain {
-  PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
-                        Instruction *ExtendB, Instruction *ExtendUser)
-      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
-        ExtendUser(ExtendUser) {}
-  /// The top-level binary operation that forms the reduction to a scalar
-  /// after the loop body.
-  Instruction *Reduction;
-  /// The extension of each of the inner binary operation's operands.
-  Instruction *ExtendA;
-  Instruction *ExtendB;
-
-  /// The user of the extends that is then reduced.
-  Instruction *ExtendUser;
-};
-
 /// Helper class to create VPRecipies from IR instructions.
 class VPRecipeBuilder {
   /// The VPlan new recipes are added to.
@@ -54,15 +33,14 @@ class VPRecipeBuilder {
   /// Target Library Info.
   const TargetLibraryInfo *TLI;
 
-  // Target Transform Info.
-  const TargetTransformInfo *TTI;
-
   /// The legality analysis.
   LoopVectorizationLegality *Legal;
 
   /// The profitablity analysis.
   LoopVectorizationCostModel &CM;
 
+  PredicatedScalarEvolution &PSE;
+
   VPBuilder &Builder;
 
   /// The mask of each VPBB, generated earlier and used for predicating recipes
@@ -79,9 +57,6 @@ class VPRecipeBuilder {
   /// created.
   SmallVector<VPHeaderPHIRecipe *, 4> PhisToFix;
 
-  /// A mapping of partial reduction exit instructions to their scaling factor.
-  DenseMap<const Instruction *, unsigned> ScaledReductionMap;
-
   /// Check if \p I can be widened at the start of \p Range and possibly
   /// decrease the range such that the returned value holds for the entire \p
   /// Range. The function should not be called for memory instructions or calls.
@@ -114,47 +89,18 @@ class VPRecipeBuilder {
   VPHistogramRecipe *tryToWidenHistogram(const HistogramInfo *HI,
                                          VPInstruction *VPI);
 
-  /// Examines reduction operations to see if the target can use a cheaper
-  /// operation with a wider per-iteration input VF and narrower PHI VF.
-  /// Each element within Chains is a pair with a struct containing reduction
-  /// information and the scaling factor between the number of elements in
-  /// the input and output.
-  /// Recursively calls itself to identify chained scaled reductions.
-  /// Returns true if this invocation added an entry to Chains, otherwise false.
-  /// i.e. returns false in the case that a subcall adds an entry to Chains,
-  /// but the top-level call does not.
-  bool getScaledReductions(
-      Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
-      SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains);
-
 public:
   VPRecipeBuilder(VPlan &Plan, Loop *OrigLoop, const TargetLibraryInfo *TLI,
-                  const TargetTransformInfo *TTI,
                   LoopVectorizationLegality *Legal,
-                  LoopVectorizationCostModel &CM, VPBuilder &Builder,
+                  LoopVectorizationCostModel &CM,
+                  PredicatedScalarEvolution &PSE, VPBuilder &Builder,
                   DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache)
-      : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), TTI(TTI), Legal(Legal),
-        CM(CM), Builder(Builder), BlockMaskCache(BlockMaskCache) {}
-
-  std::optional<unsigned> getScalingForReduction(const Instruction *ExitInst) {
-    auto It = ScaledReductionMap.find(ExitInst);
-    return It == ScaledReductionMap.end() ? std::nullopt
-                                          : std::make_optional(It->second);
-  }
-
-  /// Find all possible partial reductions in the loop and track all of those
-  /// that are valid so recipes can be formed later.
-  void collectScaledReductions(VFRange &Range);
-
-  /// Create and return a widened recipe for a non-phi recipe \p R if one can be
-  /// created within the given VF \p Range.
-  VPRecipeBase *tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
-                                             VFRange &Range);
+      : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM),
+        PSE(PSE), Builder(Builder), BlockMaskCache(BlockMaskCache) {}
 
-  /// Create and return a partial reduction recipe for a reduction instruction
-  /// along with binary operation and reduction phi operands.
-  VPRecipeBase *tryToCreatePartialReduction(VPInstruction *Reduction,
-                                            unsigned ScaleFactor);
+  /// Create and return a widened recipe for \p R if one can be created within
+  /// the given VF \p Range.
+  VPRecipeBase *tryToCreateWidenRecipe(VPSingleDefRecipe *R, VFRange &Range);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 40f8fe9203d2a..a697f829343a5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1525,6 +1525,12 @@ class LLVM_ABI_FOR_TEST VPWidenRecipe : public VPRecipeWithIRFlags,
     setUnderlyingValue(&I);
   }
 
+  VPWidenRecipe(unsigned Opcode, ArrayRef<VPValue *> Operands,
+                const VPIRFlags &Flags = {}, const VPIRMetadata &Metadata = {},
+                DebugLoc DL = {})
+      : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, Flags, DL),
+        VPIRMetadata(Metadata), Opcode(Opcode) {}
+
   ~VPWidenRecipe() override = default;
 
   VPWidenRecipe *clone() override {
@@ -2373,17 +2379,16 @@ class LLVM_ABI_FOR_TEST VPWidenPHIRecipe : public VPSingleDefRecipe,
 /// first operand of the recipe and the incoming value from the backedge is the
 /// second operand.
 struct VPFirstOrderRecurrencePHIRecipe : public VPHeaderPHIRecipe {
-  VPFirstOrderRecurrencePHIRecipe(PHINode *Phi, VPValue &Start,
-                                  VPValue &BackedgeValue)
-      : VPHeaderPHIRecipe(VPDef::VPFirstOrderRecurrencePHISC, Phi, &Start) {
-    addOperand(&BackedgeValue);
-  }
+  VPFirstOrderRecurrencePHIRecipe(PHINode *Phi, VPValue &Start)
+      : VPHeaderPHIRecipe(VPDef::VPFirstOrderRecurrencePHISC, Phi, &Start) {}
 
   VP_CLASSOF_IMPL(VPDef::VPFirstOrderRecurrencePHISC)
 
   VPFirstOrderRecurrencePHIRecipe *clone() override {
-    return new VPFirstOrderRecurrencePHIRecipe(
-        cast<PHINode>(getUnderlyingInstr()), *getOperand(0), *getOperand(1));
+    auto *R = new VPFirstOrderRecurrencePHIRecipe(
+        cast<PHINode>(getUnderlyingInstr()), *getOperand(0));
+    R->addOperand(getBackedgeValue());
+    return R;
   }
 
   void execute(VPTransformState &State) override;
@@ -2449,21 +2454,20 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
 public:
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi.
   VPReductionPHIRecipe(PHINode *Phi, RecurKind Kind, VPValue &Start,
-                       VPValue &BackedgeValue, ReductionStyle Style,
+                       ReductionStyle Style,
                        bool HasUsesOutsideReductionChain = false)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start), Kind(Kind),
         Style(Style),
-        HasUsesOutsideReductionChain(HasUsesOutsideReductionChain) {
-    addOperand(&BackedgeValue);
-  }
+        HasUsesOutsideReductionChain(HasUsesOutsideReductionChain) {}
 
   ~VPReductionPHIRecipe() override = default;
 
   VPReductionPHIRecipe *clone() override {
-    return new VPReductionPHIRecipe(
+    auto *R = new VPReductionPHIRecipe(
         dyn_cast_or_null<PHINode>(getUnderlyingValue()), getRecurrenceKind(),
-        *getOperand(0), *getBackedgeValue(), Style,
-        HasUsesOutsideReductionChain);
+        *getOperand(0), Style, HasUsesOutsideReductionChain);
+    R->addOperand(getBackedgeValue());
+    return R;
   }
 
   VP_CLASSOF_IMPL(VPDef::VPReductionPHISC)
@@ -2478,11 +2482,11 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
     return Partial ? Partial->VFScaleFactor : 1;
   }
 
-  /// Set the VFScaleFactor for this reduction phi. Can only be set to a factor
-  /// > 1.
-  void setVFScaleFactor(unsigned ScaleFactor) {
-    assert(ScaleFactor > 1 && "must set to scale factor > 1");
-    Style = RdxUnordered{ScaleFactor};
+  /// Set the factor that the VF of this recipe's output should be scaled by.
+  void setVFScaleFactor(unsigned Factor) {
+    auto *Partial = std::get_if<RdxUnordered>(&Style);
+    assert(Partial && "Can only set VFScaleFactor for unordered reductions");
+    Partial->VFScaleFactor = Factor;
   }
 
   /// Returns the number of incoming values, also number of incoming blocks.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 318c05d8ef7c5..d03943ad04869 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -12,15 +12,20 @@
 //===----------------------------------------------------------------------===//
 
 #include "LoopVectorizationPlanner.h"
+#include "VPRecipeBuilder.h"
 #include "VPlan.h"
+#include "VPlanAnalysis.h"
 #include "VPlanCFG.h"
 #include "VPlanDominatorTree.h"
+#include "VPlanHelpers.h"
 #include "VPlanPatternMatch.h"
 #include "VPlanTransforms.h"
+#include "VPlanUtils.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopIterator.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -664,7 +669,9 @@ void VPlanTransforms::createHeaderPhiRecipes(
       // first-order recurrences. If there are no users of the intermediate
       // recurrences in the chain, the fixed order recurrence should be
       // modeled directly, enabling more efficient codegen.
-      return new VPFirstOrderRecurrencePHIRecipe(Phi, *Start, *BackedgeValue);
+      auto *Recipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *Start);
+      Recipe->addOperand(BackedgeValue);
+      return Recipe;
     }
 
     auto InductionIt = Inductions.find(Phi);
@@ -681,11 +688,13 @@ void VPlanTransforms::createHeaderPhiRecipes(
     // Will be updated later to >1 if reduction is partial.
     unsigned ScaleFactor = 1;
     bool UseOrderedReductions = !AllowReordering && RdxDesc.isOrdered();
-    return new VPReductionPHIRecipe(
-        Phi, RdxDesc.getRecurrenceKind(), *Start, *BackedgeValue,
+    auto *Recipe = new VPReductionPHIRecipe(
+        Phi, RdxDesc.getRecurrenceKind(), *Start,
         getReductionStyle(InLoopReductions.contains(Phi), UseOrderedReductions,
                           ScaleFactor),
         RdxDesc.hasUsesOutsideReductionChain());
+    Recipe->addOperand(BackedgeValue);
+    return Recipe;
   };
 
   for (VPRecipeBase &R : make_early_inc_range(HeaderVPBB->phis())) {
@@ -1271,3 +1280,308 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
   }
   return true;
 }
+
+namespace {
+
+/// A chain of recipes that form a partial reduction. Matches either
+///   reduction_bin_op (extend (A), accumulator), or
+///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
+struct VPPartialReductionChain {
+  VPPartialReductionChain(VPWidenRecipe *Reduction, VPWidenCastRecipe *ExtendA,
+                          VPWidenCastRecipe *ExtendB, VPWidenRecipe *ExtendUser)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+        ExtendUser(ExtendUser) {}
+  /// The top-level binary operation that forms the reduction to a scalar
+  /// after the loop body.
+  VPWidenRecipe *Reduction;
+  /// The extension of each of the inner binary operation's operands.
+  VPWidenCastRecipe *ExtendA;
+  VPWidenCastRecipe *ExtendB;
+
+  /// The user of the extends that is then reduced.
+  VPWidenRecipe *ExtendUser;
+};
+
+// Helper to transform a single widen recipe into a partial reduction recipe.
+// Returns true if transformation succeeded.
+static bool transformToPartialReduction(VPWidenRecipe *WidenRecipe,
+                                        unsigned ScaleFactor,
+                                        VPTypeAnalysis &TypeInfo, VPlan &Plan) {
+  assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation");
+
+  VPValue *BinOp = WidenRecipe->getOperand(0);
+  VPValue *Accumulator = WidenRecipe->getOperand(1);
+
+  // Swap if needed to ensure Accumulator is the PHI or partial reduction.
+  if (auto *R = BinOp->getDefiningRecipe())
+    if (isa<VPReductionPHIRecipe>(R) || isa<VPReductionRecipe>(R))
+      std::swap(BinOp, Accumulator);
+
+  // For chained reductions, only transform if accumulator is already a PHI or
+  // partial reduction. Otherwise, it needs to be transformed first.
+  auto *AccumRecipe = Accumulator->getDefiningRecipe();
+  if (!AccumRecipe || (!isa<VPReductionPHIRecipe>(AccumRecipe) &&
+                       !isa<VPReductionRecipe>(AccumRecipe)))
+    return false;
+
+  VPValue *Cond = nullptr;
+  VPValue *ExitValue = nullptr;
+  if (auto *RdxPhi = dyn_cast<VPReductionPHIRecipe>(AccumRecipe)) {
+    assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
+    RdxPhi->setVFScaleFactor(ScaleFactor);
+
+    // Update ReductionStartVector instruction scale factor.
+    VPValue *StartValue = RdxPhi->getOperand(0);
+    auto *StartInst = cast<VPInstruction>(StartValue);
+    assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+    auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+    StartInst->setOperand(2, NewScaleFactor);
+
+    ExitValue =
+        findUserOf<VPInstruction::ComputeReductionResult>(RdxPhi)->getOperand(
+            1);
+    match(ExitValue, m_Select(m_VPValue(Cond), m_VPValue(), m_VPValue()));
+  }
+
+  // Handle SUB by negating the operand and using ADD for the partial reduction.
+  if (WidenRecipe->getOpcode() == Instruction::Sub) {
+    VPBuilder Builder(WidenRecipe);
+    Type *ElemTy = TypeInfo.inferScalarType(BinOp);
+    auto *Zero = Plan.getConstantInt(ElemTy, 0);
+    VPIRFlags Flags = WidenRecipe->getUnderlyingInstr()
+                          ? VPIRFlags(*WidenRecipe->getUnderlyingInstr())
+                          : VPIRFlags();
+    auto *NegRecipe = new VPWidenRecipe(Instruction::Sub, {Zero, BinOp}, Flags,
+                                        VPIRMetadata(), DebugLoc());
+    Builder.insert(NegRecipe);
+    BinOp = NegRecipe;
+  }
+
+  auto *PartialRed = new VPReductionRecipe(
+      RecurKind::Add, {}, WidenRecipe->getUnderlyingInstr(), Accumulator, BinOp,
+      Cond, RdxUnordered{/*VFScaleFactor=*/ScaleFactor});
+  PartialRed->insertBefore(WidenRecipe);
+
+  if (Cond)
+    ExitValue->replaceAllUsesWith(PartialRed);
+  WidenRecipe->replaceAllUsesWith(PartialRed);
+  return true;
+}
+
+static bool getScaledReductions(
+    VPSingleDefRecipe *RedPhiR, VPValue *PrevValue, VFRange &Range,
+    SmallVectorImpl<std::pair<VPPartialReductionChain, unsigned>> &Chains,
+    VPTypeAnalysis &TypeInfo, const TargetTransformInfo *TTI,
+    TargetTransformInfo::TargetCostKind CostKind) {
+  auto *UpdateRecipe = dyn_cast<VPWidenRecipe>(PrevValue);
+  if (!UpdateRecipe || UpdateRecipe->getNumOperands() != 2)
+    return false;
+
+  VPValue *Op = UpdateRecipe->getOperand(0);
+  VPValue *PhiOp = UpdateRecipe->getOperand(1);
+  if (Op->getDefiningRecipe() == RedPhiR)
+    std::swap(Op, PhiOp);
+
+  // If Op is an extend, then it's still a valid partial reduction if the
+  // extended mul fulfills the other requirements.
+  // For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
+  // reduction since the inner extends will be widened. We already have oneUse
+  // checks on the inner extends so widening them is safe.
+  std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
+  if (match(Op, m_ZExtOrSExt(m_Mul(m_VPValue(), m_VPValue())))) {
+    auto *CastRecipe = cast<VPWidenCastRecipe>(Op);
+    auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
+    OuterExtKind = TTI::getPartialReductionExtendKind(CastOp);
+    Op = CastRecipe->getOperand(0);
+  }
+
+  if (getScaledReductions(RedPhiR, Op, Range, Chains, TypeInfo, TTI,
+                          CostKind)) {
+    RedPhiR = Chains.rbegin()->first.Reduction;
+    Op = UpdateRecipe->getOperand(0);
+    PhiOp = UpdateRecipe->getOperand(1);
+    if (Op == RedPhiR)
+      std::swap(Op, PhiOp);
+  }
+  if (RedPhiR != PhiOp)
+    return false;
+
+  // Collect extension information for partial reduction.
+  VPWidenCastRecipe *CastRecipes[2] = {nullptr};
+  std::optional<unsigned> BinOpc;
+  Type *ExtOpTypes[2] = {nullptr};
+  TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};
+
+  auto CollectExtInfo = [OuterExtKind, &CastRecipes, &ExtOpTypes, &ExtKinds,
+                         &TypeInfo](ArrayRef<VPValue *> Operands) {
+    if (Operands.size() > 2)
+      return false;
+
+    for (const auto &[I, OpVal] : enumerate(Operands)) {
+      // Allow constant as second operand.
+      if (I > 0 && ExtKinds[0] != TTI::PR_None) {
+        const APInt *C;
+        if (match(OpVal, m_APInt(C)) &&
+            canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) {
+          ExtOpTypes[I] = ExtOpTypes[0];
+          ExtKinds[I] = ExtKinds[0];
+          continue;
+        }
+      }
+
+      // Check if operand is an extend (SExt or ZExt) using VPlan pattern
+      // matching.
+      VPValue *ExtInput;
+      if (!match(OpVal, m_ZExtOrSExt(m_VPValue(ExtInput))))
+        return false;
+
+      // Extends are always VPWidenCastRecipe after pattern matching.
+      CastRecipes[I] = cast<VPWidenCastRecipe>(OpVal->getDefiningRecipe());
+
+      unsigned DefOpc = CastRecipes[I]->getOpcode();
+      assert((DefOpc == Instruction::SExt || DefOpc == Instruction::ZExt) &&
+             "Pattern matched but opcode is not SExt or ZExt");
+
+      auto CastOp = static_cast<Instruction::CastOps>(DefOpc);
+      ExtKinds[I] = TTI::getPartialReductionExtendKind(CastOp);
+      if (OuterExtKind && *OuterExtKind != ExtKinds[I])
+        return false;
+
+      ExtOpTypes[I] = TypeInfo.inferScalarType(CastRecipes[I]->getOperand(0));
+    }
+    return ExtKinds[0] != TTI::PR_None;
+  };
+
+  // If Op is a binary operator, check both of its operands to see if they are
+  // extends. Otherwise, see if the update comes directly from an extend.
+  auto *ExtendUser = dyn_cast<VPWidenRecipe>(Op);
+  if (ExtendUser) {
+    if (!ExtendUser->hasOneUse())
+      return false;
+
+    // Handle neg(binop(ext, ext)) pattern.
+    VPValue *OtherOp = nullptr;
+    if (match(ExtendUser, m_Sub(m_ZeroInt(), m_VPValue(OtherOp))))
+      ExtendUser = dyn_cast<VPWidenRecipe>(OtherOp);
+
+    if (!ExtendUser || !CollectExtInfo(ExtendUser->operands()))
+      return false;
+
+    BinOpc = std::make_optional(ExtendUser->getOpcode());
+  } else if (match(UpdateRecipe, m_Add(m_VPValue(), m_VPValue()))) {
+    if (!CollectExtInfo({Op}))
+      return false;
+    ExtendUser = UpdateRecipe;
+  } else {
+    return false;
+  }
+
+  Type *PhiType = TypeInfo.inferScalarType(RedPhiR);
+  TypeSize PHISize = PhiType->getPrimitiveSizeInBits();
+  TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits();
+  if (!PHISize.hasKnownScalarFactor(ASize))
+    return false;
+
+  if (!LoopVectorizationPlanner::getDecisionAndClampRange(
+          [&](ElementCount VF) {
+            return TTI
+                ->getPartialReductionCost(
+                    UpdateRecipe->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+                    PhiType, VF, ExtKinds[0], ExtKinds[1], BinOpc, CostKind)
+                .isValid();
+          },
+          Range))
+    return false;
+
+  Chains.emplace_back(VPPartialReductionChain(UpdateRecipe, CastRecipes[0],
+                                              CastRecipes[1], ExtendUser),
+                      PHISize.getKnownScalarFactor(ASize));
+  return true;
+}
+} // namespace
+
+void VPlanTransforms::createPartialReductions(
+    VPlan &Plan, VFRange &Range, const TargetTransformInfo *TTI,
+    TargetTransformInfo::TargetCostKind CostKind) {
+  // Find all possible partial reductions, grouping chains by their PHI. This
+  // grouping allows invalidating the whole chain, if any link is not a valid
+  // partial reduction.
+  MapVector<VPReductionPHIRecipe *,
+            SmallVector<std::pair<VPPartialReductionChain, unsigned>>>
+      ChainsByPhi;
+  VPTypeAnalysis TypeInfo(Plan);
+  VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
+  for (VPRecipeBase &R : HeaderVPBB->phis()) {
+    auto *RedPhiR = dyn_cast<VPReductionPHIRecipe>(&R);
+    if (!RedPhiR)
+      continue;
+
+    auto *RdxResult =
+        findUserOf<VPInstruction::ComputeReductionResult>(RedPhiR);
+    if (!RdxResult)
+      continue;
+
+    VPValue *ExitValue = RdxResult->getOperand(1);
+    // Look through selects for predicated reductions.
+    match(ExitValue, m_Select(m_VPValue(), m_VPValue(ExitValue), m_VPValue()));
+    getScaledReductions(RedPhiR, ExitValue, Range, ChainsByPhi[RedPhiR],
+                        TypeInfo, TTI, CostKind);
+  }
+
+  if (ChainsByPhi.empty())
+    return;
+
+  // Build set of partial reduction operations for extend user validation.
+  SmallPtrSet<VPRecipeBase *, 4> PartialReductionOps;
+  for (const auto &[_, Chains] : ChainsByPhi)
+    for (const auto &[Chain, _] : Chains)
+      PartialReductionOps.insert(Chain.ExtendUser);
+
+  // A partial reduction is invalid if any of its extends are used by
+  // something that isn't another partial reduction. This is because the
+  // extends are intended to be lowered along with the reduction itself.
+  // Check extends are only used by partial reductions and build validated map.
+  auto ExtendUsersValid = [&](VPWidenCastRecipe *Ext) {
+    return !Ext || all_of(Ext->users(), [&](VPUser *U) {
+      auto *R = cast<VPRecipeBase>(U);
+      return PartialReductionOps.contains(R);
+    });
+  };
+
+  DenseMap<VPSingleDefRecipe *, unsigned> ScaledReductionMap;
+  for (auto &[_, Chains] : ChainsByPhi)
+    for (auto &[Chain, ScaleFactor] : Chains) {
+      if (!ExtendUsersValid(Chain.ExtendA) ||
+          !ExtendUsersValid(Chain.ExtendB)) {
+        Chains.clear();
+        break;
+      }
+      ScaledReductionMap.try_emplace(Chain.Reduction, ScaleFactor);
+    }
+
+  // Check that all partial reductions in a chain are only used by other
+  // partial reductions with the same scale factor. Otherwise we end up creating
+  // users of scaled reductions where the types of the other operands don't
+  // match.
+  for (auto &[RedPhiR, Chains] : ChainsByPhi) {
+    for (auto &[Chain, ScaleFactor] : Chains) {
+      auto AllUsersPartialRdx = [&, Scale = ScaleFactor,
+                                 RedPhi = RedPhiR](VPUser *U) {
+        if (auto *PhiR = dyn_cast<VPReductionPHIRecipe>(U))
+          return PhiR == RedPhi;
+        auto *R = cast<VPSingleDefRecipe>(U);
+        return match(R, m_Select(m_VPValue(), m_VPValue(), m_VPValue())) ||
+               Scale == ScaledReductionMap.lookup_or(R, 0) || !R->getRegion();
+      };
+
+      if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
+        Chains.clear();
+        break;
+      }
+    }
+  }
+
+  for (const auto &[_, Chains] : ChainsByPhi)
+    for (const auto &[Chain, ScaleFactor] : Chains)
+      transformToPartialReduction(Chain.Reduction, ScaleFactor, TypeInfo, Plan);
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 1a3ff4f9b9bbc..7ce26da2591f4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -16,6 +16,7 @@
 #include "VPlan.h"
 #include "VPlanVerifier.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
 
@@ -28,6 +29,7 @@ class PHINode;
 class ScalarEvolution;
 class PredicatedScalarEvolution;
 class TargetLibraryInfo;
+class TargetTransformInfo;
 class VPBuilder;
 class VPRecipeBuilder;
 struct VFRange;
@@ -413,6 +415,14 @@ struct VPlanTransforms {
   /// users in the original exit block using the VPIRInstruction wrapping to the
   /// LCSSA phi.
   static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range);
+
+  /// Detect and create partial reduction recipes for scaled reductions in
+  /// \p Plan. Must be called after recipe construction. If partial reductions
+  /// are only valid for a subset of VFs in Range, Range.End is updated.
+  static void
+  createPartialReductions(VPlan &Plan, VFRange &Range,
+                          const TargetTransformInfo *TTI,
+                          TargetTransformInfo::TargetCostKind CostKind);
 };
 
 } // namespace llvm
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll
index 2060168c531fb..c189ebec78e7a 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll
@@ -168,42 +168,32 @@ define void @chained_sext_adds(ptr noalias %src, ptr noalias %dst) #0 {
 ;
 ; CHECK-LABEL: define void @chained_sext_adds(
 ; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) #[[ATTR1]] {
-; CHECK-NEXT:  [[ENTRY:.*]]:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 2
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1000, [[TMP1]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    br label %[[VECTOR_PH:.*]]
 ; CHECK:       [[VECTOR_PH]]:
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1000, [[TMP3]]
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 1000, [[N_MOD_VF]]
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP7:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE1:%.*]], %[[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP4]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP5]]
-; CHECK-NEXT:    [[TMP7]] = add <vscale x 4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
-; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP8]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP1]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE1]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP1]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i64 [[INDEX_NEXT]], 992
+; CHECK-NEXT:    br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
 ; CHECK:       [[MIDDLE_BLOCK]]:
-; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP7]])
+; CHECK-NEXT:    [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE1]])
 ; CHECK-NEXT:    store i32 [[TMP9]], ptr [[DST]], align 4
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1000, [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
+; CHECK-NEXT:    br label %[[SCALAR_PH:.*]]
 ; CHECK:       [[SCALAR_PH]]:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP9]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
 ; CHECK-NEXT:    br label %[[LOOP:.*]]
-; CHECK:       [[EXIT]]:
+; CHECK:       [[EXIT:.*]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[LOOP]]:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
-; CHECK-NEXT:    [[RED:%.*]] = phi i32 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 992, %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
+; CHECK-NEXT:    [[RED:%.*]] = phi i32 [ [[TMP9]], %[[SCALAR_PH]] ], [ [[ADD_1:%.*]], %[[LOOP]] ]
 ; CHECK-NEXT:    [[GEP_SRC:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
 ; CHECK-NEXT:    [[L:%.*]] = load i8, ptr [[GEP_SRC]], align 1
 ; CHECK-NEXT:    [[CONV8:%.*]] = sext i8 [[L]] to i32

>From 6c2bd5db78cb563078cb3eec97d427fe39fb662c Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 5 Dec 2025 14:16:37 +0000
Subject: [PATCH 2/3] !fixup re-add comment

---
 llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c87634da87ecf..c629554df689e 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8345,10 +8345,13 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
                                 *Plan))
     return nullptr;
 
+  // Create partial reduction recipes for scaled reductions and transform
+  // recipes to abstract recipes if it is legal and beneficial and clamp the
+  // range for better cost estimation.
+  // TODO: Enable following transform when the EVL-version of extended-reduction
+  // and mulacc-reduction are implemented.
   if (!CM.foldTailWithEVL()) {
-    // Create partial reduction recipes for scaled reductions.
     VPlanTransforms::createPartialReductions(*Plan, Range, &TTI, CM.CostKind);
-
     VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
                           *CM.PSE.getSE(), OrigLoop);
     VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,

>From 0a181fd5b14044697e2cf4e55f82801dc968c4af Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 18 Dec 2025 17:11:26 +0000
Subject: [PATCH 3/3] !fixup cleanup after rebase

---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 56 +++----------------
 .../Transforms/Vectorize/VPRecipeBuilder.h    | 21 +++----
 llvm/lib/Transforms/Vectorize/VPlan.h         | 36 ++++++------
 .../Vectorize/VPlanConstruction.cpp           | 18 ++----
 4 files changed, 42 insertions(+), 89 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c629554df689e..e9871f3a5604d 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8011,51 +8011,14 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
   return Recipe;
 }
 
-VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
-                                                      VFRange &Range) {
-  // First, check for specific widening recipes that deal with inductions, Phi
-  // nodes, calls and memory operations.
-  VPRecipeBase *Recipe;
-  if (auto *PhiR = dyn_cast<VPPhi>(R)) {
-    VPBasicBlock *Parent = PhiR->getParent();
-    [[maybe_unused]] VPRegionBlock *LoopRegionOf =
-        Parent->getEnclosingLoopRegion();
-    assert(LoopRegionOf && LoopRegionOf->getEntry() == Parent &&
-           "Non-header phis should have been handled during predication");
-    auto *Phi = cast<PHINode>(R->getUnderlyingInstr());
-    assert(R->getNumOperands() == 2 && "Must have 2 operands for header phis");
-
-    VPHeaderPHIRecipe *PhiRecipe = nullptr;
-    assert((Legal->isReductionVariable(Phi) ||
-            Legal->isFixedOrderRecurrence(Phi)) &&
-           "can only widen reductions and fixed-order recurrences here");
-    VPValue *StartV = R->getOperand(0);
-    if (Legal->isReductionVariable(Phi)) {
-      const RecurrenceDescriptor &RdxDesc = Legal->getRecurrenceDescriptor(Phi);
-      assert(RdxDesc.getRecurrenceStartValue() ==
-             Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
-
-      // If the PHI is used by a partial reduction, set the scale factor.
-      bool UseInLoopReduction = CM.isInLoopReduction(Phi);
-      bool UseOrderedReductions = CM.useOrderedReductions(RdxDesc);
-
-      PhiRecipe = new VPReductionPHIRecipe(
-          Phi, RdxDesc.getRecurrenceKind(), *StartV,
-          getReductionStyle(UseInLoopReduction, UseOrderedReductions, 1),
-          RdxDesc.hasUsesOutsideReductionChain());
-    } else {
-      // TODO: Currently fixed-order recurrences are modeled as chains of
-      // first-order recurrences. If there are no users of the intermediate
-      // recurrences in the chain, the fixed order recurrence should be modeled
-      // directly, enabling more efficient codegen.
-      PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV);
-    }
-    // Add backedge value.
-    PhiRecipe->addOperand(R->getOperand(1));
-    return PhiRecipe;
-  }
-  assert(!R->isPhi() && "only VPPhi nodes expected at this point");
+VPRecipeBase *
+VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
+                                              VFRange &Range) {
+  assert(!R->isPhi() && "phis must be handled earlier");
+  // First, check for specific widening recipes that deal with optimizing
+  // truncates, calls and memory operations.
 
+  VPRecipeBase *Recipe;
   auto *VPI = cast<VPInstruction>(R);
   if (VPI->getOpcode() == Instruction::Trunc &&
       (Recipe = tryToOptimizeInductionTruncate(VPI, Range)))
@@ -8234,8 +8197,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
   // Construct wide recipes and apply predication for original scalar
   // VPInstructions in the loop.
   // ---------------------------------------------------------------------------
-  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder,
-                                BlockMaskCache);
+  VPRecipeBuilder RecipeBuilder(*Plan, TLI, Legal, CM, Builder, BlockMaskCache);
 
   // Scan the body of the loop in a topological order to visit each basic block
   // after having visited its predecessor basic blocks.
@@ -8283,7 +8245,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
       }
 
       VPRecipeBase *Recipe =
-          RecipeBuilder.tryToCreateWidenRecipe(VPI, Range);
+          RecipeBuilder.tryToCreateWidenNonPhiRecipe(VPI, Range);
       if (!Recipe)
         Recipe =
             RecipeBuilder.handleReplication(cast<VPInstruction>(VPI), Range);
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index d537f62c95a13..3280fc4827239 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -27,9 +27,6 @@ class VPRecipeBuilder {
   /// The VPlan new recipes are added to.
   VPlan &Plan;
 
-  /// The loop that we evaluate.
-  Loop *OrigLoop;
-
   /// Target Library Info.
   const TargetLibraryInfo *TLI;
 
@@ -39,8 +36,6 @@ class VPRecipeBuilder {
   /// The profitablity analysis.
   LoopVectorizationCostModel &CM;
 
-  PredicatedScalarEvolution &PSE;
-
   VPBuilder &Builder;
 
   /// The mask of each VPBB, generated earlier and used for predicating recipes
@@ -90,17 +85,17 @@ class VPRecipeBuilder {
                                          VPInstruction *VPI);
 
 public:
-  VPRecipeBuilder(VPlan &Plan, Loop *OrigLoop, const TargetLibraryInfo *TLI,
+  VPRecipeBuilder(VPlan &Plan, const TargetLibraryInfo *TLI,
                   LoopVectorizationLegality *Legal,
-                  LoopVectorizationCostModel &CM,
-                  PredicatedScalarEvolution &PSE, VPBuilder &Builder,
+                  LoopVectorizationCostModel &CM, VPBuilder &Builder,
                   DenseMap<VPBasicBlock *, VPValue *> &BlockMaskCache)
-      : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM),
-        PSE(PSE), Builder(Builder), BlockMaskCache(BlockMaskCache) {}
+      : Plan(Plan), TLI(TLI), Legal(Legal), CM(CM), Builder(Builder),
+        BlockMaskCache(BlockMaskCache) {}
 
-  /// Create and return a widened recipe for \p R if one can be created within
-  /// the given VF \p Range.
-  VPRecipeBase *tryToCreateWidenRecipe(VPSingleDefRecipe *R, VFRange &Range);
+  /// Create and return a widened recipe for a non-phi recipe \p R if one can be
+  /// created within the given VF \p Range.
+  VPRecipeBase *tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
+                                             VFRange &Range);
 
   /// Set the recipe created for given ingredient.
   void setRecipe(Instruction *I, VPRecipeBase *R) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index a697f829343a5..d283322b6dcd3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2379,16 +2379,17 @@ class LLVM_ABI_FOR_TEST VPWidenPHIRecipe : public VPSingleDefRecipe,
 /// first operand of the recipe and the incoming value from the backedge is the
 /// second operand.
 struct VPFirstOrderRecurrencePHIRecipe : public VPHeaderPHIRecipe {
-  VPFirstOrderRecurrencePHIRecipe(PHINode *Phi, VPValue &Start)
-      : VPHeaderPHIRecipe(VPDef::VPFirstOrderRecurrencePHISC, Phi, &Start) {}
+  VPFirstOrderRecurrencePHIRecipe(PHINode *Phi, VPValue &Start,
+                                  VPValue &BackedgeValue)
+      : VPHeaderPHIRecipe(VPDef::VPFirstOrderRecurrencePHISC, Phi, &Start) {
+    addOperand(&BackedgeValue);
+  }
 
   VP_CLASSOF_IMPL(VPDef::VPFirstOrderRecurrencePHISC)
 
   VPFirstOrderRecurrencePHIRecipe *clone() override {
-    auto *R = new VPFirstOrderRecurrencePHIRecipe(
-        cast<PHINode>(getUnderlyingInstr()), *getOperand(0));
-    R->addOperand(getBackedgeValue());
-    return R;
+    return new VPFirstOrderRecurrencePHIRecipe(
+        cast<PHINode>(getUnderlyingInstr()), *getOperand(0), *getOperand(1));
   }
 
   void execute(VPTransformState &State) override;
@@ -2454,20 +2455,21 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
 public:
   /// Create a new VPReductionPHIRecipe for the reduction \p Phi.
   VPReductionPHIRecipe(PHINode *Phi, RecurKind Kind, VPValue &Start,
-                       ReductionStyle Style,
+                       VPValue &BackedgeValue, ReductionStyle Style,
                        bool HasUsesOutsideReductionChain = false)
       : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start), Kind(Kind),
         Style(Style),
-        HasUsesOutsideReductionChain(HasUsesOutsideReductionChain) {}
+        HasUsesOutsideReductionChain(HasUsesOutsideReductionChain) {
+    addOperand(&BackedgeValue);
+  }
 
   ~VPReductionPHIRecipe() override = default;
 
   VPReductionPHIRecipe *clone() override {
-    auto *R = new VPReductionPHIRecipe(
+    return new VPReductionPHIRecipe(
         dyn_cast_or_null<PHINode>(getUnderlyingValue()), getRecurrenceKind(),
-        *getOperand(0), Style, HasUsesOutsideReductionChain);
-    R->addOperand(getBackedgeValue());
-    return R;
+        *getOperand(0), *getBackedgeValue(), Style,
+        HasUsesOutsideReductionChain);
   }
 
   VP_CLASSOF_IMPL(VPDef::VPReductionPHISC)
@@ -2482,11 +2484,11 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
     return Partial ? Partial->VFScaleFactor : 1;
   }
 
-  /// Set the factor that the VF of this recipe's output should be scaled by.
-  void setVFScaleFactor(unsigned Factor) {
-    auto *Partial = std::get_if<RdxUnordered>(&Style);
-    assert(Partial && "Can only set VFScaleFactor for unordered reductions");
-    Partial->VFScaleFactor = Factor;
+  /// Set the VFScaleFactor for this reduction phi. Can only be set to a factor
+  /// > 1.
+  void setVFScaleFactor(unsigned ScaleFactor) {
+    assert(ScaleFactor > 1 && "must set to scale factor > 1");
+    Style = RdxUnordered{ScaleFactor};
   }
 
   /// Returns the number of incoming values, also number of incoming blocks.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index d03943ad04869..fa9d386dca446 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "LoopVectorizationPlanner.h"
-#include "VPRecipeBuilder.h"
 #include "VPlan.h"
 #include "VPlanAnalysis.h"
 #include "VPlanCFG.h"
@@ -669,9 +668,7 @@ void VPlanTransforms::createHeaderPhiRecipes(
       // first-order recurrences. If there are no users of the intermediate
       // recurrences in the chain, the fixed order recurrence should be
       // modeled directly, enabling more efficient codegen.
-      auto *Recipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *Start);
-      Recipe->addOperand(BackedgeValue);
-      return Recipe;
+      return new VPFirstOrderRecurrencePHIRecipe(Phi, *Start, *BackedgeValue);
     }
 
     auto InductionIt = Inductions.find(Phi);
@@ -688,13 +685,11 @@ void VPlanTransforms::createHeaderPhiRecipes(
     // Will be updated later to >1 if reduction is partial.
     unsigned ScaleFactor = 1;
     bool UseOrderedReductions = !AllowReordering && RdxDesc.isOrdered();
-    auto *Recipe = new VPReductionPHIRecipe(
-        Phi, RdxDesc.getRecurrenceKind(), *Start,
+    return new VPReductionPHIRecipe(
+        Phi, RdxDesc.getRecurrenceKind(), *Start, *BackedgeValue,
         getReductionStyle(InLoopReductions.contains(Phi), UseOrderedReductions,
                           ScaleFactor),
         RdxDesc.hasUsesOutsideReductionChain());
-    Recipe->addOperand(BackedgeValue);
-    return Recipe;
   };
 
   for (VPRecipeBase &R : make_early_inc_range(HeaderVPBB->phis())) {
@@ -1387,7 +1382,7 @@ static bool getScaledReductions(
   // For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
   // reduction since the inner extends will be widened. We already have oneUse
   // checks on the inner extends so widening them is safe.
-  std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
+  std::optional<TTI::PartialReductionExtendKind> OuterExtKind;
   if (match(Op, m_ZExtOrSExt(m_Mul(m_VPValue(), m_VPValue())))) {
     auto *CastRecipe = cast<VPWidenCastRecipe>(Op);
     auto CastOp = static_cast<Instruction::CastOps>(CastRecipe->getOpcode());
@@ -1414,8 +1409,7 @@ static bool getScaledReductions(
 
   auto CollectExtInfo = [OuterExtKind, &CastRecipes, &ExtOpTypes, &ExtKinds,
                          &TypeInfo](ArrayRef<VPValue *> Operands) {
-    if (Operands.size() > 2)
-      return false;
+    assert(Operands.size() <= 2 && "expected at most 2 operands");
 
     for (const auto &[I, OpVal] : enumerate(Operands)) {
       // Allow constant as second operand.
@@ -1467,7 +1461,7 @@ static bool getScaledReductions(
     if (!ExtendUser || !CollectExtInfo(ExtendUser->operands()))
       return false;
 
-    BinOpc = std::make_optional(ExtendUser->getOpcode());
+    BinOpc = ExtendUser->getOpcode();
   } else if (match(UpdateRecipe, m_Add(m_VPValue(), m_VPValue()))) {
     if (!CollectExtInfo({Op}))
       return false;



More information about the llvm-commits mailing list