[llvm] 6082051 - [AArch64][SVE] Add patterns to select mla/mls

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 26 01:15:41 PDT 2022


Author: Cullen Rhodes
Date: 2022-07-26T07:52:44Z
New Revision: 6082051da158699864fc873df494ad66e271ee22

URL: https://github.com/llvm/llvm-project/commit/6082051da158699864fc873df494ad66e271ee22
DIFF: https://github.com/llvm/llvm-project/commit/6082051da158699864fc873df494ad66e271ee22.diff

LOG: [AArch64][SVE] Add patterns to select mla/mls

Adds patterns for:

  add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c)
  sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D130492

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 686ffe9f5d8f1..9b040860cc3c3 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -345,6 +345,16 @@ def AArch64add_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2),
 def AArch64sub_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2),
                              [(int_aarch64_sve_sub node:$pred, node:$op1, node:$op2),
                               (sub node:$op1, (vselect node:$pred, node:$op2, (SVEDup0)))]>;
+def AArch64mla_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
+                             [(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3),
+                              (add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)),
+                              // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c)
+                              (add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
+def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
+                             [(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3),
+                              (sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)),
+                              // sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)
+                              (sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
 
 let Predicates = [HasSVE] in {
   defm RDFFR_PPz  : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>;
@@ -399,8 +409,8 @@ let Predicates = [HasSVEorSME] in {
 
   defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad>;
   defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb>;
-  defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", int_aarch64_sve_mla, add, AArch64mul_p_oneuse>;
-  defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", int_aarch64_sve_mls, sub, AArch64mul_p_oneuse>;
+  defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1>;
+  defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1>;
 
   // SVE predicated integer reductions.
   defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv", AArch64saddv_p>;

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 7cdd4c4af95ec..36daecf634d71 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2958,8 +2958,7 @@ class sve_int_mlas_vvv_pred<bits<2> sz8_64, bits<1> opc, string asm,
   let ElementSize = zprty.ElementSize;
 }
 
-multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op,
-                                 SDPatternOperator outerop, SDPatternOperator mulop> {
+multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op> {
   def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>;
   def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>;
   def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>;
@@ -2969,15 +2968,6 @@ multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op,
   def : SVE_4_Op_Pat<nxv8i16, op, nxv8i1, nxv8i16, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _H)>;
   def : SVE_4_Op_Pat<nxv4i32, op, nxv4i1, nxv4i32, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _S)>;
   def : SVE_4_Op_Pat<nxv2i64, op, nxv2i1, nxv2i64, nxv2i64, nxv2i64, !cast<Instruction>(NAME # _D)>;
-
-  def : Pat<(outerop nxv16i8:$Op1, (mulop nxv16i1:$pred, nxv16i8:$Op2, nxv16i8:$Op3)),
-            (!cast<Instruction>(NAME # _B) $pred, $Op1, $Op2, $Op3)>;
-  def : Pat<(outerop nxv8i16:$Op1, (mulop nxv8i1:$pred, nxv8i16:$Op2, nxv8i16:$Op3)),
-            (!cast<Instruction>(NAME # _H) $pred, $Op1, $Op2, $Op3)>;
-  def : Pat<(outerop nxv4i32:$Op1, (mulop nxv4i1:$pred, nxv4i32:$Op2, nxv4i32:$Op3)),
-            (!cast<Instruction>(NAME # _S) $pred, $Op1, $Op2, $Op3)>;
-  def : Pat<(outerop nxv2i64:$Op1, (mulop nxv2i1:$pred, nxv2i64:$Op2, nxv2i64:$Op3)),
-            (!cast<Instruction>(NAME # _D) $pred, $Op1, $Op2, $Op3)>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
index b28f7e2d5f5a1..6ec4d50a38743 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
@@ -96,9 +96,7 @@ define <vscale x 2 x i64> @masked_sub_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2
 define <vscale x 16 x i8> @masked_mla_nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c, <vscale x 16 x i1> %mask) {
 ; CHECK-LABEL: masked_mla_nxv16i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.b
-; CHECK-NEXT:    mul z1.b, p1/m, z1.b, z2.b
-; CHECK-NEXT:    add z0.b, p0/m, z0.b, z1.b
+; CHECK-NEXT:    mla z0.b, p0/m, z1.b, z2.b
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 16 x i8> %b, %c
   %sel = select <vscale x 16 x i1> %mask, <vscale x 16 x i8> %mul, <vscale x 16 x i8> zeroinitializer
@@ -109,9 +107,7 @@ define <vscale x 16 x i8> @masked_mla_nxv16i8(<vscale x 16 x i8> %a, <vscale x 1
 define <vscale x 8 x i16> @masked_mla_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: masked_mla_nxv8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    mul z1.h, p1/m, z1.h, z2.h
-; CHECK-NEXT:    add z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    mla z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 8 x i16> %b, %c
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x i16> %mul, <vscale x 8 x i16> zeroinitializer
@@ -122,9 +118,7 @@ define <vscale x 8 x i16> @masked_mla_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8
 define <vscale x 4 x i32> @masked_mla_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i32> %c, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_mla_nxv4i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    mul z1.s, p1/m, z1.s, z2.s
-; CHECK-NEXT:    add z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 4 x i32> %b, %c
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x i32> %mul, <vscale x 4 x i32> zeroinitializer
@@ -135,9 +129,7 @@ define <vscale x 4 x i32> @masked_mla_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4
 define <vscale x 2 x i64> @masked_mla_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b, <vscale x 2 x i64> %c, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: masked_mla_nxv2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    mul z1.d, p1/m, z1.d, z2.d
-; CHECK-NEXT:    add z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 2 x i64> %b, %c
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x i64> %mul, <vscale x 2 x i64> zeroinitializer
@@ -152,9 +144,7 @@ define <vscale x 2 x i64> @masked_mla_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2
 define <vscale x 16 x i8> @masked_mls_nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c, <vscale x 16 x i1> %mask) {
 ; CHECK-LABEL: masked_mls_nxv16i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.b
-; CHECK-NEXT:    mul z1.b, p1/m, z1.b, z2.b
-; CHECK-NEXT:    sub z0.b, p0/m, z0.b, z1.b
+; CHECK-NEXT:    mls z0.b, p0/m, z1.b, z2.b
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 16 x i8> %b, %c
   %sel = select <vscale x 16 x i1> %mask, <vscale x 16 x i8> %mul, <vscale x 16 x i8> zeroinitializer
@@ -165,9 +155,7 @@ define <vscale x 16 x i8> @masked_mls_nxv16i8(<vscale x 16 x i8> %a, <vscale x 1
 define <vscale x 8 x i16> @masked_mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c, <vscale x 8 x i1> %mask) {
 ; CHECK-LABEL: masked_mls_nxv8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.h
-; CHECK-NEXT:    mul z1.h, p1/m, z1.h, z2.h
-; CHECK-NEXT:    sub z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT:    mls z0.h, p0/m, z1.h, z2.h
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 8 x i16> %b, %c
   %sel = select <vscale x 8 x i1> %mask, <vscale x 8 x i16> %mul, <vscale x 8 x i16> zeroinitializer
@@ -178,9 +166,7 @@ define <vscale x 8 x i16> @masked_mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8
 define <vscale x 4 x i32> @masked_mls_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i32> %c, <vscale x 4 x i1> %mask) {
 ; CHECK-LABEL: masked_mls_nxv4i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.s
-; CHECK-NEXT:    mul z1.s, p1/m, z1.s, z2.s
-; CHECK-NEXT:    sub z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    mls z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 4 x i32> %b, %c
   %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x i32> %mul, <vscale x 4 x i32> zeroinitializer
@@ -191,9 +177,7 @@ define <vscale x 4 x i32> @masked_mls_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4
 define <vscale x 2 x i64> @masked_mls_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b, <vscale x 2 x i64> %c, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: masked_mls_nxv2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ptrue p1.d
-; CHECK-NEXT:    mul z1.d, p1/m, z1.d, z2.d
-; CHECK-NEXT:    sub z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    mls z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
   %mul = mul nsw <vscale x 2 x i64> %b, %c
   %sel = select <vscale x 2 x i1> %mask, <vscale x 2 x i64> %mul, <vscale x 2 x i64> zeroinitializer


        


More information about the llvm-commits mailing list