[llvm] a6dec9f - [AArch64][SVE] Add patterns to select masked FP arith

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 8 01:44:57 PDT 2022


Author: Cullen Rhodes
Date: 2022-08-08T08:44:13Z
New Revision: a6dec9f5b2840b77dcf3c1731a68428893501ade

URL: https://github.com/llvm/llvm-project/commit/a6dec9f5b2840b77dcf3c1731a68428893501ade
DIFF: https://github.com/llvm/llvm-project/commit/a6dec9f5b2840b77dcf3c1731a68428893501ade.diff

LOG: [AArch64][SVE] Add patterns to select masked FP arith

Add patterns to select predicated instructions when lowering:

  fadd(a, select(mask, b, splat(0)))
  fsub(a, select(mask, b, splat(0)))

'fadd' is unsafe unless no-signed zeros fast-math flag is set, since

  -0.0 + 0.0 = 0.0

changes the sign. Alive2: https://alive2.llvm.org/ce/z/wbhJh_

Also adds FMA patterns for:

  fadd(a, select(mask, mul(b, c), splat(0))) -> fmla(a, mask, b, c)
  fsub(a, select(mask, mul(b, c), splat(0))) -> fmla(a, mask, b, c)

These patterns require the 'contract' fast-math flag to be set, and the
fadd 'nsz' as above.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D130564

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-fp-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 9b040860cc3c3..cd424ff07d184 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -199,6 +199,11 @@ def AArch64umax_p : SDNode<"AArch64ISD::UMAX_PRED", SDT_AArch64Arith>;
 def AArch64umin_p : SDNode<"AArch64ISD::UMIN_PRED", SDT_AArch64Arith>;
 def AArch64umulh_p : SDNode<"AArch64ISD::MULHU_PRED", SDT_AArch64Arith>;
 
+def AArch64fadd_p_nsz : PatFrag<(ops node:$op1, node:$op2, node:$op3),
+                                (AArch64fadd_p node:$op1, node:$op2, node:$op3), [{
+  return N->getFlags().hasNoSignedZeros();
+}]>;
+
 def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3,i32>,
   SDTCVecEltisVT<1,i1>, SDTCisSameAs<0,2>
@@ -242,8 +247,16 @@ def AArch64cnot_mt : PatFrags<(ops node:$pg, node:$op, node:$pt), [(int_aarch64_
 def AArch64not_mt  : PatFrags<(ops node:$pg, node:$op, node:$pt), [(int_aarch64_sve_not  node:$pt, node:$pg, node:$op)]>;
 
 def AArch64fmul_m1 : EitherVSelectOrPassthruPatFrags<int_aarch64_sve_fmul, AArch64fmul_p>;
-def AArch64fadd_m1 : EitherVSelectOrPassthruPatFrags<int_aarch64_sve_fadd, AArch64fadd_p>;
-def AArch64fsub_m1 : EitherVSelectOrPassthruPatFrags<int_aarch64_sve_fsub, AArch64fsub_p>;
+def AArch64fadd_m1 : PatFrags<(ops node:$pg, node:$op1, node:$op2), [
+    (int_aarch64_sve_fadd node:$pg, node:$op1, node:$op2),
+    (vselect node:$pg, (AArch64fadd_p (SVEAllActive), node:$op1, node:$op2), node:$op1),
+    (AArch64fadd_p_nsz (SVEAllActive), node:$op1, (vselect node:$pg, node:$op2, (SVEDup0)))
+]>;
+def AArch64fsub_m1 : PatFrags<(ops node:$pg, node:$op1, node:$op2), [
+    (int_aarch64_sve_fsub node:$pg, node:$op1, node:$op2),
+    (vselect node:$pg, (AArch64fsub_p (SVEAllActive), node:$op1, node:$op2), node:$op1),
+    (AArch64fsub_p (SVEAllActive), node:$op1, (vselect node:$pg, node:$op2, (SVEDup0)))
+]>;
 
 def AArch64saba : PatFrags<(ops node:$op1, node:$op2, node:$op3),
                            [(int_aarch64_sve_saba node:$op1, node:$op2, node:$op3),
@@ -308,6 +321,12 @@ def AArch64mul_p_oneuse : PatFrag<(ops node:$pred, node:$src1, node:$src2),
   return N->hasOneUse();
 }]>;
 
+def AArch64fmul_p_oneuse : PatFrag<(ops node:$pred, node:$src1, node:$src2),
+                                   (AArch64fmul_p node:$pred, node:$src1, node:$src2), [{
+  return N->hasOneUse();
+}]>;
+
+
 def AArch64fabd_p : PatFrag<(ops node:$pg, node:$op1, node:$op2),
                             (AArch64fabs_mt node:$pg, (AArch64fsub_p node:$pg, node:$op1, node:$op2), undef)>;
 
@@ -356,6 +375,20 @@ def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
                               // sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)
                               (sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
 
+class fma_patfrags<SDPatternOperator intrinsic, SDPatternOperator sdnode>
+    : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
+               [(intrinsic node:$pred, node:$op1, node:$op2, node:$op3),
+                (sdnode (SVEAllActive), node:$op1, (vselect node:$pred, (AArch64fmul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))],
+               [{
+  if ((N->getOpcode() != AArch64ISD::FADD_PRED) &&
+      (N->getOpcode() != AArch64ISD::FSUB_PRED))
+    return true;  // it's the intrinsic
+  return N->getFlags().hasAllowContract();
+}]>;
+
+def AArch64fmla_m1 : fma_patfrags<int_aarch64_sve_fmla, AArch64fadd_p_nsz>;
+def AArch64fmls_m1 : fma_patfrags<int_aarch64_sve_fmls, AArch64fsub_p>;
+
 let Predicates = [HasSVE] in {
   defm RDFFR_PPz  : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>;
   def  RDFFRS_PPz : sve_int_rdffr_pred<0b1, "rdffrs">;
@@ -592,8 +625,8 @@ let Predicates = [HasSVEorSME] in {
   defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>;
   defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla", int_aarch64_sve_fcmla>;
 
-  defm FMLA_ZPmZZ  : sve_fp_3op_p_zds_a<0b00, "fmla",  "FMLA_ZPZZZ", int_aarch64_sve_fmla, "FMAD_ZPmZZ">;
-  defm FMLS_ZPmZZ  : sve_fp_3op_p_zds_a<0b01, "fmls",  "FMLS_ZPZZZ", int_aarch64_sve_fmls, "FMSB_ZPmZZ">;
+  defm FMLA_ZPmZZ  : sve_fp_3op_p_zds_a<0b00, "fmla",  "FMLA_ZPZZZ", AArch64fmla_m1, "FMAD_ZPmZZ">;
+  defm FMLS_ZPmZZ  : sve_fp_3op_p_zds_a<0b01, "fmls",  "FMLS_ZPZZZ", AArch64fmls_m1, "FMSB_ZPmZZ">;
   defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla", "FNMLA_ZPZZZ", int_aarch64_sve_fnmla, "FNMAD_ZPmZZ">;
   defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls", "FNMLS_ZPZZZ", int_aarch64_sve_fnmls, "FNMSB_ZPmZZ">;
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
index 98cb7b6f93659..9fc066b877143 100644
--- a/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fp-combine.ll
@@ -826,9 +826,7 @@ define <vscale x 2 x double> @fnmsb_d(<vscale x 2 x double> %m1, <vscale x 2 x d
 define <vscale x 8 x half> @fadd_h_sel(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: fadd_h_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.h, #0 // =0x0
-; CHECK-NEXT:    sel z1.h, p0, z1.h, z2.h
-; CHECK-NEXT:    fadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    fadd z0.h, p0/m, z0.h, z1.h
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %b, <vscale x 8 x half> zeroinitializer
   %fadd = fadd nsz <vscale x 8 x half> %a, %sel
@@ -838,9 +836,7 @@ define <vscale x 8 x half> @fadd_h_sel(<vscale x 8 x half> %a, <vscale x 8 x hal
 define <vscale x 4 x float> @fadd_s_sel(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: fadd_s_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    sel z1.s, p0, z1.s, z2.s
-; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %b, <vscale x 4 x float> zeroinitializer
   %fadd = fadd nsz <vscale x 4 x float> %a, %sel
@@ -850,9 +846,7 @@ define <vscale x 4 x float> @fadd_s_sel(<vscale x 4 x float> %a, <vscale x 4 x f
 define <vscale x 2 x double> @fadd_d_sel(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: fadd_d_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.d, #0 // =0x0
-; CHECK-NEXT:    sel z1.d, p0, z1.d, z2.d
-; CHECK-NEXT:    fadd z0.d, z0.d, z1.d
+; CHECK-NEXT:    fadd z0.d, p0/m, z0.d, z1.d
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %b, <vscale x 2 x double> zeroinitializer
   %fadd = fadd nsz <vscale x 2 x double> %a, %sel
@@ -862,9 +856,7 @@ define <vscale x 2 x double> @fadd_d_sel(<vscale x 2 x double> %a, <vscale x 2 x
 define <vscale x 8 x half> @fsub_h_sel(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: fsub_h_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.h, #0 // =0x0
-; CHECK-NEXT:    sel z1.h, p0, z1.h, z2.h
-; CHECK-NEXT:    fsub z0.h, z0.h, z1.h
+; CHECK-NEXT:    fsub z0.h, p0/m, z0.h, z1.h
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %b, <vscale x 8 x half> zeroinitializer
   %fsub = fsub <vscale x 8 x half> %a, %sel
@@ -874,9 +866,7 @@ define <vscale x 8 x half> @fsub_h_sel(<vscale x 8 x half> %a, <vscale x 8 x hal
 define <vscale x 4 x float> @fsub_s_sel(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: fsub_s_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    sel z1.s, p0, z1.s, z2.s
-; CHECK-NEXT:    fsub z0.s, z0.s, z1.s
+; CHECK-NEXT:    fsub z0.s, p0/m, z0.s, z1.s
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %b, <vscale x 4 x float> zeroinitializer
   %fsub = fsub <vscale x 4 x float> %a, %sel
@@ -886,9 +876,7 @@ define <vscale x 4 x float> @fsub_s_sel(<vscale x 4 x float> %a, <vscale x 4 x f
 define <vscale x 2 x double> @fsub_d_sel(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: fsub_d_sel:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.d, #0 // =0x0
-; CHECK-NEXT:    sel z1.d, p0, z1.d, z2.d
-; CHECK-NEXT:    fsub z0.d, z0.d, z1.d
+; CHECK-NEXT:    fsub z0.d, p0/m, z0.d, z1.d
 ; CHECK-NEXT:    ret
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %b, <vscale x 2 x double> zeroinitializer
   %fsub = fsub <vscale x 2 x double> %a, %sel
@@ -898,10 +886,7 @@ define <vscale x 2 x double> @fsub_d_sel(<vscale x 2 x double> %a, <vscale x 2 x
 define <vscale x 8 x half> @fadd_sel_fmul_h(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: fadd_sel_fmul_h:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.h, #0 // =0x0
-; CHECK-NEXT:    fmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    sel z1.h, p0, z1.h, z3.h
-; CHECK-NEXT:    fadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    fmla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 8 x half> %b, %c
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %fmul, <vscale x 8 x half> zeroinitializer
@@ -912,10 +897,7 @@ define <vscale x 8 x half> @fadd_sel_fmul_h(<vscale x 8 x half> %a, <vscale x 8
 define <vscale x 4 x float> @fadd_sel_fmul_s(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: fadd_sel_fmul_s:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.s, #0 // =0x0
-; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
-; CHECK-NEXT:    sel z1.s, p0, z1.s, z3.s
-; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    fmla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 4 x float> %b, %c
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %fmul, <vscale x 4 x float> zeroinitializer
@@ -926,10 +908,7 @@ define <vscale x 4 x float> @fadd_sel_fmul_s(<vscale x 4 x float> %a, <vscale x
 define <vscale x 2 x double> @fadd_sel_fmul_d(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: fadd_sel_fmul_d:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.d, #0 // =0x0
-; CHECK-NEXT:    fmul z1.d, z1.d, z2.d
-; CHECK-NEXT:    sel z1.d, p0, z1.d, z3.d
-; CHECK-NEXT:    fadd z0.d, z0.d, z1.d
+; CHECK-NEXT:    fmla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 2 x double> %b, %c
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %fmul, <vscale x 2 x double> zeroinitializer
@@ -940,10 +919,7 @@ define <vscale x 2 x double> @fadd_sel_fmul_d(<vscale x 2 x double> %a, <vscale
 define <vscale x 8 x half> @fsub_sel_fmul_h(<vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: fsub_sel_fmul_h:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.h, #0 // =0x0
-; CHECK-NEXT:    fmul z1.h, z1.h, z2.h
-; CHECK-NEXT:    sel z1.h, p0, z1.h, z3.h
-; CHECK-NEXT:    fsub z0.h, z0.h, z1.h
+; CHECK-NEXT:    fmls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 8 x half> %b, %c
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x half> %fmul, <vscale x 8 x half> zeroinitializer
@@ -954,10 +930,7 @@ define <vscale x 8 x half> @fsub_sel_fmul_h(<vscale x 8 x half> %a, <vscale x 8
 define <vscale x 4 x float> @fsub_sel_fmul_s(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: fsub_sel_fmul_s:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.s, #0 // =0x0
-; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
-; CHECK-NEXT:    sel z1.s, p0, z1.s, z3.s
-; CHECK-NEXT:    fsub z0.s, z0.s, z1.s
+; CHECK-NEXT:    fmls z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 4 x float> %b, %c
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %fmul, <vscale x 4 x float> zeroinitializer
@@ -968,13 +941,38 @@ define <vscale x 4 x float> @fsub_sel_fmul_s(<vscale x 4 x float> %a, <vscale x
 define <vscale x 2 x double> @fsub_sel_fmul_d(<vscale x 2 x double> %a, <vscale x 2 x double> %b, <vscale x 2 x double> %c, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: fsub_sel_fmul_d:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z3.d, #0 // =0x0
-; CHECK-NEXT:    fmul z1.d, z1.d, z2.d
-; CHECK-NEXT:    sel z1.d, p0, z1.d, z3.d
-; CHECK-NEXT:    fsub z0.d, z0.d, z1.d
+; CHECK-NEXT:    fmls z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
   %fmul = fmul <vscale x 2 x double> %b, %c
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x double> %fmul, <vscale x 2 x double> zeroinitializer
   %fsub = fsub contract <vscale x 2 x double> %a, %sel
   ret <vscale x 2 x double> %fsub
 }
+
+; Verify combine requires contract fast-math flag.
+define <vscale x 4 x float> @fadd_sel_fmul_no_contract_s(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: fadd_sel_fmul_no_contract_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
+; CHECK-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %fmul = fmul <vscale x 4 x float> %b, %c
+  %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %fmul, <vscale x 4 x float> zeroinitializer
+  %fadd = fadd nsz <vscale x 4 x float> %a, %sel
+  ret <vscale x 4 x float> %fadd
+}
+
+; Verify combine requires no-signed zeros fast-math flag.
+define <vscale x 4 x float> @fadd_sel_fmul_no_nsz_s(<vscale x 4 x float> %a, <vscale x 4 x float> %b, <vscale x 4 x float> %c, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: fadd_sel_fmul_no_nsz_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z3.s, #0 // =0x0
+; CHECK-NEXT:    fmul z1.s, z1.s, z2.s
+; CHECK-NEXT:    sel z1.s, p0, z1.s, z3.s
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %fmul = fmul <vscale x 4 x float> %b, %c
+  %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %fmul, <vscale x 4 x float> zeroinitializer
+  %fadd = fadd contract <vscale x 4 x float> %a, %sel
+  ret <vscale x 4 x float> %fadd
+}


        


More information about the llvm-commits mailing list