[llvm] 7634905 - [X86][BF16] Share FP16 vector ABI with BF16
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 8 18:42:48 PDT 2023
Author: Phoebe Wang
Date: 2023-06-09T09:04:56+08:00
New Revision: 7634905a73652d26a7b36ec63c6511cc732aa7e7
URL: https://github.com/llvm/llvm-project/commit/7634905a73652d26a7b36ec63c6511cc732aa7e7
DIFF: https://github.com/llvm/llvm-project/commit/7634905a73652d26a7b36ec63c6511cc732aa7e7.diff
LOG: [X86][BF16] Share FP16 vector ABI with BF16
The ABI of BF16 is identical to FP16 rather than i16.
Fixes #62997
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D151710
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bfloat.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 83aacd69a6b6f..0f2059b0cdcd6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -417,6 +417,10 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
return Val;
if (PartEVT.isInteger() && ValueVT.isFloatingPoint())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+
+ // Vector/Vector bitcast (e.g. <2 x bfloat> -> <2 x half>).
+ if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
+ return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
}
// Promoted vector extract
@@ -622,6 +626,8 @@ static SDValue widenVectorToPartType(SelectionDAG &DAG, SDValue Val,
return SDValue();
EVT ValueVT = Val.getValueType();
+ EVT PartEVT = PartVT.getVectorElementType();
+ EVT ValueEVT = ValueVT.getVectorElementType();
ElementCount PartNumElts = PartVT.getVectorElementCount();
ElementCount ValueNumElts = ValueVT.getVectorElementCount();
@@ -629,22 +635,30 @@ static SDValue widenVectorToPartType(SelectionDAG &DAG, SDValue Val,
// fixed/scalable properties. If a target needs to widen a fixed-length type
// to a scalable one, it should be possible to use INSERT_SUBVECTOR below.
if (ElementCount::isKnownLE(PartNumElts, ValueNumElts) ||
- PartNumElts.isScalable() != ValueNumElts.isScalable() ||
- PartVT.getVectorElementType() != ValueVT.getVectorElementType())
+ PartNumElts.isScalable() != ValueNumElts.isScalable())
return SDValue();
+ // Have a try for bf16 because some targets share its ABI with fp16.
+ if (ValueEVT == MVT::bf16 && PartEVT == MVT::f16) {
+ assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) &&
+ "Cannot widen to illegal type");
+ Val = DAG.getNode(ISD::BITCAST, DL,
+ ValueVT.changeVectorElementType(MVT::f16), Val);
+ } else if (PartEVT != ValueEVT) {
+ return SDValue();
+ }
+
// Widening a scalable vector to another scalable vector is done by inserting
// the vector into a larger undef one.
if (PartNumElts.isScalable())
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
Val, DAG.getVectorIdxConstant(0, DL));
- EVT ElementVT = PartVT.getVectorElementType();
// Vector widening case, e.g. <2 x float> -> <4 x float>. Shuffle in
// undef elements.
SmallVector<SDValue, 16> Ops;
DAG.ExtractVectorElements(Val, Ops);
- SDValue EltUndef = DAG.getUNDEF(ElementVT);
+ SDValue EltUndef = DAG.getUNDEF(PartEVT);
Ops.append((PartNumElts - ValueNumElts).getFixedValue(), EltUndef);
// FIXME: Use CONCAT for 2x -> 4x.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 170396dc7ba9e..0bab667b49db3 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2608,7 +2608,7 @@ MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getRegisterTypeForCallingConv(Context, CC,
- VT.changeVectorElementTypeToInteger());
+ VT.changeVectorElementType(MVT::f16));
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}
@@ -2643,7 +2643,7 @@ unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
if (VT.isVector() && VT.getVectorElementType() == MVT::bf16)
return getNumRegistersForCallingConv(Context, CC,
- VT.changeVectorElementTypeToInteger());
+ VT.changeVectorElementType(MVT::f16));
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index ca338235f1fd6..c67c947c730b9 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -317,13 +317,13 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
; SSE2-NEXT: movq %rdx, %rax
; SSE2-NEXT: shrq $48, %rax
; SSE2-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
-; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
+; SSE2-NEXT: punpckhqdq {{.*#+}} xmm0 = xmm0[1,1]
; SSE2-NEXT: movq %xmm0, %r12
; SSE2-NEXT: movq %r12, %rax
; SSE2-NEXT: shrq $32, %rax
; SSE2-NEXT: movq %rax, (%rsp) # 8-byte Spill
-; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
-; SSE2-NEXT: movq %xmm0, %r14
+; SSE2-NEXT: punpckhqdq {{.*#+}} xmm1 = xmm1[1,1]
+; SSE2-NEXT: movq %xmm1, %r14
; SSE2-NEXT: movq %r14, %rbp
; SSE2-NEXT: shrq $32, %rbp
; SSE2-NEXT: movq %r12, %r15
@@ -543,3 +543,25 @@ define <8 x bfloat> @addv(<8 x bfloat> %a, <8 x bfloat> %b) nounwind {
%add = fadd <8 x bfloat> %a, %b
ret <8 x bfloat> %add
}
+
+define <2 x bfloat> @pr62997(bfloat %a, bfloat %b) {
+; SSE2-LABEL: pr62997:
+; SSE2: # %bb.0:
+; SSE2-NEXT: movd %xmm0, %eax
+; SSE2-NEXT: movd %xmm1, %ecx
+; SSE2-NEXT: pinsrw $0, %ecx, %xmm1
+; SSE2-NEXT: pinsrw $0, %eax, %xmm0
+; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
+; SSE2-NEXT: retq
+;
+; BF16-LABEL: pr62997:
+; BF16: # %bb.0:
+; BF16-NEXT: vmovd %xmm1, %eax
+; BF16-NEXT: vmovd %xmm0, %ecx
+; BF16-NEXT: vmovd %ecx, %xmm0
+; BF16-NEXT: vpinsrw $1, %eax, %xmm0, %xmm0
+; BF16-NEXT: retq
+ %1 = insertelement <2 x bfloat> undef, bfloat %a, i64 0
+ %2 = insertelement <2 x bfloat> %1, bfloat %b, i64 1
+ ret <2 x bfloat> %2
+}
More information about the llvm-commits
mailing list