[llvm] 9458bae - [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (#74827)
Benjamin Kramer via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 11 12:06:54 PST 2023
Author: Benjamin Kramer
Date: 2023-12-11T21:06:46+01:00
New Revision: 9458bae553c82438e1817b4a5b1d003a8de064c3
URL: https://github.com/llvm/llvm-project/commit/9458bae553c82438e1817b4a5b1d003a8de064c3
DIFF: https://github.com/llvm/llvm-project/commit/9458bae553c82438e1817b4a5b1d003a8de064c3.diff
LOG: [NVPTX] Custom lower integer<->bf16 conversions for sm_80 (#74827)
sm_80 only has f32->bf16 conversions, the remaining integer conversions
arrived with sm_90. Use a two-step conversion for sm_80.
There doesn't seem to be a way to express this promotion directly within
the legalization framework, so fallback on Custom lowering.
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/test/CodeGen/NVPTX/bf16-instructions.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 61285c6ba98dff..f5d8abaf847a22 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -766,6 +766,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ // sm_80 only has conversions between f32 and bf16. Custom lower all other
+ // bf16 conversions.
+ if (STI.hasBF16Math() &&
+ (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
+ for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
+ setOperationAction(
+ {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
+ VT, Custom);
+ }
+ }
+
setOperationAction(ISD::FROUND, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
@@ -2580,6 +2591,37 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
}
+SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
+
+ if (Op.getValueType() == MVT::bf16) {
+ SDLoc Loc(Op);
+ return DAG.getNode(
+ ISD::FP_ROUND, Loc, MVT::bf16,
+ DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
+ DAG.getIntPtrConstant(0, Loc));
+ }
+
+ // Everything else is considered legal.
+ return Op;
+}
+
+SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
+
+ if (Op.getOperand(0).getValueType() == MVT::bf16) {
+ SDLoc Loc(Op);
+ return DAG.getNode(
+ Op.getOpcode(), Loc, Op.getValueType(),
+ DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
+ }
+
+ // Everything else is considered legal.
+ return Op;
+}
+
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
SDLoc DL(Op);
if (Op.getValueType() != MVT::v2i16)
@@ -2636,6 +2678,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSelect(Op, DAG);
case ISD::FROUND:
return LowerFROUND(Op, DAG);
+ case ISD::SINT_TO_FP:
+ case ISD::UINT_TO_FP:
+ return LowerINT_TO_FP(Op, DAG);
+ case ISD::FP_TO_SINT:
+ case ISD::FP_TO_UINT:
+ return LowerFP_TO_INT(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 54e34dedc6675e..cd6bcb048c5fe2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -607,6 +607,9 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 5a6ab2926b40cf..a9faa130d6379f 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -227,3 +227,106 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
%res = fpext <8 x bfloat> %load to <8 x float>
ret <8 x float> %res
}
+
+; CHECK-LABEL: test_fptosi_i16(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fptosi_i16_param_0];
+; SM80: cvt.f32.bf16 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rzi.s16.f32 [[C:%rs[0-9]+]], [[B]];
+; SM80: cvt.u32.u16 [[R:%r[0-9]+]], [[C]];
+; SM90: cvt.rzi.s16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90: cvt.u32.u16 [[R:%r[0-9]+]], [[B]];
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define i16 @test_fptosi_i16(bfloat %a) {
+ %r = fptosi bfloat %a to i16
+ ret i16 %r
+}
+
+; CHECK-LABEL: test_fptoui_i16(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fptoui_i16_param_0];
+; SM80: cvt.f32.bf16 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rzi.u16.f32 [[C:%rs[0-9]+]], [[B]];
+; SM80: cvt.u32.u16 [[R:%r[0-9]+]], [[C]];
+; SM90: cvt.rzi.u16.bf16 [[B:%rs[0-9]+]], [[A]];
+; SM90: cvt.u32.u16 [[R:%r[0-9]+]], [[B]];
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define i16 @test_fptoui_i16(bfloat %a) {
+ %r = fptoui bfloat %a to i16
+ ret i16 %r
+}
+
+; CHECK-LABEL: test_sitofp_i16(
+; CHECK: ld.param.u16 [[A:%rs[0-9]+]], [test_sitofp_i16_param_0];
+; SM80: cvt.rn.f32.s16 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90: cvt.rn.bf16.s16 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_sitofp_i16(i16 %a) {
+ %r = sitofp i16 %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i8(
+; CHECK: ld.param.u8 %rs1, [test_uitofp_i8_param_0];
+; SM80: cvt.rn.f32.u16 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90: cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_uitofp_i8(i8 %a) {
+ %r = uitofp i8 %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i1(
+; CHECK: ld.param.u8 [[A:%rs[0-9]+]], [test_uitofp_i1_param_0];
+; CHECK: and.b16 [[B:%rs[0-9]+]], [[A]], 1;
+; CHECK: setp.eq.b16 [[C:%p[0-9]+]], [[B]], 1;
+; CHECK: selp.u32 [[D:%r[0-9]+]], 1, 0, [[C]];
+; SM80: cvt.rn.f32.u32 [[E:%f[0-9]+]], [[D]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[E]];
+; SM90: cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[D]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_uitofp_i1(i1 %a) {
+ %r = uitofp i1 %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i16(
+; CHECK: ld.param.u16 [[A:%rs[0-9]+]], [test_uitofp_i16_param_0];
+; SM80: cvt.rn.f32.u16 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90: cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_uitofp_i16(i16 %a) {
+ %r = uitofp i16 %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i32(
+; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_uitofp_i32_param_0];
+; SM80: cvt.rn.f32.u32 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90: cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_uitofp_i32(i32 %a) {
+ %r = uitofp i32 %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_uitofp_i64(
+; CHECK: ld.param.u64 [[A:%rd[0-9]+]], [test_uitofp_i64_param_0];
+; SM80: cvt.rn.f32.u64 [[B:%f[0-9]+]], [[A]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
+; SM90: cvt.rn.bf16.u64 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_uitofp_i64(i64 %a) {
+ %r = uitofp i64 %a to bfloat
+ ret bfloat %r
+}
More information about the llvm-commits
mailing list