[llvm] [SelectionDAG] Avoid double rounding by using a libcall for int (>i32) to fp conversions (PR #71658)

Alex Bradbury via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 8 03:26:38 PST 2023


https://github.com/asb created https://github.com/llvm/llvm-project/pull/71658

This is the next logical step to #70933. I was going to wait until that is reviewed+committed before posting the next steps, but seeing as that's taking a while thought it may help to share.

We likely want to hold off on merging until the compiler-rt libcall is added (though this will work fine if linking with libgcc).

>From 52c17ca236e0e76eb6983572e48716d66e7d38f1 Mon Sep 17 00:00:00 2001
From: Alex Bradbury <asb at igalia.com>
Date: Wed, 1 Nov 2023 12:57:37 +0000
Subject: [PATCH 1/2] [CodeGen][RISCV] Add [u]int{64,128} to bf16 libcalls

As noted in D157509 <https://reviews.llvm.org/D157509>, the default
lowering for {S,U}INT_TO_P for bf16 first converts to float and then
converts from float to bf16 is semantically incorrect as it introduces
double rounding. This patch doesn't fix/alter that for the case where
bf16 is not a legal type (to be handled in a follow-up), but does fix it
for the case where bf16 is a legal type. This is currently only
exercised by the RISC-V target.

The libcall names are the same as provided by libgcc. A separate patch
will add them to compiler-rt.
---
 llvm/include/llvm/IR/RuntimeLibcalls.def  |   4 +
 llvm/lib/CodeGen/TargetLoweringBase.cpp   |   8 +
 llvm/test/CodeGen/RISCV/bfloat-convert.ll | 244 +++++++++++++++++++++-
 3 files changed, 247 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.def b/llvm/include/llvm/IR/RuntimeLibcalls.def
index 6ec98e278988428..61377091f952539 100644
--- a/llvm/include/llvm/IR/RuntimeLibcalls.def
+++ b/llvm/include/llvm/IR/RuntimeLibcalls.def
@@ -371,12 +371,14 @@ HANDLE_LIBCALL(SINTTOFP_I32_F64, "__floatsidf")
 HANDLE_LIBCALL(SINTTOFP_I32_F80, "__floatsixf")
 HANDLE_LIBCALL(SINTTOFP_I32_F128, "__floatsitf")
 HANDLE_LIBCALL(SINTTOFP_I32_PPCF128, "__gcc_itoq")
+HANDLE_LIBCALL(SINTTOFP_I64_BF16, "__floatdibf")
 HANDLE_LIBCALL(SINTTOFP_I64_F16, "__floatdihf")
 HANDLE_LIBCALL(SINTTOFP_I64_F32, "__floatdisf")
 HANDLE_LIBCALL(SINTTOFP_I64_F64, "__floatdidf")
 HANDLE_LIBCALL(SINTTOFP_I64_F80, "__floatdixf")
 HANDLE_LIBCALL(SINTTOFP_I64_F128, "__floatditf")
 HANDLE_LIBCALL(SINTTOFP_I64_PPCF128, "__floatditf")
+HANDLE_LIBCALL(SINTTOFP_I128_BF16, "__floattibf")
 HANDLE_LIBCALL(SINTTOFP_I128_F16, "__floattihf")
 HANDLE_LIBCALL(SINTTOFP_I128_F32, "__floattisf")
 HANDLE_LIBCALL(SINTTOFP_I128_F64, "__floattidf")
@@ -389,12 +391,14 @@ HANDLE_LIBCALL(UINTTOFP_I32_F64, "__floatunsidf")
 HANDLE_LIBCALL(UINTTOFP_I32_F80, "__floatunsixf")
 HANDLE_LIBCALL(UINTTOFP_I32_F128, "__floatunsitf")
 HANDLE_LIBCALL(UINTTOFP_I32_PPCF128, "__gcc_utoq")
+HANDLE_LIBCALL(UINTTOFP_I64_BF16, "__floatundibf")
 HANDLE_LIBCALL(UINTTOFP_I64_F16, "__floatundihf")
 HANDLE_LIBCALL(UINTTOFP_I64_F32, "__floatundisf")
 HANDLE_LIBCALL(UINTTOFP_I64_F64, "__floatundidf")
 HANDLE_LIBCALL(UINTTOFP_I64_F80, "__floatundixf")
 HANDLE_LIBCALL(UINTTOFP_I64_F128, "__floatunditf")
 HANDLE_LIBCALL(UINTTOFP_I64_PPCF128, "__floatunditf")
+HANDLE_LIBCALL(UINTTOFP_I128_BF16, "__floatuntibf")
 HANDLE_LIBCALL(UINTTOFP_I128_F16, "__floatuntihf")
 HANDLE_LIBCALL(UINTTOFP_I128_F32, "__floatuntisf")
 HANDLE_LIBCALL(UINTTOFP_I128_F64, "__floatuntidf")
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 7be1ebe6fa79b04..81e68d40a90d3da 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -430,6 +430,8 @@ RTLIB::Libcall RTLIB::getSINTTOFP(EVT OpVT, EVT RetVT) {
     if (RetVT == MVT::ppcf128)
       return SINTTOFP_I32_PPCF128;
   } else if (OpVT == MVT::i64) {
+    if (RetVT == MVT::bf16)
+      return SINTTOFP_I64_BF16;
     if (RetVT == MVT::f16)
       return SINTTOFP_I64_F16;
     if (RetVT == MVT::f32)
@@ -443,6 +445,8 @@ RTLIB::Libcall RTLIB::getSINTTOFP(EVT OpVT, EVT RetVT) {
     if (RetVT == MVT::ppcf128)
       return SINTTOFP_I64_PPCF128;
   } else if (OpVT == MVT::i128) {
+    if (RetVT == MVT::bf16)
+      return SINTTOFP_I128_BF16;
     if (RetVT == MVT::f16)
       return SINTTOFP_I128_F16;
     if (RetVT == MVT::f32)
@@ -476,6 +480,8 @@ RTLIB::Libcall RTLIB::getUINTTOFP(EVT OpVT, EVT RetVT) {
     if (RetVT == MVT::ppcf128)
       return UINTTOFP_I32_PPCF128;
   } else if (OpVT == MVT::i64) {
+    if (RetVT == MVT::bf16)
+      return UINTTOFP_I64_BF16;
     if (RetVT == MVT::f16)
       return UINTTOFP_I64_F16;
     if (RetVT == MVT::f32)
@@ -489,6 +495,8 @@ RTLIB::Libcall RTLIB::getUINTTOFP(EVT OpVT, EVT RetVT) {
     if (RetVT == MVT::ppcf128)
       return UINTTOFP_I64_PPCF128;
   } else if (OpVT == MVT::i128) {
+    if (RetVT == MVT::bf16)
+      return UINTTOFP_I128_BF16;
     if (RetVT == MVT::f16)
       return UINTTOFP_I128_F16;
     if (RetVT == MVT::f32)
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 8a0c4240d161bfb..092a32f54b54992 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -1127,17 +1127,243 @@ define bfloat @fcvt_bf16_wu_load(ptr %p) nounwind {
   ret bfloat %1
 }
 
-; TODO: The following tests error on rv32 with zfbfmin enabled.
+; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
+; rounding is currently used.
+define bfloat @fcvt_bf16_l(i64 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_l:
+; CHECK32ZFBFMIN:       # %bb.0:
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, -16
+; CHECK32ZFBFMIN-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT:    call __floatdibf at plt
+; CHECK32ZFBFMIN-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, 16
+; CHECK32ZFBFMIN-NEXT:    ret
+;
+; RV32ID-LABEL: fcvt_bf16_l:
+; RV32ID:       # %bb.0:
+; RV32ID-NEXT:    addi sp, sp, -16
+; RV32ID-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32ID-NEXT:    call __floatdisf at plt
+; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    fmv.x.w a0, fa0
+; RV32ID-NEXT:    lui a1, 1048560
+; RV32ID-NEXT:    or a0, a0, a1
+; RV32ID-NEXT:    fmv.w.x fa0, a0
+; RV32ID-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32ID-NEXT:    addi sp, sp, 16
+; RV32ID-NEXT:    ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_l:
+; CHECK64ZFBFMIN:       # %bb.0:
+; CHECK64ZFBFMIN-NEXT:    fcvt.s.l fa5, a0
+; CHECK64ZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
+; CHECK64ZFBFMIN-NEXT:    ret
+;
+; RV64ID-LABEL: fcvt_bf16_l:
+; RV64ID:       # %bb.0:
+; RV64ID-NEXT:    addi sp, sp, -16
+; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT:    fcvt.s.l fa0, a0
+; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    fmv.x.w a0, fa0
+; RV64ID-NEXT:    lui a1, 1048560
+; RV64ID-NEXT:    or a0, a0, a1
+; RV64ID-NEXT:    fmv.w.x fa0, a0
+; RV64ID-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT:    addi sp, sp, 16
+; RV64ID-NEXT:    ret
+  %1 = sitofp i64 %a to bfloat
+  ret bfloat %1
+}
 
-; define bfloat @fcvt_bf16_l(i64 %a) nounwind {
-;   %1 = sitofp i64 %a to bfloat
-;   ret bfloat %1
-; }
+; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
+; rounding is currently used.
+define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_lu:
+; CHECK32ZFBFMIN:       # %bb.0:
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, -16
+; CHECK32ZFBFMIN-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT:    call __floatundibf at plt
+; CHECK32ZFBFMIN-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, 16
+; CHECK32ZFBFMIN-NEXT:    ret
+;
+; RV32ID-LABEL: fcvt_bf16_lu:
+; RV32ID:       # %bb.0:
+; RV32ID-NEXT:    addi sp, sp, -16
+; RV32ID-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32ID-NEXT:    call __floatundisf at plt
+; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    fmv.x.w a0, fa0
+; RV32ID-NEXT:    lui a1, 1048560
+; RV32ID-NEXT:    or a0, a0, a1
+; RV32ID-NEXT:    fmv.w.x fa0, a0
+; RV32ID-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32ID-NEXT:    addi sp, sp, 16
+; RV32ID-NEXT:    ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_lu:
+; CHECK64ZFBFMIN:       # %bb.0:
+; CHECK64ZFBFMIN-NEXT:    fcvt.s.lu fa5, a0
+; CHECK64ZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
+; CHECK64ZFBFMIN-NEXT:    ret
+;
+; RV64ID-LABEL: fcvt_bf16_lu:
+; RV64ID:       # %bb.0:
+; RV64ID-NEXT:    addi sp, sp, -16
+; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT:    fcvt.s.lu fa0, a0
+; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    fmv.x.w a0, fa0
+; RV64ID-NEXT:    lui a1, 1048560
+; RV64ID-NEXT:    or a0, a0, a1
+; RV64ID-NEXT:    fmv.w.x fa0, a0
+; RV64ID-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT:    addi sp, sp, 16
+; RV64ID-NEXT:    ret
+  %1 = uitofp i64 %a to bfloat
+  ret bfloat %1
+}
+
+; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
+; double rounding is currently used.
+define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_ll:
+; CHECK32ZFBFMIN:       # %bb.0:
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, -32
+; CHECK32ZFBFMIN-NEXT:    sw ra, 28(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT:    lw a1, 0(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a2, 4(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a3, 8(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a0, 12(a0)
+; CHECK32ZFBFMIN-NEXT:    sw a0, 20(sp)
+; CHECK32ZFBFMIN-NEXT:    sw a3, 16(sp)
+; CHECK32ZFBFMIN-NEXT:    sw a2, 12(sp)
+; CHECK32ZFBFMIN-NEXT:    addi a0, sp, 8
+; CHECK32ZFBFMIN-NEXT:    sw a1, 8(sp)
+; CHECK32ZFBFMIN-NEXT:    call __floattibf at plt
+; CHECK32ZFBFMIN-NEXT:    lw ra, 28(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, 32
+; CHECK32ZFBFMIN-NEXT:    ret
+;
+; RV32ID-LABEL: fcvt_bf16_ll:
+; RV32ID:       # %bb.0:
+; RV32ID-NEXT:    addi sp, sp, -32
+; RV32ID-NEXT:    sw ra, 28(sp) # 4-byte Folded Spill
+; RV32ID-NEXT:    lw a1, 0(a0)
+; RV32ID-NEXT:    lw a2, 4(a0)
+; RV32ID-NEXT:    lw a3, 8(a0)
+; RV32ID-NEXT:    lw a0, 12(a0)
+; RV32ID-NEXT:    sw a0, 20(sp)
+; RV32ID-NEXT:    sw a3, 16(sp)
+; RV32ID-NEXT:    sw a2, 12(sp)
+; RV32ID-NEXT:    addi a0, sp, 8
+; RV32ID-NEXT:    sw a1, 8(sp)
+; RV32ID-NEXT:    call __floattisf at plt
+; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    fmv.x.w a0, fa0
+; RV32ID-NEXT:    lui a1, 1048560
+; RV32ID-NEXT:    or a0, a0, a1
+; RV32ID-NEXT:    fmv.w.x fa0, a0
+; RV32ID-NEXT:    lw ra, 28(sp) # 4-byte Folded Reload
+; RV32ID-NEXT:    addi sp, sp, 32
+; RV32ID-NEXT:    ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_ll:
+; CHECK64ZFBFMIN:       # %bb.0:
+; CHECK64ZFBFMIN-NEXT:    addi sp, sp, -16
+; CHECK64ZFBFMIN-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK64ZFBFMIN-NEXT:    call __floattibf at plt
+; CHECK64ZFBFMIN-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK64ZFBFMIN-NEXT:    addi sp, sp, 16
+; CHECK64ZFBFMIN-NEXT:    ret
+;
+; RV64ID-LABEL: fcvt_bf16_ll:
+; RV64ID:       # %bb.0:
+; RV64ID-NEXT:    addi sp, sp, -16
+; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT:    call __floattisf at plt
+; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    fmv.x.w a0, fa0
+; RV64ID-NEXT:    lui a1, 1048560
+; RV64ID-NEXT:    or a0, a0, a1
+; RV64ID-NEXT:    fmv.w.x fa0, a0
+; RV64ID-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT:    addi sp, sp, 16
+; RV64ID-NEXT:    ret
+  %1 = sitofp i128 %a to bfloat
+  ret bfloat %1
+}
 
-; define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
-;   %1 = uitofp i64 %a to bfloat
-;   ret bfloat %1
-; }
+; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
+; double rounding is currently used.
+define bfloat @fcvt_bf16_llu(i128 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_llu:
+; CHECK32ZFBFMIN:       # %bb.0:
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, -32
+; CHECK32ZFBFMIN-NEXT:    sw ra, 28(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT:    lw a1, 0(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a2, 4(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a3, 8(a0)
+; CHECK32ZFBFMIN-NEXT:    lw a0, 12(a0)
+; CHECK32ZFBFMIN-NEXT:    sw a0, 20(sp)
+; CHECK32ZFBFMIN-NEXT:    sw a3, 16(sp)
+; CHECK32ZFBFMIN-NEXT:    sw a2, 12(sp)
+; CHECK32ZFBFMIN-NEXT:    addi a0, sp, 8
+; CHECK32ZFBFMIN-NEXT:    sw a1, 8(sp)
+; CHECK32ZFBFMIN-NEXT:    call __floatuntibf at plt
+; CHECK32ZFBFMIN-NEXT:    lw ra, 28(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT:    addi sp, sp, 32
+; CHECK32ZFBFMIN-NEXT:    ret
+;
+; RV32ID-LABEL: fcvt_bf16_llu:
+; RV32ID:       # %bb.0:
+; RV32ID-NEXT:    addi sp, sp, -32
+; RV32ID-NEXT:    sw ra, 28(sp) # 4-byte Folded Spill
+; RV32ID-NEXT:    lw a1, 0(a0)
+; RV32ID-NEXT:    lw a2, 4(a0)
+; RV32ID-NEXT:    lw a3, 8(a0)
+; RV32ID-NEXT:    lw a0, 12(a0)
+; RV32ID-NEXT:    sw a0, 20(sp)
+; RV32ID-NEXT:    sw a3, 16(sp)
+; RV32ID-NEXT:    sw a2, 12(sp)
+; RV32ID-NEXT:    addi a0, sp, 8
+; RV32ID-NEXT:    sw a1, 8(sp)
+; RV32ID-NEXT:    call __floatuntisf at plt
+; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    fmv.x.w a0, fa0
+; RV32ID-NEXT:    lui a1, 1048560
+; RV32ID-NEXT:    or a0, a0, a1
+; RV32ID-NEXT:    fmv.w.x fa0, a0
+; RV32ID-NEXT:    lw ra, 28(sp) # 4-byte Folded Reload
+; RV32ID-NEXT:    addi sp, sp, 32
+; RV32ID-NEXT:    ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_llu:
+; CHECK64ZFBFMIN:       # %bb.0:
+; CHECK64ZFBFMIN-NEXT:    addi sp, sp, -16
+; CHECK64ZFBFMIN-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK64ZFBFMIN-NEXT:    call __floatuntibf at plt
+; CHECK64ZFBFMIN-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK64ZFBFMIN-NEXT:    addi sp, sp, 16
+; CHECK64ZFBFMIN-NEXT:    ret
+;
+; RV64ID-LABEL: fcvt_bf16_llu:
+; RV64ID:       # %bb.0:
+; RV64ID-NEXT:    addi sp, sp, -16
+; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT:    call __floatuntisf at plt
+; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    fmv.x.w a0, fa0
+; RV64ID-NEXT:    lui a1, 1048560
+; RV64ID-NEXT:    or a0, a0, a1
+; RV64ID-NEXT:    fmv.w.x fa0, a0
+; RV64ID-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT:    addi sp, sp, 16
+; RV64ID-NEXT:    ret
+  %1 = uitofp i128 %a to bfloat
+  ret bfloat %1
+}
 
 define bfloat @fcvt_bf16_s(float %a) nounwind {
 ; CHECK32ZFBFMIN-LABEL: fcvt_bf16_s:

>From 0f7c227853322e6b708934089b9f00f7769a6318 Mon Sep 17 00:00:00 2001
From: Alex Bradbury <asb at igalia.com>
Date: Wed, 8 Nov 2023 11:19:42 +0000
Subject: [PATCH 2/2] [SelectionDAG] Avoid double rounding by using a libcall
 for int (>i32) to fp conversions

This is the next logical step to #70933.
---
 .../SelectionDAG/LegalizeFloatTypes.cpp       | 35 +++++++++++++++++++
 llvm/test/CodeGen/RISCV/bfloat-convert.ll     | 32 +++++------------
 2 files changed, 43 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 630aa4a07d7b946..2da7be4ee848724 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2953,8 +2953,43 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_SELECT_CC(SDNode *N) {
 SDValue DAGTypeLegalizer::SoftPromoteHalfRes_XINT_TO_FP(SDNode *N) {
   EVT OVT = N->getValueType(0);
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), OVT);
+  EVT SVT = N->getOperand(0).getValueType();
   SDLoc dl(N);
 
+  // For bf16, conversion to f32 and then rounding isn't semantically correct
+  // for types larger than i32 (it double rounds), so produce a libcall and
+  // cast the result to i16 in order to match the expected type for a
+  // soft-promoted half/bf16.
+  if (OVT == MVT::bf16 && SVT.bitsGT(MVT::i32)) {
+    assert(!N->isStrictFPOpcode() && "Unexpected strict opcode\n");
+    bool Signed = N->getOpcode() == ISD::SINT_TO_FP;
+    EVT NSVT = EVT();
+    // If the input is not legal, eg: i1 -> fp, then it needs to be promoted to
+    // a larger type, eg: i8 -> fp.  Even if it is legal, no libcall may exactly
+    // match.  Look for an appropriate libcall.
+    RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
+    for (unsigned t = MVT::FIRST_INTEGER_VALUETYPE;
+         t <= MVT::LAST_INTEGER_VALUETYPE && LC == RTLIB::UNKNOWN_LIBCALL;
+         ++t) {
+      NSVT = (MVT::SimpleValueType)t;
+      // The source needs to big enough to hold the operand.
+      if (NSVT.bitsGE(SVT))
+        LC = Signed ? RTLIB::getSINTTOFP(NSVT, MVT::bf16)
+                    : RTLIB::getUINTTOFP(NSVT, MVT::bf16);
+    }
+    assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unsupported XINT_TO_FP!");
+
+    // Sign/zero extend the argument if the libcall takes a larger type.
+    SDValue Op = DAG.getNode(Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, dl,
+                             NSVT, N->getOperand(0));
+    TargetLowering::MakeLibCallOptions CallOptions;
+    CallOptions.setSExt(Signed);
+    CallOptions.setTypeListBeforeSoften(SVT, MVT::bf16, true);
+    return DAG.getNode(
+        ISD::BITCAST, dl, MVT::i16,
+        TLI.makeLibCall(DAG, LC, MVT::bf16, Op, CallOptions, dl).first);
+  }
+
   SDValue Res = DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
 
   // Round the value to the softened type.
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 092a32f54b54992..31bc87aca4e44ec 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -1127,8 +1127,6 @@ define bfloat @fcvt_bf16_wu_load(ptr %p) nounwind {
   ret bfloat %1
 }
 
-; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
-; rounding is currently used.
 define bfloat @fcvt_bf16_l(i64 %a) nounwind {
 ; CHECK32ZFBFMIN-LABEL: fcvt_bf16_l:
 ; CHECK32ZFBFMIN:       # %bb.0:
@@ -1143,8 +1141,7 @@ define bfloat @fcvt_bf16_l(i64 %a) nounwind {
 ; RV32ID:       # %bb.0:
 ; RV32ID-NEXT:    addi sp, sp, -16
 ; RV32ID-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
-; RV32ID-NEXT:    call __floatdisf at plt
-; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    call __floatdibf at plt
 ; RV32ID-NEXT:    fmv.x.w a0, fa0
 ; RV32ID-NEXT:    lui a1, 1048560
 ; RV32ID-NEXT:    or a0, a0, a1
@@ -1163,8 +1160,7 @@ define bfloat @fcvt_bf16_l(i64 %a) nounwind {
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    addi sp, sp, -16
 ; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-NEXT:    fcvt.s.l fa0, a0
-; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    call __floatdibf at plt
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
 ; RV64ID-NEXT:    lui a1, 1048560
 ; RV64ID-NEXT:    or a0, a0, a1
@@ -1176,8 +1172,6 @@ define bfloat @fcvt_bf16_l(i64 %a) nounwind {
   ret bfloat %1
 }
 
-; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
-; rounding is currently used.
 define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
 ; CHECK32ZFBFMIN-LABEL: fcvt_bf16_lu:
 ; CHECK32ZFBFMIN:       # %bb.0:
@@ -1192,8 +1186,7 @@ define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
 ; RV32ID:       # %bb.0:
 ; RV32ID-NEXT:    addi sp, sp, -16
 ; RV32ID-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
-; RV32ID-NEXT:    call __floatundisf at plt
-; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    call __floatundibf at plt
 ; RV32ID-NEXT:    fmv.x.w a0, fa0
 ; RV32ID-NEXT:    lui a1, 1048560
 ; RV32ID-NEXT:    or a0, a0, a1
@@ -1212,8 +1205,7 @@ define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    addi sp, sp, -16
 ; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-NEXT:    fcvt.s.lu fa0, a0
-; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    call __floatundibf at plt
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
 ; RV64ID-NEXT:    lui a1, 1048560
 ; RV64ID-NEXT:    or a0, a0, a1
@@ -1225,8 +1217,6 @@ define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
   ret bfloat %1
 }
 
-; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
-; double rounding is currently used.
 define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
 ; CHECK32ZFBFMIN-LABEL: fcvt_bf16_ll:
 ; CHECK32ZFBFMIN:       # %bb.0:
@@ -1259,8 +1249,7 @@ define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
 ; RV32ID-NEXT:    sw a2, 12(sp)
 ; RV32ID-NEXT:    addi a0, sp, 8
 ; RV32ID-NEXT:    sw a1, 8(sp)
-; RV32ID-NEXT:    call __floattisf at plt
-; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    call __floattibf at plt
 ; RV32ID-NEXT:    fmv.x.w a0, fa0
 ; RV32ID-NEXT:    lui a1, 1048560
 ; RV32ID-NEXT:    or a0, a0, a1
@@ -1282,8 +1271,7 @@ define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    addi sp, sp, -16
 ; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-NEXT:    call __floattisf at plt
-; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    call __floattibf at plt
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
 ; RV64ID-NEXT:    lui a1, 1048560
 ; RV64ID-NEXT:    or a0, a0, a1
@@ -1295,8 +1283,6 @@ define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
   ret bfloat %1
 }
 
-; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
-; double rounding is currently used.
 define bfloat @fcvt_bf16_llu(i128 %a) nounwind {
 ; CHECK32ZFBFMIN-LABEL: fcvt_bf16_llu:
 ; CHECK32ZFBFMIN:       # %bb.0:
@@ -1329,8 +1315,7 @@ define bfloat @fcvt_bf16_llu(i128 %a) nounwind {
 ; RV32ID-NEXT:    sw a2, 12(sp)
 ; RV32ID-NEXT:    addi a0, sp, 8
 ; RV32ID-NEXT:    sw a1, 8(sp)
-; RV32ID-NEXT:    call __floatuntisf at plt
-; RV32ID-NEXT:    call __truncsfbf2 at plt
+; RV32ID-NEXT:    call __floatuntibf at plt
 ; RV32ID-NEXT:    fmv.x.w a0, fa0
 ; RV32ID-NEXT:    lui a1, 1048560
 ; RV32ID-NEXT:    or a0, a0, a1
@@ -1352,8 +1337,7 @@ define bfloat @fcvt_bf16_llu(i128 %a) nounwind {
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    addi sp, sp, -16
 ; RV64ID-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-NEXT:    call __floatuntisf at plt
-; RV64ID-NEXT:    call __truncsfbf2 at plt
+; RV64ID-NEXT:    call __floatuntibf at plt
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
 ; RV64ID-NEXT:    lui a1, 1048560
 ; RV64ID-NEXT:    or a0, a0, a1



More information about the llvm-commits mailing list