[llvm] [WIP][AMDGPU] Improve the handling of `inreg` arguments (PR #133614)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 2 22:14:16 PDT 2025


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/133614

>From ff4fc411ecffceda44639757447b7326c43fc911 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Thu, 3 Apr 2025 01:12:49 -0400
Subject: [PATCH] [WIP][AMDGPU] Improve the handling of `inreg` arguments

When SGPRs available for `inreg` argument passing run out, the compiler silently
falls back to using whole VGPRs to pass those arguments. Ideally, instead of
using whole VGPRs, we should pack `inreg` arguments into individual lanes of
VGPRs.

This PR introduces `InregVGPRSpiller`, which handles this packing. It uses
`v_writelane` at the call site to place `inreg` arguments into specific VGPR
lanes, and then extracts them in the callee using `v_readlane`.

Fixes #130443 and #129071.
---
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp    | 123 +++++++++++++++++--
 llvm/lib/Target/AMDGPU/SIISelLowering.h      |   2 +-
 llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll |  63 ++++++++++
 3 files changed, 179 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll

diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index a583a5cb990e7..251269be592b0 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -2841,6 +2841,91 @@ void SITargetLowering::insertCopiesSplitCSR(
   }
 }
 
+/// Classes for spilling inreg VGPR arguments.
+///
+/// When an argument marked inreg is pushed to a VGPR, it indicates that the
+/// available SGPRs for argument passing have been exhausted. In such cases, it
+/// is preferable to pack multiple inreg arguments into individual lanes of
+/// VGPRs instead of assigning each directly to separate VGPRs.
+///
+/// Spilling involves two parts: the caller-side (call site) and the
+/// callee-side. Both must follow the same method for selecting registers and
+/// lanes, ensuring that an argument written at the call site matches exactly
+/// with the one read at the callee.
+///
+/// The spilling class for the caller-side that lowers packing of call site
+/// arguments.
+class InregVPGRSpillerCallee {
+  CCState &State;
+  SelectionDAG &DAG;
+  MachineFunction &MF;
+
+  Register SrcReg;
+  SDValue SrcVal;
+  unsigned CurLane = 0;
+
+public:
+  InregVPGRSpillerCallee(SelectionDAG &DAG, MachineFunction &MF, CCState &State)
+      : State(State), DAG(DAG), MF(MF) {}
+
+  SDValue read(SDValue Chain, const SDLoc &SL, Register &Reg, EVT VT) {
+    if (SrcVal) {
+      State.DeallocateReg(Reg);
+    } else {
+      Reg = MF.addLiveIn(Reg, &AMDGPU::VGPR_32RegClass);
+      SrcReg = Reg;
+      SrcVal = DAG.getCopyFromReg(Chain, SL, Reg, VT);
+    }
+    // According to the calling convention, only SGPR4–SGPR29 should be used for
+    // passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
+    // arguments must not exceed 26.
+    assert(CurLane < 26 && "more than expected VGPR inreg arguments");
+    SmallVector<SDValue, 4> Operands{
+        DAG.getTargetConstant(Intrinsic::amdgcn_readlane, SL, MVT::i32),
+        DAG.getRegister(SrcReg, VT),
+        DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands);
+  }
+};
+
+/// The spilling class for the caller-side that lowers packing of call site
+/// arguments.
+class InregVPGRSpillerCallSite {
+  CCState &State;
+
+  Register DstReg;
+  SDValue Glue;
+  unsigned CurLane = 0;
+
+  SelectionDAG &DAG;
+  MachineFunction &MF;
+
+public:
+  InregVPGRSpillerCallSite(SelectionDAG &DAG, MachineFunction &MF,
+                           CCState &State)
+      : State(State), DAG(DAG), MF(MF) {}
+
+  std::pair<SDValue, SDValue> write(SDValue Chain, const SDLoc &SL,
+                                    Register &Reg, SDValue Val, SDValue InGlue,
+                                    EVT VT) {
+    if (DstReg.isValid()) {
+      Reg = DstReg;
+    } else {
+      DstReg = Reg;
+      Glue = DAG.getCopyToReg(Chain, SL, Reg, Val, InGlue).getValue(1);
+    }
+    // According to the calling convention, only SGPR4–SGPR29 should be used for
+    // passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
+    // arguments must not exceed 26.
+    assert(CurLane < 26 && "more than expected VGPR inreg arguments");
+    SmallVector<SDValue, 4> Operands{
+        DAG.getTargetConstant(Intrinsic::amdgcn_writelane, SL, MVT::i32),
+        DAG.getRegister(DstReg, VT), Val,
+        DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
+    return {DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands), Glue};
+  }
+};
+
 SDValue SITargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -2963,6 +3048,7 @@ SDValue SITargetLowering::LowerFormalArguments(
   // FIXME: Alignment of explicit arguments totally broken with non-0 explicit
   // kern arg offset.
   const Align KernelArgBaseAlign = Align(16);
+  InregVPGRSpillerCallee Spiller(DAG, MF, CCInfo);
 
   for (unsigned i = 0, e = Ins.size(), ArgIdx = 0; i != e; ++i) {
     const ISD::InputArg &Arg = Ins[i];
@@ -3130,8 +3216,17 @@ SDValue SITargetLowering::LowerFormalArguments(
       llvm_unreachable("Unexpected register class in LowerFormalArguments!");
     EVT ValVT = VA.getValVT();
 
-    Reg = MF.addLiveIn(Reg, RC);
-    SDValue Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
+    SDValue Val;
+    // If an argument is marked inreg but gets pushed to a VGPR, it indicates
+    // we've run out of SGPRs for argument passing. In such cases, we'd prefer
+    // to start packing inreg arguments into individual lanes of VGPRs, rather
+    // than placing them directly into VGPRs.
+    if (RC == &AMDGPU::VGPR_32RegClass && Arg.Flags.isInReg()) {
+      Val = Spiller.read(Chain, DL, Reg, VT);
+    } else {
+      Reg = MF.addLiveIn(Reg, RC);
+      Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
+    }
 
     if (Arg.Flags.isSRet()) {
       // The return object should be reasonably addressable.
@@ -3373,7 +3468,7 @@ SDValue SITargetLowering::LowerCallResult(
 // from the explicit user arguments present in the IR.
 void SITargetLowering::passSpecialInputs(
     CallLoweringInfo &CLI, CCState &CCInfo, const SIMachineFunctionInfo &Info,
-    SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
+    SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
     SmallVectorImpl<SDValue> &MemOpChains, SDValue Chain) const {
   // If we don't have a call site, this was a call inserted by
   // legalization. These can never use special inputs.
@@ -3817,7 +3912,7 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
   }
 
   const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
-  SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
+  SmallVector<std::pair<Register, SDValue>, 8> RegsToPass;
   SmallVector<SDValue, 8> MemOpChains;
 
   // Analyze operands of the call, assigning locations to each operand.
@@ -3875,6 +3970,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   MVT PtrVT = MVT::i32;
 
+  InregVPGRSpillerCallSite Spiller(DAG, MF, CCInfo);
+
   // Walk the register/memloc assignments, inserting copies/loads.
   for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
     CCValAssign &VA = ArgLocs[i];
@@ -3988,8 +4085,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
   SDValue InGlue;
 
   unsigned ArgIdx = 0;
-  for (auto [Reg, Val] : RegsToPass) {
-    if (ArgIdx++ >= NumSpecialInputs &&
+  for (auto &[Reg, Val] : RegsToPass) {
+    if (ArgIdx >= NumSpecialInputs &&
         (IsChainCallConv || !Val->isDivergent()) && TRI->isSGPRPhysReg(Reg)) {
       // For chain calls, the inreg arguments are required to be
       // uniform. Speculatively Insert a readfirstlane in case we cannot prove
@@ -4008,8 +4105,18 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
                         ReadfirstlaneArgs);
     }
 
-    Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
-    InGlue = Chain.getValue(1);
+    if (ArgIdx >= NumSpecialInputs &&
+        Outs[ArgIdx - NumSpecialInputs].Flags.isInReg() &&
+        AMDGPU::VGPR_32RegClass.contains(Reg)) {
+      std::tie(Chain, InGlue) =
+          Spiller.write(Chain, DL, Reg, Val, InGlue,
+                        ArgLocs[ArgIdx - NumSpecialInputs].getLocVT());
+    } else {
+      Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
+      InGlue = Chain.getValue(1);
+    }
+
+    ++ArgIdx;
   }
 
   // We don't usually want to end the call-sequence here because we would tidy
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h
index dc0634331caf9..0990d818783fc 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h
@@ -406,7 +406,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
     CallLoweringInfo &CLI,
     CCState &CCInfo,
     const SIMachineFunctionInfo &Info,
-    SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
+    SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
     SmallVectorImpl<SDValue> &MemOpChains,
     SDValue Chain) const;
 
diff --git a/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll b/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll
new file mode 100644
index 0000000000000..dbbe34fe351f7
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx950 -o - %s | FileCheck %s
+
+; arg3 is v0, arg4 is in v1. These should be packed into a lane and extracted with readlane
+define i32 @callee(<8 x i32> inreg %arg0, <8 x i32> inreg %arg1, <2 x i32> inreg %arg2, i32 inreg %arg3, i32 inreg %arg4) {
+; CHECK-LABEL: callee:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_readlane_b32 s0, v0, 1
+; CHECK-NEXT:    v_readlane_b32 s1, v0, 0
+; CHECK-NEXT:    s_sub_i32 s0, s1, s0
+; CHECK-NEXT:    v_mov_b32_e32 v0, s0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %add = sub i32 %arg3, %arg4
+  ret i32 %add
+}
+
+define amdgpu_kernel void @kernel(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4, ptr %p) {
+; CHECK-LABEL: kernel:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_mov_b32 s12, s8
+; CHECK-NEXT:    s_add_u32 s8, s4, 0x58
+; CHECK-NEXT:    s_mov_b32 s13, s9
+; CHECK-NEXT:    s_addc_u32 s9, s5, 0
+; CHECK-NEXT:    s_load_dwordx16 s[36:51], s[4:5], 0x0
+; CHECK-NEXT:    s_load_dwordx4 s[28:31], s[4:5], 0x40
+; CHECK-NEXT:    s_load_dwordx2 s[34:35], s[4:5], 0x50
+; CHECK-NEXT:    s_getpc_b64 s[4:5]
+; CHECK-NEXT:    s_add_u32 s4, s4, callee at gotpcrel32@lo+4
+; CHECK-NEXT:    s_addc_u32 s5, s5, callee at gotpcrel32@hi+12
+; CHECK-NEXT:    s_load_dwordx2 s[52:53], s[4:5], 0x0
+; CHECK-NEXT:    s_mov_b32 s14, s10
+; CHECK-NEXT:    s_mov_b64 s[10:11], s[6:7]
+; CHECK-NEXT:    s_mov_b64 s[4:5], s[0:1]
+; CHECK-NEXT:    s_mov_b64 s[6:7], s[2:3]
+; CHECK-NEXT:    v_mov_b32_e32 v31, v0
+; CHECK-NEXT:    s_waitcnt lgkmcnt(0)
+; CHECK-NEXT:    s_mov_b32 s0, s36
+; CHECK-NEXT:    s_mov_b32 s1, s37
+; CHECK-NEXT:    s_mov_b32 s2, s38
+; CHECK-NEXT:    s_mov_b32 s3, s39
+; CHECK-NEXT:    s_mov_b32 s16, s40
+; CHECK-NEXT:    s_mov_b32 s17, s41
+; CHECK-NEXT:    s_mov_b32 s18, s42
+; CHECK-NEXT:    s_mov_b32 s19, s43
+; CHECK-NEXT:    s_mov_b32 s20, s44
+; CHECK-NEXT:    s_mov_b32 s21, s45
+; CHECK-NEXT:    s_mov_b32 s22, s46
+; CHECK-NEXT:    s_mov_b32 s23, s47
+; CHECK-NEXT:    s_mov_b32 s24, s48
+; CHECK-NEXT:    s_mov_b32 s25, s49
+; CHECK-NEXT:    s_mov_b32 s26, s50
+; CHECK-NEXT:    s_mov_b32 s27, s51
+; CHECK-NEXT:    v_mov_b32_e32 v0, s30
+; CHECK-NEXT:    s_mov_b32 s32, 0
+; CHECK-NEXT:    s_swappc_b64 s[30:31], s[52:53]
+; CHECK-NEXT:    v_mov_b64_e32 v[2:3], s[34:35]
+; CHECK-NEXT:    flat_store_dword v[2:3], v0
+; CHECK-NEXT:    s_endpgm
+  %ret = call i32 @callee(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4)
+  store i32 %ret, ptr %p
+  ret void
+}



More information about the llvm-commits mailing list