[llvm-branch-commits] [llvm] [LV] Add support for partial reduction chains with fsubs (PR #195116)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Apr 30 09:05:51 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-vectorizers

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>

The cost-model prevented this from happening, but the LV would otherwise have generated incorrect code (without the fneg).

---
Full diff: https://github.com/llvm/llvm-project/pull/195116.diff


6 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+10-4) 
- (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+13-7) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+5) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+11-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+26-6) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll (+67) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 755321c65881c..a2fea03fb6fc0 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5978,7 +5978,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     return Invalid;
 
   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub &&
-       Opcode != Instruction::FAdd) ||
+       Opcode != Instruction::FSub && Opcode != Instruction::FAdd) ||
       OpAExtend == TTI::PR_None)
     return Invalid;
 
@@ -6045,9 +6045,10 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
             NEONPred);
   };
 
-  bool IsSub = Opcode == Instruction::Sub;
+  bool IsSub = Opcode == Instruction::Sub || Opcode == Instruction::FSub;
   InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
   InstructionCost INegCost = IsSub ? 2 * InputLT.first * TTI::TCC_Basic : 0;
+  InstructionCost FNegCost = IsSub ? InputLT.first * TTI::TCC_Basic : 0;
 
   if (AccumLT.second.getScalarType() == MVT::i32 &&
       InputLT.second.getScalarType() == MVT::i8) {
@@ -6102,13 +6103,18 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
         llvm::is_contained({MVT::i8, MVT::i16, MVT::i32}, InVT.SimpleTy))
       return Cost * 2;
 
-    // SVE2 fmlalb/t and NEON fmlal(2)
+    // SVE2 fml[as]lb/t and NEON fml[as]l(2)
     if (IsSupported(ST->hasSVE2(), ST->hasFP16FML()) && InVT == MVT::f16)
       return Cost * 2;
 
+    // SVE2p1 bfmlslb/t
+    if (IsSupported(ST->hasSVE2p1() && ST->hasBF16(), false) &&
+        InVT == MVT::bf16 && IsSub)
+      return Cost * 2;
+
     // SVE and NEON bfmlalb/t
     if (IsSupported(ST->hasBF16(), ST->hasBF16()) && InVT == MVT::bf16)
-      return Cost * 2;
+      return Cost * 2 + FNegCost;
   }
 
   return BaseT::getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 934cca006e91c..b7e92747dffcd 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3369,17 +3369,23 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
       : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
                            {Ext0, Ext1, Mul, Red}) {}
   VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
-                     VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+                     VPWidenRecipe *Mul, VPWidenRecipe *Neg,
                      VPReductionRecipe *Red)
       : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
-                           {Ext0, Ext1, Mul, Sub, Red}) {
-    assert(Mul->getOpcode() == Instruction::Mul && "Expected a mul");
-    assert(Red->getRecurrenceKind() == RecurKind::Add &&
+                           {Ext0, Ext1, Mul, Neg, Red}) {
+    assert((Mul->getOpcode() == Instruction::Mul ||
+            Mul->getOpcode() == Instruction::FMul) &&
+           "Expected a mul");
+    assert((Red->getRecurrenceKind() == RecurKind::Add ||
+            Red->getRecurrenceKind() == RecurKind::FAdd) &&
            "Expected an add reduction");
     assert(getNumOperands() >= 3 && "Expected at least three operands");
-    [[maybe_unused]] auto *SubConst = dyn_cast<VPConstantInt>(getOperand(2));
-    assert(SubConst && SubConst->isZero() &&
-           Sub->getOpcode() == Instruction::Sub && "Expected a negating sub");
+    if (Neg->getOpcode() == Instruction::Sub) {
+      [[maybe_unused]] auto *SubConst = dyn_cast<VPConstantInt>(getOperand(2));
+      assert(SubConst && SubConst->isZero() &&
+             Neg->getOpcode() == Instruction::Sub && "Expected a negating sub");
+    } else
+      assert(Neg->getOpcode() == Instruction::FNeg && "Unexpected opcode");
   }
 
   ~VPExpressionRecipe() override {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 242046480f6e9..254697883a084 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -524,6 +524,11 @@ inline AllRecipe_match<Instruction::FPExt, Op0_t> m_FPExt(const Op0_t &Op0) {
   return m_Unary<Instruction::FPExt, Op0_t>(Op0);
 }
 
+template <typename Op0_t>
+inline AllRecipe_match<Instruction::FNeg, Op0_t> m_FNeg(const Op0_t &Op0) {
+  return m_Unary<Instruction::FNeg, Op0_t>(Op0);
+}
+
 template <typename Op0_t>
 inline match_combine_or<AllRecipe_match<Instruction::ZExt, Op0_t>,
                         AllRecipe_match<Instruction::SExt, Op0_t>>
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 25fe37124b017..0a94c86cec751 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3063,8 +3063,16 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
                                           Ctx.CostKind);
 
   case ExpressionTypes::ExtNegatedMulAccReduction:
-    assert(Opcode == Instruction::Add && "Unexpected opcode");
-    Opcode = Instruction::Sub;
+    switch (Opcode) {
+    case Instruction::Add:
+      Opcode = Instruction::Sub;
+      break;
+    case Instruction::FAdd:
+      Opcode = Instruction::FSub;
+      break;
+    default:
+      llvm_unreachable("Unsupported opcode for ExtNegatedMulAccReduction");
+    }
     [[fallthrough]];
   case ExpressionTypes::ExtMulAccReduction: {
     auto *RedR = cast<VPReductionRecipe>(ExpressionRecipes.back());
@@ -3083,6 +3091,7 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
           RedTy->isFloatingPointTy() ? std::optional{RedR->getFastMathFlags()}
                                      : std::nullopt);
     }
+    assert(Opcode != Instruction::FSub && "Only integer types are supported");
     return Ctx.TTI.getMulAccReductionCost(
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
             Instruction::ZExt,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5e6d1bbcd5a7c..58060bf0670c0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -5895,6 +5895,17 @@ createPartialReductionExpression(VPReductionRecipe *Red) {
     return new VPExpressionRecipe(ExtA, ExtB, Mul, Red);
   }
 
+  // reduce.fadd(fneg(fmul(fpext(a), fpext(b))))
+  //  -> VPExpressionRecipe(a, b, fmul, fsub, red)
+  if (match(VecOp,
+            m_FNeg(m_FMul(m_FPExt(m_VPValue()), m_FPExt(m_VPValue()))))) {
+    auto *FNeg = cast<VPWidenRecipe>(VecOp);
+    auto *FMul = cast<VPWidenRecipe>(FNeg->getOperand(0));
+    auto *ExtA = cast<VPWidenCastRecipe>(FMul->getOperand(0));
+    auto *ExtB = cast<VPWidenCastRecipe>(FMul->getOperand(1));
+    return new VPExpressionRecipe(ExtA, ExtB, FMul, FNeg, Red);
+  }
+
   // reduce.add(neg(mul(ext(a), ext(b))))
   //  -> VPExpressionRecipe(a, b, mul, sub, red)
   if (match(VecOp, m_Sub(m_ZeroInt(), m_Mul(m_ZExtOrSExt(m_VPValue()),
@@ -5935,14 +5946,23 @@ static void transformToPartialReduction(const VPPartialReductionChain &Chain,
   // It's therefore better to choose option (2) such that the partial
   // reduction is always positive (starting at '0') and to do a final
   // subtract in the middle block.
-  if (WidenRecipe->getOpcode() == Instruction::Sub &&
-      Chain.RK != RecurKind::Sub) {
+  if ((WidenRecipe->getOpcode() == Instruction::Sub &&
+       Chain.RK != RecurKind::Sub) ||
+      // FIXME: We don't have a RecurKind::FSub yet.
+      WidenRecipe->getOpcode() == Instruction::FSub) {
     VPBuilder Builder(WidenRecipe);
     Type *ElemTy = TypeInfo.inferScalarType(ExtendedOp);
-    auto *Zero = Plan.getZero(ElemTy);
-    auto *NegRecipe =
-        new VPWidenRecipe(Instruction::Sub, {Zero, ExtendedOp}, VPIRFlags(),
-                          VPIRMetadata(), DebugLoc::getUnknown());
+    VPWidenRecipe *NegRecipe;
+    if (WidenRecipe->getOpcode() == Instruction::FSub)
+      NegRecipe =
+          new VPWidenRecipe(Instruction::FNeg, {ExtendedOp}, VPIRFlags(),
+                            VPIRMetadata(), DebugLoc::getUnknown());
+    else {
+      auto *Zero = Plan.getZero(ElemTy);
+      NegRecipe =
+          new VPWidenRecipe(Instruction::Sub, {Zero, ExtendedOp}, VPIRFlags(),
+                            VPIRMetadata(), DebugLoc::getUnknown());
+    }
     Builder.insert(NegRecipe);
     ExtendedOp = NegRecipe;
   }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
index f7b73371345bc..dfddba3a0025e 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-fdot-product.ll
@@ -1000,6 +1000,73 @@ exit:
   ret double %add
 }
 
+; Test that a 'fneg' is generated for the fsub.
+define float @fadd_fsub_chain_f16_f32(ptr %a, ptr %b, ptr %c) #0 {
+; CHECK-LABEL: define float @fadd_fsub_chain_f16_f32(
+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw i64 [[TMP0]], 3
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[FOR_EXIT:.*]], label %[[VECTOR_PH:.*]]
+; CHECK:       [[VECTOR_PH]]:
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 3
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK:       [[FOR_BODY]]:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi <vscale x 4 x float> [ insertelement (<vscale x 4 x float> splat (float -0.000000e+00), float 0.000000e+00, i32 0), %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE9:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT:    [[GEP_A:%.*]] = getelementptr half, ptr [[A]], i64 [[IV]]
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 8 x half>, ptr [[GEP_A]], align 1
+; CHECK-NEXT:    [[GEP_B:%.*]] = getelementptr half, ptr [[B]], i64 [[IV]]
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 8 x half>, ptr [[GEP_B]], align 1
+; CHECK-NEXT:    [[TMP12:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD4]] to <vscale x 8 x float>
+; CHECK-NEXT:    [[TMP13:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD2]] to <vscale x 8 x float>
+; CHECK-NEXT:    [[TMP14:%.*]] = fmul <vscale x 8 x float> [[TMP12]], [[TMP13]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5:%.*]] = call reassoc contract <vscale x 4 x float> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x float> [[VEC_PHI1]], <vscale x 8 x float> [[TMP14]])
+; CHECK-NEXT:    [[WIDE_LOAD7:%.*]] = load <vscale x 8 x half>, ptr [[GEP_A]], align 1
+; CHECK-NEXT:    [[TMP17:%.*]] = fpext <vscale x 8 x half> [[WIDE_LOAD7]] to <vscale x 8 x float>
+; CHECK-NEXT:    [[TMP18:%.*]] = fmul <vscale x 8 x float> [[TMP17]], [[TMP12]]
+; CHECK-NEXT:    [[TMP22:%.*]] = fneg <vscale x 8 x float> [[TMP18]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE9]] = call reassoc contract <vscale x 4 x float> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x float> [[PARTIAL_REDUCE5]], <vscale x 8 x float> [[TMP22]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[IV]], [[TMP3]]
+; CHECK-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP19]], label %[[MIDDLE_BLOCK:.*]], label %[[FOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
+; CHECK:       [[MIDDLE_BLOCK]]:
+; CHECK-NEXT:    [[TMP15:%.*]] = call reassoc contract float @llvm.vector.reduce.fadd.nxv4f32(float -0.000000e+00, <vscale x 4 x float> [[PARTIAL_REDUCE9]])
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], [[FOR_EXIT1:label %.*]], label %[[FOR_EXIT]]
+; CHECK:       [[FOR_EXIT]]:
+;
+entry:
+  br label %for.body
+
+for.body:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
+  %accum = phi float [ 0.0, %entry ], [ %sub, %for.body ]
+  %gep.a = getelementptr half, ptr %a, i64 %iv
+  %load.a = load half, ptr %gep.a, align 1
+  %ext.a = fpext half %load.a to float
+  %gep.b = getelementptr half, ptr %b, i64 %iv
+  %load.b = load half, ptr %gep.b, align 1
+  %ext.b = fpext half %load.b to float
+  %mul = fmul float %ext.b, %ext.a
+  %add = fadd reassoc contract float %mul, %accum
+  %gep.c = getelementptr half, ptr %a, i64 %iv
+  %load.c = load half, ptr %gep.c, align 1
+  %ext.c = fpext half %load.c to float
+  %mul2 = fmul float %ext.c, %ext.b
+  %sub = fsub reassoc contract float %add, %mul2
+  %iv.next = add i64 %iv, 1
+  %exitcond.not = icmp eq i64 %iv.next, 1024
+  br i1 %exitcond.not, label %for.exit, label %for.body, !llvm.loop !0
+
+for.exit:
+  ret float %sub
+}
+
 attributes #0 = { "target-features"="+sve2p1,+dotprod" }
 attributes #1 = { "target-features"="+sve" }
 attributes #2 = { "target-features"="+dotprod" }

``````````

</details>


https://github.com/llvm/llvm-project/pull/195116


More information about the llvm-branch-commits mailing list