[llvm] X86: implement lowerings for shuffles on `bf16` element type. (PR #76076)

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 09:12:51 PST 2023


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/76076

>From 8b762ed775e846a3a68232fe41de8c048f2111ee Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 18 Dec 2023 11:12:11 -0500
Subject: [PATCH] bf16

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 11 ++++
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 66 +++++++++++++++----
 2 files changed, 64 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 0917d0e4eb3e26..4dd8bfc76395d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -927,6 +927,17 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
           Chain = Result.getValue(1);
           break;
         }
+        if (SrcVT.getScalarType() == MVT::bf16) {
+          EVT ISrcVT = SrcVT.changeTypeToInteger();
+          EVT IDestVT = DestVT.changeTypeToInteger();
+          EVT ILoadVT = TLI.getRegisterType(IDestVT.getSimpleVT());
+
+          SDValue Result = DAG.getExtLoad(ISD::ZEXTLOAD, dl, ILoadVT, Chain,
+                                          Ptr, ISrcVT, LD->getMemOperand());
+          Value = DAG.getNode(ISD::BF16_TO_FP, dl, DestVT, Result);
+          Chain = Result.getValue(1);
+          break;
+        }
       }
 
       assert(!SrcVT.isVector() &&
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index db5e4fe84f410a..dd2b3f50978e53 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -12349,7 +12349,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
   MVT EltVT = VT.getVectorElementType();
   if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
         (Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
-        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
+        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
+        (Subtarget.hasBF16() && EltVT == MVT::bf16)))
     return SDValue();
 
   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
@@ -13933,28 +13934,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
                                  const X86Subtarget &Subtarget,
                                  SelectionDAG &DAG) {
-  assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
-  assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+  assert((V1.getSimpleValueType() == MVT::v8f16 ||
+          V1.getSimpleValueType() == MVT::v8bf16) &&
+         "Bad operand type!");
+  assert(V2.getSimpleValueType() == V2.getSimpleValueType());
   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
-  if (Subtarget.hasFP16()) {
+  if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+      (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
     if (NumV2Elements == 0) {
       // Check for being able to broadcast a single element.
-      if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
-                                                      Mask, Subtarget, DAG))
+      if (SDValue Broadcast = lowerShuffleAsBroadcast(
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
         return Broadcast;
     }
     if (NumV2Elements == 1 && Mask[0] >= 8)
       if (SDValue V = lowerShuffleAsElementInsertion(
-              DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+              DAG))
         return V;
   }
-
-  V1 = DAG.getBitcast(MVT::v8i16, V1);
-  V2 = DAG.getBitcast(MVT::v8i16, V2);
-  return DAG.getBitcast(MVT::v8f16,
-                        DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+                           DAG.getBitcast(MVT::v8i16, V2), Mask));
 }
 
 // Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
@@ -14377,6 +14380,7 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
   case MVT::v8i16:
     return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8f16:
+  case MVT::v8bf16:
     return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i8:
     return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -16295,6 +16299,21 @@ static SDValue lowerV16I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                     Subtarget, DAG);
 }
 
+static SDValue lowerV16F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+                                  const APInt &Zeroable, SDValue V1, SDValue V2,
+                                  const X86Subtarget &Subtarget,
+                                  SelectionDAG &DAG) {
+  assert((V1.getSimpleValueType() == MVT::v16f16 ||
+          V1.getSimpleValueType() == MVT::v16bf16) &&
+         "Bad operand type!");
+  assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+         "Bad operand type!");
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      lowerV16I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v16i16, V1),
+                         DAG.getBitcast(MVT::v16i16, V2), Subtarget, DAG));
+}
+
 /// Handle lowering of 32-lane 8-bit integer shuffles.
 ///
 /// This routine is only called when we have AVX2 and thus a reasonable
@@ -16480,6 +16499,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
     return lowerV4I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8f32:
     return lowerV8F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+  case MVT::v8f16:
+  case MVT::v8bf16:
+    return lowerV16F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8i32:
     return lowerV8I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i16:
@@ -16953,6 +16975,21 @@ static SDValue lowerV32I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, Subtarget, DAG);
 }
 
+static SDValue lowerV32F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+                                  const APInt &Zeroable, SDValue V1, SDValue V2,
+                                  const X86Subtarget &Subtarget,
+                                  SelectionDAG &DAG) {
+  assert((V1.getSimpleValueType() == MVT::v32f16 ||
+          V1.getSimpleValueType() == MVT::v32bf16) &&
+         "Bad operand type!");
+  assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+         "Bad operand type!");
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      lowerV32I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v32i16, V1),
+                         DAG.getBitcast(MVT::v32i16, V2), Subtarget, DAG));
+}
+
 /// Handle lowering of 64-lane 8-bit integer shuffles.
 static SDValue lowerV64I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17112,6 +17149,9 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
     return lowerV8F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16f32:
     return lowerV16F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+  case MVT::v32f16:
+  case MVT::v32bf16:
+    return lowerV32F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8i64:
     return lowerV8I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i32:



More information about the llvm-commits mailing list