[llvm] [AMDGPU] Add scheduling stage to rewrite MFMA from VGPR to AGPR (PR #149367)
Lucas Ramirez via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 26 09:18:32 PDT 2025
================
@@ -1642,6 +1755,534 @@ void GCNSchedStage::revertScheduling() {
DAG.Regions[RegionIdx] = std::pair(DAG.RegionBegin, DAG.RegionEnd);
}
+bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {
+
+ if (!static_cast<const SIInstrInfo *>(DAG.TII)->isMAI(*MI))
+ return false;
+ return AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode()) != -1;
+}
+
+bool RewriteScheduleStage::initHeuristics(
+ std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+ DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+ SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+ // Prepare for the heuristics
+ for (auto &MBB : MF) {
+ for (auto &MI : MBB) {
+ if (isRewriteCandidate(&MI)) {
+ int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+ if (ReplacementOp == -1)
+ continue;
+
+ RewriteCands.push_back({&MI, MI.getOpcode()});
+ MI.setDesc(TII->get(ReplacementOp));
+
+ MachineOperand *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2);
+ if (Src2->isReg()) {
+ SmallVector<SlotIndex, 8> Src2ReachingDefs;
+ findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+
+ // For any definition of the src2 register which is non-MFMA, we
+ // insert a copy.
+ for (SlotIndex RDIdx : Src2ReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
+ if (!TII->isMAI(*RD))
+ CopyForDef.insert(RD);
+ }
+ }
+
+ MachineOperand &Dst = MI.getOperand(0);
+ SmallVector<MachineOperand *, 8> DstReachingUses;
+
+ findReachingUses(&MI, DAG.LIS, DstReachingUses);
+
+ for (MachineOperand *RUOp : DstReachingUses) {
+ if (TII->isMAI(*RUOp->getParent()))
+ continue;
+
+ // For any user of the result of the MFMA which is not an MFMA, we
+ // insert a copy. For a given register, we will only insert one copy
+ // per user block.
+ CopyForUse[RUOp->getParent()->getParent()].insert(RUOp->getReg());
+
+ SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+ findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+ for (auto RDIndex : DstUsesReachingDefs) {
+ MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+ if (TII->isMAI(*RD))
+ continue;
+
+ // For any definition of the user of the MFMA which is not an MFMA,
+ // we insert a copy. We do this to transform all the reaching defs
+ // of this use to AGPR. By doing this, we can insert a copy from
+ // AGPR to VGPR at the user rather than after the MFMA.
+ CopyForDef.insert(RD);
+ }
+ }
+
+ // Do the rewrite to allow for updated RP calculation.
+ const TargetRegisterClass *VGPRRC = DAG.MRI.getRegClass(Dst.getReg());
+ const TargetRegisterClass *AGPRRC = SRI->getEquivalentAGPRClass(VGPRRC);
+ DAG.MRI.setRegClass(Dst.getReg(), AGPRRC);
+ if (Src2->isReg())
+ DAG.MRI.setRegClass(Src2->getReg(), AGPRRC);
+ }
+ }
+ }
+
+ return true;
+}
+
+int64_t RewriteScheduleStage::getRewriteCost(
+ std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+ DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+ SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+ MBFI.calculate(MF, MBPI, *DAG.MLI);
+ int64_t BestSpillCost = 0;
+ int64_t Cost = 0;
+
+ for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+ if (!RegionsWithExcessArchVGPR[Region])
+ continue;
+
+ auto PressureBefore = DAG.Pressure[Region];
+ unsigned SpillCostBefore = PressureBefore.getVGPRSpills(ST, MF);
+
+ // For the cases we care about (i.e. ArchVGPR usage is greater than the
+ // addressable limit), rewriting alone should bring pressure to manageable
+ // level. If we find any such region, then the rewrite is potentially
+ // beneficial.
+ auto PressureAfter = DAG.getRealRegPressure(Region);
+ unsigned SpillCostAfter = PressureAfter.getVGPRSpills(ST, MF);
+
+ uint64_t EntryFreq = MBFI.getEntryFreq().getFrequency();
----------------
lucas-rami wrote:
Hoist outside loop and remove redefinition below.
https://github.com/llvm/llvm-project/pull/149367
More information about the llvm-commits
mailing list