[llvm] Reland "[LoopVectorizer] Add support for partial reductions" with non-phi operand fix. (PR #121744)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 02:25:24 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Sam Tebbs (SamTebbs33)

<details>
<summary>Changes</summary>

This relands the reverted #<!-- -->120721 with a fix for cases where neither reduction operand are the reduction phi. Only 63114239cc8d26225a0ef9920baacfc7cc00fc58 and 63114239cc8d26225a0ef9920baacfc7cc00fc58 are new on top of the reverted PR.

---

Patch is 301.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121744.diff


16 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+39) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+9) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+17) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+56) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+139-5) 
- (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+57-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+57-6) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+4-4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+70-4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll (+10-10) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+213) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+1375) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+1888) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-no-dotprod.ll (+61) 
- (added) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+93) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 752313ab15858c..c6b846f96f1622 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
 /// for IR-level transformations.
 class TargetTransformInfo {
 public:
+  enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
+
+  /// Get the kind of extension that an instruction represents.
+  static PartialReductionExtendKind
+  getPartialReductionExtendKind(Instruction *I);
+
   /// Construct a TTI object using a type implementing the \c Concept
   /// API below.
   ///
@@ -1280,6 +1286,18 @@ class TargetTransformInfo {
   /// \return if target want to issue a prefetch in address space \p AS.
   bool shouldPrefetchAddressSpace(unsigned AS) const;
 
+  /// \return The cost of a partial reduction, which is a reduction from a
+  /// vector to another vector with fewer elements of larger size. They are
+  /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
+  /// takes an accumulator and a binary operation operand that itself is fed by
+  /// two extends. An example of an operation that uses a partial reduction is a
+  /// dot product, which reduces a vector to another of 4 times fewer elements.
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF, PartialReductionExtendKind OpAExtend,
+                          PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp = std::nullopt) const;
+
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -2107,6 +2125,18 @@ class TargetTransformInfo::Concept {
   /// \return if target want to issue a prefetch in address space \p AS.
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
+  /// \return The cost of a partial reduction, which is a reduction from a
+  /// vector to another vector with fewer elements of larger size. They are
+  /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
+  /// takes an accumulator and a binary operation operand that itself is fed by
+  /// two extends. An example of an operation that uses a partial reduction is a
+  /// dot product, which reduces a vector to another of 4 times fewer elements.
+  virtual InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF, PartialReductionExtendKind OpAExtend,
+                          PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp) const = 0;
+
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2786,6 +2816,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.shouldPrefetchAddressSpace(AS);
   }
 
+  InstructionCost getPartialReductionCost(
+      unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
+      PartialReductionExtendKind OpAExtend,
+      PartialReductionExtendKind OpBExtend,
+      std::optional<unsigned> BinOp = std::nullopt) const override {
+    return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
+                                        OpAExtend, OpBExtend, BinOp);
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9c74b2a0c31df1..5fa0c46ad292d8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
   bool enableWritePrefetching() const { return false; }
   bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
 
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp = std::nullopt) const {
+    return InstructionCost::getInvalid();
+  }
+
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
   InstructionCost getArithmeticInstrCost(
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b32dffa9f0fe86..c62e40db0c5775 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -863,6 +863,14 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
   return TTIImpl->shouldPrefetchAddressSpace(AS);
 }
 
+InstructionCost TargetTransformInfo::getPartialReductionCost(
+    unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
+    PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
+    std::optional<unsigned> BinOp) const {
+  return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
+                                          OpAExtend, OpBExtend, BinOp);
+}
+
 unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
@@ -974,6 +982,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
   return Cost;
 }
 
+TargetTransformInfo::PartialReductionExtendKind
+TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
+  if (isa<SExtInst>(I))
+    return PR_SignExtend;
+  if (isa<ZExtInst>(I))
+    return PR_ZeroExtend;
+  return PR_None;
+}
+
 TTI::CastContextHint
 TargetTransformInfo::getCastContextHint(const Instruction *I) {
   if (!I)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 214fb4e352eeb2..c8ed48dd093984 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -23,6 +23,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/Support/InstructionCost.h"
 #include <cstdint>
 #include <optional>
 
@@ -357,6 +358,61 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return BaseT::isLegalNTLoad(DataType, Alignment);
   }
 
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
+                          ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp) const {
+
+    InstructionCost Invalid = InstructionCost::getInvalid();
+    InstructionCost Cost(TTI::TCC_Basic);
+
+    if (Opcode != Instruction::Add)
+      return Invalid;
+
+    EVT InputEVT = EVT::getEVT(InputType);
+    EVT AccumEVT = EVT::getEVT(AccumType);
+
+    if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+      return Invalid;
+    if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+      return Invalid;
+
+    if (InputEVT == MVT::i8) {
+      switch (VF.getKnownMinValue()) {
+      default:
+        return Invalid;
+      case 8:
+        if (AccumEVT == MVT::i32)
+          Cost *= 2;
+        else if (AccumEVT != MVT::i64)
+          return Invalid;
+        break;
+      case 16:
+        if (AccumEVT == MVT::i64)
+          Cost *= 2;
+        else if (AccumEVT != MVT::i32)
+          return Invalid;
+        break;
+      }
+    } else if (InputEVT == MVT::i16) {
+      // FIXME: Allow i32 accumulator but increase cost, as we would extend
+      //        it to i64.
+      if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+        return Invalid;
+    } else
+      return Invalid;
+
+    if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
+      return Invalid;
+
+    if (!BinOp || (*BinOp) != Instruction::Mul)
+      return Invalid;
+
+    return Cost;
+  }
+
   bool enableOrderedReductions() const { return true; }
 
   InstructionCost getInterleavedMemoryOpCost(
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0797100b182cb1..d6c7870b0905ef 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7532,6 +7532,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
         }
         continue;
       }
+      // The VPlan-based cost model is more accurate for partial reduction and
+      // comparing against the legacy cost isn't desirable.
+      if (isa<VPPartialReductionRecipe>(&R))
+        return true;
       if (Instruction *UI = GetInstructionForCost(&R))
         SeenInstrs.insert(UI);
     }
@@ -8746,6 +8750,108 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
   return Recipe;
 }
 
+/// Find all possible partial reductions in the loop and track all of those that
+/// are valid so recipes can be formed later.
+void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
+  // Find all possible partial reductions.
+  SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
+      PartialReductionChains;
+  for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
+    if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
+            getScaledReduction(Phi, RdxDesc, Range))
+      PartialReductionChains.push_back(*Pair);
+
+  // A partial reduction is invalid if any of its extends are used by
+  // something that isn't another partial reduction. This is because the
+  // extends are intended to be lowered along with the reduction itself.
+
+  // Build up a set of partial reduction bin ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionBinOps;
+  for (const auto &[PartialRdx, _] : PartialReductionChains)
+    PartialReductionBinOps.insert(PartialRdx.BinOp);
+
+  auto ExtendIsOnlyUsedByPartialReductions =
+      [&PartialReductionBinOps](Instruction *Extend) {
+        return all_of(Extend->users(), [&](const User *U) {
+          return PartialReductionBinOps.contains(U);
+        });
+      };
+
+  // Check if each use of a chain's two extends is a partial reduction
+  // and only add those that don't have non-partial reduction users.
+  for (auto Pair : PartialReductionChains) {
+    PartialReductionChain Chain = Pair.first;
+    if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
+        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+      ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
+  }
+}
+
+std::optional<std::pair<PartialReductionChain, unsigned>>
+VPRecipeBuilder::getScaledReduction(PHINode *PHI,
+                                    const RecurrenceDescriptor &Rdx,
+                                    VFRange &Range) {
+  // TODO: Allow scaling reductions when predicating. The select at
+  // the end of the loop chooses between the phi value and most recent
+  // reduction result, both of which have different VFs to the active lane
+  // mask when scaling.
+  if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+    return std::nullopt;
+
+  auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+  if (!Update)
+    return std::nullopt;
+
+  Value *Op = Update->getOperand(0);
+  Value *PhiOp = Update->getOperand(1);
+  if (Op == PHI) {
+    Op = Update->getOperand(1);
+    PhiOp = Update->getOperand(0);
+  }
+  if (PhiOp != PHI)
+    return std::nullopt;
+
+  auto *BinOp = dyn_cast<BinaryOperator>(Op);
+  if (!BinOp || !BinOp->hasOneUse())
+    return std::nullopt;
+
+  using namespace llvm::PatternMatch;
+  Value *A, *B;
+  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
+      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
+    return std::nullopt;
+
+  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
+  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+
+  // Check that the extends extend from the same type.
+  if (A->getType() != B->getType())
+    return std::nullopt;
+
+  TTI::PartialReductionExtendKind OpAExtend =
+      TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+  TTI::PartialReductionExtendKind OpBExtend =
+      TargetTransformInfo::getPartialReductionExtendKind(ExtB);
+
+  PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
+
+  unsigned TargetScaleFactor =
+      PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
+          A->getType()->getPrimitiveSizeInBits());
+
+  if (LoopVectorizationPlanner::getDecisionAndClampRange(
+          [&](ElementCount VF) {
+            InstructionCost Cost = TTI->getPartialReductionCost(
+                Update->getOpcode(), A->getType(), PHI->getType(), VF,
+                OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
+            return Cost.isValid();
+          },
+          Range))
+    return std::make_pair(Chain, TargetScaleFactor);
+
+  return std::nullopt;
+}
+
 VPRecipeBase *
 VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
                                         ArrayRef<VPValue *> Operands,
@@ -8770,9 +8876,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
           Legal->getReductionVars().find(Phi)->second;
       assert(RdxDesc.getRecurrenceStartValue() ==
              Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
-      PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
-                                           CM.isInLoopReduction(Phi),
-                                           CM.useOrderedReductions(RdxDesc));
+
+      // If the PHI is used by a partial reduction, set the scale factor.
+      std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
+          getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
+      unsigned ScaleFactor = Pair ? Pair->second : 1;
+      PhiRecipe = new VPReductionPHIRecipe(
+          Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
+          CM.useOrderedReductions(RdxDesc), ScaleFactor);
     } else {
       // TODO: Currently fixed-order recurrences are modeled as chains of
       // first-order recurrences. If there are no users of the intermediate
@@ -8804,6 +8915,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
   if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
     return tryToWidenMemory(Instr, Operands, Range);
 
+  if (getScaledReductionForInstr(Instr))
+    return tryToCreatePartialReduction(Instr, Operands);
+
   if (!shouldWiden(Instr, Range))
     return nullptr;
 
@@ -8824,6 +8938,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
   return tryToWiden(Instr, Operands, VPBB);
 }
 
+VPRecipeBase *
+VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
+                                             ArrayRef<VPValue *> Operands) {
+  assert(Operands.size() == 2 &&
+         "Unexpected number of operands for partial reduction");
+
+  VPValue *BinOp = Operands[0];
+  VPValue *Phi = Operands[1];
+  if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
+    std::swap(BinOp, Phi);
+
+  return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
+                                      Reduction);
+}
+
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
                                                         ElementCount MaxVF) {
   assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9247,7 +9376,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
   bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
   addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
 
-  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
+  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
+                                Builder);
 
   // ---------------------------------------------------------------------------
   // Pre-construction: record ingredients whose recipes we'll need to further
@@ -9293,6 +9423,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
         bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
         return Legal->blockNeedsPredication(BB) || NeedsBlends;
       });
+
+  RecipeBuilder.collectScaledReductions(Range);
+
   auto *MiddleVPBB = Plan->getMiddleBlock();
   VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
   for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
@@ -9516,7 +9649,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
 
   // Collect mapping of IR header phis to header phi recipes, to be used in
   // addScalarResumePhis.
-  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
+  VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
+                                Builder);
   for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
     if (isa<VPCanonicalIVPHIRecipe>(&R))
       continue;
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 5d4a3b555981ce..cf653e2d3e6584 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -21,8 +21,28 @@ namespace llvm {
 class LoopVectorizationLegality;
 class LoopVectorizationCostModel;
 class TargetLibraryInfo;
+class TargetTransformInfo;
 struct HistogramInfo;
 
+/// A chain of instructions that form a partial reduction.
+/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
+/// accumulator).
+struct PartialReductionChain {
+  PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
+                        Instruction *ExtendB, Instruction *BinOp)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
+  }
+  /// The top-level binary operation that forms the reduction to a scalar
+  /// after the loop body.
+  Instruction *Reduction;
+  /// The extension of each of the inner binary operation's operands.
+  Instruction *ExtendA;
+  Instruction *ExtendB;
+
+  /// The binary operation using the extends that is then reduced.
+  Instruction *BinOp;
+};
+
 /// Helper class to create VPRecipies from IR instructions.
 class VPRecipeBuilder {
   /// The VPlan new recipes are added to.
@@ -34,6 +54,9 @@ class VPRecipeBuilder {
   /// Target Library Info.
   const TargetLibraryInfo *TLI;
 
+  // Target Transform Info.
+  const TargetTransformInfo *TTI;
+
   /// The legality analysis.
   LoopVectorizationLegality *Legal;
 
@@ -63,6 +86,11 @@ class VPRecipeBuilder {
   /// created.
   SmallVector<VPHeaderPHIRecipe *, 4> PhisToFix;
 
+  /// The set of reduction exit instructions that will be scaled to
+  /// a smaller VF via partial reductions, paired with the scaling factor.
+  DenseMap<const Instruction *, std::pair<PartialReductionChain, unsigned>>
+      ScaledReductionExitInstrs;
+
   /// Check if \p I can be widened at the start of \p Range and possibly
   /// decrease the range such that the returned value holds for the entire \p
   /// Range. The function should not be called for memory instructions or calls.
@@ -111,13 +139,35 @@ class VPRecipeBuilder {
   VPHistogramRecipe *tryToWidenHistogram(const HistogramInfo *HI,
                                          ArrayRef<VPValue *> Operands);
 
+  /// Examines reduction operations to see if the target can use a cheaper
+  /// operation with a wider per-iteration input VF and narrower PHI VF.
+  /// Returns null if no scaled reduction was found, otherwise a pair with a
+  /// struct containing reduction information and the scaling factor between the
+  /// number of elements in the input and output.
+  std::optional<std::pair<PartialReductionChain, unsigned>>
+  getScaledReduct...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/121744


More information about the llvm-commits mailing list