[llvm] [AMDGPU] Rewrite GFX12 SGPR hazard handling to dedicated pass (PR #118750)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 16:32:13 PST 2024


================
@@ -0,0 +1,487 @@
+//===- AMDGPUWaitSGPRHazards.cpp - Insert waits for SGPR read hazards -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Insert s_wait_alu instructions to mitigate SGPR read hazards on GFX12.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "GCNSubtarget.h"
+#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
+#include "SIInstrInfo.h"
+#include "llvm/ADT/SetVector.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "amdgpu-wait-sgpr-hazards"
+
+static cl::opt<bool> GlobalEnableSGPRHazardWaits(
+    "amdgpu-sgpr-hazard-wait", cl::init(true), cl::Hidden,
+    cl::desc("Enable required s_wait_alu on SGPR hazards"));
+
+static cl::opt<bool> GlobalCullSGPRHazardsOnFunctionBoundary(
+    "amdgpu-sgpr-hazard-boundary-cull", cl::init(false), cl::Hidden,
+    cl::desc("Cull hazards on function boundaries"));
+
+static cl::opt<bool>
+    GlobalCullSGPRHazardsAtMemWait("amdgpu-sgpr-hazard-mem-wait-cull",
+                                   cl::init(false), cl::Hidden,
+                                   cl::desc("Cull hazards on memory waits"));
+
+static cl::opt<unsigned> GlobalCullSGPRHazardsMemWaitThreshold(
+    "amdgpu-sgpr-hazard-mem-wait-cull-threshold", cl::init(8), cl::Hidden,
+    cl::desc("Number of tracked SGPRs before initiating hazard cull on memory "
+             "wait"));
+
+namespace {
+
+class AMDGPUWaitSGPRHazards : public MachineFunctionPass {
+public:
+  static char ID;
+
+  const SIInstrInfo *TII;
+  const SIRegisterInfo *TRI;
+  const MachineRegisterInfo *MRI;
+  bool Wave64;
+
+  bool EnableSGPRHazardWaits;
+  bool CullSGPRHazardsOnFunctionBoundary;
+  bool CullSGPRHazardsAtMemWait;
+  unsigned CullSGPRHazardsMemWaitThreshold;
+
+  AMDGPUWaitSGPRHazards() : MachineFunctionPass(ID) {}
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  // Return the numeric ID 0-127 for a given SGPR.
+  static std::optional<unsigned> sgprNumber(Register Reg,
+                                            const SIRegisterInfo &TRI) {
+    switch (Reg) {
+    case AMDGPU::M0:
+    case AMDGPU::EXEC:
+    case AMDGPU::EXEC_LO:
+    case AMDGPU::EXEC_HI:
+    case AMDGPU::SGPR_NULL:
+    case AMDGPU::SGPR_NULL64:
+      return {};
+    default:
+      break;
+    }
+    unsigned RegN = TRI.getEncodingValue(Reg);
+    if (RegN > 127)
+      return {};
+    return RegN;
+  }
+
+  static inline bool IsVCC(Register Reg) {
+    return (Reg == AMDGPU::VCC || Reg == AMDGPU::VCC_LO ||
+            Reg == AMDGPU::VCC_HI);
+  }
+
+  // Adjust global offsets for instructions bundled with S_GETPC_B64 after
+  // insertion of a new instruction.
+  static void updateGetPCBundle(MachineInstr *NewMI) {
+    if (!NewMI->isBundled())
+      return;
+
+    // Find start of bundle.
+    auto I = NewMI->getIterator();
+    while (I->isBundledWithPred())
+      I--;
+    if (I->isBundle())
+      I++;
+
+    // Bail if this is not an S_GETPC bundle.
+    if (I->getOpcode() != AMDGPU::S_GETPC_B64)
+      return;
+
+    // Update offsets of any references in the bundle.
+    const unsigned NewBytes = 4;
+    assert(NewMI->getOpcode() == AMDGPU::S_WAITCNT_DEPCTR &&
+           "Unexpected instruction insertion in bundle");
+    auto NextMI = std::next(NewMI->getIterator());
+    auto End = NewMI->getParent()->end();
+    while (NextMI != End && NextMI->isBundledWithPred()) {
+      for (auto &Operand : NextMI->operands()) {
+        if (Operand.isGlobal())
+          Operand.setOffset(Operand.getOffset() + NewBytes);
+      }
+      NextMI++;
+    }
+  }
+
+  struct HazardState {
+    static constexpr unsigned None = 0;
+    static constexpr unsigned SALU = (1 << 0);
+    static constexpr unsigned VALU = (1 << 1);
+
+    std::bitset<64> Tracked;      // SGPR banks ever read by VALU
+    std::bitset<128> SALUHazards; // SGPRs with uncommitted values from SALU
+    std::bitset<128> VALUHazards; // SGPRs with uncommitted values from VALU
+    unsigned VCCHazard = None;    // Source of current VCC writes
+    bool ActiveFlat = false;      // Has unwaited flat instructions
+
+    bool merge(const HazardState &RHS) {
+      HazardState Orig(*this);
+
+      Tracked |= RHS.Tracked;
+      SALUHazards |= RHS.SALUHazards;
+      VALUHazards |= RHS.VALUHazards;
+      VCCHazard |= RHS.VCCHazard;
+      ActiveFlat |= RHS.ActiveFlat;
+
+      return (*this != Orig);
+    }
+
+    bool operator==(const HazardState &RHS) const {
+      return Tracked == RHS.Tracked && SALUHazards == RHS.SALUHazards &&
+             VALUHazards == RHS.VALUHazards && VCCHazard == RHS.VCCHazard &&
+             ActiveFlat == RHS.ActiveFlat;
+    }
+    bool operator!=(const HazardState &RHS) const { return !(*this == RHS); }
+  };
+
+  struct BlockHazardState {
+    HazardState In;
+    HazardState Out;
+  };
+
+  DenseMap<const MachineBasicBlock *, BlockHazardState> BlockState;
+
+  static constexpr unsigned WAVE32_NOPS = 4;
+  static constexpr unsigned WAVE64_NOPS = 8;
+
+  void insertHazardCull(MachineBasicBlock &MBB,
+                        MachineBasicBlock::instr_iterator &MI) {
+    assert(!MI->isBundled());
+    unsigned Count = Wave64 ? WAVE64_NOPS : WAVE32_NOPS;
+    while (Count--)
+      BuildMI(MBB, MI, MI->getDebugLoc(), TII->get(AMDGPU::DS_NOP));
+  }
+
+  bool runOnMachineBasicBlock(MachineBasicBlock &MBB, bool Emit) {
+    enum { WA_VALU = 0x1, WA_SALU = 0x2, WA_VCC = 0x4 };
+
+    HazardState State = BlockState[&MBB].In;
+    SmallSet<Register, 8> SeenRegs;
+    bool Emitted = false;
+    unsigned DsNops = 0;
+
+    for (MachineBasicBlock::instr_iterator MI = MBB.instr_begin(),
+                                           E = MBB.instr_end();
+         MI != E; ++MI) {
+      // Clear tracked SGPRs if sufficient DS_NOPs occur
----------------
arsenm wrote:

Need to skip meta instructions?

https://github.com/llvm/llvm-project/pull/118750


More information about the llvm-commits mailing list