[llvm] 7ca9e47 - [AMDGPU] Start refactoring GCNSchedStrategy

Austin Kerbow via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 26 08:55:28 PDT 2022


Author: Austin Kerbow
Date: 2022-07-26T08:55:19-07:00
New Revision: 7ca9e471fe5b5ec51d151774e52dd0d5bd8f0ad0

URL: https://github.com/llvm/llvm-project/commit/7ca9e471fe5b5ec51d151774e52dd0d5bd8f0ad0
DIFF: https://github.com/llvm/llvm-project/commit/7ca9e471fe5b5ec51d151774e52dd0d5bd8f0ad0.diff

LOG: [AMDGPU] Start refactoring GCNSchedStrategy

Tries to make the different scheduling stages a bit more self contained and
modifiable. Intended to be NFC. Preface to other changes.

Reviewed By: rampitec

Differential Revision: https://reviews.llvm.org/D130147

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
    llvm/lib/Target/AMDGPU/GCNSchedStrategy.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
index 04da14cc4916..859deae86f35 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
@@ -9,6 +9,18 @@
 /// \file
 /// This contains a MachineSchedStrategy implementation for maximizing wave
 /// occupancy on GCN hardware.
+///
+/// This pass will apply multiple scheduling stages to the same function.
+/// Regions are first recorded in GCNScheduleDAGMILive::schedule. The actual
+/// entry point for the scheduling of those regions is
+/// GCNScheduleDAGMILive::runSchedStages.
+
+/// Generally, the reason for having multiple scheduling stages is to account
+/// for the kernel-wide effect of register usage on occupancy.  Usually, only a
+/// few scheduling regions will have register pressure high enough to limit
+/// occupancy for the kernel, so constraints can be relaxed to improve ILP in
+/// other regions.
+///
 //===----------------------------------------------------------------------===//
 
 #include "GCNSchedStrategy.h"
@@ -20,9 +32,9 @@
 using namespace llvm;
 
 GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
-    const MachineSchedContext *C) :
-    GenericScheduler(C), TargetOccupancy(0), HasClusteredNodes(false),
-    HasExcessPressure(false), MF(nullptr) { }
+    const MachineSchedContext *C)
+    : GenericScheduler(C), TargetOccupancy(0), MF(nullptr),
+      HasClusteredNodes(false), HasExcessPressure(false) {}
 
 void GCNMaxOccupancySchedStrategy::initialize(ScheduleDAGMI *DAG) {
   GenericScheduler::initialize(DAG);
@@ -302,210 +314,30 @@ SUnit *GCNMaxOccupancySchedStrategy::pickNode(bool &IsTopNode) {
   return SU;
 }
 
-GCNScheduleDAGMILive::GCNScheduleDAGMILive(MachineSchedContext *C,
-                        std::unique_ptr<MachineSchedStrategy> S) :
-  ScheduleDAGMILive(C, std::move(S)),
-  ST(MF.getSubtarget<GCNSubtarget>()),
-  MFI(*MF.getInfo<SIMachineFunctionInfo>()),
-  StartingOccupancy(MFI.getOccupancy()),
-  MinOccupancy(StartingOccupancy), Stage(Collect), RegionIdx(0) {
+GCNScheduleDAGMILive::GCNScheduleDAGMILive(
+    MachineSchedContext *C, std::unique_ptr<MachineSchedStrategy> S)
+    : ScheduleDAGMILive(C, std::move(S)), ST(MF.getSubtarget<GCNSubtarget>()),
+      MFI(*MF.getInfo<SIMachineFunctionInfo>()),
+      StartingOccupancy(MFI.getOccupancy()), MinOccupancy(StartingOccupancy) {
 
   LLVM_DEBUG(dbgs() << "Starting occupancy is " << StartingOccupancy << ".\n");
 }
 
 void GCNScheduleDAGMILive::schedule() {
-  if (Stage == Collect) {
-    // Just record regions at the first pass.
-    Regions.push_back(std::make_pair(RegionBegin, RegionEnd));
-    return;
-  }
-
-  std::vector<MachineInstr*> Unsched;
-  Unsched.reserve(NumRegionInstrs);
-  for (auto &I : *this) {
-    Unsched.push_back(&I);
-  }
-
-  GCNRegPressure PressureBefore;
-  if (LIS) {
-    PressureBefore = Pressure[RegionIdx];
-
-    LLVM_DEBUG(dbgs() << "Pressure before scheduling:\nRegion live-ins:";
-               GCNRPTracker::printLiveRegs(dbgs(), LiveIns[RegionIdx], MRI);
-               dbgs() << "Region live-in pressure:  ";
-               llvm::getRegPressure(MRI, LiveIns[RegionIdx]).print(dbgs());
-               dbgs() << "Region register pressure: ";
-               PressureBefore.print(dbgs()));
-  }
-
-  GCNMaxOccupancySchedStrategy &S = (GCNMaxOccupancySchedStrategy&)*SchedImpl;
-  // Set HasClusteredNodes to true for late stages where we have already
-  // collected it. That way pickNode() will not scan SDep's when not needed.
-  S.HasClusteredNodes = Stage > InitialSchedule;
-  S.HasExcessPressure = false;
-  ScheduleDAGMILive::schedule();
-  Regions[RegionIdx] = std::make_pair(RegionBegin, RegionEnd);
-  RescheduleRegions[RegionIdx] = false;
-  if (Stage == InitialSchedule && S.HasClusteredNodes)
-    RegionsWithClusters[RegionIdx] = true;
-  if (S.HasExcessPressure)
-    RegionsWithHighRP[RegionIdx] = true;
-
-  if (!LIS)
-    return;
-
-  // Check the results of scheduling.
-  auto PressureAfter = getRealRegPressure();
-
-  LLVM_DEBUG(dbgs() << "Pressure after scheduling: ";
-             PressureAfter.print(dbgs()));
-
-  if (PressureAfter.getSGPRNum() <= S.SGPRCriticalLimit &&
-      PressureAfter.getVGPRNum(ST.hasGFX90AInsts()) <= S.VGPRCriticalLimit) {
-    Pressure[RegionIdx] = PressureAfter;
-    RegionsWithMinOcc[RegionIdx] =
-        PressureAfter.getOccupancy(ST) == MinOccupancy;
-
-    LLVM_DEBUG(dbgs() << "Pressure in desired limits, done.\n");
-    return;
-  }
-
-  unsigned WavesAfter =
-      std::min(S.TargetOccupancy, PressureAfter.getOccupancy(ST));
-  unsigned WavesBefore =
-      std::min(S.TargetOccupancy, PressureBefore.getOccupancy(ST));
-  LLVM_DEBUG(dbgs() << "Occupancy before scheduling: " << WavesBefore
-                    << ", after " << WavesAfter << ".\n");
-
-  // We may not be able to keep the current target occupancy because of the just
-  // scheduled region. We might still be able to revert scheduling if the
-  // occupancy before was higher, or if the current schedule has register
-  // pressure higher than the excess limits which could lead to more spilling.
-  unsigned NewOccupancy = std::max(WavesAfter, WavesBefore);
-
-  // Allow memory bound functions to drop to 4 waves if not limited by an
-  // attribute.
-  if (WavesAfter < WavesBefore && WavesAfter < MinOccupancy &&
-      WavesAfter >= MFI.getMinAllowedOccupancy()) {
-    LLVM_DEBUG(dbgs() << "Function is memory bound, allow occupancy drop up to "
-                      << MFI.getMinAllowedOccupancy() << " waves\n");
-    NewOccupancy = WavesAfter;
-  }
-
-  if (NewOccupancy < MinOccupancy) {
-    MinOccupancy = NewOccupancy;
-    MFI.limitOccupancy(MinOccupancy);
-    RegionsWithMinOcc.reset();
-    LLVM_DEBUG(dbgs() << "Occupancy lowered for the function to "
-                      << MinOccupancy << ".\n");
-  }
-
-  unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
-  unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
-  if (PressureAfter.getVGPRNum(false) > MaxVGPRs ||
-      PressureAfter.getAGPRNum() > MaxVGPRs ||
-      PressureAfter.getSGPRNum() > MaxSGPRs) {
-    RescheduleRegions[RegionIdx] = true;
-    RegionsWithHighRP[RegionIdx] = true;
-  }
-
-  // If this condition is true, then either the occupancy before and after
-  // scheduling is the same, or we are allowing the occupancy to drop because
-  // the function is memory bound. Even if we are OK with the current occupancy,
-  // we still need to verify that we will not introduce any extra chance of
-  // spilling.
-  if (WavesAfter >= MinOccupancy) {
-    if (Stage == UnclusteredReschedule &&
-        !PressureAfter.less(ST, PressureBefore)) {
-      LLVM_DEBUG(dbgs() << "Unclustered reschedule did not help.\n");
-    } else if (WavesAfter > MFI.getMinWavesPerEU() ||
-        PressureAfter.less(ST, PressureBefore) ||
-        !RescheduleRegions[RegionIdx]) {
-      Pressure[RegionIdx] = PressureAfter;
-      RegionsWithMinOcc[RegionIdx] =
-          PressureAfter.getOccupancy(ST) == MinOccupancy;
-      if (!RegionsWithClusters[RegionIdx] &&
-          (Stage + 1) == UnclusteredReschedule)
-        RescheduleRegions[RegionIdx] = false;
-      return;
-    } else {
-      LLVM_DEBUG(dbgs() << "New pressure will result in more spilling.\n");
-    }
-  }
-
-  RegionsWithMinOcc[RegionIdx] =
-      PressureBefore.getOccupancy(ST) == MinOccupancy;
-  LLVM_DEBUG(dbgs() << "Attempting to revert scheduling.\n");
-  RescheduleRegions[RegionIdx] = RegionsWithClusters[RegionIdx] ||
-                                 (Stage + 1) != UnclusteredReschedule;
-  RegionEnd = RegionBegin;
-  int SkippedDebugInstr = 0;
-  for (MachineInstr *MI : Unsched) {
-    if (MI->isDebugInstr()) {
-      ++SkippedDebugInstr;
-      continue;
-    }
-
-    if (MI->getIterator() != RegionEnd) {
-      BB->remove(MI);
-      BB->insert(RegionEnd, MI);
-      if (!MI->isDebugInstr())
-        LIS->handleMove(*MI, true);
-    }
-    // Reset read-undef flags and update them later.
-    for (auto &Op : MI->operands())
-      if (Op.isReg() && Op.isDef())
-        Op.setIsUndef(false);
-    RegisterOperands RegOpers;
-    RegOpers.collect(*MI, *TRI, MRI, ShouldTrackLaneMasks, false);
-    if (!MI->isDebugInstr()) {
-      if (ShouldTrackLaneMasks) {
-        // Adjust liveness and add missing dead+read-undef flags.
-        SlotIndex SlotIdx = LIS->getInstructionIndex(*MI).getRegSlot();
-        RegOpers.adjustLaneLiveness(*LIS, MRI, SlotIdx, MI);
-      } else {
-        // Adjust for missing dead-def flags.
-        RegOpers.detectDeadDefs(*MI, *LIS);
-      }
-    }
-    RegionEnd = MI->getIterator();
-    ++RegionEnd;
-    LLVM_DEBUG(dbgs() << "Scheduling " << *MI);
-  }
-
-  // After reverting schedule, debug instrs will now be at the end of the block
-  // and RegionEnd will point to the first debug instr. Increment RegionEnd
-  // pass debug instrs to the actual end of the scheduling region.
-  while (SkippedDebugInstr-- > 0)
-    ++RegionEnd;
-
-  // If Unsched.front() instruction is a debug instruction, this will actually
-  // shrink the region since we moved all debug instructions to the end of the
-  // block. Find the first instruction that is not a debug instruction.
-  RegionBegin = Unsched.front()->getIterator();
-  if (RegionBegin->isDebugInstr()) {
-    for (MachineInstr *MI : Unsched) {
-      if (MI->isDebugInstr())
-        continue;
-      RegionBegin = MI->getIterator();
-      break;
-    }
-  }
-
-  // Then move the debug instructions back into their correct place and set
-  // RegionBegin and RegionEnd if needed.
-  placeDebugValues();
-
-  Regions[RegionIdx] = std::make_pair(RegionBegin, RegionEnd);
+  // Collect all scheduling regions. The actual scheduling is performed in
+  // GCNScheduleDAGMILive::finalizeSchedule.
+  Regions.push_back(std::make_pair(RegionBegin, RegionEnd));
 }
 
-GCNRegPressure GCNScheduleDAGMILive::getRealRegPressure() const {
+GCNRegPressure
+GCNScheduleDAGMILive::getRealRegPressure(unsigned RegionIdx) const {
   GCNDownwardRPTracker RPTracker(*LIS);
   RPTracker.advance(begin(), end(), &LiveIns[RegionIdx]);
   return RPTracker.moveMaxPressure();
 }
 
-void GCNScheduleDAGMILive::computeBlockPressure(const MachineBasicBlock *MBB) {
+void GCNScheduleDAGMILive::computeBlockPressure(unsigned RegionIdx,
+                                                const MachineBasicBlock *MBB) {
   GCNDownwardRPTracker RPTracker(*LIS);
 
   // If the block has the only successor then live-ins of that successor are
@@ -542,7 +374,7 @@ void GCNScheduleDAGMILive::computeBlockPressure(const MachineBasicBlock *MBB) {
     RPTracker.reset(*I, &LRS);
   }
 
-  for ( ; ; ) {
+  for (;;) {
     I = RPTracker.getNext();
 
     if (Regions[CurRegion].first == I || NonDbgMI == I) {
@@ -588,8 +420,9 @@ GCNScheduleDAGMILive::getBBLiveInMap() const {
 }
 
 void GCNScheduleDAGMILive::finalizeSchedule() {
-  LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
-
+  // Start actual scheduling here. This function is called by the base
+  // MachineScheduler after all regions have been recorded by
+  // GCNScheduleDAGMILive::schedule().
   LiveIns.resize(Regions.size());
   Pressure.resize(Regions.size());
   RescheduleRegions.resize(Regions.size());
@@ -601,142 +434,470 @@ void GCNScheduleDAGMILive::finalizeSchedule() {
   RegionsWithHighRP.reset();
   RegionsWithMinOcc.reset();
 
+  runSchedStages();
+}
+
+void GCNScheduleDAGMILive::runSchedStages() {
+  LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
+  InitialScheduleStage S0(GCNSchedStageID::InitialSchedule, *this);
+  UnclusteredRescheduleStage S1(GCNSchedStageID::UnclusteredReschedule, *this);
+  ClusteredLowOccStage S2(GCNSchedStageID::ClusteredLowOccupancyReschedule,
+                          *this);
+  PreRARematStage S3(GCNSchedStageID::PreRARematerialize, *this);
+  GCNSchedStage *SchedStages[] = {&S0, &S1, &S2, &S3};
+
   if (!Regions.empty())
     BBLiveInMap = getBBLiveInMap();
 
-  std::vector<std::unique_ptr<ScheduleDAGMutation>> SavedMutations;
+  for (auto *Stage : SchedStages) {
+    if (!Stage->initGCNSchedStage())
+      continue;
 
-  do {
-    Stage++;
-    RegionIdx = 0;
-    MachineBasicBlock *MBB = nullptr;
+    for (auto Region : Regions) {
+      RegionBegin = Region.first;
+      RegionEnd = Region.second;
+      // Setup for scheduling the region and check whether it should be skipped.
+      if (!Stage->initGCNRegion()) {
+        Stage->advanceRegion();
+        exitRegion();
+        continue;
+      }
 
-    if (Stage > InitialSchedule) {
-      if (!LIS)
-        break;
+      ScheduleDAGMILive::schedule();
+      Stage->finalizeGCNRegion();
+    }
 
-      // Retry function scheduling if we found resulting occupancy and it is
-      // lower than used for first pass scheduling. This will give more freedom
-      // to schedule low register pressure blocks.
-      // Code is partially copied from MachineSchedulerBase::scheduleRegions().
+    Stage->finalizeGCNSchedStage();
+  }
+}
 
-      if (Stage == UnclusteredReschedule) {
-        if (RescheduleRegions.none())
-          continue;
-        LLVM_DEBUG(dbgs() <<
-          "Retrying function scheduling without clustering.\n");
-      }
+#ifndef NDEBUG
+raw_ostream &llvm::operator<<(raw_ostream &OS, const GCNSchedStageID &StageID) {
+  switch (StageID) {
+  case GCNSchedStageID::InitialSchedule:
+    OS << "Initial Schedule";
+    break;
+  case GCNSchedStageID::UnclusteredReschedule:
+    OS << "Unclustered Reschedule";
+    break;
+  case GCNSchedStageID::ClusteredLowOccupancyReschedule:
+    OS << "Clustered Low Occupancy Reschedule";
+    break;
+  case GCNSchedStageID::PreRARematerialize:
+    OS << "Pre-RA Rematerialize";
+    break;
+  }
+  return OS;
+}
+#endif
 
-      if (Stage == ClusteredLowOccupancyReschedule) {
-        if (StartingOccupancy <= MinOccupancy)
-          break;
+GCNSchedStage::GCNSchedStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
+    : DAG(DAG), S(static_cast<GCNMaxOccupancySchedStrategy &>(*DAG.SchedImpl)),
+      MF(DAG.MF), MFI(DAG.MFI), ST(DAG.ST), StageID(StageID) {}
 
-        LLVM_DEBUG(
-            dbgs()
-            << "Retrying function scheduling with lowest recorded occupancy "
-            << MinOccupancy << ".\n");
-      }
+bool GCNSchedStage::initGCNSchedStage() {
+  if (!DAG.LIS)
+    return false;
 
-      if (Stage == PreRARematerialize) {
-        if (RegionsWithMinOcc.none() || Regions.size() == 1)
-          break;
-
-        const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
-        const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
-        // Check maximum occupancy
-        if (ST.computeOccupancy(MF.getFunction(), MFI.getLDSSize()) ==
-            MinOccupancy)
-          break;
-
-        // FIXME: This pass will invalidate cached MBBLiveIns for regions
-        // inbetween the defs and region we sinked the def to. Cached pressure
-        // for regions where a def is sinked from will also be invalidated. Will
-        // need to be fixed if there is another pass after this pass.
-        static_assert(LastStage == PreRARematerialize,
-                      "Passes after PreRARematerialize are not supported");
-
-        collectRematerializableInstructions();
-        if (RematerializableInsts.empty() || !sinkTriviallyRematInsts(ST, TII))
-          break;
-
-        LLVM_DEBUG(
-            dbgs() << "Retrying function scheduling with improved occupancy of "
-                   << MinOccupancy << " from rematerializing\n");
-      }
-    }
+  LLVM_DEBUG(dbgs() << "Starting scheduling stage: " << StageID << "\n");
+  return true;
+}
 
-    if (Stage == UnclusteredReschedule)
-      SavedMutations.swap(Mutations);
+bool UnclusteredRescheduleStage::initGCNSchedStage() {
+  if (!GCNSchedStage::initGCNSchedStage())
+    return false;
 
-    for (auto Region : Regions) {
-      if (((Stage == UnclusteredReschedule || Stage == PreRARematerialize) &&
-           !RescheduleRegions[RegionIdx]) ||
-          (Stage == ClusteredLowOccupancyReschedule &&
-           !RegionsWithClusters[RegionIdx] && !RegionsWithHighRP[RegionIdx])) {
+  if (DAG.RescheduleRegions.none())
+    return false;
 
-        ++RegionIdx;
-        continue;
-      }
+  SavedMutations.swap(DAG.Mutations);
 
-      RegionBegin = Region.first;
-      RegionEnd = Region.second;
+  LLVM_DEBUG(dbgs() << "Retrying function scheduling without clustering.\n");
+  return true;
+}
 
-      if (RegionBegin->getParent() != MBB) {
-        if (MBB) finishBlock();
-        MBB = RegionBegin->getParent();
-        startBlock(MBB);
-        if (Stage == InitialSchedule)
-          computeBlockPressure(MBB);
-      }
+bool ClusteredLowOccStage::initGCNSchedStage() {
+  if (!GCNSchedStage::initGCNSchedStage())
+    return false;
 
-      unsigned NumRegionInstrs = std::distance(begin(), end());
-      enterRegion(MBB, begin(), end(), NumRegionInstrs);
+  // Don't bother trying to improve ILP in lower RP regions if occupancy has not
+  // been dropped. All regions will have already been scheduled with the ideal
+  // occupancy targets.
+  if (DAG.StartingOccupancy <= DAG.MinOccupancy)
+    return false;
 
-      // Skip empty scheduling regions (0 or 1 schedulable instructions).
-      if (begin() == end() || begin() == std::prev(end())) {
-        exitRegion();
-        ++RegionIdx;
-        continue;
-      }
+  LLVM_DEBUG(
+      dbgs() << "Retrying function scheduling with lowest recorded occupancy "
+             << DAG.MinOccupancy << ".\n");
+  return true;
+}
 
-      LLVM_DEBUG(dbgs() << "********** MI Scheduling **********\n");
-      LLVM_DEBUG(dbgs() << MF.getName() << ":" << printMBBReference(*MBB) << " "
-                        << MBB->getName() << "\n  From: " << *begin()
-                        << "    To: ";
-                 if (RegionEnd != MBB->end()) dbgs() << *RegionEnd;
-                 else dbgs() << "End";
-                 dbgs() << " RegionInstrs: " << NumRegionInstrs << '\n');
+bool PreRARematStage::initGCNSchedStage() {
+  if (!GCNSchedStage::initGCNSchedStage())
+    return false;
+
+  if (DAG.RegionsWithMinOcc.none() || DAG.Regions.size() == 1)
+    return false;
 
-      schedule();
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+  // Check maximum occupancy
+  if (ST.computeOccupancy(MF.getFunction(), MFI.getLDSSize()) ==
+      DAG.MinOccupancy)
+    return false;
+
+  // FIXME: This pass will invalidate cached MBBLiveIns for regions
+  // inbetween the defs and region we sinked the def to. Cached pressure
+  // for regions where a def is sinked from will also be invalidated. Will
+  // need to be fixed if there is another pass after this pass.
+
+  collectRematerializableInstructions();
+  if (RematerializableInsts.empty() || !sinkTriviallyRematInsts(ST, TII))
+    return false;
 
-      exitRegion();
-      ++RegionIdx;
+  LLVM_DEBUG(
+      dbgs() << "Retrying function scheduling with improved occupancy of "
+             << DAG.MinOccupancy << " from rematerializing\n");
+  return true;
+}
+
+void GCNSchedStage::finalizeGCNSchedStage() {
+  DAG.finishBlock();
+  LLVM_DEBUG(dbgs() << "Ending scheduling stage: " << StageID << "\n");
+}
+
+void UnclusteredRescheduleStage::finalizeGCNSchedStage() {
+  SavedMutations.swap(DAG.Mutations);
+
+  GCNSchedStage::finalizeGCNSchedStage();
+}
+
+bool GCNSchedStage::initGCNRegion() {
+  // Check whether this new region is also a new block.
+  if (DAG.RegionBegin->getParent() != CurrentMBB)
+    setupNewBlock();
+
+  unsigned NumRegionInstrs = std::distance(DAG.begin(), DAG.end());
+  DAG.enterRegion(CurrentMBB, DAG.begin(), DAG.end(), NumRegionInstrs);
+
+  // Skip empty scheduling regions (0 or 1 schedulable instructions).
+  if (DAG.begin() == DAG.end() || DAG.begin() == std::prev(DAG.end()))
+    return false;
+
+  LLVM_DEBUG(dbgs() << "********** MI Scheduling **********\n");
+  LLVM_DEBUG(dbgs() << MF.getName() << ":" << printMBBReference(*CurrentMBB)
+                    << " " << CurrentMBB->getName()
+                    << "\n  From: " << *DAG.begin() << "    To: ";
+             if (DAG.RegionEnd != CurrentMBB->end()) dbgs() << *DAG.RegionEnd;
+             else dbgs() << "End";
+             dbgs() << " RegionInstrs: " << NumRegionInstrs << '\n');
+
+  // Save original instruction order before scheduling for possible revert.
+  Unsched.clear();
+  Unsched.reserve(DAG.NumRegionInstrs);
+  for (auto &I : DAG)
+    Unsched.push_back(&I);
+
+  PressureBefore = DAG.Pressure[RegionIdx];
+
+  LLVM_DEBUG(
+      dbgs() << "Pressure before scheduling:\nRegion live-ins:";
+      GCNRPTracker::printLiveRegs(dbgs(), DAG.LiveIns[RegionIdx], DAG.MRI);
+      dbgs() << "Region live-in pressure:  ";
+      llvm::getRegPressure(DAG.MRI, DAG.LiveIns[RegionIdx]).print(dbgs());
+      dbgs() << "Region register pressure: "; PressureBefore.print(dbgs()));
+
+  // Set HasClusteredNodes to true for late stages where we have already
+  // collected it. That way pickNode() will not scan SDep's when not needed.
+  S.HasClusteredNodes = StageID > GCNSchedStageID::InitialSchedule;
+  S.HasExcessPressure = false;
+
+  return true;
+}
+
+bool UnclusteredRescheduleStage::initGCNRegion() {
+  if (!DAG.RescheduleRegions[RegionIdx])
+    return false;
+
+  return GCNSchedStage::initGCNRegion();
+}
+
+bool ClusteredLowOccStage::initGCNRegion() {
+  // We may need to reschedule this region if it doesn't have clusters so it
+  // wasn't rescheduled in the last stage, or if we found it was testing
+  // critical register pressure limits in the unclustered reschedule stage. The
+  // later is because we may not have been able to raise the min occupancy in
+  // the previous stage so the region may be overly constrained even if it was
+  // already rescheduled.
+  if (!DAG.RegionsWithClusters[RegionIdx] && !DAG.RegionsWithHighRP[RegionIdx])
+    return false;
+
+  return GCNSchedStage::initGCNRegion();
+}
+
+bool PreRARematStage::initGCNRegion() {
+  if (!DAG.RescheduleRegions[RegionIdx])
+    return false;
+
+  return GCNSchedStage::initGCNRegion();
+}
+
+void GCNSchedStage::setupNewBlock() {
+  if (CurrentMBB)
+    DAG.finishBlock();
+
+  CurrentMBB = DAG.RegionBegin->getParent();
+  DAG.startBlock(CurrentMBB);
+  // Get real RP for the region if it hasn't be calculated before. After the
+  // initial schedule stage real RP will be collected after scheduling.
+  if (StageID == GCNSchedStageID::InitialSchedule)
+    DAG.computeBlockPressure(RegionIdx, CurrentMBB);
+}
+
+void GCNSchedStage::finalizeGCNRegion() {
+  DAG.Regions[RegionIdx] = std::make_pair(DAG.RegionBegin, DAG.RegionEnd);
+  DAG.RescheduleRegions[RegionIdx] = false;
+  if (S.HasExcessPressure)
+    DAG.RegionsWithHighRP[RegionIdx] = true;
+
+  // Revert scheduling if we have dropped occupancy or there is some other
+  // reason that the original schedule is better.
+  checkScheduling();
+
+  DAG.exitRegion();
+  RegionIdx++;
+}
+
+void InitialScheduleStage::finalizeGCNRegion() {
+  // Record which regions have clustered nodes for the next unclustered
+  // reschedule stage.
+  assert(nextStage(StageID) == GCNSchedStageID::UnclusteredReschedule);
+  if (S.HasClusteredNodes)
+    DAG.RegionsWithClusters[RegionIdx] = true;
+
+  GCNSchedStage::finalizeGCNRegion();
+}
+
+void GCNSchedStage::checkScheduling() {
+  // Check the results of scheduling.
+  PressureAfter = DAG.getRealRegPressure(RegionIdx);
+  LLVM_DEBUG(dbgs() << "Pressure after scheduling: ";
+             PressureAfter.print(dbgs()));
+
+  if (PressureAfter.getSGPRNum() <= S.SGPRCriticalLimit &&
+      PressureAfter.getVGPRNum(ST.hasGFX90AInsts()) <= S.VGPRCriticalLimit) {
+    DAG.Pressure[RegionIdx] = PressureAfter;
+    DAG.RegionsWithMinOcc[RegionIdx] =
+        PressureAfter.getOccupancy(ST) == DAG.MinOccupancy;
+
+    // Early out if we have achieve the occupancy target.
+    LLVM_DEBUG(dbgs() << "Pressure in desired limits, done.\n");
+    return;
+  }
+
+  unsigned WavesAfter =
+      std::min(S.getTargetOccupancy(), PressureAfter.getOccupancy(ST));
+  unsigned WavesBefore =
+      std::min(S.getTargetOccupancy(), PressureBefore.getOccupancy(ST));
+  LLVM_DEBUG(dbgs() << "Occupancy before scheduling: " << WavesBefore
+                    << ", after " << WavesAfter << ".\n");
+
+  // We may not be able to keep the current target occupancy because of the just
+  // scheduled region. We might still be able to revert scheduling if the
+  // occupancy before was higher, or if the current schedule has register
+  // pressure higher than the excess limits which could lead to more spilling.
+  unsigned NewOccupancy = std::max(WavesAfter, WavesBefore);
+
+  // Allow memory bound functions to drop to 4 waves if not limited by an
+  // attribute.
+  if (WavesAfter < WavesBefore && WavesAfter < DAG.MinOccupancy &&
+      WavesAfter >= MFI.getMinAllowedOccupancy()) {
+    LLVM_DEBUG(dbgs() << "Function is memory bound, allow occupancy drop up to "
+                      << MFI.getMinAllowedOccupancy() << " waves\n");
+    NewOccupancy = WavesAfter;
+  }
+
+  if (NewOccupancy < DAG.MinOccupancy) {
+    DAG.MinOccupancy = NewOccupancy;
+    MFI.limitOccupancy(DAG.MinOccupancy);
+    DAG.RegionsWithMinOcc.reset();
+    LLVM_DEBUG(dbgs() << "Occupancy lowered for the function to "
+                      << DAG.MinOccupancy << ".\n");
+  }
+
+  unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
+  unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
+  if (PressureAfter.getVGPRNum(false) > MaxVGPRs ||
+      PressureAfter.getAGPRNum() > MaxVGPRs ||
+      PressureAfter.getSGPRNum() > MaxSGPRs) {
+    DAG.RescheduleRegions[RegionIdx] = true;
+    DAG.RegionsWithHighRP[RegionIdx] = true;
+  }
+
+  // Revert if this region's schedule would cause a drop in occupancy or
+  // spilling.
+  if (shouldRevertScheduling(WavesAfter)) {
+    revertScheduling();
+  } else {
+    DAG.Pressure[RegionIdx] = PressureAfter;
+    DAG.RegionsWithMinOcc[RegionIdx] =
+        PressureAfter.getOccupancy(ST) == DAG.MinOccupancy;
+  }
+}
+
+bool GCNSchedStage::shouldRevertScheduling(unsigned WavesAfter) {
+  if (WavesAfter < DAG.MinOccupancy)
+    return true;
+
+  return false;
+}
+
+bool InitialScheduleStage::shouldRevertScheduling(unsigned WavesAfter) {
+  if (GCNSchedStage::shouldRevertScheduling(WavesAfter))
+    return true;
+
+  if (mayCauseSpilling(WavesAfter))
+    return true;
+
+  assert(nextStage(StageID) == GCNSchedStageID::UnclusteredReschedule);
+  // Don't reschedule the region in the next stage if it doesn't have clusters.
+  if (!DAG.RegionsWithClusters[RegionIdx])
+    DAG.RescheduleRegions[RegionIdx] = false;
+
+  return false;
+}
+
+bool UnclusteredRescheduleStage::shouldRevertScheduling(unsigned WavesAfter) {
+  if (GCNSchedStage::shouldRevertScheduling(WavesAfter))
+    return true;
+
+  // If RP is not reduced in the unclustred reschedule stage, revert to the old
+  // schedule.
+  if (!PressureAfter.less(ST, PressureBefore)) {
+    LLVM_DEBUG(dbgs() << "Unclustered reschedule did not help.\n");
+    return true;
+  }
+
+  return false;
+}
+
+bool ClusteredLowOccStage::shouldRevertScheduling(unsigned WavesAfter) {
+  if (GCNSchedStage::shouldRevertScheduling(WavesAfter))
+    return true;
+
+  if (mayCauseSpilling(WavesAfter))
+    return true;
+
+  return false;
+}
+
+bool PreRARematStage::shouldRevertScheduling(unsigned WavesAfter) {
+  if (GCNSchedStage::shouldRevertScheduling(WavesAfter))
+    return true;
+
+  if (mayCauseSpilling(WavesAfter))
+    return true;
+
+  return false;
+}
+
+bool GCNSchedStage::mayCauseSpilling(unsigned WavesAfter) {
+  if (WavesAfter <= MFI.getMinWavesPerEU() &&
+      !PressureAfter.less(ST, PressureBefore) &&
+      DAG.RescheduleRegions[RegionIdx]) {
+    LLVM_DEBUG(dbgs() << "New pressure will result in more spilling.\n");
+    return true;
+  }
+
+  return false;
+}
+
+void GCNSchedStage::revertScheduling() {
+  DAG.RegionsWithMinOcc[RegionIdx] =
+      PressureBefore.getOccupancy(ST) == DAG.MinOccupancy;
+  LLVM_DEBUG(dbgs() << "Attempting to revert scheduling.\n");
+  DAG.RescheduleRegions[RegionIdx] =
+      DAG.RegionsWithClusters[RegionIdx] ||
+      (nextStage(StageID)) != GCNSchedStageID::UnclusteredReschedule;
+  DAG.RegionEnd = DAG.RegionBegin;
+  int SkippedDebugInstr = 0;
+  for (MachineInstr *MI : Unsched) {
+    if (MI->isDebugInstr()) {
+      ++SkippedDebugInstr;
+      continue;
+    }
+
+    if (MI->getIterator() != DAG.RegionEnd) {
+      DAG.BB->remove(MI);
+      DAG.BB->insert(DAG.RegionEnd, MI);
+      if (!MI->isDebugInstr())
+        DAG.LIS->handleMove(*MI, true);
+    }
+
+    // Reset read-undef flags and update them later.
+    for (auto &Op : MI->operands())
+      if (Op.isReg() && Op.isDef())
+        Op.setIsUndef(false);
+    RegisterOperands RegOpers;
+    RegOpers.collect(*MI, *DAG.TRI, DAG.MRI, DAG.ShouldTrackLaneMasks, false);
+    if (!MI->isDebugInstr()) {
+      if (DAG.ShouldTrackLaneMasks) {
+        // Adjust liveness and add missing dead+read-undef flags.
+        SlotIndex SlotIdx = DAG.LIS->getInstructionIndex(*MI).getRegSlot();
+        RegOpers.adjustLaneLiveness(*DAG.LIS, DAG.MRI, SlotIdx, MI);
+      } else {
+        // Adjust for missing dead-def flags.
+        RegOpers.detectDeadDefs(*MI, *DAG.LIS);
+      }
     }
-    finishBlock();
+    DAG.RegionEnd = MI->getIterator();
+    ++DAG.RegionEnd;
+    LLVM_DEBUG(dbgs() << "Scheduling " << *MI);
+  }
+
+  // After reverting schedule, debug instrs will now be at the end of the block
+  // and RegionEnd will point to the first debug instr. Increment RegionEnd
+  // pass debug instrs to the actual end of the scheduling region.
+  while (SkippedDebugInstr-- > 0)
+    ++DAG.RegionEnd;
+
+  // If Unsched.front() instruction is a debug instruction, this will actually
+  // shrink the region since we moved all debug instructions to the end of the
+  // block. Find the first instruction that is not a debug instruction.
+  DAG.RegionBegin = Unsched.front()->getIterator();
+  if (DAG.RegionBegin->isDebugInstr()) {
+    for (MachineInstr *MI : Unsched) {
+      if (MI->isDebugInstr())
+        continue;
+      DAG.RegionBegin = MI->getIterator();
+      break;
+    }
+  }
+
+  // Then move the debug instructions back into their correct place and set
+  // RegionBegin and RegionEnd if needed.
+  DAG.placeDebugValues();
 
-    if (Stage == UnclusteredReschedule)
-      SavedMutations.swap(Mutations);
-  } while (Stage != LastStage);
+  DAG.Regions[RegionIdx] = std::make_pair(DAG.RegionBegin, DAG.RegionEnd);
 }
 
-void GCNScheduleDAGMILive::collectRematerializableInstructions() {
-  const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI);
-  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
+void PreRARematStage::collectRematerializableInstructions() {
+  const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(DAG.TRI);
+  for (unsigned I = 0, E = DAG.MRI.getNumVirtRegs(); I != E; ++I) {
     Register Reg = Register::index2VirtReg(I);
-    if (!LIS->hasInterval(Reg))
+    if (!DAG.LIS->hasInterval(Reg))
       continue;
 
     // TODO: Handle AGPR and SGPR rematerialization
-    if (!SRI->isVGPRClass(MRI.getRegClass(Reg)) || !MRI.hasOneDef(Reg) ||
-        !MRI.hasOneNonDBGUse(Reg))
+    if (!SRI->isVGPRClass(DAG.MRI.getRegClass(Reg)) ||
+        !DAG.MRI.hasOneDef(Reg) || !DAG.MRI.hasOneNonDBGUse(Reg))
       continue;
 
-    MachineOperand *Op = MRI.getOneDef(Reg);
+    MachineOperand *Op = DAG.MRI.getOneDef(Reg);
     MachineInstr *Def = Op->getParent();
     if (Op->getSubReg() != 0 || !isTriviallyReMaterializable(*Def))
       continue;
 
-    MachineInstr *UseI = &*MRI.use_instr_nodbg_begin(Reg);
+    MachineInstr *UseI = &*DAG.MRI.use_instr_nodbg_begin(Reg);
     if (Def->getParent() == UseI->getParent())
       continue;
 
@@ -744,10 +905,10 @@ void GCNScheduleDAGMILive::collectRematerializableInstructions() {
     // live-through or used inside regions at MinOccupancy. This means that the
     // register must be in the live-in set for the region.
     bool AddedToRematList = false;
-    for (unsigned I = 0, E = Regions.size(); I != E; ++I) {
-      auto It = LiveIns[I].find(Reg);
-      if (It != LiveIns[I].end() && !It->second.none()) {
-        if (RegionsWithMinOcc[I]) {
+    for (unsigned I = 0, E = DAG.Regions.size(); I != E; ++I) {
+      auto It = DAG.LiveIns[I].find(Reg);
+      if (It != DAG.LiveIns[I].end() && !It->second.none()) {
+        if (DAG.RegionsWithMinOcc[I]) {
           RematerializableInsts[I][Def] = UseI;
           AddedToRematList = true;
         }
@@ -762,8 +923,8 @@ void GCNScheduleDAGMILive::collectRematerializableInstructions() {
   }
 }
 
-bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
-                                                   const TargetInstrInfo *TII) {
+bool PreRARematStage::sinkTriviallyRematInsts(const GCNSubtarget &ST,
+                                              const TargetInstrInfo *TII) {
   // Temporary copies of cached variables we will be modifying and replacing if
   // sinking succeeds.
   SmallVector<
@@ -772,9 +933,10 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
   DenseMap<unsigned, GCNRPTracker::LiveRegSet> NewLiveIns;
   DenseMap<unsigned, GCNRegPressure> NewPressure;
   BitVector NewRescheduleRegions;
+  LiveIntervals *LIS = DAG.LIS;
 
-  NewRegions.resize(Regions.size());
-  NewRescheduleRegions.resize(Regions.size());
+  NewRegions.resize(DAG.Regions.size());
+  NewRescheduleRegions.resize(DAG.Regions.size());
 
   // Collect only regions that has a rematerializable def as a live-in.
   SmallSet<unsigned, 16> ImpactedRegions;
@@ -784,16 +946,16 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
   // Make copies of register pressure and live-ins cache that will be updated
   // as we rematerialize.
   for (auto Idx : ImpactedRegions) {
-    NewPressure[Idx] = Pressure[Idx];
-    NewLiveIns[Idx] = LiveIns[Idx];
+    NewPressure[Idx] = DAG.Pressure[Idx];
+    NewLiveIns[Idx] = DAG.LiveIns[Idx];
   }
-  NewRegions = Regions;
+  NewRegions = DAG.Regions;
   NewRescheduleRegions.reset();
 
   DenseMap<MachineInstr *, MachineInstr *> InsertedMIToOldDef;
   bool Improved = false;
   for (auto I : ImpactedRegions) {
-    if (!RegionsWithMinOcc[I])
+    if (!DAG.RegionsWithMinOcc[I])
       continue;
 
     Improved = false;
@@ -802,12 +964,12 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
 
     // TODO: Handle occupancy drop due to AGPR and SGPR.
     // Check if cause of occupancy drop is due to VGPR usage and not SGPR.
-    if (ST.getOccupancyWithNumSGPRs(SGPRUsage) == MinOccupancy)
+    if (ST.getOccupancyWithNumSGPRs(SGPRUsage) == DAG.MinOccupancy)
       break;
 
     // The occupancy of this region could have been improved by a previous
     // iteration's sinking of defs.
-    if (NewPressure[I].getOccupancy(ST) > MinOccupancy) {
+    if (NewPressure[I].getOccupancy(ST) > DAG.MinOccupancy) {
       NewRescheduleRegions[I] = true;
       Improved = true;
       continue;
@@ -827,7 +989,7 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
     unsigned OptimisticOccupancy = ST.getOccupancyWithNumVGPRs(VGPRsAfterSink);
     // If in the most optimistic scenario, we cannot improve occupancy, then do
     // not attempt to sink any instructions.
-    if (OptimisticOccupancy <= MinOccupancy)
+    if (OptimisticOccupancy <= DAG.MinOccupancy)
       break;
 
     unsigned ImproveOccupancy = 0;
@@ -842,7 +1004,7 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
       // call LiveRangeEdit::allUsesAvailableAt() and
       // LiveRangeEdit::canRematerializeAt().
       TII->reMaterialize(*InsertPos->getParent(), InsertPos, Reg,
-                         Def->getOperand(0).getSubReg(), *Def, *TRI);
+                         Def->getOperand(0).getSubReg(), *Def, *DAG.TRI);
       MachineInstr *NewMI = &*(--InsertPos);
       LIS->InsertMachineInstrInMaps(*NewMI);
       LIS->removeInterval(Reg);
@@ -851,11 +1013,11 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
 
       // Update region boundaries in scheduling region we sinked from since we
       // may sink an instruction that was at the beginning or end of its region
-      updateRegionBoundaries(NewRegions, Def, /*NewMI =*/nullptr,
-                             /*Removing =*/true);
+      DAG.updateRegionBoundaries(NewRegions, Def, /*NewMI =*/nullptr,
+                                 /*Removing =*/true);
 
       // Update region boundaries in region we sinked to.
-      updateRegionBoundaries(NewRegions, InsertPos, NewMI);
+      DAG.updateRegionBoundaries(NewRegions, InsertPos, NewMI);
 
       LaneBitmask PrevMask = NewLiveIns[I][Reg];
       // FIXME: Also update cached pressure for where the def was sinked from.
@@ -863,9 +1025,9 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
       // the reg from all regions as a live-in.
       for (auto Idx : RematDefToLiveInRegions[Def]) {
         NewLiveIns[Idx].erase(Reg);
-        if (InsertPos->getParent() != Regions[Idx].first->getParent()) {
+        if (InsertPos->getParent() != DAG.Regions[Idx].first->getParent()) {
           // Def is live-through and not used in this block.
-          NewPressure[Idx].inc(Reg, PrevMask, LaneBitmask::getNone(), MRI);
+          NewPressure[Idx].inc(Reg, PrevMask, LaneBitmask::getNone(), DAG.MRI);
         } else {
           // Def is used and rematerialized into this block.
           GCNDownwardRPTracker RPT(*LIS);
@@ -879,7 +1041,7 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
 
       SinkedDefs.push_back(Def);
       ImproveOccupancy = NewPressure[I].getOccupancy(ST);
-      if (ImproveOccupancy > MinOccupancy)
+      if (ImproveOccupancy > DAG.MinOccupancy)
         break;
     }
 
@@ -888,7 +1050,7 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
       for (auto TrackedIdx : RematDefToLiveInRegions[Def])
         RematerializableInsts[TrackedIdx].erase(Def);
 
-    if (ImproveOccupancy <= MinOccupancy)
+    if (ImproveOccupancy <= DAG.MinOccupancy)
       break;
 
     NewRescheduleRegions[I] = true;
@@ -917,7 +1079,7 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
     MachineInstr *OldMI = Entry.second;
 
     // Remove OldMI from BBLiveInMap since we are sinking it from its MBB.
-    BBLiveInMap.erase(OldMI);
+    DAG.BBLiveInMap.erase(OldMI);
 
     // Remove OldMI and update LIS
     Register Reg = MI->getOperand(0).getReg();
@@ -929,22 +1091,22 @@ bool GCNScheduleDAGMILive::sinkTriviallyRematInsts(const GCNSubtarget &ST,
 
   // Update live-ins, register pressure, and regions caches.
   for (auto Idx : ImpactedRegions) {
-    LiveIns[Idx] = NewLiveIns[Idx];
-    Pressure[Idx] = NewPressure[Idx];
-    MBBLiveIns.erase(Regions[Idx].first->getParent());
+    DAG.LiveIns[Idx] = NewLiveIns[Idx];
+    DAG.Pressure[Idx] = NewPressure[Idx];
+    DAG.MBBLiveIns.erase(DAG.Regions[Idx].first->getParent());
   }
-  Regions = NewRegions;
-  RescheduleRegions = NewRescheduleRegions;
+  DAG.Regions = NewRegions;
+  DAG.RescheduleRegions = NewRescheduleRegions;
 
   SIMachineFunctionInfo &MFI = *MF.getInfo<SIMachineFunctionInfo>();
-  MFI.increaseOccupancy(MF, ++MinOccupancy);
+  MFI.increaseOccupancy(MF, ++DAG.MinOccupancy);
 
   return true;
 }
 
 // Copied from MachineLICM
-bool GCNScheduleDAGMILive::isTriviallyReMaterializable(const MachineInstr &MI) {
-  if (!TII->isTriviallyReMaterializable(MI))
+bool PreRARematStage::isTriviallyReMaterializable(const MachineInstr &MI) {
+  if (!DAG.TII->isTriviallyReMaterializable(MI))
     return false;
 
   for (const MachineOperand &MO : MI.operands())

diff  --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
index c3db849cf81a..7aadf89e0bf7 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
@@ -28,8 +28,6 @@ class GCNSubtarget;
 /// heuristics to determine excess/critical pressure sets.  Its goal is to
 /// maximize kernel occupancy (i.e. maximum number of waves per simd).
 class GCNMaxOccupancySchedStrategy final : public GenericScheduler {
-  friend class GCNScheduleDAGMILive;
-
   SUnit *pickNodeBidirectional(bool &IsTopNode);
 
   void pickNodeFromQueue(SchedBoundary &Zone, const CandPolicy &ZonePolicy,
@@ -42,15 +40,18 @@ class GCNMaxOccupancySchedStrategy final : public GenericScheduler {
                      unsigned SGPRPressure, unsigned VGPRPressure);
 
   std::vector<unsigned> Pressure;
+
   std::vector<unsigned> MaxPressure;
 
   unsigned SGPRExcessLimit;
+
   unsigned VGPRExcessLimit;
-  unsigned SGPRCriticalLimit;
-  unsigned VGPRCriticalLimit;
 
   unsigned TargetOccupancy;
 
+  MachineFunction *MF;
+
+public:
   // schedule() have seen a clustered memory operation. Set it to false
   // before a region scheduling to know if the region had such clusters.
   bool HasClusteredNodes;
@@ -59,28 +60,53 @@ class GCNMaxOccupancySchedStrategy final : public GenericScheduler {
   // register pressure for actual scheduling heuristics.
   bool HasExcessPressure;
 
-  MachineFunction *MF;
+  unsigned SGPRCriticalLimit;
+
+  unsigned VGPRCriticalLimit;
 
-public:
   GCNMaxOccupancySchedStrategy(const MachineSchedContext *C);
 
   SUnit *pickNode(bool &IsTopNode) override;
 
   void initialize(ScheduleDAGMI *DAG) override;
 
+  unsigned getTargetOccupancy() { return TargetOccupancy; }
+
   void setTargetOccupancy(unsigned Occ) { TargetOccupancy = Occ; }
 };
 
-class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
+enum class GCNSchedStageID : unsigned {
+  InitialSchedule = 0,
+  UnclusteredReschedule = 1,
+  ClusteredLowOccupancyReschedule = 2,
+  PreRARematerialize = 3,
+  LastStage = PreRARematerialize
+};
+
+#ifndef NDEBUG
+raw_ostream &operator<<(raw_ostream &OS, const GCNSchedStageID &StageID);
+#endif
+
+inline GCNSchedStageID &operator++(GCNSchedStageID &Stage, int) {
+  assert(Stage != GCNSchedStageID::PreRARematerialize);
+  Stage = static_cast<GCNSchedStageID>(static_cast<unsigned>(Stage) + 1);
+  return Stage;
+}
+
+inline GCNSchedStageID nextStage(const GCNSchedStageID Stage) {
+  return static_cast<GCNSchedStageID>(static_cast<unsigned>(Stage) + 1);
+}
 
-  enum : unsigned {
-    Collect,
-    InitialSchedule,
-    UnclusteredReschedule,
-    ClusteredLowOccupancyReschedule,
-    PreRARematerialize,
-    LastStage = PreRARematerialize
-  };
+inline bool operator>(GCNSchedStageID &LHS, GCNSchedStageID &RHS) {
+  return static_cast<unsigned>(LHS) > static_cast<unsigned>(RHS);
+}
+
+class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
+  friend class GCNSchedStage;
+  friend class InitialScheduleStage;
+  friend class UnclusteredRescheduleStage;
+  friend class ClusteredLowOccStage;
+  friend class PreRARematStage;
 
   const GCNSubtarget &ST;
 
@@ -92,12 +118,6 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
   // Minimal real occupancy recorder for the function.
   unsigned MinOccupancy;
 
-  // Scheduling stage number.
-  unsigned Stage;
-
-  // Current region index.
-  size_t RegionIdx;
-
   // Vector of regions recorder for later rescheduling
   SmallVector<std::pair<MachineBasicBlock::iterator,
                         MachineBasicBlock::iterator>, 32> Regions;
@@ -121,6 +141,148 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
   // Region pressure cache.
   SmallVector<GCNRegPressure, 32> Pressure;
 
+  // Temporary basic block live-in cache.
+  DenseMap<const MachineBasicBlock *, GCNRPTracker::LiveRegSet> MBBLiveIns;
+
+  DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> BBLiveInMap;
+
+  DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> getBBLiveInMap() const;
+
+  // Return current region pressure.
+  GCNRegPressure getRealRegPressure(unsigned RegionIdx) const;
+
+  // Compute and cache live-ins and pressure for all regions in block.
+  void computeBlockPressure(unsigned RegionIdx, const MachineBasicBlock *MBB);
+
+  // Update region boundaries when removing MI or inserting NewMI before MI.
+  void updateRegionBoundaries(
+      SmallVectorImpl<std::pair<MachineBasicBlock::iterator,
+                                MachineBasicBlock::iterator>> &RegionBoundaries,
+      MachineBasicBlock::iterator MI, MachineInstr *NewMI,
+      bool Removing = false);
+
+  void runSchedStages();
+
+public:
+  GCNScheduleDAGMILive(MachineSchedContext *C,
+                       std::unique_ptr<MachineSchedStrategy> S);
+
+  void schedule() override;
+
+  void finalizeSchedule() override;
+};
+
+// GCNSchedStrategy applies multiple scheduling stages to a function.
+class GCNSchedStage {
+protected:
+  GCNScheduleDAGMILive &DAG;
+
+  GCNMaxOccupancySchedStrategy &S;
+
+  MachineFunction &MF;
+
+  SIMachineFunctionInfo &MFI;
+
+  const GCNSubtarget &ST;
+
+  const GCNSchedStageID StageID;
+
+  // The current block being scheduled.
+  MachineBasicBlock *CurrentMBB = nullptr;
+
+  // Current region index.
+  unsigned RegionIdx = 0;
+
+  // Record the original order of instructions before scheduling.
+  std::vector<MachineInstr *> Unsched;
+
+  // RP before scheduling the current region.
+  GCNRegPressure PressureBefore;
+
+  // RP after scheduling the current region.
+  GCNRegPressure PressureAfter;
+
+  GCNSchedStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG);
+
+public:
+  // Initialize state for a scheduling stage. Returns false if the current stage
+  // should be skipped.
+  virtual bool initGCNSchedStage();
+
+  // Finalize state after finishing a scheduling pass on the function.
+  virtual void finalizeGCNSchedStage();
+
+  // Setup for scheduling a region. Returns false if the current region should
+  // be skipped.
+  virtual bool initGCNRegion();
+
+  // Track whether a new region is also a new MBB.
+  void setupNewBlock();
+
+  // Finalize state after scheudling a region.
+  virtual void finalizeGCNRegion();
+
+  // Check result of scheduling.
+  void checkScheduling();
+
+  // Returns true if scheduling should be reverted.
+  virtual bool shouldRevertScheduling(unsigned WavesAfter);
+
+  // Returns true if the new schedule may result in more spilling.
+  bool mayCauseSpilling(unsigned WavesAfter);
+
+  // Attempt to revert scheduling for this region.
+  void revertScheduling();
+
+  void advanceRegion() { RegionIdx++; }
+
+  virtual ~GCNSchedStage() = default;
+};
+
+class InitialScheduleStage : public GCNSchedStage {
+public:
+  void finalizeGCNRegion() override;
+
+  bool shouldRevertScheduling(unsigned WavesAfter) override;
+
+  InitialScheduleStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
+      : GCNSchedStage(StageID, DAG) {}
+};
+
+class UnclusteredRescheduleStage : public GCNSchedStage {
+private:
+  std::vector<std::unique_ptr<ScheduleDAGMutation>> SavedMutations;
+
+public:
+  bool initGCNSchedStage() override;
+
+  void finalizeGCNSchedStage() override;
+
+  bool initGCNRegion() override;
+
+  bool shouldRevertScheduling(unsigned WavesAfter) override;
+
+  UnclusteredRescheduleStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
+      : GCNSchedStage(StageID, DAG) {}
+};
+
+// Retry function scheduling if we found resulting occupancy and it is
+// lower than used for other scheduling passes. This will give more freedom
+// to schedule low register pressure blocks.
+class ClusteredLowOccStage : public GCNSchedStage {
+public:
+  bool initGCNSchedStage() override;
+
+  bool initGCNRegion() override;
+
+  bool shouldRevertScheduling(unsigned WavesAfter) override;
+
+  ClusteredLowOccStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
+      : GCNSchedStage(StageID, DAG) {}
+};
+
+class PreRARematStage : public GCNSchedStage {
+private:
   // Each region at MinOccupancy will have their own list of trivially
   // rematerializable instructions we can remat to reduce RP. The list maps an
   // instruction to the position we should remat before, usually the MI using
@@ -132,12 +294,6 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
   // that has the defined reg as a live-in.
   DenseMap<MachineInstr *, SmallVector<unsigned, 4>> RematDefToLiveInRegions;
 
-  // Temporary basic block live-in cache.
-  DenseMap<const MachineBasicBlock*, GCNRPTracker::LiveRegSet> MBBLiveIns;
-
-  DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> BBLiveInMap;
-  DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> getBBLiveInMap() const;
-
   // Collect all trivially rematerializable VGPR instructions with a single def
   // and single use outside the defining block into RematerializableInsts.
   void collectRematerializableInstructions();
@@ -150,26 +306,15 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
   bool sinkTriviallyRematInsts(const GCNSubtarget &ST,
                                const TargetInstrInfo *TII);
 
-  // Return current region pressure.
-  GCNRegPressure getRealRegPressure() const;
-
-  // Compute and cache live-ins and pressure for all regions in block.
-  void computeBlockPressure(const MachineBasicBlock *MBB);
-
-  // Update region boundaries when removing MI or inserting NewMI before MI.
-  void updateRegionBoundaries(
-      SmallVectorImpl<std::pair<MachineBasicBlock::iterator,
-                                MachineBasicBlock::iterator>> &RegionBoundaries,
-      MachineBasicBlock::iterator MI, MachineInstr *NewMI,
-      bool Removing = false);
-
 public:
-  GCNScheduleDAGMILive(MachineSchedContext *C,
-                       std::unique_ptr<MachineSchedStrategy> S);
+  bool initGCNSchedStage() override;
 
-  void schedule() override;
+  bool initGCNRegion() override;
 
-  void finalizeSchedule() override;
+  bool shouldRevertScheduling(unsigned WavesAfter) override;
+
+  PreRARematStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
+      : GCNSchedStage(StageID, DAG) {}
 };
 
 } // End namespace llvm


        


More information about the llvm-commits mailing list