[llvm] [AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince (NFCI) (PR #108347)
Carl Ritson via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 12 01:35:53 PDT 2024
https://github.com/perlfu created https://github.com/llvm/llvm-project/pull/108347
Refactor getWaitStatesSince:
* Remove recursion to avoid excess stack usage with newer hazards
* Ensure algorithm always returns minima to hazards by allowing revisiting of blocks if a shorter path is encountered.
* Reduce the search space by actively pruning deeper search after a minimum is established.
Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.
>From 067dbfa3bad9ea7e881febc143d956e4959f6a5f Mon Sep 17 00:00:00 2001
From: Carl Ritson <carl.ritson at amd.com>
Date: Thu, 12 Sep 2024 14:07:31 +0900
Subject: [PATCH] [AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince
(NFCI)
Refactor getWaitStatesSince:
* Remove recursion to avoid excess stack usage with newer hazards
* Ensure algorithm always returns minima to hazards by allowing
revisiting of blocks if a shorter path is encountered.
* Reduce the search space by actively pruning deeper search after
a minimum is established.
Note: in edge cases this might be slightly slower as it now
searches to find the true minimum number of wait states.
---
.../lib/Target/AMDGPU/GCNHazardRecognizer.cpp | 96 ++++++++++++++-----
1 file changed, 74 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index cc39fd1740683f..5150fabd173fa6 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -19,6 +19,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/ScheduleDAG.h"
#include "llvm/TargetParser/TargetParser.h"
+#include <queue>
using namespace llvm;
@@ -495,21 +496,18 @@ hasHazard(StateT State,
return false;
}
-// Returns a minimum wait states since \p I walking all predecessors.
-// Only scans until \p IsExpired does not return true.
-// Can only be run in a hazard recognizer mode.
-static int getWaitStatesSince(
+// Update \p WaitStates while iterating from \p I to hazard in \p MBB.
+static HazardFnResult countWaitStatesSince(
GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
- MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
- IsExpiredFn IsExpired, DenseSet<const MachineBasicBlock *> &Visited,
- GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates,
+ IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) {
for (auto E = MBB->instr_rend(); I != E; ++I) {
// Don't add WaitStates for parent BUNDLE instructions.
if (I->isBundle())
continue;
if (IsHazard(*I))
- return WaitStates;
+ return HazardFound;
if (I->isInlineAsm())
continue;
@@ -517,29 +515,85 @@ static int getWaitStatesSince(
WaitStates += GetNumWaitStates(*I);
if (IsExpired(*I, WaitStates))
- return std::numeric_limits<int>::max();
+ return HazardExpired;
+ }
+
+ return NoHazardFound;
+}
+
+// Implements predecessor search for getWaitStatesSince.
+static int getWaitStatesSinceImpl(
+ GCNHazardRecognizer::IsHazardFn IsHazard,
+ const MachineBasicBlock *InitialMBB, int InitialWaitStates,
+ IsExpiredFn IsExpired,
+ GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ DenseMap<const MachineBasicBlock *, int> Visited;
+
+ // Build worklist of predecessors.
+ // Note: use queue so search is breadth first, which reduces search space
+ // when a hazard is found.
+ std::queue<const MachineBasicBlock *> Worklist;
+ for (MachineBasicBlock *Pred : InitialMBB->predecessors()) {
+ Visited[Pred] = InitialWaitStates;
+ Worklist.push(Pred);
}
+ // Find minimum wait states to hazard or determine that all paths expire.
int MinWaitStates = std::numeric_limits<int>::max();
- for (MachineBasicBlock *Pred : MBB->predecessors()) {
- if (!Visited.insert(Pred).second)
- continue;
+ while (!Worklist.empty()) {
+ const MachineBasicBlock *MBB = Worklist.front();
+ int WaitStates = Visited[MBB];
+ Worklist.pop();
- int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates,
- IsExpired, Visited, GetNumWaitStates);
+ // No reason to search blocks when wait states exceed established minimum.
+ if (WaitStates >= MinWaitStates)
+ continue;
- MinWaitStates = std::min(MinWaitStates, W);
+ // Search for hazard
+ auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(),
+ WaitStates, IsExpired, GetNumWaitStates);
+ if (Search == HazardFound) {
+ // Update minimum.
+ MinWaitStates = std::min(MinWaitStates, WaitStates);
+ } else if (Search == NoHazardFound && WaitStates < MinWaitStates) {
+ // Search predecessors.
+ for (MachineBasicBlock *Pred : MBB->predecessors()) {
+ if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
+ // Store lowest wait states required to visit this block.
+ Visited[Pred] = WaitStates;
+ Worklist.push(Pred);
+ }
+ }
+ }
}
return MinWaitStates;
}
+// Returns minimum wait states since \p I walking all predecessors.
+// Only scans until \p IsExpired does not return true.
+// Can only be run in a hazard recognizer mode.
+static int getWaitStatesSince(
+ GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
+ MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
+ IsExpiredFn IsExpired,
+ GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ // Scan this block from I.
+ auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
+ IsExpired, GetNumWaitStates);
+ if (InitSearch == HazardFound)
+ return WaitStates;
+ else if (InitSearch == HazardExpired)
+ return std::numeric_limits<int>::max();
+ else
+ return getWaitStatesSinceImpl(IsHazard, MBB, WaitStates, IsExpired,
+ GetNumWaitStates);
+}
+
static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
const MachineInstr *MI, IsExpiredFn IsExpired) {
- DenseSet<const MachineBasicBlock *> Visited;
return getWaitStatesSince(IsHazard, MI->getParent(),
- std::next(MI->getReverseIterator()),
- 0, IsExpired, Visited);
+ std::next(MI->getReverseIterator()), 0, IsExpired);
}
int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1524,10 +1578,9 @@ bool GCNHazardRecognizer::fixLdsDirectVALUHazard(MachineInstr *MI) {
return SIInstrInfo::isVALU(MI) ? 1 : 0;
};
- DenseSet<const MachineBasicBlock *> Visited;
auto Count = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
std::next(MI->getReverseIterator()), 0,
- IsExpiredFn, Visited, GetWaitStatesFn);
+ IsExpiredFn, GetWaitStatesFn);
// Transcendentals can execute in parallel to other VALUs.
// This makes va_vdst count unusable with a mixture of VALU and TRANS.
@@ -3234,10 +3287,9 @@ bool GCNHazardRecognizer::fixVALUReadSGPRHazard(MachineInstr *MI) {
};
// Check for the hazard.
- DenseSet<const MachineBasicBlock *> Visited;
int WaitStates = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
std::next(MI->getReverseIterator()), 0,
- IsExpiredFn, Visited, WaitStatesFn);
+ IsExpiredFn, WaitStatesFn);
if (WaitStates >= SALUExpiryCount)
return false;
More information about the llvm-commits
mailing list