[llvm] [AMDGPU] Add scheduling stage to rewrite MFMA from VGPR to AGPR (PR #170335)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 4 11:18:28 PST 2025


================
@@ -1837,6 +1945,532 @@ void GCNSchedStage::revertScheduling() {
   DAG.Regions[RegionIdx] = std::pair(DAG.RegionBegin, DAG.RegionEnd);
 }
 
+bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {
+
+  if (!static_cast<const SIInstrInfo *>(DAG.TII)->isMAI(*MI))
+    return false;
+  return AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode()) != -1;
+}
+
+bool RewriteScheduleStage::initHeuristics(
+    std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+    DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+    SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+  // Prepare for the heuristics
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (!isRewriteCandidate(&MI))
+        continue;
+
+      int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+      assert(ReplacementOp != -1);
+
+      RewriteCands.push_back({&MI, MI.getOpcode()});
+      MI.setDesc(TII->get(ReplacementOp));
+
+      MachineOperand *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2);
+      if (Src2->isReg()) {
+        SmallVector<SlotIndex, 8> Src2ReachingDefs;
+        findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+
+        // For any definition of the src2 register which is non-MFMA, we
+        // insert a copy.
+        for (SlotIndex RDIdx : Src2ReachingDefs) {
+          MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
+          if (!TII->isMAI(*RD))
+            CopyForDef.insert(RD);
+        }
+      }
+
+      MachineOperand &Dst = MI.getOperand(0);
+      SmallVector<MachineOperand *, 8> DstReachingUses;
+
+      findReachingUses(&MI, DAG.LIS, DstReachingUses);
+
+      for (MachineOperand *RUOp : DstReachingUses) {
+        if (TII->isMAI(*RUOp->getParent()))
+          continue;
+
+        // For any user of the result of the MFMA which is not an MFMA, we
+        // insert a copy. For a given register, we will only insert one copy
+        // per user block.
+        CopyForUse[RUOp->getParent()->getParent()].insert(RUOp->getReg());
+
+        SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+        findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+        for (auto RDIndex : DstUsesReachingDefs) {
----------------
arsenm wrote:

```suggestion
        for (SlotIndex RDIndex : DstUsesReachingDefs) {
```

https://github.com/llvm/llvm-project/pull/170335


More information about the llvm-commits mailing list