[llvm] AMDGPU: Rewrite VGPR MFMAs to AGPR when directly copied to AGPR class (PR #152480)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 7 04:15:27 PDT 2025
https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/152480
>From f304598dffcbf0465eb4cf1be36b082c2770519e Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Thu, 7 Aug 2025 19:54:55 +0900
Subject: [PATCH] AMDGPU: Rewrite VGPR MFMAs to AGPR when directly copied to
AGPR class
Previously we were specifically looking for AV_* class registers,
and checking if the physreg assignment. Handle the case where the
copy is to an AGPR in the first place. In hindsight it would have
been way easier to handle this first, and this makes writing tests
a lot easier for the mechanical transforms.
---
.../AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp | 13 +-
.../AMDGPU/rewrite-vgpr-mfma-to-agpr.ll | 195 +++++++++++++++++-
2 files changed, 202 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
index f580f4368110f..c21a9a1894a32 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
@@ -109,12 +109,17 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
// Find AV_* registers assigned to AGPRs.
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
- if (!TRI.isVectorSuperClass(VirtRegRC))
+ if (!TRI.hasAGPRs(VirtRegRC))
continue;
- const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
- if (!TRI.isAGPRClass(AssignedRC))
- continue;
+ const TargetRegisterClass *AssignedRC = VirtRegRC;
+ if (TRI.hasVGPRs(VirtRegRC)) {
+ // If this is an AV register, we have to check if the actual assignment is
+ // to an AGPR
+ AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
+ if (!TRI.isAGPRClass(AssignedRC))
+ continue;
+ }
LiveInterval &LI = LIS.getInterval(VReg);
diff --git a/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll b/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
index 0b43ff2e0bb4e..b35a74e4a80c3 100644
--- a/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
+++ b/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
@@ -200,8 +200,199 @@ bb:
ret void
}
-declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #1
-declare noundef i32 @llvm.amdgcn.workitem.id.x() #2
+; The inline asm requires the value be copied to an AGPR class, not
+; the AV_* pseudo we usually expect for register allocator live range
+; splits.
+define amdgpu_kernel void @test_rewrite_mfma_direct_copy_to_agpr_class(ptr addrspace(1) %arg) #0 {
+; CHECK-LABEL: test_rewrite_mfma_direct_copy_to_agpr_class:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x0
+; CHECK-NEXT: v_and_b32_e32 v0, 0x3ff, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v0, 7, v0
+; CHECK-NEXT: v_mov_b32_e32 v32, 2.0
+; CHECK-NEXT: v_mov_b32_e32 v33, 4.0
+; CHECK-NEXT: s_waitcnt lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 a[28:31], v0, s[0:1] offset:112
+; CHECK-NEXT: global_load_dwordx4 a[24:27], v0, s[0:1] offset:96
+; CHECK-NEXT: global_load_dwordx4 a[20:23], v0, s[0:1] offset:80
+; CHECK-NEXT: global_load_dwordx4 a[16:19], v0, s[0:1] offset:64
+; CHECK-NEXT: global_load_dwordx4 a[12:15], v0, s[0:1] offset:48
+; CHECK-NEXT: global_load_dwordx4 a[8:11], v0, s[0:1] offset:32
+; CHECK-NEXT: global_load_dwordx4 a[4:7], v0, s[0:1] offset:16
+; CHECK-NEXT: global_load_dwordx4 a[0:3], v0, s[0:1]
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 a[0:31], v32, v33, a[0:31]
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:31]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_endpgm
+bb:
+ %id = call i32 @llvm.amdgcn.workitem.id.x()
+ %gep = getelementptr <32 x float>, ptr addrspace(1) %arg, i32 %id
+ %in = load <32 x float>, ptr addrspace(1) %gep, align 128
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 2.0, float 4.0, <32 x float> %in, i32 0, i32 0, i32 0)
+ call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
+ ret void
+}
+
+; TODO: Handle rewriting this case
+define void @test_rewrite_mfma_imm_src2(float %arg0, float %arg1) #0 {
+; CHECK-LABEL: test_rewrite_mfma_imm_src2:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[0:31], v0, v1, 2.0
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v0
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v1
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v2
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a4, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a5, v5
+; CHECK-NEXT: v_accvgpr_write_b32 a6, v6
+; CHECK-NEXT: v_accvgpr_write_b32 a7, v7
+; CHECK-NEXT: v_accvgpr_write_b32 a8, v8
+; CHECK-NEXT: v_accvgpr_write_b32 a9, v9
+; CHECK-NEXT: v_accvgpr_write_b32 a10, v10
+; CHECK-NEXT: v_accvgpr_write_b32 a11, v11
+; CHECK-NEXT: v_accvgpr_write_b32 a12, v12
+; CHECK-NEXT: v_accvgpr_write_b32 a13, v13
+; CHECK-NEXT: v_accvgpr_write_b32 a14, v14
+; CHECK-NEXT: v_accvgpr_write_b32 a15, v15
+; CHECK-NEXT: v_accvgpr_write_b32 a16, v16
+; CHECK-NEXT: v_accvgpr_write_b32 a17, v17
+; CHECK-NEXT: v_accvgpr_write_b32 a18, v18
+; CHECK-NEXT: v_accvgpr_write_b32 a19, v19
+; CHECK-NEXT: v_accvgpr_write_b32 a20, v20
+; CHECK-NEXT: v_accvgpr_write_b32 a21, v21
+; CHECK-NEXT: v_accvgpr_write_b32 a22, v22
+; CHECK-NEXT: v_accvgpr_write_b32 a23, v23
+; CHECK-NEXT: v_accvgpr_write_b32 a24, v24
+; CHECK-NEXT: v_accvgpr_write_b32 a25, v25
+; CHECK-NEXT: v_accvgpr_write_b32 a26, v26
+; CHECK-NEXT: v_accvgpr_write_b32 a27, v27
+; CHECK-NEXT: v_accvgpr_write_b32 a28, v28
+; CHECK-NEXT: v_accvgpr_write_b32 a29, v29
+; CHECK-NEXT: v_accvgpr_write_b32 a30, v30
+; CHECK-NEXT: v_accvgpr_write_b32 a31, v31
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:31]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> splat (float 2.0), i32 0, i32 0, i32 0)
+ call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
+ ret void
+}
+
+; TODO: Handle rewriting this case
+define void @test_rewrite_mfma_subreg_extract0(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract0:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v2
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v5
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+define void @test_rewrite_mfma_subreg_extract1(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract1:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v6
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v7
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v8
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v9
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+; odd offset
+define void @test_rewrite_mfma_subreg_extract2(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract2:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v5
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v6
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+declare <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half>, <4 x half>, <4 x float>, i32 immarg, i32 immarg, i32 immarg) #2
+declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #2
+declare noundef range(i32 0, 1024) i32 @llvm.amdgcn.workitem.id.x() #3
attributes #0 = { nounwind "amdgpu-flat-work-group-size"="1,256" "amdgpu-waves-per-eu"="4,4" }
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
More information about the llvm-commits
mailing list