[llvm] [LV] Handle partial sub-reductions with sub in middle block. (PR #178919)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 10 01:52:33 PST 2026
https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/178919
>From 565ee7f90a490bec7db22c4b02988d6be51fc4ac Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 30 Jan 2026 16:36:54 +0000
Subject: [PATCH 1/5] Pre-commit test
---
.../AArch64/partial-reduce-sub-sdot.ll | 76 +++++++++++++++++++
1 file changed, 76 insertions(+)
create mode 100644 llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
new file mode 100644
index 0000000000000..d0c935a6385cc
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -0,0 +1,76 @@
+; REQUIRES: asserts
+; RUN: opt -passes=loop-vectorize \
+; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
+; RUN: -disable-output < %s 2>&1 | FileCheck %s
+
+; CHECK: LV: Checking a loop in 'sub_reduction'
+; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+
+; CHECK: LV: Checking a loop in 'add_sub_chained_reduction'
+; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+
+target triple = "aarch64"
+
+; Test the cost of a SUB reduction, where the SUB is implemented outside the loop
+; and therefore not part of the partial reduction.
+define i32 @sub_reduction(ptr %arr1, ptr %arr2, i32 %init, i32 %n) #0 {
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %acc = phi i32 [ %init, %entry ], [ %sub, %loop ]
+ %gep1 = getelementptr inbounds i8, ptr %arr1, i32 %iv
+ %load1 = load i8, ptr %gep1
+ %sext1 = sext i8 %load1 to i32
+ %gep2 = getelementptr inbounds i8, ptr %arr2, i32 %iv
+ %load2 = load i8, ptr %gep2
+ %sext2 = sext i8 %load2 to i32
+ %mul = mul i32 %sext1, %sext2
+ %sub = sub i32 %acc, %mul
+ %iv.next = add i32 %iv, 1
+ %cmp = icmp ult i32 %iv.next, %n
+ br i1 %cmp, label %loop, label %exit, !llvm.loop !0
+
+exit:
+ ret i32 %sub
+}
+
+; Test that the cost of a SUB that is part of an ADD-SUB reduction chain
+; is high, because the negation happens inside the loop and cannot be
+; folded into the SDOT instruction (because of the extend).
+define i32 @add_sub_chained_reduction(ptr %arr1, ptr %arr2, ptr %arr3, i32 %init, i32 %n) #0 {
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %acc = phi i32 [ %init, %entry ], [ %sub, %loop ]
+ %gep1 = getelementptr inbounds i8, ptr %arr1, i32 %iv
+ %load1 = load i8, ptr %gep1
+ %sext1 = sext i8 %load1 to i32
+ %gep2 = getelementptr inbounds i8, ptr %arr2, i32 %iv
+ %load2 = load i8, ptr %gep2
+ %sext2 = sext i8 %load2 to i32
+ %mul1 = mul i32 %sext1, %sext2
+ %add = add i32 %acc, %mul1
+ %gep3 = getelementptr inbounds i8, ptr %arr3, i32 %iv
+ %load3 = load i8, ptr %gep3
+ %sext3 = sext i8 %load3 to i32
+ %mul2 = mul i32 %sext2, %sext3
+ %sub = sub i32 %add, %mul2
+ %iv.next = add i32 %iv, 1
+ %cmp = icmp ult i32 %iv.next, %n
+ br i1 %cmp, label %loop, label %exit, !llvm.loop !0
+
+exit:
+ ret i32 %sub
+}
+
+attributes #0 = { vscale_range(1,16) "target-features"="+sve2" }
+
+!0 = distinct !{!0, !1, !2, !3}
+!1 = !{!"llvm.loop.scalable.enable", i1 true}
+!2 = !{!"llvm.loop.interleave.count", i32 1}
+!3 = !{!"llvm.loop.vectorize.width", i32 16}
>From 97a21654e6e90622684ddad10beb88c008c9e1ea Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 30 Jan 2026 16:21:43 +0000
Subject: [PATCH 2/5] [LV] Handle partial sub-reductions with sub in middle
block.
Sub-reductions can be implemented in two ways:
(1) negate the operand in the vector loop (the default way).
(2) subtract the reduced value from the init value in the middle block.
Note that both ways keep the reduction itself as an 'add' reduction,
which is necessary because only llvm.vector.partial.reduce.add exists.
The ISD nodes for partial reductions don't support folding the
sub/negation into its operands because the following is not a valid
transformation:
sub(0, mul(ext(a), ext(b)))
-> mul(ext(a), ext(sub(0, b)))
It can therefore be better to choose option (2) such that the partial
reduction is always positive (starting at '0') and to do a final
subtract in the middle block.
For AArch64 there are also no dot-product operations that can
do a `partial.reduce(mul(ext(a), sub(0, ext(b))))` operation.
I'm not sure if such instructions exist for other targets.
If so, then we may want to make this decision a target option.
This PR also increases the AArch64 cost of a partial sub-reduction
when this exists in an 'add-sub' reduction chain.
Fixes https://github.com/llvm/llvm-project/issues/178703
---
.../AArch64/AArch64TargetTransformInfo.cpp | 5 ++
.../Transforms/Vectorize/VPlanTransforms.cpp | 54 +++++++++++++++++--
.../AArch64/partial-reduce-chained.ll | 21 ++++----
.../AArch64/partial-reduce-sub-sdot.ll | 2 +-
.../AArch64/partial-reduce-sub.ll | 21 ++++----
5 files changed, 73 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c9d775367f929..71f52ae55d3ec 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5912,6 +5912,11 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+ // The sub/negation cannot be folded into the operands of
+ // ISD::PARTIAL_REDUCE_*MLA, so make the cost more expensive.
+ if (Opcode == Instruction::Sub)
+ Cost += 8;
+
// Prefer using full types by costing half-full input types as more expensive.
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
TypeSize::getScalable(128)))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index efea585114947..18ae31526c452 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5551,7 +5551,7 @@ struct VPPartialReductionChain {
// clamps VF range.
static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
VFRange &Range, VPCostContext &CostCtx,
- VPlan &Plan) {
+ VPlan &Plan, RecurKind RK) {
VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp;
unsigned ScaleFactor = Chain.ScaleFactor;
assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation");
@@ -5639,8 +5639,13 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
}
- // Handle SUB by negating the operand and using ADD for the partial reduction.
- if (WidenRecipe->getOpcode() == Instruction::Sub) {
+ // For partial reductions the 'sub' is performed outside the loop,
+ // so that the reduction itself is all positive, because otherwise
+ // the neg/sub can't be folded into the operands. For non-partial
+ // sub-reductions and reduction chains that have both adds and subs,
+ // the sub is performed on the operand inside the vector loop.
+ // Note that both cases use an add to implement the actual reduction step.
+ if (WidenRecipe->getOpcode() == Instruction::Sub && RK != RecurKind::Sub) {
VPBuilder Builder(WidenRecipe);
Type *ElemTy = CostCtx.Types.inferScalarType(BinOp);
auto *Zero = Plan.getConstantInt(ElemTy, 0);
@@ -5889,7 +5894,46 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan,
}
}
- for (const auto &[_, Chains] : ChainsByPhi)
+ for (auto &[Phi, Chains] : ChainsByPhi) {
+ RecurKind RK = cast<VPReductionPHIRecipe>(Phi)->getRecurrenceKind();
+ bool IsTransformed = false;
for (const VPPartialReductionChain &Chain : Chains)
- transformToPartialReduction(Chain, Range, CostCtx, Plan);
+ IsTransformed |=
+ transformToPartialReduction(Chain, Range, CostCtx, Plan, RK);
+
+ // Sub-reductions can be implemented in two ways:
+ // (1) negate the operand in the vector loop (the default way).
+ // (2) subtract the reduced value from the init value in the middle block.
+ // Both ways keep the reduction itself as an 'add' reduction.
+ //
+ // The ISD nodes for partial reductions don't support folding the
+ // sub/negation into its operands because the following is not a valid
+ // transformation:
+ // sub(0, mul(ext(a), ext(b)))
+ // -> mul(ext(a), ext(sub(0, b)))
+ //
+ // It's therefore better to choose option (2) such that the partial
+ // reduction is always positive (starting at '0') and to do a final
+ // subtract in the middle block.
+ if (IsTransformed && !Range.isEmpty() && RK == RecurKind::Sub) {
+ // Update start value of PHI node.
+ auto *StartInst = cast<VPInstruction>(Phi->getStartValue());
+ assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+ VPValue *OldStartValue = StartInst->getOperand(0);
+ StartInst->setOperand(0, StartInst->getOperand(1));
+
+ // Replace reduction_result by 'sub (startval, reductionresult)'.
+ VPInstruction *RdxResult = vputils::findComputeReductionResult(Phi);
+ assert(RdxResult && "Could not find reduction result");
+
+ VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
+ constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
+ VPInstruction *NewResult = Builder.createNaryOp(
+ SubOpc, {OldStartValue, RdxResult},
+ VPIRFlags::getDefaultFlags(SubOpc), Phi->getDebugLoc());
+ RdxResult->replaceUsesWithIf(
+ NewResult,
+ [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
+ }
+ }
}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index f1dee958fa09c..d1fde2cdaafe1 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -494,17 +494,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
-; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
; CHECK-NEON-NEXT: [[TMP16:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP16]]
-; CHECK-NEON-NEXT: [[TMP13:%.*]] = sub <16 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub i32 0, [[TMP15]]
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-NEON: scalar.ph:
@@ -537,17 +536,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-NEXT: [[TMP14:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP16:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP17:%.*]] = mul nsw <vscale x 16 x i32> [[TMP14]], [[TMP16]]
-; CHECK-SVE-NEXT: [[TMP10:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP17]]
-; CHECK-SVE-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP10]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP17]])
; CHECK-SVE-NEXT: [[TMP11:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP12:%.*]] = mul nsw <vscale x 16 x i32> [[TMP14]], [[TMP11]]
-; CHECK-SVE-NEXT: [[TMP13:%.*]] = sub <vscale x 16 x i32> zeroinitializer, [[TMP12]]
-; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP13]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP12]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-SVE-NEXT: [[TMP18:%.*]] = sub i32 0, [[TMP15]]
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-SVE: scalar.ph:
@@ -580,17 +578,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP13:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
-; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP16]])
; CHECK-SVE-MAXBW-NEXT: [[TMP12:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP12]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub <vscale x 8 x i32> zeroinitializer, [[TMP18]]
-; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sub i32 0, [[TMP21]]
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-SVE-MAXBW: scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
index d0c935a6385cc..ffedf1f9eb100 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -8,7 +8,7 @@
; CHECK: LV: Checking a loop in 'add_sub_chained_reduction'
; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; CHECK: Cost of 9 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
target triple = "aarch64"
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
index 107d1d8a78706..9bf37a43e467d 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
@@ -154,13 +154,13 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVE1-NEXT: [[TMP3:%.*]] = sext <4 x i32> [[TMP2]] to <4 x i64>
; CHECK-INTERLEAVE1-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-INTERLEAVE1-NEXT: [[TMP5:%.*]] = mul <4 x i64> [[TMP3]], [[TMP4]]
-; CHECK-INTERLEAVE1-NEXT: [[TMP6:%.*]] = sub <4 x i64> zeroinitializer, [[TMP5]]
-; CHECK-INTERLEAVE1-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP6]])
+; CHECK-INTERLEAVE1-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP5]])
; CHECK-INTERLEAVE1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-INTERLEAVE1-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 36
; CHECK-INTERLEAVE1-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-INTERLEAVE1: middle.block:
; CHECK-INTERLEAVE1-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[PARTIAL_REDUCE]])
+; CHECK-INTERLEAVE1-NEXT: [[TMP9:%.*]] = sub i64 0, [[TMP8]]
; CHECK-INTERLEAVE1-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC]], i32 3
; CHECK-INTERLEAVE1-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-INTERLEAVE1: scalar.ph:
@@ -204,23 +204,19 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = sext <4 x i32> [[TMP5]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP11:%.*]] = mul <4 x i64> [[TMP7]], [[TMP9]]
-; CHECK-INTERLEAVED-NEXT: [[TMP18:%.*]] = sub <4 x i64> zeroinitializer, [[TMP11]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP18]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP11]])
; CHECK-INTERLEAVED-NEXT: [[TMP19:%.*]] = sext <4 x i32> [[TMP31]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP20:%.*]] = sext <4 x i32> [[STRIDED_VEC5]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP21:%.*]] = mul <4 x i64> [[TMP19]], [[TMP20]]
-; CHECK-INTERLEAVED-NEXT: [[TMP22:%.*]] = sub <4 x i64> zeroinitializer, [[TMP21]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE10]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI1]], <4 x i64> [[TMP22]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE10]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI1]], <4 x i64> [[TMP21]])
; CHECK-INTERLEAVED-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP13]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = sext <4 x i32> [[STRIDED_VEC7]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = mul <4 x i64> [[TMP8]], [[TMP10]]
-; CHECK-INTERLEAVED-NEXT: [[TMP26:%.*]] = sub <4 x i64> zeroinitializer, [[TMP12]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE11]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI2]], <4 x i64> [[TMP26]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE11]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI2]], <4 x i64> [[TMP12]])
; CHECK-INTERLEAVED-NEXT: [[TMP27:%.*]] = sext <4 x i32> [[TMP14]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP28:%.*]] = sext <4 x i32> [[STRIDED_VEC3]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP29:%.*]] = mul <4 x i64> [[TMP27]], [[TMP28]]
-; CHECK-INTERLEAVED-NEXT: [[TMP30:%.*]] = sub <4 x i64> zeroinitializer, [[TMP29]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE12]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI3]], <4 x i64> [[TMP30]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE12]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI3]], <4 x i64> [[TMP29]])
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-INTERLEAVED-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 32
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -229,6 +225,7 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX13:%.*]] = add <2 x i64> [[PARTIAL_REDUCE11]], [[BIN_RDX]]
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX14:%.*]] = add <2 x i64> [[PARTIAL_REDUCE12]], [[BIN_RDX13]]
; CHECK-INTERLEAVED-NEXT: [[TMP32:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[BIN_RDX14]])
+; CHECK-INTERLEAVED-NEXT: [[TMP30:%.*]] = sub i64 0, [[TMP32]]
; CHECK-INTERLEAVED-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC3]], i32 3
; CHECK-INTERLEAVED-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-INTERLEAVED: scalar.ph:
@@ -251,13 +248,13 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-MAXBW-NEXT: [[TMP3:%.*]] = sext <4 x i32> [[TMP2]] to <4 x i64>
; CHECK-MAXBW-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-MAXBW-NEXT: [[TMP5:%.*]] = mul <4 x i64> [[TMP3]], [[TMP4]]
-; CHECK-MAXBW-NEXT: [[TMP6:%.*]] = sub <4 x i64> zeroinitializer, [[TMP5]]
-; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP6]])
+; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP5]])
; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-MAXBW-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 36
; CHECK-MAXBW-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-MAXBW: middle.block:
; CHECK-MAXBW-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[PARTIAL_REDUCE]])
+; CHECK-MAXBW-NEXT: [[TMP9:%.*]] = sub i64 0, [[TMP8]]
; CHECK-MAXBW-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC]], i32 3
; CHECK-MAXBW-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-MAXBW: scalar.ph:
>From 50b20b5e540f0e17c31cb1267a4c2ab76cce8cf5 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 5 Feb 2026 16:09:52 +0000
Subject: [PATCH 3/5] Add RUN line to check for NEON costs
---
.../AArch64/partial-reduce-sub-sdot.ll | 34 ++++++++++++-------
1 file changed, 21 insertions(+), 13 deletions(-)
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
index ffedf1f9eb100..f06b2137c2b8d 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -1,14 +1,23 @@
; REQUIRES: asserts
-; RUN: opt -passes=loop-vectorize \
-; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
-; RUN: -disable-output < %s 2>&1 | FileCheck %s
+; RUN: opt -passes=loop-vectorize \
+; RUN: -scalable-vectorization=on -mattr=+sve2 \
+; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
+; RUN: -disable-output < %s 2>&1 | FileCheck %s --check-prefixes=COMMON,SVE
-; CHECK: LV: Checking a loop in 'sub_reduction'
-; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; RUN: opt -passes=loop-vectorize \
+; RUN: -scalable-vectorization=off -mattr=+neon,+dotprod \
+; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
+; RUN: -disable-output < %s 2>&1 | FileCheck %s --check-prefixes=COMMON,NEON
-; CHECK: LV: Checking a loop in 'add_sub_chained_reduction'
-; CHECK: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
-; CHECK: Cost of 9 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; COMMON: LV: Checking a loop in 'sub_reduction'
+; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+
+; COMMON: LV: Checking a loop in 'add_sub_chained_reduction'
+; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; SVE: Cost of 9 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; NEON: Cost of 9 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
target triple = "aarch64"
@@ -68,9 +77,8 @@ exit:
ret i32 %sub
}
-attributes #0 = { vscale_range(1,16) "target-features"="+sve2" }
+attributes #0 = { vscale_range(1,16) }
-!0 = distinct !{!0, !1, !2, !3}
-!1 = !{!"llvm.loop.scalable.enable", i1 true}
-!2 = !{!"llvm.loop.interleave.count", i32 1}
-!3 = !{!"llvm.loop.vectorize.width", i32 16}
+!0 = distinct !{!0, !1, !2}
+!1 = !{!"llvm.loop.interleave.count", i32 1}
+!2 = !{!"llvm.loop.vectorize.width", i32 16}
>From b4c491ff2a7d7a6576bbaa6dc0f801277ef44c1f Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 9 Feb 2026 11:48:23 +0000
Subject: [PATCH 4/5] Update PHI node and ComputeReductionResult in
transformToPartialReduction
---
.../Transforms/Vectorize/VPlanTransforms.cpp | 127 +++++++++---------
1 file changed, 60 insertions(+), 67 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 18ae31526c452..c7da4090f51c5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5551,7 +5551,9 @@ struct VPPartialReductionChain {
// clamps VF range.
static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
VFRange &Range, VPCostContext &CostCtx,
- VPlan &Plan, RecurKind RK) {
+ VPlan &Plan,
+ VPReductionPHIRecipe *RdxPhi,
+ RecurKind RK) {
VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp;
unsigned ScaleFactor = Chain.ScaleFactor;
assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation");
@@ -5624,27 +5626,20 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
Range))
return false;
- VPValue *Cond = nullptr;
- VPValue *ExitValue = nullptr;
- if (auto *RdxPhi = dyn_cast<VPReductionPHIRecipe>(AccumRecipe)) {
- assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
- RdxPhi->setVFScaleFactor(ScaleFactor);
-
- // Update ReductionStartVector instruction scale factor.
- VPValue *StartValue = RdxPhi->getOperand(0);
- auto *StartInst = cast<VPInstruction>(StartValue);
- assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
- auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
- StartInst->setOperand(2, NewScaleFactor);
-
- }
-
- // For partial reductions the 'sub' is performed outside the loop,
- // so that the reduction itself is all positive, because otherwise
- // the neg/sub can't be folded into the operands. For non-partial
- // sub-reductions and reduction chains that have both adds and subs,
- // the sub is performed on the operand inside the vector loop.
- // Note that both cases use an add to implement the actual reduction step.
+ // Sub-reductions can be implemented in two ways:
+ // (1) negate the operand in the vector loop (the default way).
+ // (2) subtract the reduced value from the init value in the middle block.
+ // Both ways keep the reduction itself as an 'add' reduction.
+ //
+ // The ISD nodes for partial reductions don't support folding the
+ // sub/negation into its operands because the following is not a valid
+ // transformation:
+ // sub(0, mul(ext(a), ext(b)))
+ // -> mul(ext(a), ext(sub(0, b)))
+ //
+ // It's therefore better to choose option (2) such that the partial
+ // reduction is always positive (starting at '0') and to do a final
+ // subtract in the middle block.
if (WidenRecipe->getOpcode() == Instruction::Sub && RK != RecurKind::Sub) {
VPBuilder Builder(WidenRecipe);
Type *ElemTy = CostCtx.Types.inferScalarType(BinOp);
@@ -5660,13 +5655,14 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
// Check if WidenRecipe is the final result of the reduction. If so look
// through selects for predicated reductions.
- VPReductionPHIRecipe *RdxPhi;
- ExitValue = cast_or_null<VPInstruction>(vputils::findUserOf(
- WidenRecipe, m_Select(m_VPValue(Cond), m_Specific(WidenRecipe),
- m_ReductionPhi(RdxPhi))));
- assert(!ExitValue || RdxPhi->getBackedgeValue() == WidenRecipe ||
- RdxPhi->getBackedgeValue() == ExitValue &&
- "if we found ExitValue, it must match RdxPhi's backedge value");
+ VPValue *Cond = nullptr;
+ VPValue *ExitValue = cast_or_null<VPInstruction>(vputils::findUserOf(
+ WidenRecipe,
+ m_Select(m_VPValue(Cond), m_Specific(WidenRecipe), m_Specific(RdxPhi))));
+ bool IsLastInChain = RdxPhi->getBackedgeValue() == WidenRecipe ||
+ RdxPhi->getBackedgeValue() == ExitValue;
+ assert((!ExitValue || IsLastInChain) &&
+ "if we found ExitValue, it must match RdxPhi's backedge value");
RecurKind RdxKind =
PhiType->isFloatingPointTy() ? RecurKind::FAdd : RecurKind::Add;
@@ -5681,6 +5677,40 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
if (Cond)
ExitValue->replaceAllUsesWith(PartialRed);
WidenRecipe->replaceAllUsesWith(PartialRed);
+
+ if (IsLastInChain) {
+ // Scale the PHI and ReductionStartVector by the VFScaleFactor
+ assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
+ RdxPhi->setVFScaleFactor(ScaleFactor);
+
+ auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
+ assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+ auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+ StartInst->setOperand(2, NewScaleFactor);
+
+ // If this is the last value in a sub-reduction chain, then update the PHI
+ // node to start at `0` and update the reduction-result to subtract from
+ // the PHI's start value.
+ if (RK == RecurKind::Sub) {
+ // Update start value of PHI node.
+ VPValue *OldStartValue = StartInst->getOperand(0);
+ StartInst->setOperand(0, StartInst->getOperand(1));
+
+ // Replace reduction_result by 'sub (startval, reductionresult)'.
+ VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
+ assert(RdxResult && "Could not find reduction result");
+
+ VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
+ constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
+ VPInstruction *NewResult = Builder.createNaryOp(
+ SubOpc, {OldStartValue, RdxResult},
+ VPIRFlags::getDefaultFlags(SubOpc), RdxPhi->getDebugLoc());
+ RdxResult->replaceUsesWithIf(
+ NewResult,
+ [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
+ }
+ }
+
return true;
}
@@ -5896,44 +5926,7 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan,
for (auto &[Phi, Chains] : ChainsByPhi) {
RecurKind RK = cast<VPReductionPHIRecipe>(Phi)->getRecurrenceKind();
- bool IsTransformed = false;
for (const VPPartialReductionChain &Chain : Chains)
- IsTransformed |=
- transformToPartialReduction(Chain, Range, CostCtx, Plan, RK);
-
- // Sub-reductions can be implemented in two ways:
- // (1) negate the operand in the vector loop (the default way).
- // (2) subtract the reduced value from the init value in the middle block.
- // Both ways keep the reduction itself as an 'add' reduction.
- //
- // The ISD nodes for partial reductions don't support folding the
- // sub/negation into its operands because the following is not a valid
- // transformation:
- // sub(0, mul(ext(a), ext(b)))
- // -> mul(ext(a), ext(sub(0, b)))
- //
- // It's therefore better to choose option (2) such that the partial
- // reduction is always positive (starting at '0') and to do a final
- // subtract in the middle block.
- if (IsTransformed && !Range.isEmpty() && RK == RecurKind::Sub) {
- // Update start value of PHI node.
- auto *StartInst = cast<VPInstruction>(Phi->getStartValue());
- assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
- VPValue *OldStartValue = StartInst->getOperand(0);
- StartInst->setOperand(0, StartInst->getOperand(1));
-
- // Replace reduction_result by 'sub (startval, reductionresult)'.
- VPInstruction *RdxResult = vputils::findComputeReductionResult(Phi);
- assert(RdxResult && "Could not find reduction result");
-
- VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
- constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
- VPInstruction *NewResult = Builder.createNaryOp(
- SubOpc, {OldStartValue, RdxResult},
- VPIRFlags::getDefaultFlags(SubOpc), Phi->getDebugLoc());
- RdxResult->replaceUsesWithIf(
- NewResult,
- [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
- }
+ transformToPartialReduction(Chain, Range, CostCtx, Plan, Phi, RK);
}
}
>From a0e369a46afd76fe5ac33bbecec5e0743a052436 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Tue, 10 Feb 2026 09:40:20 +0000
Subject: [PATCH 5/5] Remove indentation by exiting early
---
.../Transforms/Vectorize/VPlanTransforms.cpp | 67 ++++++++++---------
1 file changed, 35 insertions(+), 32 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index c7da4090f51c5..a99641c472b9f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5678,38 +5678,41 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
ExitValue->replaceAllUsesWith(PartialRed);
WidenRecipe->replaceAllUsesWith(PartialRed);
- if (IsLastInChain) {
- // Scale the PHI and ReductionStartVector by the VFScaleFactor
- assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
- RdxPhi->setVFScaleFactor(ScaleFactor);
-
- auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
- assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
- auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
- StartInst->setOperand(2, NewScaleFactor);
-
- // If this is the last value in a sub-reduction chain, then update the PHI
- // node to start at `0` and update the reduction-result to subtract from
- // the PHI's start value.
- if (RK == RecurKind::Sub) {
- // Update start value of PHI node.
- VPValue *OldStartValue = StartInst->getOperand(0);
- StartInst->setOperand(0, StartInst->getOperand(1));
-
- // Replace reduction_result by 'sub (startval, reductionresult)'.
- VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
- assert(RdxResult && "Could not find reduction result");
-
- VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
- constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
- VPInstruction *NewResult = Builder.createNaryOp(
- SubOpc, {OldStartValue, RdxResult},
- VPIRFlags::getDefaultFlags(SubOpc), RdxPhi->getDebugLoc());
- RdxResult->replaceUsesWithIf(
- NewResult,
- [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
- }
- }
+ // We only need to update the PHI node once, which is when we find the
+ // last reduction in the chain.
+ if (!IsLastInChain)
+ return true;
+
+ // Scale the PHI and ReductionStartVector by the VFScaleFactor
+ assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
+ RdxPhi->setVFScaleFactor(ScaleFactor);
+
+ auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
+ assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+ auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+ StartInst->setOperand(2, NewScaleFactor);
+
+ // If this is the last value in a sub-reduction chain, then update the PHI
+ // node to start at `0` and update the reduction-result to subtract from
+ // the PHI's start value.
+ if (RK != RecurKind::Sub)
+ return true;
+
+ VPValue *OldStartValue = StartInst->getOperand(0);
+ StartInst->setOperand(0, StartInst->getOperand(1));
+
+ // Replace reduction_result by 'sub (startval, reductionresult)'.
+ VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
+ assert(RdxResult && "Could not find reduction result");
+
+ VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
+ constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
+ VPInstruction *NewResult = Builder.createNaryOp(
+ SubOpc, {OldStartValue, RdxResult}, VPIRFlags::getDefaultFlags(SubOpc),
+ RdxPhi->getDebugLoc());
+ RdxResult->replaceUsesWithIf(
+ NewResult,
+ [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
return true;
}
More information about the llvm-commits
mailing list