[llvm] AMDGPU: Make v4bf16 a legal type (PR #76217)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 22 00:40:43 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Matt Arsenault (arsenm)
<details>
<summary>Changes</summary>
Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
Depends #<!-- -->76213 #<!-- -->76214 #<!-- -->76215
---
Patch is 1.08 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76217.diff
27 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+28-8)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1)
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+42-14)
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+101-15)
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
- (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+59-6)
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+31-31)
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11)
- (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+5121-6738)
- (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7)
- (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
- (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5)
- (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2)
- (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+70-88)
- (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46)
- (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80)
- (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+157-346)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4)
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16)
- (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2)
- (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+362-629)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
- case ISD::FP_ROUND:
+ case ISD::FP_ROUND: {
+ EVT VT = Node->getValueType(0);
+ if (VT.getScalarType() == MVT::bf16) {
+ Results.push_back(
+ DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+ break;
+ }
+
+ LLVM_FALLTHROUGH;
+ }
case ISD::BITCAST:
if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
return true;
}
break;
- case ISD::FP_EXTEND:
- if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
- Node->getOperand(0).getValueType(),
- Node->getValueType(0), dl)))
+ case ISD::FP_EXTEND: {
+ SDValue Op = Node->getOperand(0);
+ EVT SrcVT = Op.getValueType();
+ EVT DstVT = Node->getValueType(0);
+ if (SrcVT.getScalarType() == MVT::bf16) {
+ Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+ break;
+ }
+
+ if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
Results.push_back(Tmp1);
break;
+ }
case ISD::BF16_TO_FP: {
// Always expand bf16 to f32 casts, they lower to ext + shift.
//
@@ -4908,7 +4924,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
static MVT getPromotedVectorElementType(const TargetLowering &TLI,
MVT EltVT, MVT NewEltVT) {
unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
- MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+ MVT MidVT = OldEltsPerNewElt == 1
+ ? NewEltVT
+ : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
assert(TLI.isTypeLegal(MidVT) && "unexpected");
return MidVT;
}
@@ -5395,7 +5413,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
"Invalid promote type for build_vector");
- assert(NewEltVT.bitsLT(EltVT) && "not handled");
+ assert(NewEltVT.bitsLE(EltVT) && "not handled");
MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
@@ -5406,7 +5424,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
}
SDLoc SL(Node);
- SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+ SDValue Concat =
+ DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+ SL, NVT, NewOps);
SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
Results.push_back(CvtVec);
break;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
// 32 is reserved for the stack pointer
// 33 is reserved for the frame pointer
// 34 is reserved for the base pointer
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
]>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
]>>>,
- CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+ CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;
def RetCC_SI_Gfx : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
def CC_SI_SHADER : CallingConv<[
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
]>>>,
// 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
]>>,
// 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
- CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+ CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(0, 30), !cast<Register>("SGPR"#i)) // SGPR0-29
>>>,
- CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+ CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
- CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+ CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
]>;
// Calling convention for leaf functions
def RetCC_AMDGPU_Func : CallingConv<[
CCIfType<[i1], CCPromoteToType<i32>>,
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
- CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+ CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
]>;
def CC_AMDGPU_CS_CHAIN : CallingConv<[
- CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(105), !cast<Register>("SGPR"#i))
>>>,
- CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+ CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
!foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
>>>
]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
switch (N->getOpcode()) {
case ISD::BUILD_VECTOR:
+ // TODO: Match load d16 from shl (extload:i16), 16
MadeChange |= matchLoadD16FromBuildVector(N);
break;
default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 4bf4707553e5fe..055684281d5955 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -389,15 +389,16 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
Custom);
setOperationAction(
ISD::EXTRACT_SUBVECTOR,
- {MVT::v2f16, MVT::v2i16, MVT::v4f16, MVT::v4i16, MVT::v2f32,
- MVT::v2i32, MVT::v3f32, MVT::v3i32, MVT::v4f32, MVT::v4i32,
- MVT::v5f32, MVT::v5i32, MVT::v6f32, MVT::v6i32, MVT::v7f32,
- MVT::v7i32, MVT::v8f32, MVT::v8i32, MVT::v9f32, MVT::v9i32,
- MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
- MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
- MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64,
- MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64,
- MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
+ {MVT::v2f16, MVT::v2i16, MVT::v2bf16, MVT::v4f16, MVT::v4i16,
+ MVT::v4bf16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
+ MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32,
+ MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32,
+ MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32,
+ MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16i16,
+ MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, MVT::v2f64,
+ MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64, MVT::v4i64,
+ MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64, MVT::v32i16,
+ MVT::v32f16},
Custom);
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
@@ -3274,7 +3275,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
SDLoc DL(Op);
@@ -3312,7 +3321,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
}
- assert(SrcVT == MVT::i64 && "operation should be legal");
+ if (DestVT == MVT::bf16) {
+ SDLoc SL(Op);
+ SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+ SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+ }
+
+ if (SrcVT != MVT::i64)
+ return Op;
// TODO: Factor out code common with LowerUINT_TO_FP.
@@ -3510,7 +3527,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
}
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
unsigned OpOpcode = Op.getOpcode();
@@ -3521,6 +3538,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
if (SrcVT == MVT::f16 && DestVT == MVT::i16)
return Op;
+ if (SrcVT == MVT::bf16) {
+ SDLoc DL(Op);
+ SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+ return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+ }
+
// Promote i16 to i32
if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
SDLoc DL(Op);
@@ -3529,6 +3552,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
}
+ if (DestVT != MVT::i64)
+ return Op;
+
if (SrcVT == MVT::f16 ||
(SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
SDLoc DL(Op);
@@ -3539,7 +3565,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
}
- if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+ if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
return SDValue();
@@ -4940,7 +4966,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
// vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
if (DestVT.isVector()) {
SDValue Src = N->getOperand(0);
- if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+ if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+ (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+ isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
EVT SrcVT = Src.getValueType();
unsigned NElts = DestVT.getVectorNumElements();
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fc119aa61d01a2..333f0cdd6d6541 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,16 +151,20 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
if (Subtarget->useRealTrue16Insts()) {
addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+ addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
} else {
addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+ addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
}
// Unless there are also VOP3P operations, not operations are really legal.
addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+ addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
+ addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
@@ -196,6 +200,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
MVT::i1, MVT::v32i32},
Custom);
+ if (isTypeLegal(MVT::bf16)) {
+ for (unsigned Opc :
+ {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV,
+ ISD::FREM, ISD::FMA, ISD::FMINNUM, ISD::FMAXNUM,
+ ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FSQRT, ISD::FCBRT,
+ ISD::FSIN, ISD::FCOS, ISD::FPOW, ISD::FPOWI,
+ ISD::FLDEXP, ISD::FFREXP, ISD::FLOG, ISD::FLOG2,
+ ISD::FLOG10, ISD::FEXP, ISD::FEXP2, ISD::FEXP10,
+ ISD::FCEIL, ISD::FTRUNC, ISD::FRINT, ISD::FNEARBYINT,
+ ISD::FROUND, ISD::FROUNDEVEN, ISD::FFLOOR, ISD::FCANONICALIZE,
+ ISD::SETCC}) {
+ // FIXME: The promoted to type shouldn't need to be explicit
+ setOperationAction(Opc, MVT::bf16, Promote);
+ AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+ }
+
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+ setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+ AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+ // TODO: Could make these legal
+ setOperationAction(ISD::FABS, MVT::bf16, Expand);
+ setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+ // We only need to custom lower because we can't specify an action for bf16
+ // sources.
+ setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+ setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+ AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+ }
+
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -274,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
- MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
- MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
- MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
- MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
+ MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
+ MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
+ MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
+ MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
switch (Op) {
case ISD::LOAD:
@@ -383,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
Expand);
- setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
+ setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
+ Custom);
// Avoid stack access for these.
// TODO: Generalize to more vector types.
setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
- {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
- MVT::v4i16, MVT::v4f16},
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+ MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
Custom);
// Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +538,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
+ // Custom lower these because we can't specify a rule based on an illegal
+ // source bf16.
+ setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+ setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
if (Subtarget->has16BitInsts()) {
setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +569,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+ setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
// F16 - Constant Actions.
setOperationA...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/76217
More information about the llvm-commits
mailing list