[llvm] Split vgpr regalloc pipeline (PR #93526)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri May 31 04:55:30 PDT 2024


================
@@ -259,51 +271,86 @@ bool SILowerSGPRSpills::spillCalleeSavedRegs(
   return false;
 }
 
-void SILowerSGPRSpills::extendWWMVirtRegLiveness(MachineFunction &MF,
-                                                 LiveIntervals *LIS) {
-  // TODO: This is a workaround to avoid the unmodelled liveness computed with
-  // whole-wave virtual registers when allocated together with the regular VGPR
-  // virtual registers. Presently, the liveness computed during the regalloc is
-  // only uniform (or single lane aware) and it doesn't take account of the
-  // divergent control flow that exists for our GPUs. Since the WWM registers
-  // can modify inactive lanes, the wave-aware liveness should be computed for
-  // the virtual registers to accurately plot their interferences. Without
-  // having the divergent CFG for the function, it is difficult to implement the
-  // wave-aware liveness info. Until then, we conservatively extend the liveness
-  // of the wwm registers into the entire function so that they won't be reused
-  // without first spilling/splitting their liveranges.
-  SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
-
-  // Insert the IMPLICIT_DEF for the wwm-registers in the entry blocks.
-  for (auto Reg : MFI->getSGPRSpillVGPRs()) {
-    for (MachineBasicBlock *SaveBlock : SaveBlocks) {
-      MachineBasicBlock::iterator InsertBefore = SaveBlock->begin();
-      DebugLoc DL = SaveBlock->findDebugLoc(InsertBefore);
-      auto MIB = BuildMI(*SaveBlock, InsertBefore, DL,
-                         TII->get(AMDGPU::IMPLICIT_DEF), Reg);
-      MFI->setFlag(Reg, AMDGPU::VirtRegFlag::WWM_REG);
-      // Set SGPR_SPILL asm printer flag
-      MIB->setAsmPrinterFlag(AMDGPU::SGPR_SPILL);
-      if (LIS) {
-        LIS->InsertMachineInstrInMaps(*MIB);
+void SILowerSGPRSpills::updateLaneVGPRDomInstr(
+    int FI, MachineBasicBlock *MBB, MachineBasicBlock::iterator InsertPt,
+    DenseMap<Register, MachineBasicBlock::iterator> &LaneVGPRDomInstr) {
+  // For the Def of a virtual LaneVPGR to dominate all its uses, we should
+  // insert an IMPLICIT_DEF before the dominating spill. Switching to a
+  // depth first order doesn't really help since the machine function can be in
+  // the unstructured control flow post-SSA. For each virtual register, hence
+  // finding the common dominator to get either the dominating spill or a block
+  // dominating all spills.
+  SIMachineFunctionInfo *FuncInfo =
+      MBB->getParent()->getInfo<SIMachineFunctionInfo>();
+  ArrayRef<SIRegisterInfo::SpilledReg> VGPRSpills =
+      FuncInfo->getSGPRSpillToVirtualVGPRLanes(FI);
+  Register PrevLaneVGPR;
+  for (auto &Spill : VGPRSpills) {
+    if (PrevLaneVGPR == Spill.VGPR)
+      continue;
+
+    PrevLaneVGPR = Spill.VGPR;
+    auto I = LaneVGPRDomInstr.find(Spill.VGPR);
+    if (Spill.Lane == 0 && I == LaneVGPRDomInstr.end()) {
+      // Initially add the spill instruction itself for Insertion point.
+      LaneVGPRDomInstr[Spill.VGPR] = InsertPt;
+    } else {
+      assert(I != LaneVGPRDomInstr.end());
+      auto PrevInsertPt = I->second;
+      MachineBasicBlock *DomMBB = PrevInsertPt->getParent();
+      if (DomMBB == MBB) {
+        // The insertion point earlier selected in a predecessor block whose
+        // spills are currently being lowered. The earlier InsertPt would be
+        // the one just before the block terminator and it should be changed
+        // if we insert any new spill in it.
+        if (MDT->dominates(&*InsertPt, &*PrevInsertPt))
+          I->second = InsertPt;
+
+        continue;
       }
+
+      // Find the common dominator block between PrevInsertPt and the
+      // current spill.
+      DomMBB = MDT->findNearestCommonDominator(DomMBB, MBB);
+      if (DomMBB == MBB)
+        I->second = InsertPt;
+      else if (DomMBB != PrevInsertPt->getParent())
+        I->second = &(*DomMBB->getFirstTerminator());
     }
   }
+}
 
-  // Insert the KILL in the return blocks to extend their liveness untill the
-  // end of function. Insert a separate KILL for each VGPR.
-  for (MachineBasicBlock *RestoreBlock : RestoreBlocks) {
-    MachineBasicBlock::iterator InsertBefore =
-        RestoreBlock->getFirstTerminator();
-    DebugLoc DL = RestoreBlock->findDebugLoc(InsertBefore);
-    for (auto Reg : MFI->getSGPRSpillVGPRs()) {
-      auto MIB = BuildMI(*RestoreBlock, InsertBefore, DL,
-                         TII->get(TargetOpcode::KILL));
-      MIB.addReg(Reg);
-      if (LIS)
-        LIS->InsertMachineInstrInMaps(*MIB);
+void SILowerSGPRSpills::determineRegsForWWMAllocation(MachineFunction &MF,
+                                                      BitVector &RegMask) {
+  // Determine an optimal number of VGPRs for WWM allocation. The complement
+  // list will be available for allocating other VGPR virtual registers.
+  SIMachineFunctionInfo *MFI = MF.getInfo<SIMachineFunctionInfo>();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  BitVector ReservedRegs = TRI->getReservedRegs(MF);
+  BitVector NonWwmAllocMask(TRI->getNumRegs());
+
+  // FIXME: MaxNumVGPRsForWwmAllocation might need to be adjusted in the future
+  // to have a balanced allocation between WWM values and per-thread vector
+  // register operands.
+  unsigned NumRegs = MaxNumVGPRsForWwmAllocation;
+  NumRegs =
+      std::min(static_cast<unsigned>(MFI->getSGPRSpillVGPRs().size()), NumRegs);
+
+  auto [MaxNumVGPRs, MaxNumAGPRs] = TRI->getMaxNumVectorRegs(MF);
+  // Try to use the highest available registers for now. Later after
+  // vgpr-regalloc, they can be shifted to the lowest range.
+  unsigned I = 0;
+  for (unsigned Reg = AMDGPU::VGPR0 + MaxNumVGPRs - 1;
+       (I < NumRegs) && (Reg >= AMDGPU::VGPR0); --Reg) {
+    if (!ReservedRegs.test(Reg) &&
+        !MRI.isPhysRegUsed(Reg, /*SkipRegMaskTest=*/true)) {
+      TRI->markSuperRegs(RegMask, Reg);
+      ++I;
     }
   }
+
+  assert(I == NumRegs &&
+         "Failed to find enough VGPRs for whole-wave register allocation");
----------------
arsenm wrote:

This probably needs to be report_fatal_error (or better, emitError and just pick a register). You can easily run into this with some asm 

https://github.com/llvm/llvm-project/pull/93526


More information about the llvm-commits mailing list