[llvm] c328c5d - [AMDGPU] Combine to bf16 reciprocal square root. (#154185)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 18 13:07:24 PDT 2025


Author: Stanislav Mekhanoshin
Date: 2025-08-18T13:07:20-07:00
New Revision: c328c5d9117c19555793c548ebccfedc0b972398

URL: https://github.com/llvm/llvm-project/commit/c328c5d9117c19555793c548ebccfedc0b972398
DIFF: https://github.com/llvm/llvm-project/commit/c328c5d9117c19555793c548ebccfedc0b972398.diff

LOG: [AMDGPU] Combine to bf16 reciprocal square root. (#154185)

Co-authored-by: Ivan Kosarev <Ivan.Kosarev at amd.com>

Co-authored-by: Ivan Kosarev <Ivan.Kosarev at amd.com>

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/SIISelLowering.cpp
    llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f58fde421f77d..072fb9cc547b0 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -15729,7 +15729,7 @@ SDValue SITargetLowering::performFDivCombine(SDNode *N,
   SelectionDAG &DAG = DCI.DAG;
   SDLoc SL(N);
   EVT VT = N->getValueType(0);
-  if (VT != MVT::f16 || !Subtarget->has16BitInsts())
+  if ((VT != MVT::f16 && VT != MVT::bf16) || !Subtarget->has16BitInsts())
     return SDValue();
 
   SDValue LHS = N->getOperand(0);

diff  --git a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
index 01ebe7d71428b..91831a8d4fecb 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
@@ -82,67 +82,59 @@ define bfloat @v_rcp_bf16_neg(bfloat %x) {
   ret bfloat %fdiv
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define bfloat @v_rsq_bf16(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rsq_bf16:
 ; GFX1250-TRUE16:       ; %bb.0:
 ; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.l, v0.l
 ; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
 ;
 ; GFX1250-FAKE16-LABEL: v_rsq_bf16:
 ; GFX1250-FAKE16:       ; %bb.0:
 ; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v0, v0
 ; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
   %sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
   %fdiv = fdiv contract bfloat 1.0, %sqrt
   ret bfloat %fdiv
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define bfloat @v_rsq_bf16_neg(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rsq_bf16_neg:
 ; GFX1250-TRUE16:       ; %bb.0:
 ; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT:    v_nop
 ; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e64 v0.l, -v0.l
+; GFX1250-TRUE16-NEXT:    v_xor_b16 v0.l, 0x8000, v0.l
 ; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
 ;
 ; GFX1250-FAKE16-LABEL: v_rsq_bf16_neg:
 ; GFX1250-FAKE16:       ; %bb.0:
 ; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT:    v_nop
 ; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e64 v0, -v0
+; GFX1250-FAKE16-NEXT:    v_xor_b32_e32 v0, 0x8000, v0
 ; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
   %sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
   %fdiv = fdiv contract bfloat -1.0, %sqrt
   ret bfloat %fdiv
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rsq_bf16_multi_use:
 ; GFX1250-TRUE16:       ; %bb.0:
 ; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
 ; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v1.l, v0.l
-; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v1.l, v1.l
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e32 v1.h, v1.l
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v1.h, v1.l
 ; GFX1250-TRUE16-NEXT:    v_nop
-; GFX1250-TRUE16-NEXT:    v_mov_b16_e32 v1.l, v0.l
-; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1) | instid1(VALU_DEP_1)
 ; GFX1250-TRUE16-NEXT:    v_mov_b32_e32 v0, v1
 ; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
 ;
@@ -150,10 +142,9 @@ define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
 ; GFX1250-FAKE16:       ; %bb.0:
 ; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v1, v0
-; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v1, v0
 ; GFX1250-FAKE16-NEXT:    v_nop
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1)
 ; GFX1250-FAKE16-NEXT:    v_perm_b32 v0, v1, v0, 0x5040100
 ; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
   %sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
@@ -163,7 +154,6 @@ define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
   ret <2 x bfloat> %r2
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract0:
 ; GFX1250-TRUE16:       ; %bb.0:
@@ -187,7 +177,6 @@ define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
   ret bfloat %fdiv
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract1:
 ; GFX1250-TRUE16:       ; %bb.0:
@@ -211,7 +200,6 @@ define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
   ret bfloat %fdiv
 }
 
-; TODO: Support lowering to v_rsq_bf16.
 define bfloat @v_neg_rsq_bf16_missing_contract1(bfloat %x) {
 ; GFX1250-TRUE16-LABEL: v_neg_rsq_bf16_missing_contract1:
 ; GFX1250-TRUE16:       ; %bb.0:
@@ -240,11 +228,8 @@ define <2 x bfloat> @v_rsq_v2bf16(<2 x bfloat> %a) {
 ; GFX1250-TRUE16:       ; %bb.0:
 ; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.h, v0.h
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.l, v0.l
 ; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
 ;
 ; GFX1250-FAKE16-LABEL: v_rsq_v2bf16:
@@ -252,12 +237,9 @@ define <2 x bfloat> @v_rsq_v2bf16(<2 x bfloat> %a) {
 ; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
 ; GFX1250-FAKE16-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v1, v1
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v1, v1
 ; GFX1250-FAKE16-NEXT:    v_nop
 ; GFX1250-FAKE16-NEXT:    v_perm_b32 v0, v1, v0, 0x5040100
 ; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
@@ -271,11 +253,11 @@ define <2 x bfloat> @v_neg_rsq_v2bf16(<2 x bfloat> %a) {
 ; GFX1250-TRUE16:       ; %bb.0:
 ; GFX1250-TRUE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-TRUE16-NEXT:    s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT:    v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e64 v0.h, -v0.h
-; GFX1250-TRUE16-NEXT:    v_rcp_bf16_e64 v0.l, -v0.l
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.h, v0.h
+; GFX1250-TRUE16-NEXT:    v_rsq_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT:    v_xor_b16 v0.h, 0x8000, v0.h
+; GFX1250-TRUE16-NEXT:    v_xor_b16 v0.l, 0x8000, v0.l
 ; GFX1250-TRUE16-NEXT:    s_set_pc_i64 s[30:31]
 ;
 ; GFX1250-FAKE16-LABEL: v_neg_rsq_v2bf16:
@@ -283,13 +265,12 @@ define <2 x bfloat> @v_neg_rsq_v2bf16(<2 x bfloat> %a) {
 ; GFX1250-FAKE16-NEXT:    s_wait_loadcnt_dscnt 0x0
 ; GFX1250-FAKE16-NEXT:    s_wait_kmcnt 0x0
 ; GFX1250-FAKE16-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v0, v0
 ; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-FAKE16-NEXT:    v_sqrt_bf16_e32 v1, v1
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e64 v0, -v0
-; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT:    v_rcp_bf16_e64 v1, -v1
-; GFX1250-FAKE16-NEXT:    v_nop
+; GFX1250-FAKE16-NEXT:    v_rsq_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT:    v_xor_b32_e32 v0, 0x8000, v0
+; GFX1250-FAKE16-NEXT:    s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT:    v_xor_b32_e32 v1, 0x8000, v1
 ; GFX1250-FAKE16-NEXT:    v_perm_b32 v0, v1, v0, 0x5040100
 ; GFX1250-FAKE16-NEXT:    s_set_pc_i64 s[30:31]
   %sqrt = call contract <2 x bfloat> @llvm.sqrt.v2bf16(<2 x bfloat> %a)


        


More information about the llvm-commits mailing list