[llvm] AMDGPU: Make bf16/v2bf16 legal types (PR #76215)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 22 00:39:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

<details>
<summary>Changes</summary>

There are some intrinsics are using i16 vectors in place of bfloat vectors.
Move towards making bf16 vectors legal so these can migrate. Leave the
larger vectors for a later change.

Depends #<!-- -->76213 #<!-- -->76214 

---

Patch is 1.11 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76215.diff


28 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+28-8) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+32-5) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+88-8) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1) 
- (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+46-6) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+31-31) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11) 
- (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+6030-7324) 
- (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7) 
- (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2) 
- (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5) 
- (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2) 
- (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+78-89) 
- (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46) 
- (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80) 
- (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+113-289) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4) 
- (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16) 
- (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2) 
- (modified) llvm/test/CodeGen/AMDGPU/select-undef.ll (+2-1) 
- (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+258-592) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index a483b8028fda9e..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_ROUND:
+  case ISD::FP_ROUND: {
+    EVT VT = Node->getValueType(0);
+    if (VT.getScalarType() == MVT::bf16) {
+      Results.push_back(
+          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      break;
+    }
+
+    LLVM_FALLTHROUGH;
+  }
   case ISD::BITCAST:
     if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
                                  Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_EXTEND:
-    if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
-                                 Node->getOperand(0).getValueType(),
-                                 Node->getValueType(0), dl)))
+  case ISD::FP_EXTEND: {
+    SDValue Op = Node->getOperand(0);
+    EVT SrcVT = Op.getValueType();
+    EVT DstVT = Node->getValueType(0);
+    if (SrcVT.getScalarType() == MVT::bf16) {
+      Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+      break;
+    }
+
+    if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
       Results.push_back(Tmp1);
     break;
+  }
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
@@ -4908,7 +4924,9 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
 static MVT getPromotedVectorElementType(const TargetLowering &TLI,
                                         MVT EltVT, MVT NewEltVT) {
   unsigned OldEltsPerNewElt = EltVT.getSizeInBits() / NewEltVT.getSizeInBits();
-  MVT MidVT = MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
+  MVT MidVT = OldEltsPerNewElt == 1
+                  ? NewEltVT
+                  : MVT::getVectorVT(NewEltVT, OldEltsPerNewElt);
   assert(TLI.isTypeLegal(MidVT) && "unexpected");
   return MidVT;
 }
@@ -5395,7 +5413,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
 
     assert(NVT.isVector() && OVT.getSizeInBits() == NVT.getSizeInBits() &&
            "Invalid promote type for build_vector");
-    assert(NewEltVT.bitsLT(EltVT) && "not handled");
+    assert(NewEltVT.bitsLE(EltVT) && "not handled");
 
     MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT);
 
@@ -5406,7 +5424,9 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
     }
 
     SDLoc SL(Node);
-    SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SL, NVT, NewOps);
+    SDValue Concat =
+        DAG.getNode(MidVT == NewEltVT ? ISD::BUILD_VECTOR : ISD::CONCAT_VECTORS,
+                    SL, NVT, NewOps);
     SDValue CvtVec = DAG.getNode(ISD::BITCAST, SL, OVT, Concat);
     Results.push_back(CvtVec);
     break;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
   // 32 is reserved for the stack pointer
   // 33 is reserved for the frame pointer
   // 34 is reserved for the base pointer
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
     SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
   ]>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
   ]>>>,
 
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 def RetCC_SI_Gfx : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
 
 def CC_SI_SHADER : CallingConv<[
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
   ]>>>,
 
   // 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
   ]>>,
 
   // 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
-  CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+  CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(0, 30), !cast<Register>("SGPR"#i))  // SGPR0-29
   >>>,
 
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 // Calling convention for leaf functions
 def RetCC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
 ]>;
 
 def CC_AMDGPU_CS_CHAIN : CallingConv<[
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(105), !cast<Register>("SGPR"#i))
   >>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
   >>>
 ]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
 
     switch (N->getOpcode()) {
     case ISD::BUILD_VECTOR:
+      // TODO: Match load d16 from shl (extload:i16), 16
       MadeChange |= matchLoadD16FromBuildVector(N);
       break;
     default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 4bf4707553e5fe..131000830a73d5 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3274,7 +3274,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
@@ -3312,7 +3320,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   // TODO: Factor out code common with LowerUINT_TO_FP.
 
@@ -3510,7 +3526,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
   return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
 }
 
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
                                              SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   unsigned OpOpcode = Op.getOpcode();
@@ -3521,6 +3537,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
   if (SrcVT == MVT::f16 && DestVT == MVT::i16)
     return Op;
 
+  if (SrcVT == MVT::bf16) {
+    SDLoc DL(Op);
+    SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+    return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+  }
+
   // Promote i16 to i32
   if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
     SDLoc DL(Op);
@@ -3529,6 +3551,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
   }
 
+  if (DestVT != MVT::i64)
+    return Op;
+
   if (SrcVT == MVT::f16 ||
       (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
     SDLoc DL(Op);
@@ -3539,7 +3564,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
   }
 
-  if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+  if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
     return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
 
   return SDValue();
@@ -4940,7 +4965,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     //   vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
     if (DestVT.isVector()) {
       SDValue Src = N->getOperand(0);
-      if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+      if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+          (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+           isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
         EVT SrcVT = Src.getValueType();
         unsigned NElts = DestVT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index fc119aa61d01a2..eaf32850a87149 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,14 +151,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     if (Subtarget->useRealTrue16Insts()) {
       addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
     } else {
       addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
     }
 
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+    addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
@@ -196,6 +199,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                       MVT::i1,     MVT::v32i32},
                      Custom);
 
+  if (isTypeLegal(MVT::bf16)) {
+    for (unsigned Opc :
+         {ISD::FADD,     ISD::FSUB,       ISD::FMUL,    ISD::FDIV,
+          ISD::FREM,     ISD::FMA,        ISD::FMINNUM, ISD::FMAXNUM,
+          ISD::FMINIMUM, ISD::FMAXIMUM,   ISD::FSQRT,   ISD::FCBRT,
+          ISD::FSIN,     ISD::FCOS,       ISD::FPOW,    ISD::FPOWI,
+          ISD::FLDEXP,   ISD::FFREXP,     ISD::FLOG,    ISD::FLOG2,
+          ISD::FLOG10,   ISD::FEXP,       ISD::FEXP2,   ISD::FEXP10,
+          ISD::FCEIL,    ISD::FTRUNC,     ISD::FRINT,   ISD::FNEARBYINT,
+          ISD::FROUND,   ISD::FROUNDEVEN, ISD::FFLOOR,  ISD::FCANONICALIZE,
+          ISD::SETCC}) {
+      // FIXME: The promoted to type shouldn't need to be explicit
+      setOperationAction(Opc, MVT::bf16, Promote);
+      AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+    }
+
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+    setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+    AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+    // TODO: Could make these legal
+    setOperationAction(ISD::FABS, MVT::bf16, Expand);
+    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // We only need to custom lower because we can't specify an action for bf16
+    // sources.
+    setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+    setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+    AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+  }
+
   setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
   setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
   setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -388,8 +426,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
-                     {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
-                      MVT::v4i16, MVT::v4f16},
+                     {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+                      MVT::v8i8, MVT::v4i16, MVT::v4f16},
                      Custom);
 
   // Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +536,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
   setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
+  // Custom lower these because we can't specify a rule based on an illegal
+  // source bf16.
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
                         ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +567,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
 
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
 
     // F16 - Constant Actions.
     setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+    setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
     // F16 - Load/Store Actions.
     setOperationAction(ISD::LOAD, MVT::f16, Promote);
@@ -534,16 +582,23 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::STORE, MVT::f16, Promote);
     AddPromotedToType(ISD::STORE, MVT::f16, MVT::i16);
 
+    // BF16 - Load/Store Actions.
+    setOperationAction(ISD::LOAD, MVT::bf16, Promote);
+    AddPromotedToType(ISD::LOAD, MVT::bf16, MVT::i16);
+    setOperationAction(ISD::STORE, MVT::bf16, Promote);
+    AddPromotedToType(ISD::STORE, MVT::bf16, MVT::i16);
+
     // F16 - VOP1 Actions.
     setOperationAction({ISD::FP_ROUND, ISD::STRICT_FP_ROUND, ISD::FCOS,
                         ISD::FSIN, ISD::FROUND, ISD::FPTRUNC_ROUND},
                        MVT::f16, Custom);
 
-    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::f16, Promote);
+    setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::bf16, Promote);
 
     // F16 - VOP2 Actions.
-    setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, MVT::f16, Expand);
+    setOperationAction({ISD::BR_CC, ISD::SELECT_CC}, {MVT::f16, MVT::bf16},
+                       Expand);
     setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom);
     setOperationAction(ISD::FFREXP, MVT::f16, Custom);
     setOperationAction(ISD::FDIV, MVT::f16, Custom);
@@ -554,8 +609,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FMAD, MVT::f16, Legal);
 
     for (MVT VT :
-         {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16, MVT::v8i16,
-          MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v32i16, MVT::v32f16}) {
+         {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v4i16, MVT::v4f16,
+          MVT::v4bf16, MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16,
+          MVT::v16f16, MVT::v16bf16, MVT::v32i16, MVT::v32f16}) {
       for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
         switch (Op) {
         case ISD::LOAD:
@@ -587,7 +643,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     // XXX - Do these do anything? Vector constants turn into build_vector.
     setOperationAction(ISD::Constant, {MVT::v2i16, MVT::v2f16}, Legal);
 
-    setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16}, Legal);
+    setOperationAction(ISD::UNDEF, {MVT::v2i16, MVT::v2f16, MVT::v2bf16},
+                       Legal);
 
     setOperationAction(ISD::STORE, MVT::v2i16, Promote);
     AddPromotedToType(ISD::STORE, MVT::v2i16, MVT::i32);
@@ -3901,6 +3958,26 @@ SDValue SITargetLowering::lowerPREFETCH(SDValue Op, SelectionDAG &DAG) const {
   return Op;
 }
 
+// Work around DAG legality rules only based on the result type.
+SDValue SITargetLowering::lowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
+  bool IsStrict = Op.getOpcode() == ISD::STRICT_FP_EXTEND;
+  SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
+  EVT S...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76215


More information about the llvm-commits mailing list