[llvm] [LV] Handle partial sub-reductions with sub in middle block. (PR #178919)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 9 08:44:09 PST 2026
================
@@ -5615,6 +5616,40 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
if (Cond)
ExitValue->replaceAllUsesWith(PartialRed);
WidenRecipe->replaceAllUsesWith(PartialRed);
+
+ if (IsLastInChain) {
+ // Scale the PHI and ReductionStartVector by the VFScaleFactor
+ assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
+ RdxPhi->setVFScaleFactor(ScaleFactor);
+
+ auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
+ assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+ auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+ StartInst->setOperand(2, NewScaleFactor);
+
+ // If this is the last value in a sub-reduction chain, then update the PHI
+ // node to start at `0` and update the reduction-result to subtract from
+ // the PHI's start value.
+ if (RK == RecurKind::Sub) {
+ // Update start value of PHI node.
+ VPValue *OldStartValue = StartInst->getOperand(0);
+ StartInst->setOperand(0, StartInst->getOperand(1));
+
+ // Replace reduction_result by 'sub (startval, reductionresult)'.
+ VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
+ assert(RdxResult && "Could not find reduction result");
+
+ VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
+ constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
+ VPInstruction *NewResult = Builder.createNaryOp(
+ SubOpc, {OldStartValue, RdxResult},
+ VPIRFlags::getDefaultFlags(SubOpc), RdxPhi->getDebugLoc());
+ RdxResult->replaceUsesWithIf(
+ NewResult,
+ [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
+ }
+ }
----------------
fhahn wrote:
nit: could exit early to reduce indent level, same for handling the sub below.
```suggestion
if (!IsLastInChain)
return true
// Scale the PHI and ReductionStartVector by the VFScaleFactor
assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
RdxPhi->setVFScaleFactor(ScaleFactor);
auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
StartInst->setOperand(2, NewScaleFactor);
// If this is the last value in a sub-reduction chain, then update the PHI
// node to start at `0` and update the reduction-result to subtract from
// the PHI's start value.
if (RK == RecurKind::Sub) {
// Update start value of PHI node.
VPValue *OldStartValue = StartInst->getOperand(0);
StartInst->setOperand(0, StartInst->getOperand(1));
// Replace reduction_result by 'sub (startval, reductionresult)'.
VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
assert(RdxResult && "Could not find reduction result");
VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
VPInstruction *NewResult = Builder.createNaryOp(
SubOpc, {OldStartValue, RdxResult},
VPIRFlags::getDefaultFlags(SubOpc), RdxPhi->getDebugLoc());
RdxResult->replaceUsesWithIf(
NewResult,
[&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
}
```
https://github.com/llvm/llvm-project/pull/178919
More information about the llvm-commits
mailing list