[llvm] [AMDGPU] Add scheduling stage to rewrite MFMA from VGPR to AGPR (PR #149367)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 6 15:08:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Jeffrey Byrnes (jrbyrnes)

<details>
<summary>Changes</summary>

After https://github.com/llvm/llvm-project/pull/145025 we will always produce the VGPR MFMA form. While this is beneficial for some cases, there are still cases where using the AGPR form is preferred. Specifically, in cases where we have high per-iteration RP coming from MFMAs and no in-loop VGPR users of MFMAs. In such cases, selecting the VGPR form may cause an explosion in VGPR pressure, which degrades the quality of scheduling. The PostRA MFMA rewriter can help improve RA for some of these cases, but it will not help the scheduler.

This PR does rewriting during scheduling as a separate scheduling stage. It will only try to go from VGPR -> AGPR form if we have ArchVGPR pressure over the addressable limit, and if we find that we will not need to issue any cross RC copies in loop. We can also implement AGPR form -> VGPR, but the assumption is that we will always produce VGPR form.

A WIP:
	Needs more testing
	Still a bit undecided about the heuristic
	Considering making the implemenation more generalized for other types of rewriting / transformations, though this may be left as a TODO

Putting up draft for any feedback.



---

Patch is 391.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149367.diff


6 Files Affected:

- (modified) llvm/include/llvm/CodeGen/MachineInstrBuilder.h (+15) 
- (modified) llvm/lib/Target/AMDGPU/GCNRegPressure.h (+30) 
- (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp (+643-4) 
- (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.h (+65-5) 
- (added) llvm/test/CodeGen/AMDGPU/sched_mfma_rewrite_copies.mir (+5591) 
- (added) llvm/test/CodeGen/AMDGPU/sched_mfma_rewrite_cost.mir (+524) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
index e63e77a8302c0..7a4bc392bfc47 100644
--- a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
+++ b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
@@ -454,6 +454,21 @@ inline MachineInstrBuilder BuildMI(MachineBasicBlock &BB,
       .setMMRAMetadata(MIMD.getMMRAMetadata());
 }
 
+/// This version of the builder inserts the newly-built instruction after the
+/// given position in the given MachineBasicBlock, and does NOT take a
+/// destination register.
+inline MachineInstrBuilder BuildMIAfter(MachineBasicBlock &BB,
+                                        MachineBasicBlock::iterator I,
+                                        const MIMetadata &MIMD,
+                                        const MCInstrDesc &MCID) {
+  MachineFunction &MF = *BB.getParent();
+  MachineInstr *MI = MF.CreateMachineInstr(MCID, MIMD.getDL());
+  BB.insertAfter(I, MI);
+  return MachineInstrBuilder(MF, MI)
+      .setPCSections(MIMD.getPCSections())
+      .setMMRAMetadata(MIMD.getMMRAMetadata());
+}
+
 inline MachineInstrBuilder BuildMI(MachineBasicBlock &BB,
                                    MachineBasicBlock::instr_iterator I,
                                    const MIMetadata &MIMD,
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
index ea33a229110c1..91691ea96942d 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
@@ -90,6 +90,36 @@ struct GCNRegPressure {
                                                 DynamicVGPRBlockSize));
   }
 
+  unsigned getVGPRSpills(const GCNSubtarget &ST, MachineFunction &MF) {
+    if (!ST.hasGFX90AInsts())
+      return 0;
+
+    auto MaxVectorRegs = ST.getMaxNumVectorRegs(MF.getFunction());
+    unsigned ArchVGPRThreshold = MaxVectorRegs.first;
+    unsigned AGPRThreshold = MaxVectorRegs.second;
+
+    unsigned ArchPressure = getArchVGPRNum();
+    unsigned AGPRPressure = getAGPRNum();
+
+    unsigned ArchSpill = ArchPressure > ArchVGPRThreshold
+                             ? (ArchPressure - ArchVGPRThreshold)
+                             : 0;
+    unsigned AGPRSpill =
+        AGPRPressure > AGPRThreshold ? (AGPRPressure - AGPRThreshold) : 0;
+
+    unsigned UnifiedSpill = 0;
+
+    if (ST.hasGFX90AInsts()) {
+      unsigned CombinedThreshold = ST.getMaxNumVGPRs(MF);
+      unsigned UnifiedPressure = getVGPRNum(true);
+      UnifiedSpill = UnifiedPressure > CombinedThreshold
+                         ? (UnifiedPressure - CombinedThreshold)
+                         : 0;
+    }
+
+    return std::max(UnifiedSpill, (ArchSpill + AGPRSpill));
+  }
+
   void inc(unsigned Reg,
            LaneBitmask PrevMask,
            LaneBitmask NewMask,
diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
index ce1ce687d0038..564021740b90c 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
@@ -29,6 +29,7 @@
 #include "SIMachineFunctionInfo.h"
 #include "Utils/AMDGPUBaseInfo.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/RegisterClassInfo.h"
 #include "llvm/MC/LaneBitmask.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -528,6 +529,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
     const MachineSchedContext *C, bool IsLegacyScheduler)
     : GCNSchedStrategy(C) {
   SchedStages.push_back(GCNSchedStageID::OccInitialSchedule);
+  SchedStages.push_back(GCNSchedStageID::RewriteSchedule);
   SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
   SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
   SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
@@ -778,6 +780,8 @@ GCNScheduleDAGMILive::createSchedStage(GCNSchedStageID SchedStageID) {
   switch (SchedStageID) {
   case GCNSchedStageID::OccInitialSchedule:
     return std::make_unique<OccInitialScheduleStage>(SchedStageID, *this);
+  case GCNSchedStageID::RewriteSchedule:
+    return std::make_unique<RewriteScheduleStage>(SchedStageID, *this);
   case GCNSchedStageID::UnclusteredHighRPReschedule:
     return std::make_unique<UnclusteredHighRPStage>(SchedStageID, *this);
   case GCNSchedStageID::ClusteredLowOccupancyReschedule:
@@ -898,13 +902,11 @@ GCNScheduleDAGMILive::getRegionLiveInMap() const {
   RegionFirstMIs.reserve(Regions.size());
   auto I = Regions.rbegin(), E = Regions.rend();
   do {
-    const MachineBasicBlock *MBB = I->first->getParent();
     auto *MI = &*skipDebugInstructionsForward(I->first, I->second);
     RegionFirstMIs.push_back(MI);
-    do {
-      ++I;
-    } while (I != E && I->first->getParent() == MBB);
+    ++I;
   } while (I != E);
+
   return getLiveRegMap(RegionFirstMIs, /*After=*/false, *LIS);
 }
 
@@ -1003,6 +1005,9 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const GCNSchedStageID &StageID) {
   case GCNSchedStageID::OccInitialSchedule:
     OS << "Max Occupancy Initial Schedule";
     break;
+  case GCNSchedStageID::RewriteSchedule:
+    OS << "Instruction Rewriting Reschedule";
+    break;
   case GCNSchedStageID::UnclusteredHighRPReschedule:
     OS << "Unclustered High Register Pressure Reschedule";
     break;
@@ -1036,6 +1041,112 @@ bool GCNSchedStage::initGCNSchedStage() {
   return true;
 }
 
+SlotIndex
+RewriteScheduleStage::findReachingDefs(MachineOperand &UseMO,
+                                       LiveIntervals *LIS,
+                                       SmallVectorImpl<SlotIndex> &DefIdxs) {
+  assert(UseMO.isReg());
+  MachineInstr *UseMI = UseMO.getParent();
+  LiveInterval &UseLI = LIS->getInterval(UseMO.getReg());
+  auto VNInfo = UseLI.getVNInfoAt(LIS->getInstructionIndex(*UseMI));
+
+  SlotIndex DefMBBStart =
+      LIS->getMBBStartIdx(LIS->getMBBFromIndex(VNInfo->def));
+
+  // If the def is in the block, then it must be the only reaching def.
+  if (DefMBBStart != VNInfo->def) {
+    DefIdxs.push_back(VNInfo->def);
+    return VNInfo->def;
+  }
+
+  SmallPtrSet<MachineBasicBlock *, 8> Visited;
+  SmallVector<MachineBasicBlock *, 8> Worklist;
+
+  Visited.insert(UseMI->getParent());
+
+  // Mark the predecessor blocks for traversal
+  for (auto PredMBB : UseMI->getParent()->predecessors()) {
+    Worklist.push_back(PredMBB);
+    Visited.insert(PredMBB);
+  }
+
+  while (!Worklist.empty()) {
+    MachineBasicBlock *CurrMBB = Worklist.pop_back_val();
+
+    SlotIndex CurrMBBEnd = LIS->getMBBEndIdx(CurrMBB);
+    auto VNInfo = UseLI.getVNInfoAt(CurrMBBEnd.getPrevSlot());
+
+    MachineBasicBlock *DefMBB = LIS->getMBBFromIndex(VNInfo->def);
+    SlotIndex DefMBBStart = LIS->getMBBStartIdx(DefMBB);
+
+    // If there is a def in this block, then add it to the list. This is the
+    // reaching def of this path.
+    if (DefMBBStart != VNInfo->def) {
+      DefIdxs.push_back(VNInfo->def);
+      continue;
+    }
+
+    for (auto PredMBB : DefMBB->predecessors()) {
+      if (Visited.insert(PredMBB).second)
+        Worklist.push_back(PredMBB);
+    }
+  }
+
+  return VNInfo->def;
+}
+
+void RewriteScheduleStage::findReachingUses(
+    MachineInstr *DefMI, LiveIntervals *LIS,
+    SmallVectorImpl<MachineOperand *> &ReachingUses) {
+  SlotIndex DefIdx = LIS->getInstructionIndex(*DefMI);
+  for (auto &UseMO :
+       DAG.MRI.use_nodbg_operands(DefMI->getOperand(0).getReg())) {
+    SmallVector<SlotIndex, 8> ReachingDefIndexes;
+    findReachingDefs(UseMO, LIS, ReachingDefIndexes);
+
+    // If we find a use that contains this DefMI in its reachingDefs, then it is
+    // a reaching use.
+    if (find_if(ReachingDefIndexes, [DefIdx](SlotIndex RDIdx) {
+          return SlotIndex::isSameInstr(RDIdx, DefIdx);
+        }) != ReachingDefIndexes.end())
+      ReachingUses.push_back(&UseMO);
+  }
+}
+
+bool RewriteScheduleStage::initGCNSchedStage() {
+  const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+
+  RegionsWithExcessArchVGPR.resize(DAG.Regions.size());
+  RegionsWithExcessArchVGPR.reset();
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    auto PressureBefore = DAG.Pressure[Region];
+    if (PressureBefore.getArchVGPRNum() > ST.getAddressableNumArchVGPRs())
+      RegionsWithExcessArchVGPR[Region] = true;
+  }
+
+  if (!ST.hasGFX90AInsts() || RegionsWithExcessArchVGPR.none())
+    return false;
+
+  TII = ST.getInstrInfo();
+  SRI = ST.getRegisterInfo();
+
+  std::vector<std::pair<MachineInstr *, unsigned>> RewriteCands;
+  DenseMap<MachineBasicBlock *, std::set<Register>> CopyForUse;
+  SmallPtrSet<MachineInstr *, 8> CopyForDef;
+
+  if (!initHeuristics(RewriteCands, CopyForUse, CopyForDef))
+    return false;
+
+  int64_t Cost = getRewriteCost(RewriteCands, CopyForUse, CopyForDef);
+
+  // If we haven't found the beneficial conditions, prefer the VGPR form which
+  // may result in less cross RC copies.
+  if (Cost > 0)
+    return false;
+
+  return rewrite(RewriteCands);
+}
+
 bool UnclusteredHighRPStage::initGCNSchedStage() {
   if (DisableUnclusterHighRP)
     return false;
@@ -1642,6 +1753,534 @@ void GCNSchedStage::revertScheduling() {
   DAG.Regions[RegionIdx] = std::pair(DAG.RegionBegin, DAG.RegionEnd);
 }
 
+bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {
+
+  if (!static_cast<const SIInstrInfo *>(DAG.TII)->isMAI(*MI))
+    return false;
+  return AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode()) != -1;
+}
+
+bool RewriteScheduleStage::initHeuristics(
+    std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+    DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+    SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+  // Prepare for the heuristics
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (isRewriteCandidate(&MI)) {
+        int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+        if (ReplacementOp == -1)
+          continue;
+
+        RewriteCands.push_back({&MI, MI.getOpcode()});
+        MI.setDesc(TII->get(ReplacementOp));
+
+        MachineOperand *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2);
+        if (Src2->isReg()) {
+          SmallVector<SlotIndex, 8> Src2ReachingDefs;
+          findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+
+          // For any definition of the src2 register which is non-MFMA, we
+          // insert a copy.
+          for (SlotIndex RDIdx : Src2ReachingDefs) {
+            MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
+            if (!TII->isMAI(*RD))
+              CopyForDef.insert(RD);
+          }
+        }
+
+        MachineOperand &Dst = MI.getOperand(0);
+        SmallVector<MachineOperand *, 8> DstReachingUses;
+
+        findReachingUses(&MI, DAG.LIS, DstReachingUses);
+
+        for (MachineOperand *RUOp : DstReachingUses) {
+          if (TII->isMAI(*RUOp->getParent()))
+            continue;
+
+          // For any user of the result of the MFMA which is not an MFMA, we
+          // insert a copy. For a given register, we will only insert one copy
+          // per user block.
+          CopyForUse[RUOp->getParent()->getParent()].insert(RUOp->getReg());
+
+          SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+          findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+          for (auto RDIndex : DstUsesReachingDefs) {
+            MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+            if (TII->isMAI(*RD))
+              continue;
+
+            // For any definition of the user of the MFMA which is not an MFMA,
+            // we insert a copy. We do this to transform all the reaching defs
+            // of this use to AGPR. By doing this, we can insert a copy from
+            // AGPR to VGPR at the user rather than after the MFMA.
+            CopyForDef.insert(RD);
+          }
+        }
+
+        // Do the rewrite to allow for updated RP calculation.
+        const TargetRegisterClass *VGPRRC = DAG.MRI.getRegClass(Dst.getReg());
+        const TargetRegisterClass *AGPRRC = SRI->getEquivalentAGPRClass(VGPRRC);
+        DAG.MRI.setRegClass(Dst.getReg(), AGPRRC);
+        if (Src2->isReg())
+          DAG.MRI.setRegClass(Src2->getReg(), AGPRRC);
+      }
+    }
+  }
+
+  return true;
+}
+
+int64_t RewriteScheduleStage::getRewriteCost(
+    std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+    DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+    SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+  MBFI.calculate(MF, MBPI, *DAG.MLI);
+  int64_t BestSpillCost = 0;
+  int64_t Cost = 0;
+
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    if (!RegionsWithExcessArchVGPR[Region])
+      continue;
+
+    auto PressureBefore = DAG.Pressure[Region];
+    unsigned SpillCostBefore = PressureBefore.getVGPRSpills(ST, MF);
+
+    // For the cases we care about (i.e. ArchVGPR usage is greater than the
+    // addressable limit), rewriting alone should bring pressure to manageable
+    // level. If we find any such region, then the rewrite is potentially
+    // beneficial.
+    auto PressureAfter = DAG.getRealRegPressure(Region);
+    unsigned SpillCostAfter = PressureAfter.getVGPRSpills(ST, MF);
+
+    uint64_t EntryFreq = MBFI.getEntryFreq().getFrequency();
+    uint64_t BlockFreq =
+        MBFI.getBlockFreq(DAG.Regions[Region].first->getParent())
+            .getFrequency();
+
+    bool RelativeFreqIsDenom = EntryFreq > BlockFreq;
+    uint64_t RelativeFreq = EntryFreq && BlockFreq
+                                ? (RelativeFreqIsDenom ? EntryFreq / BlockFreq
+                                                       : BlockFreq / EntryFreq)
+                                : 1;
+
+    // This assumes perfect spilling / splitting -- using one spill / copy
+    // instruction and one restoreFrom / copy for each excess register,
+    int64_t SpillCost = ((int)SpillCostAfter - (int)SpillCostBefore) * 2;
+
+    // Also account for the block frequency.
+    if (RelativeFreqIsDenom)
+      SpillCost /= (int64_t)RelativeFreq;
+    else
+      SpillCost *= (int64_t)RelativeFreq;
+
+    // If we have increased spilling in any block, just bail.
+    if (SpillCost > 0)
+      return SpillCost;
+
+    if (SpillCost < BestSpillCost)
+      BestSpillCost = SpillCost;
+  }
+
+  // Set the cost to the largest decrease in spill cost in order to not double
+  // count spill reductions.
+  Cost = BestSpillCost;
+
+  assert(Cost <= 0);
+
+  unsigned CopyCost = 0;
+
+  uint64_t EntryFreq = MBFI.getEntryFreq().getFrequency();
+
+  // For each CopyForDef, increase the cost by the register size while
+  // accounting for block frequency.
+  for (auto *DefMI : CopyForDef) {
+    auto DefReg = DefMI->getOperand(0).getReg();
+    uint64_t DefFreq =
+        EntryFreq
+            ? MBFI.getBlockFreq(DefMI->getParent()).getFrequency() / EntryFreq
+            : 1;
+
+    unsigned RegSize = DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(DefReg));
+    unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+    CopyCost += NumRegs * DefFreq;
+  }
+
+  // Account for CopyForUse copies in each block that the register is used.
+  for (auto &UseEntry : CopyForUse) {
+    uint64_t UseFreq =
+        EntryFreq ? MBFI.getBlockFreq(UseEntry.first).getFrequency() / EntryFreq
+                  : 1;
+
+    for (auto UseReg : UseEntry.second) {
+      unsigned RegSize =
+          DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(UseReg));
+      unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+      CopyCost += NumRegs * UseFreq;
+    }
+  }
+
+  Cost += CopyCost;
+
+  // Reset to the vgpr form. We must do rewriting after copy-insertion, as some
+  // defs of the register may require VGPR.
+  for (auto RI : RewriteCands) {
+    MachineInstr *MI = RI.first;
+
+    assert(TII->isMAI(*MI));
+    const TargetRegisterClass *AGPRRC =
+        DAG.MRI.getRegClass(MI->getOperand(0).getReg());
+    const TargetRegisterClass *VGPRRC = SRI->getEquivalentVGPRClass(AGPRRC);
+
+    MachineOperand *Src2 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2);
+    assert(Src2);
+
+    if (Src2->isReg()) {
+      DAG.MRI.setRegClass(Src2->getReg(), VGPRRC);
+    }
+    DAG.MRI.setRegClass(MI->getOperand(0).getReg(), VGPRRC);
+    MI->setDesc(TII->get(RI.second));
+  }
+
+  return Cost;
+}
+
+bool RewriteScheduleStage::rewrite(
+    std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands) {
+  DenseMap<MachineInstr *, unsigned> FirstMIToRegion;
+  DenseMap<MachineInstr *, unsigned> LastMIToRegion;
+
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    auto Entry = DAG.Regions[Region];
+    if (Entry.first == Entry.second)
+      continue;
+
+    FirstMIToRegion[&*Entry.first] = Region;
+    if (Entry.second != Entry.first->getParent()->end())
+      LastMIToRegion[&*Entry.second] = Region;
+  }
+
+  // Rewrite the MFMAs to AGPR, and insert any copies as needed.
+  // The general assumption of the algorithm (and the previous cost calculation)
+  // is that it is better to insert the copies in the MBB of the def of the src2
+  // operands, and in the MBB of the user of the dest operands. This is based on
+  // the assumption that the MFMAs are likely to appear in loop bodies, while
+  // the src2 and dest operands are live-in / live-out of the loop. Due to this
+  // design, the algorithm for finding copy insertion points is more
+  // complicated.
+  //
+  // There are three main cases to handle: 1. the reaching defs of the src2
+  // operands, 2. the reaching uses of the dst operands, and 3. the reaching
+  // defs of the reaching uses of the dst operand.
+  //
+  // In the first case, we simply insert copies after each of the reaching
+  // definitions. In the second case, we collect all the uses of a given dest
+  // and organize them by MBB. Then, we insert 1 copy for each MBB before the
+  // earliest use. Since the use may have multiple reaching defs, and since we
+  // want to replace the register it is using with the result of the copy, we
+  // must handle case 3. In the third case, we simply insert a copy after each
+  // of the reaching defs to connect to the copy of the reaching uses of the dst
+  // reg. This allows us to avoid inserting copies next to the' MFMAs.
+  //
+  // While inserting the copies, we maintain a map of operands which will use
+  // different regs (i.e. the result of the copies). For example, a case 1 src2
+  // operand will use the register result of the copies after the reaching defs,
+  // as opposed to the original register. Now that we have completed our copy
+  // analysis and placement, we can bulk update the registers. We do this
+  // separately as to avoid complicating the reachingDef and reachingUse
+  // queries.
+  //
+  // While inserting the copies, we also maintain a list or registers which we
+  // will want to reclassify as AGPR. After doing the copy isnertion and the
+  // register replacement, we can finally do the reclassification. This uses the
+  // redef map, as the registers we are interested in reclassifying may be
+  // replaced by the result of a copy. We must do this after the copy analysis
+  // and placement as we must have an accurate redef map -- otherwise we may end
+  // up creating illegal instructions.
+
+  // The original registers of the MFMA that need to be reclassified as AGPR
+  std::set<Register> RewriteRegs;
+  // The map of an original register in the MFMA to a new register (result of a
+  // copy) that it should be replaced with.
+  DenseMap<Register, Register> RedefMap;
+  // The map of the original MFMA registers to the relevant MFMA operands.
+  DenseMap<Register, std::set<MachineOperand *>> ReplaceMap;
+  // The map of reaching defs for a given register -- to avoid duplicate copies.
+  DenseMap<Register, SmallPtrSet<MachineInstr *, 8>> ReachingDefCopyMap;
+  // The map of reaching uses for a given register by basic block -- to avoid
+  // duplicate copies and to calculate per MBB insert pts.
+  DenseMap<unsigned, DenseMap<Register, SmallPtrSet<MachineOperand *, 8>>>
+      ReachingUseTracker;
+
+  for (auto &RI : RewriteCands) {
+    MachineInstr &MI = *RI.first;
+
+    int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+    if (ReplacementOp == -1)
+      continue;
+    MI.setDesc(TII->get(ReplacementOp));
+
+    // Case 1: insert copies for the reaching defs of the Src2Reg.
+    MachineOperan...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/149367


More information about the llvm-commits mailing list