[llvm] c217243 - [AArch64] Implements FP8 SVE intrinsics for dot-product (#118125)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 13 06:07:00 PST 2024
Author: Momchil Velikov
Date: 2024-12-13T14:06:54Z
New Revision: c2172431c72c9b249bf5bdfcc0c239fbfe64fa9b
URL: https://github.com/llvm/llvm-project/commit/c2172431c72c9b249bf5bdfcc0c239fbfe64fa9b
DIFF: https://github.com/llvm/llvm-project/commit/c2172431c72c9b249bf5bdfcc0c239fbfe64fa9b.diff
LOG: [AArch64] Implements FP8 SVE intrinsics for dot-product (#118125)
This patch adds the following intrinsics:
* 8-bit floating-point dot product to single-precision.
// Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT4) ||
__ARM_FEATURE_SSVE_FP8DOT4
svfloat32_t svdot[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn,
svmfloat8_t zm, fpm_t fpm);
svfloat32_t svdot[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn,
mfloat8_t zm, fpm_t fpm);
* 8-bit floating-point indexed dot product to single-precision.
// Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT4) ||
__ARM_FEATURE_SSVE_FP8DOT4
svfloat32_t svdot_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn,
svmfloat8_t zm,
uint64_t imm0_3, fpm_t fpm);
* 8-bit floating-point dot product to half-precision.
// Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT2) ||
__ARM_FEATURE_SSVE_FP8DOT2
svfloat16_t svdot[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn,
svmfloat8_t zm, fpm_t fpm);
svfloat16_t svdot[_n_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn,
mfloat8_t zm, fpm_t fpm);
* 8-bit floating-point indexed dot product to half-precision.
// Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT2) ||
__ARM_FEATURE_SSVE_FP8DOT2
svfloat16_t svdot_lane[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn,
svmfloat8_t zm,
uint64_t imm0_7, fpm_t fpm);
Added:
clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sve2_fp8_fdot.c
llvm/test/CodeGen/AArch64/fp8-sve-fdot.ll
Modified:
clang/include/clang/Basic/arm_sve.td
clang/include/clang/Basic/arm_sve_sme_incl.td
clang/lib/CodeGen/CGBuiltin.cpp
clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c
clang/utils/TableGen/SveEmitter.cpp
llvm/include/llvm/IR/IntrinsicsAArch64.td
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/arm_sve.td b/clang/include/clang/Basic/arm_sve.td
index b9f40faf0b18e6..2c8ca8014387d3 100644
--- a/clang/include/clang/Basic/arm_sve.td
+++ b/clang/include/clang/Basic/arm_sve.td
@@ -2476,3 +2476,22 @@ let SVETargetGuard = "sve2,fp8", SMETargetGuard = "sme2,fp8" in {
def SVFCVTNB : SInst<"svcvtnb_mf8[_f32_x2]_fpm", "~2>", "f", MergeNone, "aarch64_sve_fp8_cvtnb", [VerifyRuntimeMode, SetsFPMR]>;
def SVFCVTNT : SInst<"svcvtnt_mf8[_f32_x2]_fpm", "~~2>", "f", MergeNone, "aarch64_sve_fp8_cvtnt", [VerifyRuntimeMode, SetsFPMR]>;
}
+
+let SVETargetGuard = "sve2,fp8dot2", SMETargetGuard ="sme,ssve-fp8dot2" in {
+ // 8-bit floating-point dot product to half-precision (vectors)
+ def SVFDOT_2WAY : SInst<"svdot[_f16_mf8]_fpm", "dd~~>", "h", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
+ def SVFDOT_N_2WAY : SInst<"svdot[_n_f16_mf8]_fpm", "dd~!>", "h", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
+
+ // 8-bit floating-point dot product to half-precision (indexed)
+ def SVFDOT_LANE_2WAY : SInst<"svdot_lane[_f16_mf8]_fpm", "dd~~i>", "h", MergeNone, "aarch64_sve_fp8_fdot_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
+}
+
+let SVETargetGuard = "sve2,fp8dot4", SMETargetGuard ="sme,ssve-fp8dot4" in {
+ // 8-bit floating-point dot product to single-precision (vectors)
+ def SVFDOT_4WAY : SInst<"svdot[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
+ def SVFDOT_N_4WAY : SInst<"svdot[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
+
+ // 8-bit floating-point dot product to single-precision (indexed)
+ def SVFDOT_LANE_4WAY : SInst<"svdot_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fdot_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_3>]>;
+}
+
diff --git a/clang/include/clang/Basic/arm_sve_sme_incl.td b/clang/include/clang/Basic/arm_sve_sme_incl.td
index e7cc40db7dca6c..ee899209ad832b 100644
--- a/clang/include/clang/Basic/arm_sve_sme_incl.td
+++ b/clang/include/clang/Basic/arm_sve_sme_incl.td
@@ -89,6 +89,7 @@ include "arm_immcheck_incl.td"
// j: element type promoted to 64bits (splat to vector type)
// K: element type bitcast to a signed integer (splat to vector type)
// L: element type bitcast to an unsigned integer (splat to vector type)
+// !: mfloat8_t (splat to svmfloat8_t)
//
// i: constant uint64_t
// k: int32_t
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 49a4c1ecc825e7..84048a4beac2c5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10719,7 +10719,16 @@ Value *CodeGenFunction::EmitSVEDupX(Value *Scalar, llvm::Type *Ty) {
cast<llvm::VectorType>(Ty)->getElementCount(), Scalar);
}
-Value *CodeGenFunction::EmitSVEDupX(Value* Scalar) {
+Value *CodeGenFunction::EmitSVEDupX(Value *Scalar) {
+ if (auto *Ty = Scalar->getType(); Ty->isVectorTy()) {
+#ifndef NDEBUG
+ auto *VecTy = cast<llvm::VectorType>(Ty);
+ ElementCount EC = VecTy->getElementCount();
+ assert(EC.isScalar() && VecTy->getElementType() == Int8Ty &&
+ "Only <1 x i8> expected");
+#endif
+ Scalar = Builder.CreateExtractElement(Scalar, uint64_t(0));
+ }
return EmitSVEDupX(Scalar, getSVEVectorForElementType(Scalar->getType()));
}
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sve2_fp8_fdot.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sve2_fp8_fdot.c
new file mode 100644
index 00000000000000..950a19115811ec
--- /dev/null
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sve2_fp8_fdot.c
@@ -0,0 +1,149 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CHECK-CXX
+
+// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -x c++ -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CHECK-CXX
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
+
+// REQUIRES: aarch64-registered-target
+
+#ifdef __ARM_FEATURE_SME
+#include <arm_sme.h>
+#else
+#include <arm_sve.h>
+#endif
+
+#ifdef SVE_OVERLOADED_FORMS
+#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
+#else
+#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
+#endif
+
+#ifdef __ARM_FEATURE_SME
+#define STREAMING __arm_streaming
+#else
+#define STREAMING
+#endif
+
+// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_f32_mf8(
+// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT: ret <vscale x 4 x float> [[TMP0]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z18test_svdot_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tS0_m(
+// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP0]]
+//
+svfloat32_t test_svdot_f32_mf8(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot,_f32_mf8,_fpm)(zda, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_n_f32_mf8(
+// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
+// CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z20test_svdot_n_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tu6__mfp8m(
+// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
+// CHECK-CXX-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
+// CHECK-CXX-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
+// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP1]]
+//
+svfloat32_t test_svdot_n_f32_mf8(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot,_n_f32_mf8,_fpm)(zda, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_f16_mf8(
+// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT: ret <vscale x 8 x half> [[TMP0]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z18test_svdot_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tS0_m(
+// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP0]]
+//
+svfloat16_t test_svdot_f16_mf8(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot,_f16_mf8,_fpm)(zda, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_n_f16_mf8(
+// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
+// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
+// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
+// CHECK-NEXT: ret <vscale x 8 x half> [[TMP1]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z20test_svdot_n_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tu6__mfp8m(
+// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
+// CHECK-CXX-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
+// CHECK-CXX-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
+// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP1]]
+//
+svfloat16_t test_svdot_n_f16_mf8(svfloat16_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot,_n_f16_mf8,_fpm)(zda, zn, zm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_lane_f32_mf8(
+// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.lane.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 3)
+// CHECK-NEXT: ret <vscale x 4 x float> [[TMP0]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z23test_svdot_lane_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tS0_m(
+// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.lane.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 3)
+// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP0]]
+//
+svfloat32_t test_svdot_lane_f32_mf8(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot_lane,_f32_mf8,_fpm)(zda, zn, zm, 3, fpm);
+}
+
+// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_lane_f16_mf8(
+// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.lane.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 7)
+// CHECK-NEXT: ret <vscale x 8 x half> [[TMP0]]
+//
+// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z23test_svdot_lane_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tS0_m(
+// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.lane.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 7)
+// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP0]]
+//
+svfloat16_t test_svdot_lane_f16_mf8(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
+ return SVE_ACLE_FUNC(svdot_lane,_f16_mf8,_fpm)(zda, zn, zm, 7, fpm);
+}
diff --git a/clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c b/clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c
index e47efccf480433..d76e729b6a39c4 100644
--- a/clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c
+++ b/clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c
@@ -4,7 +4,7 @@
#include <arm_sve.h>
-void test_features(svmfloat8_t zn, fpm_t fpm) {
+void test_features(svmfloat8_t zn, svmfloat8_t zm, mfloat8_t x, fpm_t fpm) {
svcvt1_bf16_mf8_fpm(zn, fpm);
// expected-error at -1 {{'svcvt1_bf16_mf8_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
svcvt2_bf16_mf8_fpm(zn, fpm);
@@ -30,4 +30,25 @@ void test_features(svmfloat8_t zn, fpm_t fpm) {
// expected-error at -1 {{'svcvtnb_mf8_f32_x2_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
svcvtnt_mf8_f32_x2_fpm(zn, svcreate2(svundef_f32(), svundef_f32()), fpm);
// expected-error at -1 {{'svcvtnt_mf8_f32_x2_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
+
+ svdot_f32_mf8_fpm(svundef_f32(), zn, zm, fpm);
+// expected-error at -1 {{'svdot_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
+ svdot_n_f32_mf8_fpm(svundef_f32(), zn, x, fpm);
+// expected-error at -1 {{'svdot_n_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
+ svdot_f16_mf8_fpm(svundef_f16(), zn, zm, fpm);
+// expected-error at -1 {{'svdot_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
+ svdot_n_f16_mf8_fpm(svundef_f16(), zn, x, fpm);
+// expected-error at -1 {{'svdot_n_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
+ svdot_lane_f32_mf8_fpm(svundef_f32(), zn, zm, 3, fpm);
+// expected-error at -1 {{'svdot_lane_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
+ svdot_lane_f16_mf8_fpm(svundef_f16(), zn, zm, 7, fpm);
+// expected-error at -1 {{'svdot_lane_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
}
+
+
+void test_imm_range(svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) {
+ svdot_lane_f32_mf8_fpm(svundef_f32(), zn, zm, -1, fpm);
+// expected-error at -1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
+ svdot_lane_f16_mf8_fpm(svundef_f16(), zn, zm, -1, fpm);
+// expected-error at -1 {{argument value 18446744073709551615 is outside the valid range [0, 7]}}
+}
\ No newline at end of file
diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp
index 2d9f5c3381018a..14e5637f62517e 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -253,7 +253,7 @@ class Intrinsic {
/// Return true if the intrinsic takes a splat operand.
bool hasSplat() const {
// These prototype modifiers are described in arm_sve.td.
- return Proto.find_first_of("ajfrKLR@") != std::string::npos;
+ return Proto.find_first_of("ajfrKLR@!") != std::string::npos;
}
/// Return the parameter index of the splat operand.
@@ -262,7 +262,7 @@ class Intrinsic {
for (; I < Proto.size(); ++I, ++Param) {
if (Proto[I] == 'a' || Proto[I] == 'j' || Proto[I] == 'f' ||
Proto[I] == 'r' || Proto[I] == 'K' || Proto[I] == 'L' ||
- Proto[I] == 'R' || Proto[I] == '@')
+ Proto[I] == 'R' || Proto[I] == '@' || Proto[I] == '!')
break;
// Multivector modifier can be skipped
@@ -910,6 +910,11 @@ void SVEType::applyModifier(char Mod) {
Kind = MFloat8;
ElementBitwidth = 8;
break;
+ case '!':
+ Kind = MFloat8;
+ Bitwidth = ElementBitwidth = 8;
+ NumVectors = 0;
+ break;
case '.':
llvm_unreachable(". is never a type in itself");
break;
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 654bc64a30bd89..eeecc5bb75cc1e 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -3911,6 +3911,22 @@ let TargetPrefix = "aarch64" in {
[llvm_nxv16i8_ty, llvm_anyvector_ty, LLVMMatchType<0>],
[IntrReadMem, IntrInaccessibleMemOnly]>;
+ // Dot product
+ class SVE2_FP8_FMLA_FDOT
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_nxv16i8_ty, llvm_nxv16i8_ty],
+ [IntrReadMem, IntrInaccessibleMemOnly]>;
+
+ class SVE2_FP8_FMLA_FDOT_Lane
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_i32_ty],
+ [IntrReadMem, IntrInaccessibleMemOnly, ImmArg<ArgIndex<3>>]>;
+
+ def int_aarch64_sve_fp8_fdot : SVE2_FP8_FMLA_FDOT;
+ def int_aarch64_sve_fp8_fdot_lane : SVE2_FP8_FMLA_FDOT_Lane;
+
class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
[llvm_nxv16i8_ty],
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1a5be28dce4a0c..6971aae6dbe5b7 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4423,18 +4423,17 @@ let Predicates = [HasSVE2, HasF8F16MM] in {
let Predicates = [HasSSVE_FP8DOT2] in {
// FP8 Widening Dot-Product - Indexed Group
-defm FDOT_ZZZI_BtoH : sve2_fp8_dot_indexed_h<"fdot">;
+defm FDOT_ZZZI_BtoH : sve2_fp8_dot_indexed_h<"fdot", int_aarch64_sve_fp8_fdot_lane>;
// FP8 Widening Dot-Product - Group
-// TODO: Replace nxv16i8 by nxv16f8
-defm FDOT_ZZZ_BtoH : sve_fp8_dot<0b0, ZPR16, "fdot">;
+defm FDOT_ZZZ_BtoH : sve_fp8_dot<0b0, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fp8_fdot>;
}
// TODO: Replace nxv16i8 by nxv16f8
let Predicates = [HasSSVE_FP8DOT4] in {
// FP8 Widening Dot-Product - Indexed Group
-defm FDOT_ZZZI_BtoS : sve2_fp8_dot_indexed_s<"fdot">;
+defm FDOT_ZZZI_BtoS : sve2_fp8_dot_indexed_s<"fdot", int_aarch64_sve_fp8_fdot_lane>;
// FP8 Widening Dot-Product - Group
-defm FDOT_ZZZ_BtoS : sve_fp8_dot<0b1, ZPR32, "fdot">;
+defm FDOT_ZZZ_BtoS : sve_fp8_dot<0b1, ZPR32, "fdot", nxv4f32, int_aarch64_sve_fp8_fdot>;
}
let Predicates = [HasSVE2orSME2, HasLUT] in {
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 22b56f7f3e9a9e..abb12487dc80ce 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9267,10 +9267,16 @@ multiclass sve_float_dot<bit bf, bit o2, ZPRRegOp dst_ty, ZPRRegOp src_ty,
def : SVE_3_Op_Pat<nxv4f32, op, nxv4f32, InVT, InVT, !cast<Instruction>(NAME)>;
}
-multiclass sve_fp8_dot<bit bf, ZPRRegOp dst_ty, string asm> {
- def NAME : sve_float_dot<bf, 0b1, dst_ty, ZPR8, asm>{
+multiclass sve_fp8_dot<bit bf, ZPRRegOp dstrc, string asm, ValueType vt,
+ SDPatternOperator op> {
+ def NAME : sve_float_dot<bf, 0b1, dstrc, ZPR8, asm> {
let Uses = [FPMR, FPCR];
+
+ let mayLoad = 1;
+ let mayStore = 0;
}
+
+ def : SVE_3_Op_Pat<vt, op, vt, nxv16i8, nxv16i8, !cast<Instruction>(NAME)>;
}
class sve_float_dot_indexed<bit bf, ZPRRegOp dst_ty, ZPRRegOp src1_ty,
@@ -10955,24 +10961,31 @@ class sve_fp8_dot_indexed<bits<4> opc, ZPRRegOp dst_ty, Operand iop_ty, string m
let DestructiveInstType = DestructiveOther;
let hasSideEffects = 0;
let mayRaiseFPException = 1;
+
+ let mayLoad = 1;
+ let mayStore = 0;
}
// FP8 Widening Dot-Product - Indexed Group
-multiclass sve2_fp8_dot_indexed_h<string asm>{
- def NAME : sve_fp8_dot_indexed<{0, ?, ?, ?}, ZPR16, VectorIndexH, asm> {
+multiclass sve2_fp8_dot_indexed_h<string asm, SDPatternOperator op> {
+ def NAME : sve_fp8_dot_indexed<{0, ?, ?, ?}, ZPR16, VectorIndexH32b, asm> {
bits<3> iop;
let Inst{20-19} = iop{2-1};
let Inst{11} = iop{0};
}
+
+ def : SVE_4_Op_Pat<nxv8f16, op, nxv8f16, nxv16i8, nxv16i8, i32, !cast<Instruction>(NAME)>;
}
-multiclass sve2_fp8_dot_indexed_s<string asm>{
+multiclass sve2_fp8_dot_indexed_s<string asm, SDPatternOperator op> {
def NAME : sve_fp8_dot_indexed<{1, ?, ?, 0}, ZPR32, VectorIndexS32b, asm> {
bits<2> iop;
let Inst{20-19} = iop{1-0};
}
+
+ def : SVE_4_Op_Pat<nxv4f32, op, nxv4f32, nxv16i8, nxv16i8, i32, !cast<Instruction>(NAME)>;
}
// FP8 Look up table
diff --git a/llvm/test/CodeGen/AArch64/fp8-sve-fdot.ll b/llvm/test/CodeGen/AArch64/fp8-sve-fdot.ll
new file mode 100644
index 00000000000000..0cead19a74bfd5
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp8-sve-fdot.ll
@@ -0,0 +1,41 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve2,+fp8,+fp8dot2,+fp8dot4 < %s | FileCheck %s
+; RUN: llc -mattr=+sme,+fp8,+ssve-fp8dot2,+ssve-fp8dot4 --force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-linux"
+
+define <vscale x 4 x float> @fdot_4way(<vscale x 4 x float> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2) {
+; CHECK-LABEL: fdot_4way:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %r = call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2)
+ ret <vscale x 4 x float> %r
+}
+
+define <vscale x 8 x half> @fdot_2way(<vscale x 8 x half> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2) {
+; CHECK-LABEL: fdot_2way:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot z0.h, z1.b, z2.b
+; CHECK-NEXT: ret
+ %r = call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2)
+ ret <vscale x 8 x half> %r
+}
+
+define <vscale x 4 x float> @fdot_4way_lane(<vscale x 4 x float> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2) {
+; CHECK-LABEL: fdot_4way_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot z0.s, z1.b, z2.b[3]
+; CHECK-NEXT: ret
+ %r = call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.lane.nxv4f32(<vscale x 4 x float> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2, i32 3)
+ ret <vscale x 4 x float> %r
+}
+
+define <vscale x 8 x half> @fdot_2way_lane(<vscale x 8 x half> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2) {
+; CHECK-LABEL: fdot_2way_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot z0.h, z1.b, z2.b[5]
+; CHECK-NEXT: ret
+ %r = call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.lane.nxv8f16(<vscale x 8 x half> %a, <vscale x 16 x i8> %s1, <vscale x 16 x i8> %s2, i32 5)
+ ret <vscale x 8 x half> %r
+}
More information about the llvm-commits
mailing list