[llvm] AMDGPU: Add round-to-odd rounding during f64 to bf16 conversion (PR #133995)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 1 22:04:00 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Changpeng Fang (changpeng)
<details>
<summary>Changes</summary>
f64 -> bf16 conversion can be lowered to f64 -> f32 followed by f32 -> bf16:
v_cvt_f32_f64_e32 v0, v[0:1]
v_cvt_pk_bf16_f32 v0, v0, s0
Both conversion instructions will do round-to-even rounding, and thus we will have double rounding issue which may generate incorrect result in some data range. We need to add round-to-odd rounding during f64 -> f32 to avoid double rounding,.
NOTE: we are having the same issue with f64 -> f16 conversion. Will add round-to-odd rounding for it in a separate patch, which fixes SWDEV-523856
---
Full diff: https://github.com/llvm/llvm-project/pull/133995.diff
2 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+22-13)
- (modified) llvm/test/CodeGen/AMDGPU/bf16-conversions.ll (+66-6)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 96c113cc5d24c..e9aef744abcd9 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -911,8 +911,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MUL, MVT::i1, Promote);
if (Subtarget->hasBF16ConversionInsts()) {
- setOperationAction(ISD::FP_ROUND, MVT::v2bf16, Legal);
- setOperationAction(ISD::FP_ROUND, MVT::bf16, Legal);
+ setOperationAction(ISD::FP_ROUND, {MVT::bf16, MVT::v2bf16}, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal);
}
@@ -6888,23 +6887,33 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
}
SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
- assert(Op.getValueType() == MVT::f16 &&
- "Do not know how to custom lower FP_ROUND for non-f16 type");
-
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
- if (SrcVT != MVT::f64)
- return Op;
-
- // TODO: Handle strictfp
- if (Op.getOpcode() != ISD::FP_ROUND)
+ if (SrcVT.getScalarType() != MVT::f64)
return Op;
+ EVT DstVT = Op.getValueType();
SDLoc DL(Op);
+ if (DstVT == MVT::f16) {
+ // TODO: Handle strictfp
+ if (Op.getOpcode() != ISD::FP_ROUND)
+ return Op;
+
+ SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
+ return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+ }
+
+ assert(DstVT.getScalarType() == MVT::bf16 &&
+ "custom lower FP_ROUND for f16 or bf16");
+ assert(Subtarget->hasBF16ConversionInsts() && "f32 -> bf16 is legal");
- SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
- return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+ // Round-inexact-to-odd f64 to f32, then do the final rounding using the
+ // hardware f32 -> bf16 instruction.
+ EVT F32VT = SrcVT.isVector() ? SrcVT.changeVectorElementType(MVT::f32) :
+ MVT::f32;
+ SDValue Rod = expandRoundInexactToOdd(F32VT, Src, DL, DAG);
+ return getFPExtOrFPRound(DAG, Rod, DL, DstVT);
}
SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op,
diff --git a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
index 4c01e583713a7..3be911ab9e7f4 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
@@ -153,9 +153,34 @@ define amdgpu_ps float @v_test_cvt_v2f64_v2bf16_v(<2 x double> %src) {
;
; GFX-950-LABEL: v_test_cvt_v2f64_v2bf16_v:
; GFX-950: ; %bb.0:
-; GFX-950-NEXT: v_cvt_f32_f64_e32 v2, v[2:3]
-; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v2
+; GFX-950-NEXT: v_mov_b32_e32 v4, v3
+; GFX-950-NEXT: v_and_b32_e32 v3, 0x7fffffff, v4
+; GFX-950-NEXT: v_mov_b32_e32 v5, v1
+; GFX-950-NEXT: v_cvt_f32_f64_e32 v1, v[2:3]
+; GFX-950-NEXT: v_cvt_f64_f32_e32 v[6:7], v1
+; GFX-950-NEXT: v_and_b32_e32 v8, 1, v1
+; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[2:3], v[6:7]
+; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[2:3], v[6:7]
+; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v8
+; GFX-950-NEXT: v_cndmask_b32_e64 v2, -1, 1, s[2:3]
+; GFX-950-NEXT: v_add_u32_e32 v2, v1, v2
+; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1]
+; GFX-950-NEXT: v_cndmask_b32_e32 v1, v2, v1, vcc
+; GFX-950-NEXT: s_brev_b32 s4, 1
+; GFX-950-NEXT: v_and_or_b32 v4, v4, s4, v1
+; GFX-950-NEXT: v_and_b32_e32 v1, 0x7fffffff, v5
+; GFX-950-NEXT: v_cvt_f32_f64_e32 v6, v[0:1]
+; GFX-950-NEXT: v_cvt_f64_f32_e32 v[2:3], v6
+; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6
+; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[0:1], v[2:3]
+; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[0:1], v[2:3]
+; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v7
+; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0
+; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1]
+; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc
+; GFX-950-NEXT: v_and_or_b32 v0, v5, s4, v0
+; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v4
; GFX-950-NEXT: ; return to shader part epilog
%res = fptrunc <2 x double> %src to <2 x bfloat>
%cast = bitcast <2 x bfloat> %res to float
@@ -347,7 +372,18 @@ define amdgpu_ps void @fptrunc_f64_to_bf16(double %a, ptr %out) {
;
; GFX-950-LABEL: fptrunc_f64_to_bf16:
; GFX-950: ; %bb.0: ; %entry
-; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
+; GFX-950-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
+; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6
+; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v7
+; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0
+; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc
+; GFX-950-NEXT: s_brev_b32 s0, 1
+; GFX-950-NEXT: v_and_or_b32 v0, v1, s0, v0
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
; GFX-950-NEXT: flat_store_short v[2:3], v0
; GFX-950-NEXT: s_endpgm
@@ -385,7 +421,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_neg(double %a, ptr %out) {
;
; GFX-950-LABEL: fptrunc_f64_to_bf16_neg:
; GFX-950: ; %bb.0: ; %entry
-; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, -v[0:1]
+; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]|
+; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7
+; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7
+; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8
+; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0
+; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT: s_brev_b32 s4, 1
+; GFX-950-NEXT: v_xor_b32_e32 v6, 0x80000000, v1
+; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc
+; GFX-950-NEXT: v_and_or_b32 v0, v6, s4, v0
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
; GFX-950-NEXT: flat_store_short v[2:3], v0
; GFX-950-NEXT: s_endpgm
@@ -424,7 +472,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_abs(double %a, ptr %out) {
;
; GFX-950-LABEL: fptrunc_f64_to_bf16_abs:
; GFX-950: ; %bb.0: ; %entry
-; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, |v[0:1]|
+; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]|
+; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7
+; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7
+; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8
+; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0
+; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT: v_and_b32_e32 v6, 0x7fffffff, v1
+; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc
+; GFX-950-NEXT: s_brev_b32 s0, 1
+; GFX-950-NEXT: v_and_or_b32 v0, v6, s0, v0
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
; GFX-950-NEXT: flat_store_short v[2:3], v0
; GFX-950-NEXT: s_endpgm
``````````
</details>
https://github.com/llvm/llvm-project/pull/133995
More information about the llvm-commits
mailing list