[llvm] [AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince (NFCI) (PR #108347)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 12 06:34:16 PDT 2024


================
@@ -495,51 +496,104 @@ hasHazard(StateT State,
   return false;
 }
 
-// Returns a minimum wait states since \p I walking all predecessors.
-// Only scans until \p IsExpired does not return true.
-// Can only be run in a hazard recognizer mode.
-static int getWaitStatesSince(
+// Update \p WaitStates while iterating from \p I to hazard in \p MBB.
+static HazardFnResult countWaitStatesSince(
     GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
-    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
-    IsExpiredFn IsExpired, DenseSet<const MachineBasicBlock *> &Visited,
-    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+    MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates,
+    IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) {
   for (auto E = MBB->instr_rend(); I != E; ++I) {
     // Don't add WaitStates for parent BUNDLE instructions.
     if (I->isBundle())
       continue;
 
     if (IsHazard(*I))
-      return WaitStates;
+      return HazardFound;
 
     if (I->isInlineAsm())
       continue;
 
     WaitStates += GetNumWaitStates(*I);
 
     if (IsExpired(*I, WaitStates))
-      return std::numeric_limits<int>::max();
+      return HazardExpired;
+  }
+
+  return NoHazardFound;
+}
+
+// Implements predecessor search for getWaitStatesSince.
+static int getWaitStatesSinceImpl(
+    GCNHazardRecognizer::IsHazardFn IsHazard,
+    const MachineBasicBlock *InitialMBB, int InitialWaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  DenseMap<const MachineBasicBlock *, int> Visited;
+
+  // Build worklist of predecessors.
+  // Note: use queue so search is breadth first, which reduces search space
+  // when a hazard is found.
+  std::queue<const MachineBasicBlock *> Worklist;
+  for (MachineBasicBlock *Pred : InitialMBB->predecessors()) {
+    Visited[Pred] = InitialWaitStates;
+    Worklist.push(Pred);
   }
 
+  // Find minimum wait states to hazard or determine that all paths expire.
   int MinWaitStates = std::numeric_limits<int>::max();
-  for (MachineBasicBlock *Pred : MBB->predecessors()) {
-    if (!Visited.insert(Pred).second)
-      continue;
+  while (!Worklist.empty()) {
+    const MachineBasicBlock *MBB = Worklist.front();
+    int WaitStates = Visited[MBB];
+    Worklist.pop();
 
-    int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates,
-                               IsExpired, Visited, GetNumWaitStates);
+    // No reason to search blocks when wait states exceed established minimum.
+    if (WaitStates >= MinWaitStates)
+      continue;
 
-    MinWaitStates = std::min(MinWaitStates, W);
+    // Search for hazard
+    auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(),
+                                       WaitStates, IsExpired, GetNumWaitStates);
+    if (Search == HazardFound) {
+      // Update minimum.
+      MinWaitStates = std::min(MinWaitStates, WaitStates);
+    } else if (Search == NoHazardFound && WaitStates < MinWaitStates) {
+      // Search predecessors.
+      for (MachineBasicBlock *Pred : MBB->predecessors()) {
+        if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
+          // Store lowest wait states required to visit this block.
+          Visited[Pred] = WaitStates;
+          Worklist.push(Pred);
+        }
+      }
+    }
   }
 
   return MinWaitStates;
 }
 
+// Returns minimum wait states since \p I walking all predecessors.
+// Only scans until \p IsExpired does not return true.
+// Can only be run in a hazard recognizer mode.
+static int getWaitStatesSince(
+    GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
+    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  // Scan this block from I.
+  auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
----------------
jayfoad wrote:

Rather than call `countWaitStatesSince` here, can't you immediately call into `getWaitStatesSinceImpl` but initialize the worklist to just `MBB`? I.e. common up this function with the body of the loop in `getWaitStatesSinceImpl`? I guess the complexity is that each item on the worklist would have to be an (MBB, iterator) pair, so that you know at what point within each MBB to start searching.

https://github.com/llvm/llvm-project/pull/108347


More information about the llvm-commits mailing list