[llvm] AMDGPU: Custom lower fptrunc vectors for f32 -> f16 (PR #141883)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 22:47:11 PDT 2025


================
@@ -12,18 +12,124 @@ define <2 x half> @v_test_cvt_v2f32_v2f16(<2 x float> %src) {
   ret <2 x half> %res
 }
 
-define half @fptrunc_v2f32_v2f16_then_extract(<2 x float> %src) {
-; GFX950-LABEL: fptrunc_v2f32_v2f16_then_extract:
+define <4 x half> @v_test_cvt_v4f32_v4f16(<4 x float> %src) {
+; GFX950-LABEL: v_test_cvt_v4f32_v4f16:
 ; GFX950:       ; %bb.0:
 ; GFX950-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
 ; GFX950-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
-; GFX950-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v1, v2, v3
+; GFX950-NEXT:    s_setpc_b64 s[30:31]
+  %res = fptrunc <4 x float> %src to <4 x half>
+  ret <4 x half> %res
+}
+
+define <8 x half> @v_test_cvt_v8f32_v2f16(<8 x float> %src) {
+; GFX950-LABEL: v_test_cvt_v8f32_v2f16:
+; GFX950:       ; %bb.0:
+; GFX950-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v1, v2, v3
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v2, v4, v5
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v3, v6, v7
+; GFX950-NEXT:    s_setpc_b64 s[30:31]
+  %res = fptrunc <8 x float> %src to <8 x half>
+  ret <8 x half> %res
+}
+
+define half @fptrunc_v2f32_v2f16_extract_uses(<2 x float> %src) {
+; GFX950-LABEL: fptrunc_v2f32_v2f16_extract_uses:
+; GFX950:       ; %bb.0:
+; GFX950-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
 ; GFX950-NEXT:    s_setpc_b64 s[30:31]
   %vec_half = fptrunc <2 x float> %src to <2 x half>
-  %first = extractelement <2 x half> %vec_half, i64 1
-  %second = extractelement <2 x half> %vec_half, i64 0
-  %res = fadd half %first, %second
-  ret half %res
+  %f0 = extractelement <2 x half> %vec_half, i64 0
+  %f1 = extractelement <2 x half> %vec_half, i64 1
+  %rslt = fadd half %f0, %f1
+  ret half %rslt
+}
+
+define half @fptrunc_v4f32_v4f16_extract_uses(<4 x float> %vec_float) {
+; GFX950-SDAG-LABEL: fptrunc_v4f32_v4f16_extract_uses:
+; GFX950-SDAG:       ; %bb.0:
+; GFX950-SDAG-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v2, v2, v3
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v1, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-SDAG-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX950-GISEL-LABEL: fptrunc_v4f32_v4f16_extract_uses:
+; GFX950-GISEL:       ; %bb.0:
+; GFX950-GISEL-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v1, v2, v3
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v1, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-GISEL-NEXT:    s_setpc_b64 s[30:31]
+  %vec_half = fptrunc <4 x float> %vec_float to <4 x half>
+  %f0 = extractelement <4 x half> %vec_half, i64 0
+  %f1 = extractelement <4 x half> %vec_half, i64 1
+  %f2 = extractelement <4 x half> %vec_half, i64 2
+  %f3 = extractelement <4 x half> %vec_half, i64 3
+  %sum0 = fadd half %f0, %f1
+  %sum1 = fadd half %f2, %f3
+  %rslt = fadd half %sum0, %sum1
+  ret half %rslt
+}
+
+define half @fptrunc_v8f32_v8f16_extract_uses(<8 x float> %vec_float) {
+; GFX950-SDAG-LABEL: fptrunc_v8f32_v8f16_extract_uses:
+; GFX950-SDAG:       ; %bb.0:
+; GFX950-SDAG-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v6, v6, v7
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v4, v4, v5
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v2, v2, v3
+; GFX950-SDAG-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v1, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v2, v4, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_sdwa v3, v6, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-SDAG-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-SDAG-NEXT:    v_add_f16_e32 v1, v2, v3
+; GFX950-SDAG-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-SDAG-NEXT:    s_setpc_b64 s[30:31]
+;
+; GFX950-GISEL-LABEL: fptrunc_v8f32_v8f16_extract_uses:
+; GFX950-GISEL:       ; %bb.0:
+; GFX950-GISEL-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v0, v0, v1
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v1, v2, v3
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v2, v4, v5
+; GFX950-GISEL-NEXT:    v_cvt_pk_f16_f32 v3, v6, v7
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v1, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v2, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_sdwa v3, v3, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
+; GFX950-GISEL-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-GISEL-NEXT:    v_add_f16_e32 v1, v2, v3
+; GFX950-GISEL-NEXT:    v_add_f16_e32 v0, v0, v1
+; GFX950-GISEL-NEXT:    s_setpc_b64 s[30:31]
+  %vec_half = fptrunc <8 x float> %vec_float to <8 x half>
+  %f0 = extractelement <8 x half> %vec_half, i64 0
+  %f1 = extractelement <8 x half> %vec_half, i64 1
+  %f2 = extractelement <8 x half> %vec_half, i64 2
+  %f3 = extractelement <8 x half> %vec_half, i64 3
+  %f4 = extractelement <8 x half> %vec_half, i64 4
+  %f5 = extractelement <8 x half> %vec_half, i64 5
+  %f6 = extractelement <8 x half> %vec_half, i64 6
+  %f7 = extractelement <8 x half> %vec_half, i64 7
+  %sum0 = fadd half %f0, %f1
+  %sum1 = fadd half %f2, %f3
+  %sum2 = fadd half %f4, %f5
+  %sum3 = fadd half %f6, %f7
+  %sum4 = fadd half %sum0, %sum1
+  %sum5 = fadd half %sum2, %sum3
+  %rslt = fadd half %sum4, %sum5
+  ret half %rslt
 }
----------------
arsenm wrote:

Test 16 x

https://github.com/llvm/llvm-project/pull/141883


More information about the llvm-commits mailing list