[llvm] AMDGPU: Handle rewriting non-tied MFMA to AGPR form (PR #153015)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 20:01:16 PDT 2025


================
@@ -154,117 +231,122 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
       if (!DefMI || !DefMI->isFullCopy())
         continue;
 
-      Register CopySrcReg = DefMI->getOperand(1).getReg();
-      if (!CopySrcReg.isVirtual())
+      Register MFMADstReg = DefMI->getOperand(1).getReg();
+      if (!MFMADstReg.isVirtual())
         continue;
 
-      LiveInterval &CopySrcLI = LIS.getInterval(CopySrcReg);
+      LiveInterval &CopySrcLI = LIS.getInterval(MFMADstReg);
       LiveQueryResult LRQ = CopySrcLI.Query(VNI->def.getRegSlot());
-      MachineInstr *CopySrcMI = LIS.getInstructionFromIndex(LRQ.valueIn()->def);
-      if (!CopySrcMI)
+      MachineInstr *MFMA = LIS.getInstructionFromIndex(LRQ.valueIn()->def);
+      if (!MFMA || !isRewriteCandidate(*MFMA))
         continue;
 
-      int AGPROp = AMDGPU::getMFMASrcCVDstAGPROp(CopySrcMI->getOpcode());
-      if (AGPROp == -1)
+      MachineOperand *Src2 = TII.getNamedOperand(*MFMA, AMDGPU::OpName::src2);
+      if (!Src2->isReg())
         continue;
 
-      MachineOperand *Src2 =
-          TII.getNamedOperand(*CopySrcMI, AMDGPU::OpName::src2);
+      Register Src2Reg = Src2->getReg();
+      if (!Src2Reg.isVirtual())
+        continue;
 
       // FIXME: getMinimalPhysRegClass returns a nonsense AV_* subclass instead
       // of an AGPR or VGPR subclass, so we can't simply use the result on the
       // assignment.
 
       LLVM_DEBUG({
-        Register Src2PhysReg = VRM.getPhys(Src2->getReg());
         dbgs() << "Attempting to replace VGPR MFMA with AGPR version:"
                << " Dst=[" << printReg(VReg) << " => "
-               << printReg(PhysReg, &TRI) << "], Src2=["
-               << printReg(Src2->getReg(), &TRI) << " => "
-               << printReg(Src2PhysReg, &TRI) << "]: " << *CopySrcMI;
+               << printReg(PhysReg, &TRI) << ']';
+
+        if (Src2Reg) {
+          Register Src2PhysReg = VRM.getPhys(Src2Reg);
+          dbgs() << ", Src2=[" << printReg(Src2Reg, &TRI) << " => "
+                 << printReg(Src2PhysReg, &TRI) << "]: " << *MFMA;
+        }
       });
 
-      // If the inputs are tied and the same register, we can shortcut and
-      // directly replace the register.
-      if (!Src2->isReg() || Src2->getReg() != CopySrcReg ||
-          Src2->getSubReg() != DefMI->getOperand(1).getSubReg()) {
-        LLVM_DEBUG(
-            dbgs()
-            << "Replacing untied VGPR MFMAs with AGPR form not yet handled\n");
-        // TODO: Only handles the tied case for now. If the input operand is a
-        // different register, we need to also reassign it (either by looking
-        // for a compatible copy-from-AGPR, or by seeing if an available AGPR is
-        // compatible with all other uses.
-
-        // If we can't reassign it, we'd need to introduce a different copy
-        // which is likely worse than the copy we'd be saving.
-        continue;
-      }
+      const TargetRegisterClass *DstVirtRegRC = MRI.getRegClass(MFMADstReg);
 
-      const TargetRegisterClass *Src2VirtRegRC =
-          MRI.getRegClass(Src2->getReg());
+      // src2 and dst have the same physical class constraint; try to preserve
+      // the original src2 subclass if one were to exist.
+      SmallVector<MachineInstr *, 4> RewriteCandidates = {MFMA};
+      SmallSetVector<Register, 4> RewriteRegs;
+
+      // Make sure we reassign the MFMA we found the copy from first. We want
+      // to ensure dst ends up in the physreg we were originally copying to.
+      RewriteRegs.insert(MFMADstReg);
 
       // We've found av = COPY (MFMA), and need to verify that we can trivially
       // rewrite src2 to use the new AGPR. If we can't trivially replace it,
       // we're going to induce as many copies as we would have emitted in the
       // first place, as well as need to assign another register, and need to
       // figure out where to put them. The live range splitting is smarter than
       // anything we're doing here, so trust it did something reasonable.
-      const TargetRegisterClass *Src2ExceptRC =
-          recomputeRegClassExceptRewritable(Src2->getReg(), Src2VirtRegRC,
-                                            VirtRegRC);
-      if (!Src2ExceptRC) {
-        LLVM_DEBUG(dbgs() << "Could not recompute the regclass\n");
+      //
+      // Note recomputeRegClassExceptRewritable will consider the constraints of
+      // this MFMA's src2 as well as the src2/dst of any transitive MFMA users.
+      const TargetRegisterClass *DstExceptRC =
+          recomputeRegClassExceptRewritable(MFMADstReg, DstVirtRegRC, VirtRegRC,
+                                            RewriteCandidates, RewriteRegs);
+      if (!DstExceptRC) {
+        LLVM_DEBUG(dbgs() << "Could not recompute the regclass of "
+                          << printReg(MFMADstReg, &TRI) << '\n');
         continue;
       }
 
-      const TargetRegisterClass *NewSrc2ConstraintRC =
-          TII.getRegClass(TII.get(AGPROp), Src2->getOperandNo(), &TRI, MF);
-
-      // Try to constrain src2 to the replacement instruction candidate's
-      // register class.
-      const TargetRegisterClass *NewSrc2RC =
-          TRI.getCommonSubClass(Src2ExceptRC, NewSrc2ConstraintRC);
-      if (!NewSrc2RC) {
-        LLVM_DEBUG(dbgs() << "Other uses of " << printReg(Src2->getReg(), &TRI)
-                          << " are incompatible with replacement class\n");
-        continue;
+      // If src2 and dst are different registers, we need to also reassign the
+      // input to an available AGPR if it is compatible with all other uses.
+      //
+      // If we can't reassign it, we'd need to introduce a different copy
+      // which is likely worse than the copy we'd be saving.
+      //
+      // It's likely that the MFMA is used in sequence with other MFMAs; if we
+      // cannot migrate the full use/def chain of MFMAs, we would need to
+      // introduce intermediate copies somewhere. So we only make the
+      // transform if all the interfering MFMAs can also be migrated. Collect
+      // the set of rewritable MFMAs and check if we can assign an AGPR at
+      // that point.
+      //
+      // If any of the MFMAs aren't reassignable, we give up and rollback to
+      // the original register assignments.
+
+      using RecoloringStack =
+          SmallVector<std::pair<const LiveInterval *, MCRegister>, 8>;
+      RecoloringStack TentativeReassignments;
+
+      for (Register RewriteReg : RewriteRegs) {
+        LiveInterval &LI = LIS.getInterval(RewriteReg);
+        TentativeReassignments.push_back({&LI, VRM.getPhys(RewriteReg)});
+        LRM.unassign(LI);
       }
 
-      MRI.setRegClass(VReg, AssignedRC);
-      MRI.setRegClass(Src2->getReg(), NewSrc2RC);
-
-      CopySrcMI->setDesc(TII.get(AGPROp));
-
-      // Perform replacement of the register, rewriting the rewritable uses.
-      for (MachineInstr &UseMI :
-           make_early_inc_range(MRI.reg_instructions(CopySrcReg))) {
-        if (TII.isMAI(UseMI)) {
-          // Note the register we need to rewrite may still appear in src0/src1,
-          // but that's fine since those can use A or V anyway.
-          int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(UseMI.getOpcode());
-          if (ReplacementOp != -1)
-            UseMI.setDesc(TII.get(ReplacementOp));
+      if (!attemptReassignmentsToAGPR(RewriteRegs, PhysReg)) {
----------------
arsenm wrote:

Probably but I'm not sure the hinting is doing anything useful right now, other than a compile time hint to avoid trying the whole allocation order 

https://github.com/llvm/llvm-project/pull/153015


More information about the llvm-commits mailing list