[clang] [llvm] [AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed) (PR #123615)
Momchil Velikov via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 27 12:27:16 PST 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/123615
>From ae09723ecc1cc9bc2cbcef300b05aa2ce5ced448 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 17 Dec 2024 11:42:42 +0000
Subject: [PATCH 1/3] [AArch64] Add FP8 Neon intrinsics for dot-product
THis patch adds the following intrinsics:
float16x4_t vdot_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpm)
float16x8_t vdotq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm)
float16x4_t vdot_lane_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x4_t vdot_laneq_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x8_t vdotq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x8_t vdotq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
[fixup] Remove not needed argument (NFC)
[fixup] Update intrinsics declarations
[fixup] Add C++ runs to tests, remove some opt passes
---
clang/include/clang/Basic/arm_neon.td | 20 ++
clang/include/clang/Basic/arm_neon_incl.td | 2 +-
clang/lib/CodeGen/CGBuiltin.cpp | 44 +++
clang/lib/CodeGen/CodeGenFunction.h | 4 +
.../fp8-intrinsics/acle_neon_fp8_fdot.c | 254 ++++++++++++++++++
.../acle_neon_fp8_fdot.c | 54 ++++
llvm/include/llvm/IR/IntrinsicsAArch64.td | 21 ++
.../lib/Target/AArch64/AArch64InstrFormats.td | 82 +++---
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 14 +-
llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll | 74 +++++
10 files changed, 529 insertions(+), 40 deletions(-)
create mode 100644 clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c
create mode 100644 clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c
create mode 100644 llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index 9a6a77640ef5d3..c6609f312969ee 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -2141,6 +2141,26 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
def VCVTN_F8_F16 : VInst<"vcvt_mf8_f16_fpm", ".(>F)(>F)V", "mQm">;
}
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot2,neon" in {
+ def VDOT_F16_MF8 : VInst<"vdot_f16_mf8_fpm", "(>F)(>F)..V", "mQm">;
+
+ def VDOT_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F)..IV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
+ def VDOT_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F).QIV", "m", [ImmCheck<3, ImmCheck0_7, 0>]>;
+
+ def VDOTQ_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
+ def VDOTQ_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
+ def VDOT_F32_MF8 : VInst<"vdot_f32_mf8_fpm", "(>>F)(>>F)..V", "mQm">;
+
+ def VDOT_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F)..IV", "m", [ImmCheck<3, ImmCheck0_1, 0>]>;
+ def VDOT_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F).QIV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
+
+ def VDOTQ_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_1, 0>]>;
+ def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
+}
+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index 91a2bf3020b9a3..b9b9d509c22512 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -302,7 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
-class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
+class VInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
// The following instruction classes are implemented via operators
// instead of builtins. As such these declarations are only used for
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0a06ce028a9160..b4b26eb84d5f92 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6766,6 +6766,24 @@ Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
return EmitNeonCall(F, Ops, name);
}
+llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
+ unsigned IID, bool ExtendLane, llvm::Type *RetTy,
+ SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
+
+ const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
+ RetTy->getPrimitiveSizeInBits();
+ llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
+ Ops[1]->getType()};
+ if (ExtendLane) {
+ auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
+ Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
+ Builder.getInt64(0));
+ }
+ llvm::Value *FPM =
+ EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
+ return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
+}
+
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
bool neg) {
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12761,6 +12779,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
unsigned Int;
bool ExtractLow = false;
+ bool ExtendLane = false;
switch (BuiltinID) {
default: return nullptr;
case NEON::BI__builtin_neon_vbsl_v:
@@ -14028,6 +14047,31 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2, Ty,
Ops[1]->getType(), false, Ops, E, "vfcvtn2");
}
+
+ case NEON::BI__builtin_neon_vdot_f16_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_f16_mf8_fpm:
+ return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2, false, HalfTy,
+ Ops, E, "fdot2");
+ case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
+ ExtendLane = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
+ return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
+ ExtendLane, HalfTy, Ops, E, "fdot2_lane");
+ case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
+ return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
+ FloatTy, Ops, E, "fdot4");
+ case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
+ ExtendLane = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
+ case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
+ return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
+ ExtendLane, FloatTy, Ops, E, "fdot4_lane");
case NEON::BI__builtin_neon_vamin_f16:
case NEON::BI__builtin_neon_vaminq_f16:
case NEON::BI__builtin_neon_vamin_f32:
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5416ab91c8d61..fd6d44b2579b92 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4699,6 +4699,10 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Type *Ty1, bool Extract,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
+ llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
+ llvm::Type *RetTy,
+ SmallVectorImpl<llvm::Value *> &Ops,
+ const CallExpr *E, const char *name);
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
const llvm::ElementCount &Count);
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c
new file mode 100644
index 00000000000000..4d2f5d550c4dcb
--- /dev/null
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c
@@ -0,0 +1,254 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s
+// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -O3 -Werror -Wall -S -o /dev/null %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_neon.h>
+
+// CHECK-LABEL: define dso_local <4 x half> @test_vdot_f16(
+// CHECK-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT21_I:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.v4f16.v8i8(<4 x half> [[VD]], <8 x i8> [[VN]], <8 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x half> [[FDOT21_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x half> @_Z13test_vdot_f1613__Float16x4_t13__Mfloat8x8_tS0_m(
+// CHECK-CXX-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT21_I:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.v4f16.v8i8(<4 x half> [[VD]], <8 x i8> [[VN]], <8 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x half> [[FDOT21_I]]
+//
+float16x4_t test_vdot_f16(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdot_f16_mf8_fpm(vd, vn, vm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vdotq_f16(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT21_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.v8f16.v16i8(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <8 x half> [[FDOT21_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z14test_vdotq_f1613__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT21_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.v8f16.v16i8(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <8 x half> [[FDOT21_I]]
+//
+float16x8_t test_vdotq_f16(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdotq_f16_mf8_fpm(vd, vn, vm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <4 x half> @test_vdot_lane_f16(
+// CHECK-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT2_LANE:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x half>
+// CHECK-NEXT: [[FDOT2_LANE1:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v4f16.v8i8(<4 x half> [[FDOT2_LANE]], <8 x i8> [[VN]], <16 x i8> [[TMP1]], i32 3)
+// CHECK-NEXT: ret <4 x half> [[FDOT2_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x half> @_Z18test_vdot_lane_f1613__Float16x4_t13__Mfloat8x8_tS0_m(
+// CHECK-CXX-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT2_LANE:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x half>
+// CHECK-CXX-NEXT: [[FDOT2_LANE1:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v4f16.v8i8(<4 x half> [[FDOT2_LANE]], <8 x i8> [[VN]], <16 x i8> [[TMP1]], i32 3)
+// CHECK-CXX-NEXT: ret <4 x half> [[FDOT2_LANE1]]
+//
+float16x4_t test_vdot_lane_f16(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdot_lane_f16_mf8_fpm(vd, vn, vm, 3, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <4 x half> @test_vdot_laneq_f16(
+// CHECK-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT2_LANE:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x half>
+// CHECK-NEXT: [[FDOT2_LANE1:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v4f16.v8i8(<4 x half> [[FDOT2_LANE]], <8 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-NEXT: ret <4 x half> [[FDOT2_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x half> @_Z19test_vdot_laneq_f1613__Float16x4_t13__Mfloat8x8_t14__Mfloat8x16_tm(
+// CHECK-CXX-SAME: <4 x half> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <4 x half> [[VD]] to <8 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT2_LANE:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x half>
+// CHECK-CXX-NEXT: [[FDOT2_LANE1:%.*]] = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v4f16.v8i8(<4 x half> [[FDOT2_LANE]], <8 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-CXX-NEXT: ret <4 x half> [[FDOT2_LANE1]]
+//
+float16x4_t test_vdot_laneq_f16(float16x4_t vd, mfloat8x8_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdot_laneq_f16_mf8_fpm(vd, vn, vm, 7, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vdotq_lane_f16(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT2_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[FDOT2_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v8f16.v16i8(<8 x half> [[FDOT2_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 3)
+// CHECK-NEXT: ret <8 x half> [[FDOT2_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z19test_vdotq_lane_f1613__Float16x8_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT2_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[FDOT2_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v8f16.v16i8(<8 x half> [[FDOT2_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 3)
+// CHECK-CXX-NEXT: ret <8 x half> [[FDOT2_LANE1]]
+//
+float16x8_t test_vdotq_lane_f16(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdotq_lane_f16_mf8_fpm(vd, vn, vm, 3, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vdotq_laneq_f16(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT2_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[FDOT2_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v8f16.v16i8(<8 x half> [[FDOT2_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-NEXT: ret <8 x half> [[FDOT2_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z20test_vdotq_laneq_f1613__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT2_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[FDOT2_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v8f16.v16i8(<8 x half> [[FDOT2_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-CXX-NEXT: ret <8 x half> [[FDOT2_LANE1]]
+//
+float16x8_t test_vdotq_laneq_f16(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdotq_laneq_f16_mf8_fpm(vd, vn, vm, 7, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <2 x float> @test_vdot_f32(
+// CHECK-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_I:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <8 x i8> [[VM]])
+// CHECK-NEXT: ret <2 x float> [[FDOT4_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <2 x float> @_Z13test_vdot_f3213__Float32x2_t13__Mfloat8x8_tS0_m(
+// CHECK-CXX-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_I:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <8 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <2 x float> [[FDOT4_I]]
+//
+float32x2_t test_vdot_f32(float32x2_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdot_f32_mf8_fpm(vd, vn, vm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vdotq_f32(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x float> [[FDOT4_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z14test_vdotq_f3213__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x float> [[FDOT4_I]]
+//
+float32x4_t test_vdotq_f32(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdotq_f32_mf8_fpm(vd, vn, vm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <2 x float> @test_vdot_lane_f32(
+// CHECK-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_LANE:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <16 x i8> [[TMP0]], i32 1)
+// CHECK-NEXT: ret <2 x float> [[FDOT4_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <2 x float> @_Z18test_vdot_lane_f3213__Float32x2_t13__Mfloat8x8_tS0_m(
+// CHECK-CXX-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_LANE:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <16 x i8> [[TMP0]], i32 1)
+// CHECK-CXX-NEXT: ret <2 x float> [[FDOT4_LANE]]
+//
+float32x2_t test_vdot_lane_f32(float32x2_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdot_lane_f32_mf8_fpm(vd, vn, vm, 1, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <2 x float> @test_vdot_laneq_f32(
+// CHECK-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_LANE:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-NEXT: ret <2 x float> [[FDOT4_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <2 x float> @_Z19test_vdot_laneq_f3213__Float32x2_t13__Mfloat8x8_t14__Mfloat8x16_tm(
+// CHECK-CXX-SAME: <2 x float> noundef [[VD:%.*]], <8 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_LANE:%.*]] = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v2f32.v8i8(<2 x float> [[VD]], <8 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-CXX-NEXT: ret <2 x float> [[FDOT4_LANE]]
+//
+float32x2_t test_vdot_laneq_f32(float32x2_t vd, mfloat8x8_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdot_laneq_f32_mf8_fpm(vd, vn, vm, 3, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vdotq_lane_f32(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 1)
+// CHECK-NEXT: ret <4 x float> [[FDOT4_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vdotq_lane_f3213__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 1)
+// CHECK-CXX-NEXT: ret <4 x float> [[FDOT4_LANE]]
+//
+float32x4_t test_vdotq_lane_f32(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpmr) {
+ return vdotq_lane_f32_mf8_fpm(vd, vn, vm, 1, fpmr);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vdotq_laneq_f32(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: [[FDOT4_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-NEXT: ret <4 x float> [[FDOT4_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z20test_vdotq_laneq_f3213__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-CXX-NEXT: [[FDOT4_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v4f32.v16i8(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-CXX-NEXT: ret <4 x float> [[FDOT4_LANE]]
+//
+float32x4_t test_vdotq_laneq_f32(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpmr) {
+ return vdotq_laneq_f32_mf8_fpm(vd, vn, vm, 3, fpmr);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c
new file mode 100644
index 00000000000000..8bfe3ac26ab2c3
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c
@@ -0,0 +1,54 @@
+// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_neon.h>
+
+void test_features(float16x4_t vd4, float16x8_t vd8, float32x4_t va4, float32x2_t va2,
+ mfloat8x8_t v8, mfloat8x16_t v16, fpm_t fpm) {
+ (void) vdot_f16_mf8_fpm(vd4, v8, v8, fpm);
+// expected-error at -1 {{'vdot_f16_mf8_fpm' requires target feature 'fp8dot2'}}
+ (void) vdotq_f16_mf8_fpm(vd8, v16, v16, fpm);
+// expected-error at -1 {{'vdotq_f16_mf8_fpm' requires target feature 'fp8dot2'}}
+ (void) vdot_lane_f16_mf8_fpm(vd4, v8, v8, 3, fpm);
+// expected-error at -1 {{'__builtin_neon_vdot_lane_f16_mf8_fpm' needs target feature fp8dot2,neon}}
+ (void) vdot_laneq_f16_mf8_fpm(vd4, v8, v16, 7, fpm);
+// expected-error at -1 {{'__builtin_neon_vdot_laneq_f16_mf8_fpm' needs target feature fp8dot2,neon}}
+ (void) vdotq_lane_f16_mf8_fpm(vd8, v16, v8, 3, fpm);
+// expected-error at -1 {{'__builtin_neon_vdotq_lane_f16_mf8_fpm' needs target feature fp8dot2,neon}}
+ (void) vdotq_laneq_f16_mf8_fpm(vd8, v16, v16, 7, fpm);
+// expected-error at -1 {{'__builtin_neon_vdotq_laneq_f16_mf8_fpm' needs target feature fp8dot2,neon}}
+
+ (void) vdot_f32_mf8_fpm(va2, v8, v8, fpm);
+// expected-error at -1 {{'vdot_f32_mf8_fpm' requires target feature 'fp8dot4'}}
+ (void) vdotq_f32_mf8_fpm(va4, v16, v16, fpm);
+// expected-error at -1 {{'vdotq_f32_mf8_fpm' requires target feature 'fp8dot4}}
+ (void) vdot_lane_f32_mf8_fpm(va2, v8, v8, 1, fpm);
+// expected-error at -1 {{'__builtin_neon_vdot_lane_f32_mf8_fpm' needs target feature fp8dot4,neon}}
+ (void) vdot_laneq_f32_mf8_fpm(va2, v8, v16, 3, fpm);
+// expected-error at -1 {{'__builtin_neon_vdot_laneq_f32_mf8_fpm' needs target feature fp8dot4,neon}}
+ (void) vdotq_lane_f32_mf8_fpm(va4, v16, v8, 1, fpm);
+// expected-error at -1 {{'__builtin_neon_vdotq_lane_f32_mf8_fpm' needs target feature fp8dot4,neon}}
+ (void) vdotq_laneq_f32_mf8_fpm(va4, v16, v16, 3, fpm);
+// expected-error at -1 {{'__builtin_neon_vdotq_laneq_f32_mf8_fpm' needs target feature fp8dot4,neon}}
+}
+
+void test_imm(float16x4_t vd4, float16x8_t vd8, float32x2_t va2, float32x4_t va4,
+ mfloat8x8_t v8, mfloat8x16_t v16, fpm_t fpm) {
+ (void) vdot_lane_f16_mf8_fpm(vd4, v8, v8, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 3]}}
+ (void) vdot_laneq_f16_mf8_fpm(vd4, v8, v16, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+ (void) vdotq_lane_f16_mf8_fpm(vd8, v16, v8, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 3]}}
+ (void) vdotq_laneq_f16_mf8_fpm(vd8, v16, v16, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+ (void) vdot_lane_f32_mf8_fpm(va2, v8, v8, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 1]}}
+ (void) vdot_laneq_f32_mf8_fpm(va2, v8, v16, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 3]}}
+ (void) vdotq_lane_f32_mf8_fpm(va4, v16, v8, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 1]}}
+ (void) vdotq_laneq_f32_mf8_fpm(va4, v16, v16, -1, fpm);
+ // expected-error at -1 {{argument value -1 is outside the valid range [0, 3]}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 31c9546376c820..395db293063f45 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -1015,6 +1015,27 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
llvm_anyvector_ty,
LLVMMatchType<1>],
[IntrReadMem, IntrInaccessibleMemOnly]>;
+
+ // Dot-product
+ class AdvSIMD_FP8_DOT_Intrinsic
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_anyvector_ty,
+ LLVMMatchType<1>],
+ [IntrReadMem, IntrInaccessibleMemOnly]>;
+ class AdvSIMD_FP8_DOT_LANE_Intrinsic
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_anyvector_ty,
+ llvm_v16i8_ty,
+ llvm_i32_ty],
+ [IntrReadMem, IntrInaccessibleMemOnly, ImmArg<ArgIndex<3>>]>;
+
+ def int_aarch64_neon_fp8_fdot2 : AdvSIMD_FP8_DOT_Intrinsic;
+ def int_aarch64_neon_fp8_fdot2_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
+
+ def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic;
+ def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
}
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 67b43664548457..dea2af16e3184a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -6585,19 +6585,22 @@ multiclass SIMD_FP8_CVTN_F32<string asm, SDPatternOperator Op> {
(!cast<Instruction>(NAME # 2v16f8) V128:$_Rd, V128:$Rn, V128:$Rm)>;
}
-// TODO: Create a new Value Type v8f8 and v16f8
-multiclass SIMDThreeSameVectorDOT2<string asm> {
- def v4f16 : BaseSIMDThreeSameVectorDot<0b0, 0b0, 0b01, 0b1111, asm, ".4h", ".8b",
- V64, v4f16, v8i8, null_frag>;
- def v8f16 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b01, 0b1111, asm, ".8h", ".16b",
- V128, v8f16, v16i8, null_frag>;
+multiclass SIMD_FP8_Dot2<string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v4f16 : BaseSIMDThreeSameVectorDot<0b0, 0b0, 0b01, 0b1111, asm, ".4h", ".8b",
+ V64, v4f16, v8i8, op>;
+ def v8f16 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b01, 0b1111, asm, ".8h", ".16b",
+ V128, v8f16, v16i8, op>;
+ }
}
-multiclass SIMDThreeSameVectorDOT4<string asm> {
- def v2f32 : BaseSIMDThreeSameVectorDot<0b0, 0b0, 0b00, 0b1111, asm, ".2s", ".8b",
- V64, v2f32, v8i8, null_frag>;
- def v4f32 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b00, 0b1111, asm, ".4s", ".16b",
- V128, v4f32, v16i8, null_frag>;
+multiclass SIMD_FP8_Dot4<string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v2f32 : BaseSIMDThreeSameVectorDot<0b0, 0b0, 0b00, 0b1111, asm, ".2s", ".8b",
+ V64, v2f32, v8i8, op>;
+ def v4f32 : BaseSIMDThreeSameVectorDot<0b1, 0b0, 0b00, 0b1111, asm, ".4s", ".16b",
+ V128, v4f32, v16i8, op>;
+ }
}
let mayRaiseFPException = 1, Uses = [FPCR] in
@@ -9138,15 +9141,16 @@ class BaseSIMDThreeSameVectorIndexS<bit Q, bit U, bits<2> size, bits<4> opc, str
string dst_kind, string lhs_kind, string rhs_kind,
RegisterOperand RegType,
ValueType AccumType, ValueType InputType,
+ AsmVectorIndexOpnd VIdx,
SDPatternOperator OpNode> :
BaseSIMDIndexedTied<Q, U, 0b0, size, opc, RegType, RegType, V128,
- VectorIndexS, asm, "", dst_kind, lhs_kind, rhs_kind,
+ VIdx, asm, "", dst_kind, lhs_kind, rhs_kind,
[(set (AccumType RegType:$dst),
(AccumType (OpNode (AccumType RegType:$Rd),
(InputType RegType:$Rn),
(InputType (bitconvert (AccumType
(AArch64duplane32 (v4i32 V128:$Rm),
- VectorIndexS:$idx)))))))]> {
+ VIdx:$idx)))))))]> {
bits<2> idx;
let Inst{21} = idx{0}; // L
let Inst{11} = idx{1}; // H
@@ -9155,17 +9159,24 @@ class BaseSIMDThreeSameVectorIndexS<bit Q, bit U, bits<2> size, bits<4> opc, str
multiclass SIMDThreeSameVectorDotIndex<bit U, bit Mixed, bits<2> size, string asm,
SDPatternOperator OpNode> {
def v8i8 : BaseSIMDThreeSameVectorIndexS<0, U, size, {0b111, Mixed}, asm, ".2s", ".8b", ".4b",
- V64, v2i32, v8i8, OpNode>;
+ V64, v2i32, v8i8, VectorIndexS, OpNode>;
def v16i8 : BaseSIMDThreeSameVectorIndexS<1, U, size, {0b111, Mixed}, asm, ".4s", ".16b", ".4b",
- V128, v4i32, v16i8, OpNode>;
+ V128, v4i32, v16i8, VectorIndexS, OpNode>;
}
-// TODO: The vectors v8i8 and v16i8 should be v8f8 and v16f8
-multiclass SIMDThreeSameVectorFP8DOT4Index<string asm> {
- def v8f8 : BaseSIMDThreeSameVectorIndexS<0b0, 0b0, 0b00, 0b0000, asm, ".2s", ".8b", ".4b",
- V64, v2f32, v8i8, null_frag>;
- def v16f8 : BaseSIMDThreeSameVectorIndexS<0b1, 0b0, 0b00, 0b0000, asm, ".4s", ".16b",".4b",
- V128, v4f32, v16i8, null_frag>;
+multiclass SIMD_FP8_Dot4_Index<string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v2f32 : BaseSIMDThreeSameVectorIndexS<0b0, 0b0, 0b00, 0b0000, asm, ".2s", ".8b", ".4b",
+ V64, v2f32, v8i8, VectorIndexS32b_timm, null_frag>;
+ def v4f32 : BaseSIMDThreeSameVectorIndexS<0b1, 0b0, 0b00, 0b0000, asm, ".4s", ".16b",".4b",
+ V128, v4f32, v16i8, VectorIndexS32b_timm, null_frag>;
+ }
+
+ def : Pat<(v2f32 (op (v2f32 V64:$Rd), (v8i8 V64:$Rn), (v16i8 V128:$Rm), VectorIndexS32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v2f32) $Rd, $Rn, $Rm, $Idx)>;
+
+ def : Pat<(v4f32 (op (v4f32 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128:$Rm), VectorIndexS32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v4f32) $Rd, $Rn, $Rm, $Idx)>;
}
// ARMv8.2-A Fused Multiply Add-Long Instructions (Indexed)
@@ -9174,14 +9185,15 @@ class BaseSIMDThreeSameVectorIndexH<bit Q, bit U, bits<2> sz, bits<4> opc, strin
string dst_kind, string lhs_kind,
string rhs_kind, RegisterOperand RegType,
RegisterOperand RegType_lo, ValueType AccumType,
- ValueType InputType, SDPatternOperator OpNode> :
+ ValueType InputType, AsmVectorIndexOpnd VIdx,
+ SDPatternOperator OpNode> :
BaseSIMDIndexedTied<Q, U, 0, sz, opc, RegType, RegType, RegType_lo,
- VectorIndexH, asm, "", dst_kind, lhs_kind, rhs_kind,
+ VIdx, asm, "", dst_kind, lhs_kind, rhs_kind,
[(set (AccumType RegType:$dst),
(AccumType (OpNode (AccumType RegType:$Rd),
(InputType RegType:$Rn),
(InputType (AArch64duplane16 (v8f16 V128_lo:$Rm),
- VectorIndexH:$idx)))))]> {
+ VIdx:$idx)))))]> {
// idx = H:L:M
bits<3> idx;
let Inst{11} = idx{2}; // H
@@ -9192,19 +9204,25 @@ class BaseSIMDThreeSameVectorIndexH<bit Q, bit U, bits<2> sz, bits<4> opc, strin
multiclass SIMDThreeSameVectorFMLIndex<bit U, bits<4> opc, string asm,
SDPatternOperator OpNode> {
def v4f16 : BaseSIMDThreeSameVectorIndexH<0, U, 0b10, opc, asm, ".2s", ".2h", ".h",
- V64, V128_lo, v2f32, v4f16, OpNode>;
+ V64, V128_lo, v2f32, v4f16, VectorIndexH, OpNode>;
def v8f16 : BaseSIMDThreeSameVectorIndexH<1, U, 0b10, opc, asm, ".4s", ".4h", ".h",
- V128, V128_lo, v4f32, v8f16, OpNode>;
+ V128, V128_lo, v4f32, v8f16, VectorIndexH, OpNode>;
}
//----------------------------------------------------------------------------
// FP8 Advanced SIMD vector x indexed element
-// TODO: Replace value types v8i8 and v16i8 by v8f8 and v16f8
-multiclass SIMDThreeSameVectorFP8DOT2Index<string asm> {
- def v4f16 : BaseSIMDThreeSameVectorIndexH<0b0, 0b0, 0b01, 0b0000, asm, ".4h", ".8b", ".2b",
- V64, V128_lo, v4f16, v8i8, null_frag>;
- def v8f16 : BaseSIMDThreeSameVectorIndexH<0b1, 0b0, 0b01, 0b0000, asm, ".8h", ".16b", ".2b",
- V128, V128_lo, v8f16, v8i16, null_frag>;
+multiclass SIMD_FP8_Dot2_Index<string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v4f16 : BaseSIMDThreeSameVectorIndexH<0b0, 0b0, 0b01, 0b0000, asm, ".4h", ".8b", ".2b",
+ V64, V128_lo, v4f16, v8i8, VectorIndexH32b_timm, null_frag>;
+ def v8f16 : BaseSIMDThreeSameVectorIndexH<0b1, 0b0, 0b01, 0b0000, asm, ".8h", ".16b", ".2b",
+ V128, V128_lo, v8f16, v16i8, VectorIndexH32b_timm, null_frag>;
+ }
+ def : Pat<(v4f16 (op (v4f16 V64:$Rd), (v8i8 V64:$Rn), (v16i8 V128_lo:$Rm), VectorIndexH32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v4f16) $Rd, $Rn, $Rm, $Idx)>;
+
+ def : Pat<(v8f16 (op (v8f16 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128_lo:$Rm), VectorIndexH32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v8f16) $Rd, $Rn, $Rm, $Idx)>;
}
multiclass SIMDFPIndexed<bit U, bits<4> opc, string asm,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 881af6eb951177..364566f63bca10 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1497,7 +1497,7 @@ class BaseSIMDSUDOTIndex<bit Q, string dst_kind, string lhs_kind,
ValueType AccumType, ValueType InputType>
: BaseSIMDThreeSameVectorIndexS<Q, 0, 0b00, 0b1111, "sudot", dst_kind,
lhs_kind, rhs_kind, RegType, AccumType,
- InputType, null_frag> {
+ InputType, VectorIndexS, null_frag> {
let Pattern = [(set (AccumType RegType:$dst),
(AccumType (AArch64usdot (AccumType RegType:$Rd),
(InputType (bitconvert (AccumType
@@ -10369,14 +10369,14 @@ let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
} // End let Predicates = [HasFP8FMA]
-let Uses = [FPMR, FPCR], Predicates = [HasFP8DOT2] in {
- defm FDOTlane : SIMDThreeSameVectorFP8DOT2Index<"fdot">;
- defm FDOT : SIMDThreeSameVectorDOT2<"fdot">;
+let Predicates = [HasFP8DOT2] in {
+ defm FDOTlane : SIMD_FP8_Dot2_Index<"fdot", int_aarch64_neon_fp8_fdot2_lane>;
+ defm FDOT : SIMD_FP8_Dot2<"fdot", int_aarch64_neon_fp8_fdot2>;
} // End let Predicates = [HasFP8DOT2]
-let Uses = [FPMR, FPCR], Predicates = [HasFP8DOT4] in {
- defm FDOTlane : SIMDThreeSameVectorFP8DOT4Index<"fdot">;
- defm FDOT : SIMDThreeSameVectorDOT4<"fdot">;
+let Predicates = [HasFP8DOT4] in {
+ defm FDOTlane : SIMD_FP8_Dot4_Index<"fdot", int_aarch64_neon_fp8_fdot4_lane>;
+ defm FDOT : SIMD_FP8_Dot4<"fdot", int_aarch64_neon_fp8_fdot4>;
} // End let Predicates = [HasFP8DOT4]
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll b/llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll
new file mode 100644
index 00000000000000..b7a35c5fddf170
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux -mattr=+neon,+fp8dot2,+fp8dot4 < %s | FileCheck %s
+
+define <4 x half> @test_fdot_f16(<4 x half> %vd, <8 x i8> %vn, <8 x i8> %vm) {
+; CHECK-LABEL: test_fdot_f16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.4h, v1.8b, v2.8b
+; CHECK-NEXT: ret
+ %res = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.v4f16.v8i8(<4 x half> %vd, <8 x i8> %vn, <8 x i8> %vm)
+ ret <4 x half> %res
+}
+
+define <8 x half> @test_fdotq_f16(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdotq_f16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.8h, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %res = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.v8f16.v16i8(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm)
+ ret <8 x half> %res
+}
+
+define <4 x half> @test_fdot_lane_f16(<4 x half> %vd, <8 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdot_lane_f16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.4h, v1.8b, v2.2b[0]
+; CHECK-NEXT: ret
+ %res = call <4 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v4f16.v8i8(<4 x half> %vd, <8 x i8> %vn, <16 x i8> %vm, i32 0)
+ ret <4 x half> %res
+}
+
+define <8 x half> @test_fdotq_lane_f16(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdotq_lane_f16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.8h, v1.16b, v2.2b[7]
+; CHECK-NEXT: ret
+ %res = call <8 x half> @llvm.aarch64.neon.fp8.fdot2.lane.v8f16.v16i8(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 7)
+ ret <8 x half> %res
+}
+
+define <2 x float> @test_fdot_f32(<2 x float> %vd, <8 x i8> %vn, <8 x i8> %vm) {
+; CHECK-LABEL: test_fdot_f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.2s, v1.8b, v2.8b
+; CHECK-NEXT: ret
+ %res = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.v2f32.v8i8(<2 x float> %vd, <8 x i8> %vn, <8 x i8> %vm)
+ ret <2 x float> %res
+}
+
+define <4 x float> @test_fdotq_f32(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdotq_f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.4s, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %res = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.v4f32.v16i8(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm)
+ ret <4 x float> %res
+}
+
+define <2 x float> @test_fdot_lane_f32(<2 x float> %vd, <8 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdot_lane_f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.2s, v1.8b, v2.4b[0]
+; CHECK-NEXT: ret
+ %res = call <2 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v2f32.v8i8(<2 x float> %vd, <8 x i8> %vn, <16 x i8> %vm, i32 0)
+ ret <2 x float> %res
+}
+
+define <4 x float> @test_fdotq_lane_f32(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fdotq_lane_f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fdot v0.4s, v1.16b, v2.4b[3]
+; CHECK-NEXT: ret
+ %res = call <4 x float> @llvm.aarch64.neon.fp8.fdot4.lane.v4f32.v16i8(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 3)
+ ret <4 x float> %res
+}
>From e677f8302bb65082495600b0a5642e63fa515530 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 17 Dec 2024 17:10:38 +0000
Subject: [PATCH 2/3] [AArch64] Implement NEON FP8 fused multiply-add
intrinsics (non-indexed)
This patch adds the following intrinsics:
float16x8_t vmlalbq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
float16x8_t vmlaltq_f16_mf8_fpm(float16x8_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
float32x4_t vmlallbbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
float32x4_t vmlallbtq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
float32x4_t vmlalltbq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
float32x4_t vmlallttq_f32_mf8_fpm(float32x4_t, mfloat8x16_t, mfloat8x16_t, fpm_t)
[fixup] Update intrinsics definitions
[fixup] Remove some opt passes from RUN lines
---
clang/include/clang/Basic/arm_neon.td | 10 ++
clang/lib/CodeGen/CGBuiltin.cpp | 43 +++++--
clang/lib/CodeGen/CodeGenFunction.h | 4 +-
.../fp8-intrinsics/acle_neon_fp8_fmla.c | 121 ++++++++++++++++++
.../acle_neon_fp8_fmla.c | 22 ++++
llvm/include/llvm/IR/IntrinsicsAArch64.td | 17 +++
.../lib/Target/AArch64/AArch64InstrFormats.td | 9 +-
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 14 +-
llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll | 56 ++++++++
9 files changed, 274 insertions(+), 22 deletions(-)
create mode 100644 clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
create mode 100644 clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
create mode 100644 llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index c6609f312969ee..7e7faa68c55692 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -2161,6 +2161,16 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
}
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
+ def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+ def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+
+ def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+ def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+ def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+ def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+}
+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b4b26eb84d5f92..8dbc8bfff95a4b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6759,11 +6759,14 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
return Builder.CreateCall(F, Ops, name);
}
-Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
+Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
+ ArrayRef<llvm::Type *> Tys,
SmallVectorImpl<Value *> &Ops,
- Value *FPM, const char *name) {
+ const CallExpr *E, const char *name) {
+ llvm::Value *FPM =
+ EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
- return EmitNeonCall(F, Ops, name);
+ return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
}
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
@@ -6779,9 +6782,7 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
Builder.getInt64(0));
}
- llvm::Value *FPM =
- EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
- return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
+ return EmitFP8NeonCall(IID, Tys, Ops, E, name);
}
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
@@ -6802,9 +6803,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
}
- llvm::Value *FPM =
- EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
- return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
+ return EmitFP8NeonCall(IID, Tys, Ops, E, name);
}
// Right-shift a vector by a constant.
@@ -14072,6 +14071,32 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
+
+ case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
+ {llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
+ "vmlal");
+ case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
+ {llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
+ "vmlal");
+ case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
+ {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
+ "vmlall");
+ case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
+ {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
+ "vmlall");
+ case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
+ {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
+ "vmlall");
+ case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
+ return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
+ {llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
+ "vmlall");
+
case NEON::BI__builtin_neon_vamin_f16:
case NEON::BI__builtin_neon_vaminq_f16:
case NEON::BI__builtin_neon_vamin_f32:
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fd6d44b2579b92..92aee1a05764cd 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4692,9 +4692,9 @@ class CodeGenFunction : public CodeGenTypeCache {
SmallVectorImpl<llvm::Value*> &O,
const char *name,
unsigned shift = 0, bool rightshift = false);
- llvm::Value *EmitFP8NeonCall(llvm::Function *F,
+ llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
SmallVectorImpl<llvm::Value *> &O,
- llvm::Value *FPM, const char *name);
+ const CallExpr *E, const char *name);
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
llvm::Type *Ty1, bool Extract,
SmallVectorImpl<llvm::Value *> &Ops,
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
new file mode 100644
index 00000000000000..b0b96a4b6725da
--- /dev/null
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
@@ -0,0 +1,121 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s
+// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +fp8 -target-feature +fp8fma -disable-O0-optnone -Werror -Wall -S -o /dev/null %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_neon.h>
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalb13__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
+//
+float16x8_t test_vmlalb(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlalbq_f16_mf8_fpm(vd, vn, vm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <8 x half> [[VMLAL1_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z11test_vmlalt13__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL1_I:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL1_I]]
+//
+float16x8_t test_vmlalt(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlaltq_f16_mf8_fpm(vd, vn, vm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbb13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+float32x4_t test_vmlallbb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallbbq_f32_mf8_fpm(vd, vn, vm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlallbt13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+float32x4_t test_vmlallbt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallbtq_f32_mf8_fpm(vd, vn, vm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltb13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlalltbq_f32_mf8_fpm(vd, vn, vm, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z13test_vmlalltt13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_I:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]])
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_I]]
+//
+float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
new file mode 100644
index 00000000000000..fcdd14e583101e
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -triple aarch64-linux-gnu -target-feature +neon -target-feature +bf16 -target-feature +faminmax -target-feature +fp8 -emit-llvm -verify %s -o /dev/null
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_neon.h>
+
+void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
+
+ (void) vmlalbq_f16_mf8_fpm(a, u, u, fpm);
+ // expected-error at -1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}}
+ (void) vmlaltq_f16_mf8_fpm(a, u, u, fpm);
+ // expected-error at -1 {{'vmlaltq_f16_mf8_fpm' requires target feature 'fp8fma'}}
+ (void) vmlallbbq_f32_mf8_fpm(b, u, u, fpm);
+ // expected-error at -1 {{'vmlallbbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
+ (void) vmlallbtq_f32_mf8_fpm(b, u, u, fpm);
+ // expected-error at -1 {{'vmlallbtq_f32_mf8_fpm' requires target feature 'fp8fma'}}
+ (void) vmlalltbq_f32_mf8_fpm(b, u, u, fpm);
+ // expected-error at -1 {{'vmlalltbq_f32_mf8_fpm' requires target feature 'fp8fma'}}
+ (void) vmlallttq_f32_mf8_fpm(b, u, u, fpm);
+ // expected-error at -1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}}
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 395db293063f45..9244c549dc469b 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -1036,6 +1036,23 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
def int_aarch64_neon_fp8_fdot4 : AdvSIMD_FP8_DOT_Intrinsic;
def int_aarch64_neon_fp8_fdot4_lane : AdvSIMD_FP8_DOT_LANE_Intrinsic;
+
+
+// Fused multiply-add
+ class AdvSIMD_FP8_FMLA_Intrinsic
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_v16i8_ty,
+ llvm_v16i8_ty],
+ [IntrReadMem, IntrInaccessibleMemOnly]>;
+
+ def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic;
+
+ def int_aarch64_neon_fp8_fmlallbb : AdvSIMD_FP8_FMLA_Intrinsic;
+ def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic;
}
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index dea2af16e3184a..38ab1ea785c706 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -6519,14 +6519,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
v4f32, v8f16, OpNode>;
}
-multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
+multiclass SIMDThreeSameVectorMLA<bit Q, string asm, SDPatternOperator op> {
+
def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
- V128, v8f16, v16i8, null_frag>;
+ V128, v8f16, v16i8, op>;
}
-multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
+multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm, SDPatternOperator op> {
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
- V128, v4f32, v16i8, null_frag>;
+ V128, v4f32, v16i8, op>;
}
// FP8 assembly/disassembly classes
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 364566f63bca10..ff65cd1fca9e1c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -10353,7 +10353,7 @@ let Predicates = [HasNEON, HasFAMINMAX] in {
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>;
} // End let Predicates = [HasNEON, HasFAMINMAX]
-let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
+let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
@@ -10361,12 +10361,12 @@ let Uses = [FPMR, FPCR], Predicates = [HasFP8FMA] in {
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;
- defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
- defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
- defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
- defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
- defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
- defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
+ defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>;
+ defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>;
+ defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>;
+ defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt>;
+ defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb>;
+ defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt>;
} // End let Predicates = [HasFP8FMA]
let Predicates = [HasFP8DOT2] in {
diff --git a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
new file mode 100644
index 00000000000000..008069ff63761f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux -mattr=+neon,+fp8fma < %s | FileCheck %s
+
+define <8 x half> @test_fmlalb(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlalb:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalb v0.8h, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.v8f16(<8 x half> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <8 x half> %r
+}
+
+define <8 x half> @test_fmlalt(<8 x half> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlalt:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalt v0.8h, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.v8f16(<8 x half> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <8 x half> %r
+}
+
+define <4 x float> @test_fmlallbb(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlallbb:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlallbb v0.4s, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <4 x float> %r
+}
+
+define <4 x float> @test_fmlallbt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlallbt:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlallbt v0.4s, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <4 x float> %r
+}
+
+define <4 x float> @test_fmlalltb(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlalltb:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalltb v0.4s, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <4 x float> %r
+}
+
+define <4 x float> @test_fmlalltt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: test_fmlalltt:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalltt v0.4s, v1.16b, v2.16b
+; CHECK-NEXT: ret
+ %r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b)
+ ret <4 x float> %r
+}
>From 4a79123a00981e8f65b151307c4f658574bd47e4 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 18 Dec 2024 10:52:34 +0000
Subject: [PATCH 3/3] [AArch64] Implement NEON FP8 intrinsics for fused
multiply-add (indexed)
This patch adds the following intrinsics:
* Floating-point multiply-add long to half-precision (vector, by element)
float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
* Floating-point multiply-add long-long to single-precision (vector, by element)
float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm)
float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
[fixup] Update intrinsics definitions
[fixup] Regenerate tests
---
clang/include/clang/Basic/arm_neon.td | 14 +
clang/lib/CodeGen/CGBuiltin.cpp | 66 ++++-
clang/lib/CodeGen/CodeGenFunction.h | 6 +-
.../fp8-intrinsics/acle_neon_fp8_fmla.c | 244 ++++++++++++++++++
.../acle_neon_fp8_fmla.c | 29 ++-
llvm/include/llvm/IR/IntrinsicsAArch64.td | 16 ++
.../lib/Target/AArch64/AArch64InstrFormats.td | 24 +-
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 16 +-
llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll | 54 ++++
9 files changed, 445 insertions(+), 24 deletions(-)
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index 7e7faa68c55692..577379e4921605 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -2169,6 +2169,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+
+ def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+ def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+
+ def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+ def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+ def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+ def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
+ def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
}
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 8dbc8bfff95a4b..9ae133421d8dac 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6770,14 +6770,14 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
}
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
- unsigned IID, bool ExtendLane, llvm::Type *RetTy,
+ unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
RetTy->getPrimitiveSizeInBits();
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
Ops[1]->getType()};
- if (ExtendLane) {
+ if (ExtendLaneArg) {
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
Builder.getInt64(0));
@@ -6785,6 +6785,21 @@ llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
}
+llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
+ unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
+ SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
+
+ if (ExtendLaneArg) {
+ auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
+ Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
+ Builder.getInt64(0));
+ }
+ const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
+ RetTy->getPrimitiveSizeInBits();
+ return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
+ Ops, E, name);
+}
+
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
bool neg) {
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12778,7 +12793,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
unsigned Int;
bool ExtractLow = false;
- bool ExtendLane = false;
+ bool ExtendLaneArg = false;
switch (BuiltinID) {
default: return nullptr;
case NEON::BI__builtin_neon_vbsl_v:
@@ -14053,24 +14068,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
Ops, E, "fdot2");
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
- ExtendLane = true;
+ ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
- ExtendLane, HalfTy, Ops, E, "fdot2_lane");
+ ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
FloatTy, Ops, E, "fdot4");
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
- ExtendLane = true;
+ ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
- ExtendLane, FloatTy, Ops, E, "fdot4_lane");
+ ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
@@ -14096,7 +14111,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
"vmlall");
-
+ case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
+ ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
+ case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
+ ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
+ case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
+ ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
+ case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
+ ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
+ case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
+ ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
+ case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
+ ExtendLaneArg = true;
+ LLVM_FALLTHROUGH;
+ case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
+ return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
+ ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
case NEON::BI__builtin_neon_vamin_f16:
case NEON::BI__builtin_neon_vaminq_f16:
case NEON::BI__builtin_neon_vamin_f32:
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 92aee1a05764cd..f70e73fdab0e39 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4699,7 +4699,11 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Type *Ty1, bool Extract,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
- llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
+ llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
+ llvm::Type *RetTy,
+ SmallVectorImpl<llvm::Value *> &Ops,
+ const CallExpr *E, const char *name);
+ llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
index b0b96a4b6725da..736538073cb391 100644
--- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c
@@ -119,3 +119,247 @@ float32x4_t test_vmlalltb(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_
float32x4_t test_vmlalltt(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
return vmlallttq_f32_mf8_fpm(vd, vn, vm, fpm);
}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb_lane(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 0)
+// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z16test_vmlalb_lane13__Float16x8_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 0)
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+float16x8_t test_vmlalb_lane(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlalbq_lane_f16_mf8_fpm(vd, vn, vm, 0, fpm);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalb_laneq(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0)
+// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z17test_vmlalb_laneq13__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0)
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+float16x8_t test_vmlalb_laneq(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlalbq_laneq_f16_mf8_fpm(vd, vn, vm, 0, fpm);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt_lane(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 7)
+// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z16test_vmlalt_lane13__Float16x8_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[TMP1]], i32 7)
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+float16x8_t test_vmlalt_lane(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlaltq_lane_f16_mf8_fpm(vd, vn, vm, 7, fpm);
+}
+
+// CHECK-LABEL: define dso_local <8 x half> @test_vmlalt_laneq(
+// CHECK-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15)
+// CHECK-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <8 x half> @_Z17test_vmlalt_laneq13__Float16x8_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <8 x half> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x half> [[VD]] to <16 x i8>
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLAL_LANE:%.*]] = bitcast <16 x i8> [[TMP0]] to <8 x half>
+// CHECK-CXX-NEXT: [[VMLAL_LANE1:%.*]] = call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane.v8f16(<8 x half> [[VMLAL_LANE]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15)
+// CHECK-CXX-NEXT: ret <8 x half> [[VMLAL_LANE1]]
+//
+float16x8_t test_vmlalt_laneq(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlaltq_laneq_f16_mf8_fpm(vd, vn, vm, 15, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb_lane(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 0)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlallbb_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 0)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlallbb_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlallbbq_lane_f32_mf8_fpm(vd, vn, vm, 0, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbb_laneq(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlallbb_laneq13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 0)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlallbb_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallbbq_laneq_f32_mf8_fpm(vd, vn, vm, 0, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt_lane(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 3)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlallbt_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 3)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlallbt_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlallbtq_lane_f32_mf8_fpm(vd, vn, vm, 3, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlallbt_laneq(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlallbt_laneq13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 3)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlallbt_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallbtq_laneq_f32_mf8_fpm(vd, vn, vm, 3, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb_lane(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlalltb_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlalltb_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlalltbq_lane_f32_mf8_fpm(vd, vn, vm, 7, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltb_laneq(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlalltb_laneq13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 7)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlalltb_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlalltbq_laneq_f32_mf8_fpm(vd, vn, vm, 7, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt_lane(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z18test_vmlalltt_lane13__Float32x4_t14__Mfloat8x16_t13__Mfloat8x8_tm(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <8 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = call <16 x i8> @llvm.vector.insert.v16i8.v8i8(<16 x i8> poison, <8 x i8> [[VM]], i64 0)
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[TMP0]], i32 7)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlalltt_lane(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, fpm_t fpm) {
+ return vmlallttq_lane_f32_mf8_fpm(vd, vn, vm, 7, fpm);
+}
+
+// CHECK-LABEL: define dso_local <4 x float> @test_vmlalltt_laneq(
+// CHECK-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15)
+// CHECK-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z19test_vmlalltt_laneq13__Float32x4_t14__Mfloat8x16_tS0_m(
+// CHECK-CXX-SAME: <4 x float> noundef [[VD:%.*]], <16 x i8> [[VN:%.*]], <16 x i8> [[VM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
+// CHECK-CXX-NEXT: [[ENTRY:.*:]]
+// CHECK-CXX-NEXT: call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
+// CHECK-CXX-NEXT: [[VMLALL_LANE:%.*]] = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane.v4f32(<4 x float> [[VD]], <16 x i8> [[VN]], <16 x i8> [[VM]], i32 15)
+// CHECK-CXX-NEXT: ret <4 x float> [[VMLALL_LANE]]
+//
+float32x4_t test_vmlalltt_laneq(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) {
+ return vmlallttq_laneq_f32_mf8_fpm(vd, vn, vm, 15, fpm);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
index fcdd14e583101e..4a507b08040fff 100644
--- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c
@@ -5,7 +5,6 @@
#include <arm_neon.h>
void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
-
(void) vmlalbq_f16_mf8_fpm(a, u, u, fpm);
// expected-error at -1 {{'vmlalbq_f16_mf8_fpm' requires target feature 'fp8fma'}}
(void) vmlaltq_f16_mf8_fpm(a, u, u, fpm);
@@ -20,3 +19,31 @@ void test_features(float16x8_t a, float32x4_t b, mfloat8x16_t u, fpm_t fpm) {
// expected-error at -1 {{'vmlallttq_f32_mf8_fpm' requires target feature 'fp8fma'}}
}
+void test_imm(float16x8_t d, float32x4_t c, mfloat8x16_t a, mfloat8x8_t b, fpm_t fpm) {
+(void) vmlalbq_lane_f16_mf8_fpm(d, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlalbq_laneq_f16_mf8_fpm(d, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+(void) vmlaltq_lane_f16_mf8_fpm(d, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlaltq_laneq_f16_mf8_fpm(d, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+
+(void) vmlallbbq_lane_f32_mf8_fpm(c, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlallbbq_laneq_f32_mf8_fpm(c, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+(void) vmlallbtq_lane_f32_mf8_fpm(c, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlallbtq_laneq_f32_mf8_fpm(c, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+(void) vmlalltbq_lane_f32_mf8_fpm(c, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlalltbq_laneq_f32_mf8_fpm(c, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+(void) vmlallttq_lane_f32_mf8_fpm(c, a, b, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 7]}}
+(void) vmlallttq_laneq_f32_mf8_fpm(c, a, a, -1, fpm);
+// expected-error at -1 {{argument value -1 is outside the valid range [0, 15]}}
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 9244c549dc469b..6dfc3c8f2a3931 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -1046,6 +1046,14 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
llvm_v16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
+ class AdvSIMD_FP8_FMLA_LANE_Intrinsic
+ : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMMatchType<0>,
+ llvm_v16i8_ty,
+ llvm_v16i8_ty,
+ llvm_i32_ty],
+ [IntrReadMem, IntrInaccessibleMemOnly, ImmArg<ArgIndex<3>>]>;
+
def int_aarch64_neon_fp8_fmlalb : AdvSIMD_FP8_FMLA_Intrinsic;
def int_aarch64_neon_fp8_fmlalt : AdvSIMD_FP8_FMLA_Intrinsic;
@@ -1053,6 +1061,14 @@ def int_aarch64_st64bv0: Intrinsic<[llvm_i64_ty], !listconcat([llvm_ptr_ty], dat
def int_aarch64_neon_fp8_fmlallbt : AdvSIMD_FP8_FMLA_Intrinsic;
def int_aarch64_neon_fp8_fmlalltb : AdvSIMD_FP8_FMLA_Intrinsic;
def int_aarch64_neon_fp8_fmlalltt : AdvSIMD_FP8_FMLA_Intrinsic;
+
+ def int_aarch64_neon_fp8_fmlalb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
+
+ def int_aarch64_neon_fp8_fmlallbb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
+ def int_aarch64_neon_fp8_fmlallbt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalltb_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
+ def int_aarch64_neon_fp8_fmlalltt_lane : AdvSIMD_FP8_FMLA_LANE_Intrinsic;
}
def llvm_nxv1i1_ty : LLVMType<nxv1i1>;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 38ab1ea785c706..3bb5d3cb4d09de 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -9105,7 +9105,7 @@ class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
RegisterOperand RegType,
RegisterOperand RegType_lo>
: BaseSIMDIndexedTied<Q, U, 0b0, sz, opc,
- RegType, RegType, RegType_lo, VectorIndexB,
+ RegType, RegType, RegType_lo, VectorIndexB32b_timm,
asm, "", dst_kind, ".16b", ".b", []> {
// idx = H:L:M
@@ -9114,14 +9114,24 @@ class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
let Inst{21-19} = idx{2-0};
}
-multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm> {
- def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h",
- V128, V128_0to7>;
+multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h",
+ V128, V128_0to7>;
+ }
+
+ def : Pat<(v8f16 (op (v8f16 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128_0to7:$Rm), VectorIndexB32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v8f16) $Rd, $Rn, $Rm, $Idx)>;
}
-multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm> {
- def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s",
- V128, V128_0to7>;
+multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm, SDPatternOperator op> {
+ let Uses = [FPMR, FPCR], mayLoad = 1 in {
+ def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s",
+ V128, V128_0to7>;
+ }
+
+ def : Pat<(v4f32 (op (v4f32 V128:$Rd), (v16i8 V128:$Rn), (v16i8 V128_0to7:$Rm), VectorIndexB32b_timm:$Idx)),
+ (!cast<Instruction>(NAME # v4f32) $Rd, $Rn, $Rm, $Idx)>;
}
//----------------------------------------------------------------------------
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ff65cd1fca9e1c..d112d4f10e47d9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -10353,14 +10353,16 @@ let Predicates = [HasNEON, HasFAMINMAX] in {
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", AArch64famin>;
} // End let Predicates = [HasNEON, HasFAMINMAX]
-let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
- defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
- defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
- defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
- defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
- defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
- defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;
+let Predicates = [HasFP8FMA] in {
+ defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb_lane>;
+ defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt_lane>;
+ defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb_lane>;
+ defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt", int_aarch64_neon_fp8_fmlallbt_lane>;
+ defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb", int_aarch64_neon_fp8_fmlalltb_lane>;
+ defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt", int_aarch64_neon_fp8_fmlalltt_lane>;
+}
+let Predicates = [HasFP8FMA], Uses = [FPMR, FPCR], mayLoad = 1 in {
defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb", int_aarch64_neon_fp8_fmlalb>;
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt", int_aarch64_neon_fp8_fmlalt>;
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb", int_aarch64_neon_fp8_fmlallbb>;
diff --git a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
index 008069ff63761f..60957a7c0f2f41 100644
--- a/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
+++ b/llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll
@@ -54,3 +54,57 @@ define <4 x float> @test_fmlalltt(<4 x float> %d, <16 x i8> %a, <16 x i8> %b) {
%r = call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.v4f32(<4 x float> %d, <16 x i8> %a, <16 x i8> %b)
ret <4 x float> %r
}
+
+define <8 x half> @test_fmlalb_lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlalb_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalb v0.8h, v1.16b, v2.b[0]
+; CHECK-NEXT: ret
+ %res = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalb.lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 0)
+ ret <8 x half> %res
+}
+
+define <8 x half> @test_fmlalt_lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlalt_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalt v0.8h, v1.16b, v2.b[4]
+; CHECK-NEXT: ret
+ %res = tail call <8 x half> @llvm.aarch64.neon.fp8.fmlalt.lane(<8 x half> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 4)
+ ret <8 x half> %res
+}
+
+define <4 x float> @test_fmlallbb_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlallbb_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlallbb v0.4s, v1.16b, v2.b[7]
+; CHECK-NEXT: ret
+ %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbb.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 7)
+ ret <4 x float> %res
+}
+
+define <4 x float> @test_fmlallbt_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlallbt_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlallbt v0.4s, v1.16b, v2.b[10]
+; CHECK-NEXT: ret
+ %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlallbt.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 10)
+ ret <4 x float> %res
+}
+
+define <4 x float> @test_fmlalltb_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlalltb_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalltb v0.4s, v1.16b, v2.b[13]
+; CHECK-NEXT: ret
+ %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltb.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 13)
+ ret <4 x float> %res
+}
+
+define <4 x float> @test_fmlalltt_lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm) {
+; CHECK-LABEL: test_fmlalltt_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: fmlalltt v0.4s, v1.16b, v2.b[15]
+; CHECK-NEXT: ret
+ %res = tail call <4 x float> @llvm.aarch64.neon.fp8.fmlalltt.lane(<4 x float> %vd, <16 x i8> %vn, <16 x i8> %vm, i32 15)
+ ret <4 x float> %res
+}
More information about the llvm-commits
mailing list