[llvm] [AArch64] Combine vector FNEG+FMA into `FNML[A|S]` (PR #167900)

Damian Heaton via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 1 06:59:20 PST 2025


https://github.com/dheaton-arm updated https://github.com/llvm/llvm-project/pull/167900

>From 3c115a2ab6bb3d66dfd740add6145aeaa9b30249 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Thu, 13 Nov 2025 15:42:11 +0000
Subject: [PATCH 1/6] Combine vector FNEG+FMA into `FNML[A|S]`

This allows for FNEG + FMA sequences to be combined into a
single operation, with `FNML[A|S]`, `FNMAD`, or `FNMSB` selected
depending on the operand order.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  50 ++++
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   8 +-
 llvm/test/CodeGen/AArch64/sve-fmsub.ll        | 276 ++++++++++++++++++
 3 files changed, 332 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve-fmsub.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e91f5a877b35b..f2ddc8aa9a1ca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1170,6 +1170,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
   setTargetDAGCombine(ISD::CTPOP);
 
+  setTargetDAGCombine(ISD::FMA);
+
   // In case of strict alignment, avoid an excessive number of byte wide stores.
   MaxStoresPerMemsetOptSize = 8;
   MaxStoresPerMemset =
@@ -20692,6 +20694,52 @@ static SDValue performFADDCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performFMACombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+  SDValue Op1 = N->getOperand(0);
+  SDValue Op2 = N->getOperand(1);
+  SDValue Op3 = N->getOperand(2);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // fma(a, b, neg(c)) -> fnmls(a, b, c)
+  // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+  // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+  if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+      (Subtarget->hasSVE() || Subtarget->hasSME())) {
+    if (Op3.getOpcode() == ISD::FNEG) {
+      unsigned int Opcode;
+      if (Op1.getOpcode() == ISD::FNEG) {
+        Op1 = Op1.getOperand(0);
+        Opcode = AArch64ISD::FNMLA_PRED;
+      } else if (Op2.getOpcode() == ISD::FNEG) {
+        Op2 = Op2.getOperand(0);
+        Opcode = AArch64ISD::FNMLA_PRED;
+      } else {
+        Opcode = AArch64ISD::FNMLS_PRED;
+      }
+      Op3 = Op3.getOperand(0);
+      auto Pg = getPredicateForVector(DAG, DL, VT);
+      if (VT.isFixedLengthVector()) {
+        assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+               "Expected only legal fixed-width types");
+        EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+        Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
+        Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
+        Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
+        auto ScalableRes =
+            DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
+        return convertFromScalableVector(DAG, VT, ScalableRes);
+      }
+      return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
+    }
+  }
+
+  return SDValue();
+}
+
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
   switch (Opcode) {
   case ISD::STRICT_FADD:
@@ -28223,6 +28271,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performANDCombine(N, DCI);
   case ISD::FADD:
     return performFADDCombine(N, DCI);
+  case ISD::FMA:
+    return performFMACombine(N, DCI, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
     return performIntrinsicCombine(N, DCI, Subtarget);
   case ISD::ANY_EXTEND:
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e99b3f8ff07e0..2291495e1f9e8 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -240,6 +240,8 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
 def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
 def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
 def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
+def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>;
+def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>;
 
 def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
                                      (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
@@ -460,12 +462,14 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
 
 
 def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
-                              [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+                              [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
+                               (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
                                (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
                                (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
 
 def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
-                              [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
+                              [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
+                               (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
                                (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
 
 def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
new file mode 100644
index 0000000000000..721066038769c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -0,0 +1,276 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+
+define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fmsub_nxv2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 2 x double> %c
+  %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %neg)
+  ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fmsub_nxv4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 4 x float> %c
+  %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %neg)
+  ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fmsub_nxv8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 8 x half> %c
+  %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %neg)
+  ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fmsub_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <2 x double> %c
+  %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+  ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fmsub_v4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <4 x float> %c
+  %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+  ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fmsub_v8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <8 x half> %c
+  %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+  ret <8 x half> %0
+}
+
+
+define <2 x double> @fmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fmsub_flipped_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmls z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <2 x double> %c
+  %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+  ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_v4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <4 x float> %c
+  %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+  ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fmsub_flipped_v8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmls z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <8 x half> %c
+  %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+  ret <8 x half> %0
+}
+
+define <vscale x 2 x double> @fnmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fnmsub_nxv2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 2 x double> %a
+  %neg1 = fneg <vscale x 2 x double> %c
+  %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %neg, <vscale x 2 x double> %b, <vscale x 2 x double> %neg1)
+  ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fnmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fnmsub_nxv4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 4 x float> %a
+  %neg1 = fneg <vscale x 4 x float> %c
+  %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %neg, <vscale x 4 x float> %b, <vscale x 4 x float> %neg1)
+  ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fnmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fnmsub_nxv8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 8 x half> %a
+  %neg1 = fneg <vscale x 8 x half> %c
+  %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %neg, <vscale x 8 x half> %b, <vscale x 8 x half> %neg1)
+  ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fnmsub_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <2 x double> %a
+  %neg1 = fneg <2 x double> %c
+  %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+  ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fnmsub_v4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <4 x float> %a
+  %neg1 = fneg <4 x float> %c
+  %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+  ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fnmsub_v8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <8 x half> %a
+  %neg1 = fneg <8 x half> %c
+  %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+  ret <8 x half> %0
+}
+
+define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fnmsub_flipped_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <2 x double> %a
+  %neg1 = fneg <2 x double> %c
+  %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+  ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fnmsub_flipped_v4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <4 x float> %a
+  %neg1 = fneg <4 x float> %c
+  %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+  ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_v8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <8 x half> %a
+  %neg1 = fneg <8 x half> %c
+  %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+  ret <8 x half> %0
+}

>From 6a530d8cc775fb929c9d0f9d1f1aef1cdd78f0f1 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 18 Nov 2025 12:54:41 +0000
Subject: [PATCH 2/6] Address review comments

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  61 +++++-----
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   4 +-
 llvm/test/CodeGen/AArch64/sve-fmsub.ll        | 115 +++++++++++++++++-
 3 files changed, 143 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2ddc8aa9a1ca..721f9d2cdb367 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20698,46 +20698,41 @@ static SDValue performFMACombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const AArch64Subtarget *Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
-  SDValue Op1 = N->getOperand(0);
-  SDValue Op2 = N->getOperand(1);
-  SDValue Op3 = N->getOperand(2);
+  SDValue OpA = N->getOperand(0);
+  SDValue OpB = N->getOperand(1);
+  SDValue OpC = N->getOperand(2);
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
   // fma(a, b, neg(c)) -> fnmls(a, b, c)
   // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
   // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
-  if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
-      (Subtarget->hasSVE() || Subtarget->hasSME())) {
-    if (Op3.getOpcode() == ISD::FNEG) {
-      unsigned int Opcode;
-      if (Op1.getOpcode() == ISD::FNEG) {
-        Op1 = Op1.getOperand(0);
-        Opcode = AArch64ISD::FNMLA_PRED;
-      } else if (Op2.getOpcode() == ISD::FNEG) {
-        Op2 = Op2.getOperand(0);
-        Opcode = AArch64ISD::FNMLA_PRED;
-      } else {
-        Opcode = AArch64ISD::FNMLS_PRED;
-      }
-      Op3 = Op3.getOperand(0);
-      auto Pg = getPredicateForVector(DAG, DL, VT);
-      if (VT.isFixedLengthVector()) {
-        assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
-               "Expected only legal fixed-width types");
-        EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
-        Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
-        Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
-        Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
-        auto ScalableRes =
-            DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
-        return convertFromScalableVector(DAG, VT, ScalableRes);
-      }
-      return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
-    }
+  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+      !Subtarget->isSVEorStreamingSVEAvailable() ||
+      OpC.getOpcode() != ISD::FNEG) {
+    return SDValue();
+  }
+  unsigned int Opcode;
+  if (OpA.getOpcode() == ISD::FNEG) {
+    OpA = OpA.getOperand(0);
+    Opcode = AArch64ISD::FNMLA_PRED;
+  } else if (OpB.getOpcode() == ISD::FNEG) {
+    OpB = OpB.getOperand(0);
+    Opcode = AArch64ISD::FNMLA_PRED;
+  } else {
+    Opcode = AArch64ISD::FNMLS_PRED;
   }
-
-  return SDValue();
+  OpC = OpC.getOperand(0);
+  auto Pg = getPredicateForVector(DAG, DL, VT);
+  if (VT.isFixedLengthVector()) {
+    EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+    OpA = convertToScalableVector(DAG, ContainerVT, OpA);
+    OpB = convertToScalableVector(DAG, ContainerVT, OpB);
+    OpC = convertToScalableVector(DAG, ContainerVT, OpC);
+    auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
+    return convertFromScalableVector(DAG, VT, ScalableRes);
+  }
+  return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
 }
 
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 2291495e1f9e8..3722dcbb13d40 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -464,13 +464,11 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
 def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
                               [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
                                (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
-                               (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
                                (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
 
 def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
                               [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
-                               (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
-                               (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
+                               (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
 
 def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
                              (AArch64fsub_p node:$pg, node:$op2, node:$op1)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 721066038769c..29dbb87f1b875 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -1,5 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
-; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
+; RUN: llc -mattr=+sme -force-streaming %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SME
+
+target triple = "aarch64"
 
 define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
 ; CHECK-LABEL: fmsub_nxv2f64:
@@ -274,3 +277,113 @@ entry:
   %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
   ret <8 x half> %0
 }
+
+; Illegal types
+
+define <vscale x 3 x float> @fmsub_illegal_nxv3f32(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %c) {
+; CHECK-LABEL: fmsub_illegal_nxv3f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 3 x float> %c
+  %0 = tail call <vscale x 3 x float> @llvm.fmuladd(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %neg)
+  ret <vscale x 3 x float> %0
+}
+
+define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
+; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    ptrue p0.d, vl1
+; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-SVE-NEXT:    // kill: def $d2 killed $d2 def $z2
+; CHECK-SVE-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-SVE-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SME-LABEL: fmsub_illegal_v1f64:
+; CHECK-SME:       // %bb.0: // %entry
+; CHECK-SME-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SME-NEXT:    addvl sp, sp, #-1
+; CHECK-SME-NEXT:    .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-SME-NEXT:    .cfi_offset w29, -16
+; CHECK-SME-NEXT:    ptrue p0.d, vl1
+; CHECK-SME-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-SME-NEXT:    // kill: def $d2 killed $d2 def $z2
+; CHECK-SME-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-SME-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SME-NEXT:    str z0, [sp]
+; CHECK-SME-NEXT:    ldr d0, [sp]
+; CHECK-SME-NEXT:    addvl sp, sp, #1
+; CHECK-SME-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-SME-NEXT:    ret
+entry:
+  %neg = fneg <1 x double> %c
+  %0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
+  ret <1 x double> %0
+}
+
+define <3 x float> @fmsub_flipped_illegal_v3f32(<3 x float> %c, <3 x float> %a, <3 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_illegal_v3f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <3 x float> %c
+  %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %a, <3 x float> %b, <3 x float> %neg)
+  ret <3 x float> %0
+}
+
+define <vscale x 7 x half> @fnmsub_illegal_nxv7f16(<vscale x 7 x half> %a, <vscale x 7 x half> %b, <vscale x 7 x half> %c) {
+; CHECK-LABEL: fnmsub_illegal_nxv7f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 7 x half> %a
+  %neg1 = fneg <vscale x 7 x half> %c
+  %0 = tail call <vscale x 7 x half> @llvm.fmuladd(<vscale x 7 x half> %neg, <vscale x 7 x half> %b, <vscale x 7 x half> %neg1)
+  ret <vscale x 7 x half> %0
+}
+
+define <3 x float> @fnmsub_illegal_v3f32(<3 x float> %a, <3 x float> %b, <3 x float> %c) {
+; CHECK-LABEL: fnmsub_illegal_v3f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <3 x float> %a
+  %neg1 = fneg <3 x float> %c
+  %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %neg, <3 x float> %b, <3 x float> %neg1)
+  ret <3 x float> %0
+}
+
+define <7 x half> @fnmsub_flipped_illegal_v7f16(<7 x half> %c, <7 x half> %a, <7 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_illegal_v7f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <7 x half> %a
+  %neg1 = fneg <7 x half> %c
+  %0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
+  ret <7 x half> %0
+}

>From ae2ba4054b463f9b0eb20be028081ea093a99092 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 24 Nov 2025 14:47:42 +0000
Subject: [PATCH 3/6] Use existing TableGen patterns

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 41 ++++++++++++-------
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 10 ++---
 llvm/test/CodeGen/AArch64/sve-fmsub.ll        | 40 ++++++++++++++++++
 3 files changed, 70 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 721f9d2cdb367..362806e8b7b99 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20704,6 +20704,7 @@ static SDValue performFMACombine(SDNode *N,
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
+  // Convert FMA/FNEG nodes to SVE to enable the following patterns:
   // fma(a, b, neg(c)) -> fnmls(a, b, c)
   // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
   // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
@@ -20712,27 +20713,37 @@ static SDValue performFMACombine(SDNode *N,
       OpC.getOpcode() != ISD::FNEG) {
     return SDValue();
   }
-  unsigned int Opcode;
+
+  SDValue Pg = getPredicateForVector(DAG, DL, VT);
+  EVT ContainerVT =
+      VT.isFixedLengthVector() ? getContainerForFixedLengthVector(DAG, VT) : VT;
+  OpC = VT.isFixedLengthVector()
+            ? convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0))
+            : OpC->getOperand(0);
+  OpC = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpC,
+                    DAG.getUNDEF(ContainerVT));
+
+  if (OpB.getOpcode() == ISD::FNEG) {
+    std::swap(OpA, OpB);
+  }
+
   if (OpA.getOpcode() == ISD::FNEG) {
-    OpA = OpA.getOperand(0);
-    Opcode = AArch64ISD::FNMLA_PRED;
-  } else if (OpB.getOpcode() == ISD::FNEG) {
-    OpB = OpB.getOperand(0);
-    Opcode = AArch64ISD::FNMLA_PRED;
-  } else {
-    Opcode = AArch64ISD::FNMLS_PRED;
+    OpA = VT.isFixedLengthVector()
+              ? convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0))
+              : OpA->getOperand(0);
+    OpA = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpA,
+                      DAG.getUNDEF(ContainerVT));
+  } else if (VT.isFixedLengthVector()) {
+    OpA = convertToScalableVector(DAG, ContainerVT, OpA);
   }
-  OpC = OpC.getOperand(0);
-  auto Pg = getPredicateForVector(DAG, DL, VT);
+
   if (VT.isFixedLengthVector()) {
-    EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
-    OpA = convertToScalableVector(DAG, ContainerVT, OpA);
     OpB = convertToScalableVector(DAG, ContainerVT, OpB);
-    OpC = convertToScalableVector(DAG, ContainerVT, OpC);
-    auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
+    SDValue ScalableRes =
+        DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
     return convertFromScalableVector(DAG, VT, ScalableRes);
   }
-  return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
+  return DAG.getNode(AArch64ISD::FMA_PRED, DL, VT, Pg, OpA, OpB, OpC);
 }
 
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3722dcbb13d40..e99b3f8ff07e0 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -240,8 +240,6 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
 def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
 def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
 def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
-def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>;
-def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>;
 
 def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
                                      (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
@@ -462,13 +460,13 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
 
 
 def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
-                              [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
-                               (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+                              [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+                               (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
                                (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
 
 def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
-                              [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
-                               (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
+                              [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
+                               (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
 
 def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
                              (AArch64fsub_p node:$pg, node:$op2, node:$op1)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 29dbb87f1b875..27827720b2cdb 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -176,6 +176,46 @@ entry:
   ret <vscale x 8 x half> %0
 }
 
+
+define <vscale x 2 x double> @fnmsub_negated_b_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 2 x double> %b
+  %neg1 = fneg <vscale x 2 x double> %c
+  %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %a, <vscale x 2 x double> %neg, <vscale x 2 x double> %neg1)
+  ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fnmsub_negated_b_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 4 x float> %b
+  %neg1 = fneg <vscale x 4 x float> %c
+  %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %a, <vscale x 4 x float> %neg, <vscale x 4 x float> %neg1)
+  ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fnmsub_negated_b_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <vscale x 8 x half> %b
+  %neg1 = fneg <vscale x 8 x half> %c
+  %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %a, <vscale x 8 x half> %neg, <vscale x 8 x half> %neg1)
+  ret <vscale x 8 x half> %0
+}
+
 define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
 ; CHECK-LABEL: fnmsub_v2f64:
 ; CHECK:       // %bb.0: // %entry

>From 9892c8652ae6e1c2e899e610fec0ce8a0710d6a8 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 24 Nov 2025 16:59:54 +0000
Subject: [PATCH 4/6] Remove SVE handling and add pattern to compensate

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 49 +++++++------------
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  1 +
 2 files changed, 20 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 362806e8b7b99..5b1a6c47c94dd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20708,42 +20708,31 @@ static SDValue performFMACombine(SDNode *N,
   // fma(a, b, neg(c)) -> fnmls(a, b, c)
   // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
   // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
-  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+  if (!VT.isFixedLengthVector() ||
+      !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
       !Subtarget->isSVEorStreamingSVEAvailable() ||
       OpC.getOpcode() != ISD::FNEG) {
     return SDValue();
   }
 
   SDValue Pg = getPredicateForVector(DAG, DL, VT);
-  EVT ContainerVT =
-      VT.isFixedLengthVector() ? getContainerForFixedLengthVector(DAG, VT) : VT;
-  OpC = VT.isFixedLengthVector()
-            ? convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0))
-            : OpC->getOperand(0);
-  OpC = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpC,
-                    DAG.getUNDEF(ContainerVT));
-
-  if (OpB.getOpcode() == ISD::FNEG) {
-    std::swap(OpA, OpB);
-  }
-
-  if (OpA.getOpcode() == ISD::FNEG) {
-    OpA = VT.isFixedLengthVector()
-              ? convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0))
-              : OpA->getOperand(0);
-    OpA = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpA,
-                      DAG.getUNDEF(ContainerVT));
-  } else if (VT.isFixedLengthVector()) {
-    OpA = convertToScalableVector(DAG, ContainerVT, OpA);
-  }
-
-  if (VT.isFixedLengthVector()) {
-    OpB = convertToScalableVector(DAG, ContainerVT, OpB);
-    SDValue ScalableRes =
-        DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
-    return convertFromScalableVector(DAG, VT, ScalableRes);
-  }
-  return DAG.getNode(AArch64ISD::FMA_PRED, DL, VT, Pg, OpA, OpB, OpC);
+  EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+  OpC =
+      DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
+                  convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
+                  DAG.getUNDEF(ContainerVT));
+
+  OpA = OpA.getOpcode() == ISD::FNEG
+            ? DAG.getNode(
+                  AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
+                  convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
+                  DAG.getUNDEF(ContainerVT))
+            : convertToScalableVector(DAG, ContainerVT, OpA);
+
+  OpB = convertToScalableVector(DAG, ContainerVT, OpB);
+  SDValue ScalableRes =
+      DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
+  return convertFromScalableVector(DAG, VT, ScalableRes);
 }
 
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e99b3f8ff07e0..e46afa65d3c05 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -462,6 +462,7 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
 def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
                               [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
                                (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
+                               (AArch64fma_p node:$pg, node:$zn, (AArch64fneg_mt node:$pg, node:$zm, (undef)), (AArch64fneg_mt node:$pg, node:$za, (undef))),
                                (AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
 
 def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),

>From 2879d78055896a3889a1c33aa50963bb4eede793 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 25 Nov 2025 16:28:09 +0000
Subject: [PATCH 5/6] Use custom lowering rather than DAG Combiner

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 82 ++++++++----------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  1 +
 .../complex-deinterleaving-symmetric-fixed.ll | 22 +++--
 .../AArch64/sve-fixed-length-fp-arith.ll      | 24 ++++--
 llvm/test/CodeGen/AArch64/sve-fmsub.ll        | 84 +++++++++++++------
 5 files changed, 127 insertions(+), 86 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5b1a6c47c94dd..808a974eda694 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1170,8 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
   setTargetDAGCombine(ISD::CTPOP);
 
-  setTargetDAGCombine(ISD::FMA);
-
   // In case of strict alignment, avoid an excessive number of byte wide stores.
   MaxStoresPerMemsetOptSize = 8;
   MaxStoresPerMemset =
@@ -1526,6 +1524,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
       setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
+
+    for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
+      setOperationAction(ISD::FMA, VT, Custom);
+    }
   }
 
   if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7732,6 +7734,37 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
   return FCVTNT(VT, BottomBF16, Pg, TopF32);
 }
 
+SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
+  SDValue OpA = Op->getOperand(0);
+  SDValue OpB = Op->getOperand(1);
+  SDValue OpC = Op->getOperand(2);
+  EVT VT = Op.getValueType();
+  SDLoc DL(Op);
+
+  // Bail early if we're definitely not looking to merge FNEGs into the FMA.
+  if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
+    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+  }
+
+  // Convert FMA/FNEG nodes to SVE to enable the following patterns:
+  // fma(a, b, neg(c)) -> fnmls(a, b, c)
+  // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+  // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+  SDValue Pg = getPredicateForVector(DAG, DL, VT);
+  EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+  for (SDValue *Op : {&OpA, &OpB, &OpC}) {
+    // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
+    *Op = Op->getOpcode() == ISD::FNEG
+              ? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+                    ->getOperand(0)
+              : convertToScalableVector(DAG, ContainerVT, *Op);
+  }
+  SDValue ScalableRes =
+      DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
+  return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
                                               SelectionDAG &DAG) const {
   LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7808,7 +7841,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::FMUL:
     return LowerFMUL(Op, DAG);
   case ISD::FMA:
-    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+    return LowerFMA(Op, DAG);
   case ISD::FDIV:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
   case ISD::FNEG:
@@ -20694,47 +20727,6 @@ static SDValue performFADDCombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performFMACombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const AArch64Subtarget *Subtarget) {
-  SelectionDAG &DAG = DCI.DAG;
-  SDValue OpA = N->getOperand(0);
-  SDValue OpB = N->getOperand(1);
-  SDValue OpC = N->getOperand(2);
-  EVT VT = N->getValueType(0);
-  SDLoc DL(N);
-
-  // Convert FMA/FNEG nodes to SVE to enable the following patterns:
-  // fma(a, b, neg(c)) -> fnmls(a, b, c)
-  // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
-  // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
-  if (!VT.isFixedLengthVector() ||
-      !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
-      !Subtarget->isSVEorStreamingSVEAvailable() ||
-      OpC.getOpcode() != ISD::FNEG) {
-    return SDValue();
-  }
-
-  SDValue Pg = getPredicateForVector(DAG, DL, VT);
-  EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
-  OpC =
-      DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
-                  convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
-                  DAG.getUNDEF(ContainerVT));
-
-  OpA = OpA.getOpcode() == ISD::FNEG
-            ? DAG.getNode(
-                  AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
-                  convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
-                  DAG.getUNDEF(ContainerVT))
-            : convertToScalableVector(DAG, ContainerVT, OpA);
-
-  OpB = convertToScalableVector(DAG, ContainerVT, OpB);
-  SDValue ScalableRes =
-      DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
-  return convertFromScalableVector(DAG, VT, ScalableRes);
-}
-
 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
   switch (Opcode) {
   case ISD::STRICT_FADD:
@@ -28266,8 +28258,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performANDCombine(N, DCI);
   case ISD::FADD:
     return performFADDCombine(N, DCI);
-  case ISD::FMA:
-    return performFMACombine(N, DCI, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
     return performIntrinsicCombine(N, DCI, Subtarget);
   case ISD::ANY_EXTEND:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ca08eb40c956a..5f8b8adbede43 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -615,6 +615,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
index d05b9c6d7662a..49d33b9709e04 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
@@ -7,13 +7,18 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
 ; CHECK-LABEL: simple_symmetric_muladd2:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov x8, #-7378697629483820647 // =0x9999999999999999
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q3 killed $q3 def $z3
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
 ; CHECK-NEXT:    movk x8, #39322
 ; CHECK-NEXT:    movk x8, #16393, lsl #48
 ; CHECK-NEXT:    dup v4.2d, x8
-; CHECK-NEXT:    fmla v2.2d, v4.2d, v0.2d
-; CHECK-NEXT:    fmla v3.2d, v4.2d, v1.2d
-; CHECK-NEXT:    mov v0.16b, v2.16b
-; CHECK-NEXT:    mov v1.16b, v3.16b
+; CHECK-NEXT:    fmad z0.d, p0/m, z4.d, z2.d
+; CHECK-NEXT:    fmad z1.d, p0/m, z4.d, z3.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z1
 ; CHECK-NEXT:    ret
 entry:
   %ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -43,10 +48,11 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
 ; CHECK-NEXT:    zip1 v17.2d, v5.2d, v7.2d
 ; CHECK-NEXT:    zip2 v5.2d, v5.2d, v7.2d
 ; CHECK-NEXT:    dup v6.2d, x8
-; CHECK-NEXT:    fmla v3.2d, v6.2d, v16.2d
-; CHECK-NEXT:    fmla v4.2d, v6.2d, v0.2d
-; CHECK-NEXT:    fmla v17.2d, v6.2d, v2.2d
-; CHECK-NEXT:    fmla v5.2d, v6.2d, v1.2d
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    fmla z3.d, p0/m, z16.d, z6.d
+; CHECK-NEXT:    fmla z4.d, p0/m, z0.d, z6.d
+; CHECK-NEXT:    fmla z17.d, p0/m, z2.d, z6.d
+; CHECK-NEXT:    fmla z5.d, p0/m, z1.d, z6.d
 ; CHECK-NEXT:    zip1 v0.2d, v3.2d, v4.2d
 ; CHECK-NEXT:    zip2 v2.2d, v3.2d, v4.2d
 ; CHECK-NEXT:    zip1 v1.2d, v17.2d, v5.2d
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
index 2dda03e5c6dab..c6e87f9c3abcc 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
@@ -620,8 +620,12 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
 define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v8f16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmla v2.8h, v1.8h, v0.8h
-; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; CHECK-NEXT:    ret
   %res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
   ret <8 x half> %res
@@ -730,8 +734,12 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
 define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v4f32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmla v2.4s, v1.4s, v0.4s
-; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; CHECK-NEXT:    ret
   %res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
   ret <4 x float> %res
@@ -839,8 +847,12 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
 define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v2f64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmla v2.2d, v1.2d, v0.2d
-; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
 ; CHECK-NEXT:    ret
   %res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
   ret <2 x double> %res
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 27827720b2cdb..9b0205f658056 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -267,6 +267,57 @@ entry:
   ret <8 x half> %0
 }
 
+define <2 x double> @fnmsub_negated_b_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d, vl2
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <2 x double> %b
+  %neg1 = fneg <2 x double> %c
+  %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %neg, <2 x double> %neg1)
+  ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_negated_b_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <4 x float> %b
+  %neg1 = fneg <4 x float> %c
+  %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %neg, <4 x float> %neg1)
+  ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_negated_b_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h, vl8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT:    fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+entry:
+  %neg = fneg <8 x half> %b
+  %neg1 = fneg <8 x half> %c
+  %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %neg, <8 x half> %neg1)
+  ret <8 x half> %0
+}
+
 define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: fnmsub_flipped_v2f64:
 ; CHECK:       // %bb.0: // %entry
@@ -333,32 +384,10 @@ entry:
 }
 
 define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
-; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
-; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    ptrue p0.d, vl1
-; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 def $z0
-; CHECK-SVE-NEXT:    // kill: def $d2 killed $d2 def $z2
-; CHECK-SVE-NEXT:    // kill: def $d1 killed $d1 def $z1
-; CHECK-SVE-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
-; CHECK-SVE-NEXT:    // kill: def $d0 killed $d0 killed $z0
-; CHECK-SVE-NEXT:    ret
-;
-; CHECK-SME-LABEL: fmsub_illegal_v1f64:
-; CHECK-SME:       // %bb.0: // %entry
-; CHECK-SME-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-SME-NEXT:    addvl sp, sp, #-1
-; CHECK-SME-NEXT:    .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
-; CHECK-SME-NEXT:    .cfi_offset w29, -16
-; CHECK-SME-NEXT:    ptrue p0.d, vl1
-; CHECK-SME-NEXT:    // kill: def $d0 killed $d0 def $z0
-; CHECK-SME-NEXT:    // kill: def $d2 killed $d2 def $z2
-; CHECK-SME-NEXT:    // kill: def $d1 killed $d1 def $z1
-; CHECK-SME-NEXT:    fnmsb z0.d, p0/m, z1.d, z2.d
-; CHECK-SME-NEXT:    str z0, [sp]
-; CHECK-SME-NEXT:    ldr d0, [sp]
-; CHECK-SME-NEXT:    addvl sp, sp, #1
-; CHECK-SME-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
-; CHECK-SME-NEXT:    ret
+; CHECK-LABEL: fmsub_illegal_v1f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fnmsub d0, d0, d1, d2
+; CHECK-NEXT:    ret
 entry:
   %neg = fneg <1 x double> %c
   %0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
@@ -427,3 +456,6 @@ entry:
   %0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
   ret <7 x half> %0
 }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK-SME: {{.*}}
+; CHECK-SVE: {{.*}}

>From dd43ed98426056ec8f9d2de5710596c7f2f1f541 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 1 Dec 2025 14:58:04 +0000
Subject: [PATCH 6/6] Address feedback

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 27 ++++++++++++-------
 .../complex-deinterleaving-symmetric-fixed.ll | 22 ++++++---------
 .../AArch64/sve-fixed-length-fp-arith.ll      | 24 +++++------------
 3 files changed, 31 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 808a974eda694..38f834c1b52ac 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1525,9 +1525,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
       setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
 
-    for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
+    for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64})
       setOperationAction(ISD::FMA, VT, Custom);
-    }
   }
 
   if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7743,7 +7742,10 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
 
   // Bail early if we're definitely not looking to merge FNEGs into the FMA.
   if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
-    return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+    if (VT.isScalableVector() || VT.getScalarType() == MVT::bf16 ||
+        useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
+      return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+    return Op; // Fallback to NEON lowering.
   }
 
   // Convert FMA/FNEG nodes to SVE to enable the following patterns:
@@ -7753,13 +7755,18 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
   SDValue Pg = getPredicateForVector(DAG, DL, VT);
   EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
 
-  for (SDValue *Op : {&OpA, &OpB, &OpC}) {
-    // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
-    *Op = Op->getOpcode() == ISD::FNEG
-              ? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
-                    ->getOperand(0)
-              : convertToScalableVector(DAG, ContainerVT, *Op);
-  }
+  // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
+  OpA = OpA.getOpcode() == ISD::FNEG
+            ? LowerToPredicatedOp(OpA, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+                  ->getOperand(0)
+            : convertToScalableVector(DAG, ContainerVT, OpA);
+  OpB = OpB.getOpcode() == ISD::FNEG
+            ? LowerToPredicatedOp(OpB, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+                  ->getOperand(0)
+            : convertToScalableVector(DAG, ContainerVT, OpB);
+  OpC = LowerToPredicatedOp(OpC, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+            ->getOperand(0);
+
   SDValue ScalableRes =
       DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
   return convertFromScalableVector(DAG, VT, ScalableRes);
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
index 49d33b9709e04..d05b9c6d7662a 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
@@ -7,18 +7,13 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
 ; CHECK-LABEL: simple_symmetric_muladd2:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov x8, #-7378697629483820647 // =0x9999999999999999
-; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT:    // kill: def $q3 killed $q3 def $z3
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
 ; CHECK-NEXT:    movk x8, #39322
 ; CHECK-NEXT:    movk x8, #16393, lsl #48
 ; CHECK-NEXT:    dup v4.2d, x8
-; CHECK-NEXT:    fmad z0.d, p0/m, z4.d, z2.d
-; CHECK-NEXT:    fmad z1.d, p0/m, z4.d, z3.d
-; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
-; CHECK-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; CHECK-NEXT:    fmla v2.2d, v4.2d, v0.2d
+; CHECK-NEXT:    fmla v3.2d, v4.2d, v1.2d
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    mov v1.16b, v3.16b
 ; CHECK-NEXT:    ret
 entry:
   %ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -48,11 +43,10 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
 ; CHECK-NEXT:    zip1 v17.2d, v5.2d, v7.2d
 ; CHECK-NEXT:    zip2 v5.2d, v5.2d, v7.2d
 ; CHECK-NEXT:    dup v6.2d, x8
-; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    fmla z3.d, p0/m, z16.d, z6.d
-; CHECK-NEXT:    fmla z4.d, p0/m, z0.d, z6.d
-; CHECK-NEXT:    fmla z17.d, p0/m, z2.d, z6.d
-; CHECK-NEXT:    fmla z5.d, p0/m, z1.d, z6.d
+; CHECK-NEXT:    fmla v3.2d, v6.2d, v16.2d
+; CHECK-NEXT:    fmla v4.2d, v6.2d, v0.2d
+; CHECK-NEXT:    fmla v17.2d, v6.2d, v2.2d
+; CHECK-NEXT:    fmla v5.2d, v6.2d, v1.2d
 ; CHECK-NEXT:    zip1 v0.2d, v3.2d, v4.2d
 ; CHECK-NEXT:    zip2 v2.2d, v3.2d, v4.2d
 ; CHECK-NEXT:    zip1 v1.2d, v17.2d, v5.2d
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
index c6e87f9c3abcc..2dda03e5c6dab 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
@@ -620,12 +620,8 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
 define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v8f16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p0.h, vl8
-; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT:    fmad z0.h, p0/m, z1.h, z2.h
-; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    fmla v2.8h, v1.8h, v0.8h
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
   %res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
   ret <8 x half> %res
@@ -734,12 +730,8 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
 define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v4f32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p0.s, vl4
-; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT:    fmad z0.s, p0/m, z1.s, z2.s
-; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    fmla v2.4s, v1.4s, v0.4s
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
   %res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
   ret <4 x float> %res
@@ -847,12 +839,8 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
 define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
 ; CHECK-LABEL: fma_v2f64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT:    // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT:    fmad z0.d, p0/m, z1.d, z2.d
-; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    fmla v2.2d, v1.2d, v0.2d
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
   %res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
   ret <2 x double> %res



More information about the llvm-commits mailing list