[llvm-branch-commits] [llvm] [LV] Add support for partial reduction chains with fsubs (PR #195116)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Apr 30 09:05:51 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-vectorizers
Author: Sander de Smalen (sdesmalen-arm)
<details>
<summary>Changes</summary>
The cost-model prevented this from happening, but the LV would otherwise have generated incorrect code (without the fneg).
---
Full diff: https://github.com/llvm/llvm-project/pull/195116.diff
6 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+10-4)
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+13-7)
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+5)
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+11-2)
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+26-6)
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll (+67)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 755321c65881c..a2fea03fb6fc0 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5978,7 +5978,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
return Invalid;
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub &&
- Opcode != Instruction::FAdd) ||
+ Opcode != Instruction::FSub && Opcode != Instruction::FAdd) ||
OpAExtend == TTI::PR_None)
return Invalid;
@@ -6045,9 +6045,10 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
NEONPred);
};
- bool IsSub = Opcode == Instruction::Sub;
+ bool IsSub = Opcode == Instruction::Sub || Opcode == Instruction::FSub;
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
InstructionCost INegCost = IsSub ? 2 * InputLT.first * TTI::TCC_Basic : 0;
+ InstructionCost FNegCost = IsSub ? InputLT.first * TTI::TCC_Basic : 0;
if (AccumLT.second.getScalarType() == MVT::i32 &&
InputLT.second.getScalarType() == MVT::i8) {
@@ -6102,13 +6103,18 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
llvm::is_contained({MVT::i8, MVT::i16, MVT::i32}, InVT.SimpleTy))
return Cost * 2;
- // SVE2 fmlalb/t and NEON fmlal(2)
+ // SVE2 fml[as]lb/t and NEON fml[as]l(2)
if (IsSupported(ST->hasSVE2(), ST->hasFP16FML()) && InVT == MVT::f16)
return Cost * 2;
+ // SVE2p1 bfmlslb/t
+ if (IsSupported(ST->hasSVE2p1() && ST->hasBF16(), false) &&
+ InVT == MVT::bf16 && IsSub)
+ return Cost * 2;
+
// SVE and NEON bfmlalb/t
if (IsSupported(ST->hasBF16(), ST->hasBF16()) && InVT == MVT::bf16)
- return Cost * 2;
+ return Cost * 2 + FNegCost;
}
return BaseT::getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 934cca006e91c..b7e92747dffcd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3369,17 +3369,23 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
{Ext0, Ext1, Mul, Red}) {}
VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
- VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+ VPWidenRecipe *Mul, VPWidenRecipe *Neg,
VPReductionRecipe *Red)
: VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
- {Ext0, Ext1, Mul, Sub, Red}) {
- assert(Mul->getOpcode() == Instruction::Mul && "Expected a mul");
- assert(Red->getRecurrenceKind() == RecurKind::Add &&
+ {Ext0, Ext1, Mul, Neg, Red}) {
+ assert((Mul->getOpcode() == Instruction::Mul ||
+ Mul->getOpcode() == Instruction::FMul) &&
+ "Expected a mul");
+ assert((Red->getRecurrenceKind() == RecurKind::Add ||
+ Red->getRecurrenceKind() == RecurKind::FAdd) &&
"Expected an add reduction");
assert(getNumOperands() >= 3 && "Expected at least three operands");
- [[maybe_unused]] auto *SubConst = dyn_cast<VPConstantInt>(getOperand(2));
- assert(SubConst && SubConst->isZero() &&
- Sub->getOpcode() == Instruction::Sub && "Expected a negating sub");
+ if (Neg->getOpcode() == Instruction::Sub) {
+ [[maybe_unused]] auto *SubConst = dyn_cast<VPConstantInt>(getOperand(2));
+ assert(SubConst && SubConst->isZero() &&
+ Neg->getOpcode() == Instruction::Sub && "Expected a negating sub");
+ } else
+ assert(Neg->getOpcode() == Instruction::FNeg && "Unexpected opcode");
}
~VPExpressionRecipe() override {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 242046480f6e9..254697883a084 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -524,6 +524,11 @@ inline AllRecipe_match<Instruction::FPExt, Op0_t> m_FPExt(const Op0_t &Op0) {
return m_Unary<Instruction::FPExt, Op0_t>(Op0);
}
+template <typename Op0_t>
+inline AllRecipe_match<Instruction::FNeg, Op0_t> m_FNeg(const Op0_t &Op0) {
+ return m_Unary<Instruction::FNeg, Op0_t>(Op0);
+}
+
template <typename Op0_t>
inline match_combine_or<AllRecipe_match<Instruction::ZExt, Op0_t>,
AllRecipe_match<Instruction::SExt, Op0_t>>
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 25fe37124b017..0a94c86cec751 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3063,8 +3063,16 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
Ctx.CostKind);
case ExpressionTypes::ExtNegatedMulAccReduction:
- assert(Opcode == Instruction::Add && "Unexpected opcode");
- Opcode = Instruction::Sub;
+ switch (Opcode) {
+ case Instruction::Add:
+ Opcode = Instruction::Sub;
+ break;
+ case Instruction::FAdd:
+ Opcode = Instruction::FSub;
+ break;
+ default:
+ llvm_unreachable("Unsupported opcode for ExtNegatedMulAccReduction");
+ }
[[fallthrough]];
case ExpressionTypes::ExtMulAccReduction: {
auto *RedR = cast<VPReductionRecipe>(ExpressionRecipes.back());
@@ -3083,6 +3091,7 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
RedTy->isFloatingPointTy() ? std::optional{RedR->getFastMathFlags()}
: std::nullopt);
}
+ assert(Opcode != Instruction::FSub && "Only integer types are supported");
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5e6d1bbcd5a7c..58060bf0670c0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5895,6 +5895,17 @@ createPartialReductionExpression(VPReductionRecipe *Red) {
return new VPExpressionRecipe(ExtA, ExtB, Mul, Red);
}
+ // reduce.fadd(fneg(fmul(fpext(a), fpext(b))))
+ // -> VPExpressionRecipe(a, b, fmul, fsub, red)
+ if (match(VecOp,
+ m_FNeg(m_FMul(m_FPExt(m_VPValue()), m_FPExt(m_VPValue()))))) {
+ auto *FNeg = cast<VPWidenRecipe>(VecOp);
+ auto *FMul = cast<VPWidenRecipe>(FNeg->getOperand(0));
+ auto *ExtA = cast<VPWidenCastRecipe>(FMul->getOperand(0));
+ auto *ExtB = cast<VPWidenCastRecipe>(FMul->getOperand(1));
+ return new VPExpressionRecipe(ExtA, ExtB, FMul, FNeg, Red);
+ }
+
// reduce.add(neg(mul(ext(a), ext(b))))
// -> VPExpressionRecipe(a, b, mul, sub, red)
if (match(VecOp, m_Sub(m_ZeroInt(), m_Mul(m_ZExtOrSExt(m_VPValue()),
@@ -5935,14 +5946,23 @@ static void transformToPartialReduction(const VPPartialReductionChain &Chain,
// It's therefore better to choose option (2) such that the partial
// reduction is always positive (starting at '0') and to do a final
// subtract in the middle block.
- if (WidenRecipe->getOpcode() == Instruction::Sub &&
- Chain.RK != RecurKind::Sub) {
+ if ((WidenRecipe->getOpcode() == Instruction::Sub &&
+ Chain.RK != RecurKind::Sub) ||
+ // FIXME: We don't have a RecurKind::FSub yet.
+ WidenRecipe->getOpcode() == Instruction::FSub) {
VPBuilder Builder(WidenRecipe);
Type *ElemTy = TypeInfo.inferScalarType(ExtendedOp);
- auto *Zero = Plan.getZero(ElemTy);
- auto *NegRecipe =
- new VPWidenRecipe(Instruction::Sub, {Zero, ExtendedOp}, VPIRFlags(),
- VPIRMetadata(), DebugLoc::getUnknown());
+ VPWidenRecipe *NegRecipe;
+ if (WidenRecipe->getOpcode() == Instruction::FSub)
+ NegRecipe =
+ new VPWidenRecipe(Instruction::FNeg, {ExtendedOp}, VPIRFlags(),
+ VPIRMetadata(), DebugLoc::getUnknown());
+ else {
+ auto *Zero = Plan.getZero(ElemTy);
+ NegRecipe =
+ new VPWidenRecipe(Instruction::Sub, {Zero, ExtendedOp}, VPIRFlags(),
+ VPIRMetadata(), DebugLoc::getUnknown());
+ }
Builder.insert(NegRecipe);
ExtendedOp = NegRecipe;
}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
index f7b73371345bc..dfddba3a0025e 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
@@ -1000,6 +1000,73 @@ exit:
ret double %add
}
+; Test that a 'fneg' is generated for the fsub.
+define float @fadd_fsub_chain_f16_f32(ptr %a, ptr %b, ptr %c) #0 {
+; CHECK-LABEL: define float @fadd_fsub_chain_f16_f32(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 3
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[FOR_EXIT:.*]], label %[[VECTOR_PH:.*]]
+; CHECK: [[VECTOR_PH]]:
+; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT: [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 3
+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NEXT: br label %[[FOR_BODY:.*]]
+; CHECK: [[FOR_BODY]]:
+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x float> [ insertelement (<vscale x 4 x float> splat (float -0.000000e+00), float 0.000000e+00, i32 0), %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT: [[GEP_A:%.*]] = getelementptr half, ptr [[A]], i64 [[IV]]
+; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 8 x half>, ptr [[GEP_A]], align 1
+; CHECK-NEXT: [[GEP_B:%.*]] = getelementptr half, ptr [[B]], i64 [[IV]]
+; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 8 x half>, ptr [[GEP_B]], align 1
+; CHECK-NEXT: [[TMP12:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD4]] to <vscale x 8 x float>
+; CHECK-NEXT: [[TMP13:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD2]] to <vscale x 8 x float>
+; CHECK-NEXT: [[TMP14:%.*]] = fmul <vscale x 8 x float> [[TMP12]], [[TMP13]]
+; CHECK-NEXT: [[PARTIAL_REDUCE5:%.*]] = call reassoc contract <vscale x 4 x float> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x float> [[VEC_PHI1]], <vscale x 8 x float> [[TMP14]])
+; CHECK-NEXT: [[WIDE_LOAD7:%.*]] = load <vscale x 8 x half>, ptr [[GEP_A]], align 1
+; CHECK-NEXT: [[TMP17:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD7]] to <vscale x 8 x float>
+; CHECK-NEXT: [[TMP18:%.*]] = fmul <vscale x 8 x float> [[TMP17]], [[TMP12]]
+; CHECK-NEXT: [[TMP22:%.*]] = fneg <vscale x 8 x float> [[TMP18]]
+; CHECK-NEXT: [[PARTIAL_REDUCE9]] = call reassoc contract <vscale x 4 x float> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x float> [[PARTIAL_REDUCE5]], <vscale x 8 x float> [[TMP22]])
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[IV]], [[TMP3]]
+; CHECK-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT: br i1 [[TMP19]], label %[[MIDDLE_BLOCK:.*]], label %[[FOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
+; CHECK: [[MIDDLE_BLOCK]]:
+; CHECK-NEXT: [[TMP15:%.*]] = call reassoc contract float @llvm.vector.reduce.fadd.nxv4f32(float -0.000000e+00, <vscale x 4 x float> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NEXT: br i1 [[CMP_N]], [[FOR_EXIT1:label %.*]], label %[[FOR_EXIT]]
+; CHECK: [[FOR_EXIT]]:
+;
+entry:
+ br label %for.body
+
+for.body:
+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+ %accum = phi float [ 0.0, %entry ], [ %sub, %for.body ]
+ %gep.a = getelementptr half, ptr %a, i64 %iv
+ %load.a = load half, ptr %gep.a, align 1
+ %ext.a = fpext half %load.a to float
+ %gep.b = getelementptr half, ptr %b, i64 %iv
+ %load.b = load half, ptr %gep.b, align 1
+ %ext.b = fpext half %load.b to float
+ %mul = fmul float %ext.b, %ext.a
+ %add = fadd reassoc contract float %mul, %accum
+ %gep.c = getelementptr half, ptr %a, i64 %iv
+ %load.c = load half, ptr %gep.c, align 1
+ %ext.c = fpext half %load.c to float
+ %mul2 = fmul float %ext.c, %ext.b
+ %sub = fsub reassoc contract float %add, %mul2
+ %iv.next = add i64 %iv, 1
+ %exitcond.not = icmp eq i64 %iv.next, 1024
+ br i1 %exitcond.not, label %for.exit, label %for.body, !llvm.loop !0
+
+for.exit:
+ ret float %sub
+}
+
attributes #0 = { "target-features"="+sve2p1,+dotprod" }
attributes #1 = { "target-features"="+sve" }
attributes #2 = { "target-features"="+dotprod" }
``````````
</details>
https://github.com/llvm/llvm-project/pull/195116
More information about the llvm-branch-commits
mailing list