[llvm-branch-commits] [llvm] WIP: AMDGPU: Always select the VGPR version of MFMAs (PR #145025)

Matt Arsenault via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jun 26 22:26:29 PDT 2025


https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/145025

>From 96e2c10e532af88c9b18f03bacd701d946f9fe84 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 9 Dec 2024 15:41:44 -0600
Subject: [PATCH] WIP: AMDGPU: Always select the VGPR version of MFMAs

We do not want to use AGPRs unless absolutely required due
to register pressure. Rely on a post-regalloc pass to replace
VGPR MFMAs with the AGPR version if it avoids the copies introduced
due to live range splitting.
---
 .../Target/AMDGPU/AMDGPURegisterBankInfo.cpp  | 10 ++--
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp     | 20 +------
 .../Target/AMDGPU/SIMachineFunctionInfo.cpp   |  6 --
 .../lib/Target/AMDGPU/SIMachineFunctionInfo.h |  6 --
 llvm/lib/Target/AMDGPU/VOP3PInstructions.td   | 55 ++++++++++---------
 5 files changed, 35 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b20760c356263..a8cabcb5831c8 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4867,31 +4867,29 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // for srcA/srcB?
       //
       // vdst, srcA, srcB, srcC
-      const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
       OpdsMapping[0] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
       OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
       OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
       OpdsMapping[4] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
       break;
     }
     case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
-      const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
       OpdsMapping[0] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
 
       OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
       OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
       OpdsMapping[4] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 8d7dcf8c4a064..405535b7d27db 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -16227,7 +16227,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
 
   MachineFunction *MF = MI.getParent()->getParent();
   MachineRegisterInfo &MRI = MF->getRegInfo();
-  SIMachineFunctionInfo *Info = MF->getInfo<SIMachineFunctionInfo>();
 
   if (TII->isVOP3(MI.getOpcode())) {
     // Make sure constant bus requirements are respected.
@@ -16238,7 +16237,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
     // use between vgpr and agpr as agpr tuples tend to be big.
     if (!MI.getDesc().operands().empty()) {
       unsigned Opc = MI.getOpcode();
-      bool HasAGPRs = Info->mayNeedAGPRs();
       const SIRegisterInfo *TRI = Subtarget->getRegisterInfo();
       int16_t Src2Idx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2);
       for (auto I :
@@ -16246,7 +16244,7 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
             AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src1), Src2Idx}) {
         if (I == -1)
           break;
-        if ((I == Src2Idx) && (HasAGPRs))
+        if (I == Src2Idx)
           break;
         MachineOperand &Op = MI.getOperand(I);
         if (!Op.isReg() || !Op.getReg().isVirtual())
@@ -16280,22 +16278,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
             TII->legalizeOpWithMove(MI, Src1Idx);
         }
       }
-
-      if (!HasAGPRs)
-        return;
-
-      // Resolve the rest of AV operands to AGPRs.
-      if (auto *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2)) {
-        if (Src2->isReg() && Src2->getReg().isVirtual()) {
-          auto *RC = TRI->getRegClassForReg(MRI, Src2->getReg());
-          if (TRI->isVectorSuperClass(RC)) {
-            auto *NewRC = TRI->getEquivalentAGPRClass(RC);
-            MRI.setRegClass(Src2->getReg(), NewRC);
-            if (Src2->isTied())
-              MRI.setRegClass(MI.getOperand(0).getReg(), NewRC);
-          }
-        }
-      }
     }
 
     return;
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
index 67ad28661da43..a33dafac85b08 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
@@ -69,12 +69,6 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F,
     PSInputAddr = AMDGPU::getInitialPSInputAddr(F);
   }
 
-  MayNeedAGPRs = ST.hasMAIInsts();
-  if (ST.hasGFX90AInsts() &&
-      ST.getMaxNumVGPRs(F) <= AMDGPU::VGPR_32RegClass.getNumRegs() &&
-      !mayUseAGPRs(F))
-    MayNeedAGPRs = false; // We will select all MAI with VGPR operands.
-
   if (AMDGPU::isChainCC(CC)) {
     // Chain functions don't receive an SP from their caller, but are free to
     // set one up. For now, we can use s32 to match what amdgpu_gfx functions
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
index 274a60adb8d07..1b45737f45106 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
@@ -497,8 +497,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
   // user arguments. This is an offset from the KernargSegmentPtr.
   bool ImplicitArgPtr : 1;
 
-  bool MayNeedAGPRs : 1;
-
   // The hard-wired high half of the address of the global information table
   // for AMDPAL OS type. 0xffffffff represents no hard-wired high half, since
   // current hardware only allows a 16 bit value.
@@ -1172,10 +1170,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
 
   unsigned getMaxMemoryClusterDWords() const { return MaxMemoryClusterDWords; }
 
-  bool mayNeedAGPRs() const {
-    return MayNeedAGPRs;
-  }
-
   // \returns true if a function has a use of AGPRs via inline asm or
   // has a call which may use it.
   bool mayUseAGPRs(const Function &F) const;
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index e8db879ca5077..6b6b74234cfef 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -856,17 +856,11 @@ defvar MayNotNeedAGPRs_gisel = [{
   return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
 }];
 
-class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
-                  bit Scaled = false> :
-  MAIFrag<Op, MayNeedAGPRs, HasAbid, Scaled> {
-  let GISelPredicateCode = MayNeedAGPRs_gisel;
-}
+class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
+    : MAIFrag<Op, [{}], HasAbid, Scaled> {}
 
-class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
-                   bit Scaled = false> :
-  MAIFrag<Op, MayNotNeedAGPRs, HasAbid, Scaled> {
-  let GISelPredicateCode = MayNotNeedAGPRs_gisel;
-}
+class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
+    : MAIFrag<Op, [{}], HasAbid, Scaled> {}
 
 let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
   defm V_ACCVGPR_READ_B32  : VOP3Inst<"v_accvgpr_read_b32",  VOPProfileAccRead>;
@@ -917,10 +911,14 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
                          !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
                  MFMATable<0, "AGPR", NAME # "_e64">;
 
-      let OtherPredicates = [isGFX90APlus], Mnemonic = OpName in
-      def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
-                                !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
-                        MFMATable<0, "VGPR", NAME # "_vgprcd_e64", NAME # "_e64">;
+      let OtherPredicates = [isGFX90APlus], Mnemonic = OpName,
+          AddedComplexity = 10 in def _vgprcd_e64
+          : MAIInst<OpName#"_vgprcd",
+                    !cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
+                    !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag,
+                        VgprMAIFrag<node, HasAbid, Scaled>),
+                    Scaled>,
+          MFMATable<0, "VGPR", NAME#"_vgprcd_e64", NAME#"_e64">;
     }
 
     if NoDstOverlap then {
@@ -931,16 +929,22 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
                                  !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
                          MFMATable<1, "AGPR", NAME # "_e64", NAME # "_mac_e64">;
 
-        let OtherPredicates = [isGFX90APlus] in
-        def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
-                                      !if(!eq(node, null_frag), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
-                              MFMATable<1, "VGPR", NAME # "_vgprcd_e64", NAME # "_mac_e64">;
+        let OtherPredicates = [isGFX90APlus],
+            AddedComplexity = 10 in def _mac_vgprcd_e64
+            : MAIInst<OpName#"_mac_vgprcd",
+                      !cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
+                      !if(!eq(node, null_frag), null_frag,
+                          VgprMAIFrag<node, HasAbid, Scaled>),
+                      Scaled>,
+            MFMATable<1, "VGPR", NAME#"_vgprcd_e64", NAME#"_mac_e64">;
       }
     }
   } // End isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1
 }
 
-// Provide a wrapper around MAIInst that provides the appended operands from V_MFMA_LD_SCALE_B32
+// Provide a wrapper around MAIInst that provides the appended operands from
+// V_MFMA_LD_SCALE_B32 AGPR variants are never selected; VGPR is selected and
+// may later be rewritten to AGPR.
 multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOperator node> {
   defvar VariantSuffix = !subst(!toupper(OpName), "", NAME); // Drop the main opcode name prefix to get the "_fN_fM" suffix.
   defvar UnscaledOpName = UnscaledOpName_#VariantSuffix;
@@ -949,9 +953,9 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
 
   defvar NoDstOverlap = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl).NoDstOverlap;
 
-  def _e64 : ScaledMAIInst<OpName,
-        !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, HasAbid, true>)>,
-      MFMATable<0, "AGPR", NAME # "_e64">;
+  def _e64
+      : ScaledMAIInst<OpName, !cast<MAIInst>(UnscaledOpName#"_e64"), null_frag>,
+        MFMATable<0, "AGPR", NAME#"_e64">;
 
   def _vgprcd_e64 : ScaledMAIInst<OpName # "_vgprcd",
           !cast<MAIInst>(UnscaledOpName#"_vgprcd_e64"), !if(NoDstOverlap, null_frag, VgprMAIFrag<node, HasAbid, true>)>,
@@ -961,9 +965,10 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
    let Constraints = !if(NoDstOverlap, "$vdst = $src2", ""),
        isConvertibleToThreeAddress = NoDstOverlap,
        Mnemonic = UnscaledOpName_ in {
-     def _mac_e64 : ScaledMAIInst<OpName # "_mac",
-          !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, HasAbid, true>>,
-        MFMATable<1, "AGPR", NAME # "_e64">;
+     def _mac_e64
+         : ScaledMAIInst<OpName#"_mac",
+                         !cast<MAIInst>(UnscaledOpName#"_mac_e64"), null_frag>,
+           MFMATable<1, "AGPR", NAME#"_e64">;
 
      def _mac_vgprcd_e64 : ScaledMAIInst<OpName # " _mac_vgprcd",
           !cast<MAIInst>(UnscaledOpName # "_mac_vgprcd_e64"), VgprMAIFrag<node, HasAbid, true>>,



More information about the llvm-branch-commits mailing list