[llvm] 6082051 - [AArch64][SVE] Add patterns to select mla/mls
Cullen Rhodes via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 26 01:15:41 PDT 2022
Author: Cullen Rhodes
Date: 2022-07-26T07:52:44Z
New Revision: 6082051da158699864fc873df494ad66e271ee22
URL: https://github.com/llvm/llvm-project/commit/6082051da158699864fc873df494ad66e271ee22
DIFF: https://github.com/llvm/llvm-project/commit/6082051da158699864fc873df494ad66e271ee22.diff
LOG: [AArch64][SVE] Add patterns to select mla/mls
Adds patterns for:
add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c)
sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)
Reviewed By: paulwalker-arm
Differential Revision: https://reviews.llvm.org/D130492
Added:
Modified:
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 686ffe9f5d8f1..9b040860cc3c3 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -345,6 +345,16 @@ def AArch64add_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2),
def AArch64sub_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2),
[(int_aarch64_sve_sub node:$pred, node:$op1, node:$op2),
(sub node:$op1, (vselect node:$pred, node:$op2, (SVEDup0)))]>;
+def AArch64mla_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
+ [(int_aarch64_sve_mla node:$pred, node:$op1, node:$op2, node:$op3),
+ (add node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)),
+ // add(a, select(mask, mul(b, c), splat(0))) -> mla(a, mask, b, c)
+ (add node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
+def AArch64mls_m1 : PatFrags<(ops node:$pred, node:$op1, node:$op2, node:$op3),
+ [(int_aarch64_sve_mls node:$pred, node:$op1, node:$op2, node:$op3),
+ (sub node:$op1, (AArch64mul_p_oneuse node:$pred, node:$op2, node:$op3)),
+ // sub(a, select(mask, mul(b, c), splat(0))) -> mls(a, mask, b, c)
+ (sub node:$op1, (vselect node:$pred, (AArch64mul_p_oneuse (SVEAllActive), node:$op2, node:$op3), (SVEDup0)))]>;
let Predicates = [HasSVE] in {
defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>;
@@ -399,8 +409,8 @@ let Predicates = [HasSVEorSME] in {
defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad>;
defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb>;
- defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", int_aarch64_sve_mla, add, AArch64mul_p_oneuse>;
- defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", int_aarch64_sve_mls, sub, AArch64mul_p_oneuse>;
+ defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", AArch64mla_m1>;
+ defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", AArch64mls_m1>;
// SVE predicated integer reductions.
defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv", AArch64saddv_p>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 7cdd4c4af95ec..36daecf634d71 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2958,8 +2958,7 @@ class sve_int_mlas_vvv_pred<bits<2> sz8_64, bits<1> opc, string asm,
let ElementSize = zprty.ElementSize;
}
-multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op,
- SDPatternOperator outerop, SDPatternOperator mulop> {
+multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op> {
def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>;
def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>;
def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>;
@@ -2969,15 +2968,6 @@ multiclass sve_int_mlas_vvv_pred<bits<1> opc, string asm, SDPatternOperator op,
def : SVE_4_Op_Pat<nxv8i16, op, nxv8i1, nxv8i16, nxv8i16, nxv8i16, !cast<Instruction>(NAME # _H)>;
def : SVE_4_Op_Pat<nxv4i32, op, nxv4i1, nxv4i32, nxv4i32, nxv4i32, !cast<Instruction>(NAME # _S)>;
def : SVE_4_Op_Pat<nxv2i64, op, nxv2i1, nxv2i64, nxv2i64, nxv2i64, !cast<Instruction>(NAME # _D)>;
-
- def : Pat<(outerop nxv16i8:$Op1, (mulop nxv16i1:$pred, nxv16i8:$Op2, nxv16i8:$Op3)),
- (!cast<Instruction>(NAME # _B) $pred, $Op1, $Op2, $Op3)>;
- def : Pat<(outerop nxv8i16:$Op1, (mulop nxv8i1:$pred, nxv8i16:$Op2, nxv8i16:$Op3)),
- (!cast<Instruction>(NAME # _H) $pred, $Op1, $Op2, $Op3)>;
- def : Pat<(outerop nxv4i32:$Op1, (mulop nxv4i1:$pred, nxv4i32:$Op2, nxv4i32:$Op3)),
- (!cast<Instruction>(NAME # _S) $pred, $Op1, $Op2, $Op3)>;
- def : Pat<(outerop nxv2i64:$Op1, (mulop nxv2i1:$pred, nxv2i64:$Op2, nxv2i64:$Op3)),
- (!cast<Instruction>(NAME # _D) $pred, $Op1, $Op2, $Op3)>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
index b28f7e2d5f5a1..6ec4d50a38743 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-int-arith.ll
@@ -96,9 +96,7 @@ define <vscale x 2 x i64> @masked_sub_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2
define <vscale x 16 x i8> @masked_mla_nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: masked_mla_nxv16i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.b
-; CHECK-NEXT: mul z1.b, p1/m, z1.b, z2.b
-; CHECK-NEXT: add z0.b, p0/m, z0.b, z1.b
+; CHECK-NEXT: mla z0.b, p0/m, z1.b, z2.b
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 16 x i8> %b, %c
%sel = select <vscale x 16 x i1> %mask, <vscale x 16 x i8> %mul, <vscale x 16 x i8> zeroinitializer
@@ -109,9 +107,7 @@ define <vscale x 16 x i8> @masked_mla_nxv16i8(<vscale x 16 x i8> %a, <vscale x 1
define <vscale x 8 x i16> @masked_mla_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: masked_mla_nxv8i16:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.h
-; CHECK-NEXT: mul z1.h, p1/m, z1.h, z2.h
-; CHECK-NEXT: add z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: mla z0.h, p0/m, z1.h, z2.h
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 8 x i16> %b, %c
%sel = select <vscale x 8 x i1> %mask, <vscale x 8 x i16> %mul, <vscale x 8 x i16> zeroinitializer
@@ -122,9 +118,7 @@ define <vscale x 8 x i16> @masked_mla_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8
define <vscale x 4 x i32> @masked_mla_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i32> %c, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: masked_mla_nxv4i32:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.s
-; CHECK-NEXT: mul z1.s, p1/m, z1.s, z2.s
-; CHECK-NEXT: add z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 4 x i32> %b, %c
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x i32> %mul, <vscale x 4 x i32> zeroinitializer
@@ -135,9 +129,7 @@ define <vscale x 4 x i32> @masked_mla_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4
define <vscale x 2 x i64> @masked_mla_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b, <vscale x 2 x i64> %c, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: masked_mla_nxv2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.d
-; CHECK-NEXT: mul z1.d, p1/m, z1.d, z2.d
-; CHECK-NEXT: add z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 2 x i64> %b, %c
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x i64> %mul, <vscale x 2 x i64> zeroinitializer
@@ -152,9 +144,7 @@ define <vscale x 2 x i64> @masked_mla_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2
define <vscale x 16 x i8> @masked_mls_nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: masked_mls_nxv16i8:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.b
-; CHECK-NEXT: mul z1.b, p1/m, z1.b, z2.b
-; CHECK-NEXT: sub z0.b, p0/m, z0.b, z1.b
+; CHECK-NEXT: mls z0.b, p0/m, z1.b, z2.b
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 16 x i8> %b, %c
%sel = select <vscale x 16 x i1> %mask, <vscale x 16 x i8> %mul, <vscale x 16 x i8> zeroinitializer
@@ -165,9 +155,7 @@ define <vscale x 16 x i8> @masked_mls_nxv16i8(<vscale x 16 x i8> %a, <vscale x 1
define <vscale x 8 x i16> @masked_mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: masked_mls_nxv8i16:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.h
-; CHECK-NEXT: mul z1.h, p1/m, z1.h, z2.h
-; CHECK-NEXT: sub z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: mls z0.h, p0/m, z1.h, z2.h
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 8 x i16> %b, %c
%sel = select <vscale x 8 x i1> %mask, <vscale x 8 x i16> %mul, <vscale x 8 x i16> zeroinitializer
@@ -178,9 +166,7 @@ define <vscale x 8 x i16> @masked_mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8
define <vscale x 4 x i32> @masked_mls_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b, <vscale x 4 x i32> %c, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: masked_mls_nxv4i32:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.s
-; CHECK-NEXT: mul z1.s, p1/m, z1.s, z2.s
-; CHECK-NEXT: sub z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: mls z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 4 x i32> %b, %c
%sel = select <vscale x 4 x i1> %mask, <vscale x 4 x i32> %mul, <vscale x 4 x i32> zeroinitializer
@@ -191,9 +177,7 @@ define <vscale x 4 x i32> @masked_mls_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4
define <vscale x 2 x i64> @masked_mls_nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %b, <vscale x 2 x i64> %c, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: masked_mls_nxv2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ptrue p1.d
-; CHECK-NEXT: mul z1.d, p1/m, z1.d, z2.d
-; CHECK-NEXT: sub z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
%mul = mul nsw <vscale x 2 x i64> %b, %c
%sel = select <vscale x 2 x i1> %mask, <vscale x 2 x i64> %mul, <vscale x 2 x i64> zeroinitializer
More information about the llvm-commits
mailing list