[llvm-branch-commits] [llvm] f9e9adf - [CodeGen] Fix incorrect pattern FMLA_* pseudo instructions

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 10 00:09:34 PDT 2023


Author: Igor Kirillov
Date: 2023-08-10T09:04:50+02:00
New Revision: f9e9adf9b90db8a646fc67154a33e0edf0b37e10

URL: https://github.com/llvm/llvm-project/commit/f9e9adf9b90db8a646fc67154a33e0edf0b37e10
DIFF: https://github.com/llvm/llvm-project/commit/f9e9adf9b90db8a646fc67154a33e0edf0b37e10.diff

LOG: [CodeGen] Fix incorrect pattern FMLA_* pseudo instructions

* Remove the incorrect patterns from AArch64fmla_p/AArch64fmls_p
* Add correct patterns to AArch64fmla_m1/AArch64fmls_m1
* Refactor fma_patfrags for the sake of PatFrags

Fixes https://github.com/llvm/llvm-project/issues/64419

Differential Revision: https://reviews.llvm.org/D157095

(cherry picked from commit 84d444f90900d1b9d6c08be61f8d62090df28042)

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-fp-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index ad404e8dab2ad4..0710c654a95df6 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -204,10 +204,18 @@ def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
 def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
 def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
 
+def AArch64fadd_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
+                                     (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
+  return N->getFlags().hasAllowContract();
+}]>;
 def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3),
                                 (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
   return N->getFlags().hasNoSignedZeros();
 }]>;
+def AArch64fsub_p_contract : PatFrag<(ops node:$op1, node:$op2, node:$op3),
+                                     (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{
+  return N->getFlags().hasAllowContract();
+}]>;
 def AArch64fsub_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3),
                                 (AArch64fsub_p node:$op1, node:$op2, node:$op3), [{
   return N->getFlags().hasNoSignedZeros();
@@ -363,14 +371,12 @@ def AArch64fabd_p : PatFrags<(ops node:$pg, node:$op1, node:$op2),
                               (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)]>;
 
 def AArch64fmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
-                             [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za),
-                              (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), node:$zn, node:$zm, node:$za), node:$za)]>;
+                             [(AArch64fma_p node:$pg, node:$zn, node:$zm, node:$za)]>;
 
 def AArch64fmls_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
                              [(int_aarch64_sve_fmls_u node:$pg, node:$za, node:$zn, node:$zm),
                               (AArch64fma_p node:$pg, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$zm, node:$za),
-                              (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za),
-                              (vselect node:$pg, (AArch64fma_p (AArch64ptrue 31), (AArch64fneg_mt (AArch64ptrue 31), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>;
+                              (AArch64fma_p node:$pg, node:$zm, (AArch64fneg_mt node:$pg, node:$zn, (undef)), node:$za)]>;
 
 def AArch64fnmla_p : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
                               [(int_aarch64_sve_fnmla_u node:$pg, node:$za, node:$zn, node:$zm),
@@ -423,18 +429,15 @@ def AArch64eor3 : PatFrags<(ops node:$op1, node:$op2, node:$op3),
                            [(int_aarch64_sve_eor3 node:$op1, node:$op2, node:$op3),
                             (xor node:$op1, (xor node:$op2, node:$op3))]>;
 
-class fma_patfrags<SDPatternOperator intrinsic, SDPatternOperator add>
-    : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
-               [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3),
-                (vselect node:$pred, (add (SVEAllActive), node:$op1, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3)), node:$op1)],
-[{
-  if (N->getOpcode() == ISD::VSELECT)
-    return N->getOperand(1)->getFlags().hasAllowContract();
-  return true;  // it's the intrinsic
-}]>;
+def AArch64fmla_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
+                              [(int_aarch64_sve_fmla node:$pg, node:$za, node:$zn, node:$zm),
+                               (vselect node:$pg, (AArch64fadd_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za),
+                               (vselect node:$pg, (AArch64fma_p (SVEAllActive), node:$zn, node:$zm, node:$za), node:$za)]>;
 
-def AArch64fmla_m1 : fma_patfrags<int_aarch64_sve_fmla, AArch64fadd_p>;
-def AArch64fmls_m1 : fma_patfrags<int_aarch64_sve_fmls, AArch64fsub_p>;
+def AArch64fmls_m1 : PatFrags<(ops node:$pg, node:$za, node:$zn, node:$zm),
+                              [(int_aarch64_sve_fmls node:$pg, node:$za, node:$zn, node:$zm),
+                               (vselect node:$pg, (AArch64fsub_p_contract (SVEAllActive), node:$za, (AArch64fmul_p_oneuse (SVEAllActive), node:$zn, node:$zm)), node:$za),
+                               (vselect node:$pg, (AArch64fma_p (SVEAllActive), (AArch64fneg_mt (SVEAllActive), node:$zn, (undef)), node:$zm, node:$za), node:$za)]>;
 
 def AArch64add_m1 : VSelectUnpredOrPassthruPatFrags<int_aarch64_sve_add, add>;
 def AArch64sub_m1 : VSelectUnpredOrPassthruPatFrags<int_aarch64_sve_sub, sub>;

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 118862b8c317cb..c4c0dca114ce7c 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2317,7 +2317,10 @@ multiclass sve_fp_3op_p_zds_a<bits<2> opc, string asm, string Ps,
            SVEPseudo2Instr<Ps # _D, 1>, SVEInstr2Rev<NAME # _D, revname # _D, isReverseInstr>;
 
   def : SVE_4_Op_Pat<nxv8f16, op, nxv8i1, nxv8f16, nxv8f16, nxv8f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_4_Op_Pat<nxv4f16, op, nxv4i1, nxv4f16, nxv4f16, nxv4f16, !cast<Instruction>(NAME # _H)>;
+  def : SVE_4_Op_Pat<nxv2f16, op, nxv2i1, nxv2f16, nxv2f16, nxv2f16, !cast<Instruction>(NAME # _H)>;
   def : SVE_4_Op_Pat<nxv4f32, op, nxv4i1, nxv4f32, nxv4f32, nxv4f32, !cast<Instruction>(NAME # _S)>;
+  def : SVE_4_Op_Pat<nxv2f32, op, nxv2i1, nxv2f32, nxv2f32, nxv2f32, !cast<Instruction>(NAME # _S)>;
   def : SVE_4_Op_Pat<nxv2f64, op, nxv2i1, nxv2f64, nxv2f64, nxv2f64, !cast<Instruction>(NAME # _D)>;
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
index 14471584bf2863..e53f76f6512127 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
@@ -1271,7 +1271,8 @@ define <vscale x 4 x float> @fadd_sel_fmul_no_contract_s(<vscale x 4 x float> %a
 define <vscale x 8 x half> @fma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pred, <vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc) {
 ; CHECK-LABEL: fma_sel_h_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmad z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    fmla z2.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %mul.add = call <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc)
   %masked.mul.add = select <vscale x 8 x i1> %pred, <vscale x 8 x half> %mul.add, <vscale x 8 x half> %acc
@@ -1281,7 +1282,8 @@ define <vscale x 8 x half> @fma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pre
 define <vscale x 4 x float> @fma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pred, <vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc) {
 ; CHECK-LABEL: fma_sel_s_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    fmla z2.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %mul.add = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc)
   %masked.mul.add = select <vscale x 4 x i1> %pred, <vscale x 4 x float> %mul.add, <vscale x 4 x float> %acc
@@ -1291,7 +1293,8 @@ define <vscale x 4 x float> @fma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pr
 define <vscale x 2 x double> @fma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %pred, <vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc) {
 ; CHECK-LABEL: fma_sel_d_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    fmla z2.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %mul.add = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc)
   %masked.mul.add = select <vscale x 2 x i1> %pred, <vscale x 2 x double> %mul.add, <vscale x 2 x double> %acc
@@ -1301,7 +1304,8 @@ define <vscale x 2 x double> @fma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %p
 define <vscale x 8 x half> @fnma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pred, <vscale x 8 x half> %m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc) {
 ; CHECK-LABEL: fnma_sel_h_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmsb z0.h, p0/m, z1.h, z2.h
+; CHECK-NEXT:    fmls z2.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %neg_m1 = fneg contract <vscale x 8 x half> %m1
   %mul.add = call <vscale x 8 x half> @llvm.fma.nxv8f16(<vscale x 8 x half> %neg_m1, <vscale x 8 x half> %m2, <vscale x 8 x half> %acc)
@@ -1312,7 +1316,8 @@ define <vscale x 8 x half> @fnma_sel_h_
diff erent_arg_order(<vscale x 8 x i1> %pr
 define <vscale x 4 x float> @fnma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %pred, <vscale x 4 x float> %m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc) {
 ; CHECK-LABEL: fnma_sel_s_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmsb z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    fmls z2.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %neg_m1 = fneg contract <vscale x 4 x float> %m1
   %mul.add = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> %neg_m1, <vscale x 4 x float> %m2, <vscale x 4 x float> %acc)
@@ -1323,7 +1328,8 @@ define <vscale x 4 x float> @fnma_sel_s_
diff erent_arg_order(<vscale x 4 x i1> %p
 define <vscale x 2 x double> @fnma_sel_d_
diff erent_arg_order(<vscale x 2 x i1> %pred, <vscale x 2 x double> %m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc) {
 ; CHECK-LABEL: fnma_sel_d_
diff erent_arg_order:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmsb z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    fmls z2.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    mov z0.d, z2.d
 ; CHECK-NEXT:    ret
   %neg_m1 = fneg contract <vscale x 2 x double> %m1
   %mul.add = call <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %neg_m1, <vscale x 2 x double> %m2, <vscale x 2 x double> %acc)


        


More information about the llvm-branch-commits mailing list