[llvm] [AMDGPU] GCNRegPressure printing pass for testing. (PR #70031)

Valery Pykhtin via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 06:17:05 PDT 2023


================
@@ -483,7 +496,125 @@ Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
   });
 }
 
-LLVM_DUMP_METHOD
 void GCNRegPressure::dump() const { dbgs() << print(*this); }
 
-#endif
+static cl::opt<bool> UseDownwardTracker(
+    "amdgpu-print-rp-downward",
+    cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
+    cl::init(false), cl::Hidden);
+
+char llvm::GCNRegPressurePrinter::ID = 0;
+char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
+
+INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
+
+bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
+  const LiveIntervals &LIS = getAnalysis<LiveIntervals>();
+
+  auto &OS = dbgs();
+
+// Leading spaces are important for YAML syntax.
+#define PFX "  "
+
+  OS << "---\nname: " << MF.getName() << "\nbody:             |\n";
+
+  auto printRP = [](const GCNRegPressure &RP) {
+    return Printable([&RP](raw_ostream &OS) {
+      OS << format(PFX "  %-5d", RP.getSGPRNum())
+         << format(" %-5d", RP.getVGPRNum(false));
+    });
+  };
+
+  auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
+                                    const GCNRPTracker::LiveRegSet &LISLR) {
+    if (LISLR != TrackedLR) {
+      OS << PFX "  mis LIS: " << llvm::print(LISLR, MRI)
+         << reportMismatch(LISLR, TrackedLR, TRI, PFX "    ");
+    }
+  };
+
+  // Register pressure before and at an instruction (in program order).
+  SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
+
+  for (auto &MBB : MF) {
+    RP.clear();
+    RP.reserve(MBB.size());
+
+    OS << PFX;
+    MBB.printName(OS);
+    OS << ":\n";
+
+    SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
+    SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
+
+    GCNRPTracker::LiveRegSet LRAtMBBBegin, LRAtMBBEnd;
+    GCNRegPressure RPAtMBBEnd;
+
+    if (UseDownwardTracker) {
+      if (MBB.empty()) {
+        LRAtMBBBegin = LRAtMBBEnd = getLiveRegs(MBBStartSlot, LIS, MRI);
+        RPAtMBBEnd = getRegPressure(MRI, LRAtMBBBegin);
+      } else {
+        GCNDownwardRPTracker RPT(LIS);
+        RPT.reset(MBB.front());
+
+        LRAtMBBBegin = RPT.getLiveRegs();
+
+        while (!RPT.advanceBeforeNext()) {
+          GCNRegPressure RPBeforeMI = RPT.getPressure();
+          RPT.advanceToNext();
+          RP.emplace_back(RPBeforeMI, RPT.getPressure());
+        }
+
+        LRAtMBBEnd = RPT.getLiveRegs();
+        RPAtMBBEnd = RPT.getPressure();
+      }
+    } else {
+      GCNUpwardRPTracker RPT(LIS);
+      RPT.reset(MRI, MBBEndSlot);
+      RPT.moveMaxPressure(); // Clear max pressure.
+
+      LRAtMBBEnd = RPT.getLiveRegs();
+      RPAtMBBEnd = RPT.getPressure();
+
+      for (auto &MI : reverse(MBB)) {
+        RPT.recede(MI);
+        if (!MI.isDebugInstr())
+          RP.emplace_back(RPT.getPressure(), RPT.moveMaxPressure());
+      }
+
+      LRAtMBBBegin = RPT.getLiveRegs();
+    }
+
+    OS << PFX "  Live-in: " << llvm::print(LRAtMBBBegin, MRI);
+    if (!UseDownwardTracker)
+      ReportLISMismatchIfAny(LRAtMBBBegin, getLiveRegs(MBBStartSlot, LIS, MRI));
+
+    OS << PFX "  SGPR  VGPR\n";
+    int I = 0;
+    for (auto &MI : MBB) {
+      if (!MI.isDebugInstr()) {
+        auto &[RPBeforeInstr, RPAtInstr] =
+            RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
+        ++I;
+        OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << "  ";
+      } else
+        OS << PFX "               ";
+      MI.print(OS);
+    }
+    OS << printRP(RPAtMBBEnd) << '\n';
+
+    OS << PFX "  Live-out:" << llvm::print(LRAtMBBEnd, MRI);
+    if (UseDownwardTracker)
+      ReportLISMismatchIfAny(LRAtMBBEnd, getLiveRegs(MBBEndSlot, LIS, MRI));
+
+    GCNRPTracker::LiveRegSet LRThr = getIntersection(LRAtMBBBegin, LRAtMBBEnd);
+    OS << PFX "  Live-thr:" << llvm::print(LRThr, MRI);
----------------
vpykhtin wrote:

Sorry, I mean using registers from the intersection I need to perform additional pass on def/used for a given block.

https://github.com/llvm/llvm-project/pull/70031


More information about the llvm-commits mailing list