[llvm] c328c5d - [AMDGPU] Combine to bf16 reciprocal square root. (#154185)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 18 13:07:24 PDT 2025
Author: Stanislav Mekhanoshin
Date: 2025-08-18T13:07:20-07:00
New Revision: c328c5d9117c19555793c548ebccfedc0b972398
URL: https://github.com/llvm/llvm-project/commit/c328c5d9117c19555793c548ebccfedc0b972398
DIFF: https://github.com/llvm/llvm-project/commit/c328c5d9117c19555793c548ebccfedc0b972398.diff
LOG: [AMDGPU] Combine to bf16 reciprocal square root. (#154185)
Co-authored-by: Ivan Kosarev <Ivan.Kosarev at amd.com>
Co-authored-by: Ivan Kosarev <Ivan.Kosarev at amd.com>
Added:
Modified:
llvm/lib/Target/AMDGPU/SIISelLowering.cpp
llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f58fde421f77d..072fb9cc547b0 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -15729,7 +15729,7 @@ SDValue SITargetLowering::performFDivCombine(SDNode *N,
SelectionDAG &DAG = DCI.DAG;
SDLoc SL(N);
EVT VT = N->getValueType(0);
- if (VT != MVT::f16 || !Subtarget->has16BitInsts())
+ if ((VT != MVT::f16 && VT != MVT::bf16) || !Subtarget->has16BitInsts())
return SDValue();
SDValue LHS = N->getOperand(0);
diff --git a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
index 01ebe7d71428b..91831a8d4fecb 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
@@ -82,67 +82,59 @@ define bfloat @v_rcp_bf16_neg(bfloat %x) {
ret bfloat %fdiv
}
-; TODO: Support lowering to v_rsq_bf16.
define bfloat @v_rsq_bf16(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rsq_bf16:
; GFX1250-TRUE16: ; %bb.0:
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
;
; GFX1250-FAKE16-LABEL: v_rsq_bf16:
; GFX1250-FAKE16: ; %bb.0:
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
%fdiv = fdiv contract bfloat 1.0, %sqrt
ret bfloat %fdiv
}
-; TODO: Support lowering to v_rsq_bf16.
define bfloat @v_rsq_bf16_neg(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rsq_bf16_neg:
; GFX1250-TRUE16: ; %bb.0:
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT: v_nop
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.l, -v0.l
+; GFX1250-TRUE16-NEXT: v_xor_b16 v0.l, 0x8000, v0.l
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
;
; GFX1250-FAKE16-LABEL: v_rsq_bf16_neg:
; GFX1250-FAKE16: ; %bb.0:
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT: v_nop
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v0, -v0
+; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v0, 0x8000, v0
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
%fdiv = fdiv contract bfloat -1.0, %sqrt
ret bfloat %fdiv
}
-; TODO: Support lowering to v_rsq_bf16.
define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rsq_bf16_multi_use:
; GFX1250-TRUE16: ; %bb.0:
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v0.l
-; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_1)
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v1.l, v1.l
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v1.h, v1.l
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v1.h, v1.l
; GFX1250-TRUE16-NEXT: v_nop
-; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v0.l
-; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instid1(VALU_DEP_1)
; GFX1250-TRUE16-NEXT: v_mov_b32_e32 v0, v1
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
;
@@ -150,10 +142,9 @@ define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
; GFX1250-FAKE16: ; %bb.0:
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v0
-; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v0
; GFX1250-FAKE16-NEXT: v_nop
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
@@ -163,7 +154,6 @@ define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
ret <2 x bfloat> %r2
}
-; TODO: Support lowering to v_rsq_bf16.
define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract0:
; GFX1250-TRUE16: ; %bb.0:
@@ -187,7 +177,6 @@ define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
ret bfloat %fdiv
}
-; TODO: Support lowering to v_rsq_bf16.
define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract1:
; GFX1250-TRUE16: ; %bb.0:
@@ -211,7 +200,6 @@ define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
ret bfloat %fdiv
}
-; TODO: Support lowering to v_rsq_bf16.
define bfloat @v_neg_rsq_bf16_missing_contract1(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_neg_rsq_bf16_missing_contract1:
; GFX1250-TRUE16: ; %bb.0:
@@ -240,11 +228,8 @@ define <2 x bfloat> @v_rsq_v2bf16(<2 x bfloat> %a) {
; GFX1250-TRUE16: ; %bb.0:
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.h, v0.h
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
;
; GFX1250-FAKE16-LABEL: v_rsq_v2bf16:
@@ -252,12 +237,9 @@ define <2 x bfloat> @v_rsq_v2bf16(<2 x bfloat> %a) {
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
; GFX1250-FAKE16-NEXT: v_lshrrev_b32_e32 v1, 16, v0
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v1
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v0, v0
-; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v1
; GFX1250-FAKE16-NEXT: v_nop
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
@@ -271,11 +253,11 @@ define <2 x bfloat> @v_neg_rsq_v2bf16(<2 x bfloat> %a) {
; GFX1250-TRUE16: ; %bb.0:
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.h, v0.h
-; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
-; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.h, -v0.h
-; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.l, -v0.l
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.h, v0.h
+; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT: v_xor_b16 v0.h, 0x8000, v0.h
+; GFX1250-TRUE16-NEXT: v_xor_b16 v0.l, 0x8000, v0.l
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
;
; GFX1250-FAKE16-LABEL: v_neg_rsq_v2bf16:
@@ -283,13 +265,12 @@ define <2 x bfloat> @v_neg_rsq_v2bf16(<2 x bfloat> %a) {
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
; GFX1250-FAKE16-NEXT: v_lshrrev_b32_e32 v1, 16, v0
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
-; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v1
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v0, -v0
-; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
-; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v1, -v1
-; GFX1250-FAKE16-NEXT: v_nop
+; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v1
+; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v0, 0x8000, v0
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v1, 0x8000, v1
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
%sqrt = call contract <2 x bfloat> @llvm.sqrt.v2bf16(<2 x bfloat> %a)
More information about the llvm-commits
mailing list