[llvm] 3157758 - [LV] Handle partial sub-reductions with sub in middle block. (#178919)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 10 03:00:37 PST 2026
Author: Sander de Smalen
Date: 2026-02-10T11:00:32Z
New Revision: 3157758190a487e344ecf1d2ff6402609a38d6b5
URL: https://github.com/llvm/llvm-project/commit/3157758190a487e344ecf1d2ff6402609a38d6b5
DIFF: https://github.com/llvm/llvm-project/commit/3157758190a487e344ecf1d2ff6402609a38d6b5.diff
LOG: [LV] Handle partial sub-reductions with sub in middle block. (#178919)
Sub-reductions can be implemented in two ways:
(1) negate the operand in the vector loop (the default way).
(2) subtract the reduced value from the init value in the middle block.
Note that both ways keep the reduction itself as an 'add' reduction,
which is necessary because only llvm.vector.partial.reduce.add exists.
The ISD nodes for partial reductions don't support folding the
sub/negation into its operands because the following is not a valid
transformation:
```
sub(0, mul(ext(a), ext(b)))
-> mul(ext(a), ext(sub(0, b)))
```
It can therefore be better to choose option (2) such that the partial
reduction is always positive (starting at '0') and to do a final
subtract in the middle block.
For AArch64 there are no dot-product instructions that can
do a `partial.reduce.sub(acc, mul(ext(a), ext(b)))` operation.
I'm not sure if such instructions exist for other targets.
(If so then we may want to make this decision a target option)
This PR also increases the AArch64 cost of a partial sub-reduction
when this exists in an 'add-sub' reduction chain.
Fixes https://github.com/llvm/llvm-project/issues/178703
Added:
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
Modified:
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c9d775367f929..71f52ae55d3ec 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5912,6 +5912,11 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+ // The sub/negation cannot be folded into the operands of
+ // ISD::PARTIAL_REDUCE_*MLA, so make the cost more expensive.
+ if (Opcode == Instruction::Sub)
+ Cost += 8;
+
// Prefer using full types by costing half-full input types as more expensive.
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
TypeSize::getScalable(128)))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index efea585114947..a99641c472b9f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5551,7 +5551,9 @@ struct VPPartialReductionChain {
// clamps VF range.
static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
VFRange &Range, VPCostContext &CostCtx,
- VPlan &Plan) {
+ VPlan &Plan,
+ VPReductionPHIRecipe *RdxPhi,
+ RecurKind RK) {
VPWidenRecipe *WidenRecipe = Chain.ReductionBinOp;
unsigned ScaleFactor = Chain.ScaleFactor;
assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation");
@@ -5624,23 +5626,21 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
Range))
return false;
- VPValue *Cond = nullptr;
- VPValue *ExitValue = nullptr;
- if (auto *RdxPhi = dyn_cast<VPReductionPHIRecipe>(AccumRecipe)) {
- assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
- RdxPhi->setVFScaleFactor(ScaleFactor);
-
- // Update ReductionStartVector instruction scale factor.
- VPValue *StartValue = RdxPhi->getOperand(0);
- auto *StartInst = cast<VPInstruction>(StartValue);
- assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
- auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
- StartInst->setOperand(2, NewScaleFactor);
-
- }
-
- // Handle SUB by negating the operand and using ADD for the partial reduction.
- if (WidenRecipe->getOpcode() == Instruction::Sub) {
+ // Sub-reductions can be implemented in two ways:
+ // (1) negate the operand in the vector loop (the default way).
+ // (2) subtract the reduced value from the init value in the middle block.
+ // Both ways keep the reduction itself as an 'add' reduction.
+ //
+ // The ISD nodes for partial reductions don't support folding the
+ // sub/negation into its operands because the following is not a valid
+ // transformation:
+ // sub(0, mul(ext(a), ext(b)))
+ // -> mul(ext(a), ext(sub(0, b)))
+ //
+ // It's therefore better to choose option (2) such that the partial
+ // reduction is always positive (starting at '0') and to do a final
+ // subtract in the middle block.
+ if (WidenRecipe->getOpcode() == Instruction::Sub && RK != RecurKind::Sub) {
VPBuilder Builder(WidenRecipe);
Type *ElemTy = CostCtx.Types.inferScalarType(BinOp);
auto *Zero = Plan.getConstantInt(ElemTy, 0);
@@ -5655,13 +5655,14 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
// Check if WidenRecipe is the final result of the reduction. If so look
// through selects for predicated reductions.
- VPReductionPHIRecipe *RdxPhi;
- ExitValue = cast_or_null<VPInstruction>(vputils::findUserOf(
- WidenRecipe, m_Select(m_VPValue(Cond), m_Specific(WidenRecipe),
- m_ReductionPhi(RdxPhi))));
- assert(!ExitValue || RdxPhi->getBackedgeValue() == WidenRecipe ||
- RdxPhi->getBackedgeValue() == ExitValue &&
- "if we found ExitValue, it must match RdxPhi's backedge value");
+ VPValue *Cond = nullptr;
+ VPValue *ExitValue = cast_or_null<VPInstruction>(vputils::findUserOf(
+ WidenRecipe,
+ m_Select(m_VPValue(Cond), m_Specific(WidenRecipe), m_Specific(RdxPhi))));
+ bool IsLastInChain = RdxPhi->getBackedgeValue() == WidenRecipe ||
+ RdxPhi->getBackedgeValue() == ExitValue;
+ assert((!ExitValue || IsLastInChain) &&
+ "if we found ExitValue, it must match RdxPhi's backedge value");
RecurKind RdxKind =
PhiType->isFloatingPointTy() ? RecurKind::FAdd : RecurKind::Add;
@@ -5676,6 +5677,43 @@ static bool transformToPartialReduction(const VPPartialReductionChain &Chain,
if (Cond)
ExitValue->replaceAllUsesWith(PartialRed);
WidenRecipe->replaceAllUsesWith(PartialRed);
+
+ // We only need to update the PHI node once, which is when we find the
+ // last reduction in the chain.
+ if (!IsLastInChain)
+ return true;
+
+ // Scale the PHI and ReductionStartVector by the VFScaleFactor
+ assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set");
+ RdxPhi->setVFScaleFactor(ScaleFactor);
+
+ auto *StartInst = cast<VPInstruction>(RdxPhi->getStartValue());
+ assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector);
+ auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor);
+ StartInst->setOperand(2, NewScaleFactor);
+
+ // If this is the last value in a sub-reduction chain, then update the PHI
+ // node to start at `0` and update the reduction-result to subtract from
+ // the PHI's start value.
+ if (RK != RecurKind::Sub)
+ return true;
+
+ VPValue *OldStartValue = StartInst->getOperand(0);
+ StartInst->setOperand(0, StartInst->getOperand(1));
+
+ // Replace reduction_result by 'sub (startval, reductionresult)'.
+ VPInstruction *RdxResult = vputils::findComputeReductionResult(RdxPhi);
+ assert(RdxResult && "Could not find reduction result");
+
+ VPBuilder Builder = VPBuilder::getToInsertAfter(RdxResult);
+ constexpr unsigned SubOpc = Instruction::BinaryOps::Sub;
+ VPInstruction *NewResult = Builder.createNaryOp(
+ SubOpc, {OldStartValue, RdxResult}, VPIRFlags::getDefaultFlags(SubOpc),
+ RdxPhi->getDebugLoc());
+ RdxResult->replaceUsesWithIf(
+ NewResult,
+ [&NewResult](VPUser &U, unsigned Idx) { return &U != NewResult; });
+
return true;
}
@@ -5889,7 +5927,9 @@ void VPlanTransforms::createPartialReductions(VPlan &Plan,
}
}
- for (const auto &[_, Chains] : ChainsByPhi)
+ for (auto &[Phi, Chains] : ChainsByPhi) {
+ RecurKind RK = cast<VPReductionPHIRecipe>(Phi)->getRecurrenceKind();
for (const VPPartialReductionChain &Chain : Chains)
- transformToPartialReduction(Chain, Range, CostCtx, Plan);
+ transformToPartialReduction(Chain, Range, CostCtx, Plan, Phi, RK);
+ }
}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index f1dee958fa09c..d1fde2cdaafe1 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -494,17 +494,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-NEON-NEXT: [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
-; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub nsw <16 x i32> zeroinitializer, [[TMP10]]
-; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
; CHECK-NEON-NEXT: [[TMP16:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
; CHECK-NEON-NEXT: [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP16]]
-; CHECK-NEON-NEXT: [[TMP13:%.*]] = sub <16 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP12]])
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-NEON-NEXT: [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-NEON: middle.block:
; CHECK-NEON-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT: [[TMP11:%.*]] = sub i32 0, [[TMP15]]
; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-NEON: scalar.ph:
@@ -537,17 +536,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-NEXT: [[TMP14:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP16:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD1]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP17:%.*]] = mul nsw <vscale x 16 x i32> [[TMP14]], [[TMP16]]
-; CHECK-SVE-NEXT: [[TMP10:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP17]]
-; CHECK-SVE-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP10]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP17]])
; CHECK-SVE-NEXT: [[TMP11:%.*]] = sext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i32>
; CHECK-SVE-NEXT: [[TMP12:%.*]] = mul nsw <vscale x 16 x i32> [[TMP14]], [[TMP11]]
-; CHECK-SVE-NEXT: [[TMP13:%.*]] = sub <vscale x 16 x i32> zeroinitializer, [[TMP12]]
-; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP13]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP12]])
; CHECK-SVE-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE: middle.block:
; CHECK-SVE-NEXT: [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-SVE-NEXT: [[TMP18:%.*]] = sub i32 0, [[TMP15]]
; CHECK-SVE-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-SVE: scalar.ph:
@@ -580,17 +578,16 @@ define i32 @chained_partial_reduce_sub_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
; CHECK-SVE-MAXBW-NEXT: [[TMP13:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP17:%.*]] = sub nsw <vscale x 8 x i32> zeroinitializer, [[TMP16]]
-; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP17]])
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP16]])
; CHECK-SVE-MAXBW-NEXT: [[TMP12:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
; CHECK-SVE-MAXBW-NEXT: [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP12]]
-; CHECK-SVE-MAXBW-NEXT: [[TMP19:%.*]] = sub <vscale x 8 x i32> zeroinitializer, [[TMP18]]
-; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE3]] = call <vscale x 2 x i32> @llvm.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[PARTIAL_REDUCE]], <vscale x 8 x i32> [[TMP18]])
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
; CHECK-SVE-MAXBW-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP10:![0-9]+]]
; CHECK-SVE-MAXBW: middle.block:
; CHECK-SVE-MAXBW-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-SVE-MAXBW-NEXT: [[TMP15:%.*]] = sub i32 0, [[TMP21]]
; CHECK-SVE-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
; CHECK-SVE-MAXBW-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
; CHECK-SVE-MAXBW: scalar.ph:
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
new file mode 100644
index 0000000000000..f06b2137c2b8d
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub-sdot.ll
@@ -0,0 +1,84 @@
+; REQUIRES: asserts
+; RUN: opt -passes=loop-vectorize \
+; RUN: -scalable-vectorization=on -mattr=+sve2 \
+; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
+; RUN: -disable-output < %s 2>&1 | FileCheck %s --check-prefixes=COMMON,SVE
+
+; RUN: opt -passes=loop-vectorize \
+; RUN: -scalable-vectorization=off -mattr=+neon,+dotprod \
+; RUN: -enable-epilogue-vectorization=false -debug-only=loop-vectorize \
+; RUN: -disable-output < %s 2>&1 | FileCheck %s --check-prefixes=COMMON,NEON
+
+; COMMON: LV: Checking a loop in 'sub_reduction'
+; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+
+; COMMON: LV: Checking a loop in 'add_sub_chained_reduction'
+; SVE: Cost of 1 for VF vscale x 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; SVE: Cost of 9 for VF vscale x 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+; NEON: Cost of 1 for VF 16: EXPRESSION vp<{{.*}}> = ir<%acc> + partial.reduce.add (mul (ir<%load1> sext to i32), (ir<%load2> sext to i32))
+; NEON: Cost of 9 for VF 16: EXPRESSION vp<{{.*}}> = vp<%9> + partial.reduce.add (sub (0, mul (ir<%load2> sext to i32), (ir<%load3> sext to i32)))
+
+target triple = "aarch64"
+
+; Test the cost of a SUB reduction, where the SUB is implemented outside the loop
+; and therefore not part of the partial reduction.
+define i32 @sub_reduction(ptr %arr1, ptr %arr2, i32 %init, i32 %n) #0 {
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %acc = phi i32 [ %init, %entry ], [ %sub, %loop ]
+ %gep1 = getelementptr inbounds i8, ptr %arr1, i32 %iv
+ %load1 = load i8, ptr %gep1
+ %sext1 = sext i8 %load1 to i32
+ %gep2 = getelementptr inbounds i8, ptr %arr2, i32 %iv
+ %load2 = load i8, ptr %gep2
+ %sext2 = sext i8 %load2 to i32
+ %mul = mul i32 %sext1, %sext2
+ %sub = sub i32 %acc, %mul
+ %iv.next = add i32 %iv, 1
+ %cmp = icmp ult i32 %iv.next, %n
+ br i1 %cmp, label %loop, label %exit, !llvm.loop !0
+
+exit:
+ ret i32 %sub
+}
+
+; Test that the cost of a SUB that is part of an ADD-SUB reduction chain
+; is high, because the negation happens inside the loop and cannot be
+; folded into the SDOT instruction (because of the extend).
+define i32 @add_sub_chained_reduction(ptr %arr1, ptr %arr2, ptr %arr3, i32 %init, i32 %n) #0 {
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
+ %acc = phi i32 [ %init, %entry ], [ %sub, %loop ]
+ %gep1 = getelementptr inbounds i8, ptr %arr1, i32 %iv
+ %load1 = load i8, ptr %gep1
+ %sext1 = sext i8 %load1 to i32
+ %gep2 = getelementptr inbounds i8, ptr %arr2, i32 %iv
+ %load2 = load i8, ptr %gep2
+ %sext2 = sext i8 %load2 to i32
+ %mul1 = mul i32 %sext1, %sext2
+ %add = add i32 %acc, %mul1
+ %gep3 = getelementptr inbounds i8, ptr %arr3, i32 %iv
+ %load3 = load i8, ptr %gep3
+ %sext3 = sext i8 %load3 to i32
+ %mul2 = mul i32 %sext2, %sext3
+ %sub = sub i32 %add, %mul2
+ %iv.next = add i32 %iv, 1
+ %cmp = icmp ult i32 %iv.next, %n
+ br i1 %cmp, label %loop, label %exit, !llvm.loop !0
+
+exit:
+ ret i32 %sub
+}
+
+attributes #0 = { vscale_range(1,16) }
+
+!0 = distinct !{!0, !1, !2}
+!1 = !{!"llvm.loop.interleave.count", i32 1}
+!2 = !{!"llvm.loop.vectorize.width", i32 16}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
index 107d1d8a78706..9bf37a43e467d 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll
@@ -154,13 +154,13 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVE1-NEXT: [[TMP3:%.*]] = sext <4 x i32> [[TMP2]] to <4 x i64>
; CHECK-INTERLEAVE1-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-INTERLEAVE1-NEXT: [[TMP5:%.*]] = mul <4 x i64> [[TMP3]], [[TMP4]]
-; CHECK-INTERLEAVE1-NEXT: [[TMP6:%.*]] = sub <4 x i64> zeroinitializer, [[TMP5]]
-; CHECK-INTERLEAVE1-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP6]])
+; CHECK-INTERLEAVE1-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP5]])
; CHECK-INTERLEAVE1-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-INTERLEAVE1-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 36
; CHECK-INTERLEAVE1-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-INTERLEAVE1: middle.block:
; CHECK-INTERLEAVE1-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[PARTIAL_REDUCE]])
+; CHECK-INTERLEAVE1-NEXT: [[TMP9:%.*]] = sub i64 0, [[TMP8]]
; CHECK-INTERLEAVE1-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC]], i32 3
; CHECK-INTERLEAVE1-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-INTERLEAVE1: scalar.ph:
@@ -204,23 +204,19 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = sext <4 x i32> [[TMP5]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP11:%.*]] = mul <4 x i64> [[TMP7]], [[TMP9]]
-; CHECK-INTERLEAVED-NEXT: [[TMP18:%.*]] = sub <4 x i64> zeroinitializer, [[TMP11]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP18]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP11]])
; CHECK-INTERLEAVED-NEXT: [[TMP19:%.*]] = sext <4 x i32> [[TMP31]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP20:%.*]] = sext <4 x i32> [[STRIDED_VEC5]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP21:%.*]] = mul <4 x i64> [[TMP19]], [[TMP20]]
-; CHECK-INTERLEAVED-NEXT: [[TMP22:%.*]] = sub <4 x i64> zeroinitializer, [[TMP21]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE10]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI1]], <4 x i64> [[TMP22]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE10]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI1]], <4 x i64> [[TMP21]])
; CHECK-INTERLEAVED-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP13]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = sext <4 x i32> [[STRIDED_VEC7]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = mul <4 x i64> [[TMP8]], [[TMP10]]
-; CHECK-INTERLEAVED-NEXT: [[TMP26:%.*]] = sub <4 x i64> zeroinitializer, [[TMP12]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE11]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI2]], <4 x i64> [[TMP26]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE11]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI2]], <4 x i64> [[TMP12]])
; CHECK-INTERLEAVED-NEXT: [[TMP27:%.*]] = sext <4 x i32> [[TMP14]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP28:%.*]] = sext <4 x i32> [[STRIDED_VEC3]] to <4 x i64>
; CHECK-INTERLEAVED-NEXT: [[TMP29:%.*]] = mul <4 x i64> [[TMP27]], [[TMP28]]
-; CHECK-INTERLEAVED-NEXT: [[TMP30:%.*]] = sub <4 x i64> zeroinitializer, [[TMP29]]
-; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE12]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI3]], <4 x i64> [[TMP30]])
+; CHECK-INTERLEAVED-NEXT: [[PARTIAL_REDUCE12]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI3]], <4 x i64> [[TMP29]])
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
; CHECK-INTERLEAVED-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 32
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -229,6 +225,7 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX13:%.*]] = add <2 x i64> [[PARTIAL_REDUCE11]], [[BIN_RDX]]
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX14:%.*]] = add <2 x i64> [[PARTIAL_REDUCE12]], [[BIN_RDX13]]
; CHECK-INTERLEAVED-NEXT: [[TMP32:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[BIN_RDX14]])
+; CHECK-INTERLEAVED-NEXT: [[TMP30:%.*]] = sub i64 0, [[TMP32]]
; CHECK-INTERLEAVED-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC3]], i32 3
; CHECK-INTERLEAVED-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-INTERLEAVED: scalar.ph:
@@ -251,13 +248,13 @@ define i64 @partial_reduce_sub_sext_mul(ptr %x) #1 {
; CHECK-MAXBW-NEXT: [[TMP3:%.*]] = sext <4 x i32> [[TMP2]] to <4 x i64>
; CHECK-MAXBW-NEXT: [[TMP4:%.*]] = sext <4 x i32> [[STRIDED_VEC]] to <4 x i64>
; CHECK-MAXBW-NEXT: [[TMP5:%.*]] = mul <4 x i64> [[TMP3]], [[TMP4]]
-; CHECK-MAXBW-NEXT: [[TMP6:%.*]] = sub <4 x i64> zeroinitializer, [[TMP5]]
-; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP6]])
+; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE]] = call <2 x i64> @llvm.vector.partial.reduce.add.v2i64.v4i64(<2 x i64> [[VEC_PHI]], <4 x i64> [[TMP5]])
; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-MAXBW-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 36
; CHECK-MAXBW-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-MAXBW: middle.block:
; CHECK-MAXBW-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> [[PARTIAL_REDUCE]])
+; CHECK-MAXBW-NEXT: [[TMP9:%.*]] = sub i64 0, [[TMP8]]
; CHECK-MAXBW-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i32> [[STRIDED_VEC]], i32 3
; CHECK-MAXBW-NEXT: br label [[SCALAR_PH:%.*]]
; CHECK-MAXBW: scalar.ph:
More information about the llvm-commits
mailing list