[llvm] [AMDGPU] NFC: Provide RPTracker interface for external iterators (PR #93088)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 22 12:00:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Jeffrey Byrnes (jrbyrnes)

<details>
<summary>Changes</summary>

This is part of a series of PRs which enable using the GCNRPTrackers during scheduling. I've split them up to (hopefully) make reviewing easier. For context see https://github.com/llvm/llvm-project/pull/88797 . 

This PR adds adds an interface to the existing GCNRPTrackers which provides a way to use the trackers with an externally managed iterator.

---
Full diff: https://github.com/llvm/llvm-project/pull/93088.diff


2 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (+46-24) 
- (modified) llvm/lib/Target/AMDGPU/GCNRegPressure.h (+10-8) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
index 5c394e6d6296d..f1c4c8b397ddc 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
@@ -343,24 +343,25 @@ void GCNRPTracker::reset(const MachineInstr &MI,
   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
 }
 
-////////////////////////////////////////////////////////////////////////////////
-// GCNUpwardRPTracker
-
-void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
-                               const LiveRegSet &LiveRegs_) {
+void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
+                         const LiveRegSet &LiveRegs_) {
   MRI = &MRI_;
   LiveRegs = LiveRegs_;
   LastTrackedMI = nullptr;
   MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
 }
 
-void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
+////////////////////////////////////////////////////////////////////////////////
+// GCNUpwardRPTracker
+
+bool GCNUpwardRPTracker::recede(const MachineInstr &MI, bool ShouldTrackIt) {
   assert(MRI && "call reset first");
 
-  LastTrackedMI = &MI;
+  if (ShouldTrackIt)
+    LastTrackedMI = &MI;
 
   if (MI.isDebugInstr())
-    return;
+    return false;
 
   // Kill all defs.
   GCNRegPressure DefPressure, ECDefPressure;
@@ -412,6 +413,7 @@ void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
                           : max(CurPressure, MaxPressure);
 
   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
+  return false;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -430,28 +432,44 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
   return true;
 }
 
-bool GCNDownwardRPTracker::advanceBeforeNext() {
+bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
+                                             bool ShouldTrackIt,
+                                             LiveIntervals *TheLIS) {
   assert(MRI && "call reset first");
-  if (!LastTrackedMI)
-    return NextMI == MBBEnd;
+  SlotIndex SI;
+  LiveIntervals *CurrLIS;
+  MachineInstr *CurrMI;
+  if (ShouldTrackIt) {
+    if (!LastTrackedMI)
+      return NextMI == MBBEnd;
+
+    assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
+    CurrLIS = const_cast<LiveIntervals *>(&LIS);
+    CurrMI = const_cast<MachineInstr *>(LastTrackedMI);
+
+    SI = NextMI == MBBEnd
+             ? CurrLIS->getInstructionIndex(*LastTrackedMI).getDeadSlot()
+             : CurrLIS->getInstructionIndex(*NextMI).getBaseIndex();
+  }
 
-  assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
+  else { //! ShouldTrackIt
+    CurrLIS = TheLIS;
+    SI = CurrLIS->getInstructionIndex(*MI).getBaseIndex();
+    CurrMI = MI;
+  }
 
-  SlotIndex SI = NextMI == MBBEnd
-                     ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
-                     : LIS.getInstructionIndex(*NextMI).getBaseIndex();
   assert(SI.isValid());
 
   // Remove dead registers or mask bits.
   SmallSet<Register, 8> SeenRegs;
-  for (auto &MO : LastTrackedMI->operands()) {
+  for (auto &MO : CurrMI->operands()) {
     if (!MO.isReg() || !MO.getReg().isVirtual())
       continue;
     if (MO.isUse() && !MO.readsReg())
       continue;
     if (!SeenRegs.insert(MO.getReg()).second)
       continue;
-    const LiveInterval &LI = LIS.getInterval(MO.getReg());
+    const LiveInterval &LI = CurrLIS->getInterval(MO.getReg());
     if (LI.hasSubRanges()) {
       auto It = LiveRegs.end();
       for (const auto &S : LI.subranges()) {
@@ -481,15 +499,18 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
 
   LastTrackedMI = nullptr;
 
-  return NextMI == MBBEnd;
+  return ShouldTrackIt && (NextMI == MBBEnd);
 }
 
-void GCNDownwardRPTracker::advanceToNext() {
+void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, bool ShouldTrackIt) {
   LastTrackedMI = &*NextMI++;
   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
 
+  MachineInstr *CurrMI =
+      ShouldTrackIt ? const_cast<MachineInstr *>(LastTrackedMI) : MI;
+
   // Add new registers or mask bits.
-  for (const auto &MO : LastTrackedMI->all_defs()) {
+  for (const auto &MO : CurrMI->all_defs()) {
     Register Reg = MO.getReg();
     if (!Reg.isVirtual())
       continue;
@@ -502,11 +523,12 @@ void GCNDownwardRPTracker::advanceToNext() {
   MaxPressure = max(MaxPressure, CurPressure);
 }
 
-bool GCNDownwardRPTracker::advance() {
-  if (NextMI == MBBEnd)
+bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool ShouldTrackIt,
+                                   LiveIntervals *TheLIS) {
+  if (ShouldTrackIt && NextMI == MBBEnd)
     return false;
-  advanceBeforeNext();
-  advanceToNext();
+  advanceBeforeNext(MI, ShouldTrackIt, TheLIS);
+  advanceToNext(MI, ShouldTrackIt);
   return true;
 }
 
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
index 752f53752fa68..8abbce138cf16 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
@@ -160,6 +160,9 @@ class GCNRPTracker {
              bool After);
 
 public:
+  // reset tracker and set live register set to the specified value.
+  void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
+
   // live regs for the current state
   const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
   const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
@@ -180,12 +183,9 @@ class GCNUpwardRPTracker : public GCNRPTracker {
 public:
   GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
 
-  // reset tracker and set live register set to the specified value.
-  void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
-
   // reset tracker at the specified slot index.
   void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
-    reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
+    GCNRPTracker::reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
   }
 
   // reset tracker to the end of the MBB.
@@ -200,7 +200,7 @@ class GCNUpwardRPTracker : public GCNRPTracker {
   }
 
   // move to the state just before the MI (in program order).
-  void recede(const MachineInstr &MI);
+  bool recede(const MachineInstr &MI, bool ShouldTrackIt = true);
 
   // checks whether the tracker's state after receding MI corresponds
   // to reported by LIS.
@@ -242,13 +242,15 @@ class GCNDownwardRPTracker : public GCNRPTracker {
 
   // Move to the state right before the next MI or after the end of MBB.
   // Returns false if reached end of the block.
-  bool advanceBeforeNext();
+  bool advanceBeforeNext(MachineInstr *MI = nullptr, bool ShouldTrackIt = true,
+                         LiveIntervals *TheLIS = nullptr);
 
   // Move to the state at the MI, advanceBeforeNext has to be called first.
-  void advanceToNext();
+  void advanceToNext(MachineInstr *MI = nullptr, bool ShouldTrackIt = true);
 
   // Move to the state at the next MI. Returns false if reached end of block.
-  bool advance();
+  bool advance(MachineInstr *MI = nullptr, bool ShouldTrackIt = true,
+               LiveIntervals *TheLIS = nullptr);
 
   // Advance instructions until before End.
   bool advance(MachineBasicBlock::const_iterator End);

``````````

</details>


https://github.com/llvm/llvm-project/pull/93088


More information about the llvm-commits mailing list