[llvm] [CodeGen][RISCV] Add [u]int{64, 128} to bf16 libcalls (PR #70933)
Alex Bradbury via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 1 06:06:35 PDT 2023
https://github.com/asb created https://github.com/llvm/llvm-project/pull/70933
As noted in D157509 <https://reviews.llvm.org/D157509>, the default lowering for {S,U}INT_TO_P for bf16 that first converts to float and then converts from float to bf16 is semantically incorrect as it introduces double rounding. This patch doesn't fix/alter that for the case where bf16 is not a legal type (to be handled in a follow-up), but does fix it for the case where bf16 is a legal type. This is currently only exercised by the RISC-V target.
The libcall names are the same as provided by libgcc. A separate patch will add them to compiler-rt.
>From 1a20e2848d6be3f7c1b1c047e433563fdb9fc100 Mon Sep 17 00:00:00 2001
From: Alex Bradbury <asb at igalia.com>
Date: Wed, 1 Nov 2023 12:57:37 +0000
Subject: [PATCH] [CodeGen][RISCV] Add [u]int{64,128} to bf16 libcalls
As noted in D157509 <https://reviews.llvm.org/D157509>, the default
lowering for {S,U}INT_TO_P for bf16 first converts to float and then
converts from float to bf16 is semantically incorrect as it introduces
double rounding. This patch doesn't fix/alter that for the case where
bf16 is not a legal type (to be handled in a follow-up), but does fix it
for the case where bf16 is a legal type. This is currently only
exercised by the RISC-V target.
The libcall names are the same as provided by libgcc. A separate patch
will add them to compiler-rt.
---
llvm/include/llvm/IR/RuntimeLibcalls.def | 4 +
llvm/lib/CodeGen/TargetLoweringBase.cpp | 8 +
llvm/test/CodeGen/RISCV/bfloat-convert.ll | 244 +++++++++++++++++++++-
3 files changed, 247 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/IR/RuntimeLibcalls.def b/llvm/include/llvm/IR/RuntimeLibcalls.def
index 6ec98e278988428..61377091f952539 100644
--- a/llvm/include/llvm/IR/RuntimeLibcalls.def
+++ b/llvm/include/llvm/IR/RuntimeLibcalls.def
@@ -371,12 +371,14 @@ HANDLE_LIBCALL(SINTTOFP_I32_F64, "__floatsidf")
HANDLE_LIBCALL(SINTTOFP_I32_F80, "__floatsixf")
HANDLE_LIBCALL(SINTTOFP_I32_F128, "__floatsitf")
HANDLE_LIBCALL(SINTTOFP_I32_PPCF128, "__gcc_itoq")
+HANDLE_LIBCALL(SINTTOFP_I64_BF16, "__floatdibf")
HANDLE_LIBCALL(SINTTOFP_I64_F16, "__floatdihf")
HANDLE_LIBCALL(SINTTOFP_I64_F32, "__floatdisf")
HANDLE_LIBCALL(SINTTOFP_I64_F64, "__floatdidf")
HANDLE_LIBCALL(SINTTOFP_I64_F80, "__floatdixf")
HANDLE_LIBCALL(SINTTOFP_I64_F128, "__floatditf")
HANDLE_LIBCALL(SINTTOFP_I64_PPCF128, "__floatditf")
+HANDLE_LIBCALL(SINTTOFP_I128_BF16, "__floattibf")
HANDLE_LIBCALL(SINTTOFP_I128_F16, "__floattihf")
HANDLE_LIBCALL(SINTTOFP_I128_F32, "__floattisf")
HANDLE_LIBCALL(SINTTOFP_I128_F64, "__floattidf")
@@ -389,12 +391,14 @@ HANDLE_LIBCALL(UINTTOFP_I32_F64, "__floatunsidf")
HANDLE_LIBCALL(UINTTOFP_I32_F80, "__floatunsixf")
HANDLE_LIBCALL(UINTTOFP_I32_F128, "__floatunsitf")
HANDLE_LIBCALL(UINTTOFP_I32_PPCF128, "__gcc_utoq")
+HANDLE_LIBCALL(UINTTOFP_I64_BF16, "__floatundibf")
HANDLE_LIBCALL(UINTTOFP_I64_F16, "__floatundihf")
HANDLE_LIBCALL(UINTTOFP_I64_F32, "__floatundisf")
HANDLE_LIBCALL(UINTTOFP_I64_F64, "__floatundidf")
HANDLE_LIBCALL(UINTTOFP_I64_F80, "__floatundixf")
HANDLE_LIBCALL(UINTTOFP_I64_F128, "__floatunditf")
HANDLE_LIBCALL(UINTTOFP_I64_PPCF128, "__floatunditf")
+HANDLE_LIBCALL(UINTTOFP_I128_BF16, "__floatuntibf")
HANDLE_LIBCALL(UINTTOFP_I128_F16, "__floatuntihf")
HANDLE_LIBCALL(UINTTOFP_I128_F32, "__floatuntisf")
HANDLE_LIBCALL(UINTTOFP_I128_F64, "__floatuntidf")
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 722cefb1eddb3c5..91959b0cb6cc6b5 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -430,6 +430,8 @@ RTLIB::Libcall RTLIB::getSINTTOFP(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::ppcf128)
return SINTTOFP_I32_PPCF128;
} else if (OpVT == MVT::i64) {
+ if (RetVT == MVT::bf16)
+ return SINTTOFP_I64_BF16;
if (RetVT == MVT::f16)
return SINTTOFP_I64_F16;
if (RetVT == MVT::f32)
@@ -443,6 +445,8 @@ RTLIB::Libcall RTLIB::getSINTTOFP(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::ppcf128)
return SINTTOFP_I64_PPCF128;
} else if (OpVT == MVT::i128) {
+ if (RetVT == MVT::bf16)
+ return SINTTOFP_I128_BF16;
if (RetVT == MVT::f16)
return SINTTOFP_I128_F16;
if (RetVT == MVT::f32)
@@ -476,6 +480,8 @@ RTLIB::Libcall RTLIB::getUINTTOFP(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::ppcf128)
return UINTTOFP_I32_PPCF128;
} else if (OpVT == MVT::i64) {
+ if (RetVT == MVT::bf16)
+ return UINTTOFP_I64_BF16;
if (RetVT == MVT::f16)
return UINTTOFP_I64_F16;
if (RetVT == MVT::f32)
@@ -489,6 +495,8 @@ RTLIB::Libcall RTLIB::getUINTTOFP(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::ppcf128)
return UINTTOFP_I64_PPCF128;
} else if (OpVT == MVT::i128) {
+ if (RetVT == MVT::bf16)
+ return UINTTOFP_I128_BF16;
if (RetVT == MVT::f16)
return UINTTOFP_I128_F16;
if (RetVT == MVT::f32)
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 8a0c4240d161bfb..092a32f54b54992 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -1127,17 +1127,243 @@ define bfloat @fcvt_bf16_wu_load(ptr %p) nounwind {
ret bfloat %1
}
-; TODO: The following tests error on rv32 with zfbfmin enabled.
+; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
+; rounding is currently used.
+define bfloat @fcvt_bf16_l(i64 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_l:
+; CHECK32ZFBFMIN: # %bb.0:
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, -16
+; CHECK32ZFBFMIN-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT: call __floatdibf at plt
+; CHECK32ZFBFMIN-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, 16
+; CHECK32ZFBFMIN-NEXT: ret
+;
+; RV32ID-LABEL: fcvt_bf16_l:
+; RV32ID: # %bb.0:
+; RV32ID-NEXT: addi sp, sp, -16
+; RV32ID-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32ID-NEXT: call __floatdisf at plt
+; RV32ID-NEXT: call __truncsfbf2 at plt
+; RV32ID-NEXT: fmv.x.w a0, fa0
+; RV32ID-NEXT: lui a1, 1048560
+; RV32ID-NEXT: or a0, a0, a1
+; RV32ID-NEXT: fmv.w.x fa0, a0
+; RV32ID-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32ID-NEXT: addi sp, sp, 16
+; RV32ID-NEXT: ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_l:
+; CHECK64ZFBFMIN: # %bb.0:
+; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
+; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
+; CHECK64ZFBFMIN-NEXT: ret
+;
+; RV64ID-LABEL: fcvt_bf16_l:
+; RV64ID: # %bb.0:
+; RV64ID-NEXT: addi sp, sp, -16
+; RV64ID-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT: fcvt.s.l fa0, a0
+; RV64ID-NEXT: call __truncsfbf2 at plt
+; RV64ID-NEXT: fmv.x.w a0, fa0
+; RV64ID-NEXT: lui a1, 1048560
+; RV64ID-NEXT: or a0, a0, a1
+; RV64ID-NEXT: fmv.w.x fa0, a0
+; RV64ID-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT: addi sp, sp, 16
+; RV64ID-NEXT: ret
+ %1 = sitofp i64 %a to bfloat
+ ret bfloat %1
+}
-; define bfloat @fcvt_bf16_l(i64 %a) nounwind {
-; %1 = sitofp i64 %a to bfloat
-; ret bfloat %1
-; }
+; TODO: Other than the RV32 zfbfmin case, semantically incorrect double
+; rounding is currently used.
+define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_lu:
+; CHECK32ZFBFMIN: # %bb.0:
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, -16
+; CHECK32ZFBFMIN-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT: call __floatundibf at plt
+; CHECK32ZFBFMIN-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, 16
+; CHECK32ZFBFMIN-NEXT: ret
+;
+; RV32ID-LABEL: fcvt_bf16_lu:
+; RV32ID: # %bb.0:
+; RV32ID-NEXT: addi sp, sp, -16
+; RV32ID-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
+; RV32ID-NEXT: call __floatundisf at plt
+; RV32ID-NEXT: call __truncsfbf2 at plt
+; RV32ID-NEXT: fmv.x.w a0, fa0
+; RV32ID-NEXT: lui a1, 1048560
+; RV32ID-NEXT: or a0, a0, a1
+; RV32ID-NEXT: fmv.w.x fa0, a0
+; RV32ID-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
+; RV32ID-NEXT: addi sp, sp, 16
+; RV32ID-NEXT: ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_lu:
+; CHECK64ZFBFMIN: # %bb.0:
+; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
+; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
+; CHECK64ZFBFMIN-NEXT: ret
+;
+; RV64ID-LABEL: fcvt_bf16_lu:
+; RV64ID: # %bb.0:
+; RV64ID-NEXT: addi sp, sp, -16
+; RV64ID-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT: fcvt.s.lu fa0, a0
+; RV64ID-NEXT: call __truncsfbf2 at plt
+; RV64ID-NEXT: fmv.x.w a0, fa0
+; RV64ID-NEXT: lui a1, 1048560
+; RV64ID-NEXT: or a0, a0, a1
+; RV64ID-NEXT: fmv.w.x fa0, a0
+; RV64ID-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT: addi sp, sp, 16
+; RV64ID-NEXT: ret
+ %1 = uitofp i64 %a to bfloat
+ ret bfloat %1
+}
+
+; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
+; double rounding is currently used.
+define bfloat @fcvt_bf16_ll(i128 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_ll:
+; CHECK32ZFBFMIN: # %bb.0:
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, -32
+; CHECK32ZFBFMIN-NEXT: sw ra, 28(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT: lw a1, 0(a0)
+; CHECK32ZFBFMIN-NEXT: lw a2, 4(a0)
+; CHECK32ZFBFMIN-NEXT: lw a3, 8(a0)
+; CHECK32ZFBFMIN-NEXT: lw a0, 12(a0)
+; CHECK32ZFBFMIN-NEXT: sw a0, 20(sp)
+; CHECK32ZFBFMIN-NEXT: sw a3, 16(sp)
+; CHECK32ZFBFMIN-NEXT: sw a2, 12(sp)
+; CHECK32ZFBFMIN-NEXT: addi a0, sp, 8
+; CHECK32ZFBFMIN-NEXT: sw a1, 8(sp)
+; CHECK32ZFBFMIN-NEXT: call __floattibf at plt
+; CHECK32ZFBFMIN-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, 32
+; CHECK32ZFBFMIN-NEXT: ret
+;
+; RV32ID-LABEL: fcvt_bf16_ll:
+; RV32ID: # %bb.0:
+; RV32ID-NEXT: addi sp, sp, -32
+; RV32ID-NEXT: sw ra, 28(sp) # 4-byte Folded Spill
+; RV32ID-NEXT: lw a1, 0(a0)
+; RV32ID-NEXT: lw a2, 4(a0)
+; RV32ID-NEXT: lw a3, 8(a0)
+; RV32ID-NEXT: lw a0, 12(a0)
+; RV32ID-NEXT: sw a0, 20(sp)
+; RV32ID-NEXT: sw a3, 16(sp)
+; RV32ID-NEXT: sw a2, 12(sp)
+; RV32ID-NEXT: addi a0, sp, 8
+; RV32ID-NEXT: sw a1, 8(sp)
+; RV32ID-NEXT: call __floattisf at plt
+; RV32ID-NEXT: call __truncsfbf2 at plt
+; RV32ID-NEXT: fmv.x.w a0, fa0
+; RV32ID-NEXT: lui a1, 1048560
+; RV32ID-NEXT: or a0, a0, a1
+; RV32ID-NEXT: fmv.w.x fa0, a0
+; RV32ID-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
+; RV32ID-NEXT: addi sp, sp, 32
+; RV32ID-NEXT: ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_ll:
+; CHECK64ZFBFMIN: # %bb.0:
+; CHECK64ZFBFMIN-NEXT: addi sp, sp, -16
+; CHECK64ZFBFMIN-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK64ZFBFMIN-NEXT: call __floattibf at plt
+; CHECK64ZFBFMIN-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK64ZFBFMIN-NEXT: addi sp, sp, 16
+; CHECK64ZFBFMIN-NEXT: ret
+;
+; RV64ID-LABEL: fcvt_bf16_ll:
+; RV64ID: # %bb.0:
+; RV64ID-NEXT: addi sp, sp, -16
+; RV64ID-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT: call __floattisf at plt
+; RV64ID-NEXT: call __truncsfbf2 at plt
+; RV64ID-NEXT: fmv.x.w a0, fa0
+; RV64ID-NEXT: lui a1, 1048560
+; RV64ID-NEXT: or a0, a0, a1
+; RV64ID-NEXT: fmv.w.x fa0, a0
+; RV64ID-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT: addi sp, sp, 16
+; RV64ID-NEXT: ret
+ %1 = sitofp i128 %a to bfloat
+ ret bfloat %1
+}
-; define bfloat @fcvt_bf16_lu(i64 %a) nounwind {
-; %1 = uitofp i64 %a to bfloat
-; ret bfloat %1
-; }
+; TODO: Other than the RV32 and RV64 zfbfmin cases, semantically incorrect
+; double rounding is currently used.
+define bfloat @fcvt_bf16_llu(i128 %a) nounwind {
+; CHECK32ZFBFMIN-LABEL: fcvt_bf16_llu:
+; CHECK32ZFBFMIN: # %bb.0:
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, -32
+; CHECK32ZFBFMIN-NEXT: sw ra, 28(sp) # 4-byte Folded Spill
+; CHECK32ZFBFMIN-NEXT: lw a1, 0(a0)
+; CHECK32ZFBFMIN-NEXT: lw a2, 4(a0)
+; CHECK32ZFBFMIN-NEXT: lw a3, 8(a0)
+; CHECK32ZFBFMIN-NEXT: lw a0, 12(a0)
+; CHECK32ZFBFMIN-NEXT: sw a0, 20(sp)
+; CHECK32ZFBFMIN-NEXT: sw a3, 16(sp)
+; CHECK32ZFBFMIN-NEXT: sw a2, 12(sp)
+; CHECK32ZFBFMIN-NEXT: addi a0, sp, 8
+; CHECK32ZFBFMIN-NEXT: sw a1, 8(sp)
+; CHECK32ZFBFMIN-NEXT: call __floatuntibf at plt
+; CHECK32ZFBFMIN-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
+; CHECK32ZFBFMIN-NEXT: addi sp, sp, 32
+; CHECK32ZFBFMIN-NEXT: ret
+;
+; RV32ID-LABEL: fcvt_bf16_llu:
+; RV32ID: # %bb.0:
+; RV32ID-NEXT: addi sp, sp, -32
+; RV32ID-NEXT: sw ra, 28(sp) # 4-byte Folded Spill
+; RV32ID-NEXT: lw a1, 0(a0)
+; RV32ID-NEXT: lw a2, 4(a0)
+; RV32ID-NEXT: lw a3, 8(a0)
+; RV32ID-NEXT: lw a0, 12(a0)
+; RV32ID-NEXT: sw a0, 20(sp)
+; RV32ID-NEXT: sw a3, 16(sp)
+; RV32ID-NEXT: sw a2, 12(sp)
+; RV32ID-NEXT: addi a0, sp, 8
+; RV32ID-NEXT: sw a1, 8(sp)
+; RV32ID-NEXT: call __floatuntisf at plt
+; RV32ID-NEXT: call __truncsfbf2 at plt
+; RV32ID-NEXT: fmv.x.w a0, fa0
+; RV32ID-NEXT: lui a1, 1048560
+; RV32ID-NEXT: or a0, a0, a1
+; RV32ID-NEXT: fmv.w.x fa0, a0
+; RV32ID-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
+; RV32ID-NEXT: addi sp, sp, 32
+; RV32ID-NEXT: ret
+;
+; CHECK64ZFBFMIN-LABEL: fcvt_bf16_llu:
+; CHECK64ZFBFMIN: # %bb.0:
+; CHECK64ZFBFMIN-NEXT: addi sp, sp, -16
+; CHECK64ZFBFMIN-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK64ZFBFMIN-NEXT: call __floatuntibf at plt
+; CHECK64ZFBFMIN-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK64ZFBFMIN-NEXT: addi sp, sp, 16
+; CHECK64ZFBFMIN-NEXT: ret
+;
+; RV64ID-LABEL: fcvt_bf16_llu:
+; RV64ID: # %bb.0:
+; RV64ID-NEXT: addi sp, sp, -16
+; RV64ID-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
+; RV64ID-NEXT: call __floatuntisf at plt
+; RV64ID-NEXT: call __truncsfbf2 at plt
+; RV64ID-NEXT: fmv.x.w a0, fa0
+; RV64ID-NEXT: lui a1, 1048560
+; RV64ID-NEXT: or a0, a0, a1
+; RV64ID-NEXT: fmv.w.x fa0, a0
+; RV64ID-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
+; RV64ID-NEXT: addi sp, sp, 16
+; RV64ID-NEXT: ret
+ %1 = uitofp i128 %a to bfloat
+ ret bfloat %1
+}
define bfloat @fcvt_bf16_s(float %a) nounwind {
; CHECK32ZFBFMIN-LABEL: fcvt_bf16_s:
More information about the llvm-commits
mailing list