[llvm] [AMDGPU] ISel & PEI for whole wave functions (PR #145858)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 2 03:18:38 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Diana Picus (rovka)

<details>
<summary>Changes</summary>

Whole wave functions are functions that will run with a full EXEC mask.
They will not be invoked directly, but instead will be launched by way
of a new intrinsic, `llvm.amdgcn.call.whole.wave` (to be added in
a future patch). These functions are meant as an alternative to the
`llvm.amdgcn.init.whole.wave` or `llvm.amdgcn.strict.wwm` intrinsics.

Whole wave functions will set EXEC to -1 in the prologue and restore the
original value of EXEC in the epilogue. They must have a special first
argument, `i1 %active`, that is going to be mapped to EXEC. They may
have either the default calling convention or amdgpu_gfx. The inactive
lanes need to be preserved for all registers used, active lanes only for
the CSRs.

At the IR level, arguments to a whole wave function (other than
`%active`) contain poison in their inactive lanes. Likewise, the return
value for the inactive lanes is poison.

This patch contains the following work:
* 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN
  used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return
  a SReg_1 representing `%active`, which needs to be passed into
  SI_WHOLE_WAVE_FUNC_RETURN.
* SelectionDAG support for generating these 2 new pseudos and the
  special handling of %active. Since the return may be in a different
  basic block, it's difficult to add the virtual reg for %active to
  SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF
  which is later replaced via a custom inserter.
* Expansion of the 2 pseudos during prolog/epilog insertion. PEI also
  marks any used VGPRs are WWM registers, which are then spilled and
  restored with the usual logic.

Future patches will include the `llvm.amdgcn.call.whole.wave` intrinsic
and a lot of optimization work (especially in order to reduce spills around
function calls).

---

Patch is 212.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145858.diff


41 Files Affected:

- (modified) llvm/docs/AMDGPUUsage.rst (+14) 
- (modified) llvm/include/llvm/AsmParser/LLToken.h (+1) 
- (modified) llvm/include/llvm/IR/CallingConv.h (+8) 
- (modified) llvm/lib/AsmParser/LLLexer.cpp (+1) 
- (modified) llvm/lib/AsmParser/LLParser.cpp (+3) 
- (modified) llvm/lib/IR/AsmWriter.cpp (+3) 
- (modified) llvm/lib/IR/Function.cpp (+1) 
- (modified) llvm/lib/IR/Verifier.cpp (+10) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp (+30-3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+4) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+4) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h (+6) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td (+11) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+4) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+4) 
- (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+1-1) 
- (modified) llvm/lib/Target/AMDGPU/SIFrameLowering.cpp (+77-10) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+34-6) 
- (modified) llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp (+1) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+14) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.h (+2) 
- (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+40) 
- (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp (+7-2) 
- (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h (+6) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp (+2) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+2-1) 
- (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUPALMetadata.cpp (+1) 
- (modified) llvm/test/Bitcode/compatibility.ll (+4) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-whole-wave-functions.mir (+40) 
- (added) llvm/test/CodeGen/AMDGPU/irtranslator-whole-wave-functions.ll (+103) 
- (added) llvm/test/CodeGen/AMDGPU/isel-whole-wave-functions.ll (+190) 
- (added) llvm/test/CodeGen/AMDGPU/whole-wave-functions-pei.mir (+448) 
- (added) llvm/test/CodeGen/AMDGPU/whole-wave-functions.ll (+2414) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/long-branch-reg-all-sgpr-used.ll (+2) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-after-pei.ll (+1) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-long-branch-reg-debug.ll (+1) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-long-branch-reg.ll (+1) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info-no-ir.mir (+4) 
- (modified) llvm/test/CodeGen/MIR/AMDGPU/machine-function-info.ll (+4) 
- (modified) llvm/test/Verifier/amdgpu-cc.ll (+33) 


``````````diff
diff --git a/llvm/docs/AMDGPUUsage.rst b/llvm/docs/AMDGPUUsage.rst
index c5b9bd9de66e1..19357635ecfc1 100644
--- a/llvm/docs/AMDGPUUsage.rst
+++ b/llvm/docs/AMDGPUUsage.rst
@@ -1844,6 +1844,20 @@ The AMDGPU backend supports the following calling conventions:
                                      ..TODO::
                                      Describe.
 
+     ``amdgpu_gfx_whole_wave``       Used for AMD graphics targets. Functions with this calling convention
+                                     cannot be used as entry points. They must have an i1 as the first argument,
+                                     which will be mapped to the value of EXEC on entry into the function. Other
+                                     arguments will contain poison in their inactive lanes. Similarly, the return
+                                     value for the inactive lanes is poison.
+
+                                     The function will run with all lanes enabled, i.e. EXEC will be set to -1 in the
+                                     prologue and restored to its original value in the epilogue. The inactive lanes
+                                     will be preserved for all the registers used by the function. Active lanes only
+                                     will only be preserved for the callee saved registers.
+
+                                     In all other respects, functions with this calling convention behave like
+                                     ``amdgpu_gfx`` functions.
+
      ``amdgpu_gs``                   Used for Mesa/AMDPAL geometry shaders.
                                      ..TODO::
                                      Describe.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index c7e4bdf3ff811..a2311d2ac285d 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -181,6 +181,7 @@ enum Kind {
   kw_amdgpu_cs_chain_preserve,
   kw_amdgpu_kernel,
   kw_amdgpu_gfx,
+  kw_amdgpu_gfx_whole_wave,
   kw_tailcc,
   kw_m68k_rtdcc,
   kw_graalcc,
diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h
index d68491eb5535c..ef761eb1aed73 100644
--- a/llvm/include/llvm/IR/CallingConv.h
+++ b/llvm/include/llvm/IR/CallingConv.h
@@ -284,6 +284,9 @@ namespace CallingConv {
     RISCV_VLSCall_32768 = 122,
     RISCV_VLSCall_65536 = 123,
 
+    // Calling convention for AMDGPU whole wave functions.
+    AMDGPU_Gfx_WholeWave = 124,
+
     /// The highest possible ID. Must be some 2^k - 1.
     MaxID = 1023
   };
@@ -294,8 +297,13 @@ namespace CallingConv {
 /// directly or indirectly via a call-like instruction.
 constexpr bool isCallableCC(CallingConv::ID CC) {
   switch (CC) {
+  // Called with special intrinsics:
+  // llvm.amdgcn.cs.chain
   case CallingConv::AMDGPU_CS_Chain:
   case CallingConv::AMDGPU_CS_ChainPreserve:
+  // llvm.amdgcn.call.whole.wave
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+  // Hardware entry points:
   case CallingConv::AMDGPU_CS:
   case CallingConv::AMDGPU_ES:
   case CallingConv::AMDGPU_GS:
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index ce813e1d7b1c4..520c6a00a9c07 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -679,6 +679,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(amdgpu_cs_chain_preserve);
   KEYWORD(amdgpu_kernel);
   KEYWORD(amdgpu_gfx);
+  KEYWORD(amdgpu_gfx_whole_wave);
   KEYWORD(tailcc);
   KEYWORD(m68k_rtdcc);
   KEYWORD(graalcc);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index c5e166cef6da6..b09696497cc4e 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -2274,6 +2274,9 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
     CC = CallingConv::AMDGPU_CS_ChainPreserve;
     break;
   case lltok::kw_amdgpu_kernel:  CC = CallingConv::AMDGPU_KERNEL; break;
+  case lltok::kw_amdgpu_gfx_whole_wave:
+    CC = CallingConv::AMDGPU_Gfx_WholeWave;
+    break;
   case lltok::kw_tailcc:         CC = CallingConv::Tail; break;
   case lltok::kw_m68k_rtdcc:     CC = CallingConv::M68k_RTD; break;
   case lltok::kw_graalcc:        CC = CallingConv::GRAAL; break;
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 7828ba45ec27f..3ce892ecbff19 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -404,6 +404,9 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
     break;
   case CallingConv::AMDGPU_KERNEL: Out << "amdgpu_kernel"; break;
   case CallingConv::AMDGPU_Gfx:    Out << "amdgpu_gfx"; break;
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+    Out << "amdgpu_gfx_whole_wave";
+    break;
   case CallingConv::M68k_RTD:      Out << "m68k_rtdcc"; break;
   case CallingConv::RISCV_VectorCall:
     Out << "riscv_vector_cc";
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 3e7fcbb983738..b1e8dd716063f 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1226,6 +1226,7 @@ bool llvm::CallingConv::supportsNonVoidReturnType(CallingConv::ID CC) {
   case CallingConv::AArch64_SVE_VectorCall:
   case CallingConv::WASM_EmscriptenInvoke:
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
   case CallingConv::M68k_INTR:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 9cab88b09779a..32ce1880f2fdd 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2975,6 +2975,16 @@ void Verifier::visitFunction(const Function &F) {
           "perfect forwarding!",
           &F);
     break;
+  case CallingConv::AMDGPU_Gfx_WholeWave:
+    Check(F.arg_size() != 0 && F.arg_begin()->getType()->isIntegerTy(1),
+          "Calling convention requires first argument to be i1", &F);
+    Check(!F.arg_begin()->hasInRegAttr(),
+          "Calling convention requires first argument to not be inreg", &F);
+    Check(!F.isVarArg(),
+          "Calling convention does not support varargs or "
+          "perfect forwarding!",
+          &F);
+    break;
   }
 
   // Check that the argument values match the function type for this function...
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
index 14101e57f5143..b4ea3c81b3b6e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
@@ -374,8 +374,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
     return true;
   }
 
-  unsigned ReturnOpc =
-      IsShader ? AMDGPU::SI_RETURN_TO_EPILOG : AMDGPU::SI_RETURN;
+  const bool IsWholeWave = MFI->isWholeWaveFunction();
+  unsigned ReturnOpc = IsWholeWave ? AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN
+                       : IsShader  ? AMDGPU::SI_RETURN_TO_EPILOG
+                                   : AMDGPU::SI_RETURN;
   auto Ret = B.buildInstrNoInsert(ReturnOpc);
 
   if (!FLI.CanLowerReturn)
@@ -383,6 +385,10 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
   else if (!lowerReturnVal(B, Val, VRegs, Ret))
     return false;
 
+  if (IsWholeWave) {
+    addOriginalExecToReturn(B.getMF(), Ret);
+  }
+
   // TODO: Handle CalleeSavedRegsViaCopy.
 
   B.insertInstr(Ret);
@@ -632,6 +638,17 @@ bool AMDGPUCallLowering::lowerFormalArguments(
     if (DL.getTypeStoreSize(Arg.getType()) == 0)
       continue;
 
+    if (Info->isWholeWaveFunction() && Idx == 0) {
+      assert(VRegs[Idx].size() == 1 && "Expected only one register");
+
+      // The first argument for whole wave functions is the original EXEC value.
+      B.buildInstr(AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP)
+          .addDef(VRegs[Idx][0]);
+
+      ++Idx;
+      continue;
+    }
+
     const bool InReg = Arg.hasAttribute(Attribute::InReg);
 
     if (Arg.hasAttribute(Attribute::SwiftSelf) ||
@@ -1347,6 +1364,7 @@ bool AMDGPUCallLowering::lowerTailCall(
   SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
 
   if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+      Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave &&
       !AMDGPU::isChainCC(Info.CallConv)) {
     // With a fixed ABI, allocate fixed registers before user arguments.
     if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
@@ -1524,7 +1542,8 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // after the ordinary user argument registers.
   SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
 
-  if (Info.CallConv != CallingConv::AMDGPU_Gfx) {
+  if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
+      Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave) {
     // With a fixed ABI, allocate fixed registers before user arguments.
     if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
       return false;
@@ -1592,3 +1611,11 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   return true;
 }
+
+void AMDGPUCallLowering::addOriginalExecToReturn(
+    MachineFunction &MF, MachineInstrBuilder &Ret) const {
+  const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+  const SIInstrInfo *TII = ST.getInstrInfo();
+  const MachineInstr *Setup = TII->getWholeWaveFunctionSetup(MF);
+  Ret.addReg(Setup->getOperand(0).getReg());
+}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
index a6e801f2a547b..e0033d59d10bb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h
@@ -37,6 +37,9 @@ class AMDGPUCallLowering final : public CallLowering {
   bool lowerReturnVal(MachineIRBuilder &B, const Value *Val,
                       ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const;
 
+  void addOriginalExecToReturn(MachineFunction &MF,
+                               MachineInstrBuilder &Ret) const;
+
 public:
   AMDGPUCallLowering(const AMDGPUTargetLowering &TLI);
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..c5063c4de4ad3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -300,6 +300,10 @@ def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_SSHORT, SIsbuffer_load_short>;
 def : GINodeEquiv<G_AMDGPU_S_BUFFER_LOAD_USHORT, SIsbuffer_load_ushort>;
 def : GINodeEquiv<G_AMDGPU_S_BUFFER_PREFETCH, SIsbuffer_prefetch>;
 
+def : GINodeEquiv<G_AMDGPU_WHOLE_WAVE_FUNC_SETUP, AMDGPUwhole_wave_setup>;
+// G_AMDGPU_WHOLE_WAVE_FUNC_RETURN is simpler than AMDGPUwhole_wave_return,
+// so we don't mark it as equivalent.
+
 class GISelSop2Pat <
   SDPatternOperator node,
   Instruction inst,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index d75c7a178b4a8..0421ed87e61f4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -1138,6 +1138,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForCall(CallingConv::ID CC,
   case CallingConv::Cold:
     return CC_AMDGPU_Func;
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
     return CC_SI_Gfx;
   case CallingConv::AMDGPU_KERNEL:
   case CallingConv::SPIR_KERNEL:
@@ -1163,6 +1164,7 @@ CCAssignFn *AMDGPUCallLowering::CCAssignFnForReturn(CallingConv::ID CC,
   case CallingConv::AMDGPU_LS:
     return RetCC_SI_Shader;
   case CallingConv::AMDGPU_Gfx:
+  case CallingConv::AMDGPU_Gfx_WholeWave:
     return RetCC_SI_Gfx;
   case CallingConv::C:
   case CallingConv::Fast:
@@ -5777,6 +5779,8 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(BUFFER_ATOMIC_FMIN)
   NODE_NAME_CASE(BUFFER_ATOMIC_FMAX)
   NODE_NAME_CASE(BUFFER_ATOMIC_COND_SUB_U32)
+  NODE_NAME_CASE(WHOLE_WAVE_SETUP)
+  NODE_NAME_CASE(WHOLE_WAVE_RETURN)
   }
   return nullptr;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 0dd2183b72b24..5716711de3402 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -607,6 +607,12 @@ enum NodeType : unsigned {
   BUFFER_ATOMIC_FMAX,
   BUFFER_ATOMIC_COND_SUB_U32,
   LAST_MEMORY_OPCODE = BUFFER_ATOMIC_COND_SUB_U32,
+
+  // Set up a whole wave function.
+  WHOLE_WAVE_SETUP,
+
+  // Return from a whole wave function.
+  WHOLE_WAVE_RETURN,
 };
 
 } // End namespace AMDGPUISD
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
index ce58e93a15207..e305f08925cc6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
@@ -348,6 +348,17 @@ def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",
 
 def AMDGPUperm_impl : SDNode<"AMDGPUISD::PERM", AMDGPUDTIntTernaryOp, []>;
 
+// Marks the entry into a whole wave function.
+def AMDGPUwhole_wave_setup : SDNode<
+  "AMDGPUISD::WHOLE_WAVE_SETUP", SDTypeProfile<1, 0, [SDTCisInt<0>]>,
+  [SDNPHasChain, SDNPSideEffect]>;
+
+// Marks the return from a whole wave function.
+def AMDGPUwhole_wave_return : SDNode<
+  "AMDGPUISD::WHOLE_WAVE_RETURN", SDTNone,
+  [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
+>;
+
 // SI+ export
 def AMDGPUExportOp : SDTypeProfile<0, 8, [
   SDTCisInt<0>,       // i8 tgt
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index b632b16f5c198..d86e7735a07bc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -4141,6 +4141,10 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
     return true;
   case AMDGPU::G_AMDGPU_WAVE_ADDRESS:
     return selectWaveAddress(I);
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN: {
+    I.setDesc(TII.get(AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN));
+    return true;
+  }
   case AMDGPU::G_STACKRESTORE:
     return selectStackRestore(I);
   case AMDGPU::G_PHI:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index b20760c356263..a07699ae1eb23 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -5458,6 +5458,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case AMDGPU::G_PREFETCH:
     OpdsMapping[0] = getSGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
     break;
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP:
+  case AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN:
+    OpdsMapping[0] = AMDGPU::getValueMapping(AMDGPU::VCCRegBankID, 1);
+    break;
   }
 
   return getInstructionMapping(/*ID*/1, /*Cost*/1,
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index bc95d3f040e1d..098c2dc2405df 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -3155,7 +3155,7 @@ bool GCNHazardRecognizer::fixRequiredExportPriority(MachineInstr *MI) {
   // Check entry priority at each export (as there will only be a few).
   // Note: amdgpu_gfx can only be a callee, so defer to caller setprio.
   bool Changed = false;
-  if (CC != CallingConv::AMDGPU_Gfx)
+  if (CC != CallingConv::AMDGPU_Gfx && CC != CallingConv::AMDGPU_Gfx_WholeWave)
     Changed = ensureEntrySetPrio(MF, NormalPriority, TII);
 
   auto NextMI = std::next(It);
diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
index 6a3867937d57f..b88df50c6c999 100644
--- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
@@ -946,8 +946,18 @@ static Register buildScratchExecCopy(LiveRegUnits &LiveUnits,
 
   initLiveUnits(LiveUnits, TRI, FuncInfo, MF, MBB, MBBI, IsProlog);
 
-  ScratchExecCopy = findScratchNonCalleeSaveRegister(
-      MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+  if (FuncInfo->isWholeWaveFunction()) {
+    // Whole wave functions already have a copy of the original EXEC mask that
+    // we can use.
+    assert(IsProlog && "Epilog should look at return, not setup");
+    ScratchExecCopy =
+        TII->getWholeWaveFunctionSetup(MF)->getOperand(0).getReg();
+    assert(ScratchExecCopy && "Couldn't find copy of EXEC");
+  } else {
+    ScratchExecCopy = findScratchNonCalleeSaveRegister(
+        MRI, LiveUnits, *TRI.getWaveMaskRegClass());
+  }
+
   if (!ScratchExecCopy)
     report_fatal_error("failed to find free scratch register");
 
@@ -996,10 +1006,15 @@ void SIFrameLowering::emitCSRSpillStores(
       };
 
   StoreWWMRegisters(WWMScratchRegs);
+
+  auto EnableAllLanes = [&]() {
+    unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+    BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+  };
+
   if (!WWMCalleeSavedRegs.empty()) {
     if (ScratchExecCopy) {
-      unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
-      BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
+      EnableAllLanes();
     } else {
       ScratchExecCopy = buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
                                              /*IsProlog*/ true,
@@ -1008,7 +1023,18 @@ void SIFrameLowering::emitCSRSpillStores(
   }
 
   StoreWWMRegisters(WWMCalleeSavedRegs);
-  if (ScratchExecCopy) {
+  if (FuncInfo->isWholeWaveFunction()) {
+    // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose, so we can remove
+    // it now. If we have already saved some WWM CSR registers, then the EXEC is
+    // already -1 and we don't need to do anything else. Otherwise, set EXEC to
+    // -1 here.
+    if (!ScratchExecCopy)
+      buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true,
+                           /*EnableInactiveLanes*/ true);
+    else if (WWMCalleeSavedRegs.empty())
+      EnableAllLanes();
+    TII->getWholeWaveFunctionSetup(MF)->eraseFromParent();
+  } else if (ScratchExecCopy) {
     // FIXME: Split block and make terminator.
     unsigned ExecMov = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
     BuildMI(MBB, MBBI, DL, TII->get(ExecMov), TRI.getExec())
@@ -1083,11 +1109,6 @@ void SIFrameLowering::emitCSRSpillRestores(
   Register ScratchExecCopy;
   SmallVector<std::pair<Register, int>, 2> WWMCalleeSavedRegs, WWMScratchRegs;
   FuncInfo->splitWWMSpillRegisters(MF, WWMCalleeSavedRegs, WWMScratchRegs);
-  if (!WWMScratchRegs.empty())
-    ScratchExecCopy =
-        buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
-                             /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
-
   auto RestoreWWMRegisters =
       [&](SmallVectorImpl<std::pair<Register, int>> &WWMRegs) {
         for (const auto &Reg : WWMRegs) {
@@ -1098,6 +1119,36 @@ void SIFrameLowering::emitCSRSpillRestores(
         }
       };
 
+  if (FuncInfo->isWholeWaveFunction()) {
+    // For whole wave functions, the EXEC is already -1 at this point.
+    // Therefore, we can restore the CSR WWM registers right away.
+    RestoreWWMRegisters(WWMCalleeSavedRegs);
+
+    // The original EXEC is the first operand of the return instruction.
+    const MachineInstr &Return = MBB.instr_back();
+    assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
+           "Unexpected return inst");
+    Register OrigExec = Return.getOperand(0).getReg();
+
+    if (!WWMScratchRegs.empty()) {
+      unsigned XorOpc = ST.isWave32() ? AMDGPU::S_XOR_B32 : AMDGPU::S_XOR_B64;
+      BuildMI(MBB, MBBI, DL, TII->get(XorOpc), TRI.getExec())
+          .addReg(OrigExec)
+          .addImm(-1);
+      RestoreWWMRegisters(WWMScratchRegs);
+    }
+
+    // Restore original EXEC.
+    unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
+    BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
+    return;
+  }
+
+  if (!WWMScratchRegs.empty())
+    ScratchExecCopy =
+        buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
+                             /*IsProlog*/ false, /*EnableInactiveLanes*/ true);
+
   RestoreWWMRegisters(WWMScratchRegs);
   if (!WWMCalleeSavedRegs.empty()) {
     if (ScratchExecCopy) {
@@ -1634,6 +1685,7 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
         NeedExecCopyReservedReg = true;
       else if (MI....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145858


More information about the llvm-commits mailing list