[llvm] [AMDGPU][SDAG] Use the f16 lowering for bf16 safe divisions. (PR #147530)

Ivan Kosarev via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 8 07:04:15 PDT 2025


https://github.com/kosarev created https://github.com/llvm/llvm-project/pull/147530

None

>From 1d9dabbc528285746067cdb214056e023a98ac3e Mon Sep 17 00:00:00 2001
From: Ivan Kosarev <ivan.kosarev at amd.com>
Date: Tue, 8 Jul 2025 14:09:55 +0100
Subject: [PATCH] [AMDGPU][SDAG] Use the f16 lowering for bf16 safe divisions.

---
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp |  15 ++-
 llvm/test/CodeGen/AMDGPU/bf16.ll          | 120 ++++++++++------------
 2 files changed, 64 insertions(+), 71 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index b083a9014737b..bd0bb38570a8f 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -624,7 +624,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                        Expand);
     setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom);
     setOperationAction(ISD::FFREXP, MVT::f16, Custom);
-    setOperationAction(ISD::FDIV, MVT::f16, Custom);
+    setOperationAction(ISD::FDIV, {MVT::f16, MVT::bf16}, Custom);
 
     // F16 - VOP3 Actions.
     setOperationAction(ISD::FMA, MVT::f16, Legal);
@@ -11229,6 +11229,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
   SDLoc SL(Op);
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
+  EVT VT = Op.getValueType();
 
   // a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
   // b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
@@ -11265,10 +11266,14 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
                         DAG.getConstant(0xff800000, SL, MVT::i32));
   Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
   Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
-  SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
+
+  EVT FixupVT = VT == MVT::bf16 ? MVT::f32 : VT;
+  SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, FixupVT, Quot,
                              DAG.getTargetConstant(0, SL, MVT::i32));
-  return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
-                     Op->getFlags());
+  SDValue Fixup = DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, FixupVT, RDst, RHS, LHS,
+                              Op->getFlags());
+  return DAG.getNode(ISD::FP_ROUND, SL, VT, Fixup,
+                     DAG.getTargetConstant(0, SL, MVT::i32));
 }
 
 // Faster 2.5 ULP division that does not support denormals.
@@ -11531,7 +11536,7 @@ SDValue SITargetLowering::LowerFDIV(SDValue Op, SelectionDAG &DAG) const {
   if (VT == MVT::f64)
     return LowerFDIV64(Op, DAG);
 
-  if (VT == MVT::f16)
+  if (VT == MVT::f16 || VT == MVT::bf16)
     return LowerFDIV16(Op, DAG);
 
   llvm_unreachable("Unexpected type for fdiv");
diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll
index 2bdf994496421..2724c16acbcc9 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16.ll
@@ -18494,18 +18494,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
 ; GFX8-LABEL: v_fdiv_bf16:
 ; GFX8:       ; %bb.0:
 ; GFX8-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX8-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
-; GFX8-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
-; GFX8-NEXT:    v_div_scale_f32 v2, s[4:5], v1, v1, v0
-; GFX8-NEXT:    v_div_scale_f32 v3, vcc, v0, v1, v0
-; GFX8-NEXT:    v_rcp_f32_e32 v4, v2
-; GFX8-NEXT:    v_fma_f32 v5, -v2, v4, 1.0
-; GFX8-NEXT:    v_fma_f32 v4, v5, v4, v4
-; GFX8-NEXT:    v_mul_f32_e32 v5, v3, v4
-; GFX8-NEXT:    v_fma_f32 v6, -v2, v5, v3
-; GFX8-NEXT:    v_fma_f32 v5, v6, v4, v5
-; GFX8-NEXT:    v_fma_f32 v2, -v2, v5, v3
-; GFX8-NEXT:    v_div_fmas_f32 v2, v2, v4, v5
+; GFX8-NEXT:    v_lshlrev_b32_e32 v2, 16, v1
+; GFX8-NEXT:    v_rcp_f32_e32 v3, v2
+; GFX8-NEXT:    v_lshlrev_b32_e32 v4, 16, v0
+; GFX8-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX8-NEXT:    v_mad_f32 v6, -v2, v5, v4
+; GFX8-NEXT:    v_mac_f32_e32 v5, v6, v3
+; GFX8-NEXT:    v_mad_f32 v2, -v2, v5, v4
+; GFX8-NEXT:    v_mul_f32_e32 v2, v2, v3
+; GFX8-NEXT:    v_and_b32_e32 v2, 0xff800000, v2
+; GFX8-NEXT:    v_add_f32_e32 v2, v2, v5
 ; GFX8-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
 ; GFX8-NEXT:    v_bfe_u32 v1, v0, 16, 1
 ; GFX8-NEXT:    v_add_u32_e32 v1, vcc, v1, v0
@@ -18519,23 +18517,21 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
 ; GFX9-LABEL: v_fdiv_bf16:
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX9-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
-; GFX9-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
-; GFX9-NEXT:    v_div_scale_f32 v2, s[4:5], v1, v1, v0
-; GFX9-NEXT:    v_div_scale_f32 v3, vcc, v0, v1, v0
+; GFX9-NEXT:    v_lshlrev_b32_e32 v2, 16, v1
+; GFX9-NEXT:    v_rcp_f32_e32 v3, v2
+; GFX9-NEXT:    v_lshlrev_b32_e32 v4, 16, v0
 ; GFX9-NEXT:    s_movk_i32 s4, 0x7fff
-; GFX9-NEXT:    v_rcp_f32_e32 v4, v2
-; GFX9-NEXT:    v_fma_f32 v5, -v2, v4, 1.0
-; GFX9-NEXT:    v_fma_f32 v4, v5, v4, v4
-; GFX9-NEXT:    v_mul_f32_e32 v5, v3, v4
-; GFX9-NEXT:    v_fma_f32 v6, -v2, v5, v3
-; GFX9-NEXT:    v_fma_f32 v5, v6, v4, v5
-; GFX9-NEXT:    v_fma_f32 v2, -v2, v5, v3
-; GFX9-NEXT:    v_div_fmas_f32 v2, v2, v4, v5
+; GFX9-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX9-NEXT:    v_mad_f32 v6, -v2, v5, v4
+; GFX9-NEXT:    v_mac_f32_e32 v5, v6, v3
+; GFX9-NEXT:    v_mad_f32 v2, -v2, v5, v4
+; GFX9-NEXT:    v_mul_f32_e32 v2, v2, v3
+; GFX9-NEXT:    v_and_b32_e32 v2, 0xff800000, v2
+; GFX9-NEXT:    v_add_f32_e32 v2, v2, v5
 ; GFX9-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
 ; GFX9-NEXT:    v_bfe_u32 v1, v0, 16, 1
-; GFX9-NEXT:    v_or_b32_e32 v2, 0x400000, v0
 ; GFX9-NEXT:    v_add3_u32 v1, v1, v0, s4
+; GFX9-NEXT:    v_or_b32_e32 v2, 0x400000, v0
 ; GFX9-NEXT:    v_cmp_u_f32_e32 vcc, v0, v0
 ; GFX9-NEXT:    v_cndmask_b32_e32 v0, v1, v2, vcc
 ; GFX9-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
@@ -18544,18 +18540,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
 ; GFX10-LABEL: v_fdiv_bf16:
 ; GFX10:       ; %bb.0:
 ; GFX10-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX10-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
-; GFX10-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
-; GFX10-NEXT:    v_div_scale_f32 v2, s4, v1, v1, v0
-; GFX10-NEXT:    v_div_scale_f32 v5, vcc_lo, v0, v1, v0
+; GFX10-NEXT:    v_lshlrev_b32_e32 v2, 16, v1
+; GFX10-NEXT:    v_lshlrev_b32_e32 v4, 16, v0
 ; GFX10-NEXT:    v_rcp_f32_e32 v3, v2
-; GFX10-NEXT:    v_fma_f32 v4, -v2, v3, 1.0
-; GFX10-NEXT:    v_fmac_f32_e32 v3, v4, v3
-; GFX10-NEXT:    v_mul_f32_e32 v4, v5, v3
-; GFX10-NEXT:    v_fma_f32 v6, -v2, v4, v5
-; GFX10-NEXT:    v_fmac_f32_e32 v4, v6, v3
-; GFX10-NEXT:    v_fma_f32 v2, -v2, v4, v5
-; GFX10-NEXT:    v_div_fmas_f32 v2, v2, v3, v4
+; GFX10-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX10-NEXT:    v_mad_f32 v6, -v2, v5, v4
+; GFX10-NEXT:    v_mac_f32_e32 v5, v6, v3
+; GFX10-NEXT:    v_mad_f32 v2, -v2, v5, v4
+; GFX10-NEXT:    v_mul_f32_e32 v2, v2, v3
+; GFX10-NEXT:    v_and_b32_e32 v2, 0xff800000, v2
+; GFX10-NEXT:    v_add_f32_e32 v2, v2, v5
 ; GFX10-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
 ; GFX10-NEXT:    v_bfe_u32 v1, v0, 16, 1
 ; GFX10-NEXT:    v_or_b32_e32 v2, 0x400000, v0
@@ -18568,64 +18562,58 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
 ; GFX11TRUE16-LABEL: v_fdiv_bf16:
 ; GFX11TRUE16:       ; %bb.0:
 ; GFX11TRUE16-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11TRUE16-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
-; GFX11TRUE16-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
-; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT:    v_div_scale_f32 v2, null, v1, v1, v0
+; GFX11TRUE16-NEXT:    v_lshlrev_b32_e32 v4, 16, v0
+; GFX11TRUE16-NEXT:    v_lshlrev_b32_e32 v2, 16, v1
+; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1)
 ; GFX11TRUE16-NEXT:    v_rcp_f32_e32 v3, v2
 ; GFX11TRUE16-NEXT:    s_waitcnt_depctr 0xfff
-; GFX11TRUE16-NEXT:    v_fma_f32 v4, -v2, v3, 1.0
-; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT:    v_fmac_f32_e32 v3, v4, v3
-; GFX11TRUE16-NEXT:    v_div_scale_f32 v5, vcc_lo, v0, v1, v0
-; GFX11TRUE16-NEXT:    v_mul_f32_e32 v4, v5, v3
+; GFX11TRUE16-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX11TRUE16-NEXT:    v_fma_f32 v6, -v2, v5, v4
 ; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT:    v_fma_f32 v6, -v2, v4, v5
-; GFX11TRUE16-NEXT:    v_fmac_f32_e32 v4, v6, v3
+; GFX11TRUE16-NEXT:    v_fmac_f32_e32 v5, v6, v3
+; GFX11TRUE16-NEXT:    v_fma_f32 v2, -v2, v5, v4
 ; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT:    v_fma_f32 v2, -v2, v4, v5
-; GFX11TRUE16-NEXT:    v_div_fmas_f32 v2, v2, v3, v4
+; GFX11TRUE16-NEXT:    v_mul_f32_e32 v2, v2, v3
+; GFX11TRUE16-NEXT:    v_and_b32_e32 v2, 0xff800000, v2
 ; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11TRUE16-NEXT:    v_add_f32_e32 v2, v2, v5
 ; GFX11TRUE16-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
+; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3)
 ; GFX11TRUE16-NEXT:    v_bfe_u32 v1, v0, 16, 1
 ; GFX11TRUE16-NEXT:    v_or_b32_e32 v2, 0x400000, v0
 ; GFX11TRUE16-NEXT:    v_cmp_u_f32_e32 vcc_lo, v0, v0
-; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
 ; GFX11TRUE16-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
 ; GFX11TRUE16-NEXT:    v_cndmask_b32_e32 v0, v1, v2, vcc_lo
-; GFX11TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1)
 ; GFX11TRUE16-NEXT:    v_mov_b16_e32 v0.l, v0.h
 ; GFX11TRUE16-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX11FAKE16-LABEL: v_fdiv_bf16:
 ; GFX11FAKE16:       ; %bb.0:
 ; GFX11FAKE16-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11FAKE16-NEXT:    v_lshlrev_b32_e32 v0, 16, v0
-; GFX11FAKE16-NEXT:    v_lshlrev_b32_e32 v1, 16, v1
-; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT:    v_div_scale_f32 v2, null, v1, v1, v0
+; GFX11FAKE16-NEXT:    v_lshlrev_b32_e32 v4, 16, v0
+; GFX11FAKE16-NEXT:    v_lshlrev_b32_e32 v2, 16, v1
+; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1)
 ; GFX11FAKE16-NEXT:    v_rcp_f32_e32 v3, v2
 ; GFX11FAKE16-NEXT:    s_waitcnt_depctr 0xfff
-; GFX11FAKE16-NEXT:    v_fma_f32 v4, -v2, v3, 1.0
-; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT:    v_fmac_f32_e32 v3, v4, v3
-; GFX11FAKE16-NEXT:    v_div_scale_f32 v5, vcc_lo, v0, v1, v0
-; GFX11FAKE16-NEXT:    v_mul_f32_e32 v4, v5, v3
+; GFX11FAKE16-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX11FAKE16-NEXT:    v_fma_f32 v6, -v2, v5, v4
 ; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT:    v_fma_f32 v6, -v2, v4, v5
-; GFX11FAKE16-NEXT:    v_fmac_f32_e32 v4, v6, v3
+; GFX11FAKE16-NEXT:    v_fmac_f32_e32 v5, v6, v3
+; GFX11FAKE16-NEXT:    v_fma_f32 v2, -v2, v5, v4
 ; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT:    v_fma_f32 v2, -v2, v4, v5
-; GFX11FAKE16-NEXT:    v_div_fmas_f32 v2, v2, v3, v4
+; GFX11FAKE16-NEXT:    v_mul_f32_e32 v2, v2, v3
+; GFX11FAKE16-NEXT:    v_and_b32_e32 v2, 0xff800000, v2
 ; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11FAKE16-NEXT:    v_add_f32_e32 v2, v2, v5
 ; GFX11FAKE16-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
+; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3)
 ; GFX11FAKE16-NEXT:    v_bfe_u32 v1, v0, 16, 1
 ; GFX11FAKE16-NEXT:    v_or_b32_e32 v2, 0x400000, v0
 ; GFX11FAKE16-NEXT:    v_cmp_u_f32_e32 vcc_lo, v0, v0
-; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
 ; GFX11FAKE16-NEXT:    v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
 ; GFX11FAKE16-NEXT:    v_cndmask_b32_e32 v0, v1, v2, vcc_lo
-; GFX11FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1)
 ; GFX11FAKE16-NEXT:    v_lshrrev_b32_e32 v0, 16, v0
 ; GFX11FAKE16-NEXT:    s_setpc_b64 s[30:31]
   %op = fdiv bfloat %a, %b



More information about the llvm-commits mailing list