[llvm] [AMDGPU] Refine GCNHazardRecognizer hasHazard() (PR #138841)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 7 03:23:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Carl Ritson (perlfu)
<details>
<summary>Changes</summary>
Remove recursion to avoid stack overflow on large CFGs.
Avoid worklist for hazard search within single MachineBasicBlock.
Ensure predecessors are visited for all state combinations.
---
Full diff: https://github.com/llvm/llvm-project/pull/138841.diff
1 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+48-31)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index aaefe27b1324f..644fbb77a495a 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -436,42 +436,55 @@ using IsExpiredFn = function_ref<bool(const MachineInstr &, int WaitStates)>;
using GetNumWaitStatesFn = function_ref<unsigned int(const MachineInstr &)>;
// Search for a hazard in a block and its predecessors.
+// StateT must implement getHashValue().
template <typename StateT>
static bool
-hasHazard(StateT State,
+hasHazard(StateT InitialState,
function_ref<HazardFnResult(StateT &, const MachineInstr &)> IsHazard,
function_ref<void(StateT &, const MachineInstr &)> UpdateState,
- const MachineBasicBlock *MBB,
- MachineBasicBlock::const_reverse_instr_iterator I,
- DenseSet<const MachineBasicBlock *> &Visited) {
- for (auto E = MBB->instr_rend(); I != E; ++I) {
- // No need to look at parent BUNDLE instructions.
- if (I->isBundle())
- continue;
+ const MachineBasicBlock *InitialMBB,
+ MachineBasicBlock::const_reverse_instr_iterator InitialI) {
+ SmallVector<std::pair<const MachineBasicBlock *, StateT>> Worklist;
+ DenseSet<std::pair<const MachineBasicBlock *, unsigned>> Visited;
+ const MachineBasicBlock *MBB = InitialMBB;
+ StateT State = InitialState;
+ auto I = InitialI;
+
+ for (;;) {
+ bool Expired = false;
+ for (auto E = MBB->instr_rend(); I != E; ++I) {
+ // No need to look at parent BUNDLE instructions.
+ if (I->isBundle())
+ continue;
- switch (IsHazard(State, *I)) {
- case HazardFound:
- return true;
- case HazardExpired:
- return false;
- default:
- // Continue search
- break;
- }
+ auto Result = IsHazard(State, *I);
+ if (Result == HazardFound)
+ return true;
+ if (Result == HazardExpired) {
+ Expired = true;
+ break;
+ }
- if (I->isInlineAsm() || I->isMetaInstruction())
- continue;
+ if (I->isInlineAsm() || I->isMetaInstruction())
+ continue;
- UpdateState(State, *I);
- }
+ UpdateState(State, *I);
+ }
- for (MachineBasicBlock *Pred : MBB->predecessors()) {
- if (!Visited.insert(Pred).second)
- continue;
+ if (!Expired) {
+ unsigned StateHash = State.getHashValue();
+ for (MachineBasicBlock *Pred : MBB->predecessors()) {
+ if (!Visited.insert(std::pair(Pred, StateHash)).second)
+ continue;
+ Worklist.emplace_back(Pred, State);
+ }
+ }
- if (hasHazard(State, IsHazard, UpdateState, Pred, Pred->instr_rbegin(),
- Visited))
- return true;
+ if (Worklist.empty())
+ break;
+
+ std::tie(MBB, State) = Worklist.pop_back_val();
+ I = MBB->instr_rbegin();
}
return false;
@@ -1624,6 +1637,10 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
SmallDenseMap<Register, int, 4> DefPos;
int ExecPos = std::numeric_limits<int>::max();
int VALUs = 0;
+
+ unsigned getHashValue() const {
+ return hash_combine(ExecPos, VALUs, hash_combine_range(DefPos));
+ }
};
StateType State;
@@ -1718,9 +1735,8 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
State.VALUs += 1;
};
- DenseSet<const MachineBasicBlock *> Visited;
if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
- std::next(MI->getReverseIterator()), Visited))
+ std::next(MI->getReverseIterator())))
return false;
BuildMI(*MI->getParent(), MI, MI->getDebugLoc(),
@@ -1761,6 +1777,8 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
struct StateType {
int VALUs = 0;
int TRANS = 0;
+
+ unsigned getHashValue() const { return hash_combine(VALUs, TRANS); }
};
StateType State;
@@ -1796,9 +1814,8 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
State.TRANS += 1;
};
- DenseSet<const MachineBasicBlock *> Visited;
if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
- std::next(MI->getReverseIterator()), Visited))
+ std::next(MI->getReverseIterator())))
return false;
// Hazard is observed - insert a wait on va_dst counter to ensure hazard is
``````````
</details>
https://github.com/llvm/llvm-project/pull/138841
More information about the llvm-commits
mailing list