[llvm] b170f17 - [AMDGPU] Add support for safe bfloat16 fdiv on targets with bf16 trans instructions (#154373)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 19 13:03:49 PDT 2025


Author: Shilei Tian
Date: 2025-08-19T16:03:45-04:00
New Revision: b170f17861a12e91e67e9fe5951e2118cd3db164

URL: https://github.com/llvm/llvm-project/commit/b170f17861a12e91e67e9fe5951e2118cd3db164
DIFF: https://github.com/llvm/llvm-project/commit/b170f17861a12e91e67e9fe5951e2118cd3db164.diff

LOG: [AMDGPU] Add support for safe bfloat16 fdiv on targets with bf16 trans instructions (#154373)

Recent changes introduced custom lowering for bf16 fdiv on targets that
support bf16 trans instructions, but only covered the unsafe version.
This PR extends that support to the safe variant.

For the safe version, the op is lowered by converting to float,
performing the div in float, and converting the result back to bf16.
This matches the behavior on targets that don't support bf16 trans
instructions.

Fixes SWDEV-550381.

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIISelLowering.cpp
    llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index a2084074263da..561019bb65549 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -11540,9 +11540,22 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
     return FastLowered;
 
   SDLoc SL(Op);
+  EVT VT = Op.getValueType();
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
 
+  SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
+  SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
+
+  if (VT == MVT::bf16) {
+    SDValue ExtDiv =
+        DAG.getNode(ISD::FDIV, SL, MVT::f32, LHSExt, RHSExt, Op->getFlags());
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ExtDiv,
+                       DAG.getTargetConstant(0, SL, MVT::i32));
+  }
+
+  assert(VT == MVT::f16);
+
   // a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
   // b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
   // r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
@@ -11559,9 +11572,6 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
   // We will use ISD::FMA on targets that don't support ISD::FMAD.
   unsigned FMADOpCode =
       isOperationLegal(ISD::FMAD, MVT::f32) ? ISD::FMAD : ISD::FMA;
-
-  SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
-  SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
   SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
   SDValue Rcp =
       DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt, Op->getFlags());

diff  --git a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
index 91831a8d4fecb..00cde422a2297 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
@@ -2,12 +2,68 @@
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=+real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-TRUE16 %s
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=-real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-FAKE16 %s
 
-/* TODO: Support safe bf16 fdiv lowering.
 define bfloat @v_fdiv_bf16(bfloat %x, bfloat %y) {
+; GFX1250-TRUE16-LABEL: v_fdiv_bf16:
+; GFX1250-TRUE16:       ; %bb.0:
+; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
+; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
+; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v2.l, 0
+; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v2.h, v1.l
+; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v1.h, v0.l
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v1.l, v2.l
+; GFX1250-TRUE16-NEXT:    v_div_scale_f32 v0, null, v2, v2, v1
+; GFX1250-TRUE16-NEXT:    v_div_scale_f32 v4, vcc_lo, v1, v2, v1
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_2) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_rcp_f32_e32 v3, v0
+; GFX1250-TRUE16-NEXT:    s_denorm_mode 15
+; GFX1250-TRUE16-NEXT:    v_nop
+; GFX1250-TRUE16-NEXT:    v_fma_f32 v5, -v0, v3, 1.0
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_fmac_f32_e32 v3, v5, v3
+; GFX1250-TRUE16-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_fma_f32 v6, -v0, v5, v4
+; GFX1250-TRUE16-NEXT:    v_fmac_f32_e32 v5, v6, v3
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_fma_f32 v0, -v0, v5, v4
+; GFX1250-TRUE16-NEXT:    s_denorm_mode 12
+; GFX1250-TRUE16-NEXT:    v_div_fmas_f32 v0, v0, v3, v5
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_div_fixup_f32 v0, v0, v2, v1
+; GFX1250-TRUE16-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
+; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
+;
+; GFX1250-FAKE16-LABEL: v_fdiv_bf16:
+; GFX1250-FAKE16:       ; %bb.0:
+; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
+; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
+; GFX1250-FAKE16-NEXT:    v_dual_lshlrev_b32 v1, 16, v1 :: v_dual_lshlrev_b32 v0, 16, v0
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
+; GFX1250-FAKE16-NEXT:    v_div_scale_f32 v2, null, v1, v1, v0
+; GFX1250-FAKE16-NEXT:    v_div_scale_f32 v4, vcc_lo, v0, v1, v0
+; GFX1250-FAKE16-NEXT:    v_rcp_f32_e32 v3, v2
+; GFX1250-FAKE16-NEXT:    s_denorm_mode 15
+; GFX1250-FAKE16-NEXT:    v_nop
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_fma_f32 v5, -v2, v3, 1.0
+; GFX1250-FAKE16-NEXT:    v_fmac_f32_e32 v3, v5, v3
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_mul_f32_e32 v5, v4, v3
+; GFX1250-FAKE16-NEXT:    v_fma_f32 v6, -v2, v5, v4
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_fmac_f32_e32 v5, v6, v3
+; GFX1250-FAKE16-NEXT:    v_fma_f32 v2, -v2, v5, v4
+; GFX1250-FAKE16-NEXT:    s_denorm_mode 12
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_div_fmas_f32 v2, v2, v3, v5
+; GFX1250-FAKE16-NEXT:    v_div_fixup_f32 v0, v2, v1, v0
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_cvt_pk_bf16_f32 v0, v0, s0
+; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
   %fdiv = fdiv bfloat %x, %y
   ret bfloat %fdiv
 }
-*/
 
 define bfloat @v_rcp_bf16(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rcp_bf16:


        


More information about the llvm-commits mailing list