[llvm] [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (PR #74827)

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 8 03:23:18 PST 2023


https://github.com/d0k created https://github.com/llvm/llvm-project/pull/74827

sm_80 only has f32->bf16 conversions, the remaining integer conversions arrived with sm_90. Use a two-step conversion for sm_80.

There doesn't seem to be a way to express this promotion directly within the legalization framework, so fallback on Custom lowering.

>From ceda94078a6a10d5af84889defbf4bc96d8b413a Mon Sep 17 00:00:00 2001
From: Benjamin Kramer <benny.kra at googlemail.com>
Date: Fri, 8 Dec 2023 12:16:53 +0100
Subject: [PATCH] [NVPTX] Custom lower integer<->bf16 conversions for sm_80

sm_80 only has f32->bf16 conversions, the remaining integer conversions
arrived with sm_90. Use a two-step conversion for sm_80.

There doesn't seem to be a way to express this promotion directly within
the legalization framework, so fallback on Custom lowering.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  |  43 ++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h    |   3 +
 llvm/test/CodeGen/NVPTX/bf16-instructions.ll | 103 +++++++++++++++++++
 3 files changed, 149 insertions(+)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 61285c6ba98df..fc7e077da36ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -766,6 +766,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
       AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
 
+  for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
+    setOperationAction(
+        {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+        VT, Custom);
+  }
+
   setOperationAction(ISD::FROUND, MVT::f16, Promote);
   setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
   setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
@@ -2580,6 +2586,37 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
   return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
 }
 
+SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  // sm_90 has instructions for bf16 conversions, sm_80 only has f32 -> bf16.
+  if (Op.getValueType() == MVT::bf16 &&
+      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        ISD::FP_ROUND, Loc, MVT::bf16,
+        DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
+        DAG.getIntPtrConstant(0, Loc));
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  // sm_90 has instructions for bf16 conversions, sm_80 only has f32.
+  if (Op.getOperand(0).getValueType() == MVT::bf16 &&
+      (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+    SDLoc Loc(Op);
+    return DAG.getNode(
+        Op.getOpcode(), Loc, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
+  }
+
+  // Everything else is considered legal.
+  return Op;
+}
+
 static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   if (Op.getValueType() != MVT::v2i16)
@@ -2636,6 +2673,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerSelect(Op, DAG);
   case ISD::FROUND:
     return LowerFROUND(Op, DAG);
+  case ISD::SINT_TO_FP:
+  case ISD::UINT_TO_FP:
+    return LowerINT_TO_FP(Op, DAG);
+  case ISD::FP_TO_SINT:
+  case ISD::FP_TO_UINT:
+    return LowerFP_TO_INT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 54e34dedc6675..cd6bcb048c5fe 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -607,6 +607,9 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 5a6ab2926b40c..a9faa130d6379 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -227,3 +227,106 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
   %res = fpext <8 x bfloat> %load to <8 x float>
   ret <8 x float> %res
 }
+
+; CHECK-LABEL: test_fptosi_i16(
+; CHECK:      ld.param.b16     [[A:%rs[0-9]+]], [test_fptosi_i16_param_0];
+; SM80:       cvt.f32.bf16     [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rzi.s16.f32  [[C:%rs[0-9]+]], [[B]];
+; SM80:       cvt.u32.u16      [[R:%r[0-9]+]], [[C]];
+; SM90:       cvt.rzi.s16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90:       cvt.u32.u16      [[R:%r[0-9]+]], [[B]];
+; CHECK:      st.param.b32     [func_retval0+0], [[R]];
+; CHECK:      ret;
+define i16 @test_fptosi_i16(bfloat %a) {
+  %r = fptosi bfloat %a to i16
+  ret i16 %r
+}
+
+; CHECK-LABEL: test_fptoui_i16(
+; CHECK:      ld.param.b16     [[A:%rs[0-9]+]], [test_fptoui_i16_param_0];
+; SM80:       cvt.f32.bf16     [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rzi.u16.f32  [[C:%rs[0-9]+]], [[B]];
+; SM80:       cvt.u32.u16      [[R:%r[0-9]+]], [[C]];
+; SM90:       cvt.rzi.u16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90:       cvt.u32.u16      [[R:%r[0-9]+]], [[B]];
+; CHECK:      st.param.b32     [func_retval0+0], [[R]];
+; CHECK:      ret;
+define i16 @test_fptoui_i16(bfloat %a) {
+  %r = fptoui bfloat %a to i16
+  ret i16 %r
+}
+
+; CHECK-LABEL: test_sitofp_i16(
+; CHECK:      ld.param.u16    [[A:%rs[0-9]+]], [test_sitofp_i16_param_0];
+; SM80:       cvt.rn.f32.s16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.s16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_sitofp_i16(i16 %a) {
+  %r = sitofp i16 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i8(
+; CHECK:      ld.param.u8 %rs1, [test_uitofp_i8_param_0];
+; SM80:       cvt.rn.f32.u16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i8(i8 %a) {
+  %r = uitofp i8 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i1(
+; CHECK:      ld.param.u8     [[A:%rs[0-9]+]], [test_uitofp_i1_param_0];
+; CHECK:      and.b16         [[B:%rs[0-9]+]], [[A]], 1;
+; CHECK:      setp.eq.b16     [[C:%p[0-9]+]], [[B]], 1;
+; CHECK:      selp.u32        [[D:%r[0-9]+]], 1, 0, [[C]];
+; SM80:       cvt.rn.f32.u32  [[E:%f[0-9]+]], [[D]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[E]];
+; SM90:       cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[D]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i1(i1 %a) {
+  %r = uitofp i1 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i16(
+; CHECK:      ld.param.u16    [[A:%rs[0-9]+]], [test_uitofp_i16_param_0];
+; SM80:       cvt.rn.f32.u16  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i16(i16 %a) {
+  %r = uitofp i16 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i32(
+; CHECK:      ld.param.u32    [[A:%r[0-9]+]], [test_uitofp_i32_param_0];
+; SM80:       cvt.rn.f32.u32  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i32(i32 %a) {
+  %r = uitofp i32 %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i64(
+; CHECK:      ld.param.u64    [[A:%rd[0-9]+]], [test_uitofp_i64_param_0];
+; SM80:       cvt.rn.f32.u64  [[B:%f[0-9]+]], [[A]];
+; SM80:       cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90:       cvt.rn.bf16.u64 [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_uitofp_i64(i64 %a) {
+  %r = uitofp i64 %a to bfloat
+  ret bfloat %r
+}



More information about the llvm-commits mailing list