[clang] [llvm] [X86][AMX] Support AMX-TF32 (PR #115625)

Feng Zou via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 9 17:36:43 PST 2024


https://github.com/fzou1 created https://github.com/llvm/llvm-project/pull/115625

Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368

>From b1d9799b99b45b5af2b63868c4c3b139dbf9378c Mon Sep 17 00:00:00 2001
From: Feng Zou <feng.zou at intel.com>
Date: Sat, 26 Oct 2024 18:44:32 +0800
Subject: [PATCH] [X86][AMX] Support AMX-TF32

Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368
---
 clang/docs/ReleaseNotes.rst                   |   1 +
 clang/include/clang/Basic/BuiltinsX86_64.def  |  15 +-
 clang/include/clang/Driver/Options.td         |   2 +
 clang/lib/Basic/Targets/X86.cpp               |   6 +
 clang/lib/Basic/Targets/X86.h                 |   1 +
 clang/lib/Headers/CMakeLists.txt              |   1 +
 clang/lib/Headers/amxtf32intrin.h             | 194 ++++++++++++++++++
 clang/lib/Headers/immintrin.h                 |   4 +
 clang/lib/Sema/SemaX86.cpp                    |   2 +
 clang/test/CodeGen/X86/amx_tf32.c             |  17 ++
 clang/test/CodeGen/X86/amx_tf32_api.c         |  27 +++
 clang/test/CodeGen/X86/amx_tf32_errors.c      |  23 +++
 clang/test/CodeGen/X86/amx_tf32_inline_asm.c  |  18 ++
 clang/test/Driver/x86-target-features.c       |   7 +
 clang/test/Preprocessor/x86_target_features.c |   9 +
 llvm/include/llvm/IR/IntrinsicsX86.td         |  19 ++
 .../llvm/TargetParser/X86TargetParser.def     |   1 +
 llvm/lib/Target/X86/X86.td                    |   3 +
 llvm/lib/Target/X86/X86ExpandPseudo.cpp       |  11 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  22 ++
 llvm/lib/Target/X86/X86InstrAMX.td            |  52 +++++
 llvm/lib/Target/X86/X86InstrPredicates.td     |   1 +
 llvm/lib/Target/X86/X86LowerAMXType.cpp       |  20 +-
 llvm/lib/Target/X86/X86RegisterInfo.cpp       |   4 +-
 llvm/lib/TargetParser/Host.cpp                |   1 +
 llvm/lib/TargetParser/X86TargetParser.cpp     |   1 +
 llvm/test/CodeGen/X86/amx-tf32-internal.ll    |  47 +++++
 llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll  |  23 +++
 .../Disassembler/X86/AMX/x86-64-amx-tf32.txt  |  19 ++
 llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s    |  17 ++
 llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s  |  17 ++
 31 files changed, 578 insertions(+), 7 deletions(-)
 create mode 100644 clang/lib/Headers/amxtf32intrin.h
 create mode 100644 clang/test/CodeGen/X86/amx_tf32.c
 create mode 100644 clang/test/CodeGen/X86/amx_tf32_api.c
 create mode 100644 clang/test/CodeGen/X86/amx_tf32_errors.c
 create mode 100644 clang/test/CodeGen/X86/amx_tf32_inline_asm.c
 create mode 100644 llvm/test/CodeGen/X86/amx-tf32-internal.ll
 create mode 100644 llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll
 create mode 100644 llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt
 create mode 100644 llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s
 create mode 100644 llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index c3424e0e6f34c9..e235a04f78112b 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -740,6 +740,7 @@ X86 Support
 - Support ISA of ``AMX-FP8``.
 - Support ISA of ``AMX-TRANSPOSE``.
 - Support ISA of ``AMX-AVX512``.
+- Support ISA of ``AMX-TF32``.
 
 Arm and AArch64 Support
 ^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index 9f7462b1e0d962..25c10d39df32e2 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -139,6 +139,9 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16l_internal, "V32yUsUsV256iUi", "n",
 TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
 TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
 TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
+TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32")
+TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose")
+
 // AMX
 TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
@@ -172,10 +175,6 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh, "V32xIUcUi", "n", "amx-avx512,avx10
 TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl, "V32xIUcUi", "n", "amx-avx512,avx10.2-512")
 TARGET_BUILTIN(__builtin_ia32_tilemovrow, "V16iIUcUi", "n", "amx-avx512,avx10.2-512")
 
-TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
-TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
-TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd")
-
 // AMX_FP16 FP16
 TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16")
 
@@ -185,6 +184,14 @@ TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8")
 TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
 TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8")
 
+// AMX TF32
+TARGET_BUILTIN(__builtin_ia32_tmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32")
+TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32,amx-transpose")
+
+TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
+TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
+TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd")
+
 // RAO-INT
 TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint")
 TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint")
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 0dba5672c5a85d..1304ef3c5a228b 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -6297,6 +6297,8 @@ def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
 def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
 def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
 def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
+def mamx_tf32 : Flag<["-"], "mamx-tf32">, Group<m_x86_Features_Group>;
+def mno_amx_tf32 : Flag<["-"], "mno-amx-tf32">, Group<m_x86_Features_Group>;
 def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
 def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
 def mamx_transpose : Flag<["-"], "mamx-transpose">, Group<m_x86_Features_Group>;
diff --git a/clang/lib/Basic/Targets/X86.cpp b/clang/lib/Basic/Targets/X86.cpp
index 3c3dbfa13e452b..dc85e9aa77cd3d 100644
--- a/clang/lib/Basic/Targets/X86.cpp
+++ b/clang/lib/Basic/Targets/X86.cpp
@@ -434,6 +434,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
       HasAMXTRANSPOSE = true;
     } else if (Feature == "+amx-avx512") {
       HasAMXAVX512 = true;
+    } else if (Feature == "+amx-tf32") {
+      HasAMXTF32 = true;
     } else if (Feature == "+cmpccxadd") {
       HasCMPCCXADD = true;
     } else if (Feature == "+raoint") {
@@ -959,6 +961,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
     Builder.defineMacro("__AMX_TRANSPOSE__");
   if (HasAMXAVX512)
     Builder.defineMacro("__AMX_AVX512__");
+  if (HasAMXTF32)
+    Builder.defineMacro("__AMX_TF32__");
   if (HasCMPCCXADD)
     Builder.defineMacro("__CMPCCXADD__");
   if (HasRAOINT)
@@ -1090,6 +1094,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
       .Case("amx-fp16", true)
       .Case("amx-fp8", true)
       .Case("amx-int8", true)
+      .Case("amx-tf32", true)
       .Case("amx-tile", true)
       .Case("amx-transpose", true)
       .Case("avx", true)
@@ -1211,6 +1216,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
       .Case("amx-fp16", HasAMXFP16)
       .Case("amx-fp8", HasAMXFP8)
       .Case("amx-int8", HasAMXINT8)
+      .Case("amx-tf32", HasAMXTF32)
       .Case("amx-tile", HasAMXTILE)
       .Case("amx-transpose", HasAMXTRANSPOSE)
       .Case("avx", SSELevel >= AVX)
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index 70047731b17295..04b1d5d33ea231 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -160,6 +160,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
   bool HasAMXFP8 = false;
   bool HasAMXTRANSPOSE = false;
   bool HasAMXAVX512 = false;
+  bool HasAMXTF32 = false;
   bool HasSERIALIZE = false;
   bool HasTSXLDTRK = false;
   bool HasUSERMSR = false;
diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index 76366ca1f108e9..0ad9596ba9e257 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -151,6 +151,7 @@ set(x86_files
   amxfp16intrin.h
   amxfp8intrin.h
   amxintrin.h
+  amxtf32intrin.h
   amxtransposeintrin.h
   avx10_2_512bf16intrin.h
   avx10_2_512convertintrin.h
diff --git a/clang/lib/Headers/amxtf32intrin.h b/clang/lib/Headers/amxtf32intrin.h
new file mode 100644
index 00000000000000..f11b7c7499e2d5
--- /dev/null
+++ b/clang/lib/Headers/amxtf32intrin.h
@@ -0,0 +1,194 @@
+/*===------------- amxtf32intrin.h - AMX_TF32 intrinsics -*- C++ -*---------===
+ *
+ * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+ * See https://llvm.org/LICENSE.txt for license information.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ *
+ *===------------------------------------------------------------------------===
+ */
+
+#ifndef __IMMINTRIN_H
+#error "Never use <amxtf32intrin.h> directly; include <immintrin.h> instead."
+#endif // __IMMINTRIN_H
+
+#ifndef __AMX_TF32INTRIN_H
+#define __AMX_TF32INTRIN_H
+#ifdef __x86_64__
+
+#define __DEFAULT_FN_ATTRS_TF32                                                \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-tf32")))
+
+#define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE                                      \
+  __attribute__((__always_inline__, __nodebug__,                               \
+                 __target__("amx-tf32,amx-transpose")))
+
+/// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus
+/// with \a srcdst.
+/// All the calculation is base on float32 but with the lower 13-bit set to 0.
+///
+/// \headerfile <immintrin.h>
+///
+/// \code
+/// void _tile_mmultf32ps(constexpr int srcdst, constexpr int a, \
+///                       constexpr int b);
+/// \endcode
+///
+/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
+///
+/// \param srcdst
+/// 	The destination tile. Max size is 1024 Bytes.
+/// \param a
+/// 	The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+/// 	The 2nd source tile. Max size is 1024 Bytes.
+///
+/// \code{.operation}
+/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
+///	dword[12:0] := 0
+///	dword[31:13] := x[31:13]
+///	return dword
+/// }
+///
+/// DEFINE silence_snan_fp32(x[31:0]) {
+/// 	IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
+/// 		x.fraction[22] := 1
+/// 	return x
+/// }
+///
+/// elements_a := a.colsb / 4
+/// elements_dest := srcdst.colsb / 4
+///
+/// FOR m = 0 TO (srcdst.rows-1)
+/// 	tmp[511:0] := 0
+/// 	FOR k = 0 TO (elements_a-1)
+/// 		FOR n = 0 TO (elements_dest-1)
+/// 			af := silence_snan_fp32(a.row[m].fp32[k])
+/// 			bf := silence_snan_fp32(b.row[k].fp32[n])
+/// 			tmp.fp32[n] += zero_lower_mantissa_bits_fp32(af)
+/// 					* zero_lower_mantissa_bits_fp32(bf)
+/// 		ENDFOR
+/// 	ENDFOR
+///
+/// 	FOR n = 0 TO (elements_dest-1)
+/// 		tmp.fp32[n] += srcdst.row[m].fp32[n]
+/// 	ENDFOR
+///	write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
+///
+/// ENDFOR
+///
+/// zero_upper_rows(srcdst, srcdst.rows)
+/// zero_tileconfig_start()
+/// \endcode
+#define _tile_mmultf32ps(srcdst, a, b)                                         \
+  __builtin_ia32_tmmultf32ps((srcdst), (a), (b))
+
+/// \code
+/// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \
+///                        constexpr int b);
+/// \endcode
+///
+/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
+///
+/// \param srcdst
+/// 	The destination tile. Max size is 1024 Bytes.
+/// \param a
+/// 	The 1st source tile. Max size is 1024 Bytes.
+/// \param b
+/// 	The 2nd source tile. Max size is 1024 Bytes.
+///
+/// \code{.operation}
+/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
+/// 	dword[12:0] := 0
+/// 	dword[31:13] := x[31:13]
+/// 	return dword
+/// }
+///
+/// DEFINE silence_snan_fp32(x[31:0]) {
+/// 	IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
+/// 		x.fraction[22] := 1
+/// 	return x
+/// }
+///
+/// elements_dest:= srcdst.colsb/4
+///
+/// FOR m := 0 TO (srcdst.rows-1)
+/// 	tmp[511:0] := 0
+/// 	FOR k := 0 TO (a.rows-1)
+/// 		FOR n := 0 TO (elements_dest-1)
+/// 			a1e := silence_snan_fp32(a.row[k].fp32[m])
+/// 			a2e := silence_snan_fp32(b.row[k].fp32[n])
+/// 			s1e := zero_lower_mantissa_bits_fp32(a1e)
+/// 			s2e := zero_lower_mantissa_bits_fp32(a2e)
+/// 			tmp.fp32[n] += s1e * s2e
+/// 		ENDFOR
+/// 	ENDFOR
+///
+/// 	FOR n := 0 TO (elements_dest-1)
+/// 		tmp.fp32[n] += srcdst.row[m].fp32[n]
+/// 	ENDFOR
+///	write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
+///
+/// ENDFOR
+///
+/// zero_upper_rows(srcdst, srcdst.rows)
+/// zero_tileconfig_start()
+/// \endcode
+#define _tile_tmmultf32ps(srcdst, a, b)                                        \
+  __builtin_ia32_ttmmultf32ps((srcdst), (a), (b))
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32
+_tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                          _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tmmultf32ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Do Matrix Multiplication of src0 and src1, and then do Matrix Plus with dst.
+/// All the calculation is base on float32 but with the lower 13-bit set to 0.
+///
+/// \headerfile <immintrin.h>
+///
+/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param src0
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param src1
+///    The 2nd source tile. Max size is 1024 Bytes.
+__DEFAULT_FN_ATTRS_TF32
+static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0,
+                              __tile1024i src1) {
+  dst->tile = _tile_mmultf32ps_internal(src0.row, src1.col, src0.col, dst->tile,
+                                        src0.tile, src1.tile);
+}
+
+// dst = m x n (srcdest), src1 = k x m, src2 = k x n
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
+_tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                           _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Compute transpose and do Matrix Multiplication of src0 and src1, and then do
+/// Matrix Plus with dst. All the calculation is base on float32 but with the
+/// lower 13-bit set to 0.
+///
+/// \headerfile <immintrin.h>
+///
+/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
+///
+/// \param dst
+///    The destination tile. Max size is 1024 Bytes.
+/// \param src0
+///    The 1st source tile. Max size is 1024 Bytes.
+/// \param src1
+///    The 2nd source tile. Max size is 1024 Bytes.
+__DEFAULT_FN_ATTRS_TF32_TRANSPOSE
+static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0,
+                               __tile1024i src1) {
+  dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col,
+                                         dst->tile, src0.tile, src1.tile);
+}
+
+#endif // __x86_64__
+#endif // __AMX_TF32INTRIN_H
diff --git a/clang/lib/Headers/immintrin.h b/clang/lib/Headers/immintrin.h
index bc240e28d59142..5740da8136ca99 100644
--- a/clang/lib/Headers/immintrin.h
+++ b/clang/lib/Headers/immintrin.h
@@ -660,6 +660,10 @@ _storebe_i64(void * __P, long long __D) {
 #include <amxavx512intrin.h>
 #endif
 
+#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_TF32__)
+#include <amxtf32intrin.h>
+#endif
+
 #if !defined(__SCE__) || __has_feature(modules) ||                             \
     defined(__AVX512VP2INTERSECT__)
 #include <avx512vp2intersectintrin.h>
diff --git a/clang/lib/Sema/SemaX86.cpp b/clang/lib/Sema/SemaX86.cpp
index 1155a5edc73c34..d7c8ed351f410a 100644
--- a/clang/lib/Sema/SemaX86.cpp
+++ b/clang/lib/Sema/SemaX86.cpp
@@ -654,6 +654,8 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) {
   case X86::BI__builtin_ia32_tdpbhf8ps:
   case X86::BI__builtin_ia32_tdphbf8ps:
   case X86::BI__builtin_ia32_tdphf8ps:
+  case X86::BI__builtin_ia32_tmmultf32ps:
+  case X86::BI__builtin_ia32_ttmmultf32ps:
     return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2});
   case X86::BI__builtin_ia32_ttransposed:
     return CheckBuiltinTileArgumentsRange(TheCall, {0, 1});
diff --git a/clang/test/CodeGen/X86/amx_tf32.c b/clang/test/CodeGen/X86/amx_tf32.c
new file mode 100644
index 00000000000000..661a9dfbc673b2
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_tf32.c
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-tf32 \
+// RUN: -target-feature +amx-transpose -emit-llvm -o - -Wall -Werror -pedantic -Wno-gnu-statement-expression | FileCheck %s
+
+#include <immintrin.h>
+#include <stddef.h>
+
+void test_tile_mmultf32ps(void) {
+  // CHECK-LABEL: @test_tile_mmultf32ps(
+  // CHECK: call void @llvm.x86.tmmultf32ps(i8 1, i8 2, i8 3)
+  _tile_mmultf32ps(1, 2, 3);
+}
+
+void test_tile_tmmultf32ps(void) {
+  // CHECK-LABEL: @test_tile_tmmultf32ps(
+  // CHECK: call void @llvm.x86.ttmmultf32ps(i8 1, i8 2, i8 3)
+  _tile_tmmultf32ps(1, 2, 3);
+}
diff --git a/clang/test/CodeGen/X86/amx_tf32_api.c b/clang/test/CodeGen/X86/amx_tf32_api.c
new file mode 100644
index 00000000000000..2ac8489e3e0baf
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_tf32_api.c
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown \
+// RUN: -target-feature +amx-tf32 -target-feature +amx-transpose  \
+// RUN: -target-feature +amx-bf16 -target-feature +avx512f \
+// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
+
+#include <immintrin.h>
+
+char buf[1024];
+#define STRIDE 32
+
+char buf2[1024];
+
+void test_tile_mmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) {
+  //CHECK-LABEL: @test_tile_mmultf32ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.tmmultf32ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_mmultf32ps(&c, a, b);
+}
+
+void test_tile_tmmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) {
+  //CHECK-LABEL: @test_tile_tmmultf32ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.ttmmultf32ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_tmmultf32ps(&c, a, b);
+}
diff --git a/clang/test/CodeGen/X86/amx_tf32_errors.c b/clang/test/CodeGen/X86/amx_tf32_errors.c
new file mode 100644
index 00000000000000..45021306921150
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_tf32_errors.c
@@ -0,0 +1,23 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown \
+// RUN: -target-feature +amx-tf32 -target-feature +amx-transpose -verify
+
+#include <immintrin.h>
+#include <stddef.h>
+
+void test_tile_mmultf32ps() {
+  _tile_mmultf32ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}}
+  _tile_mmultf32ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}}
+  _tile_mmultf32ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}}
+  _tile_mmultf32ps(1, 1, 3);  // expected-error {{tile arguments must refer to different tiles}}
+  _tile_mmultf32ps(1, 2, 1);  // expected-error {{tile arguments must refer to different tiles}}
+  _tile_mmultf32ps(1, 3, 3);  // expected-error {{tile arguments must refer to different tiles}}
+}
+
+void test_tile_tmmultf32ps() {
+  _tile_tmmultf32ps(16, 2, 3); // expected-error {{argument value 16 is outside the valid range [0, 7]}}
+  _tile_tmmultf32ps(1, 26, 3); // expected-error {{argument value 26 is outside the valid range [0, 7]}}
+  _tile_tmmultf32ps(1, 2, 36); // expected-error {{argument value 36 is outside the valid range [0, 7]}}
+  _tile_tmmultf32ps(1, 1, 3);  // expected-error {{tile arguments must refer to different tiles}}
+  _tile_tmmultf32ps(1, 2, 1);  // expected-error {{tile arguments must refer to different tiles}}
+  _tile_tmmultf32ps(1, 2, 2);  // expected-error {{tile arguments must refer to different tiles}}
+}
diff --git a/clang/test/CodeGen/X86/amx_tf32_inline_asm.c b/clang/test/CodeGen/X86/amx_tf32_inline_asm.c
new file mode 100644
index 00000000000000..76d164737d88b6
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_tf32_inline_asm.c
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tf32 -target-feature +amx-transpose -emit-llvm -o - -Wall -Werror -pedantic | FileCheck %s
+
+void f_tilemul(short a)
+{
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09tmmultf32ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "tmmultf32ps %%tmm6, %%tmm0, %%tmm7   \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+
+  //CHECK:  call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0   \0A\09tileloadd 0(%rdx,%r14,4), %tmm6   \0A\09ttmmultf32ps %tmm6, %tmm0, %tmm7    \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
+  __asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0   \n\t"
+                    "tileloadd 0(%%rdx,%%r14,4), %%tmm6   \n\t"
+                    "ttmmultf32ps %%tmm6, %%tmm0, %%tmm7  \n\t"
+                    "tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
+          ::: "memory", "tmm0", "tmm6", "tmm7");
+}
diff --git a/clang/test/Driver/x86-target-features.c b/clang/test/Driver/x86-target-features.c
index 822c997f71744f..339f593dc760a8 100644
--- a/clang/test/Driver/x86-target-features.c
+++ b/clang/test/Driver/x86-target-features.c
@@ -318,6 +318,13 @@
 // AMX-AVX512: "-target-feature" "+amx-avx512"
 // NO-AMX-AVX512: "-target-feature" "-amx-avx512"
 
+// RUN: %clang -target x86_64-unknown-linux-gnu -mamx-tf32 %s \
+// RUN: -### -o %t.o 2>&1 | FileCheck -check-prefix=AMX-TF32 %s
+// RUN: %clang -target x86_64-unknown-linux-gnu -mno-amx-tf32 %s \
+// RUN: -### -o %t.o 2>&1 | FileCheck -check-prefix=NO-AMX-TF32 %s
+// AMX-TF32: "-target-feature" "+amx-tf32"
+// NO-AMX-TF32: "-target-feature" "-amx-tf32"
+
 // RUN: %clang --target=i386 -march=i386 -mhreset %s -### 2>&1 | FileCheck -check-prefix=HRESET %s
 // RUN: %clang --target=i386 -march=i386 -mno-hreset %s -### 2>&1 | FileCheck -check-prefix=NO-HRESET %s
 // HRESET: "-target-feature" "+hreset"
diff --git a/clang/test/Preprocessor/x86_target_features.c b/clang/test/Preprocessor/x86_target_features.c
index 8e4ddb1526626e..fa3d0038f05a93 100644
--- a/clang/test/Preprocessor/x86_target_features.c
+++ b/clang/test/Preprocessor/x86_target_features.c
@@ -570,6 +570,15 @@
 
 // NO-AMX-AVX512-NOT: #define __AMX_AVX512__ 1
 
+// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mamx-tf32 -x c \
+// RUN: -E -dM -o - %s | FileCheck  -check-prefix=AMX-TF32 %s
+// AMX-TF32: #define __AMX_TF32__ 1
+// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mno-amx-tf32 -x c \
+// RUN: -E -dM -o - %s | FileCheck  -check-prefix=NO-AMX-TF32 %s
+// RUN: %clang -target x86_64-unknown-linux-gnu -march=x86-64 -mamx-tf32 -mno-amx-tile \
+// RUN: -x c -E -dM -o - %s | FileCheck  -check-prefix=NO-AMX-TF32 %s
+// NO-AMX-TF32-NOT: #define __AMX_TF32__ 1
+
 // RUN: %clang -target i386-unknown-unknown -march=atom -mavxvnni -x c -E -dM -o - %s | FileCheck -match-full-lines --check-prefix=AVXVNNI %s
 
 // AVXVNNI: #define __AVX2__ 1
diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index 3003f9887e239c..ce519daacdb210 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -6101,6 +6101,25 @@ let TargetPrefix = "x86" in {
               Intrinsic<[llvm_v16i32_ty],
                         [llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_i32_ty],
                         []>;
+
+  def int_x86_tmmultf32ps : ClangBuiltin<"__builtin_ia32_tmmultf32ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+              [ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+  def int_x86_ttmmultf32ps : ClangBuiltin<"__builtin_ia32_ttmmultf32ps">,
+              Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
+              [ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
+  def int_x86_tmmultf32ps_internal :
+              ClangBuiltin<"__builtin_ia32_tmmultf32ps_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
+  def int_x86_ttmmultf32ps_internal :
+              ClangBuiltin<"__builtin_ia32_ttmmultf32ps_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/TargetParser/X86TargetParser.def b/llvm/include/llvm/TargetParser/X86TargetParser.def
index 815556e374bef5..3b643563775688 100644
--- a/llvm/include/llvm/TargetParser/X86TargetParser.def
+++ b/llvm/include/llvm/TargetParser/X86TargetParser.def
@@ -267,6 +267,7 @@ X86_FEATURE       (ZU,              "zu")
 X86_FEATURE       (AMX_FP8,         "amx-fp8")
 X86_FEATURE       (AMX_TRANSPOSE,   "amx-transpose")
 X86_FEATURE       (AMX_AVX512,      "amx-avx512")
+X86_FEATURE       (AMX_TF32,        "amx-tf32")
 // These features aren't really CPU features, but the frontend can set them.
 X86_FEATURE       (RETPOLINE_EXTERNAL_THUNK,    "retpoline-external-thunk")
 X86_FEATURE       (RETPOLINE_INDIRECT_BRANCHES, "retpoline-indirect-branches")
diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td
index 59780ba5b99fcf..35bbffdb20942d 100644
--- a/llvm/lib/Target/X86/X86.td
+++ b/llvm/lib/Target/X86/X86.td
@@ -280,6 +280,9 @@ def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512",
                                         "HasAMXAVX512", "true",
                                         "Support AMX-AVX512 instructions",
                                         [FeatureAMXTILE]>;
+def FeatureAMXTF32 : SubtargetFeature<"amx-tf32", "HasAMXTF32", "true",
+                                      "Support AMX-TF32 instructions",
+                                      [FeatureAMXTILE]>;
 def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true",
                                         "Support CMPCCXADD instructions">;
 def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true",
diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
index a6096e5032e89c..4f045d78f75fb2 100644
--- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -755,7 +755,9 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
   case X86::PTDPBUSDV:
   case X86::PTDPBUUDV:
   case X86::PTDPBF16PSV:
-  case X86::PTDPFP16PSV: {
+  case X86::PTDPFP16PSV:
+  case X86::PTMMULTF32PSV:
+  case X86::PTTMMULTF32PSV: {
     MI.untieRegOperand(4);
     for (unsigned i = 3; i > 0; --i)
       MI.removeOperand(i);
@@ -769,6 +771,13 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
     case X86::PTDPBUUDV:   Opc = X86::TDPBUUD; break;
     case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;
     case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break;
+    case X86::PTMMULTF32PSV:
+      Opc = X86::TMMULTF32PS;
+      break;
+    case X86::PTTMMULTF32PSV:
+      Opc = X86::TTMMULTF32PS;
+      break;
+
     default:
       llvm_unreachable("Unexpected Opcode");
     }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 253b768f34a07c..6140dffbe8ce65 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37686,6 +37686,28 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
     MI.eraseFromParent(); // The pseudo is gone now.
     return BB;
   }
+  case X86::PTMMULTF32PS:
+  case X86::PTTMMULTF32PS: {
+    const DebugLoc &DL = MI.getDebugLoc();
+    unsigned Opc;
+    switch (MI.getOpcode()) {
+    default:
+      llvm_unreachable("Unexpected instruction!");
+    case X86::PTMMULTF32PS:
+      Opc = X86::TMMULTF32PS;
+      break;
+    case X86::PTTMMULTF32PS:
+      Opc = X86::TTMMULTF32PS;
+      break;
+    }
+    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
+    MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
+    MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef);
+    MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
+    MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef);
+    MI.eraseFromParent();
+    return BB;
+  }
   }
 }
 
diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index b954c977f8c6c9..1b579c488c2f00 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -516,3 +516,55 @@ let Predicates = [HasAMXAVX512, HasAVX10_2_512, In64BitMode] in {
                                      TILE:$src3, GR32:$src4))]>;
   }
 }
+
+let Predicates = [HasAMXTF32, In64BitMode] in {
+  let SchedRW = [WriteSystem] in {
+    let Constraints = "$src1 = $dst" in {
+      def TMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst),
+                         (ins TILE:$src1, TILE:$src2, TILE:$src3),
+                         "tmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
+                         []>, VEX, VVVV, T8, PD;
+    }
+    let Constraints = "$src4 = $dst" in {
+      def PTMMULTF32PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
+                                  GR16:$src2, GR16:$src3, TILE:$src4,
+                                  TILE:$src5, TILE:$src6),
+                                  [(set TILE:$dst,
+                                  (int_x86_tmmultf32ps_internal GR16:$src1,
+                                  GR16:$src2, GR16:$src3, TILE:$src4,
+                                  TILE:$src5, TILE:$src6))]>;
+    }
+    let usesCustomInserter = 1 in {
+      def PTMMULTF32PS : PseudoI<(outs), (ins u8imm:$src1,
+                                 u8imm:$src2, u8imm:$src3),
+                                 [(int_x86_tmmultf32ps timm:$src1,
+                                   timm:$src2, timm:$src3)]>;
+    }
+  } // SchedRW = [WriteSystem]
+} // HasAMXTF32
+
+let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in {
+  let SchedRW = [WriteSystem] in {
+    let Constraints = "$src1 = $dst" in {
+      def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst),
+                         (ins TILE:$src1, TILE:$src2, TILE:$src3),
+                         "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
+                         []>, VEX, VVVV, T8, PS;
+    }
+    let Constraints = "$src4 = $dst" in {
+      def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1,
+                                  GR16:$src2, GR16:$src3, TILE:$src4,
+                                  TILE:$src5, TILE:$src6),
+                                  [(set TILE:$dst,
+                                  (int_x86_ttmmultf32ps_internal GR16:$src1,
+                                  GR16:$src2, GR16:$src3, TILE:$src4,
+                                  TILE:$src5, TILE:$src6))]>;
+    }
+    let usesCustomInserter = 1 in {
+      def PTTMMULTF32PS : PseudoI<(outs), (ins u8imm:$src1,
+                                 u8imm:$src2, u8imm:$src3),
+                                 [(int_x86_ttmmultf32ps timm:$src1,
+                                   timm:$src2, timm:$src3)]>;
+    }
+  } // SchedRW = [WriteSystem]
+} // HasAMXTF32, HasAMXTRANSPOSE
diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td
index 2eb4e4fb941b29..a9ec5f660ff1d8 100644
--- a/llvm/lib/Target/X86/X86InstrPredicates.td
+++ b/llvm/lib/Target/X86/X86InstrPredicates.td
@@ -186,6 +186,7 @@ def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;
 def HasAMXFP8    : Predicate<"Subtarget->hasAMXFP8()">;
 def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">;
 def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">;
+def HasAMXTF32   : Predicate<"Subtarget->hasAMXTF32()">;
 def HasUINTR     : Predicate<"Subtarget->hasUINTR()">;
 def HasUSERMSR   : Predicate<"Subtarget->hasUSERMSR()">;
 def HasCRC32     : Predicate<"Subtarget->hasCRC32()">;
diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 08c065c39ee1e3..0e74cfa75e9606 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -241,7 +241,8 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
   case Intrinsic::x86_tdpbusd_internal:
   case Intrinsic::x86_tdpbuud_internal:
   case Intrinsic::x86_tdpbf16ps_internal:
-  case Intrinsic::x86_tdpfp16ps_internal: {
+  case Intrinsic::x86_tdpfp16ps_internal:
+  case Intrinsic::x86_tmmultf32ps_internal: {
     switch (OpNo) {
     case 3:
       Row = II->getArgOperand(0);
@@ -275,6 +276,23 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
     Col = II->getArgOperand(1);
     break;
   }
+  case Intrinsic::x86_ttmmultf32ps_internal: {
+    switch (OpNo) {
+    case 3:
+      Row = II->getArgOperand(0);
+      Col = II->getArgOperand(1);
+      break;
+    case 4:
+      Row = getRowFromCol(II, II->getArgOperand(2), 4);
+      Col = getColFromRow(II, II->getArgOperand(0), 4);
+      break;
+    case 5:
+      Row = getRowFromCol(II, II->getArgOperand(2), 4);
+      Col = II->getArgOperand(1);
+      break;
+    }
+    break;
+  }
   }
 
   return std::make_pair(Row, Col);
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 1b2192e3891fc5..09418c9bb74d34 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -1076,7 +1076,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
   case X86::PTDPFP16PSV:
   case X86::PTCMMIMFP16PSV:
   case X86::PTCMMRLFP16PSV:
-  case X86::PTTRANSPOSEDV: {
+  case X86::PTTRANSPOSEDV:
+  case X86::PTMMULTF32PSV:
+  case X86::PTTMMULTF32PSV: {
     MachineOperand &MO1 = MI->getOperand(1);
     MachineOperand &MO2 = MI->getOperand(2);
     ShapeT Shape(&MO1, &MO2, MRI);
diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp
index a973aaaa4806e6..140e565e1686f2 100644
--- a/llvm/lib/TargetParser/Host.cpp
+++ b/llvm/lib/TargetParser/Host.cpp
@@ -1880,6 +1880,7 @@ const StringMap<bool> sys::getHostCPUFeatures() {
                    !getX86CpuIDAndInfoEx(0x1e, 0x1, &EAX, &EBX, &ECX, &EDX);
   Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave;
   Features["amx-transpose"] = HasLeaf1E && ((EAX >> 5) & 1) && HasAMXSave;
+  Features["amx-tf32"] = HasLeaf1E && ((EAX >> 6) & 1) && HasAMXSave;
   Features["amx-avx512"] = HasLeaf1E && ((EAX >> 7) & 1) && HasAMXSave;
 
   bool HasLeaf24 =
diff --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp
index eb55e6fc9134c8..6b53424833bd47 100644
--- a/llvm/lib/TargetParser/X86TargetParser.cpp
+++ b/llvm/lib/TargetParser/X86TargetParser.cpp
@@ -602,6 +602,7 @@ constexpr FeatureBitset ImpliedFeaturesAMX_FP8 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesAMX_TRANSPOSE = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesAMX_AVX512 =
     FeatureAMX_TILE | FeatureAVX10_2_512;
+constexpr FeatureBitset ImpliedFeaturesAMX_TF32 = FeatureAMX_TILE;
 constexpr FeatureBitset ImpliedFeaturesHRESET = {};
 
 constexpr FeatureBitset ImpliedFeaturesPREFETCHI = {};
diff --git a/llvm/test/CodeGen/X86/amx-tf32-internal.ll b/llvm/test/CodeGen/X86/amx-tf32-internal.ll
new file mode 100644
index 00000000000000..8094f990828bad
--- /dev/null
+++ b/llvm/test/CodeGen/X86/amx-tf32-internal.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+avx512f, \
+; RUN: -mattr=+amx-tf32,+amx-transpose -verify-machineinstrs | FileCheck %s
+
+define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
+; CHECK-LABEL: test_amx:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vxorps %xmm0, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %zmm0, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $1, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw $8, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw $8, %ax
+; CHECK-NEXT:    tileloadd (%rsi,%rdx), %tmm0
+; CHECK-NEXT:    tilezero %tmm1
+; CHECK-NEXT:    tilezero %tmm2
+; CHECK-NEXT:    tmmultf32ps %tmm1, %tmm0, %tmm2
+; CHECK-NEXT:    ttmmultf32ps %tmm1, %tmm0, %tmm2
+; CHECK-NEXT:    tilestored %tmm2, (%rdi,%rdx)
+; CHECK-NEXT:    tilerelease
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+
+  %a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride)
+  %b = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8)
+  %c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8)
+
+  %c1 = call x86_amx @llvm.x86.tmmultf32ps.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b)
+  %c2 = call x86_amx @llvm.x86.ttmmultf32ps.internal(i16 8, i16 8, i16 8, x86_amx %c1, x86_amx %a, x86_amx %b)
+
+  call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %c2)
+  ret void
+}
+
+declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tileloaddt164.internal(i16, i16, i8*, i64)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+
+declare x86_amx @llvm.x86.tmmultf32ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.ttmmultf32ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
diff --git a/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll b/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll
new file mode 100644
index 00000000000000..af1a7ae1029756
--- /dev/null
+++ b/llvm/test/CodeGen/X86/amx-tf32-intrinsics.ll
@@ -0,0 +1,23 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -O0 -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-tf32,+amx-transpose -verify-machineinstrs | FileCheck %s
+
+define void @test_tmmultf32ps() {
+; CHECK-LABEL: test_tmmultf32ps:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    tmmultf32ps %tmm3, %tmm2, %tmm1
+; CHECK-NEXT:    retq
+  call void @llvm.x86.tmmultf32ps(i8 1, i8 2, i8 3)
+  ret  void
+}
+declare void @llvm.x86.tmmultf32ps(i8 %A, i8 %B, i8 %C)
+
+define void @test_ttmmultf32ps() {
+; CHECK-LABEL: test_ttmmultf32ps:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ttmmultf32ps %tmm3, %tmm2, %tmm1
+; CHECK-NEXT:    retq
+  call void @llvm.x86.ttmmultf32ps(i8 1, i8 2, i8 3)
+  ret  void
+}
+declare void @llvm.x86.ttmmultf32ps(i8 %A, i8 %B, i8 %C)
+
diff --git a/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt b/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt
new file mode 100644
index 00000000000000..f372c42982b1b6
--- /dev/null
+++ b/llvm/test/MC/Disassembler/X86/AMX/x86-64-amx-tf32.txt
@@ -0,0 +1,19 @@
+# RUN: llvm-mc --disassemble %s -triple=x86_64 | FileCheck -check-prefix=ATT %s
+# RUN: llvm-mc --disassemble %s -triple=x86_64 -x86-asm-syntax=intel --output-asm-variant=1 | FileCheck -check-prefix=INTEL %s
+
+# ATT:      tmmultf32ps %tmm4, %tmm5, %tmm6
+# INTEL:      tmmultf32ps tmm6, tmm5, tmm4
+0xc4,0xe2,0x59,0x48,0xf5
+
+# ATT:      tmmultf32ps %tmm1, %tmm2, %tmm3
+# INTEL:      tmmultf32ps tmm3, tmm2, tmm1
+0xc4,0xe2,0x71,0x48,0xda
+
+# ATT:      ttmmultf32ps %tmm4, %tmm5, %tmm6
+# INTEL:      ttmmultf32ps tmm6, tmm5, tmm4
+0xc4,0xe2,0x58,0x48,0xf5
+
+# ATT:      ttmmultf32ps %tmm1, %tmm2, %tmm3
+# INTEL:      ttmmultf32ps tmm3, tmm2, tmm1
+0xc4,0xe2,0x70,0x48,0xda
+
diff --git a/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s
new file mode 100644
index 00000000000000..b413597cd9da71
--- /dev/null
+++ b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-att.s
@@ -0,0 +1,17 @@
+// RUN: llvm-mc -triple x86_64-unknown-unknown --show-encoding < %s  | FileCheck %s
+
+// CHECK:      tmmultf32ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe2,0x59,0x48,0xf5]
+               tmmultf32ps %tmm4, %tmm5, %tmm6
+
+// CHECK:      tmmultf32ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe2,0x71,0x48,0xda]
+               tmmultf32ps %tmm1, %tmm2, %tmm3
+
+// CHECK:      ttmmultf32ps %tmm4, %tmm5, %tmm6
+// CHECK: encoding: [0xc4,0xe2,0x58,0x48,0xf5]
+               ttmmultf32ps %tmm4, %tmm5, %tmm6
+
+// CHECK:      ttmmultf32ps %tmm1, %tmm2, %tmm3
+// CHECK: encoding: [0xc4,0xe2,0x70,0x48,0xda]
+               ttmmultf32ps %tmm1, %tmm2, %tmm3
diff --git a/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s
new file mode 100644
index 00000000000000..98f55275716eb0
--- /dev/null
+++ b/llvm/test/MC/X86/AMX/x86-64-amx-tf32-intel.s
@@ -0,0 +1,17 @@
+// RUN: llvm-mc -triple x86_64-unknown-unknown -x86-asm-syntax=intel -output-asm-variant=1 --show-encoding %s | FileCheck %s
+
+// CHECK:      tmmultf32ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe2,0x59,0x48,0xf5]
+               tmmultf32ps tmm6, tmm5, tmm4
+
+// CHECK:      tmmultf32ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe2,0x71,0x48,0xda]
+               tmmultf32ps tmm3, tmm2, tmm1
+
+// CHECK:      ttmmultf32ps tmm6, tmm5, tmm4
+// CHECK: encoding: [0xc4,0xe2,0x58,0x48,0xf5]
+               ttmmultf32ps tmm6, tmm5, tmm4
+
+// CHECK:      ttmmultf32ps tmm3, tmm2, tmm1
+// CHECK: encoding: [0xc4,0xe2,0x70,0x48,0xda]
+               ttmmultf32ps tmm3, tmm2, tmm1



More information about the llvm-commits mailing list