[llvm-branch-commits] [llvm] f9e9adf - [CodeGen] Fix incorrect pattern FMLA_* pseudo instructions
Tobias Hieta via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 10 00:09:34 PDT 2023
Author: Igor Kirillov
Date: 2023-08-10T09:04:50+02:00
New Revision: f9e9adf9b90db8a646fc67154a33e0edf0b37e10
URL: https://github.com/llvm/llvm-project/commit/f9e9adf9b90db8a646fc67154a33e0edf0b37e10
DIFF: https://github.com/llvm/llvm-project/commit/f9e9adf9b90db8a646fc67154a33e0edf0b37e10.diff
LOG: [CodeGen] Fix incorrect pattern FMLA_* pseudo instructions
* Remove the incorrect patterns from AArch64fmla_p/AArch64fmls_p
* Add correct patterns to AArch64fmla_m1/AArch64fmls_m1
* Refactor fma_patfrags for the sake of PatFrags
Fixes https://github.com/llvm/llvm-project/issues/64419
Differential Revision: https://reviews.llvm.org/D157095
(cherry picked from commit 84d444f90900d1b9d6c08be61f8d62090df28042)
Added:
Modified:
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
llvm/test/CodeGen/AArch64/sve-fp-combine.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index ad404e8dab2ad4..0710c654a95df6 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -204,10 +204,18 @@ def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
+def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
+ (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
+ return N->getFlags().hasAllowContract();
+}]>;
def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
return N->getFlags().hasNoSignedZeros();
}]>;
+def AArch64fsub_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
+ (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{
+ return N->getFlags().hasAllowContract();
+}]>;
def AArch64fsub_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3),
(AArch64fsub_p node:$op1, node:$op2, node:$op3), [{
return N->getFlags().hasNoSignedZeros();
@@ -363,14 +371,12 @@ def AArch64fabd_p : PatFrags<(ops node:$pg, node:$op1, node:$op2),
(AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)]>;
def AArch64fmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
- [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za),
- (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>;
+ [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za)]>;
def AArch64fmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(int_aarch64_sve_fmls_u node:$pg, node:$za, node:$zn, node:$zm),
(AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, node:$za),
- (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za),
- (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>;
+ (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za)]>;
def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
[(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
@@ -423,18 +429,15 @@ def AArch64eor3 : PatFrags<(ops node:$op1, node:$op2, node:$op3),
[(int_aarch64_sve_eor3 node:$op1, node:$op2, node:$op3),
(xor node:$op1, (xor node:$op2, node:$op3))]>;
-class fma_patfrags<SDPatternOperator intrinsic, SDPatternOperator add>
- : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
- [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3),
- (vselect node:$pred, (add (SVEAllActive), node:$op1, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)],
-[{
- if (N->getOpcode() == ISD::VSELECT)
- return N->getOperand(1)->getFlags().hasAllowContract();
- return true; // it's the intrinsic
-}]>;
+def AArch64fmla_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
+ [(int_aarch64_sve_fmla node:$pg, node:$za, node:$zn, node:$zm),
+ (vselect node:$pg, (AArch64fadd_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za),
+ (vselect node:$pg, (AArch64fma_p (SVEAllActive), node:$zn, node:$zm, node:$za), node:$za)]>;
-def AArch64fmla_m1 : fma_patfrags<int_aarch64_sve_fmla, AArch64fadd_p>;
-def AArch64fmls_m1 : fma_patfrags<int_aarch64_sve_fmls, AArch64fsub_p>;
+def AArch64fmls_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
+ [(int_aarch64_sve_fmls node:$pg, node:$za, node:$zn, node:$zm),
+ (vselect node:$pg, (AArch64fsub_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za),
+ (vselect node:$pg, (AArch64fma_p (SVEAllActive), (AArch64fneg_mt (SVEAllActive), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>;
def AArch64add_m1 : VSelectUnpredOrPassthruPatFrags<int_aarch64_sve_add, add>;
def AArch64sub_m1 : VSelectUnpredOrPassthruPatFrags<int_aarch64_sve_sub, sub>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 118862b8c317cb..c4c0dca114ce7c 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2317,7 +2317,10 @@ multiclass sve_fp_3op_p_zds_a<bits<2> opc, string asm, string Ps,
SVEPseudo2Instr<Ps # _D, 1>, SVEInstr2Rev<NAME # _D, revname # _D, isReverseInstr>;
def : SVE_4_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
+ def : SVE_4_Op_Pat<nxv4f16, op, nxv4i1, nxv4f16, nxv4f16, nxv4f16, !cast<Instruction>(NAME # _H)>;
+ def : SVE_4_Op_Pat<nxv2f16, op, nxv2i1, nxv2f16, nxv2f16, nxv2f16, !cast<Instruction>(NAME # _H)>;
def : SVE_4_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
+ def : SVE_4_Op_Pat<nxv2f32, op, nxv2i1, nxv2f32, nxv2f32, nxv2f32, !cast<Instruction>(NAME # _S)>;
def : SVE_4_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
}
diff --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
index 14471584bf2863..e53f76f6512127 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
@@ -1271,7 +1271,8 @@ define <vscale x 4 x float> @fadd_sel_fmul_no_contract_s(<vscale x 4 x float> %a
define <vscale x 8 x half> @fma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pred, <vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc) {
; CHECK-LABEL: fma_sel_h_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: fmla z2.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%mul.add = call <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc)
%masked.mul.add = select <vscale x 8 x i1> %pred, <vscale x 8 x half> %mul.add, <vscale x 8 x half> %acc
@@ -1281,7 +1282,8 @@ define <vscale x 8 x half> @fma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pre
define <vscale x 4 x float> @fma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pred, <vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc) {
; CHECK-LABEL: fma_sel_s_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: fmla z2.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%mul.add = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc)
%masked.mul.add = select <vscale x 4 x i1> %pred, <vscale x 4 x float> %mul.add, <vscale x 4 x float> %acc
@@ -1291,7 +1293,8 @@ define <vscale x 4 x float> @fma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pr
define <vscale x 2 x double> @fma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %pred, <vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc) {
; CHECK-LABEL: fma_sel_d_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: fmla z2.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%mul.add = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc)
%masked.mul.add = select <vscale x 2 x i1> %pred, <vscale x 2 x double> %mul.add, <vscale x 2 x double> %acc
@@ -1301,7 +1304,8 @@ define <vscale x 2 x double> @fma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %p
define <vscale x 8 x half> @fnma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pred, <vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc) {
; CHECK-LABEL: fnma_sel_h_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT: fmls z2.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%neg_m1 = fneg contract <vscale x 8 x half> %m1
%mul.add = call <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half> %neg_m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc)
@@ -1312,7 +1316,8 @@ define <vscale x 8 x half> @fnma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pr
define <vscale x 4 x float> @fnma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pred, <vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc) {
; CHECK-LABEL: fnma_sel_s_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: fmls z2.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%neg_m1 = fneg contract <vscale x 4 x float> %m1
%mul.add = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> %neg_m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc)
@@ -1323,7 +1328,8 @@ define <vscale x 4 x float> @fnma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %p
define <vscale x 2 x double> @fnma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %pred, <vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc) {
; CHECK-LABEL: fnma_sel_d_
diff erent_arg_order:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: fmls z2.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: ret
%neg_m1 = fneg contract <vscale x 2 x double> %m1
%mul.add = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %neg_m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc)
More information about the llvm-branch-commits
mailing list