[llvm] X86: implement lowerings for shuffles on `bf16` element type. (PR #76076)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 22 12:39:39 PST 2023


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/76076

>From 730bf96877f802a2ee7827c16f2ce8e648a862fe Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 18 Dec 2023 11:12:11 -0500
Subject: [PATCH] bf16

---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 34 +++++++++++--------
 .../X86/avx512-shuffles/shuffle-bf16.ll       | 26 ++++++++++++++
 2 files changed, 45 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/avx512-shuffles/shuffle-bf16.ll

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 49112862a31425..fcaa72ac49dab4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -12348,7 +12348,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
   MVT EltVT = VT.getVectorElementType();
   if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
         (Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
-        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
+        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
+        (Subtarget.hasBF16() && EltVT == MVT::bf16)))
     return SDValue();
 
   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
@@ -13932,28 +13933,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
                                  const X86Subtarget &Subtarget,
                                  SelectionDAG &DAG) {
-  assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
-  assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+  assert((V1.getSimpleValueType() == MVT::v8f16 ||
+          V1.getSimpleValueType() == MVT::v8bf16) &&
+         "Bad operand type!");
+  assert(V2.getSimpleValueType() == V2.getSimpleValueType());
   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
-  if (Subtarget.hasFP16()) {
+  if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+      (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
     if (NumV2Elements == 0) {
       // Check for being able to broadcast a single element.
-      if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
-                                                      Mask, Subtarget, DAG))
+      if (SDValue Broadcast = lowerShuffleAsBroadcast(
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
         return Broadcast;
     }
     if (NumV2Elements == 1 && Mask[0] >= 8)
       if (SDValue V = lowerShuffleAsElementInsertion(
-              DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+              DAG))
         return V;
   }
-
-  V1 = DAG.getBitcast(MVT::v8i16, V1);
-  V2 = DAG.getBitcast(MVT::v8i16, V2);
-  return DAG.getBitcast(MVT::v8f16,
-                        DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+                           DAG.getBitcast(MVT::v8i16, V2), Mask));
 }
 
 // Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
@@ -14376,6 +14379,7 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
   case MVT::v8i16:
     return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8f16:
+  case MVT::v8bf16:
     return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i8:
     return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -17091,14 +17095,14 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
     return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG, /*SimpleOnly*/ false);
   }
 
-  if (VT == MVT::v32f16) {
+  if (VT == MVT::v32f16 || VT == MVT::v32bf16) {
     if (!Subtarget.hasBWI())
       return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG,
                                   /*SimpleOnly*/ false);
 
     V1 = DAG.getBitcast(MVT::v32i16, V1);
     V2 = DAG.getBitcast(MVT::v32i16, V2);
-    return DAG.getBitcast(MVT::v32f16,
+    return DAG.getBitcast(VT,
                           DAG.getVectorShuffle(MVT::v32i16, DL, V1, V2, Mask));
   }
 
diff --git a/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-bf16.ll b/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-bf16.ll
new file mode 100644
index 00000000000000..1a7ade61284b01
--- /dev/null
+++ b/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-bf16.ll
@@ -0,0 +1,26 @@
+; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f,+avx512vl,+avx512bw,+avx512bf16 | FileCheck %s
+
+target triple = "x86_64-unknown-linux-gnu"
+
+define <32 x bfloat> @shuffle_v32bf16(<32 x bfloat> %a) {
+; CHECK-LABEL: shuffle_v32bf16:
+; CHECK:       vmovdqa64 {{.*}} zmm1 = [0,16,1,17,2,18,3,19,4,20,5,21,6,22,7,23,8,24,9,25,10,26,11,27,12,28,13,29,14,30,15,31]
+; CHECK:       vpermw %zmm0, %zmm1, %zmm0
+  %s = shufflevector <32 x bfloat> %a, <32 x bfloat> zeroinitializer, <32 x i32> <i32 0, i32 16, i32 1, i32 17, i32 2, i32 18, i32 3, i32 19, i32 4, i32 20, i32 5, i32 21, i32 6, i32 22, i32 7, i32 23, i32 8, i32 24, i32 9, i32 25, i32 10, i32 26, i32 11, i32 27, i32 12, i32 28, i32 13, i32 29, i32 14, i32 30, i32 15, i32 31>
+  ret <32 x bfloat> %s
+}
+
+define <16 x bfloat> @shuffle_v16bf16(<16 x bfloat> %a) {
+; CHECK-LABEL: shuffle_v16bf16:
+; CHECK:       vmovdqa {{.*}} ymm1 = [0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15]
+; CHECK:       vpermw %ymm0, %ymm1, %ymm0
+  %s = shufflevector <16 x bfloat> %a, <16 x bfloat> zeroinitializer, <16 x i32> <i32 0, i32 8, i32 1, i32 9, i32 2, i32 10, i32 3, i32 11, i32 4, i32 12, i32 5, i32 13, i32 6, i32 14, i32 7, i32 15>
+  ret <16 x bfloat> %s
+}
+
+define <8 x bfloat> @shuffle_v8bf16(<8 x bfloat> %a) {
+; CHECK-LABEL: shuffle_v8bf16:
+; CHECK:       vpshufb {{.*}} %xmm0, %xmm0 {{.*}} xmm0[0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15]
+  %s = shufflevector <8 x bfloat> %a, <8 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+  ret <8 x bfloat> %s
+}



More information about the llvm-commits mailing list