[llvm] X86: implement lowerings for shuffles on `bf16` element type. (PR #76076)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 20 09:09:30 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: None (bjacob)
<details>
<summary>Changes</summary>
These were apparently just unimplemented.
Do you prefer to take this over or teach me how/where to add tests for this?
---
Full diff: https://github.com/llvm/llvm-project/pull/76076.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+11)
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+55-14)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 0917d0e4eb3e26..4dd8bfc76395d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -927,6 +927,17 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
Chain = Result.getValue(1);
break;
}
+ if (SrcVT.getScalarType() == MVT::bf16) {
+ EVT ISrcVT = SrcVT.changeTypeToInteger();
+ EVT IDestVT = DestVT.changeTypeToInteger();
+ EVT ILoadVT = TLI.getRegisterType(IDestVT.getSimpleVT());
+
+ SDValue Result = DAG.getExtLoad(ISD::ZEXTLOAD, dl, ILoadVT, Chain,
+ Ptr, ISrcVT, LD->getMemOperand());
+ Value = DAG.getNode(ISD::BF16_TO_FP, dl, DestVT, Result);
+ Chain = Result.getValue(1);
+ break;
+ }
}
assert(!SrcVT.isVector() &&
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index db5e4fe84f410a..b7123256f57dd6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -12349,7 +12349,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
MVT EltVT = VT.getVectorElementType();
if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
(Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
- (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
+ (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
+ (Subtarget.hasBF16() && EltVT == MVT::bf16)))
return SDValue();
// With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
@@ -12527,7 +12528,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// possibly narrower than VT. Then perform the broadcast.
unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
MVT CastVT = MVT::getVectorVT(VT.getVectorElementType(), NumSrcElts);
- return DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+ const auto &retval = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+ return retval;
}
// Check for whether we can use INSERTPS to perform the shuffle. We only use
@@ -13933,28 +13935,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
- assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
- assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+ assert((V1.getSimpleValueType() == MVT::v8f16 ||
+ V1.getSimpleValueType() == MVT::v8bf16) &&
+ "Bad operand type!");
+ assert(V2.getSimpleValueType() == V2.getSimpleValueType());
assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
- if (Subtarget.hasFP16()) {
+ if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+ (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
if (NumV2Elements == 0) {
// Check for being able to broadcast a single element.
- if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
- Mask, Subtarget, DAG))
+ if (SDValue Broadcast = lowerShuffleAsBroadcast(
+ DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
return Broadcast;
}
if (NumV2Elements == 1 && Mask[0] >= 8)
if (SDValue V = lowerShuffleAsElementInsertion(
- DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+ DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+ DAG))
return V;
}
-
- V1 = DAG.getBitcast(MVT::v8i16, V1);
- V2 = DAG.getBitcast(MVT::v8i16, V2);
- return DAG.getBitcast(MVT::v8f16,
- DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+ DAG.getBitcast(MVT::v8i16, V2), Mask));
}
// Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
@@ -14377,6 +14381,7 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
case MVT::v8i16:
return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8f16:
+ case MVT::v8bf16:
return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i8:
return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -16295,6 +16300,21 @@ static SDValue lowerV16I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
Subtarget, DAG);
}
+static SDValue lowerV16F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+ const APInt &Zeroable, SDValue V1, SDValue V2,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ assert((V1.getSimpleValueType() == MVT::v16f16 ||
+ V1.getSimpleValueType() == MVT::v16bf16) &&
+ "Bad operand type!");
+ assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+ "Bad operand type!");
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ lowerV16I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v16i16, V1),
+ DAG.getBitcast(MVT::v16i16, V2), Subtarget, DAG));
+}
+
/// Handle lowering of 32-lane 8-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
@@ -16480,6 +16500,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
return lowerV4I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8f32:
return lowerV8F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+ case MVT::v8f16:
+ case MVT::v8bf16:
+ return lowerV16F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8i32:
return lowerV8I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i16:
@@ -16953,6 +16976,21 @@ static SDValue lowerV32I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, Subtarget, DAG);
}
+static SDValue lowerV32F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+ const APInt &Zeroable, SDValue V1, SDValue V2,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ assert((V1.getSimpleValueType() == MVT::v32f16 ||
+ V1.getSimpleValueType() == MVT::v32bf16) &&
+ "Bad operand type!");
+ assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+ "Bad operand type!");
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ lowerV32I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v32i16, V1),
+ DAG.getBitcast(MVT::v32i16, V2), Subtarget, DAG));
+}
+
/// Handle lowering of 64-lane 8-bit integer shuffles.
static SDValue lowerV64I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17112,6 +17150,9 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerV8F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16f32:
return lowerV16F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+ case MVT::v32f16:
+ case MVT::v32bf16:
+ return lowerV32F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8i64:
return lowerV8I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i32:
``````````
</details>
https://github.com/llvm/llvm-project/pull/76076
More information about the llvm-commits
mailing list