[llvm] AMDGPU: Add round-to-odd rounding during f64 to bf16 conversion (PR #133995)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 22:04:00 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Changpeng Fang (changpeng)

<details>
<summary>Changes</summary>

f64 -> bf16 conversion can be lowered to f64 -> f32 followed by f32 -> bf16:
   v_cvt_f32_f64_e32 v0, v[0:1]
   v_cvt_pk_bf16_f32 v0, v0, s0
Both conversion instructions will do round-to-even rounding, and thus we will have double rounding issue which may generate incorrect result in some data range.  We need to add round-to-odd rounding during f64 -> f32 to avoid double rounding,.

NOTE: we are having the same issue with f64 -> f16 conversion. Will add round-to-odd rounding for it in a separate patch, which fixes SWDEV-523856

---
Full diff: https://github.com/llvm/llvm-project/pull/133995.diff


2 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+22-13) 
- (modified) llvm/test/CodeGen/AMDGPU/bf16-conversions.ll (+66-6) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 96c113cc5d24c..e9aef744abcd9 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -911,8 +911,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::MUL, MVT::i1, Promote);
 
   if (Subtarget->hasBF16ConversionInsts()) {
-    setOperationAction(ISD::FP_ROUND, MVT::v2bf16, Legal);
-    setOperationAction(ISD::FP_ROUND, MVT::bf16, Legal);
+    setOperationAction(ISD::FP_ROUND, {MVT::bf16, MVT::v2bf16}, Custom);
     setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal);
   }
 
@@ -6888,23 +6887,33 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
 }
 
 SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
-  assert(Op.getValueType() == MVT::f16 &&
-         "Do not know how to custom lower FP_ROUND for non-f16 type");
-
   SDValue Src = Op.getOperand(0);
   EVT SrcVT = Src.getValueType();
-  if (SrcVT != MVT::f64)
-    return Op;
-
-  // TODO: Handle strictfp
-  if (Op.getOpcode() != ISD::FP_ROUND)
+  if (SrcVT.getScalarType() != MVT::f64)
     return Op;
 
+  EVT DstVT = Op.getValueType();
   SDLoc DL(Op);
+  if (DstVT == MVT::f16) {
+    // TODO: Handle strictfp
+    if (Op.getOpcode() != ISD::FP_ROUND)
+      return Op;
+
+    SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
+    SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
+    return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+  }
+
+  assert(DstVT.getScalarType() == MVT::bf16 &&
+          "custom lower FP_ROUND for f16 or bf16");
+  assert(Subtarget->hasBF16ConversionInsts() && "f32 -> bf16 is legal");
 
-  SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
-  SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
-  return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
+  // Round-inexact-to-odd f64 to f32, then do the final rounding using the
+  // hardware f32 -> bf16 instruction.
+  EVT F32VT = SrcVT.isVector() ? SrcVT.changeVectorElementType(MVT::f32) :
+                                 MVT::f32;
+  SDValue Rod = expandRoundInexactToOdd(F32VT, Src, DL, DAG);
+  return getFPExtOrFPRound(DAG, Rod, DL, DstVT);
 }
 
 SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op,
diff --git a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
index 4c01e583713a7..3be911ab9e7f4 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll
@@ -153,9 +153,34 @@ define amdgpu_ps float @v_test_cvt_v2f64_v2bf16_v(<2 x double> %src) {
 ;
 ; GFX-950-LABEL: v_test_cvt_v2f64_v2bf16_v:
 ; GFX-950:       ; %bb.0:
-; GFX-950-NEXT:    v_cvt_f32_f64_e32 v2, v[2:3]
-; GFX-950-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
-; GFX-950-NEXT:    v_cvt_pk_bf16_f32 v0, v0, v2
+; GFX-950-NEXT:    v_mov_b32_e32 v4, v3
+; GFX-950-NEXT:    v_and_b32_e32 v3, 0x7fffffff, v4
+; GFX-950-NEXT:    v_mov_b32_e32 v5, v1
+; GFX-950-NEXT:    v_cvt_f32_f64_e32 v1, v[2:3]
+; GFX-950-NEXT:    v_cvt_f64_f32_e32 v[6:7], v1
+; GFX-950-NEXT:    v_and_b32_e32 v8, 1, v1
+; GFX-950-NEXT:    v_cmp_gt_f64_e64 s[2:3], v[2:3], v[6:7]
+; GFX-950-NEXT:    v_cmp_nlg_f64_e32 vcc, v[2:3], v[6:7]
+; GFX-950-NEXT:    v_cmp_eq_u32_e64 s[0:1], 1, v8
+; GFX-950-NEXT:    v_cndmask_b32_e64 v2, -1, 1, s[2:3]
+; GFX-950-NEXT:    v_add_u32_e32 v2, v1, v2
+; GFX-950-NEXT:    s_or_b64 vcc, vcc, s[0:1]
+; GFX-950-NEXT:    v_cndmask_b32_e32 v1, v2, v1, vcc
+; GFX-950-NEXT:    s_brev_b32 s4, 1
+; GFX-950-NEXT:    v_and_or_b32 v4, v4, s4, v1
+; GFX-950-NEXT:    v_and_b32_e32 v1, 0x7fffffff, v5
+; GFX-950-NEXT:    v_cvt_f32_f64_e32 v6, v[0:1]
+; GFX-950-NEXT:    v_cvt_f64_f32_e32 v[2:3], v6
+; GFX-950-NEXT:    v_and_b32_e32 v7, 1, v6
+; GFX-950-NEXT:    v_cmp_gt_f64_e64 s[2:3], v[0:1], v[2:3]
+; GFX-950-NEXT:    v_cmp_nlg_f64_e32 vcc, v[0:1], v[2:3]
+; GFX-950-NEXT:    v_cmp_eq_u32_e64 s[0:1], 1, v7
+; GFX-950-NEXT:    v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT:    v_add_u32_e32 v0, v6, v0
+; GFX-950-NEXT:    s_or_b64 vcc, vcc, s[0:1]
+; GFX-950-NEXT:    v_cndmask_b32_e32 v0, v0, v6, vcc
+; GFX-950-NEXT:    v_and_or_b32 v0, v5, s4, v0
+; GFX-950-NEXT:    v_cvt_pk_bf16_f32 v0, v0, v4
 ; GFX-950-NEXT:    ; return to shader part epilog
   %res = fptrunc <2 x double> %src to <2 x bfloat>
   %cast = bitcast <2 x bfloat> %res to float
@@ -347,7 +372,18 @@ define amdgpu_ps void @fptrunc_f64_to_bf16(double %a, ptr %out) {
 ;
 ; GFX-950-LABEL: fptrunc_f64_to_bf16:
 ; GFX-950:       ; %bb.0: ; %entry
-; GFX-950-NEXT:    v_cvt_f32_f64_e32 v0, v[0:1]
+; GFX-950-NEXT:    v_cvt_f32_f64_e64 v6, |v[0:1]|
+; GFX-950-NEXT:    v_cvt_f64_f32_e32 v[4:5], v6
+; GFX-950-NEXT:    v_and_b32_e32 v7, 1, v6
+; GFX-950-NEXT:    v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v7
+; GFX-950-NEXT:    v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT:    v_add_u32_e32 v0, v6, v0
+; GFX-950-NEXT:    s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT:    v_cndmask_b32_e32 v0, v0, v6, vcc
+; GFX-950-NEXT:    s_brev_b32 s0, 1
+; GFX-950-NEXT:    v_and_or_b32 v0, v1, s0, v0
 ; GFX-950-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
 ; GFX-950-NEXT:    flat_store_short v[2:3], v0
 ; GFX-950-NEXT:    s_endpgm
@@ -385,7 +421,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_neg(double %a, ptr %out) {
 ;
 ; GFX-950-LABEL: fptrunc_f64_to_bf16_neg:
 ; GFX-950:       ; %bb.0: ; %entry
-; GFX-950-NEXT:    v_cvt_f32_f64_e64 v0, -v[0:1]
+; GFX-950-NEXT:    v_cvt_f32_f64_e64 v7, |v[0:1]|
+; GFX-950-NEXT:    v_cvt_f64_f32_e32 v[4:5], v7
+; GFX-950-NEXT:    v_and_b32_e32 v8, 1, v7
+; GFX-950-NEXT:    v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v8
+; GFX-950-NEXT:    v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT:    v_add_u32_e32 v0, v7, v0
+; GFX-950-NEXT:    s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT:    s_brev_b32 s4, 1
+; GFX-950-NEXT:    v_xor_b32_e32 v6, 0x80000000, v1
+; GFX-950-NEXT:    v_cndmask_b32_e32 v0, v0, v7, vcc
+; GFX-950-NEXT:    v_and_or_b32 v0, v6, s4, v0
 ; GFX-950-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
 ; GFX-950-NEXT:    flat_store_short v[2:3], v0
 ; GFX-950-NEXT:    s_endpgm
@@ -424,7 +472,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_abs(double %a, ptr %out) {
 ;
 ; GFX-950-LABEL: fptrunc_f64_to_bf16_abs:
 ; GFX-950:       ; %bb.0: ; %entry
-; GFX-950-NEXT:    v_cvt_f32_f64_e64 v0, |v[0:1]|
+; GFX-950-NEXT:    v_cvt_f32_f64_e64 v7, |v[0:1]|
+; GFX-950-NEXT:    v_cvt_f64_f32_e32 v[4:5], v7
+; GFX-950-NEXT:    v_and_b32_e32 v8, 1, v7
+; GFX-950-NEXT:    v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
+; GFX-950-NEXT:    v_cmp_eq_u32_e32 vcc, 1, v8
+; GFX-950-NEXT:    v_cndmask_b32_e64 v0, -1, 1, s[2:3]
+; GFX-950-NEXT:    v_add_u32_e32 v0, v7, v0
+; GFX-950-NEXT:    s_or_b64 vcc, s[0:1], vcc
+; GFX-950-NEXT:    v_and_b32_e32 v6, 0x7fffffff, v1
+; GFX-950-NEXT:    v_cndmask_b32_e32 v0, v0, v7, vcc
+; GFX-950-NEXT:    s_brev_b32 s0, 1
+; GFX-950-NEXT:    v_and_or_b32 v0, v6, s0, v0
 ; GFX-950-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
 ; GFX-950-NEXT:    flat_store_short v[2:3], v0
 ; GFX-950-NEXT:    s_endpgm

``````````

</details>


https://github.com/llvm/llvm-project/pull/133995


More information about the llvm-commits mailing list