[llvm] 908e306 - [AArch64] Implement intrinsics for FP8 SME FMLAL/FMLALL (multi) (#119546)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 03:47:24 PST 2024


Author: SpencerAbson
Date: 2024-12-17T11:47:20Z
New Revision: 908e30658ddf634f7c929f0c7e78dd40405c795a

URL: https://github.com/llvm/llvm-project/commit/908e30658ddf634f7c929f0c7e78dd40405c795a
DIFF: https://github.com/llvm/llvm-project/commit/908e30658ddf634f7c929f0c7e78dd40405c795a.diff

LOG: [AArch64] Implement intrinsics for FP8 SME FMLAL/FMLALL (multi) (#119546)

This patch implements the following intrinsics:

Multi-vector 8-bit floating-point multiply-add long (multiple vectors).

``` c
// Only if __ARM_FEATURE_SME_F8F16 != 0
void svmla_za16[_mf8]_vg2x2_fpm(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm,
                                fpm_t fpm) __arm_streaming __arm_inout("za");

void svmla_za16[_mf8]_vg2x4_fpm(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm,
                                fpm_t fpm) __arm_streaming __arm_inout("za");
// Only if __ARM_FEATURE_SME_F8F32 != 0
void svmla_za32[_mf8]_vg4x2_fpm(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm,
                                fpm_t fpm) __arm_streaming __arm_inout("za");

void svmla_za32[_mf8]_vg4x4_fpm(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm,
                                fpm_t fpm) __arm_streaming __arm_inout("za");                              
```

In accordance with https://github.com/ARM-software/acle/pull/323

Added: 
    

Modified: 
    clang/include/clang/Basic/arm_sme.td
    clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c
    clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c
    llvm/include/llvm/IR/IntrinsicsAArch64.td
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
    llvm/lib/Target/AArch64/SMEInstrFormats.td
    llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td
index 859d5fdfea504d..6b31dec004a1e2 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -873,6 +873,11 @@ let SMETargetGuard = "sme-f8f32" in {
                                          [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
   def SVMLA_FP8_SINGLE_ZA32_VG4x4 : Inst<"svmla[_single]_za32[_mf8]_vg4x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlall_single_za32_vg4x4",
                                          [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
+  // FMLALL (multiple)
+  def SVMLA_FP8_MULTI_ZA32_VG4x2 : Inst<"svmla_za32[_mf8]_vg4x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x2",
+                                        [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
+  def SVMLA_FP8_MULTI_ZA32_VG4x4 : Inst<"svmla_za32[_mf8]_vg4x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x4",
+                                        [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
 }
 
 let SMETargetGuard = "sme-f8f16" in {
@@ -892,6 +897,11 @@ let SMETargetGuard = "sme-f8f16" in {
                                          [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
   def SVMLA_FP8_SINGLE_ZA16_VG2x4 : Inst<"svmla[_single]_za16[_mf8]_vg2x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlal_single_za16_vg2x4",
                                          [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
+  // FMLAL (multiple)
+  def SVMLA_FP8_MULTI_ZA16_VG2x2 : Inst<"svmla_za16[_mf8]_vg2x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x2",
+                                        [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
+  def SVMLA_FP8_MULTI_ZA16_VG2x4 : Inst<"svmla_za16[_mf8]_vg2x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x4",
+                                        [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
 }
 
 } // let SVETargetGuard = InvalidMode

diff  --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c
index 942c4a24f77812..d603045edf2824 100644
--- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c
@@ -1,4 +1,3 @@
-
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 // REQUIRES: aarch64-registered-target
 
@@ -239,3 +238,79 @@ void test_svmla_single_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8_t
 void test_svmla_single_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
     SME_ACLE_FUNC(svmla,_single,_za32,_mf8,_vg4x4_fpm)(slice, zn, zm, fpm);
 }
+
+// FMLAL (multi)
+
+// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x2(
+// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x2j13svmfloat8x2_tS_m(
+// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmla_multi_za16_vg2x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
+    SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x2_fpm,,)(slice, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x4(
+// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x4j13svmfloat8x4_tS_m(
+// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmla_multi_za16_vg2x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
+    SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x4_fpm,,)(slice, zn, zm, fpm);
+}
+
+// FMLALL (multi)
+
+// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x2(
+// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x2j13svmfloat8x2_tS_m(
+// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmla_multi_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
+    SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x2_fpm,,)(slice, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x4(
+// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x4j13svmfloat8x4_tS_m(
+// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmla_multi_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
+    SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x4_fpm,,)(slice, zn, zm, fpm);
+}

diff  --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c
index 9204ca0ae5d4cf..1dbc30196a5544 100644
--- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c
@@ -41,4 +41,16 @@ void test_svmla(uint32_t slice, svmfloat8_t zn, svmfloat8x2_t znx2, svmfloat8x4_
 
     // expected-error at +1 {{'svmla_single_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}}
     svmla_single_za32_mf8_vg4x4_fpm(slice, znx4, zn, fpmr);
+
+    // expected-error at +1 {{'svmla_za16_mf8_vg2x2_fpm' needs target feature sme,sme-f8f16}}
+    svmla_za16_mf8_vg2x2_fpm(slice, znx2, znx2, fpmr);
+
+    // expected-error at +1 {{'svmla_za16_mf8_vg2x4_fpm' needs target feature sme,sme-f8f16}}
+    svmla_za16_mf8_vg2x4_fpm(slice, znx4, znx4, fpmr);
+
+    // expected-error at +1 {{'svmla_za32_mf8_vg4x2_fpm' needs target feature sme,sme-f8f32}}
+    svmla_za32_mf8_vg4x2_fpm(slice, znx2, znx2, fpmr);
+
+    // expected-error at +1 {{'svmla_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}}
+    svmla_za32_mf8_vg4x4_fpm(slice, znx4, znx4, fpmr);
 }
\ No newline at end of file

diff  --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 7fa0d421babc5d..53a66099a92bda 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -3856,29 +3856,6 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic;
 def int_aarch64_neon_famax :  AdvSIMD_2VectorArg_Intrinsic;
 def int_aarch64_neon_famin :  AdvSIMD_2VectorArg_Intrinsic;
 
-
-// SME FP8 FDOT intrinsics
-let TargetPrefix = "aarch64" in {
-
-class SME2_FP8_FDOT_MULTI_VG1x2 :
-    DefaultAttrsIntrinsic<[], [llvm_i32_ty,
-                               llvm_nxv16i8_ty, llvm_nxv16i8_ty,
-                               llvm_nxv16i8_ty, llvm_nxv16i8_ty],
-                          [IntrInaccessibleMemOnly, IntrHasSideEffects]>;
-
-class SME2_FP8_FDOT_MULTI_VG1x4 :
-    DefaultAttrsIntrinsic<[], [llvm_i32_ty,
-                               llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty,
-                               llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty],
-                          [IntrInaccessibleMemOnly, IntrHasSideEffects]>;
-
-  def int_aarch64_sme_fp8_fdot_multi_za16_vg1x2 : SME2_FP8_FDOT_MULTI_VG1x2;
-  def int_aarch64_sme_fp8_fdot_multi_za16_vg1x4 : SME2_FP8_FDOT_MULTI_VG1x4;
-
-  def int_aarch64_sme_fp8_fdot_multi_za32_vg1x2 : SME2_FP8_FDOT_MULTI_VG1x2;
-  def int_aarch64_sme_fp8_fdot_multi_za32_vg1x4 : SME2_FP8_FDOT_MULTI_VG1x4;
-}
-
 //
 // FP8 Intrinsics
 //
@@ -3997,6 +3974,18 @@ let TargetPrefix = "aarch64" in {
     : DefaultAttrsIntrinsic<[], [llvm_i32_ty,
                                 llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty,
                                 llvm_nxv16i8_ty],
+                              [IntrInaccessibleMemOnly, IntrHasSideEffects]>;
+
+  class SME_FP8_ZA_MULTI_VGx2_Intrinsic
+    : DefaultAttrsIntrinsic<[], [llvm_i32_ty,
+                                 llvm_nxv16i8_ty, llvm_nxv16i8_ty,
+                                 llvm_nxv16i8_ty, llvm_nxv16i8_ty],
+                            [IntrInaccessibleMemOnly, IntrHasSideEffects]>;
+
+  class SME_FP8_ZA_MULTI_VGx4_Intrinsic
+    : DefaultAttrsIntrinsic<[], [llvm_i32_ty,
+                                 llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty,
+                                 llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty],
                             [IntrInaccessibleMemOnly, IntrHasSideEffects]>;
   //
   // CVT from FP8 to half-precision/BFloat16 multi-vector
@@ -4036,6 +4025,9 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sme_fp8_fmlal_single_za16_vg2x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic;
   def int_aarch64_sme_fp8_fmlal_single_za16_vg2x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic;
   def int_aarch64_sme_fp8_fmlal_single_za16_vg2x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
+  // Multi
+  def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
+  def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
 
   // Quad-vector groups (F8F32)
   def int_aarch64_sme_fp8_fmlall_lane_za32_vg4x1 : SME_FP8_ZA_LANE_VGx1_Intrinsic;
@@ -4045,6 +4037,9 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sme_fp8_fmlall_single_za32_vg4x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic;
   def int_aarch64_sme_fp8_fmlall_single_za32_vg4x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic;
   def int_aarch64_sme_fp8_fmlall_single_za32_vg4x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
+  // Multi
+  def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
+  def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
 
   //
   // FP8 FDOT intrinsics
@@ -4061,6 +4056,12 @@ let TargetPrefix = "aarch64" in {
 
   def int_aarch64_sme_fp8_fdot_single_za16_vg1x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
   def int_aarch64_sme_fp8_fdot_single_za32_vg1x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
+  // Multi
+  def int_aarch64_sme_fp8_fdot_multi_za16_vg1x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
+  def int_aarch64_sme_fp8_fdot_multi_za32_vg1x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
+
+  def int_aarch64_sme_fp8_fdot_multi_za16_vg1x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
+  def int_aarch64_sme_fp8_fdot_multi_za32_vg1x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
 
   // FVDOT
   def int_aarch64_sme_fp8_fvdot_lane_za16_vg1x2  : SME_FP8_ZA_LANE_VGx2_Intrinsic;

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 22daaa12e29b98..fedf761f53b647 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -1003,8 +1003,9 @@ defm FMLAL_VG2_MZZ_BtoH  : sme2_fp8_fmlal_single_za16<"fmlal", int_aarch64_sme_f
 defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal",  0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8,  int_aarch64_sme_fp8_fmlal_single_za16_vg2x2, [FPMR, FPCR]>;
 defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x4, [FPMR, FPCR]>;
 
-defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal",   0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
-defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal",   0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+// FP8 FMLALL (multi)
+defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8,   int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2, [FPMR, FPCR]>;
+defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4, [FPMR, FPCR]>;
 
 defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
 } //[HasSMEF8F16]
@@ -1030,8 +1031,9 @@ defm FMLALL_MZZ_BtoS       : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixO
 defm FMLALL_VG2_M2ZZ_BtoS  : sme2_mla_ll_array_vg2_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x2, [FPMR, FPCR]>;
 defm FMLALL_VG4_M4ZZ_BtoS  : sme2_mla_ll_array_vg4_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x4, [FPMR, FPCR]>;
 
-defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall",   0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
-defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall",   0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+// FP8 FMLALL (multi)
+defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8,   int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2, [FPMR, FPCR]>;
+defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4, [FPMR, FPCR]>;
 
 defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
 } //[HasSMEF8F32]

diff  --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 69780b0765c5b0..81004e70dc179b 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -2262,11 +2262,12 @@ class sme2_mla_long_array_vg2_multi<string mnemonic, bits<2> op0, bits<3> op,
 }
 
 multiclass sme2_fp_mla_long_array_vg2_multi<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
-                                            RegisterOperand multi_vector_ty,
-                                            ValueType zpr_ty, SDPatternOperator intrinsic> {
-
+                                            RegisterOperand multi_vector_ty, ValueType zpr_ty,
+                                            SDPatternOperator intrinsic, list<Register> uses=[]> {
   def NAME : sme2_mla_long_array_vg2_multi<mnemonic, 0b10, op, matrix_ty, multi_vector_ty>,
-                                           SMEPseudo2Instr<NAME, 1>;
+                                           SMEPseudo2Instr<NAME, 1> {
+    let Uses = uses;
+  }
 
   def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm2s2range, multi_vector_ty, SMEMatrixArray>;
 
@@ -2309,9 +2310,11 @@ class sme2_mla_long_array_vg4_multi<string mnemonic, bits<2> op0, bits<3> op,
 
 multiclass sme2_fp_mla_long_array_vg4_multi<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
                                             RegisterOperand multi_vector_ty, ValueType zpr_ty,
-                                            SDPatternOperator intrinsic> {
+                                            SDPatternOperator intrinsic, list<Register> uses=[]> {
   def NAME : sme2_mla_long_array_vg4_multi<mnemonic, 0b10, op, matrix_ty, multi_vector_ty>,
-                                           SMEPseudo2Instr<NAME, 1>;
+                                           SMEPseudo2Instr<NAME, 1> {
+    let Uses = uses;
+  }
 
   def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm2s2range, multi_vector_ty, SMEMatrixArray>;
 
@@ -3341,9 +3344,11 @@ class sme2_mla_ll_array_vg2_multi<bits<5> op, MatrixOperand matrix_ty,
 
 multiclass sme2_mla_ll_array_vg2_multi<string mnemonic, bits<5> op,
                                        MatrixOperand matrix_ty,
-                                       RegisterOperand vector_ty,
-                                       ValueType vt, SDPatternOperator intrinsic> {
-  def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;
+                                       RegisterOperand vector_ty, ValueType vt,
+                                       SDPatternOperator intrinsic, list<Register> uses=[]> {
+  def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+    let Uses = uses;
+  }
 
   def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;
 
@@ -3385,9 +3390,11 @@ class sme2_mla_ll_array_vg4_multi<bits<5> op,MatrixOperand matrix_ty,
 
 multiclass sme2_mla_ll_array_vg4_multi<string mnemonic, bits<5> op,
                                        MatrixOperand matrix_ty,
-                                       RegisterOperand vector_ty,
-                                       ValueType vt, SDPatternOperator intrinsic> {
-  def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;
+                                       RegisterOperand vector_ty, ValueType vt,
+                                       SDPatternOperator intrinsic, list<Register> uses=[]> {
+  def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+    let Uses = uses;
+  }
 
   def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;
 

diff  --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll
index 41422e2cf9f560..80009d097e9a5b 100644
--- a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-mla.ll
@@ -213,3 +213,77 @@ define void @test_fmlall_single_vg4x4(i32 %slice, <vscale x 16 x i8> %zn0, <vsca
                                                              <vscale x 16 x i8> %zm)
     ret void
 }
+
+; FMLAL (multi)
+
+define void @test_fmlal_multi_vg2x2(i32 %slice, <vscale x 16 x i8> %zn0,  <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zm0,  <vscale x 16 x i8> %zm1) {
+; CHECK-LABEL: test_fmlal_multi_vg2x2:
+; CHECK:  // %bb.0:
+; CHECK:    mov w8, w0
+; CHECK:    fmlal za.h[w8, 0:1, vgx2], { z0.b, z1.b }, { z2.b, z3.b }
+; CHECK:    fmlal za.h[w8, 6:7, vgx2], { z0.b, z1.b }, { z2.b, z3.b }
+; CHECK:    ret
+    call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 %slice,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1)
+    %add = add i32 %slice, 6
+    call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 %add,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1)
+    ret void
+}
+
+define void @test_fmlal_multi_vg2x4(i32 %slice,  <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+; CHECK-LABEL: test_fmlal_multi_vg2x4:
+; CHECK:  // %bb.0:
+; CHECK:    mov w8, w0
+; CHECK:    fmlal za.h[w8, 0:1, vgx4], { z0.b - z3.b }, { z4.b - z7.b }
+; CHECK:    fmlal za.h[w8, 6:7, vgx4], { z0.b - z3.b }, { z4.b - z7.b }
+; CHECK:    ret
+                                    <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3) {
+    call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 %slice,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3)
+    %add = add i32 %slice, 6
+    call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 %add,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3)
+    ret void
+}
+
+; FMLALL (multi)
+
+define void @test_fmlal_multi_vg4x2(i32 %slice, <vscale x 16 x i8> %zn0,  <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zm0,  <vscale x 16 x i8> %zm1) {
+; CHECK-LABEL: test_fmlal_multi_vg4x2:
+; CHECK:  // %bb.0:
+; CHECK:    mov w8, w0
+; CHECK:    fmlall za.s[w8, 0:3, vgx2], { z0.b, z1.b }, { z2.b, z3.b }
+; CHECK:    fmlall za.s[w8, 4:7, vgx2], { z0.b, z1.b }, { z2.b, z3.b }
+; CHECK:    ret
+    call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 %slice,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1)
+    %add = add i32 %slice, 4
+    call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 %add,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1)
+    ret void
+}
+
+define void @test_fmlal_multi_vg4x4(i32 %slice,  <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+; CHECK-LABEL: test_fmlal_multi_vg4x4:
+; CHECK:  // %bb.0:
+; CHECK:    mov w8, w0
+; CHECK:    fmlall za.s[w8, 0:3, vgx4], { z0.b - z3.b }, { z4.b - z7.b }
+; CHECK:    fmlall za.s[w8, 4:7, vgx4], { z0.b - z3.b }, { z4.b - z7.b }
+; CHECK:    ret
+                                    <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3) {
+    call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 %slice,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3)
+    %add = add i32 %slice, 4
+    call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 %add,
+                                                           <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
+                                                           <vscale x 16 x i8> %zm0, <vscale x 16 x i8> %zm1, <vscale x 16 x i8> %zm2, <vscale x 16 x i8> %zm3)
+    ret void
+}


        


More information about the llvm-commits mailing list