[llvm] [AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince (NFCI) (PR #108347)

Carl Ritson via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 12 01:35:53 PDT 2024


https://github.com/perlfu created https://github.com/llvm/llvm-project/pull/108347

Refactor getWaitStatesSince:
* Remove recursion to avoid excess stack usage with newer hazards
* Ensure algorithm always returns minima to hazards by allowing revisiting of blocks if a shorter path is encountered.
* Reduce the search space by actively pruning deeper search after a minimum is established.

Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.

>From 067dbfa3bad9ea7e881febc143d956e4959f6a5f Mon Sep 17 00:00:00 2001
From: Carl Ritson <carl.ritson at amd.com>
Date: Thu, 12 Sep 2024 14:07:31 +0900
Subject: [PATCH] [AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince
 (NFCI)

Refactor getWaitStatesSince:
* Remove recursion to avoid excess stack usage with newer hazards
* Ensure algorithm always returns minima to hazards by allowing
  revisiting of blocks if a shorter path is encountered.
* Reduce the search space by actively pruning deeper search after
  a minimum is established.

Note: in edge cases this might be slightly slower as it now
searches to find the true minimum number of wait states.
---
 .../lib/Target/AMDGPU/GCNHazardRecognizer.cpp | 96 ++++++++++++++-----
 1 file changed, 74 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index cc39fd1740683f..5150fabd173fa6 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -19,6 +19,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/ScheduleDAG.h"
 #include "llvm/TargetParser/TargetParser.h"
+#include <queue>
 
 using namespace llvm;
 
@@ -495,21 +496,18 @@ hasHazard(StateT State,
   return false;
 }
 
-// Returns a minimum wait states since \p I walking all predecessors.
-// Only scans until \p IsExpired does not return true.
-// Can only be run in a hazard recognizer mode.
-static int getWaitStatesSince(
+// Update \p WaitStates while iterating from \p I to hazard in \p MBB.
+static HazardFnResult countWaitStatesSince(
     GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
-    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
-    IsExpiredFn IsExpired, DenseSet<const MachineBasicBlock *> &Visited,
-    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+    MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates,
+    IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) {
   for (auto E = MBB->instr_rend(); I != E; ++I) {
     // Don't add WaitStates for parent BUNDLE instructions.
     if (I->isBundle())
       continue;
 
     if (IsHazard(*I))
-      return WaitStates;
+      return HazardFound;
 
     if (I->isInlineAsm())
       continue;
@@ -517,29 +515,85 @@ static int getWaitStatesSince(
     WaitStates += GetNumWaitStates(*I);
 
     if (IsExpired(*I, WaitStates))
-      return std::numeric_limits<int>::max();
+      return HazardExpired;
+  }
+
+  return NoHazardFound;
+}
+
+// Implements predecessor search for getWaitStatesSince.
+static int getWaitStatesSinceImpl(
+    GCNHazardRecognizer::IsHazardFn IsHazard,
+    const MachineBasicBlock *InitialMBB, int InitialWaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  DenseMap<const MachineBasicBlock *, int> Visited;
+
+  // Build worklist of predecessors.
+  // Note: use queue so search is breadth first, which reduces search space
+  // when a hazard is found.
+  std::queue<const MachineBasicBlock *> Worklist;
+  for (MachineBasicBlock *Pred : InitialMBB->predecessors()) {
+    Visited[Pred] = InitialWaitStates;
+    Worklist.push(Pred);
   }
 
+  // Find minimum wait states to hazard or determine that all paths expire.
   int MinWaitStates = std::numeric_limits<int>::max();
-  for (MachineBasicBlock *Pred : MBB->predecessors()) {
-    if (!Visited.insert(Pred).second)
-      continue;
+  while (!Worklist.empty()) {
+    const MachineBasicBlock *MBB = Worklist.front();
+    int WaitStates = Visited[MBB];
+    Worklist.pop();
 
-    int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates,
-                               IsExpired, Visited, GetNumWaitStates);
+    // No reason to search blocks when wait states exceed established minimum.
+    if (WaitStates >= MinWaitStates)
+      continue;
 
-    MinWaitStates = std::min(MinWaitStates, W);
+    // Search for hazard
+    auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(),
+                                       WaitStates, IsExpired, GetNumWaitStates);
+    if (Search == HazardFound) {
+      // Update minimum.
+      MinWaitStates = std::min(MinWaitStates, WaitStates);
+    } else if (Search == NoHazardFound && WaitStates < MinWaitStates) {
+      // Search predecessors.
+      for (MachineBasicBlock *Pred : MBB->predecessors()) {
+        if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
+          // Store lowest wait states required to visit this block.
+          Visited[Pred] = WaitStates;
+          Worklist.push(Pred);
+        }
+      }
+    }
   }
 
   return MinWaitStates;
 }
 
+// Returns minimum wait states since \p I walking all predecessors.
+// Only scans until \p IsExpired does not return true.
+// Can only be run in a hazard recognizer mode.
+static int getWaitStatesSince(
+    GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
+    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  // Scan this block from I.
+  auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
+                                         IsExpired, GetNumWaitStates);
+  if (InitSearch == HazardFound)
+    return WaitStates;
+  else if (InitSearch == HazardExpired)
+    return std::numeric_limits<int>::max();
+  else
+    return getWaitStatesSinceImpl(IsHazard, MBB, WaitStates, IsExpired,
+                                  GetNumWaitStates);
+}
+
 static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
                               const MachineInstr *MI, IsExpiredFn IsExpired) {
-  DenseSet<const MachineBasicBlock *> Visited;
   return getWaitStatesSince(IsHazard, MI->getParent(),
-                            std::next(MI->getReverseIterator()),
-                            0, IsExpired, Visited);
+                            std::next(MI->getReverseIterator()), 0, IsExpired);
 }
 
 int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1524,10 +1578,9 @@ bool GCNHazardRecognizer::fixLdsDirectVALUHazard(MachineInstr *MI) {
     return SIInstrInfo::isVALU(MI) ? 1 : 0;
   };
 
-  DenseSet<const MachineBasicBlock *> Visited;
   auto Count = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
                                     std::next(MI->getReverseIterator()), 0,
-                                    IsExpiredFn, Visited, GetWaitStatesFn);
+                                    IsExpiredFn, GetWaitStatesFn);
 
   // Transcendentals can execute in parallel to other VALUs.
   // This makes va_vdst count unusable with a mixture of VALU and TRANS.
@@ -3234,10 +3287,9 @@ bool GCNHazardRecognizer::fixVALUReadSGPRHazard(MachineInstr *MI) {
   };
 
   // Check for the hazard.
-  DenseSet<const MachineBasicBlock *> Visited;
   int WaitStates = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
                                         std::next(MI->getReverseIterator()), 0,
-                                        IsExpiredFn, Visited, WaitStatesFn);
+                                        IsExpiredFn, WaitStatesFn);
 
   if (WaitStates >= SALUExpiryCount)
     return false;



More information about the llvm-commits mailing list