[llvm] [AArch64] Combine vector FNEG+FMA into `FNML[A|S]` (PR #167900)
Damian Heaton via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 1 06:59:20 PST 2025
https://github.com/dheaton-arm updated https://github.com/llvm/llvm-project/pull/167900
>From 3c115a2ab6bb3d66dfd740add6145aeaa9b30249 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Thu, 13 Nov 2025 15:42:11 +0000
Subject: [PATCH 1/6] Combine vector FNEG+FMA into `FNML[A|S]`
This allows for FNEG + FMA sequences to be combined into a
single operation, with `FNML[A|S]`, `FNMAD`, or `FNMSB` selected
depending on the operand order.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 50 ++++
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 8 +-
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 276 ++++++++++++++++++
3 files changed, 332 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-fmsub.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e91f5a877b35b..f2ddc8aa9a1ca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1170,6 +1170,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
setTargetDAGCombine(ISD::CTPOP);
+ setTargetDAGCombine(ISD::FMA);
+
// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
MaxStoresPerMemset =
@@ -20692,6 +20694,52 @@ static SDValue performFADDCombine(SDNode *N,
return SDValue();
}
+static SDValue performFMACombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+ SelectionDAG &DAG = DCI.DAG;
+ SDValue Op1 = N->getOperand(0);
+ SDValue Op2 = N->getOperand(1);
+ SDValue Op3 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // fma(a, b, neg(c)) -> fnmls(a, b, c)
+ // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+ // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+ if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+ (Subtarget->hasSVE() || Subtarget->hasSME())) {
+ if (Op3.getOpcode() == ISD::FNEG) {
+ unsigned int Opcode;
+ if (Op1.getOpcode() == ISD::FNEG) {
+ Op1 = Op1.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else if (Op2.getOpcode() == ISD::FNEG) {
+ Op2 = Op2.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else {
+ Opcode = AArch64ISD::FNMLS_PRED;
+ }
+ Op3 = Op3.getOperand(0);
+ auto Pg = getPredicateForVector(DAG, DL, VT);
+ if (VT.isFixedLengthVector()) {
+ assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+ "Expected only legal fixed-width types");
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
+ Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
+ Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
+ auto ScalableRes =
+ DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+ }
+ return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
+ }
+ }
+
+ return SDValue();
+}
+
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
switch (Opcode) {
case ISD::STRICT_FADD:
@@ -28223,6 +28271,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performANDCombine(N, DCI);
case ISD::FADD:
return performFADDCombine(N, DCI);
+ case ISD::FMA:
+ return performFMACombine(N, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performIntrinsicCombine(N, DCI, Subtarget);
case ISD::ANY_EXTEND:
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e99b3f8ff07e0..2291495e1f9e8 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -240,6 +240,8 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
+def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>;
+def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>;
def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
@@ -460,12 +462,14 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+ [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
+ (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
+ [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
+ (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
new file mode 100644
index 0000000000000..721066038769c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -0,0 +1,276 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+
+define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fmsub_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 2 x double> %c
+ %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %neg)
+ ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fmsub_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 4 x float> %c
+ %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %neg)
+ ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fmsub_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 8 x half> %c
+ %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %neg)
+ ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fmsub_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fmsub_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fmsub_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+ ret <8 x half> %0
+}
+
+
+define <2 x double> @fmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fmsub_flipped_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fmsub_flipped_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+ ret <8 x half> %0
+}
+
+define <vscale x 2 x double> @fnmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fnmsub_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 2 x double> %a
+ %neg1 = fneg <vscale x 2 x double> %c
+ %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %neg, <vscale x 2 x double> %b, <vscale x 2 x double> %neg1)
+ ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fnmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fnmsub_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 4 x float> %a
+ %neg1 = fneg <vscale x 4 x float> %c
+ %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %neg, <vscale x 4 x float> %b, <vscale x 4 x float> %neg1)
+ ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fnmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fnmsub_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 8 x half> %a
+ %neg1 = fneg <vscale x 8 x half> %c
+ %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %neg, <vscale x 8 x half> %b, <vscale x 8 x half> %neg1)
+ ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fnmsub_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %a
+ %neg1 = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fnmsub_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %a
+ %neg1 = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fnmsub_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %a
+ %neg1 = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+ ret <8 x half> %0
+}
+
+define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fnmsub_flipped_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %a
+ %neg1 = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fnmsub_flipped_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %a
+ %neg1 = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %a
+ %neg1 = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+ ret <8 x half> %0
+}
>From 6a530d8cc775fb929c9d0f9d1f1aef1cdd78f0f1 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 18 Nov 2025 12:54:41 +0000
Subject: [PATCH 2/6] Address review comments
---
.../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++-----
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 4 +-
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 115 +++++++++++++++++-
3 files changed, 143 insertions(+), 37 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2ddc8aa9a1ca..721f9d2cdb367 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20698,46 +20698,41 @@ static SDValue performFMACombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SelectionDAG &DAG = DCI.DAG;
- SDValue Op1 = N->getOperand(0);
- SDValue Op2 = N->getOperand(1);
- SDValue Op3 = N->getOperand(2);
+ SDValue OpA = N->getOperand(0);
+ SDValue OpB = N->getOperand(1);
+ SDValue OpC = N->getOperand(2);
EVT VT = N->getValueType(0);
SDLoc DL(N);
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
- if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
- (Subtarget->hasSVE() || Subtarget->hasSME())) {
- if (Op3.getOpcode() == ISD::FNEG) {
- unsigned int Opcode;
- if (Op1.getOpcode() == ISD::FNEG) {
- Op1 = Op1.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else if (Op2.getOpcode() == ISD::FNEG) {
- Op2 = Op2.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else {
- Opcode = AArch64ISD::FNMLS_PRED;
- }
- Op3 = Op3.getOperand(0);
- auto Pg = getPredicateForVector(DAG, DL, VT);
- if (VT.isFixedLengthVector()) {
- assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
- "Expected only legal fixed-width types");
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
- Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
- Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
- auto ScalableRes =
- DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
- return convertFromScalableVector(DAG, VT, ScalableRes);
- }
- return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
- }
+ if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+ !Subtarget->isSVEorStreamingSVEAvailable() ||
+ OpC.getOpcode() != ISD::FNEG) {
+ return SDValue();
+ }
+ unsigned int Opcode;
+ if (OpA.getOpcode() == ISD::FNEG) {
+ OpA = OpA.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else if (OpB.getOpcode() == ISD::FNEG) {
+ OpB = OpB.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else {
+ Opcode = AArch64ISD::FNMLS_PRED;
}
-
- return SDValue();
+ OpC = OpC.getOperand(0);
+ auto Pg = getPredicateForVector(DAG, DL, VT);
+ if (VT.isFixedLengthVector()) {
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ OpA = convertToScalableVector(DAG, ContainerVT, OpA);
+ OpB = convertToScalableVector(DAG, ContainerVT, OpB);
+ OpC = convertToScalableVector(DAG, ContainerVT, OpC);
+ auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+ }
+ return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
}
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 2291495e1f9e8..3722dcbb13d40 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -464,13 +464,11 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
- (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
- (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
- (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
+ (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
(AArch64fsub_p node:$pg, node:$op2, node:$op1)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 721066038769c..29dbb87f1b875 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -1,5 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
-; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
+; RUN: llc -mattr=+sme -force-streaming %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SME
+
+target triple = "aarch64"
define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
; CHECK-LABEL: fmsub_nxv2f64:
@@ -274,3 +277,113 @@ entry:
%0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
ret <8 x half> %0
}
+
+; Illegal types
+
+define <vscale x 3 x float> @fmsub_illegal_nxv3f32(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %c) {
+; CHECK-LABEL: fmsub_illegal_nxv3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 3 x float> %c
+ %0 = tail call <vscale x 3 x float> @llvm.fmuladd(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %neg)
+ ret <vscale x 3 x float> %0
+}
+
+define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
+; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: ptrue p0.d, vl1
+; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SME-LABEL: fmsub_illegal_v1f64:
+; CHECK-SME: // %bb.0: // %entry
+; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SME-NEXT: addvl sp, sp, #-1
+; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-SME-NEXT: .cfi_offset w29, -16
+; CHECK-SME-NEXT: ptrue p0.d, vl1
+; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SME-NEXT: str z0, [sp]
+; CHECK-SME-NEXT: ldr d0, [sp]
+; CHECK-SME-NEXT: addvl sp, sp, #1
+; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-SME-NEXT: ret
+entry:
+ %neg = fneg <1 x double> %c
+ %0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
+ ret <1 x double> %0
+}
+
+define <3 x float> @fmsub_flipped_illegal_v3f32(<3 x float> %c, <3 x float> %a, <3 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_illegal_v3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <3 x float> %c
+ %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %a, <3 x float> %b, <3 x float> %neg)
+ ret <3 x float> %0
+}
+
+define <vscale x 7 x half> @fnmsub_illegal_nxv7f16(<vscale x 7 x half> %a, <vscale x 7 x half> %b, <vscale x 7 x half> %c) {
+; CHECK-LABEL: fnmsub_illegal_nxv7f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 7 x half> %a
+ %neg1 = fneg <vscale x 7 x half> %c
+ %0 = tail call <vscale x 7 x half> @llvm.fmuladd(<vscale x 7 x half> %neg, <vscale x 7 x half> %b, <vscale x 7 x half> %neg1)
+ ret <vscale x 7 x half> %0
+}
+
+define <3 x float> @fnmsub_illegal_v3f32(<3 x float> %a, <3 x float> %b, <3 x float> %c) {
+; CHECK-LABEL: fnmsub_illegal_v3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <3 x float> %a
+ %neg1 = fneg <3 x float> %c
+ %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %neg, <3 x float> %b, <3 x float> %neg1)
+ ret <3 x float> %0
+}
+
+define <7 x half> @fnmsub_flipped_illegal_v7f16(<7 x half> %c, <7 x half> %a, <7 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_illegal_v7f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <7 x half> %a
+ %neg1 = fneg <7 x half> %c
+ %0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
+ ret <7 x half> %0
+}
>From ae2ba4054b463f9b0eb20be028081ea093a99092 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 24 Nov 2025 14:47:42 +0000
Subject: [PATCH 3/6] Use existing TableGen patterns
---
.../Target/AArch64/AArch64ISelLowering.cpp | 41 ++++++++++++-------
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 10 ++---
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 40 ++++++++++++++++++
3 files changed, 70 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 721f9d2cdb367..362806e8b7b99 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20704,6 +20704,7 @@ static SDValue performFMACombine(SDNode *N,
EVT VT = N->getValueType(0);
SDLoc DL(N);
+ // Convert FMA/FNEG nodes to SVE to enable the following patterns:
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
@@ -20712,27 +20713,37 @@ static SDValue performFMACombine(SDNode *N,
OpC.getOpcode() != ISD::FNEG) {
return SDValue();
}
- unsigned int Opcode;
+
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ EVT ContainerVT =
+ VT.isFixedLengthVector() ? getContainerForFixedLengthVector(DAG, VT) : VT;
+ OpC = VT.isFixedLengthVector()
+ ? convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0))
+ : OpC->getOperand(0);
+ OpC = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpC,
+ DAG.getUNDEF(ContainerVT));
+
+ if (OpB.getOpcode() == ISD::FNEG) {
+ std::swap(OpA, OpB);
+ }
+
if (OpA.getOpcode() == ISD::FNEG) {
- OpA = OpA.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else if (OpB.getOpcode() == ISD::FNEG) {
- OpB = OpB.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else {
- Opcode = AArch64ISD::FNMLS_PRED;
+ OpA = VT.isFixedLengthVector()
+ ? convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0))
+ : OpA->getOperand(0);
+ OpA = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpA,
+ DAG.getUNDEF(ContainerVT));
+ } else if (VT.isFixedLengthVector()) {
+ OpA = convertToScalableVector(DAG, ContainerVT, OpA);
}
- OpC = OpC.getOperand(0);
- auto Pg = getPredicateForVector(DAG, DL, VT);
+
if (VT.isFixedLengthVector()) {
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- OpA = convertToScalableVector(DAG, ContainerVT, OpA);
OpB = convertToScalableVector(DAG, ContainerVT, OpB);
- OpC = convertToScalableVector(DAG, ContainerVT, OpC);
- auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ SDValue ScalableRes =
+ DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
return convertFromScalableVector(DAG, VT, ScalableRes);
}
- return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
+ return DAG.getNode(AArch64ISD::FMA_PRED, DL, VT, Pg, OpA, OpB, OpC);
}
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3722dcbb13d40..e99b3f8ff07e0 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -240,8 +240,6 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
-def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>;
-def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>;
def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
@@ -462,13 +460,13 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
- (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+ [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+ (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
- (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
+ [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
+ (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
(AArch64fsub_p node:$pg, node:$op2, node:$op1)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 29dbb87f1b875..27827720b2cdb 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -176,6 +176,46 @@ entry:
ret <vscale x 8 x half> %0
}
+
+define <vscale x 2 x double> @fnmsub_negated_b_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 2 x double> %b
+ %neg1 = fneg <vscale x 2 x double> %c
+ %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %a, <vscale x 2 x double> %neg, <vscale x 2 x double> %neg1)
+ ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fnmsub_negated_b_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 4 x float> %b
+ %neg1 = fneg <vscale x 4 x float> %c
+ %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %a, <vscale x 4 x float> %neg, <vscale x 4 x float> %neg1)
+ ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fnmsub_negated_b_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fnmsub_negated_b_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 8 x half> %b
+ %neg1 = fneg <vscale x 8 x half> %c
+ %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %a, <vscale x 8 x half> %neg, <vscale x 8 x half> %neg1)
+ ret <vscale x 8 x half> %0
+}
+
define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: fnmsub_v2f64:
; CHECK: // %bb.0: // %entry
>From 9892c8652ae6e1c2e899e610fec0ce8a0710d6a8 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 24 Nov 2025 16:59:54 +0000
Subject: [PATCH 4/6] Remove SVE handling and add pattern to compensate
---
.../Target/AArch64/AArch64ISelLowering.cpp | 49 +++++++------------
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 1 +
2 files changed, 20 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 362806e8b7b99..5b1a6c47c94dd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20708,42 +20708,31 @@ static SDValue performFMACombine(SDNode *N,
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
- if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+ if (!VT.isFixedLengthVector() ||
+ !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
!Subtarget->isSVEorStreamingSVEAvailable() ||
OpC.getOpcode() != ISD::FNEG) {
return SDValue();
}
SDValue Pg = getPredicateForVector(DAG, DL, VT);
- EVT ContainerVT =
- VT.isFixedLengthVector() ? getContainerForFixedLengthVector(DAG, VT) : VT;
- OpC = VT.isFixedLengthVector()
- ? convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0))
- : OpC->getOperand(0);
- OpC = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpC,
- DAG.getUNDEF(ContainerVT));
-
- if (OpB.getOpcode() == ISD::FNEG) {
- std::swap(OpA, OpB);
- }
-
- if (OpA.getOpcode() == ISD::FNEG) {
- OpA = VT.isFixedLengthVector()
- ? convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0))
- : OpA->getOperand(0);
- OpA = DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg, OpA,
- DAG.getUNDEF(ContainerVT));
- } else if (VT.isFixedLengthVector()) {
- OpA = convertToScalableVector(DAG, ContainerVT, OpA);
- }
-
- if (VT.isFixedLengthVector()) {
- OpB = convertToScalableVector(DAG, ContainerVT, OpB);
- SDValue ScalableRes =
- DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
- return convertFromScalableVector(DAG, VT, ScalableRes);
- }
- return DAG.getNode(AArch64ISD::FMA_PRED, DL, VT, Pg, OpA, OpB, OpC);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ OpC =
+ DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
+ convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
+ DAG.getUNDEF(ContainerVT));
+
+ OpA = OpA.getOpcode() == ISD::FNEG
+ ? DAG.getNode(
+ AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
+ convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
+ DAG.getUNDEF(ContainerVT))
+ : convertToScalableVector(DAG, ContainerVT, OpA);
+
+ OpB = convertToScalableVector(DAG, ContainerVT, OpB);
+ SDValue ScalableRes =
+ DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
}
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index e99b3f8ff07e0..e46afa65d3c05 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -462,6 +462,7 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
+ (AArch64fma_p node:$pg, node:$zn, (AArch64fneg_mt node:$pg, node:$zm, (undef)), (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
>From 2879d78055896a3889a1c33aa50963bb4eede793 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 25 Nov 2025 16:28:09 +0000
Subject: [PATCH 5/6] Use custom lowering rather than DAG Combiner
---
.../Target/AArch64/AArch64ISelLowering.cpp | 82 ++++++++----------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 +
.../complex-deinterleaving-symmetric-fixed.ll | 22 +++--
.../AArch64/sve-fixed-length-fp-arith.ll | 24 ++++--
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 84 +++++++++++++------
5 files changed, 127 insertions(+), 86 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5b1a6c47c94dd..808a974eda694 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1170,8 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
setTargetDAGCombine(ISD::CTPOP);
- setTargetDAGCombine(ISD::FMA);
-
// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
MaxStoresPerMemset =
@@ -1526,6 +1524,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
+
+ for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
+ setOperationAction(ISD::FMA, VT, Custom);
+ }
}
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7732,6 +7734,37 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
return FCVTNT(VT, BottomBF16, Pg, TopF32);
}
+SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
+ SDValue OpA = Op->getOperand(0);
+ SDValue OpB = Op->getOperand(1);
+ SDValue OpC = Op->getOperand(2);
+ EVT VT = Op.getValueType();
+ SDLoc DL(Op);
+
+ // Bail early if we're definitely not looking to merge FNEGs into the FMA.
+ if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+ }
+
+ // Convert FMA/FNEG nodes to SVE to enable the following patterns:
+ // fma(a, b, neg(c)) -> fnmls(a, b, c)
+ // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+ // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ for (SDValue *Op : {&OpA, &OpB, &OpC}) {
+ // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
+ *Op = Op->getOpcode() == ISD::FNEG
+ ? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+ ->getOperand(0)
+ : convertToScalableVector(DAG, ContainerVT, *Op);
+ }
+ SDValue ScalableRes =
+ DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -7808,7 +7841,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::FMUL:
return LowerFMUL(Op, DAG);
case ISD::FMA:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+ return LowerFMA(Op, DAG);
case ISD::FDIV:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
case ISD::FNEG:
@@ -20694,47 +20727,6 @@ static SDValue performFADDCombine(SDNode *N,
return SDValue();
}
-static SDValue performFMACombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- const AArch64Subtarget *Subtarget) {
- SelectionDAG &DAG = DCI.DAG;
- SDValue OpA = N->getOperand(0);
- SDValue OpB = N->getOperand(1);
- SDValue OpC = N->getOperand(2);
- EVT VT = N->getValueType(0);
- SDLoc DL(N);
-
- // Convert FMA/FNEG nodes to SVE to enable the following patterns:
- // fma(a, b, neg(c)) -> fnmls(a, b, c)
- // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
- // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
- if (!VT.isFixedLengthVector() ||
- !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
- !Subtarget->isSVEorStreamingSVEAvailable() ||
- OpC.getOpcode() != ISD::FNEG) {
- return SDValue();
- }
-
- SDValue Pg = getPredicateForVector(DAG, DL, VT);
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- OpC =
- DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
- convertToScalableVector(DAG, ContainerVT, OpC.getOperand(0)),
- DAG.getUNDEF(ContainerVT));
-
- OpA = OpA.getOpcode() == ISD::FNEG
- ? DAG.getNode(
- AArch64ISD::FNEG_MERGE_PASSTHRU, DL, ContainerVT, Pg,
- convertToScalableVector(DAG, ContainerVT, OpA.getOperand(0)),
- DAG.getUNDEF(ContainerVT))
- : convertToScalableVector(DAG, ContainerVT, OpA);
-
- OpB = convertToScalableVector(DAG, ContainerVT, OpB);
- SDValue ScalableRes =
- DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
- return convertFromScalableVector(DAG, VT, ScalableRes);
-}
-
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
switch (Opcode) {
case ISD::STRICT_FADD:
@@ -28266,8 +28258,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performANDCombine(N, DCI);
case ISD::FADD:
return performFADDCombine(N, DCI);
- case ISD::FMA:
- return performFMACombine(N, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performIntrinsicCombine(N, DCI, Subtarget);
case ISD::ANY_EXTEND:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ca08eb40c956a..5f8b8adbede43 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -615,6 +615,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFMUL(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFMA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
index d05b9c6d7662a..49d33b9709e04 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
@@ -7,13 +7,18 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
; CHECK-LABEL: simple_symmetric_muladd2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov x8, #-7378697629483820647 // =0x9999999999999999
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q3 killed $q3 def $z3
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: movk x8, #39322
; CHECK-NEXT: movk x8, #16393, lsl #48
; CHECK-NEXT: dup v4.2d, x8
-; CHECK-NEXT: fmla v2.2d, v4.2d, v0.2d
-; CHECK-NEXT: fmla v3.2d, v4.2d, v1.2d
-; CHECK-NEXT: mov v0.16b, v2.16b
-; CHECK-NEXT: mov v1.16b, v3.16b
+; CHECK-NEXT: fmad z0.d, p0/m, z4.d, z2.d
+; CHECK-NEXT: fmad z1.d, p0/m, z4.d, z3.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: // kill: def $q1 killed $q1 killed $z1
; CHECK-NEXT: ret
entry:
%ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -43,10 +48,11 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
; CHECK-NEXT: zip1 v17.2d, v5.2d, v7.2d
; CHECK-NEXT: zip2 v5.2d, v5.2d, v7.2d
; CHECK-NEXT: dup v6.2d, x8
-; CHECK-NEXT: fmla v3.2d, v6.2d, v16.2d
-; CHECK-NEXT: fmla v4.2d, v6.2d, v0.2d
-; CHECK-NEXT: fmla v17.2d, v6.2d, v2.2d
-; CHECK-NEXT: fmla v5.2d, v6.2d, v1.2d
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: fmla z3.d, p0/m, z16.d, z6.d
+; CHECK-NEXT: fmla z4.d, p0/m, z0.d, z6.d
+; CHECK-NEXT: fmla z17.d, p0/m, z2.d, z6.d
+; CHECK-NEXT: fmla z5.d, p0/m, z1.d, z6.d
; CHECK-NEXT: zip1 v0.2d, v3.2d, v4.2d
; CHECK-NEXT: zip2 v2.2d, v3.2d, v4.2d
; CHECK-NEXT: zip1 v1.2d, v17.2d, v5.2d
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
index 2dda03e5c6dab..c6e87f9c3abcc 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
@@ -620,8 +620,12 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v8f16:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmla v2.8h, v1.8h, v0.8h
-; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
ret <8 x half> %res
@@ -730,8 +734,12 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v4f32:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmla v2.4s, v1.4s, v0.4s
-; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
ret <4 x float> %res
@@ -839,8 +847,12 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v2f64:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmla v2.2d, v1.2d, v0.2d
-; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
ret <2 x double> %res
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 27827720b2cdb..9b0205f658056 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -267,6 +267,57 @@ entry:
ret <8 x half> %0
}
+define <2 x double> @fnmsub_negated_b_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %b
+ %neg1 = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %neg, <2 x double> %neg1)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_negated_b_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %b
+ %neg1 = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %neg, <4 x float> %neg1)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_negated_b_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fnmsub_negated_b_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %b
+ %neg1 = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %neg, <8 x half> %neg1)
+ ret <8 x half> %0
+}
+
define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: fnmsub_flipped_v2f64:
; CHECK: // %bb.0: // %entry
@@ -333,32 +384,10 @@ entry:
}
define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
-; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
-; CHECK-SVE: // %bb.0: // %entry
-; CHECK-SVE-NEXT: ptrue p0.d, vl1
-; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
-; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
-; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
-; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
-; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
-; CHECK-SVE-NEXT: ret
-;
-; CHECK-SME-LABEL: fmsub_illegal_v1f64:
-; CHECK-SME: // %bb.0: // %entry
-; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-SME-NEXT: addvl sp, sp, #-1
-; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
-; CHECK-SME-NEXT: .cfi_offset w29, -16
-; CHECK-SME-NEXT: ptrue p0.d, vl1
-; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
-; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
-; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
-; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
-; CHECK-SME-NEXT: str z0, [sp]
-; CHECK-SME-NEXT: ldr d0, [sp]
-; CHECK-SME-NEXT: addvl sp, sp, #1
-; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
-; CHECK-SME-NEXT: ret
+; CHECK-LABEL: fmsub_illegal_v1f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fnmsub d0, d0, d1, d2
+; CHECK-NEXT: ret
entry:
%neg = fneg <1 x double> %c
%0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
@@ -427,3 +456,6 @@ entry:
%0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
ret <7 x half> %0
}
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK-SME: {{.*}}
+; CHECK-SVE: {{.*}}
>From dd43ed98426056ec8f9d2de5710596c7f2f1f541 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 1 Dec 2025 14:58:04 +0000
Subject: [PATCH 6/6] Address feedback
---
.../Target/AArch64/AArch64ISelLowering.cpp | 27 ++++++++++++-------
.../complex-deinterleaving-symmetric-fixed.ll | 22 ++++++---------
.../AArch64/sve-fixed-length-fp-arith.ll | 24 +++++------------
3 files changed, 31 insertions(+), 42 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 808a974eda694..38f834c1b52ac 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1525,9 +1525,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
- for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64}) {
+ for (auto VT : {MVT::v8f16, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::FMA, VT, Custom);
- }
}
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -7743,7 +7742,10 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
// Bail early if we're definitely not looking to merge FNEGs into the FMA.
if (!VT.isFixedLengthVector() || OpC.getOpcode() != ISD::FNEG) {
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+ if (VT.isScalableVector() || VT.getScalarType() == MVT::bf16 ||
+ useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
+ return Op; // Fallback to NEON lowering.
}
// Convert FMA/FNEG nodes to SVE to enable the following patterns:
@@ -7753,13 +7755,18 @@ SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {
SDValue Pg = getPredicateForVector(DAG, DL, VT);
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- for (SDValue *Op : {&OpA, &OpB, &OpC}) {
- // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
- *Op = Op->getOpcode() == ISD::FNEG
- ? LowerToPredicatedOp(*Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
- ->getOperand(0)
- : convertToScalableVector(DAG, ContainerVT, *Op);
- }
+ // Reuse `LowerToPredicatedOp` but drop the subsequent `extract_subvector`
+ OpA = OpA.getOpcode() == ISD::FNEG
+ ? LowerToPredicatedOp(OpA, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+ ->getOperand(0)
+ : convertToScalableVector(DAG, ContainerVT, OpA);
+ OpB = OpB.getOpcode() == ISD::FNEG
+ ? LowerToPredicatedOp(OpB, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+ ->getOperand(0)
+ : convertToScalableVector(DAG, ContainerVT, OpB);
+ OpC = LowerToPredicatedOp(OpC, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU)
+ ->getOperand(0);
+
SDValue ScalableRes =
DAG.getNode(AArch64ISD::FMA_PRED, DL, ContainerVT, Pg, OpA, OpB, OpC);
return convertFromScalableVector(DAG, VT, ScalableRes);
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
index 49d33b9709e04..d05b9c6d7662a 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll
@@ -7,18 +7,13 @@ define <4 x double> @simple_symmetric_muladd2(<4 x double> %a, <4 x double> %b)
; CHECK-LABEL: simple_symmetric_muladd2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov x8, #-7378697629483820647 // =0x9999999999999999
-; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT: // kill: def $q3 killed $q3 def $z3
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: movk x8, #39322
; CHECK-NEXT: movk x8, #16393, lsl #48
; CHECK-NEXT: dup v4.2d, x8
-; CHECK-NEXT: fmad z0.d, p0/m, z4.d, z2.d
-; CHECK-NEXT: fmad z1.d, p0/m, z4.d, z3.d
-; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
-; CHECK-NEXT: // kill: def $q1 killed $q1 killed $z1
+; CHECK-NEXT: fmla v2.2d, v4.2d, v0.2d
+; CHECK-NEXT: fmla v3.2d, v4.2d, v1.2d
+; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: mov v1.16b, v3.16b
; CHECK-NEXT: ret
entry:
%ext00 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -48,11 +43,10 @@ define <8 x double> @simple_symmetric_muladd4(<8 x double> %a, <8 x double> %b)
; CHECK-NEXT: zip1 v17.2d, v5.2d, v7.2d
; CHECK-NEXT: zip2 v5.2d, v5.2d, v7.2d
; CHECK-NEXT: dup v6.2d, x8
-; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: fmla z3.d, p0/m, z16.d, z6.d
-; CHECK-NEXT: fmla z4.d, p0/m, z0.d, z6.d
-; CHECK-NEXT: fmla z17.d, p0/m, z2.d, z6.d
-; CHECK-NEXT: fmla z5.d, p0/m, z1.d, z6.d
+; CHECK-NEXT: fmla v3.2d, v6.2d, v16.2d
+; CHECK-NEXT: fmla v4.2d, v6.2d, v0.2d
+; CHECK-NEXT: fmla v17.2d, v6.2d, v2.2d
+; CHECK-NEXT: fmla v5.2d, v6.2d, v1.2d
; CHECK-NEXT: zip1 v0.2d, v3.2d, v4.2d
; CHECK-NEXT: zip2 v2.2d, v3.2d, v4.2d
; CHECK-NEXT: zip1 v1.2d, v17.2d, v5.2d
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
index c6e87f9c3abcc..2dda03e5c6dab 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-arith.ll
@@ -620,12 +620,8 @@ define <4 x half> @fma_v4f16(<4 x half> %op1, <4 x half> %op2, <4 x half> %op3)
define <8 x half> @fma_v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v8f16:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p0.h, vl8
-; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h
-; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: fmla v2.8h, v1.8h, v0.8h
+; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%res = call <8 x half> @llvm.fma.v8f16(<8 x half> %op1, <8 x half> %op2, <8 x half> %op3)
ret <8 x half> %res
@@ -734,12 +730,8 @@ define <2 x float> @fma_v2f32(<2 x float> %op1, <2 x float> %op2, <2 x float> %o
define <4 x float> @fma_v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v4f32:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p0.s, vl4
-; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s
-; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: fmla v2.4s, v1.4s, v0.4s
+; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%res = call <4 x float> @llvm.fma.v4f32(<4 x float> %op1, <4 x float> %op2, <4 x float> %op3)
ret <4 x float> %res
@@ -847,12 +839,8 @@ define <1 x double> @fma_v1f64(<1 x double> %op1, <1 x double> %op2, <1 x double
define <2 x double> @fma_v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3) vscale_range(2,0) #0 {
; CHECK-LABEL: fma_v2f64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
-; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d
-; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: fmla v2.2d, v1.2d, v0.2d
+; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
%res = call <2 x double> @llvm.fma.v2f64(<2 x double> %op1, <2 x double> %op2, <2 x double> %op3)
ret <2 x double> %res
More information about the llvm-commits
mailing list