[llvm] [AArch64] Combine vector FNEG+FMA into `FNML[A|S]` (PR #167900)
Damian Heaton via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 05:17:44 PST 2025
https://github.com/dheaton-arm updated https://github.com/llvm/llvm-project/pull/167900
>From 9aa8dd52093712c9f01f28fa29487c00a21dac83 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Thu, 13 Nov 2025 15:42:11 +0000
Subject: [PATCH 1/2] Combine vector FNEG+FMA into `FNML[A|S]`
This allows for FNEG + FMA sequences to be combined into a
single operation, with `FNML[A|S]`, `FNMAD`, or `FNMSB` selected
depending on the operand order.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 50 ++++
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 8 +-
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 276 ++++++++++++++++++
3 files changed, 332 insertions(+), 2 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-fmsub.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7b51f453b4974..79625dd766085 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1176,6 +1176,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
setTargetDAGCombine(ISD::CTPOP);
+ setTargetDAGCombine(ISD::FMA);
+
// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
MaxStoresPerMemset =
@@ -20444,6 +20446,52 @@ static SDValue performFADDCombine(SDNode *N,
return SDValue();
}
+static SDValue performFMACombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+ SelectionDAG &DAG = DCI.DAG;
+ SDValue Op1 = N->getOperand(0);
+ SDValue Op2 = N->getOperand(1);
+ SDValue Op3 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // fma(a, b, neg(c)) -> fnmls(a, b, c)
+ // fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
+ // fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
+ if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+ (Subtarget->hasSVE() || Subtarget->hasSME())) {
+ if (Op3.getOpcode() == ISD::FNEG) {
+ unsigned int Opcode;
+ if (Op1.getOpcode() == ISD::FNEG) {
+ Op1 = Op1.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else if (Op2.getOpcode() == ISD::FNEG) {
+ Op2 = Op2.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else {
+ Opcode = AArch64ISD::FNMLS_PRED;
+ }
+ Op3 = Op3.getOperand(0);
+ auto Pg = getPredicateForVector(DAG, DL, VT);
+ if (VT.isFixedLengthVector()) {
+ assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+ "Expected only legal fixed-width types");
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
+ Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
+ Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
+ auto ScalableRes =
+ DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+ }
+ return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
+ }
+ }
+
+ return SDValue();
+}
+
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
switch (Opcode) {
case ISD::STRICT_FADD:
@@ -27977,6 +28025,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performANDCombine(N, DCI);
case ISD::FADD:
return performFADDCombine(N, DCI);
+ case ISD::FMA:
+ return performFMACombine(N, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performIntrinsicCombine(N, DCI, Subtarget);
case ISD::ANY_EXTEND:
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index c8c21c4822ffe..4640719cda43c 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -240,6 +240,8 @@ def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>;
def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
+def AArch64fnmla_p_node : SDNode<"AArch64ISD::FNMLA_PRED", SDT_AArch64FMA>;
+def AArch64fnmls_p_node : SDNode<"AArch64ISD::FNMLS_PRED", SDT_AArch64FMA>;
def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
@@ -460,12 +462,14 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
+ [(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
+ (int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
+ [(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
+ (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
new file mode 100644
index 0000000000000..721066038769c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -0,0 +1,276 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+
+define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fmsub_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 2 x double> %c
+ %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %neg)
+ ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fmsub_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 4 x float> %c
+ %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %neg)
+ ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fmsub_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 8 x half> %c
+ %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %neg)
+ ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fmsub_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fmsub_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fmsub_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+ ret <8 x half> %0
+}
+
+
+define <2 x double> @fmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fmsub_flipped_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %neg)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %a, <4 x float> %b, <4 x float> %neg)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fmsub_flipped_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %a, <8 x half> %b, <8 x half> %neg)
+ ret <8 x half> %0
+}
+
+define <vscale x 2 x double> @fnmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
+; CHECK-LABEL: fnmsub_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 2 x double> %a
+ %neg1 = fneg <vscale x 2 x double> %c
+ %0 = tail call <vscale x 2 x double> @llvm.fmuladd(<vscale x 2 x double> %neg, <vscale x 2 x double> %b, <vscale x 2 x double> %neg1)
+ ret <vscale x 2 x double> %0
+}
+
+define <vscale x 4 x float> @fnmsub_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c) {
+; CHECK-LABEL: fnmsub_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 4 x float> %a
+ %neg1 = fneg <vscale x 4 x float> %c
+ %0 = tail call <vscale x 4 x float> @llvm.fmuladd(<vscale x 4 x float> %neg, <vscale x 4 x float> %b, <vscale x 4 x float> %neg1)
+ ret <vscale x 4 x float> %0
+}
+
+define <vscale x 8 x half> @fnmsub_nxv8f16(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) {
+; CHECK-LABEL: fnmsub_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 8 x half> %a
+ %neg1 = fneg <vscale x 8 x half> %c
+ %0 = tail call <vscale x 8 x half> @llvm.fmuladd(<vscale x 8 x half> %neg, <vscale x 8 x half> %b, <vscale x 8 x half> %neg1)
+ ret <vscale x 8 x half> %0
+}
+
+define <2 x double> @fnmsub_v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: fnmsub_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %a
+ %neg1 = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
+; CHECK-LABEL: fnmsub_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %a
+ %neg1 = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
+; CHECK-LABEL: fnmsub_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %a
+ %neg1 = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+ ret <8 x half> %0
+}
+
+define <2 x double> @fnmsub_flipped_v2f64(<2 x double> %c, <2 x double> %a, <2 x double> %b) {
+; CHECK-LABEL: fnmsub_flipped_v2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <2 x double> %a
+ %neg1 = fneg <2 x double> %c
+ %0 = tail call <2 x double> @llvm.fmuladd(<2 x double> %neg, <2 x double> %b, <2 x double> %neg1)
+ ret <2 x double> %0
+}
+
+define <4 x float> @fnmsub_flipped_v4f32(<4 x float> %c, <4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: fnmsub_flipped_v4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <4 x float> %a
+ %neg1 = fneg <4 x float> %c
+ %0 = tail call <4 x float> @llvm.fmuladd(<4 x float> %neg, <4 x float> %b, <4 x float> %neg1)
+ ret <4 x float> %0
+}
+
+define <8 x half> @fnmsub_flipped_v8f16(<8 x half> %c, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_v8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <8 x half> %a
+ %neg1 = fneg <8 x half> %c
+ %0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
+ ret <8 x half> %0
+}
>From d4e4360f27eb9536c3f74b1e8802cd600216bba2 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 18 Nov 2025 12:54:41 +0000
Subject: [PATCH 2/2] Address review comments
---
.../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++-----
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 4 +-
llvm/test/CodeGen/AArch64/sve-fmsub.ll | 115 +++++++++++++++++-
3 files changed, 143 insertions(+), 37 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 79625dd766085..08aec2c2cb79b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20450,46 +20450,41 @@ static SDValue performFMACombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SelectionDAG &DAG = DCI.DAG;
- SDValue Op1 = N->getOperand(0);
- SDValue Op2 = N->getOperand(1);
- SDValue Op3 = N->getOperand(2);
+ SDValue OpA = N->getOperand(0);
+ SDValue OpB = N->getOperand(1);
+ SDValue OpC = N->getOperand(2);
EVT VT = N->getValueType(0);
SDLoc DL(N);
// fma(a, b, neg(c)) -> fnmls(a, b, c)
// fma(neg(a), b, neg(c)) -> fnmla(a, b, c)
// fma(a, neg(b), neg(c)) -> fnmla(a, b, c)
- if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
- (Subtarget->hasSVE() || Subtarget->hasSME())) {
- if (Op3.getOpcode() == ISD::FNEG) {
- unsigned int Opcode;
- if (Op1.getOpcode() == ISD::FNEG) {
- Op1 = Op1.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else if (Op2.getOpcode() == ISD::FNEG) {
- Op2 = Op2.getOperand(0);
- Opcode = AArch64ISD::FNMLA_PRED;
- } else {
- Opcode = AArch64ISD::FNMLS_PRED;
- }
- Op3 = Op3.getOperand(0);
- auto Pg = getPredicateForVector(DAG, DL, VT);
- if (VT.isFixedLengthVector()) {
- assert(DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
- "Expected only legal fixed-width types");
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
- Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
- Op3 = convertToScalableVector(DAG, ContainerVT, Op3);
- auto ScalableRes =
- DAG.getNode(Opcode, DL, ContainerVT, Pg, Op1, Op2, Op3);
- return convertFromScalableVector(DAG, VT, ScalableRes);
- }
- return DAG.getNode(Opcode, DL, VT, Pg, Op1, Op2, Op3);
- }
+ if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+ !Subtarget->isSVEorStreamingSVEAvailable() ||
+ OpC.getOpcode() != ISD::FNEG) {
+ return SDValue();
+ }
+ unsigned int Opcode;
+ if (OpA.getOpcode() == ISD::FNEG) {
+ OpA = OpA.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else if (OpB.getOpcode() == ISD::FNEG) {
+ OpB = OpB.getOperand(0);
+ Opcode = AArch64ISD::FNMLA_PRED;
+ } else {
+ Opcode = AArch64ISD::FNMLS_PRED;
}
-
- return SDValue();
+ OpC = OpC.getOperand(0);
+ auto Pg = getPredicateForVector(DAG, DL, VT);
+ if (VT.isFixedLengthVector()) {
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ OpA = convertToScalableVector(DAG, ContainerVT, OpA);
+ OpB = convertToScalableVector(DAG, ContainerVT, OpB);
+ OpC = convertToScalableVector(DAG, ContainerVT, OpC);
+ auto ScalableRes = DAG.getNode(Opcode, DL, ContainerVT, Pg, OpA, OpB, OpC);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+ }
+ return DAG.getNode(Opcode, DL, VT, Pg, OpA, OpB, OpC);
}
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 4640719cda43c..2d90123d37e01 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -464,13 +464,11 @@ def AArch64fmlsidx : PatFrags<(ops node:$acc, node:$op1, node:$op2, node:$idx),
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(AArch64fnmla_p_node node:$pg, node:$zn, node:$zm, node:$za),
(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
- (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef))),
(AArch64fneg_mt_nsz node:$pg, (AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za), (undef))]>;
def AArch64fnmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(AArch64fnmls_p_node node:$pg, node:$zn, node:$zm, node:$za),
- (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm),
- (AArch64fma_p node:$pg, node:$zn, node:$zm, (AArch64fneg_mt node:$pg, node:$za, (undef)))]>;
+ (int_aarch64_sve_fnmls_u node:$pg, node:$za, node:$zn, node:$zm)]>;
def AArch64fsubr_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
(AArch64fsub_p node:$pg, node:$op2, node:$op1)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-fmsub.ll b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
index 721066038769c..29dbb87f1b875 100644
--- a/llvm/test/CodeGen/AArch64/sve-fmsub.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fmsub.ll
@@ -1,5 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
-; RUN: llc -mtriple=aarch64 -mattr=+v9a,+sve2,+crypto,+bf16,+sm4,+i8mm,+sve2-bitperm,+sve2-sha3,+sve2-aes,+sve2-sm4 %s -o - | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
+; RUN: llc -mattr=+sme -force-streaming %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SME
+
+target triple = "aarch64"
define <vscale x 2 x double> @fmsub_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c) {
; CHECK-LABEL: fmsub_nxv2f64:
@@ -274,3 +277,113 @@ entry:
%0 = tail call <8 x half> @llvm.fmuladd(<8 x half> %neg, <8 x half> %b, <8 x half> %neg1)
ret <8 x half> %0
}
+
+; Illegal types
+
+define <vscale x 3 x float> @fmsub_illegal_nxv3f32(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %c) {
+; CHECK-LABEL: fmsub_illegal_nxv3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fnmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 3 x float> %c
+ %0 = tail call <vscale x 3 x float> @llvm.fmuladd(<vscale x 3 x float> %a, <vscale x 3 x float> %b, <vscale x 3 x float> %neg)
+ ret <vscale x 3 x float> %0
+}
+
+define <1 x double> @fmsub_illegal_v1f64(<1 x double> %a, <1 x double> %b, <1 x double> %c) {
+; CHECK-SVE-LABEL: fmsub_illegal_v1f64:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: ptrue p0.d, vl1
+; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-SVE-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-SVE-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SME-LABEL: fmsub_illegal_v1f64:
+; CHECK-SME: // %bb.0: // %entry
+; CHECK-SME-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-SME-NEXT: addvl sp, sp, #-1
+; CHECK-SME-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x38, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-SME-NEXT: .cfi_offset w29, -16
+; CHECK-SME-NEXT: ptrue p0.d, vl1
+; CHECK-SME-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-SME-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-SME-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-SME-NEXT: fnmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-SME-NEXT: str z0, [sp]
+; CHECK-SME-NEXT: ldr d0, [sp]
+; CHECK-SME-NEXT: addvl sp, sp, #1
+; CHECK-SME-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-SME-NEXT: ret
+entry:
+ %neg = fneg <1 x double> %c
+ %0 = tail call <1 x double> @llvm.fmuladd(<1 x double> %a, <1 x double> %b, <1 x double> %neg)
+ ret <1 x double> %0
+}
+
+define <3 x float> @fmsub_flipped_illegal_v3f32(<3 x float> %c, <3 x float> %a, <3 x float> %b) {
+; CHECK-LABEL: fmsub_flipped_illegal_v3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmls z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <3 x float> %c
+ %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %a, <3 x float> %b, <3 x float> %neg)
+ ret <3 x float> %0
+}
+
+define <vscale x 7 x half> @fnmsub_illegal_nxv7f16(<vscale x 7 x half> %a, <vscale x 7 x half> %b, <vscale x 7 x half> %c) {
+; CHECK-LABEL: fnmsub_illegal_nxv7f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: fnmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <vscale x 7 x half> %a
+ %neg1 = fneg <vscale x 7 x half> %c
+ %0 = tail call <vscale x 7 x half> @llvm.fmuladd(<vscale x 7 x half> %neg, <vscale x 7 x half> %b, <vscale x 7 x half> %neg1)
+ ret <vscale x 7 x half> %0
+}
+
+define <3 x float> @fnmsub_illegal_v3f32(<3 x float> %a, <3 x float> %b, <3 x float> %c) {
+; CHECK-LABEL: fnmsub_illegal_v3f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <3 x float> %a
+ %neg1 = fneg <3 x float> %c
+ %0 = tail call <3 x float> @llvm.fmuladd(<3 x float> %neg, <3 x float> %b, <3 x float> %neg1)
+ ret <3 x float> %0
+}
+
+define <7 x half> @fnmsub_flipped_illegal_v7f16(<7 x half> %c, <7 x half> %a, <7 x half> %b) {
+; CHECK-LABEL: fnmsub_flipped_illegal_v7f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h, vl8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: fnmla z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+entry:
+ %neg = fneg <7 x half> %a
+ %neg1 = fneg <7 x half> %c
+ %0 = tail call <7 x half> @llvm.fmuladd(<7 x half> %neg, <7 x half> %b, <7 x half> %neg1)
+ ret <7 x half> %0
+}
More information about the llvm-commits
mailing list