[llvm] 01f9d8c - [llvm][SVE] IR intrinscs for matrix multiplication instructions.

Francesco Petrogalli via llvm-commits llvm-commits at lists.llvm.org
Mon May 18 15:04:04 PDT 2020


Author: Francesco Petrogalli
Date: 2020-05-18T22:02:19Z
New Revision: 01f9d8ce5c0e394d4a080ed6117f64e284b3f303

URL: https://github.com/llvm/llvm-project/commit/01f9d8ce5c0e394d4a080ed6117f64e284b3f303
DIFF: https://github.com/llvm/llvm-project/commit/01f9d8ce5c0e394d4a080ed6117f64e284b3f303.diff

LOG: [llvm][SVE] IR intrinscs for matrix multiplication instructions.

Summary:
Instructions:

* SMMLA
* UMMLA
* USMMLA
* FMMLA

Reviewers: sdesmalen, efriedma, kmclaughlin

Subscribers: tschuett, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79638

Added: 
    llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll
    llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll
    llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll

Modified: 
    llvm/include/llvm/IR/IntrinsicsAArch64.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 56fb18a28df7..0be3a7f3593d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -1294,6 +1294,11 @@ class SVE_gather_prf_VS
                 ],
                 [IntrInaccessibleMemOrArgMemOnly, ImmArg<3>]>;
 
+class SVE_MatMul_Intrinsic
+    : Intrinsic<[llvm_anyvector_ty],
+                [LLVMMatchType<0>, LLVMSubdivide4VectorType<0>, LLVMSubdivide4VectorType<0>],
+                [IntrNoMem]>;
+
 //
 // Loads
 //
@@ -2254,6 +2259,19 @@ def int_aarch64_sve_bdep_x : AdvSIMD_2VectorArg_Intrinsic;
 def int_aarch64_sve_bext_x : AdvSIMD_2VectorArg_Intrinsic;
 def int_aarch64_sve_bgrp_x : AdvSIMD_2VectorArg_Intrinsic;
 
+
+//
+// SVE ACLE: 7.3. INT8 matrix multiply extensions
+//
+def int_aarch64_sve_ummla : SVE_MatMul_Intrinsic;
+def int_aarch64_sve_smmla : SVE_MatMul_Intrinsic;
+def int_aarch64_sve_usmmla : SVE_MatMul_Intrinsic;
+
+//
+// SVE ACLE: 7.4/5. FP64/FP32 matrix multiply extensions
+//
+def int_aarch64_sve_fmmla : AdvSIMD_3VectorArg_Intrinsic;
+
 }
 
 //

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 81de64757e87..81721926d9dd 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1854,20 +1854,20 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
 }
 
 let Predicates = [HasSVE, HasMatMulInt8] in {
-  def  SMMLA_ZZZ : sve_int_matmul<0b00, "smmla">;
-  def  UMMLA_ZZZ : sve_int_matmul<0b11, "ummla">;
-  def USMMLA_ZZZ : sve_int_matmul<0b10, "usmmla">;
+  defm  SMMLA_ZZZ : sve_int_matmul<0b00, "smmla", int_aarch64_sve_smmla>;
+  defm  UMMLA_ZZZ : sve_int_matmul<0b11, "ummla", int_aarch64_sve_ummla>;
+  defm USMMLA_ZZZ : sve_int_matmul<0b10, "usmmla", int_aarch64_sve_usmmla>;
   def USDOT_ZZZ  : sve_int_dot_mixed<"usdot">;
   def USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot">;
   def SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot">;
 }
 
 let Predicates = [HasSVE, HasMatMulFP32] in {
-  def FMMLA_ZZZ_S : sve_fp_matrix_mla<0, "fmmla", ZPR32>;
+  defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0, "fmmla", ZPR32, int_aarch64_sve_fmmla, nxv4f32>;
 }
 
 let Predicates = [HasSVE, HasMatMulFP64] in {
-  def FMMLA_ZZZ_D : sve_fp_matrix_mla<1, "fmmla", ZPR64>;
+  defm FMMLA_ZZZ_D : sve_fp_matrix_mla<1, "fmmla", ZPR64, int_aarch64_sve_fmmla, nxv2f64>;
   defm LD1RO_B_IMM  : sve_mem_ldor_si<0b00, "ld1rob", Z_b, ZPR8>;
   defm LD1RO_H_IMM  : sve_mem_ldor_si<0b01, "ld1roh", Z_h, ZPR16>;
   defm LD1RO_W_IMM  : sve_mem_ldor_si<0b10, "ld1row", Z_s, ZPR32>;

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index e874edbb5fe2..992542c0b75c 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -7547,6 +7547,12 @@ class sve_int_matmul<bits<2> uns, string asm>
   let ElementSize = ZPR32.ElementSize;
 }
 
+multiclass sve_int_matmul<bits<2> uns, string asm, SDPatternOperator op> {
+  def NAME : sve_int_matmul<uns, asm>;
+
+  def : SVE_3_Op_Pat<nxv4i32, op , nxv4i32, nxv16i8, nxv16i8, !cast<Instruction>(NAME)>;
+}
+
 //===----------------------------------------------------------------------===//
 // SVE Integer Dot Product Mixed Sign Group
 //===----------------------------------------------------------------------===//
@@ -7615,6 +7621,12 @@ class sve_fp_matrix_mla<bit sz, string asm, ZPRRegOp zprty>
   let ElementSize = zprty.ElementSize;
 }
 
+multiclass sve_fp_matrix_mla<bit sz, string asm, ZPRRegOp zprty, SDPatternOperator op, ValueType vt> {
+  def NAME : sve_fp_matrix_mla<sz, asm, zprty>;
+
+  def : SVE_3_Op_Pat<vt, op , vt, vt, vt, !cast<Instruction>(NAME)>;
+}
+
 //===----------------------------------------------------------------------===//
 // SVE Memory - Contiguous Load And Replicate 256-bit Group
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll
new file mode 100644
index 000000000000..6486b1596d1e
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll
@@ -0,0 +1,13 @@
+; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f32mm -asm-verbose=0 < %s -o - | FileCheck %s
+
+define <vscale x 4 x float> @fmmla_s(<vscale x 4 x float> %r, <vscale x 4 x float> %a, <vscale x 4 x float> %b) nounwind {
+entry:
+; CHECK-LABEL: fmmla_s:
+; CHECK-NEXT:  fmmla   z0.s, z1.s, z2.s
+; CHECK-NEXT : ret
+  %val = tail call <vscale x 4 x float> @llvm.aarch64.sve.fmmla.nxv4f32(<vscale x 4 x float> %r, <vscale x 4 x float> %a, <vscale x 4 x float> %b)
+  ret <vscale x 4 x float> %val
+}
+
+declare <vscale x 4 x float> @llvm.aarch64.sve.fmmla.nxv4f32(<vscale x 4 x float>,<vscale x 4 x float>,<vscale x 4 x float>)
+

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll
new file mode 100644
index 000000000000..9f6ff187e0c5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll
@@ -0,0 +1,14 @@
+; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f64mm -asm-verbose=0 < %s -o - | FileCheck %s
+
+
+define <vscale x 2 x double> @fmmla_d(<vscale x 2 x double> %r, <vscale x 2 x double> %a, <vscale x 2 x double> %b) nounwind {
+entry:
+; CHECK-LABEL: fmmla_d:
+; CHECK-NEXT:  fmmla   z0.d, z1.d, z2.d
+; CHECK-NEXT : ret
+  %val = tail call <vscale x 2 x double> @llvm.aarch64.sve.fmmla.nxv2f64(<vscale x 2 x double> %r, <vscale x 2 x double> %a, <vscale x 2 x double> %b)
+  ret <vscale x 2 x double> %val
+}
+
+declare <vscale x 2 x double> @llvm.aarch64.sve.fmmla.nxv2f64(<vscale x 2 x double>,<vscale x 2 x double>,<vscale x 2 x double>)
+

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll
new file mode 100644
index 000000000000..c295aee43975
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll
@@ -0,0 +1,33 @@
+; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+i8mm -asm-verbose=0 < %s -o - | FileCheck %s
+
+define <vscale x 4 x i32> @smmla(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) nounwind {
+entry:
+; CHECK-LABEL: smmla:
+; CHECK-NEXT:  smmla   z0.s, z1.b, z2.b
+; CHECK-NEXT:  ret
+  %val = tail call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  ret <vscale x 4 x i32> %val
+}
+
+define <vscale x 4 x i32> @ummla(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) nounwind {
+entry:
+; CHECK-LABEL: ummla:
+; CHECK-NEXT:  ummla   z0.s, z1.b, z2.b
+; CHECK-NEXT:  ret
+  %val = tail call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  ret <vscale x 4 x i32> %val
+}
+
+define <vscale x 4 x i32> @usmmla(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) nounwind {
+entry:
+; CHECK-LABEL: usmmla:
+; CHECK-NEXT:  usmmla   z0.s, z1.b, z2.b
+; CHECK-NEXT : ret
+  %val = tail call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4 x i32> %r, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b)
+  ret <vscale x 4 x i32> %val
+}
+
+declare <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4 x i32>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+


        


More information about the llvm-commits mailing list