[clang] 8127162 - [X86][AMX] Support AMX-FP8 (#113850)

via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 30 19:14:30 PDT 2024


Author: Feng Zou
Date: 2024-10-31T10:14:25+08:00
New Revision: 8127162427c5f8c28d6292e1d4b4ce8a00b2d5a2

URL: https://github.com/llvm/llvm-project/commit/8127162427c5f8c28d6292e1d4b4ce8a00b2d5a2
DIFF: https://github.com/llvm/llvm-project/commit/8127162427c5f8c28d6292e1d4b4ce8a00b2d5a2.diff

LOG: [X86][AMX] Support AMX-FP8 (#113850)

Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368

Added: 
    clang/lib/Headers/amxfp8intrin.h
    clang/test/CodeGen/X86/amx_fp8.c
    clang/test/CodeGen/X86/amx_fp8_errors.c
    clang/test/CodeGen/X86/amx_fp8_inline_asm.c
    llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll
    llvm/test/MC/Disassembler/X86/AMX/amx-fp8.txt
    llvm/test/MC/X86/AMX/amx-fp8-att.s
    llvm/test/MC/X86/AMX/amx-fp8-intel.s

Modified: 
    clang/docs/ReleaseNotes.rst
    clang/include/clang/Basic/BuiltinsX86_64.def
    clang/include/clang/Driver/Options.td
    clang/lib/Basic/Targets/X86.cpp
    clang/lib/Basic/Targets/X86.h
    clang/lib/Headers/CMakeLists.txt
    clang/lib/Headers/immintrin.h
    clang/lib/Sema/SemaX86.cpp
    llvm/include/llvm/IR/IntrinsicsX86.td
    llvm/include/llvm/TargetParser/X86TargetParser.def
    llvm/lib/Target/X86/X86.td
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/lib/Target/X86/X86InstrAMX.td
    llvm/lib/Target/X86/X86InstrPredicates.td
    llvm/lib/TargetParser/Host.cpp
    llvm/lib/TargetParser/X86TargetParser.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 402203f89e23a0..145786bcc59b45 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -661,6 +661,7 @@ X86 Support
 
 - Supported intrinsics for ``MOVRS AND AVX10.2``.
   * Supported intrinsics of ``_mm(256|512)_(mask(z))_loadrs_epi(8|16|32|64)``.
+- Support ISA of ``AMX-FP8``.
 
 Arm and AArch64 Support
 ^^^^^^^^^^^^^^^^^^^^^^^

diff  --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index e1e613560167ac..68904ae8abcd15 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -155,6 +155,12 @@ TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiv*SLLiSLLiIi", "n", "cmpccxadd")
 // AMX_FP16 FP16
 TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16")
 
+// AMX FP8
+TARGET_BUILTIN(__builtin_ia32_tdpbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8")
+
 // RAO-INT
 TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint")
 TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint")

diff  --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 9d595984b63c4b..2b9ee1a0e669ed 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -6300,6 +6300,8 @@ def mamx_fp16 : Flag<["-"], "mamx-fp16">, Group<m_x86_Features_Group>;
 def mno_amx_fp16 : Flag<["-"], "mno-amx-fp16">, Group<m_x86_Features_Group>;
 def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
 def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
+def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
+def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
 def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
 def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
 def mcmpccxadd : Flag<["-"], "mcmpccxadd">, Group<m_x86_Features_Group>;

diff  --git a/clang/lib/Basic/Targets/X86.cpp b/clang/lib/Basic/Targets/X86.cpp
index 82d29ea9fea5c4..4988682a22f019 100644
--- a/clang/lib/Basic/Targets/X86.cpp
+++ b/clang/lib/Basic/Targets/X86.cpp
@@ -428,6 +428,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
       HasAMXTILE = true;
     } else if (Feature == "+amx-complex") {
       HasAMXCOMPLEX = true;
+    } else if (Feature == "+amx-fp8") {
+      HasAMXFP8 = true;
     } else if (Feature == "+cmpccxadd") {
       HasCMPCCXADD = true;
     } else if (Feature == "+raoint") {
@@ -947,6 +949,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
     Builder.defineMacro("__AMX_FP16__");
   if (HasAMXCOMPLEX)
     Builder.defineMacro("__AMX_COMPLEX__");
+  if (HasAMXFP8)
+    Builder.defineMacro("__AMX_FP8__");
   if (HasCMPCCXADD)
     Builder.defineMacro("__CMPCCXADD__");
   if (HasRAOINT)
@@ -1077,6 +1081,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
       .Case("amx-fp16", true)
       .Case("amx-int8", true)
       .Case("amx-tile", true)
+      .Case("amx-fp8", true)
       .Case("avx", true)
       .Case("avx10.1-256", true)
       .Case("avx10.1-512", true)
@@ -1195,6 +1200,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
       .Case("amx-fp16", HasAMXFP16)
       .Case("amx-int8", HasAMXINT8)
       .Case("amx-tile", HasAMXTILE)
+      .Case("amx-fp8", HasAMXFP8)
       .Case("avx", SSELevel >= AVX)
       .Case("avx10.1-256", HasAVX10_1)
       .Case("avx10.1-512", HasAVX10_1_512)

diff  --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index e8aad3ec5a74b1..a1b2a0cec209ab 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -157,6 +157,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
   bool HasAMXINT8 = false;
   bool HasAMXBF16 = false;
   bool HasAMXCOMPLEX = false;
+  bool HasAMXFP8 = false;
   bool HasSERIALIZE = false;
   bool HasTSXLDTRK = false;
   bool HasUSERMSR = false;

diff  --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index 0211d1870b30a0..818de5a8e1d231 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -149,6 +149,7 @@ set(x86_files
   amxcomplexintrin.h
   amxfp16intrin.h
   amxintrin.h
+  amxfp8intrin.h
   avx10_2_512bf16intrin.h
   avx10_2_512convertintrin.h
   avx10_2_512minmaxintrin.h

diff  --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
new file mode 100644
index 00000000000000..0f5ddc87e5a752
--- /dev/null
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -0,0 +1,95 @@
+/*===------------- amxfp8intrin.h - AMX intrinsics -*- C++ -*----------------===
+ *
+ * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+ * See https://llvm.org/LICENSE.txt for license information.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ *
+ *===------------------------------------------------------------------------===
+ */
+
+#ifndef __IMMINTRIN_H
+#error "Never use <amxfp8intrin.h> directly; include <immintrin.h> instead."
+#endif /* __IMMINTRIN_H */
+
+#ifndef __AMXFP8INTRIN_H
+#define __AMXFP8INTRIN_H
+#ifdef __x86_64__
+
+/// Peform the dot product of a BF8 value \a a by a BF8 value \a b accumulating
+/// into a Single Precision (FP32) source/dest \a dst.
+///
+/// \headerfile <immintrin.h>
+///
+/// \code
+/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
+/// \endcode
+///
+/// This intrinsic corresponds to the \c TDPBF8PS instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param a
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+///    The 2nd source tile. Max size is 1024 Bytes.
+#define _tile_dpbf8ps(dst, a, b) __builtin_ia32_tdpbf8ps((dst), (a), (b))
+
+/// Perform the dot product of a BF8 value \a a by an HF8 value \a b
+/// accumulating into a Single Precision (FP32) source/dest \a dst.
+///
+/// \headerfile <immintrin.h>
+///
+/// \code
+/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
+/// \endcode
+///
+/// This intrinsic corresponds to the \c TDPBHF8PS instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param a
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+///    The 2nd source tile. Max size is 1024 Bytes.
+#define _tile_dpbhf8ps(dst, a, b) __builtin_ia32_tdpbhf8ps((dst), (a), (b))
+
+/// Perform the dot product of an HF8 value \a a by a BF8 value \a b
+/// accumulating into a Single Precision (FP32) source/dest \a dst.
+///
+/// \headerfile <immintrin.h>
+///
+/// \code
+/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
+/// \endcode
+///
+/// This intrinsic corresponds to the \c TDPHBF8PS instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param a
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+///    The 2nd source tile. Max size is 1024 Bytes.
+#define _tile_dphbf8ps(dst, a, b) __builtin_ia32_tdphbf8ps((dst), (a), (b))
+
+/// Perform the dot product of an HF8 value \a a by an HF8 value \a b
+/// accumulating into a Single Precision (FP32) source/dest \a dst.
+///
+/// \headerfile <immintrin.h>
+///
+/// \code
+/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
+/// \endcode
+///
+/// This intrinsic corresponds to the \c TDPHF8PS instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param a
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+///    The 2nd source tile. Max size is 1024 Bytes.
+#define _tile_dphf8ps(dst, a, b) __builtin_ia32_tdphf8ps((dst), (a), (b))
+
+#endif /* __x86_64__ */
+#endif /* __AMXFP8INTRIN_H */

diff  --git a/clang/lib/Headers/immintrin.h b/clang/lib/Headers/immintrin.h
index 65ad72bc479f49..6184e9c8479695 100644
--- a/clang/lib/Headers/immintrin.h
+++ b/clang/lib/Headers/immintrin.h
@@ -648,6 +648,10 @@ _storebe_i64(void * __P, long long __D) {
 #include <amxcomplexintrin.h>
 #endif
 
+#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_FP8__)
+#include <amxfp8intrin.h>
+#endif
+
 #if !defined(__SCE__) || __has_feature(modules) ||                             \
     defined(__AVX512VP2INTERSECT__)
 #include <avx512vp2intersectintrin.h>

diff  --git a/clang/lib/Sema/SemaX86.cpp b/clang/lib/Sema/SemaX86.cpp
index 6a4d78f0ca9084..0e43b030e70d41 100644
--- a/clang/lib/Sema/SemaX86.cpp
+++ b/clang/lib/Sema/SemaX86.cpp
@@ -640,6 +640,10 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) {
   case X86::BI__builtin_ia32_tdpfp16ps:
   case X86::BI__builtin_ia32_tcmmimfp16ps:
   case X86::BI__builtin_ia32_tcmmrlfp16ps:
+  case X86::BI__builtin_ia32_tdpbf8ps:
+  case X86::BI__builtin_ia32_tdpbhf8ps:
+  case X86::BI__builtin_ia32_tdphbf8ps:
+  case X86::BI__builtin_ia32_tdphf8ps:
     return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2});
   }
 }

diff  --git a/clang/test/CodeGen/X86/amx_fp8.c b/clang/test/CodeGen/X86/amx_fp8.c
new file mode 100644
index 00000000000000..9c79514f891299
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8.c
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +amx-fp8  \
+// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
+#include <immintrin.h>
+
+void test_amx(void *data) {
+  //CHECK-LABEL: @test_amx
+  //CHECK: call void @llvm.x86.tdpbf8ps(i8 1, i8 2, i8 3)
+  _tile_dpbf8ps(1, 2, 3);
+}
+
+void test_amx2(void *data) {
+  //CHECK-LABEL: @test_amx2
+  //CHECK: call void @llvm.x86.tdpbhf8ps(i8 1, i8 2, i8 3)
+  _tile_dpbhf8ps(1, 2, 3);
+}
+
+void test_amx3(void *data) {
+  //CHECK-LABEL: @test_amx3
+  //CHECK: call void @llvm.x86.tdphbf8ps(i8 1, i8 2, i8 3)
+  _tile_dphbf8ps(1, 2, 3);
+}
+
+void test_amx4(void *data) {
+  //CHECK-LABEL: @test_amx4
+  //CHECK: call void @llvm.x86.tdphf8ps(i8 1, i8 2, i8 3)
+  _tile_dphf8ps(1, 2, 3);
+}

diff  --git a/clang/test/CodeGen/X86/amx_fp8_errors.c b/clang/test/CodeGen/X86/amx_fp8_errors.c
new file mode 100644
index 00000000000000..77cbd34905b8ba
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_errors.c
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-fp8 -verify
+
+#include <immintrin.h>
+
+void test_amx(void *data) {
+  _tile_dpbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to 
diff erent tiles}}
+  _tile_dpbhf8ps(4, 3, 3); // expected-error {{tile arguments must refer to 
diff erent tiles}}
+  _tile_dphbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to 
diff erent tiles}}
+  _tile_dphf8ps(4, 3, 3); // expected-error {{tile arguments must refer to 
diff erent tiles}}
+}

diff  --git a/clang/test/CodeGen/X86/amx_fp8_inline_asm.c b/clang/test/CodeGen/X86/amx_fp8_inline_asm.c
new file mode 100644
index 00000000000000..49331bd9d368ab
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_inline_asm.c
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +amx-fp8 -emit-llvm -o - -Wall -Werror -pedantic | FileCheck %s
+
+void f_tilemul(short a)
+{
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09tdpbf8ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "tdpbf8ps %%tmm6, %%tmm0, %%tmm7    \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09tdpbhf8ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "tdpbhf8ps %%tmm6, %%tmm0, %%tmm7    \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09tdphbf8ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "tdphbf8ps %%tmm6, %%tmm0, %%tmm7    \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09tdphf8ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "tdphf8ps %%tmm6, %%tmm0, %%tmm7    \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index 0ecca157077fdc..d1807d26a874ba 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5994,6 +5994,23 @@ let TargetPrefix = "x86" in {
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty], []>;
+
+  def int_x86_tdpbf8ps : ClangBuiltin<"__builtin_ia32_tdpbf8ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+                        [ImmArg<ArgIndex<0>>,
+                         ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+  def int_x86_tdpbhf8ps : ClangBuiltin<"__builtin_ia32_tdpbhf8ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+                        [ImmArg<ArgIndex<0>>,
+                         ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+  def int_x86_tdphbf8ps : ClangBuiltin<"__builtin_ia32_tdphbf8ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+                        [ImmArg<ArgIndex<0>>,
+                         ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+  def int_x86_tdphf8ps : ClangBuiltin<"__builtin_ia32_tdphf8ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+                        [ImmArg<ArgIndex<0>>,
+                        ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/TargetParser/X86TargetParser.def b/llvm/include/llvm/TargetParser/X86TargetParser.def
index 073e19f8187c65..19e8e0013ef6a0 100644
--- a/llvm/include/llvm/TargetParser/X86TargetParser.def
+++ b/llvm/include/llvm/TargetParser/X86TargetParser.def
@@ -264,6 +264,7 @@ X86_FEATURE_COMPAT(AVX10_2_512,     "avx10.2-512",            0)
 //FIXME: make MOVRS _COMPAT defined when gcc landed relate patch.
 X86_FEATURE       (MOVRS,           "movrs")
 X86_FEATURE       (ZU,              "zu")
+X86_FEATURE       (AMX_FP8,         "amx-fp8")
 // These features aren't really CPU features, but the frontend can set them.
 X86_FEATURE       (RETPOLINE_EXTERNAL_THUNK,    "retpoline-external-thunk")
 X86_FEATURE       (RETPOLINE_INDIRECT_BRANCHES, "retpoline-indirect-branches")

diff  --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td
index 6bedf9e1d13ac3..c7882acc044e04 100644
--- a/llvm/lib/Target/X86/X86.td
+++ b/llvm/lib/Target/X86/X86.td
@@ -270,6 +270,9 @@ def FeatureAMXFP16     : SubtargetFeature<"amx-fp16", "HasAMXFP16", "true",
 def FeatureAMXCOMPLEX : SubtargetFeature<"amx-complex", "HasAMXCOMPLEX", "true",
                                          "Support AMX-COMPLEX instructions",
                                          [FeatureAMXTILE]>;
+def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",
+                                     "Support AMX-FP8 instructions",
+                                     [FeatureAMXTILE]>;
 def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true",
                                         "Support CMPCCXADD instructions">;
 def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true",

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 22cba69af41f51..58598fefe0e796 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37420,7 +37420,11 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
   case X86::PTDPBUSD:
   case X86::PTDPBUUD:
   case X86::PTDPBF16PS:
-  case X86::PTDPFP16PS: {
+  case X86::PTDPFP16PS:
+  case X86::PTDPBF8PS:
+  case X86::PTDPBHF8PS:
+  case X86::PTDPHBF8PS:
+  case X86::PTDPHF8PS: {
     unsigned Opc;
     switch (MI.getOpcode()) {
     // clang-format off
@@ -37431,6 +37435,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
     case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
     case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
     case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
+    case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
+    case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
+    case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
+    case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
     // clang-format on
     }
 

diff  --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index 99deacc811a170..202232ccb8bc72 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -267,3 +267,42 @@ let Predicates = [HasAMXCOMPLEX, In64BitMode] in {
     }
   } // SchedRW = [WriteSystem]
 }
+
+// AMX-FP8
+let Predicates = [HasAMXFP8, In64BitMode] in {
+  let SchedRW = [WriteSystem] in {
+    let Constraints = "$src1 = $dst" in {
+      class AMX_FP8_BASE<bits<8> Opcode, string Opstr> :
+        I<Opcode, MRMSrcReg4VOp3, (outs TILE:$dst),
+          (ins TILE:$src1, TILE:$src2, TILE:$src3),
+          !strconcat(Opstr, "\t{$src3, $src2, $dst|$dst, $src2, $src3}"),
+          []>, VEX, VVVV;
+    }
+
+    def TDPBF8PS : AMX_FP8_BASE<0xfd, "tdpbf8ps">, T_MAP5, PS;
+    def TDPBHF8PS : AMX_FP8_BASE<0xfd, "tdpbhf8ps">, T_MAP5, XD;
+    def TDPHBF8PS : AMX_FP8_BASE<0xfd, "tdphbf8ps">, T_MAP5, XS;
+    def TDPHF8PS : AMX_FP8_BASE<0xfd, "tdphf8ps">, T_MAP5, PD;
+
+    let usesCustomInserter = 1 in {
+      // Pseudo instructions, using immediates instead of tile registers.
+      // To be translated to the actual instructions in X86ISelLowering.cpp
+      def PTDPBF8PS : PseudoI<(outs),
+                              (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
+                              [(int_x86_tdpbf8ps timm:$src1, timm:$src2,
+                                timm:$src3)]>;
+      def PTDPBHF8PS : PseudoI<(outs),
+                               (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
+                               [(int_x86_tdpbhf8ps timm:$src1, timm:$src2,
+                                 timm:$src3)]>;
+      def PTDPHBF8PS : PseudoI<(outs),
+                               (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
+                               [(int_x86_tdphbf8ps timm:$src1, timm:$src2,
+                                 timm:$src3)]>;
+      def PTDPHF8PS : PseudoI<(outs),
+                              (ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
+                              [(int_x86_tdphf8ps timm:$src1, timm:$src2,
+                                timm:$src3)]>;
+    }
+  }
+}

diff  --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td
index 7fb566fba51818..5b659d3b072dca 100644
--- a/llvm/lib/Target/X86/X86InstrPredicates.td
+++ b/llvm/lib/Target/X86/X86InstrPredicates.td
@@ -183,6 +183,7 @@ def HasAMXTILE   : Predicate<"Subtarget->hasAMXTILE()">;
 def HasAMXBF16   : Predicate<"Subtarget->hasAMXBF16()">;
 def HasAMXINT8   : Predicate<"Subtarget->hasAMXINT8()">;
 def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;
+def HasAMXFP8    : Predicate<"Subtarget->hasAMXFP8()">;
 def HasUINTR     : Predicate<"Subtarget->hasUINTR()">;
 def HasUSERMSR   : Predicate<"Subtarget->hasUSERMSR()">;
 def HasCRC32     : Predicate<"Subtarget->hasCRC32()">;

diff  --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp
index 5c4e3a9dc52b0f..fd34a276cf3ce5 100644
--- a/llvm/lib/TargetParser/Host.cpp
+++ b/llvm/lib/TargetParser/Host.cpp
@@ -1876,6 +1876,10 @@ const StringMap<bool> sys::getHostCPUFeatures() {
       MaxLevel >= 0x19 && !getX86CpuIDAndInfo(0x19, &EAX, &EBX, &ECX, &EDX);
   Features["widekl"] = HasLeaf7 && HasLeaf19 && ((EBX >> 2) & 1);
 
+  bool HasLeaf1E = MaxLevel >= 0x1e &&
+                   !getX86CpuIDAndInfoEx(0x1e, 0x1, &EAX, &EBX, &ECX, &EDX);
+  Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave;
+
   bool HasLeaf24 =
       MaxLevel >= 0x24 && !getX86CpuIDAndInfo(0x24, &EAX, &EBX, &ECX, &EDX);
 

diff  --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp
index 586df5748aa822..7d60b81d4bb1c3 100644
--- a/llvm/lib/TargetParser/X86TargetParser.cpp
+++ b/llvm/lib/TargetParser/X86TargetParser.cpp
@@ -598,6 +598,7 @@ constexpr FeatureBitset ImpliedFeaturesAMX_BF16 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesAMX_FP16 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesAMX_INT8 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesAMX_COMPLEX = FeatureAMX_TILE;
+constexpr FeatureBitset ImpliedFeaturesAMX_FP8 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesHRESET = {};
 
 constexpr FeatureBitset ImpliedFeaturesPREFETCHI = {};

diff  --git a/llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll b/llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll
new file mode 100644
index 00000000000000..f5d3f6ec9ec298
--- /dev/null
+++ b/llvm/test/CodeGen/X86/amx_fp8_intrinsics.ll
@@ -0,0 +1,20 @@
+; RUN: llc < %s -O0 -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-fp8 | FileCheck %s
+
+; CHECK-LABEL: test_amx:
+; CHECK:       # %bb.0:
+; CHECK:    tdpbf8ps        %tmm3, %tmm2, %tmm1
+; CHECK:    tdpbhf8ps        %tmm3, %tmm2, %tmm1
+; CHECK:    tdphbf8ps        %tmm3, %tmm2, %tmm1
+; CHECK:    tdphf8ps        %tmm3, %tmm2, %tmm1
+
+define void @test_amx(){
+call void @llvm.x86.tdpbf8ps(i8 1, i8 2, i8 3)
+call void @llvm.x86.tdpbhf8ps(i8 1, i8 2, i8 3)
+call void @llvm.x86.tdphbf8ps(i8 1, i8 2, i8 3)
+call void @llvm.x86.tdphf8ps(i8 1, i8 2, i8 3)
+ret void
+}
+declare void @llvm.x86.tdpbf8ps(i8 %tile0, i8 %tile1, i8 %tile2)
+declare void @llvm.x86.tdpbhf8ps(i8 %tile0, i8 %tile1, i8 %tile2)
+declare void @llvm.x86.tdphbf8ps(i8 %tile0, i8 %tile1, i8 %tile2)
+declare void @llvm.x86.tdphf8ps(i8 %tile0, i8 %tile1, i8 %tile2)

diff  --git a/llvm/test/MC/Disassembler/X86/AMX/amx-fp8.txt b/llvm/test/MC/Disassembler/X86/AMX/amx-fp8.txt
new file mode 100644
index 00000000000000..e714a52d2c31a7
--- /dev/null
+++ b/llvm/test/MC/Disassembler/X86/AMX/amx-fp8.txt
@@ -0,0 +1,34 @@
+# RUN: llvm-mc --disassemble %s -triple=x86_64 | FileCheck %s --check-prefixes=ATT
+# RUN: llvm-mc --disassemble %s -triple=x86_64 -x86-asm-syntax=intel --output-asm-variant=1 | FileCheck %s --check-prefixes=INTEL
+
+# ATT:   tdpbf8ps %tmm4, %tmm5, %tmm6
+# INTEL: tdpbf8ps tmm6, tmm5, tmm4
+0xc4,0xe5,0x58,0xfd,0xf5
+
+# ATT:   tdpbf8ps %tmm1, %tmm2, %tmm3
+# INTEL: tdpbf8ps tmm3, tmm2, tmm1
+0xc4,0xe5,0x70,0xfd,0xda
+
+# ATT:   tdpbhf8ps %tmm4, %tmm5, %tmm6
+# INTEL: tdpbhf8ps tmm6, tmm5, tmm4
+0xc4,0xe5,0x5b,0xfd,0xf5
+
+# ATT:   tdpbhf8ps %tmm1, %tmm2, %tmm3
+# INTEL: tdpbhf8ps tmm3, tmm2, tmm1
+0xc4,0xe5,0x73,0xfd,0xda
+
+# ATT:   tdphbf8ps %tmm4, %tmm5, %tmm6
+# INTEL: tdphbf8ps tmm6, tmm5, tmm4
+0xc4,0xe5,0x5a,0xfd,0xf5
+
+# ATT:   tdphbf8ps %tmm1, %tmm2, %tmm3
+# INTEL: tdphbf8ps tmm3, tmm2, tmm1
+0xc4,0xe5,0x72,0xfd,0xda
+
+# ATT:   tdphf8ps %tmm4, %tmm5, %tmm6
+# INTEL: tdphf8ps tmm6, tmm5, tmm4
+0xc4,0xe5,0x59,0xfd,0xf5
+
+# ATT:   tdphf8ps %tmm1, %tmm2, %tmm3
+# INTEL: tdphf8ps tmm3, tmm2, tmm1
+0xc4,0xe5,0x71,0xfd,0xda

diff  --git a/llvm/test/MC/X86/AMX/amx-fp8-att.s b/llvm/test/MC/X86/AMX/amx-fp8-att.s
new file mode 100644
index 00000000000000..904539ec4917fe
--- /dev/null
+++ b/llvm/test/MC/X86/AMX/amx-fp8-att.s
@@ -0,0 +1,33 @@
+// RUN: llvm-mc -triple x86_64 --show-encoding %s | FileCheck %s
+
+// CHECK: tdpbf8ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe5,0x58,0xfd,0xf5]
+          tdpbf8ps %tmm4, %tmm5, %tmm6
+
+// CHECK: tdpbf8ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe5,0x70,0xfd,0xda]
+          tdpbf8ps %tmm1, %tmm2, %tmm3
+
+// CHECK: tdpbhf8ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe5,0x5b,0xfd,0xf5]
+          tdpbhf8ps %tmm4, %tmm5, %tmm6
+
+// CHECK: tdpbhf8ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe5,0x73,0xfd,0xda]
+          tdpbhf8ps %tmm1, %tmm2, %tmm3
+
+// CHECK: tdphbf8ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe5,0x5a,0xfd,0xf5]
+          tdphbf8ps %tmm4, %tmm5, %tmm6
+
+// CHECK: tdphbf8ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe5,0x72,0xfd,0xda]
+          tdphbf8ps %tmm1, %tmm2, %tmm3
+
+// CHECK: tdphf8ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe5,0x59,0xfd,0xf5]
+          tdphf8ps %tmm4, %tmm5, %tmm6
+
+// CHECK: tdphf8ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe5,0x71,0xfd,0xda]
+          tdphf8ps %tmm1, %tmm2, %tmm3

diff  --git a/llvm/test/MC/X86/AMX/amx-fp8-intel.s b/llvm/test/MC/X86/AMX/amx-fp8-intel.s
new file mode 100644
index 00000000000000..4191ae6f5cd133
--- /dev/null
+++ b/llvm/test/MC/X86/AMX/amx-fp8-intel.s
@@ -0,0 +1,33 @@
+// RUN: llvm-mc -triple x86_64 -x86-asm-syntax=intel -output-asm-variant=1 --show-encoding %s | FileCheck %s
+
+// CHECK: tdpbf8ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe5,0x58,0xfd,0xf5]
+          tdpbf8ps tmm6, tmm5, tmm4
+
+// CHECK: tdpbf8ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe5,0x70,0xfd,0xda]
+          tdpbf8ps tmm3, tmm2, tmm1
+
+// CHECK: tdpbhf8ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe5,0x5b,0xfd,0xf5]
+          tdpbhf8ps tmm6, tmm5, tmm4
+
+// CHECK: tdpbhf8ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe5,0x73,0xfd,0xda]
+          tdpbhf8ps tmm3, tmm2, tmm1
+
+// CHECK: tdphbf8ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe5,0x5a,0xfd,0xf5]
+          tdphbf8ps tmm6, tmm5, tmm4
+
+// CHECK: tdphbf8ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe5,0x72,0xfd,0xda]
+          tdphbf8ps tmm3, tmm2, tmm1
+
+// CHECK: tdphf8ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe5,0x59,0xfd,0xf5]
+          tdphf8ps tmm6, tmm5, tmm4
+
+// CHECK: tdphf8ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe5,0x71,0xfd,0xda]
+          tdphf8ps tmm3, tmm2, tmm1


        


More information about the cfe-commits mailing list