[llvm] [AMDGPU] AsmPrinter: Unify arg handling (PR #151672)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 1 01:52:18 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Diana Picus (rovka)

<details>
<summary>Changes</summary>

When computing the number of registers required by entry functions, the
`AMDGPUAsmPrinter` needs to take into account both the register usage
computed by the `AMDGPUResourceUsageAnalysis` pass, and the number
of registers initialized by the hardware. At the moment, the way it computes
the latter is different for graphics vs compute, due to differences in the
implementation. For kernels, all the information needed is available in the
`SIMachineFunctionInfo`, but for graphics shaders we would iterate over the
`Function`  arguments in the `AMDGPUAsmPrinter`. This pretty much forces
us to keep the IR `Function` around forever and repeats some of the logic
from instruction selection.

This patch introduces 2 new members to `SIMachineFunctionInfo`, one
for SGPRs and one for VGPRs. Both will be computed during instruction
selection and then used during `AMDGPUAsmPrinter`, removing the need
to refer to the `Function` when printing assembly.

This patch is NFC except for the fact that we now add the extra SGPRs (VCC, 
XNACK etc) to the number of SGPRs computed for graphics entry points.
I'm not sure why these weren't included before. It would be nice if someone
could confirm if that was just an oversight or if we have some docs somewhere
that I haven't managed to find. Only one test is affected (its SGPR usage
increases because we now take into account the XNACK registers).

---
Full diff: https://github.com/llvm/llvm-project/pull/151672.diff


6 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp (+10-75) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp (+12) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+9) 
- (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h (+11) 
- (modified) llvm/test/CodeGen/AMDGPU/ps-shader-arg-count.ll (+4-2) 
- (modified) llvm/test/CodeGen/AMDGPU/wave_dispatch_regs.ll (+7-4) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp
index 668139383f56c..ca0164d682a4c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp
@@ -993,89 +993,24 @@ void AMDGPUAsmPrinter::getSIProgramInfo(SIProgramInfo &ProgInfo,
   const Function &F = MF.getFunction();
 
   // Ensure there are enough SGPRs and VGPRs for wave dispatch, where wave
-  // dispatch registers are function args.
-  unsigned WaveDispatchNumSGPR = 0, WaveDispatchNumVGPR = 0;
-
-  if (isShader(F.getCallingConv())) {
-    bool IsPixelShader =
-        F.getCallingConv() == CallingConv::AMDGPU_PS && !STM.isAmdHsaOS();
-
-    // Calculate the number of VGPR registers based on the SPI input registers
-    uint32_t InputEna = 0;
-    uint32_t InputAddr = 0;
-    unsigned LastEna = 0;
-
-    if (IsPixelShader) {
-      // Note for IsPixelShader:
-      // By this stage, all enabled inputs are tagged in InputAddr as well.
-      // We will use InputAddr to determine whether the input counts against the
-      // vgpr total and only use the InputEnable to determine the last input
-      // that is relevant - if extra arguments are used, then we have to honour
-      // the InputAddr for any intermediate non-enabled inputs.
-      InputEna = MFI->getPSInputEnable();
-      InputAddr = MFI->getPSInputAddr();
-
-      // We only need to consider input args up to the last used arg.
-      assert((InputEna || InputAddr) &&
-             "PSInputAddr and PSInputEnable should "
-             "never both be 0 for AMDGPU_PS shaders");
-      // There are some rare circumstances where InputAddr is non-zero and
-      // InputEna can be set to 0. In this case we default to setting LastEna
-      // to 1.
-      LastEna = InputEna ? llvm::Log2_32(InputEna) + 1 : 1;
-    }
+  // dispatch registers as function args.
+  unsigned WaveDispatchNumSGPR = MFI->getNumWaveDispatchSGPRs(),
+           WaveDispatchNumVGPR = MFI->getNumWaveDispatchVGPRs();
 
-    // FIXME: We should be using the number of registers determined during
-    // calling convention lowering to legalize the types.
-    const DataLayout &DL = F.getDataLayout();
-    unsigned PSArgCount = 0;
-    unsigned IntermediateVGPR = 0;
-    for (auto &Arg : F.args()) {
-      unsigned NumRegs = (DL.getTypeSizeInBits(Arg.getType()) + 31) / 32;
-      if (Arg.hasAttribute(Attribute::InReg)) {
-        WaveDispatchNumSGPR += NumRegs;
-      } else {
-        // If this is a PS shader and we're processing the PS Input args (first
-        // 16 VGPR), use the InputEna and InputAddr bits to define how many
-        // VGPRs are actually used.
-        // Any extra VGPR arguments are handled as normal arguments (and
-        // contribute to the VGPR count whether they're used or not).
-        if (IsPixelShader && PSArgCount < 16) {
-          if ((1 << PSArgCount) & InputAddr) {
-            if (PSArgCount < LastEna)
-              WaveDispatchNumVGPR += NumRegs;
-            else
-              IntermediateVGPR += NumRegs;
-          }
-          PSArgCount++;
-        } else {
-          // If there are extra arguments we have to include the allocation for
-          // the non-used (but enabled with InputAddr) input arguments
-          if (IntermediateVGPR) {
-            WaveDispatchNumVGPR += IntermediateVGPR;
-            IntermediateVGPR = 0;
-          }
-          WaveDispatchNumVGPR += NumRegs;
-        }
-      }
-    }
+  if (WaveDispatchNumSGPR) {
     ProgInfo.NumSGPR = AMDGPUMCExpr::createMax(
-        {ProgInfo.NumSGPR, CreateExpr(WaveDispatchNumSGPR)}, Ctx);
+        {ProgInfo.NumSGPR,
+         MCBinaryExpr::createAdd(CreateExpr(WaveDispatchNumSGPR), ExtraSGPRs,
+                                 Ctx)},
+        Ctx);
+  }
 
+  if (WaveDispatchNumVGPR) {
     ProgInfo.NumArchVGPR = AMDGPUMCExpr::createMax(
         {ProgInfo.NumVGPR, CreateExpr(WaveDispatchNumVGPR)}, Ctx);
 
     ProgInfo.NumVGPR = AMDGPUMCExpr::createTotalNumVGPR(
         ProgInfo.NumAccVGPR, ProgInfo.NumArchVGPR, Ctx);
-  } else if (isKernel(F.getCallingConv()) &&
-             MFI->getNumKernargPreloadedSGPRs()) {
-    // Consider cases where the total number of UserSGPRs with trailing
-    // allocated preload SGPRs, is greater than the number of explicitly
-    // referenced SGPRs.
-    const MCExpr *UserPlusExtraSGPRs = MCBinaryExpr::createAdd(
-        CreateExpr(MFI->getNumUserSGPRs()), ExtraSGPRs, Ctx);
-    ProgInfo.NumSGPR =
-        AMDGPUMCExpr::createMax({ProgInfo.NumSGPR, UserPlusExtraSGPRs}, Ctx);
   }
 
   // Adjust number of registers used to meet default/requested minimum/maximum
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
index 3d8d274f06246..64a9bde4e26e9 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
@@ -580,6 +580,9 @@ bool AMDGPUCallLowering::lowerFormalArgumentsKernel(
     ++i;
   }
 
+  if (Info->getNumKernargPreloadedSGPRs())
+    Info->setNumWaveDispatchSGPRs(Info->getNumUserSGPRs());
+
   TLI.allocateSpecialEntryInputVGPRs(CCInfo, MF, *TRI, *Info);
   TLI.allocateSystemSGPRs(CCInfo, MF, *Info, F.getCallingConv(), false);
   return true;
@@ -743,6 +746,15 @@ bool AMDGPUCallLowering::lowerFormalArguments(
   if (!determineAssignments(Assigner, SplitArgs, CCInfo))
     return false;
 
+  if (IsEntryFunc) {
+    // This assumes the registers are allocated by CCInfo in ascending order
+    // with no gaps.
+    Info->setNumWaveDispatchSGPRs(
+        CCInfo.getFirstUnallocated(AMDGPU::SGPR_32RegClass.getRegisters()));
+    Info->setNumWaveDispatchVGPRs(
+        CCInfo.getFirstUnallocated(AMDGPU::VGPR_32RegClass.getRegisters()));
+  }
+
   FormalArgHandler Handler(B, MRI);
   if (!handleAssignments(Handler, SplitArgs, CCInfo, ArgLocs, B))
     return false;
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 4d67e4a5cbcf9..9d2cb79b25d46 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -3099,6 +3099,15 @@ SDValue SITargetLowering::LowerFormalArguments(
   if (!IsKernel) {
     CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, isVarArg);
     CCInfo.AnalyzeFormalArguments(Splits, AssignFn);
+
+    // This assumes the registers are allocated by CCInfo in ascending order
+    // with no gaps.
+    Info->setNumWaveDispatchSGPRs(
+        CCInfo.getFirstUnallocated(AMDGPU::SGPR_32RegClass.getRegisters()));
+    Info->setNumWaveDispatchVGPRs(
+        CCInfo.getFirstUnallocated(AMDGPU::VGPR_32RegClass.getRegisters()));
+  } else if (Info->getNumKernargPreloadedSGPRs()) {
+    Info->setNumWaveDispatchSGPRs(Info->getNumUserSGPRs());
   }
 
   SmallVector<SDValue, 16> Chains;
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
index 08b0206d244fb..23166aed3a9ec 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
@@ -465,6 +465,9 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
   unsigned NumUserSGPRs = 0;
   unsigned NumSystemSGPRs = 0;
 
+  unsigned NumWaveDispatchSGPRs = 0;
+  unsigned NumWaveDispatchVGPRs = 0;
+
   bool HasSpilledSGPRs = false;
   bool HasSpilledVGPRs = false;
   bool HasNonSpillStackObjects = false;
@@ -991,6 +994,14 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
     return UserSGPRInfo.getNumKernargPreloadSGPRs();
   }
 
+  unsigned getNumWaveDispatchSGPRs() const { return NumWaveDispatchSGPRs; }
+
+  void setNumWaveDispatchSGPRs(unsigned Count) { NumWaveDispatchSGPRs = Count; }
+
+  unsigned getNumWaveDispatchVGPRs() const { return NumWaveDispatchVGPRs; }
+
+  void setNumWaveDispatchVGPRs(unsigned Count) { NumWaveDispatchVGPRs = Count; }
+
   Register getPrivateSegmentWaveByteOffsetSystemSGPR() const {
     return ArgInfo.PrivateSegmentWaveByteOffset.getRegister();
   }
diff --git a/llvm/test/CodeGen/AMDGPU/ps-shader-arg-count.ll b/llvm/test/CodeGen/AMDGPU/ps-shader-arg-count.ll
index 013b68a40f44b..99e5d0017f30b 100644
--- a/llvm/test/CodeGen/AMDGPU/ps-shader-arg-count.ll
+++ b/llvm/test/CodeGen/AMDGPU/ps-shader-arg-count.ll
@@ -1,5 +1,7 @@
-;RUN: llc < %s -mtriple=amdgcn-pal -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
-;RUN: llc < %s -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
+;RUN: llc -global-isel=1 < %s -mtriple=amdgcn-pal -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
+;RUN: llc -global-isel=1 < %s -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
+;RUN: llc -global-isel=0 < %s -mtriple=amdgcn-pal -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
+;RUN: llc -global-isel=0 < %s -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 | FileCheck %s --check-prefixes=CHECK
 
 ; ;CHECK-LABEL: {{^}}_amdgpu_ps_1_arg:
 ; ;CHECK: NumVgprs: 4
diff --git a/llvm/test/CodeGen/AMDGPU/wave_dispatch_regs.ll b/llvm/test/CodeGen/AMDGPU/wave_dispatch_regs.ll
index 76c331cdc8303..e2ef60bb80153 100644
--- a/llvm/test/CodeGen/AMDGPU/wave_dispatch_regs.ll
+++ b/llvm/test/CodeGen/AMDGPU/wave_dispatch_regs.ll
@@ -1,6 +1,9 @@
-; RUN: llc -mtriple=amdgcn--amdpal < %s | FileCheck -check-prefix=GCN -check-prefix=SI -enable-var-scope %s
-; RUN: llc -mtriple=amdgcn--amdpal -mcpu=tonga < %s | FileCheck -check-prefix=GCN -check-prefix=VI -enable-var-scope %s
-; RUN: llc -mtriple=amdgcn--amdpal -mcpu=gfx900 < %s | FileCheck -check-prefix=GCN -check-prefix=GFX9 -enable-var-scope %s
+; RUN: llc -global-isel=1 -mtriple=amdgcn--amdpal < %s | FileCheck -check-prefix=GCN -check-prefix=SI -enable-var-scope %s
+; RUN: llc -global-isel=1 -mtriple=amdgcn--amdpal -mcpu=tonga < %s | FileCheck -check-prefix=GCN -check-prefix=VI -enable-var-scope %s
+; RUN: llc -global-isel=1 -mtriple=amdgcn--amdpal -mcpu=gfx900 < %s | FileCheck -check-prefix=GCN -check-prefix=GFX9 -enable-var-scope %s
+; RUN: llc -global-isel=0 -mtriple=amdgcn--amdpal < %s | FileCheck -check-prefix=GCN -check-prefix=SI -enable-var-scope %s
+; RUN: llc -global-isel=0 -mtriple=amdgcn--amdpal -mcpu=tonga < %s | FileCheck -check-prefix=GCN -check-prefix=VI -enable-var-scope %s
+; RUN: llc -global-isel=0 -mtriple=amdgcn--amdpal -mcpu=gfx900 < %s | FileCheck -check-prefix=GCN -check-prefix=GFX9 -enable-var-scope %s
 
 ; This compute shader has input args that claim that it has 17 sgprs and 5 vgprs
 ; in wave dispatch. Ensure that the sgpr and vgpr counts in COMPUTE_PGM_RSRC1
@@ -17,7 +20,7 @@
 ; GCN-NEXT:         .scratch_memory_size: 0
 ; SI-NEXT:          .sgpr_count:     0x11
 ; VI-NEXT:          .sgpr_count:     0x60
-; GFX9-NEXT:        .sgpr_count:     0x11
+; GFX9-NEXT:        .sgpr_count:     0x15
 ; SI-NEXT:          .vgpr_count:     0x5
 ; VI-NEXT:          .vgpr_count:     0x5
 ; GFX9-NEXT:        .vgpr_count:     0x5

``````````

</details>


https://github.com/llvm/llvm-project/pull/151672


More information about the llvm-commits mailing list