[llvm-branch-commits] [llvm] AMDGPU: Try to unspill VGPRs after rewriting MFMAs to AGPR form (PR #154323)

Matt Arsenault via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 21 06:44:02 PDT 2025


https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/154323

>From 5adb38cc0cffff45d976d54e5faddd10814b2b42 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Tue, 19 Aug 2025 13:12:37 +0900
Subject: [PATCH] AMDGPU: Try to unspill VGPRs after rewriting MFMAs to AGPR
 form

After replacing VGPR MFMAs with the AGPR form, we've alleviated VGPR
pressure which may have triggered spills during allocation. Identify
these spill slots, and try to reassign them to newly freed VGPRs,
and replace the spill instructions with copies.

Fixes #154260
---
 .../AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp      | 169 +++++++++++++++++-
 .../unspill-vgpr-after-rewrite-vgpr-mfma.ll   |  44 +----
 2 files changed, 172 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
index 639796c6cefff..218dc90d9935f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
@@ -28,6 +28,7 @@
 #include "SIRegisterInfo.h"
 #include "llvm/CodeGen/LiveIntervals.h"
 #include "llvm/CodeGen/LiveRegMatrix.h"
+#include "llvm/CodeGen/LiveStacks.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/VirtRegMap.h"
 #include "llvm/InitializePasses.h"
@@ -38,6 +39,9 @@ using namespace llvm;
 
 namespace {
 
+/// Map from spill slot frame index to list of instructions which reference it.
+using SpillReferenceMap = DenseMap<int, SmallVector<MachineInstr *, 4>>;
+
 class AMDGPURewriteAGPRCopyMFMAImpl {
   MachineFunction &MF;
   const GCNSubtarget &ST;
@@ -47,6 +51,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
   VirtRegMap &VRM;
   LiveRegMatrix &LRM;
   LiveIntervals &LIS;
+  LiveStacks &LSS;
   const RegisterClassInfo &RegClassInfo;
 
   bool attemptReassignmentsToAGPR(SmallSetVector<Register, 4> &InterferingRegs,
@@ -55,10 +60,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
 public:
   AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
                                 LiveRegMatrix &LRM, LiveIntervals &LIS,
+                                LiveStacks &LSS,
                                 const RegisterClassInfo &RegClassInfo)
       : MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
         TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
-        LIS(LIS), RegClassInfo(RegClassInfo) {}
+        LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
 
   bool isRewriteCandidate(const MachineInstr &MI) const {
     return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
@@ -106,6 +112,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
 
   bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const;
   bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const;
+
+  /// Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI
+  /// with a COPY to the replacement register value \p VReg.
+  void replaceSpillWithCopyToVReg(MachineInstr &SpillMI, int SpillFI,
+                                  Register VReg) const;
+
+  /// Create a map from frame index to use instructions for spills. If a use of
+  /// the frame index does not consist only of spill instructions, it will not
+  /// be included in the map.
+  void collectSpillIndexUses(ArrayRef<LiveInterval *> StackIntervals,
+                             SpillReferenceMap &Map) const;
+
+  /// Attempt to unspill VGPRs by finding a free register and replacing the
+  /// spill instructions with copies.
+  void eliminateSpillsOfReassignedVGPRs() const;
+
   bool run(MachineFunction &MF) const;
 };
 
@@ -392,6 +414,133 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
   return MadeChange;
 }
 
+void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg(
+    MachineInstr &SpillMI, int SpillFI, Register VReg) const {
+  const DebugLoc &DL = SpillMI.getDebugLoc();
+  MachineBasicBlock &MBB = *SpillMI.getParent();
+  MachineInstr *NewCopy;
+  if (SpillMI.mayStore()) {
+    NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY), VReg)
+                  .add(SpillMI.getOperand(0));
+  } else {
+    NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY))
+                  .add(SpillMI.getOperand(0))
+                  .addReg(VReg);
+  }
+
+  LIS.ReplaceMachineInstrInMaps(SpillMI, *NewCopy);
+  SpillMI.eraseFromParent();
+}
+
+void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses(
+    ArrayRef<LiveInterval *> StackIntervals, SpillReferenceMap &Map) const {
+
+  SmallSet<int, 4> NeededFrameIndexes;
+  for (const LiveInterval *LI : StackIntervals)
+    NeededFrameIndexes.insert(LI->reg().stackSlotIndex());
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      for (MachineOperand &MO : MI.operands()) {
+        if (!MO.isFI() || !NeededFrameIndexes.count(MO.getIndex()))
+          continue;
+
+        SmallVector<MachineInstr *, 4> &References = Map[MO.getIndex()];
+        if (TII.isVGPRSpill(MI)) {
+          References.push_back(&MI);
+          break;
+        }
+
+        // Verify this was really a spill instruction, if it's not just ignore
+        // all uses.
+
+        // TODO: This should probably be verifier enforced.
+        NeededFrameIndexes.erase(MO.getIndex());
+        Map.erase(MO.getIndex());
+      }
+    }
+  }
+}
+
+void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs() const {
+  unsigned NumSlots = LSS.getNumIntervals();
+  if (NumSlots == 0)
+    return;
+
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  SmallVector<LiveInterval *, 32> StackIntervals;
+  StackIntervals.reserve(NumSlots);
+
+  for (auto I = LSS.begin(), E = LSS.end(); I != E; ++I) {
+    int Slot = I->first;
+    if (!MFI.isSpillSlotObjectIndex(Slot) || MFI.isDeadObjectIndex(Slot))
+      continue;
+
+    LiveInterval &LI = I->second;
+    const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+    if (TRI.hasVGPRs(RC))
+      StackIntervals.push_back(&LI);
+  }
+
+  /// Sort heaviest intervals first to prioritize their unspilling
+  sort(StackIntervals, [](const LiveInterval *A, const LiveInterval *B) {
+    return A->weight() > B->weight();
+  });
+
+  // FIXME: The APIs for dealing with the LiveInterval of a frame index are
+  // cumbersome. LiveStacks owns its LiveIntervals which refer to stack
+  // slots. We cannot use the usual LiveRegMatrix::assign and unassign on these,
+  // and must create a substitute virtual register to do so. This makes
+  // incremental updating here difficult; we need to actually perform the IR
+  // mutation to get the new vreg references in place to compute the register
+  // LiveInterval to perform an assignment to track the new interference
+  // correctly, and we can't simply migrate the LiveInterval we already have.
+  //
+  // To avoid walking through the entire function for each index, pre-collect
+  // all the instructions slot referencess.
+
+  DenseMap<int, SmallVector<MachineInstr *, 4>> SpillSlotReferences;
+  collectSpillIndexUses(StackIntervals, SpillSlotReferences);
+
+  for (LiveInterval *LI : StackIntervals) {
+    int Slot = LI->reg().stackSlotIndex();
+    auto SpillReferences = SpillSlotReferences.find(Slot);
+    if (SpillReferences == SpillSlotReferences.end())
+      continue;
+
+    const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+
+    LLVM_DEBUG(dbgs() << "Trying to eliminate " << printReg(Slot, &TRI)
+                      << " by reassigning\n");
+
+    ArrayRef<MCPhysReg> AllocOrder = RegClassInfo.getOrder(RC);
+
+    for (MCPhysReg PhysReg : AllocOrder) {
+      if (LRM.checkInterference(*LI, PhysReg) != LiveRegMatrix::IK_Free)
+        continue;
+
+      LLVM_DEBUG(dbgs() << "Reassigning " << *LI << " to "
+                        << printReg(PhysReg, &TRI) << '\n');
+
+      const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+      Register NewVReg = MRI.createVirtualRegister(RC);
+
+      for (MachineInstr *SpillMI : SpillReferences->second)
+        replaceSpillWithCopyToVReg(*SpillMI, Slot, NewVReg);
+
+      // TODO: We should be able to transfer the information from the stack
+      // slot's LiveInterval without recomputing from scratch with the
+      // replacement vreg uses.
+      LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval(NewVReg);
+      VRM.grow();
+      LRM.assign(NewLI, PhysReg);
+      MFI.RemoveStackObject(Slot);
+      break;
+    }
+  }
+}
+
 bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
   // This only applies on subtargets that have a configurable AGPR vs. VGPR
   // allocation.
@@ -418,6 +567,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
       MadeChange = true;
   }
 
+  // If we've successfully rewritten some MFMAs, we've alleviated some VGPR
+  // pressure. See if we can eliminate some spills now that those registers are
+  // more available.
+  if (MadeChange)
+    eliminateSpillsOfReassignedVGPRs();
+
   return MadeChange;
 }
 
@@ -441,10 +596,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
     AU.addRequired<LiveIntervalsWrapperPass>();
     AU.addRequired<VirtRegMapWrapperLegacy>();
     AU.addRequired<LiveRegMatrixWrapperLegacy>();
+    AU.addRequired<LiveStacksWrapperLegacy>();
 
     AU.addPreserved<LiveIntervalsWrapperPass>();
     AU.addPreserved<VirtRegMapWrapperLegacy>();
     AU.addPreserved<LiveRegMatrixWrapperLegacy>();
+    AU.addPreserved<LiveStacksWrapperLegacy>();
+
     AU.setPreservesAll();
     MachineFunctionPass::getAnalysisUsage(AU);
   }
@@ -457,6 +615,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
 INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
 INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
+INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
 INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
                     "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
 
@@ -475,8 +634,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
   auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
   auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
   auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
-
-  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
+  auto &LSS = getAnalysis<LiveStacksWrapperLegacy>().getLS();
+  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
   return Impl.run(MF);
 }
 
@@ -486,13 +645,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
   VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(MF);
   LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(MF);
   LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(MF);
+  LiveStacks &LSS = MFAM.getResult<LiveStacksAnalysis>(MF);
   RegisterClassInfo RegClassInfo;
   RegClassInfo.runOnMachineFunction(MF);
 
-  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
+  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
   if (!Impl.run(MF))
     return PreservedAnalyses::all();
   auto PA = getMachineFunctionPassPreservedAnalyses();
   PA.preserveSet<CFGAnalyses>();
+  PA.preserve<LiveStacksAnalysis>();
   return PA;
 }
diff --git a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
index 122d46b39ff32..3b9e3a00036af 100644
--- a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
+++ b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
@@ -101,13 +101,8 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
 ; CHECK-NEXT:    v_accvgpr_read_b32 v2, a2
 ; CHECK-NEXT:    v_accvgpr_read_b32 v3, a3
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[10:13]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
 ; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    ;;#ASMSTART
 ; CHECK-NEXT:    ; def a[0:31]
@@ -147,12 +142,7 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
+; CHECK-NEXT:    global_store_dwordx4 v0, v[10:13], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload
@@ -311,26 +301,16 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
 ; CHECK-NEXT:    v_accvgpr_write_b32 a33, v1
 ; CHECK-NEXT:    v_accvgpr_write_b32 a32, v0
 ; CHECK-NEXT:    v_accvgpr_read_b32 v7, a3
+; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    v_accvgpr_read_b32 v6, a2
 ; CHECK-NEXT:    v_accvgpr_read_b32 v5, a1
 ; CHECK-NEXT:    v_accvgpr_read_b32 v4, a0
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[14:17]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[10:13]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:208 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:212 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:216 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:220 ; 4-byte Folded Spill
-; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    ;;#ASMSTART
 ; CHECK-NEXT:    ; def a[0:31]
 ; CHECK-NEXT:    ;;#ASMEND
@@ -369,19 +349,9 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:208 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:212 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:216 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:220 ; 4-byte Folded Reload
+; CHECK-NEXT:    global_store_dwordx4 v0, v[14:17], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
+; CHECK-NEXT:    global_store_dwordx4 v0, v[10:13], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload



More information about the llvm-branch-commits mailing list