[clang] 4bc7c86 - [X86] Support amx-bf16 intrinsic.

via cfe-commits cfe-commits at lists.llvm.org
Wed Feb 24 17:07:26 PST 2021


Author: Liu, Chen3
Date: 2021-02-25T09:06:48+08:00
New Revision: 4bc7c8631ad62487a290dd4b7791848b67635787

URL: https://github.com/llvm/llvm-project/commit/4bc7c8631ad62487a290dd4b7791848b67635787
DIFF: https://github.com/llvm/llvm-project/commit/4bc7c8631ad62487a290dd4b7791848b67635787.diff

LOG: [X86] Support amx-bf16 intrinsic.

Adding support for intrinsics of AMX-BF16.
This patch alse fix a bug that AMX-INT8 instructions will be selected with wrong
predicate.

Differential Revision: https://reviews.llvm.org/D97358

Added: 
    

Modified: 
    clang/include/clang/Basic/BuiltinsX86_64.def
    clang/lib/Headers/amxintrin.h
    clang/test/CodeGen/X86/amx_api.c
    llvm/include/llvm/IR/IntrinsicsX86.td
    llvm/lib/Target/X86/X86ExpandPseudo.cpp
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/lib/Target/X86/X86InstrAMX.td
    llvm/lib/Target/X86/X86LowerAMXType.cpp
    llvm/lib/Target/X86/X86PreTileConfig.cpp
    llvm/lib/Target/X86/X86RegisterInfo.cpp
    llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index aed46b352342..731f17452cbe 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -108,6 +108,7 @@ TARGET_BUILTIN(__builtin_ia32_tdpbusd_internal, "V256iUsUsUsV256iV256iV256i", "n
 TARGET_BUILTIN(__builtin_ia32_tdpbuud_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
 TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
+TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16")
 // AMX
 TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")

diff  --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h
index 31a2b64b9ff2..8c276519e362 100644
--- a/clang/lib/Headers/amxintrin.h
+++ b/clang/lib/Headers/amxintrin.h
@@ -15,8 +15,13 @@
 #define __AMXINTRIN_H
 #ifdef __x86_64__
 
+/* Define the default attributes for the functions in this file. */
 #define __DEFAULT_FN_ATTRS_TILE                                                \
   __attribute__((__always_inline__, __nodebug__, __target__("amx-tile")))
+#define __DEFAULT_FN_ATTRS_INT8                                                \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
+#define __DEFAULT_FN_ATTRS_BF16                                                \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
 
 /// Load tile configuration from a 64-byte memory location specified by
 /// "mem_addr". The tile configuration includes the tile type palette, the
@@ -221,10 +226,8 @@ static __inline__ void __DEFAULT_FN_ATTRS_TILE _tile_release(void) {
 #define _tile_dpbf16ps(dst, src0, src1)                                        \
   __builtin_ia32_tdpbf16ps((dst), (src0), (src1))
 
-#define __DEFAULT_FN_ATTRS_INT8                                                \
-  __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
-
 typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
+
 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base,
                      __SIZE_TYPE__ stride) {
@@ -263,6 +266,12 @@ _tile_stored_internal(unsigned short m, unsigned short n, void *base,
                                               (__SIZE_TYPE__)(stride), tile);
 }
 
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_BF16
+_tile_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                         _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
+}
+
 typedef struct __tile1024i_str {
   const unsigned short row;
   const unsigned short col;
@@ -313,5 +322,16 @@ static void __tile_zero(__tile1024i *dst) {
   dst->tile = __builtin_ia32_tilezero_internal(dst->row, dst->col);
 }
 
+__DEFAULT_FN_ATTRS_BF16
+static void __tile_tdpbf16ps(__tile1024i *dst, __tile1024i src1,
+                             __tile1024i src2) {
+  dst->tile = _tile_tdpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                       src1.tile, src2.tile);
+}
+
+#undef __DEFAULT_FN_ATTRS_TILE
+#undef __DEFAULT_FN_ATTRS_INT8
+#undef __DEFAULT_FN_ATTRS_BF16
+
 #endif /* __x86_64__ */
 #endif /* __AMXINTRIN_H */

diff  --git a/clang/test/CodeGen/X86/amx_api.c b/clang/test/CodeGen/X86/amx_api.c
index 7120de4c9e88..824a3aec20ec 100644
--- a/clang/test/CodeGen/X86/amx_api.c
+++ b/clang/test/CodeGen/X86/amx_api.c
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +avx512f  -target-feature +amx-int8  \
+// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +avx512f  -target-feature +amx-int8  \
 // RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK
 
 #include <immintrin.h>
@@ -80,3 +80,10 @@ void test_tile_zero(__tile1024i c) {
   //CHECK-NEXT bitcast x86_amx {{%.*}} to <256 x i32>
   __tile_zero(&c);
 }
+
+void test_tile_tdpbf16ps(__tile1024i a, __tile1024i b, __tile1024i c) {
+  //CHECK-LABEL: @test_tile_tdpbf16ps
+  //CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal
+  //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
+  __tile_tdpbf16ps(&a, b, c);
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index 2c1202cc2a05..643018b0eedb 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5079,6 +5079,12 @@ let TargetPrefix = "x86" in {
               GCCBuiltin<"__builtin_ia32_tilezero_internal">,
               Intrinsic<[llvm_x86amx_ty], [llvm_i16_ty, llvm_i16_ty],
                         []>;
+  def int_x86_tdpbf16ps_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">,
+              Intrinsic<[llvm_x86amx_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_x86amx_ty, llvm_x86amx_ty,
+                         llvm_x86amx_ty], []>;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
index fc4e9eb4a4bb..ab0062e027a7 100644
--- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -470,16 +470,18 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
   case X86::PTDPBSSDV:
   case X86::PTDPBSUDV:
   case X86::PTDPBUSDV:
-  case X86::PTDPBUUDV: {
+  case X86::PTDPBUUDV:
+  case X86::PTDPBF16PSV: {
     MI.untieRegOperand(4);
     for (unsigned i = 3; i > 0; --i)
       MI.RemoveOperand(i);
     unsigned Opc;
     switch (Opcode) {
-    case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break;
-    case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break;
-    case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break;
-    case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break;
+    case X86::PTDPBSSDV:   Opc = X86::TDPBSSD; break;
+    case X86::PTDPBSUDV:   Opc = X86::TDPBSUD; break;
+    case X86::PTDPBUSDV:   Opc = X86::TDPBUSD; break;
+    case X86::PTDPBUUDV:   Opc = X86::TDPBUUD; break;
+    case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;
     default: llvm_unreachable("Impossible Opcode!");
     }
     MI.setDesc(TII->get(Opc));

diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index bebd430af6a7..f34d34f8a34c 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4626,7 +4626,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
     case Intrinsic::x86_tdpbsud_internal:
     case Intrinsic::x86_tdpbusd_internal:
     case Intrinsic::x86_tdpbuud_internal: {
-      if (!Subtarget->hasAMXTILE())
+      if (!Subtarget->hasAMXINT8())
         break;
       SDValue Chain = Node->getOperand(0);
       unsigned Opc;

diff  --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index b93aab30161d..6731599b909f 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -138,6 +138,16 @@ let Predicates = [HasAMXBF16, In64BitMode] in {
                       "tdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}",
                       []>, VEX_4V, T8XS;
 
+    // Pseduo instruction for RA.
+    let Constraints = "$src4 = $dst" in
+      def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                                 GR16:$src2, GR16:$src3, TILE:$src4,
+                                 TILE:$src5, TILE:$src6),
+                                 [(set TILE: $dst,
+                                  (int_x86_tdpbf16ps_internal GR16:$src1,
+                                   GR16:$src2, GR16:$src3, TILE:$src4,
+                                   TILE:$src5, TILE:$src6))]>;
+
     let usesCustomInserter = 1 in {
       // Pseudo instructions, using immediates instead of tile registers.
       // To be translated to the actual instructions in X86ISelLowering.cpp

diff  --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp
index 3fdcf1607d22..5e844a083e7a 100644
--- a/llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -70,7 +70,8 @@ static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
   case Intrinsic::x86_tdpbssd_internal:
   case Intrinsic::x86_tdpbsud_internal:
   case Intrinsic::x86_tdpbusd_internal:
-  case Intrinsic::x86_tdpbuud_internal: {
+  case Intrinsic::x86_tdpbuud_internal:
+  case Intrinsic::x86_tdpbf16ps_internal: {
     switch (OpNo) {
     case 3:
       Row = II->getArgOperand(0);

diff  --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp
index 90b421b44d7a..1c91e87e69d5 100644
--- a/llvm/lib/Target/X86/X86PreTileConfig.cpp
+++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -159,6 +159,7 @@ static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) {
   case X86::PTDPBUSDV:
   case X86::PTDPBUUDV:
   case X86::PTILEZEROV:
+  case X86::PTDPBF16PSV:
     MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
     MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
     ShapeT Shape(&MO1, &MO2, MRI);
@@ -256,6 +257,7 @@ static bool isAMXInstruction(MachineBasicBlock::iterator MII) {
   case X86::PTDPBUSDV:
   case X86::PTDPBUUDV:
   case X86::PTILEZEROV:
+  case X86::PTDPBF16PSV:
     return true;
   }
 }

diff  --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 00bb73fa2d9a..9865216ee369 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -888,6 +888,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
   case X86::PTDPBUSDV:
   case X86::PTDPBUUDV:
   case X86::PTILEZEROV:
+  case X86::PTDPBF16PSV:
     MachineOperand &MO1 = MI->getOperand(1);
     MachineOperand &MO2 = MI->getOperand(2);
     ShapeT Shape(&MO1, &MO2, MRI);

diff  --git a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
index ebb6ee5bc231..095eb8e6ea8d 100644
--- a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
+++ b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile -mattr=+avx512f -verify-machineinstrs | FileCheck %s
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-int8,+amx-bf16,+avx512f -verify-machineinstrs | FileCheck %s
 
 define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
 ; CHECK-LABEL: test_amx:
@@ -22,6 +22,7 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
 ; CHECK-NEXT:    tdpbsud %tmm2, %tmm1, %tmm0
 ; CHECK-NEXT:    tdpbusd %tmm2, %tmm1, %tmm0
 ; CHECK-NEXT:    tdpbuud %tmm2, %tmm1, %tmm0
+; CHECK-NEXT:    tdpbf16ps %tmm2, %tmm1, %tmm0
 ; CHECK-NEXT:    tilestored %tmm0, (%rdi,%rdx)
 ; CHECK-NEXT:    tilerelease
 ; CHECK-NEXT:    vzeroupper
@@ -33,7 +34,8 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
   %d1 = call x86_amx @llvm.x86.tdpbsud.internal(i16 8, i16 8, i16 8, x86_amx %d0, x86_amx %a, x86_amx %b)
   %d2 = call x86_amx @llvm.x86.tdpbusd.internal(i16 8, i16 8, i16 8, x86_amx %d1, x86_amx %a, x86_amx %b)
   %d3 = call x86_amx @llvm.x86.tdpbuud.internal(i16 8, i16 8, i16 8, x86_amx %d2, x86_amx %a, x86_amx %b)
-  call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d3)
+  %d4 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 8, i16 8, i16 8, x86_amx %d3, x86_amx %a, x86_amx %b)
+  call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d4)
 
   ret void
 }
@@ -44,4 +46,5 @@ declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_
 declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
 declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
 declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
 declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)


        


More information about the cfe-commits mailing list