[clang] [llvm] [X86][AMX] Support AMX-FP8 (PR #113850)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 27 18:59:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Feng Zou (fzou1)
<details>
<summary>Changes</summary>
Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368
---
Patch is 24.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113850.diff
24 Files Affected:
- (modified) clang/docs/ReleaseNotes.rst (+1)
- (modified) clang/include/clang/Basic/BuiltinsX86_64.def (+6)
- (modified) clang/include/clang/Driver/Options.td (+2)
- (modified) clang/lib/Basic/Targets/X86.cpp (+6)
- (modified) clang/lib/Basic/Targets/X86.h (+1)
- (modified) clang/lib/Headers/CMakeLists.txt (+1)
- (added) clang/lib/Headers/amxfp8intrin.h (+24)
- (modified) clang/lib/Headers/immintrin.h (+4)
- (modified) clang/lib/Sema/SemaX86.cpp (+4)
- (added) clang/test/CodeGen/X86/amx_fp8.c (+27)
- (added) clang/test/CodeGen/X86/amx_fp8_errors.c (+10)
- (added) clang/test/CodeGen/X86/amx_fp8_inline_asm.c (+32)
- (modified) llvm/include/llvm/IR/IntrinsicsX86.td (+17)
- (modified) llvm/include/llvm/TargetParser/X86TargetParser.def (+1)
- (modified) llvm/lib/Target/X86/X86.td (+3)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+23)
- (modified) llvm/lib/Target/X86/X86InstrAMX.td (+39)
- (modified) llvm/lib/Target/X86/X86InstrPredicates.td (+1)
- (modified) llvm/lib/TargetParser/Host.cpp (+4)
- (modified) llvm/lib/TargetParser/X86TargetParser.cpp (+1)
- (added) llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll (+20)
- (added) llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-fp8.txt (+34)
- (added) llvm/test/MC/X86/AMX/x86-64-amx-fp8-att.s (+33)
- (added) llvm/test/MC/X86/AMX/x86-64-amx-fp8-intel.s (+33)
``````````diff
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 6a95337815174b..da0ab888ce200d 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -642,6 +642,7 @@ X86 Support
- Supported intrinsics for ``MOVRS AND AVX10.2``.
* Supported intrinsics of ``_mm(256|512)_(mask(z))_loadrs_epi(8|16|32|64)``.
+- Support ISA of ``AMX-FP8``.
Arm and AArch64 Support
^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index e1e613560167ac..68904ae8abcd15 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -155,6 +155,12 @@ TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiv*SLLiSLLiIi", "n", "cmpccxadd")
// AMX_FP16 FP16
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16")
+// AMX FP8
+TARGET_BUILTIN(__builtin_ia32_tdpbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+
// RAO-INT
TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint")
TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint")
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 5df6ddd5e6a0c5..bbada0834526d7 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -6290,6 +6290,8 @@ def mamx_fp16 : Flag<["-"], "mamx-fp16">, Group<m_x86_Features_Group>;
def mno_amx_fp16 : Flag<["-"], "mno-amx-fp16">, Group<m_x86_Features_Group>;
def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
+def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
+def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
def mcmpccxadd : Flag<["-"], "mcmpccxadd">, Group<m_x86_Features_Group>;
diff --git a/clang/lib/Basic/Targets/X86.cpp b/clang/lib/Basic/Targets/X86.cpp
index d067ec218b5270..b95261c39a5993 100644
--- a/clang/lib/Basic/Targets/X86.cpp
+++ b/clang/lib/Basic/Targets/X86.cpp
@@ -420,6 +420,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasAMXTILE = true;
} else if (Feature == "+amx-complex") {
HasAMXCOMPLEX = true;
+ } else if (Feature == "+amx-fp8") {
+ HasAMXFP8 = true;
} else if (Feature == "+cmpccxadd") {
HasCMPCCXADD = true;
} else if (Feature == "+raoint") {
@@ -939,6 +941,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__AMX_FP16__");
if (HasAMXCOMPLEX)
Builder.defineMacro("__AMX_COMPLEX__");
+ if (HasAMXFP8)
+ Builder.defineMacro("__AMX_FP8__");
if (HasCMPCCXADD)
Builder.defineMacro("__CMPCCXADD__");
if (HasRAOINT)
@@ -1069,6 +1073,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
.Case("amx-fp16", true)
.Case("amx-int8", true)
.Case("amx-tile", true)
+ .Case("amx-fp8", true)
.Case("avx", true)
.Case("avx10.1-256", true)
.Case("avx10.1-512", true)
@@ -1187,6 +1192,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
.Case("amx-fp16", HasAMXFP16)
.Case("amx-int8", HasAMXINT8)
.Case("amx-tile", HasAMXTILE)
+ .Case("amx-fp8", HasAMXFP8)
.Case("avx", SSELevel >= AVX)
.Case("avx10.1-256", HasAVX10_1)
.Case("avx10.1-512", HasAVX10_1_512)
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index e8aad3ec5a74b1..a1b2a0cec209ab 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -157,6 +157,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
bool HasAMXINT8 = false;
bool HasAMXBF16 = false;
bool HasAMXCOMPLEX = false;
+ bool HasAMXFP8 = false;
bool HasSERIALIZE = false;
bool HasTSXLDTRK = false;
bool HasUSERMSR = false;
diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index e97953d87a2ff9..142cd01ac5aec0 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -149,6 +149,7 @@ set(x86_files
amxcomplexintrin.h
amxfp16intrin.h
amxintrin.h
+ amxfp8intrin.h
avx10_2_512bf16intrin.h
avx10_2_512convertintrin.h
avx10_2_512minmaxintrin.h
diff --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
new file mode 100644
index 00000000000000..d187b5f0421bbb
--- /dev/null
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -0,0 +1,24 @@
+/*===---------- amxfp8intrin.h - AMX intrinsics -*- C++ -*------------===
+ *
+ * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+ * See https://llvm.org/LICENSE.txt for license information.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ *
+ *===------------------------------------------------------------------------===
+ */
+
+#ifndef __IMMINTRIN_H
+#error "Never use <amxfp8intrin.h> directly; include <immintrin.h> instead."
+#endif /* __IMMINTRIN_H */
+
+#ifndef __AMXFP8INTRIN_H
+#define __AMXFP8INTRIN_H
+#ifdef __x86_64__
+
+#define _tile_dpbf8ps __builtin_ia32_tdpbf8ps
+#define _tile_dpbhf8ps __builtin_ia32_tdpbhf8ps
+#define _tile_dphbf8ps __builtin_ia32_tdphbf8ps
+#define _tile_dphf8ps __builtin_ia32_tdphf8ps
+
+#endif /* __x86_64__ */
+#endif /* __AMXFP8INTRIN_H */
diff --git a/clang/lib/Headers/immintrin.h b/clang/lib/Headers/immintrin.h
index 5f296d0a3324d0..5529f99ab0c6b6 100644
--- a/clang/lib/Headers/immintrin.h
+++ b/clang/lib/Headers/immintrin.h
@@ -648,6 +648,10 @@ _storebe_i64(void * __P, long long __D) {
#include <amxcomplexintrin.h>
#endif
+#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_FP8__)
+#include <amxfp8intrin.h>
+#endif
+
#if !defined(__SCE__) || __has_feature(modules) || \
defined(__AVX512VP2INTERSECT__)
#include <avx512vp2intersectintrin.h>
diff --git a/clang/lib/Sema/SemaX86.cpp b/clang/lib/Sema/SemaX86.cpp
index 6a4d78f0ca9084..0e43b030e70d41 100644
--- a/clang/lib/Sema/SemaX86.cpp
+++ b/clang/lib/Sema/SemaX86.cpp
@@ -640,6 +640,10 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) {
case X86::BI__builtin_ia32_tdpfp16ps:
case X86::BI__builtin_ia32_tcmmimfp16ps:
case X86::BI__builtin_ia32_tcmmrlfp16ps:
+ case X86::BI__builtin_ia32_tdpbf8ps:
+ case X86::BI__builtin_ia32_tdpbhf8ps:
+ case X86::BI__builtin_ia32_tdphbf8ps:
+ case X86::BI__builtin_ia32_tdphf8ps:
return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2});
}
}
diff --git a/clang/test/CodeGen/X86/amx_fp8.c b/clang/test/CodeGen/X86/amx_fp8.c
new file mode 100644
index 00000000000000..9c79514f891299
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8.c
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 \
+// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
+#include <immintrin.h>
+
+void test_amx(void *data) {
+ //CHECK-LABEL: @test_amx
+ //CHECK: call void @llvm.x86.tdpbf8ps(i8 1, i8 2, i8 3)
+ _tile_dpbf8ps(1, 2, 3);
+}
+
+void test_amx2(void *data) {
+ //CHECK-LABEL: @test_amx2
+ //CHECK: call void @llvm.x86.tdpbhf8ps(i8 1, i8 2, i8 3)
+ _tile_dpbhf8ps(1, 2, 3);
+}
+
+void test_amx3(void *data) {
+ //CHECK-LABEL: @test_amx3
+ //CHECK: call void @llvm.x86.tdphbf8ps(i8 1, i8 2, i8 3)
+ _tile_dphbf8ps(1, 2, 3);
+}
+
+void test_amx4(void *data) {
+ //CHECK-LABEL: @test_amx4
+ //CHECK: call void @llvm.x86.tdphf8ps(i8 1, i8 2, i8 3)
+ _tile_dphf8ps(1, 2, 3);
+}
diff --git a/clang/test/CodeGen/X86/amx_fp8_errors.c b/clang/test/CodeGen/X86/amx_fp8_errors.c
new file mode 100644
index 00000000000000..77cbd34905b8ba
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_errors.c
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-fp8 -verify
+
+#include <immintrin.h>
+
+void test_amx(void *data) {
+ _tile_dpbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
+ _tile_dpbhf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
+ _tile_dphbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
+ _tile_dphf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
+}
diff --git a/clang/test/CodeGen/X86/amx_fp8_inline_asm.c b/clang/test/CodeGen/X86/amx_fp8_inline_asm.c
new file mode 100644
index 00000000000000..49331bd9d368ab
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_inline_asm.c
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 -emit-llvm -o - -Wall -Werror -pedantic | FileCheck %s
+
+void f_tilemul(short a)
+{
+ //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdpbf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+ __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
+ "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
+ "tdpbf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
+ "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+ ::: "memory", "tmm0", "tmm6", "tmm7");
+
+ //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdpbhf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+ __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
+ "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
+ "tdpbhf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
+ "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+ ::: "memory", "tmm0", "tmm6", "tmm7");
+
+ //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdphbf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+ __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
+ "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
+ "tdphbf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
+ "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+ ::: "memory", "tmm0", "tmm6", "tmm7");
+
+ //CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdphf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+ __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
+ "tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
+ "tdphf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
+ "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+ ::: "memory", "tmm0", "tmm6", "tmm7");
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index d0083017fb9383..6530051e0d3e9f 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5994,6 +5994,23 @@ let TargetPrefix = "x86" in {
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty,
llvm_x86amx_ty], []>;
+
+ def int_x86_tdpbf8ps : ClangBuiltin<"__builtin_ia32_tdpbf8ps">,
+ Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+ [ImmArg<ArgIndex<0>>,
+ ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+ def int_x86_tdpbhf8ps : ClangBuiltin<"__builtin_ia32_tdpbhf8ps">,
+ Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+ [ImmArg<ArgIndex<0>>,
+ ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+ def int_x86_tdphbf8ps : ClangBuiltin<"__builtin_ia32_tdphbf8ps">,
+ Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+ [ImmArg<ArgIndex<0>>,
+ ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+ def int_x86_tdphf8ps : ClangBuiltin<"__builtin_ia32_tdphf8ps">,
+ Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+ [ImmArg<ArgIndex<0>>,
+ ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/TargetParser/X86TargetParser.def b/llvm/include/llvm/TargetParser/X86TargetParser.def
index 073e19f8187c65..19e8e0013ef6a0 100644
--- a/llvm/include/llvm/TargetParser/X86TargetParser.def
+++ b/llvm/include/llvm/TargetParser/X86TargetParser.def
@@ -264,6 +264,7 @@ X86_FEATURE_COMPAT(AVX10_2_512, "avx10.2-512", 0)
//FIXME: make MOVRS _COMPAT defined when gcc landed relate patch.
X86_FEATURE (MOVRS, "movrs")
X86_FEATURE (ZU, "zu")
+X86_FEATURE (AMX_FP8, "amx-fp8")
// These features aren't really CPU features, but the frontend can set them.
X86_FEATURE (RETPOLINE_EXTERNAL_THUNK, "retpoline-external-thunk")
X86_FEATURE (RETPOLINE_INDIRECT_BRANCHES, "retpoline-indirect-branches")
diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td
index 6bedf9e1d13ac3..79c60402a49f0a 100644
--- a/llvm/lib/Target/X86/X86.td
+++ b/llvm/lib/Target/X86/X86.td
@@ -270,6 +270,9 @@ def FeatureAMXFP16 : SubtargetFeature<"amx-fp16", "HasAMXFP16", "true",
def FeatureAMXCOMPLEX : SubtargetFeature<"amx-complex", "HasAMXCOMPLEX", "true",
"Support AMX-COMPLEX instructions",
[FeatureAMXTILE]>;
+def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",
+ "Support AMX-FP8 instructions",
+ [FeatureAMXTILE]>;
def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true",
"Support CMPCCXADD instructions">;
def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true",
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a6d77873ec2901..5a7313ac3e1234 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37503,6 +37503,29 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent(); // The pseudo is gone now.
return BB;
}
+ case X86::PTDPBF8PS:
+ case X86::PTDPBHF8PS:
+ case X86::PTDPHBF8PS:
+ case X86::PTDPHF8PS: {
+ const DebugLoc &DL = MI.getDebugLoc();
+ unsigned Opc;
+ switch(MI.getOpcode()) {
+ default: llvm_unreachable("Unexpected instruction!");
+ case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
+ case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
+ case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
+ case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
+ }
+
+ MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
+ MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
+ MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef);
+ MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
+ MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef);
+
+ MI.eraseFromParent();
+ return BB;
+ }
}
}
diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index 99deacc811a170..d0c91ab7b5e696 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -267,3 +267,42 @@ let Predicates = [HasAMXCOMPLEX, In64BitMode] in {
}
} // SchedRW = [WriteSystem]
}
+
+// AMX-FP8
+let Predicates = [HasAMXFP8, In64BitMode] in {
+ let SchedRW = [WriteSystem] in {
+ let Constraints = "$src1 = $dst" in {
+ class AMX_FP8_BASE<bits<8> Opcode, string Opstr> :
+ I<Opcode, MRMSrcReg4VOp3, (outs TILE:$dst),
+ (ins TILE:$src1, TILE:$src2, TILE:$src3),
+ !strconcat(Opstr, "\t{$src3, $src2, $dst|$dst, $src2, $src3}"),
+ []>, VEX, VVVV;
+ }
+
+ def TDPBF8PS : AMX_FP8_BASE<0xfd, "tdpbf8ps">, T_MAP5, PS;
+ def TDPBHF8PS : AMX_FP8_BASE<0xfd, "tdpbhf8ps">, T_MAP5, XD;
+ def TDPHBF8PS : AMX_FP8_BASE<0xfd, "tdphbf8ps">, T_MAP5, XS;
+ def TDPHF8PS : AMX_FP8_BASE<0xfd, "tdphf8ps">, T_MAP5, PD;
+
+ let usesCustomInserter = 1 in {
+ // Pseudo instructions, using immediates instead of tile registers.
+ // To be translated to the actual instructions in X86ISelLowering.cpp
+ def PTDPBF8PS : PseudoI<(outs), (ins u8imm:$src1,
+ u8imm:$src2, u8imm:$src3),
+ [(int_x86_tdpbf8ps timm:$src1,
+ timm:$src2, timm:$src3)]>;
+ def PTDPBHF8PS : PseudoI<(outs), (ins u8imm:$src1,
+ u8imm:$src2, u8imm:$src3),
+ [(int_x86_tdpbhf8ps timm:$src1,
+ timm:$src2, timm:$src3)]>;
+ def PTDPHBF8PS : PseudoI<(outs), (ins u8imm:$src1,
+ u8imm:$src2, u8imm:$src3),
+ [(int_x86_tdphbf8ps timm:$src1,
+ timm:$src2, timm:$src3)]>;
+ def PTDPHF8PS : PseudoI<(outs), (ins u8imm:$src1,
+ u8imm:$src2, u8imm:$src3),
+ [(int_x86_tdphf8ps timm:$src1,
+ timm:$src2, timm:$src3)]>;
+ }
+ }
+}
diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td
index 7fb566fba51818..5b659d3b072dca 100644
--- a/llvm/lib/Target/X86/X86InstrPredicates.td
+++ b/llvm/lib/Target/X86/X86InstrPredicates.td
@@ -183,6 +183,7 @@ def HasAMXTILE : Predicate<"Subtarget->hasAMXTILE()">;
def HasAMXBF16 : Predicate<"Subtarget->hasAMXBF16()">;
def HasAMXINT8 : Predicate<"Subtarget->hasAMXINT8()">;
def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;
+def HasAMXFP8 : Predicate<"Subtarget->hasAMXFP8()">;
def HasUINTR : Predicate<"Subtarget->hasUINTR()">;
def HasUSERMSR : Predicate<"Subtarget->hasUSERMSR()">;
def HasCRC32 : Predicate<"Subtarget->hasCRC32()">;
diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp
index 5c4e3a9dc52b0f..78991b3936505c 100644
--- a/llvm/lib/TargetParser/Host.cpp
+++ b/llvm/lib/TargetParser/Host.cpp
@@ -1876,6 +1876,10 @@ const StringMap<bool> sys::getHostCPUFeatures() {
MaxLevel >= 0x19 && !getX86CpuIDAndInfo(0x19, &EAX, &EBX, &ECX, &EDX);
Features["widekl"] = HasLeaf7 && HasLeaf19 && ((EBX >> 2) & 1);
+ bool HasLeaf1E =
+ MaxLevel >= 0x1e && !getX86CpuIDAndInfo(0x1e, &EAX, &EBX, &ECX, &EDX);
+ Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave;
+
bool HasLeaf24 =
MaxLevel >= 0x24 && !getX86CpuIDAndInfo(0x24, &EAX, &EBX, &ECX, &EDX);
diff --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp
index 586df5748aa822..7d60b81d4bb1c3 100644
--- a/llvm/lib/TargetParser/X86TargetParser.cpp
+++ b/llvm/lib/TargetParser/X86TargetParser.cpp
@@ -598,6 +598,7 @@ constexpr FeatureBitset ImpliedFeaturesAMX_BF16 = FeatureAMX_TILE;
constexpr FeatureBitset ImpliedFeaturesAMX_FP16 = FeatureAMX_TILE;
constexpr FeatureBitset ImpliedFeaturesAMX_INT8 = FeatureAMX_TILE;
constexpr FeatureBitset ImpliedFeaturesAMX_COMPLEX = FeatureAMX_TILE;
+constexpr FeatureBitset ImpliedFeaturesAMX_FP8 = FeatureAMX_TILE;
constexpr FeatureBitset ImpliedFeaturesHRESET = {};
constexpr FeatureBitset ImpliedFeaturesPREFETCHI = {};
diff --git a/llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll b/llvm/test/CodeGen/X86/amx_fp...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/113850
More information about the llvm-commits
mailing list