[llvm] [WIP][AMDGPU] Improve the handling of `inreg` arguments (PR #133614)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 3 14:58:02 PDT 2025
https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/133614
>From bf0af2599672a03a67f39532fd30709058286120 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Thu, 3 Apr 2025 17:57:45 -0400
Subject: [PATCH] [WIP][AMDGPU] Improve the handling of `inreg` arguments
When SGPRs available for `inreg` argument passing run out, the compiler silently
falls back to using whole VGPRs to pass those arguments. Ideally, instead of
using whole VGPRs, we should pack `inreg` arguments into individual lanes of
VGPRs.
This PR introduces `InregVGPRSpiller`, which handles this packing. It uses
`v_writelane` at the call site to place `inreg` arguments into specific VGPR
lanes, and then extracts them in the callee using `v_readlane`.
Fixes #130443 and #129071.
---
llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 129 ++++++++++++++++++-
llvm/lib/Target/AMDGPU/SIISelLowering.h | 2 +-
llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll | 107 +++++++++++++++
3 files changed, 230 insertions(+), 8 deletions(-)
create mode 100644 llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index a583a5cb990e7..daf9589c1666c 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -2841,6 +2841,95 @@ void SITargetLowering::insertCopiesSplitCSR(
}
}
+/// Classes for spilling inreg VGPR arguments.
+///
+/// When an argument marked inreg is pushed to a VGPR, it indicates that the
+/// available SGPRs for argument passing have been exhausted. In such cases, it
+/// is preferable to pack multiple inreg arguments into individual lanes of
+/// VGPRs instead of assigning each directly to separate VGPRs.
+///
+/// Spilling involves two parts: the caller-side (call site) and the
+/// callee-side. Both must follow the same method for selecting registers and
+/// lanes, ensuring that an argument written at the call site matches exactly
+/// with the one read at the callee.
+
+/// The spilling class for the caller-side that lowers packing of call site
+/// arguments.
+class InregVPGRSpillerCallee {
+ CCState &State;
+ SelectionDAG &DAG;
+ MachineFunction &MF;
+
+ Register SrcReg;
+ SDValue SrcVal;
+ unsigned CurLane = 0;
+
+public:
+ InregVPGRSpillerCallee(SelectionDAG &DAG, MachineFunction &MF, CCState &State)
+ : State(State), DAG(DAG), MF(MF) {}
+
+ SDValue readLane(SDValue Chain, const SDLoc &SL, Register &Reg, EVT VT) {
+ if (SrcVal) {
+ State.DeallocateReg(Reg);
+ } else {
+ Reg = MF.addLiveIn(Reg, &AMDGPU::VGPR_32RegClass);
+ SrcReg = Reg;
+ SrcVal = DAG.getCopyFromReg(Chain, SL, Reg, VT);
+ }
+ // According to the calling convention, only SGPR4–SGPR29 should be used for
+ // passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
+ // arguments must not exceed 26.
+ assert(CurLane < 26 && "more than expected VGPR inreg arguments");
+ SmallVector<SDValue, 4> Operands{
+ DAG.getTargetConstant(Intrinsic::amdgcn_readlane, SL, MVT::i32),
+ DAG.getRegister(SrcReg, VT),
+ DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands);
+ }
+};
+
+/// The spilling class for the caller-side that lowers packing of call site
+/// arguments.
+class InregVPGRSpillerCallSite {
+ Register DstReg;
+ SDValue LastWrite;
+ unsigned CurLane = 0;
+
+ SelectionDAG &DAG;
+ MachineFunction &MF;
+
+public:
+ InregVPGRSpillerCallSite(SelectionDAG &DAG, MachineFunction &MF)
+ : DAG(DAG), MF(MF) {}
+
+ void writeLane(const SDLoc &SL, Register &Reg, SDValue Val, EVT VT) {
+ if (DstReg.isValid())
+ Reg = DstReg;
+ else
+ DstReg = Reg;
+ // According to the calling convention, only SGPR4–SGPR29 should be used for
+ // passing 'inreg' function arguments. Therefore, the number of 'inreg' VGPR
+ // arguments must not exceed 26.
+ assert(CurLane < 26 && "more than expected VGPR inreg arguments");
+ SmallVector<SDValue, 4> Operands{
+ DAG.getTargetConstant(Intrinsic::amdgcn_writelane, SL, MVT::i32), Val,
+ DAG.getTargetConstant(CurLane++, SL, MVT::i32)};
+ if (!LastWrite) {
+ Register VReg = MF.getRegInfo().getLiveInVirtReg(DstReg);
+ Operands.push_back(DAG.getRegister(VReg, VT));
+ } else {
+ Operands.push_back(LastWrite);
+ }
+ LastWrite = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, VT, Operands);
+ }
+
+ SDValue finalize(SDValue Chain, const SDLoc &SL, SDValue InGlue) {
+ if (!LastWrite)
+ return LastWrite;
+ return DAG.getCopyToReg(Chain, SL, DstReg, LastWrite, InGlue);
+ }
+};
+
SDValue SITargetLowering::LowerFormalArguments(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -2963,6 +3052,7 @@ SDValue SITargetLowering::LowerFormalArguments(
// FIXME: Alignment of explicit arguments totally broken with non-0 explicit
// kern arg offset.
const Align KernelArgBaseAlign = Align(16);
+ InregVPGRSpillerCallee Spiller(DAG, MF, CCInfo);
for (unsigned i = 0, e = Ins.size(), ArgIdx = 0; i != e; ++i) {
const ISD::InputArg &Arg = Ins[i];
@@ -3130,8 +3220,17 @@ SDValue SITargetLowering::LowerFormalArguments(
llvm_unreachable("Unexpected register class in LowerFormalArguments!");
EVT ValVT = VA.getValVT();
- Reg = MF.addLiveIn(Reg, RC);
- SDValue Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
+ SDValue Val;
+ // If an argument is marked inreg but gets pushed to a VGPR, it indicates
+ // we've run out of SGPRs for argument passing. In such cases, we'd prefer
+ // to start packing inreg arguments into individual lanes of VGPRs, rather
+ // than placing them directly into VGPRs.
+ if (RC == &AMDGPU::VGPR_32RegClass && Arg.Flags.isInReg()) {
+ Val = Spiller.readLane(Chain, DL, Reg, VT);
+ } else {
+ Reg = MF.addLiveIn(Reg, RC);
+ Val = DAG.getCopyFromReg(Chain, DL, Reg, VT);
+ }
if (Arg.Flags.isSRet()) {
// The return object should be reasonably addressable.
@@ -3373,7 +3472,7 @@ SDValue SITargetLowering::LowerCallResult(
// from the explicit user arguments present in the IR.
void SITargetLowering::passSpecialInputs(
CallLoweringInfo &CLI, CCState &CCInfo, const SIMachineFunctionInfo &Info,
- SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
+ SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
SmallVectorImpl<SDValue> &MemOpChains, SDValue Chain) const {
// If we don't have a call site, this was a call inserted by
// legalization. These can never use special inputs.
@@ -3817,7 +3916,7 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
}
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
- SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
+ SmallVector<std::pair<Register, SDValue>, 8> RegsToPass;
SmallVector<SDValue, 8> MemOpChains;
// Analyze operands of the call, assigning locations to each operand.
@@ -3875,6 +3974,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
MVT PtrVT = MVT::i32;
+ InregVPGRSpillerCallSite Spiller(DAG, MF);
+
// Walk the register/memloc assignments, inserting copies/loads.
for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
CCValAssign &VA = ArgLocs[i];
@@ -3988,8 +4089,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue InGlue;
unsigned ArgIdx = 0;
- for (auto [Reg, Val] : RegsToPass) {
- if (ArgIdx++ >= NumSpecialInputs &&
+ for (auto &[Reg, Val] : RegsToPass) {
+ if (ArgIdx >= NumSpecialInputs &&
(IsChainCallConv || !Val->isDivergent()) && TRI->isSGPRPhysReg(Reg)) {
// For chain calls, the inreg arguments are required to be
// uniform. Speculatively Insert a readfirstlane in case we cannot prove
@@ -4008,7 +4109,21 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
ReadfirstlaneArgs);
}
- Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
+ if (ArgIdx >= NumSpecialInputs &&
+ Outs[ArgIdx - NumSpecialInputs].Flags.isInReg() &&
+ AMDGPU::VGPR_32RegClass.contains(Reg)) {
+ Spiller.writeLane(DL, Reg, Val,
+ ArgLocs[ArgIdx - NumSpecialInputs].getLocVT());
+ } else {
+ Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
+ InGlue = Chain.getValue(1);
+ }
+
+ ++ArgIdx;
+ }
+
+ if (SDValue R = Spiller.finalize(Chain, DL, InGlue)) {
+ Chain = R;
InGlue = Chain.getValue(1);
}
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h
index dc0634331caf9..0990d818783fc 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h
@@ -406,7 +406,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
CallLoweringInfo &CLI,
CCState &CCInfo,
const SIMachineFunctionInfo &Info,
- SmallVectorImpl<std::pair<unsigned, SDValue>> &RegsToPass,
+ SmallVectorImpl<std::pair<Register, SDValue>> &RegsToPass,
SmallVectorImpl<SDValue> &MemOpChains,
SDValue Chain) const;
diff --git a/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll b/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll
new file mode 100644
index 0000000000000..b2effe42176c8
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/inreg-vgpr-spill.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx950 -o - %s | FileCheck %s
+
+; arg3 is v0, arg4 is in v1. These should be packed into a lane and extracted with readlane
+define i32 @callee(<8 x i32> inreg %arg0, <8 x i32> inreg %arg1, <2 x i32> inreg %arg2, i32 inreg %arg3, i32 inreg %arg4) {
+; CHECK-LABEL: callee:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_readlane_b32 s0, v0, 1
+; CHECK-NEXT: v_readlane_b32 s1, v0, 0
+; CHECK-NEXT: s_sub_i32 s0, s1, s0
+; CHECK-NEXT: v_mov_b32_e32 v0, s0
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %add = sub i32 %arg3, %arg4
+ ret i32 %add
+}
+
+define amdgpu_kernel void @kernel(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4, ptr %p) {
+; CHECK-LABEL: kernel:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_load_dwordx16 s[36:51], s[4:5], 0x0
+; CHECK-NEXT: s_load_dwordx4 s[28:31], s[4:5], 0x40
+; CHECK-NEXT: s_load_dwordx2 s[34:35], s[4:5], 0x50
+; CHECK-NEXT: s_mov_b32 s12, s8
+; CHECK-NEXT: s_add_u32 s8, s4, 0x58
+; CHECK-NEXT: s_mov_b32 s13, s9
+; CHECK-NEXT: s_addc_u32 s9, s5, 0
+; CHECK-NEXT: v_mov_b32_e32 v1, v0
+; CHECK-NEXT: s_waitcnt lgkmcnt(0)
+; CHECK-NEXT: v_writelane_b32 v1, s30, 0
+; CHECK-NEXT: s_getpc_b64 s[4:5]
+; CHECK-NEXT: s_add_u32 s4, s4, callee at gotpcrel32@lo+4
+; CHECK-NEXT: s_addc_u32 s5, s5, callee at gotpcrel32@hi+12
+; CHECK-NEXT: v_writelane_b32 v1, s31, 1
+; CHECK-NEXT: s_load_dwordx2 s[30:31], s[4:5], 0x0
+; CHECK-NEXT: s_mov_b32 s14, s10
+; CHECK-NEXT: s_mov_b64 s[10:11], s[6:7]
+; CHECK-NEXT: s_mov_b64 s[4:5], s[0:1]
+; CHECK-NEXT: s_mov_b64 s[6:7], s[2:3]
+; CHECK-NEXT: v_mov_b32_e32 v31, v0
+; CHECK-NEXT: s_mov_b32 s0, s36
+; CHECK-NEXT: s_mov_b32 s1, s37
+; CHECK-NEXT: s_mov_b32 s2, s38
+; CHECK-NEXT: s_mov_b32 s3, s39
+; CHECK-NEXT: s_mov_b32 s16, s40
+; CHECK-NEXT: s_mov_b32 s17, s41
+; CHECK-NEXT: s_mov_b32 s18, s42
+; CHECK-NEXT: s_mov_b32 s19, s43
+; CHECK-NEXT: s_mov_b32 s20, s44
+; CHECK-NEXT: s_mov_b32 s21, s45
+; CHECK-NEXT: s_mov_b32 s22, s46
+; CHECK-NEXT: s_mov_b32 s23, s47
+; CHECK-NEXT: s_mov_b32 s24, s48
+; CHECK-NEXT: s_mov_b32 s25, s49
+; CHECK-NEXT: s_mov_b32 s26, s50
+; CHECK-NEXT: s_mov_b32 s27, s51
+; CHECK-NEXT: v_mov_b32_e32 v0, v1
+; CHECK-NEXT: s_mov_b32 s32, 0
+; CHECK-NEXT: s_waitcnt lgkmcnt(0)
+; CHECK-NEXT: s_swappc_b64 s[30:31], s[30:31]
+; CHECK-NEXT: v_mov_b64_e32 v[2:3], s[34:35]
+; CHECK-NEXT: flat_store_dword v[2:3], v0
+; CHECK-NEXT: s_endpgm
+ %ret = call i32 @callee(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4)
+ store i32 %ret, ptr %p
+ ret void
+}
+
+define i32 @caller(<8 x i32> inreg %arg0, <8 x i32> inreg %arg1, <2 x i32> inreg %arg2, i32 inreg %arg3, i32 inreg %arg4) {
+; CHECK-LABEL: caller:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_mov_b32 s42, s33
+; CHECK-NEXT: s_mov_b32 s33, s32
+; CHECK-NEXT: s_xor_saveexec_b64 s[40:41], -1
+; CHECK-NEXT: scratch_store_dword off, v1, s33 ; 4-byte Folded Spill
+; CHECK-NEXT: s_mov_b64 exec, s[40:41]
+; CHECK-NEXT: v_readlane_b32 s41, v0, 0
+; CHECK-NEXT: s_add_i32 s32, s32, 16
+; CHECK-NEXT: v_readlane_b32 s40, v0, 1
+; CHECK-NEXT: v_writelane_b32 v0, s41, 0
+; CHECK-NEXT: v_writelane_b32 v1, s30, 0
+; CHECK-NEXT: v_writelane_b32 v0, s40, 1
+; CHECK-NEXT: s_getpc_b64 s[40:41]
+; CHECK-NEXT: s_add_u32 s40, s40, callee at gotpcrel32@lo+4
+; CHECK-NEXT: s_addc_u32 s41, s41, callee at gotpcrel32@hi+12
+; CHECK-NEXT: s_load_dwordx2 s[40:41], s[40:41], 0x0
+; CHECK-NEXT: v_writelane_b32 v1, s31, 1
+; CHECK-NEXT: s_waitcnt lgkmcnt(0)
+; CHECK-NEXT: s_swappc_b64 s[30:31], s[40:41]
+; CHECK-NEXT: v_readlane_b32 s31, v1, 1
+; CHECK-NEXT: v_readlane_b32 s30, v1, 0
+; CHECK-NEXT: s_mov_b32 s32, s33
+; CHECK-NEXT: s_xor_saveexec_b64 s[0:1], -1
+; CHECK-NEXT: scratch_load_dword v1, off, s33 ; 4-byte Folded Reload
+; CHECK-NEXT: s_mov_b64 exec, s[0:1]
+; CHECK-NEXT: s_mov_b32 s33, s42
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %ret = call i32 @callee(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4)
+ ret i32 %ret
+}
+
+define i32 @tail_caller(<8 x i32> inreg %arg0, <8 x i32> inreg %arg1, <2 x i32> inreg %arg2, i32 inreg %arg3, i32 inreg %arg4) {
+ %ret = tail call i32 @callee(<8 x i32> %arg0, <8 x i32> %arg1, <2 x i32> %arg2, i32 %arg3, i32 %arg4)
+ ret i32 %ret
+}
More information about the llvm-commits
mailing list