[llvm] [LoopVectorizer] Add support for chaining partial reductions (PR #120272)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 20 03:30:45 PST 2025


================
@@ -8823,26 +8823,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   }
 }
 
-std::optional<std::pair<PartialReductionChain, unsigned>>
-VPRecipeBuilder::getScaledReduction(PHINode *PHI,
-                                    const RecurrenceDescriptor &Rdx,
+std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
                                     VFRange &Range) {
+
+  if (!CM.TheLoop->contains(RdxExitInstr))
+    return std::nullopt;
+
   // TODO: Allow scaling reductions when predicating. The select at
   // the end of the loop chooses between the phi value and most recent
   // reduction result, both of which have different VFs to the active lane
   // mask when scaling.
-  if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+  if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
     return std::nullopt;
 
-  auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+  auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
   if (!Update)
     return std::nullopt;
 
   Value *Op = Update->getOperand(0);
   Value *PhiOp = Update->getOperand(1);
-  if (Op == PHI) {
-    Op = Update->getOperand(1);
-    PhiOp = Update->getOperand(0);
+  if (Op == PHI)
+    std::swap(Op, PhiOp);
+
+  SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
+
+  if (auto *OpInst = dyn_cast<Instruction>(Op)) {
----------------
SamTebbs33 wrote:

A comment here would be good.

https://github.com/llvm/llvm-project/pull/120272


More information about the llvm-commits mailing list