[llvm] babdad3 - AMDGPU: Try to unspill VGPRs after rewriting MFMAs to AGPR form (#154323)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 16 08:11:41 PDT 2025


Author: Matt Arsenault
Date: 2025-09-17T00:11:32+09:00
New Revision: babdad3fdbc66b4992654f9b7dc0fa4da85bd843

URL: https://github.com/llvm/llvm-project/commit/babdad3fdbc66b4992654f9b7dc0fa4da85bd843
DIFF: https://github.com/llvm/llvm-project/commit/babdad3fdbc66b4992654f9b7dc0fa4da85bd843.diff

LOG: AMDGPU: Try to unspill VGPRs after rewriting MFMAs to AGPR form (#154323)

After replacing VGPR MFMAs with the AGPR form, we've alleviated VGPR
pressure which may have triggered spills during allocation. Identify
these spill slots, and try to reassign them to newly freed VGPRs,
and replace the spill instructions with copies.

Fixes #154260

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
    llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
index 21cf9cc6878fb..fedb694bfcc2a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
@@ -29,6 +29,8 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/CodeGen/LiveIntervals.h"
 #include "llvm/CodeGen/LiveRegMatrix.h"
+#include "llvm/CodeGen/LiveStacks.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/VirtRegMap.h"
 #include "llvm/InitializePasses.h"
@@ -42,6 +44,9 @@ namespace {
 STATISTIC(NumMFMAsRewrittenToAGPR,
           "Number of MFMA instructions rewritten to use AGPR form");
 
+/// Map from spill slot frame index to list of instructions which reference it.
+using SpillReferenceMap = DenseMap<int, SmallVector<MachineInstr *, 4>>;
+
 class AMDGPURewriteAGPRCopyMFMAImpl {
   MachineFunction &MF;
   const GCNSubtarget &ST;
@@ -51,6 +56,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
   VirtRegMap &VRM;
   LiveRegMatrix &LRM;
   LiveIntervals &LIS;
+  LiveStacks &LSS;
   const RegisterClassInfo &RegClassInfo;
 
   bool attemptReassignmentsToAGPR(SmallSetVector<Register, 4> &InterferingRegs,
@@ -59,10 +65,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
 public:
   AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
                                 LiveRegMatrix &LRM, LiveIntervals &LIS,
+                                LiveStacks &LSS,
                                 const RegisterClassInfo &RegClassInfo)
       : MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
         TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
-        LIS(LIS), RegClassInfo(RegClassInfo) {}
+        LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
 
   bool isRewriteCandidate(const MachineInstr &MI) const {
     return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
@@ -103,6 +110,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
 
   bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const;
   bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const;
+
+  /// Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI
+  /// with a COPY to the replacement register value \p VReg.
+  void replaceSpillWithCopyToVReg(MachineInstr &SpillMI, int SpillFI,
+                                  Register VReg) const;
+
+  /// Create a map from frame index to use instructions for spills. If a use of
+  /// the frame index does not consist only of spill instructions, it will not
+  /// be included in the map.
+  void collectSpillIndexUses(ArrayRef<LiveInterval *> StackIntervals,
+                             SpillReferenceMap &Map) const;
+
+  /// Attempt to unspill VGPRs by finding a free register and replacing the
+  /// spill instructions with copies.
+  void eliminateSpillsOfReassignedVGPRs() const;
+
   bool run(MachineFunction &MF) const;
 };
 
@@ -391,6 +414,138 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
   return MadeChange;
 }
 
+void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg(
+    MachineInstr &SpillMI, int SpillFI, Register VReg) const {
+  const DebugLoc &DL = SpillMI.getDebugLoc();
+  MachineBasicBlock &MBB = *SpillMI.getParent();
+  MachineInstr *NewCopy;
+  if (SpillMI.mayStore()) {
+    NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY), VReg)
+                  .add(SpillMI.getOperand(0));
+  } else {
+    NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY))
+                  .add(SpillMI.getOperand(0))
+                  .addReg(VReg);
+  }
+
+  LIS.ReplaceMachineInstrInMaps(SpillMI, *NewCopy);
+  SpillMI.eraseFromParent();
+}
+
+void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses(
+    ArrayRef<LiveInterval *> StackIntervals, SpillReferenceMap &Map) const {
+
+  SmallSet<int, 4> NeededFrameIndexes;
+  for (const LiveInterval *LI : StackIntervals)
+    NeededFrameIndexes.insert(LI->reg().stackSlotIndex());
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      for (MachineOperand &MO : MI.operands()) {
+        if (!MO.isFI() || !NeededFrameIndexes.count(MO.getIndex()))
+          continue;
+
+        if (TII.isVGPRSpill(MI)) {
+          SmallVector<MachineInstr *, 4> &References = Map[MO.getIndex()];
+          References.push_back(&MI);
+          break;
+        }
+
+        // Verify this was really a spill instruction, if it's not just ignore
+        // all uses.
+
+        // TODO: This should probably be verifier enforced.
+        NeededFrameIndexes.erase(MO.getIndex());
+        Map.erase(MO.getIndex());
+      }
+    }
+  }
+}
+
+void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs() const {
+  unsigned NumSlots = LSS.getNumIntervals();
+  if (NumSlots == 0)
+    return;
+
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  SmallVector<LiveInterval *, 32> StackIntervals;
+  StackIntervals.reserve(NumSlots);
+
+  for (auto &[Slot, LI] : LSS) {
+    if (!MFI.isSpillSlotObjectIndex(Slot) || MFI.isDeadObjectIndex(Slot))
+      continue;
+
+    const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+    if (TRI.hasVGPRs(RC))
+      StackIntervals.push_back(&LI);
+  }
+
+  sort(StackIntervals, [](const LiveInterval *A, const LiveInterval *B) {
+    /// Sort heaviest intervals first to prioritize their unspilling
+    if (A->weight() > B->weight())
+      return true;
+
+    if (A->getSize() > B->getSize())
+      return true;
+
+    // Tie breaker by number to avoid need for stable sort
+    return A->reg().stackSlotIndex() < B->reg().stackSlotIndex();
+  });
+
+  // FIXME: The APIs for dealing with the LiveInterval of a frame index are
+  // cumbersome. LiveStacks owns its LiveIntervals which refer to stack
+  // slots. We cannot use the usual LiveRegMatrix::assign and unassign on these,
+  // and must create a substitute virtual register to do so. This makes
+  // incremental updating here 
diff icult; we need to actually perform the IR
+  // mutation to get the new vreg references in place to compute the register
+  // LiveInterval to perform an assignment to track the new interference
+  // correctly, and we can't simply migrate the LiveInterval we already have.
+  //
+  // To avoid walking through the entire function for each index, pre-collect
+  // all the instructions slot referencess.
+
+  DenseMap<int, SmallVector<MachineInstr *, 4>> SpillSlotReferences;
+  collectSpillIndexUses(StackIntervals, SpillSlotReferences);
+
+  for (LiveInterval *LI : StackIntervals) {
+    int Slot = LI->reg().stackSlotIndex();
+    auto SpillReferences = SpillSlotReferences.find(Slot);
+    if (SpillReferences == SpillSlotReferences.end())
+      continue;
+
+    const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+
+    LLVM_DEBUG(dbgs() << "Trying to eliminate " << printReg(Slot, &TRI)
+                      << " by reassigning\n");
+
+    ArrayRef<MCPhysReg> AllocOrder = RegClassInfo.getOrder(RC);
+
+    for (MCPhysReg PhysReg : AllocOrder) {
+      if (LRM.checkInterference(*LI, PhysReg) != LiveRegMatrix::IK_Free)
+        continue;
+
+      LLVM_DEBUG(dbgs() << "Reassigning " << *LI << " to "
+                        << printReg(PhysReg, &TRI) << '\n');
+
+      const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
+      Register NewVReg = MRI.createVirtualRegister(RC);
+
+      for (MachineInstr *SpillMI : SpillReferences->second)
+        replaceSpillWithCopyToVReg(*SpillMI, Slot, NewVReg);
+
+      // TODO: We should be able to transfer the information from the stack
+      // slot's LiveInterval without recomputing from scratch with the
+      // replacement vreg uses.
+      LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval(NewVReg);
+      VRM.grow();
+      LRM.assign(NewLI, PhysReg);
+      MFI.RemoveStackObject(Slot);
+      break;
+    }
+  }
+}
+
 bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
   // This only applies on subtargets that have a configurable AGPR vs. VGPR
   // allocation.
@@ -417,6 +572,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
       MadeChange = true;
   }
 
+  // If we've successfully rewritten some MFMAs, we've alleviated some VGPR
+  // pressure. See if we can eliminate some spills now that those registers are
+  // more available.
+  if (MadeChange)
+    eliminateSpillsOfReassignedVGPRs();
+
   return MadeChange;
 }
 
@@ -440,10 +601,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
     AU.addRequired<LiveIntervalsWrapperPass>();
     AU.addRequired<VirtRegMapWrapperLegacy>();
     AU.addRequired<LiveRegMatrixWrapperLegacy>();
+    AU.addRequired<LiveStacksWrapperLegacy>();
 
     AU.addPreserved<LiveIntervalsWrapperPass>();
     AU.addPreserved<VirtRegMapWrapperLegacy>();
     AU.addPreserved<LiveRegMatrixWrapperLegacy>();
+    AU.addPreserved<LiveStacksWrapperLegacy>();
+
     AU.setPreservesAll();
     MachineFunctionPass::getAnalysisUsage(AU);
   }
@@ -456,6 +620,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
 INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
 INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
+INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
 INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
                     "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
 
@@ -474,8 +639,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
   auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
   auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
   auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
-
-  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
+  auto &LSS = getAnalysis<LiveStacksWrapperLegacy>().getLS();
+  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
   return Impl.run(MF);
 }
 
@@ -485,13 +650,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
   VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(MF);
   LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(MF);
   LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(MF);
+  LiveStacks &LSS = MFAM.getResult<LiveStacksAnalysis>(MF);
   RegisterClassInfo RegClassInfo;
   RegClassInfo.runOnMachineFunction(MF);
 
-  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
+  AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
   if (!Impl.run(MF))
     return PreservedAnalyses::all();
   auto PA = getMachineFunctionPassPreservedAnalyses();
   PA.preserveSet<CFGAnalyses>();
+  PA.preserve<LiveStacksAnalysis>();
   return PA;
 }

diff  --git a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
index 122d46b39ff32..8878e9b65a088 100644
--- a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
+++ b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll
@@ -101,13 +101,8 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
 ; CHECK-NEXT:    v_accvgpr_read_b32 v2, a2
 ; CHECK-NEXT:    v_accvgpr_read_b32 v3, a3
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[10:13]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
 ; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    ;;#ASMSTART
 ; CHECK-NEXT:    ; def a[0:31]
@@ -147,12 +142,7 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
+; CHECK-NEXT:    global_store_dwordx4 v0, v[10:13], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload
@@ -311,26 +301,16 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
 ; CHECK-NEXT:    v_accvgpr_write_b32 a33, v1
 ; CHECK-NEXT:    v_accvgpr_write_b32 a32, v0
 ; CHECK-NEXT:    v_accvgpr_read_b32 v7, a3
+; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    v_accvgpr_read_b32 v6, a2
 ; CHECK-NEXT:    v_accvgpr_read_b32 v5, a1
 ; CHECK-NEXT:    v_accvgpr_read_b32 v4, a0
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[10:13]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
 ; CHECK-NEXT:    ;;#ASMSTART
-; CHECK-NEXT:    ; def v[0:3]
+; CHECK-NEXT:    ; def v[14:17]
 ; CHECK-NEXT:    ;;#ASMEND
-; CHECK-NEXT:    buffer_store_dword v0, off, s[0:3], s32 offset:208 ; 4-byte Folded Spill
-; CHECK-NEXT:    s_nop 0
-; CHECK-NEXT:    buffer_store_dword v1, off, s[0:3], s32 offset:212 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v2, off, s[0:3], s32 offset:216 ; 4-byte Folded Spill
-; CHECK-NEXT:    buffer_store_dword v3, off, s[0:3], s32 offset:220 ; 4-byte Folded Spill
-; CHECK-NEXT:    v_mov_b32_e32 v0, 0
 ; CHECK-NEXT:    ;;#ASMSTART
 ; CHECK-NEXT:    ; def a[0:31]
 ; CHECK-NEXT:    ;;#ASMEND
@@ -369,19 +349,9 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
-; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    buffer_load_dword v2, off, s[0:3], s32 offset:208 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v3, off, s[0:3], s32 offset:212 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v4, off, s[0:3], s32 offset:216 ; 4-byte Folded Reload
-; CHECK-NEXT:    buffer_load_dword v5, off, s[0:3], s32 offset:220 ; 4-byte Folded Reload
+; CHECK-NEXT:    global_store_dwordx4 v0, v[10:13], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
-; CHECK-NEXT:    global_store_dwordx4 v0, v[2:5], s[16:17]
+; CHECK-NEXT:    global_store_dwordx4 v0, v[14:17], s[16:17]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0)
 ; CHECK-NEXT:    buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
 ; CHECK-NEXT:    buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload


        


More information about the llvm-commits mailing list