[llvm] X86: add some missing lowerings for shuffles on `bf16` element type. (PR #76076)

Benoit Jacob via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 3 13:58:45 PST 2024


================
@@ -13932,28 +13933,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
                                  const X86Subtarget &Subtarget,
                                  SelectionDAG &DAG) {
-  assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
-  assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+  assert((V1.getSimpleValueType() == MVT::v8f16 ||
+          V1.getSimpleValueType() == MVT::v8bf16) &&
+         "Bad operand type!");
+  assert(V2.getSimpleValueType() == V2.getSimpleValueType());
   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
-  if (Subtarget.hasFP16()) {
+  if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+      (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
     if (NumV2Elements == 0) {
       // Check for being able to broadcast a single element.
-      if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
-                                                      Mask, Subtarget, DAG))
+      if (SDValue Broadcast = lowerShuffleAsBroadcast(
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
         return Broadcast;
     }
     if (NumV2Elements == 1 && Mask[0] >= 8)
       if (SDValue V = lowerShuffleAsElementInsertion(
-              DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+              DAG))
         return V;
   }
-
-  V1 = DAG.getBitcast(MVT::v8i16, V1);
-  V2 = DAG.getBitcast(MVT::v8i16, V2);
-  return DAG.getBitcast(MVT::v8f16,
-                        DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+                           DAG.getBitcast(MVT::v8i16, V2), Mask));
----------------
bjacob wrote:

Got it - that works! Actually I had tried something like this earlier, except that, like in the 256bit and 512bit cases, I was trying to do it for both `bf16` and `f16`, but that was causing several test failures. What I was missing is the asymmetry here due to `f16`-specific logic in `lowerV8F16Shuffle` to target `vmovw` which only exists in `avx512fp16`. Once I made this bitcast only in the `bf16` case, it just worked.

Now the PR is much simpler - thanks! PTAL.

https://github.com/llvm/llvm-project/pull/76076


More information about the llvm-commits mailing list