[llvm] [AMDGPU][SDAG] Use the f16 lowering for bf16 safe divisions. (PR #147530)
Ivan Kosarev via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 8 07:04:15 PDT 2025
https://github.com/kosarev created https://github.com/llvm/llvm-project/pull/147530
None
>From 1d9dabbc528285746067cdb214056e023a98ac3e Mon Sep 17 00:00:00 2001
From: Ivan Kosarev <ivan.kosarev at amd.com>
Date: Tue, 8 Jul 2025 14:09:55 +0100
Subject: [PATCH] [AMDGPU][SDAG] Use the f16 lowering for bf16 safe divisions.
---
llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 15 ++-
llvm/test/CodeGen/AMDGPU/bf16.ll | 120 ++++++++++------------
2 files changed, 64 insertions(+), 71 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index b083a9014737b..bd0bb38570a8f 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -624,7 +624,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
Expand);
setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom);
setOperationAction(ISD::FFREXP, MVT::f16, Custom);
- setOperationAction(ISD::FDIV, MVT::f16, Custom);
+ setOperationAction(ISD::FDIV, {MVT::f16, MVT::bf16}, Custom);
// F16 - VOP3 Actions.
setOperationAction(ISD::FMA, MVT::f16, Legal);
@@ -11229,6 +11229,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
SDLoc SL(Op);
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
+ EVT VT = Op.getValueType();
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
@@ -11265,10 +11266,14 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
DAG.getConstant(0xff800000, SL, MVT::i32));
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
- SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
+
+ EVT FixupVT = VT == MVT::bf16 ? MVT::f32 : VT;
+ SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, FixupVT, Quot,
DAG.getTargetConstant(0, SL, MVT::i32));
- return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
- Op->getFlags());
+ SDValue Fixup = DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, FixupVT, RDst, RHS, LHS,
+ Op->getFlags());
+ return DAG.getNode(ISD::FP_ROUND, SL, VT, Fixup,
+ DAG.getTargetConstant(0, SL, MVT::i32));
}
// Faster 2.5 ULP division that does not support denormals.
@@ -11531,7 +11536,7 @@ SDValue SITargetLowering::LowerFDIV(SDValue Op, SelectionDAG &DAG) const {
if (VT == MVT::f64)
return LowerFDIV64(Op, DAG);
- if (VT == MVT::f16)
+ if (VT == MVT::f16 || VT == MVT::bf16)
return LowerFDIV16(Op, DAG);
llvm_unreachable("Unexpected type for fdiv");
diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll
index 2bdf994496421..2724c16acbcc9 100644
--- a/llvm/test/CodeGen/AMDGPU/bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/bf16.ll
@@ -18494,18 +18494,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
; GFX8-LABEL: v_fdiv_bf16:
; GFX8: ; %bb.0:
; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX8-NEXT: v_lshlrev_b32_e32 v0, 16, v0
-; GFX8-NEXT: v_lshlrev_b32_e32 v1, 16, v1
-; GFX8-NEXT: v_div_scale_f32 v2, s[4:5], v1, v1, v0
-; GFX8-NEXT: v_div_scale_f32 v3, vcc, v0, v1, v0
-; GFX8-NEXT: v_rcp_f32_e32 v4, v2
-; GFX8-NEXT: v_fma_f32 v5, -v2, v4, 1.0
-; GFX8-NEXT: v_fma_f32 v4, v5, v4, v4
-; GFX8-NEXT: v_mul_f32_e32 v5, v3, v4
-; GFX8-NEXT: v_fma_f32 v6, -v2, v5, v3
-; GFX8-NEXT: v_fma_f32 v5, v6, v4, v5
-; GFX8-NEXT: v_fma_f32 v2, -v2, v5, v3
-; GFX8-NEXT: v_div_fmas_f32 v2, v2, v4, v5
+; GFX8-NEXT: v_lshlrev_b32_e32 v2, 16, v1
+; GFX8-NEXT: v_rcp_f32_e32 v3, v2
+; GFX8-NEXT: v_lshlrev_b32_e32 v4, 16, v0
+; GFX8-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX8-NEXT: v_mad_f32 v6, -v2, v5, v4
+; GFX8-NEXT: v_mac_f32_e32 v5, v6, v3
+; GFX8-NEXT: v_mad_f32 v2, -v2, v5, v4
+; GFX8-NEXT: v_mul_f32_e32 v2, v2, v3
+; GFX8-NEXT: v_and_b32_e32 v2, 0xff800000, v2
+; GFX8-NEXT: v_add_f32_e32 v2, v2, v5
; GFX8-NEXT: v_div_fixup_f32 v0, v2, v1, v0
; GFX8-NEXT: v_bfe_u32 v1, v0, 16, 1
; GFX8-NEXT: v_add_u32_e32 v1, vcc, v1, v0
@@ -18519,23 +18517,21 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
; GFX9-LABEL: v_fdiv_bf16:
; GFX9: ; %bb.0:
; GFX9-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX9-NEXT: v_lshlrev_b32_e32 v0, 16, v0
-; GFX9-NEXT: v_lshlrev_b32_e32 v1, 16, v1
-; GFX9-NEXT: v_div_scale_f32 v2, s[4:5], v1, v1, v0
-; GFX9-NEXT: v_div_scale_f32 v3, vcc, v0, v1, v0
+; GFX9-NEXT: v_lshlrev_b32_e32 v2, 16, v1
+; GFX9-NEXT: v_rcp_f32_e32 v3, v2
+; GFX9-NEXT: v_lshlrev_b32_e32 v4, 16, v0
; GFX9-NEXT: s_movk_i32 s4, 0x7fff
-; GFX9-NEXT: v_rcp_f32_e32 v4, v2
-; GFX9-NEXT: v_fma_f32 v5, -v2, v4, 1.0
-; GFX9-NEXT: v_fma_f32 v4, v5, v4, v4
-; GFX9-NEXT: v_mul_f32_e32 v5, v3, v4
-; GFX9-NEXT: v_fma_f32 v6, -v2, v5, v3
-; GFX9-NEXT: v_fma_f32 v5, v6, v4, v5
-; GFX9-NEXT: v_fma_f32 v2, -v2, v5, v3
-; GFX9-NEXT: v_div_fmas_f32 v2, v2, v4, v5
+; GFX9-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX9-NEXT: v_mad_f32 v6, -v2, v5, v4
+; GFX9-NEXT: v_mac_f32_e32 v5, v6, v3
+; GFX9-NEXT: v_mad_f32 v2, -v2, v5, v4
+; GFX9-NEXT: v_mul_f32_e32 v2, v2, v3
+; GFX9-NEXT: v_and_b32_e32 v2, 0xff800000, v2
+; GFX9-NEXT: v_add_f32_e32 v2, v2, v5
; GFX9-NEXT: v_div_fixup_f32 v0, v2, v1, v0
; GFX9-NEXT: v_bfe_u32 v1, v0, 16, 1
-; GFX9-NEXT: v_or_b32_e32 v2, 0x400000, v0
; GFX9-NEXT: v_add3_u32 v1, v1, v0, s4
+; GFX9-NEXT: v_or_b32_e32 v2, 0x400000, v0
; GFX9-NEXT: v_cmp_u_f32_e32 vcc, v0, v0
; GFX9-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc
; GFX9-NEXT: v_lshrrev_b32_e32 v0, 16, v0
@@ -18544,18 +18540,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
; GFX10-LABEL: v_fdiv_bf16:
; GFX10: ; %bb.0:
; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX10-NEXT: v_lshlrev_b32_e32 v0, 16, v0
-; GFX10-NEXT: v_lshlrev_b32_e32 v1, 16, v1
-; GFX10-NEXT: v_div_scale_f32 v2, s4, v1, v1, v0
-; GFX10-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0
+; GFX10-NEXT: v_lshlrev_b32_e32 v2, 16, v1
+; GFX10-NEXT: v_lshlrev_b32_e32 v4, 16, v0
; GFX10-NEXT: v_rcp_f32_e32 v3, v2
-; GFX10-NEXT: v_fma_f32 v4, -v2, v3, 1.0
-; GFX10-NEXT: v_fmac_f32_e32 v3, v4, v3
-; GFX10-NEXT: v_mul_f32_e32 v4, v5, v3
-; GFX10-NEXT: v_fma_f32 v6, -v2, v4, v5
-; GFX10-NEXT: v_fmac_f32_e32 v4, v6, v3
-; GFX10-NEXT: v_fma_f32 v2, -v2, v4, v5
-; GFX10-NEXT: v_div_fmas_f32 v2, v2, v3, v4
+; GFX10-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX10-NEXT: v_mad_f32 v6, -v2, v5, v4
+; GFX10-NEXT: v_mac_f32_e32 v5, v6, v3
+; GFX10-NEXT: v_mad_f32 v2, -v2, v5, v4
+; GFX10-NEXT: v_mul_f32_e32 v2, v2, v3
+; GFX10-NEXT: v_and_b32_e32 v2, 0xff800000, v2
+; GFX10-NEXT: v_add_f32_e32 v2, v2, v5
; GFX10-NEXT: v_div_fixup_f32 v0, v2, v1, v0
; GFX10-NEXT: v_bfe_u32 v1, v0, 16, 1
; GFX10-NEXT: v_or_b32_e32 v2, 0x400000, v0
@@ -18568,64 +18562,58 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) {
; GFX11TRUE16-LABEL: v_fdiv_bf16:
; GFX11TRUE16: ; %bb.0:
; GFX11TRUE16-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v0, 16, v0
-; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v1, 16, v1
-; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0
+; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v4, 16, v0
+; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v2, 16, v1
+; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1)
; GFX11TRUE16-NEXT: v_rcp_f32_e32 v3, v2
; GFX11TRUE16-NEXT: s_waitcnt_depctr 0xfff
-; GFX11TRUE16-NEXT: v_fma_f32 v4, -v2, v3, 1.0
-; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT: v_fmac_f32_e32 v3, v4, v3
-; GFX11TRUE16-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0
-; GFX11TRUE16-NEXT: v_mul_f32_e32 v4, v5, v3
+; GFX11TRUE16-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX11TRUE16-NEXT: v_fma_f32 v6, -v2, v5, v4
; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT: v_fma_f32 v6, -v2, v4, v5
-; GFX11TRUE16-NEXT: v_fmac_f32_e32 v4, v6, v3
+; GFX11TRUE16-NEXT: v_fmac_f32_e32 v5, v6, v3
+; GFX11TRUE16-NEXT: v_fma_f32 v2, -v2, v5, v4
; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11TRUE16-NEXT: v_fma_f32 v2, -v2, v4, v5
-; GFX11TRUE16-NEXT: v_div_fmas_f32 v2, v2, v3, v4
+; GFX11TRUE16-NEXT: v_mul_f32_e32 v2, v2, v3
+; GFX11TRUE16-NEXT: v_and_b32_e32 v2, 0xff800000, v2
; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11TRUE16-NEXT: v_add_f32_e32 v2, v2, v5
; GFX11TRUE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0
+; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3)
; GFX11TRUE16-NEXT: v_bfe_u32 v1, v0, 16, 1
; GFX11TRUE16-NEXT: v_or_b32_e32 v2, 0x400000, v0
; GFX11TRUE16-NEXT: v_cmp_u_f32_e32 vcc_lo, v0, v0
-; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX11TRUE16-NEXT: v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX11TRUE16-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc_lo
-; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1)
; GFX11TRUE16-NEXT: v_mov_b16_e32 v0.l, v0.h
; GFX11TRUE16-NEXT: s_setpc_b64 s[30:31]
;
; GFX11FAKE16-LABEL: v_fdiv_bf16:
; GFX11FAKE16: ; %bb.0:
; GFX11FAKE16-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v0, 16, v0
-; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v1, 16, v1
-; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0
+; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v4, 16, v0
+; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v2, 16, v1
+; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1)
; GFX11FAKE16-NEXT: v_rcp_f32_e32 v3, v2
; GFX11FAKE16-NEXT: s_waitcnt_depctr 0xfff
-; GFX11FAKE16-NEXT: v_fma_f32 v4, -v2, v3, 1.0
-; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT: v_fmac_f32_e32 v3, v4, v3
-; GFX11FAKE16-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0
-; GFX11FAKE16-NEXT: v_mul_f32_e32 v4, v5, v3
+; GFX11FAKE16-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX11FAKE16-NEXT: v_fma_f32 v6, -v2, v5, v4
; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT: v_fma_f32 v6, -v2, v4, v5
-; GFX11FAKE16-NEXT: v_fmac_f32_e32 v4, v6, v3
+; GFX11FAKE16-NEXT: v_fmac_f32_e32 v5, v6, v3
+; GFX11FAKE16-NEXT: v_fma_f32 v2, -v2, v5, v4
; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11FAKE16-NEXT: v_fma_f32 v2, -v2, v4, v5
-; GFX11FAKE16-NEXT: v_div_fmas_f32 v2, v2, v3, v4
+; GFX11FAKE16-NEXT: v_mul_f32_e32 v2, v2, v3
+; GFX11FAKE16-NEXT: v_and_b32_e32 v2, 0xff800000, v2
; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11FAKE16-NEXT: v_add_f32_e32 v2, v2, v5
; GFX11FAKE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0
+; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3)
; GFX11FAKE16-NEXT: v_bfe_u32 v1, v0, 16, 1
; GFX11FAKE16-NEXT: v_or_b32_e32 v2, 0x400000, v0
; GFX11FAKE16-NEXT: v_cmp_u_f32_e32 vcc_lo, v0, v0
-; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX11FAKE16-NEXT: v_add3_u32 v1, v1, v0, 0x7fff
+; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX11FAKE16-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc_lo
-; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1)
; GFX11FAKE16-NEXT: v_lshrrev_b32_e32 v0, 16, v0
; GFX11FAKE16-NEXT: s_setpc_b64 s[30:31]
%op = fdiv bfloat %a, %b
More information about the llvm-commits
mailing list