[llvm] [AMDGPU] Add scheduling stage to rewrite MFMA from VGPR to AGPR (PR #149367)
Lucas Ramirez via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 17 07:53:52 PDT 2025
================
@@ -1642,6 +1751,532 @@ void GCNSchedStage::revertScheduling() {
DAG.Regions[RegionIdx] = std::pair(DAG.RegionBegin, DAG.RegionEnd);
}
+bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {
+
+ if (!static_cast<const SIInstrInfo *>(DAG.TII)->isMAI(*MI))
+ return false;
+ return AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode()) != -1;
+}
+
+bool RewriteScheduleStage::initHeuristics(
+ std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+ DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+ SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+ // Prepare for the heuristics
+ for (auto &MBB : MF) {
+ for (auto &MI : MBB) {
+ if (!isRewriteCandidate(&MI))
+ continue;
+
+ int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+ if (ReplacementOp == -1)
+ continue;
+
+ RewriteCands.push_back({&MI, MI.getOpcode()});
+ MI.setDesc(TII->get(ReplacementOp));
+
+ MachineOperand *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2);
+ if (Src2->isReg()) {
+ SmallVector<SlotIndex, 8> Src2ReachingDefs;
+ findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+
+ // For any definition of the src2 register which is non-MFMA, we
+ // insert a copy.
+ for (SlotIndex RDIdx : Src2ReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
+ if (!TII->isMAI(*RD))
+ CopyForDef.insert(RD);
+ }
+ }
+
+ MachineOperand &Dst = MI.getOperand(0);
+ SmallVector<MachineOperand *, 8> DstReachingUses;
+
+ findReachingUses(&MI, DAG.LIS, DstReachingUses);
+
+ for (MachineOperand *RUOp : DstReachingUses) {
+ if (TII->isMAI(*RUOp->getParent()))
+ continue;
+
+ // For any user of the result of the MFMA which is not an MFMA, we
+ // insert a copy. For a given register, we will only insert one copy
+ // per user block.
+ CopyForUse[RUOp->getParent()->getParent()].insert(RUOp->getReg());
+
+ SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+ findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+ for (auto RDIndex : DstUsesReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+ if (TII->isMAI(*RD))
+ continue;
+
+ // For any definition of the user of the MFMA which is not an MFMA,
+ // we insert a copy. We do this to transform all the reaching defs
+ // of this use to AGPR. By doing this, we can insert a copy from
+ // AGPR to VGPR at the user rather than after the MFMA.
+ CopyForDef.insert(RD);
+ }
+ }
+
+ // Do the rewrite to allow for updated RP calculation.
+ const TargetRegisterClass *VGPRRC = DAG.MRI.getRegClass(Dst.getReg());
+ const TargetRegisterClass *AGPRRC = SRI->getEquivalentAGPRClass(VGPRRC);
+ DAG.MRI.setRegClass(Dst.getReg(), AGPRRC);
+ if (Src2->isReg())
+ DAG.MRI.setRegClass(Src2->getReg(), AGPRRC);
+ }
+ }
+
+ return true;
+}
+
+int64_t RewriteScheduleStage::getRewriteCost(
+ const std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+ const DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+ const SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+ MachineBranchProbabilityInfo MBPI;
+ MachineBlockFrequencyInfo MBFI;
+
+ MBFI.calculate(MF, MBPI, *DAG.MLI);
+ int64_t BestSpillCost = 0;
+ int64_t Cost = 0;
+
+ uint64_t EntryFreq = MBFI.getEntryFreq().getFrequency();
+
+ for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+ if (!RegionsWithExcessArchVGPR[Region])
+ continue;
+
+ GCNRegPressure &PressureBefore = DAG.Pressure[Region];
+ unsigned SpillCostBefore = PressureBefore.getVGPRSpills(MF);
+
+ // For the cases we care about (i.e. ArchVGPR usage is greater than the
+ // addressable limit), rewriting alone should bring pressure to manageable
+ // level. If we find any such region, then the rewrite is potentially
+ // beneficial.
+ GCNRegPressure PressureAfter = DAG.getRealRegPressure(Region);
+ unsigned SpillCostAfter = PressureAfter.getVGPRSpills(MF);
+
+ uint64_t BlockFreq =
+ MBFI.getBlockFreq(DAG.Regions[Region].first->getParent())
+ .getFrequency();
+
+ bool RelativeFreqIsDenom = EntryFreq > BlockFreq;
+ uint64_t RelativeFreq = EntryFreq && BlockFreq
+ ? (RelativeFreqIsDenom ? EntryFreq / BlockFreq
+ : BlockFreq / EntryFreq)
+ : 1;
+
+ // This assumes perfect spilling / splitting -- using one spill / copy
+ // instruction and one restoreFrom / copy for each excess register,
+ int64_t SpillCost = ((int)SpillCostAfter - (int)SpillCostBefore) * 2;
+
+ // Also account for the block frequency.
+ if (RelativeFreqIsDenom)
+ SpillCost /= (int64_t)RelativeFreq;
+ else
+ SpillCost *= (int64_t)RelativeFreq;
+
+ // If we have increased spilling in any block, just bail.
+ if (SpillCost > 0)
+ return SpillCost;
+
+ if (SpillCost < BestSpillCost)
+ BestSpillCost = SpillCost;
+ }
+
+ // Set the cost to the largest decrease in spill cost in order to not double
+ // count spill reductions.
+ Cost = BestSpillCost;
+
+ assert(Cost <= 0);
+
+ unsigned CopyCost = 0;
+
+ // For each CopyForDef, increase the cost by the register size while
+ // accounting for block frequency.
+ for (auto *DefMI : CopyForDef) {
+ auto DefReg = DefMI->getOperand(0).getReg();
+ uint64_t DefFreq =
+ EntryFreq
+ ? MBFI.getBlockFreq(DefMI->getParent()).getFrequency() / EntryFreq
+ : 1;
+
+ unsigned RegSize = DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(DefReg));
+ unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+ CopyCost += NumRegs * DefFreq;
+ }
+
+ // Account for CopyForUse copies in each block that the register is used.
+ for (auto &[UseBlock, UseRegs] : CopyForUse) {
+ uint64_t UseFreq =
+ EntryFreq ? MBFI.getBlockFreq(UseBlock).getFrequency() / EntryFreq : 1;
+
+ for (auto UseReg : UseRegs) {
+ unsigned RegSize =
+ DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(UseReg));
+ unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+ CopyCost += NumRegs * UseFreq;
+ }
+ }
+
+ Cost += CopyCost;
+
+ // Reset to the vgpr form. We must do rewriting after copy-insertion, as some
+ // defs of the register may require VGPR.
+ for (auto &[MI, OriginalOpcode] : RewriteCands) {
+ assert(TII->isMAI(*MI));
+ const TargetRegisterClass *AGPRRC =
+ DAG.MRI.getRegClass(MI->getOperand(0).getReg());
+ const TargetRegisterClass *VGPRRC = SRI->getEquivalentVGPRClass(AGPRRC);
+
+ MachineOperand *Src2 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2);
+ assert(Src2);
+
+ if (Src2->isReg())
+ DAG.MRI.setRegClass(Src2->getReg(), VGPRRC);
+ DAG.MRI.setRegClass(MI->getOperand(0).getReg(), VGPRRC);
+ MI->setDesc(TII->get(OriginalOpcode));
+ }
+
+ return Cost;
+}
+
+bool RewriteScheduleStage::rewrite(
+ const std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands) {
+ DenseMap<MachineInstr *, unsigned> FirstMIToRegion;
+ DenseMap<MachineInstr *, unsigned> LastMIToRegion;
+
+ for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+ auto Entry = DAG.Regions[Region];
+ if (Entry.first == Entry.second)
+ continue;
+
+ FirstMIToRegion[&*Entry.first] = Region;
+ if (Entry.second != Entry.first->getParent()->end())
+ LastMIToRegion[&*Entry.second] = Region;
+ }
+
+ // Rewrite the MFMAs to AGPR, and insert any copies as needed.
+ // The general assumption of the algorithm (and the previous cost calculation)
+ // is that it is better to insert the copies in the MBB of the def of the src2
+ // operands, and in the MBB of the user of the dest operands. This is based on
+ // the assumption that the MFMAs are likely to appear in loop bodies, while
+ // the src2 and dest operands are live-in / live-out of the loop. Due to this
+ // design, the algorithm for finding copy insertion points is more
+ // complicated.
+ //
+ // There are three main cases to handle: 1. the reaching defs of the src2
+ // operands, 2. the reaching uses of the dst operands, and 3. the reaching
+ // defs of the reaching uses of the dst operand.
+ //
+ // In the first case, we simply insert copies after each of the reaching
+ // definitions. In the second case, we collect all the uses of a given dest
+ // and organize them by MBB. Then, we insert 1 copy for each MBB before the
+ // earliest use. Since the use may have multiple reaching defs, and since we
+ // want to replace the register it is using with the result of the copy, we
+ // must handle case 3. In the third case, we simply insert a copy after each
+ // of the reaching defs to connect to the copy of the reaching uses of the dst
+ // reg. This allows us to avoid inserting copies next to the MFMAs.
+ //
+ // While inserting the copies, we maintain a map of operands which will use
+ // different regs (i.e. the result of the copies). For example, a case 1 src2
+ // operand will use the register result of the copies after the reaching defs,
+ // as opposed to the original register. Now that we have completed our copy
+ // analysis and placement, we can bulk update the registers. We do this
+ // separately as to avoid complicating the reachingDef and reachingUse
+ // queries.
+ //
+ // While inserting the copies, we also maintain a list or registers which we
+ // will want to reclassify as AGPR. After doing the copy insertion and the
+ // register replacement, we can finally do the reclassification. This uses the
+ // redef map, as the registers we are interested in reclassifying may be
+ // replaced by the result of a copy. We must do this after the copy analysis
+ // and placement as we must have an accurate redef map -- otherwise we may end
+ // up creating illegal instructions.
+
+ // The original registers of the MFMA that need to be reclassified as AGPR.
+ std::set<Register> RewriteRegs;
+ // The map of an original register in the MFMA to a new register (result of a
+ // copy) that it should be replaced with.
+ DenseMap<Register, Register> RedefMap;
+ // The map of the original MFMA registers to the relevant MFMA operands.
+ DenseMap<Register, std::set<MachineOperand *>> ReplaceMap;
+ // The map of reaching defs for a given register -- to avoid duplicate copies.
+ DenseMap<Register, SmallPtrSet<MachineInstr *, 8>> ReachingDefCopyMap;
+ // The map of reaching uses for a given register by basic block -- to avoid
+ // duplicate copies and to calculate per MBB insert pts.
+ DenseMap<unsigned, DenseMap<Register, SmallPtrSet<MachineOperand *, 8>>>
+ ReachingUseTracker;
+
+ for (auto &[MI, OriginalOpcode] : RewriteCands) {
+
+ int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode());
+ if (ReplacementOp == -1)
+ continue;
+ MI->setDesc(TII->get(ReplacementOp));
+
+ // Case 1: insert copies for the reaching defs of the Src2Reg.
+ MachineOperand *Src2 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2);
+
+ if (Src2->isReg()) {
+ Register Src2Reg = Src2->getReg();
+ if (!Src2Reg.isVirtual())
+ return false;
+
+ Register MappedReg = Src2->getReg();
+ SmallVector<SlotIndex, 8> Src2ReachingDefs;
+ findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+ SmallVector<MachineInstr *, 8> Src2DefsReplace;
+
+ for (auto RDIndex : Src2ReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+ if (TII->isMAI(*RD))
+ continue;
+
+ // If there is a non mai reaching def, then we need a copy.
+ if (find(Src2DefsReplace, RD) == Src2DefsReplace.end())
+ Src2DefsReplace.push_back(RD);
+ }
+
+ if (!Src2DefsReplace.empty()) {
+ if (RedefMap.contains(Src2Reg))
+ MappedReg = RedefMap[Src2Reg];
+ else {
+ assert(!ReachingDefCopyMap.contains(Src2Reg));
+ const TargetRegisterClass *Src2RC = DAG.MRI.getRegClass(Src2Reg);
+ const TargetRegisterClass *VGPRRC =
+ SRI->getEquivalentVGPRClass(Src2RC);
+
+ // Track the mapping of the original register to the new register.
+ MappedReg = DAG.MRI.createVirtualRegister(VGPRRC);
+ RedefMap[Src2Reg] = MappedReg;
+ }
+
+ // If none exists, create a copy from this reaching def.
+ // We may have inserted a copy already in an earlier iteration.
+ for (MachineInstr *RD : Src2DefsReplace) {
+ // Do not create redundant copies.
+ if (ReachingDefCopyMap[Src2Reg].insert(RD).second) {
+ MachineInstrBuilder VGPRCopy =
+ BuildMIAfter(*RD->getParent(), RD->getIterator(),
+ RD->getDebugLoc(), TII->get(TargetOpcode::COPY))
+ .addDef(MappedReg, 0, 0)
+ .addUse(Src2Reg, 0, 0);
+ DAG.LIS->InsertMachineInstrInMaps(*VGPRCopy);
+
+ // If this reaching def was the last MI in the region, update the
+ // region boundaries.
+ if (LastMIToRegion.contains(RD)) {
+ unsigned UpdateRegion = LastMIToRegion[RD];
+ DAG.Regions[UpdateRegion].second = VGPRCopy;
+ LastMIToRegion.erase(RD);
+ }
+ }
+ }
+ }
+
+ // Track the register for reclassification
+ RewriteRegs.insert(Src2Reg);
+
+ // Always insert the operand for replacement. If this corresponds with a
+ // chain of tied-def we may not see the VGPR requirement until later.
+ ReplaceMap[Src2Reg].insert(Src2);
+ }
+
+ // Case 2 and Case 3: insert copies before the reaching uses of the dsts,
+ // and after the reaching defs of the reaching uses of the dsts.
+
+ MachineOperand *Dst = &MI->getOperand(0);
+ Register DstReg = Dst->getReg();
+ if (!DstReg.isVirtual())
+ return false;
+
+ Register MappedReg = DstReg;
+ SmallVector<MachineOperand *, 8> DstReachingUses;
+
+ SmallVector<MachineOperand *, 8> DstReachingUseCopies;
+ SmallVector<MachineInstr *, 8> DstUseDefsReplace;
+
+ findReachingUses(MI, DAG.LIS, DstReachingUses);
+
+ for (MachineOperand *RUOp : DstReachingUses) {
+ if (TII->isMAI(*RUOp->getParent()))
+ continue;
+
+ // If there is a non mai reaching use, then we need a copy.
+ if (find(DstReachingUseCopies, RUOp) == DstReachingUseCopies.end())
+ DstReachingUseCopies.push_back(RUOp);
+ SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+ findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+ for (auto RDIndex : DstUsesReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+ if (TII->isMAI(*RD))
+ continue;
+
+ // If there is a non mai reaching def of this reaching use, then we will
+ // need a copy.
+ if (find(DstUseDefsReplace, RD) == DstUseDefsReplace.end())
+ DstUseDefsReplace.push_back(RD);
+ }
+ }
+
+ if (!DstUseDefsReplace.empty()) {
+ if (RedefMap.contains(DstReg))
+ MappedReg = RedefMap[DstReg];
+ else {
+ assert(!ReachingDefCopyMap.contains(DstReg));
+ const TargetRegisterClass *DstRC = DAG.MRI.getRegClass(DstReg);
+ const TargetRegisterClass *VGPRRC = SRI->getEquivalentVGPRClass(DstRC);
+
+ // Track the mapping of the original register to the new register.
+ MappedReg = DAG.MRI.createVirtualRegister(VGPRRC);
+ RedefMap[DstReg] = MappedReg;
+ }
+
+ // If none exists, create a copy from this reaching def.
+ // We may have inserted a copy already in an earlier iteration.
+ for (MachineInstr *RD : DstUseDefsReplace) {
+ // Do not create reundant copies.
+ if (ReachingDefCopyMap[DstReg].insert(RD).second) {
+ MachineInstrBuilder VGPRCopy =
+ BuildMIAfter(*RD->getParent(), RD->getIterator(),
+ RD->getDebugLoc(), TII->get(TargetOpcode::COPY))
+ .addDef(MappedReg, 0, 0)
+ .addUse(DstReg, 0, 0);
+ DAG.LIS->InsertMachineInstrInMaps(*VGPRCopy);
+
+ // If this reaching def was the last MI in the region, update the
+ // region boundaries.
+ if (LastMIToRegion.contains(RD)) {
+ unsigned UpdateRegion = LastMIToRegion[RD];
+ DAG.Regions[UpdateRegion].second = VGPRCopy;
+ LastMIToRegion.erase(RD);
+ }
+ }
+ }
+ }
+
+ for (MachineOperand *RU : DstReachingUseCopies) {
+ MachineBasicBlock *RUBlock = RU->getParent()->getParent();
+ // Just keep track of the reaching use of this register by block. After we
+ // have scanned all the MFMAs we can find optimal insert pts.
+ if (RUBlock != MI->getParent()) {
+ ReachingUseTracker[RUBlock->getNumber()][DstReg].insert(RU);
+ continue;
+ }
+
+ // Special case, the use is in the same block as the MFMA. Insert the copy
+ // just before the use.
+ const TargetRegisterClass *DstRC = DAG.MRI.getRegClass(DstReg);
+ const TargetRegisterClass *VGPRRC = SRI->getEquivalentVGPRClass(DstRC);
+ Register NewUseReg = DAG.MRI.createVirtualRegister(VGPRRC);
+ MachineInstr *UseInst = RU->getParent();
+ MachineInstrBuilder VGPRCopy =
+ BuildMI(*UseInst->getParent(), UseInst->getIterator(),
+ UseInst->getDebugLoc(), TII->get(TargetOpcode::COPY))
+ .addDef(NewUseReg, 0, 0)
+ .addUse(DstReg, 0, 0);
+ DAG.LIS->InsertMachineInstrInMaps(*VGPRCopy);
+ // Since we know this use has only one reaching def, we can replace the
+ // use reg.
+ RU->setReg(NewUseReg);
+ // Track the copy source operand for replacement.
+ ReplaceMap[DstReg].insert(&VGPRCopy->getOperand(1));
+ }
+
+ // Track the register for reclassification
+ RewriteRegs.insert(DstReg);
+ // Insert the dst operand for replacement. If this dst is in a chain of
+ // tied-def MFMAs, and the first src2 needs to be replaced with a new reg,
+ // all the correspond operands need to be replaced.
+ ReplaceMap[DstReg].insert(Dst);
+ }
+
+ // Handle the copies for dst uses.
+ for (auto RUBlockEntry : ReachingUseTracker) {
+ for (auto RUDst : RUBlockEntry.second) {
+ MachineOperand *OpBegin = *RUDst.second.begin();
+ SlotIndex InstPt = DAG.LIS->getInstructionIndex(*OpBegin->getParent());
+
+ // Find the earliest use in this block.
+ for (auto User : RUDst.second) {
----------------
lucas-rami wrote:
```suggestion
for (auto *User : RUDst.second) {
```
https://github.com/llvm/llvm-project/pull/149367
More information about the llvm-commits
mailing list