[llvm] 9458bae - [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (#74827)

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 12:06:54 PST 2023


Author: Benjamin Kramer
Date: 2023-12-11T21:06:46+01:00
New Revision: 9458bae553c82438e1817b4a5b1d003a8de064c3

URL: https://github.com/llvm/llvm-project/commit/9458bae553c82438e1817b4a5b1d003a8de064c3
DIFF: https://github.com/llvm/llvm-project/commit/9458bae553c82438e1817b4a5b1d003a8de064c3.diff

LOG: [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (#74827)

sm_80 only has f32->bf16 conversions, the remaining integer conversions
arrived with sm_90. Use a two-step conversion for sm_80.

There doesn't seem to be a way to express this promotion directly within
the legalization framework, so fallback on Custom lowering.

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 61285c6ba98dff..f5d8abaf847a22 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -766,6 +766,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  // sm_80 only has conversions between f32 and bf16. Custom lower all other
+  // bf16 conversions.
+  if (STI.hasBF16Math() &&
+      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+    for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
+      setOperationAction(
+          {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+          VT, Custom);
+    }
+  }
+
   setOperationAction(ISD::FROUND, MVT::f16, Promote);
   setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
   setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
@@ -2580,6 +2591,37 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
+
+  if (Op.getValueType() == MVT::bf16) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        ISD::FP_ROUND, Loc, MVT::bf16,
+        DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
+        DAG.getIntPtrConstant(0, Loc));
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
+
+  if (Op.getOperand(0).getValueType() == MVT::bf16) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        Op.getOpcode(), Loc, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   if (Op.getValueType() != MVT::v2i16)
@@ -2636,6 +2678,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSelect(Op, DAG);
   case ISD::FROUND:
     return LowerFROUND(Op, DAG);
+  case ISD::SINT_TO_FP:
+  case ISD::UINT_TO_FP:
+    return LowerINT_TO_FP(Op, DAG);
+  case ISD::FP_TO_SINT:
+  case ISD::FP_TO_UINT:
+    return LowerFP_TO_INT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 54e34dedc6675e..cd6bcb048c5fe2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -607,6 +607,9 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
 

diff  --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 5a6ab2926b40cf..a9faa130d6379f 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -227,3 +227,106 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
   %res = fpext <8 x bfloat> %load to <8 x float>
   ret <8 x float> %res
 }
+
+; CHECK-LABEL: test_fptosi_i16(
+; CHECK:      ld.param.b16     [[A:%rs[0-9]+]], [test_fptosi_i16_param_0];
+; SM80:       cvt.f32.bf16     [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rzi.s16.f32  [[C:%rs[0-9]+]], [[B]];
+; SM80:       cvt.u32.u16      [[R:%r[0-9]+]], [[C]];
+; SM90:       cvt.rzi.s16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90:       cvt.u32.u16      [[R:%r[0-9]+]], [[B]];
+; CHECK:      st.param.b32     [func_retval0+0], [[R]];
+; CHECK:      ret;
+define i16 @test_fptosi_i16(bfloat %a) {
+  %r = fptosi bfloat %a to i16
+  ret i16 %r
+}
+
+; CHECK-LABEL: test_fptoui_i16(
+; CHECK:      ld.param.b16     [[A:%rs[0-9]+]], [test_fptoui_i16_param_0];
+; SM80:       cvt.f32.bf16     [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rzi.u16.f32  [[C:%rs[0-9]+]], [[B]];
+; SM80:       cvt.u32.u16      [[R:%r[0-9]+]], [[C]];
+; SM90:       cvt.rzi.u16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90:       cvt.u32.u16      [[R:%r[0-9]+]], [[B]];
+; CHECK:      st.param.b32     [func_retval0+0], [[R]];
+; CHECK:      ret;
+define i16 @test_fptoui_i16(bfloat %a) {
+  %r = fptoui bfloat %a to i16
+  ret i16 %r
+}
+
+; CHECK-LABEL: test_sitofp_i16(
+; CHECK:      ld.param.u16    [[A:%rs[0-9]+]], [test_sitofp_i16_param_0];
+; SM80:       cvt.rn.f32.s16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.s16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_sitofp_i16(i16 %a) {
+  %r = sitofp i16 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i8(
+; CHECK:      ld.param.u8 %rs1, [test_uitofp_i8_param_0];
+; SM80:       cvt.rn.f32.u16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i8(i8 %a) {
+  %r = uitofp i8 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i1(
+; CHECK:      ld.param.u8     [[A:%rs[0-9]+]], [test_uitofp_i1_param_0];
+; CHECK:      and.b16         [[B:%rs[0-9]+]], [[A]], 1;
+; CHECK:      setp.eq.b16     [[C:%p[0-9]+]], [[B]], 1;
+; CHECK:      selp.u32        [[D:%r[0-9]+]], 1, 0, [[C]];
+; SM80:       cvt.rn.f32.u32  [[E:%f[0-9]+]], [[D]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[E]];
+; SM90:       cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[D]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i1(i1 %a) {
+  %r = uitofp i1 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i16(
+; CHECK:      ld.param.u16    [[A:%rs[0-9]+]], [test_uitofp_i16_param_0];
+; SM80:       cvt.rn.f32.u16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i16(i16 %a) {
+  %r = uitofp i16 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i32(
+; CHECK:      ld.param.u32    [[A:%r[0-9]+]], [test_uitofp_i32_param_0];
+; SM80:       cvt.rn.f32.u32  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i32(i32 %a) {
+  %r = uitofp i32 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i64(
+; CHECK:      ld.param.u64    [[A:%rd[0-9]+]], [test_uitofp_i64_param_0];
+; SM80:       cvt.rn.f32.u64  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u64 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i64(i64 %a) {
+  %r = uitofp i64 %a to bfloat
+  ret bfloat %r
+}


        


More information about the llvm-commits mailing list