[llvm] [CodeGen][MachinePipeliner] Limit register pressure when scheduling (PR #74807)
Ryotaro Kasuga via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 4 03:51:30 PST 2024
================
@@ -1102,6 +1121,359 @@ struct FuncUnitSorter {
}
};
+/// Calculate the maximum register pressure of the scheduled instructions stream
+class HighRegisterPressureDetector {
+ MachineBasicBlock *OrigMBB;
+ const MachineFunction &MF;
+ const MachineRegisterInfo &MRI;
+ const TargetRegisterInfo *TRI;
+
+ const unsigned PSetNum;
+
+ // Indexed by PSet ID
+ // InitSetPressure takes into account the register preesure of live-in
+ // registers. It's not depend on how the loop is scheduled, so it's enough to
+ // calculate them once at the begining.
+ std::vector<unsigned> InitSetPressure;
+
+ // Indexed by PSet ID
+ // Upper limit for each register pressure set
+ std::vector<unsigned> PressureSetLimit;
+
+ using Instr2LastUsesTy = DenseMap<MachineInstr *, SmallDenseSet<Register, 4>>;
+
+public:
+ using OrderedInstsTy = std::vector<MachineInstr *>;
+ using Instr2StageTy = DenseMap<MachineInstr *, unsigned>;
+
+private:
+ static void dumpRegisterPressures(const std::vector<unsigned> &Pressures) {
+ if (Pressures.size() == 0) {
+ dbgs() << "[]";
+ } else {
+ char Prefix = '[';
+ for (unsigned P : Pressures) {
+ dbgs() << Prefix << P;
+ Prefix = ' ';
+ }
+ dbgs() << ']';
+ }
+ }
+
+ void dumpPSet(Register Reg) const {
+ dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
+ for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
+ ++PSetIter) {
+ dbgs() << *PSetIter << ' ';
+ }
+ dbgs() << '\n';
+ }
+
+ void increaseRegisterPressure(std::vector<unsigned> &Pressure,
+ Register Reg) const {
+ auto PSetIter = MRI.getPressureSets(Reg);
+ unsigned Weight = PSetIter.getWeight();
+ for (; PSetIter.isValid(); ++PSetIter)
+ Pressure[*PSetIter] += Weight;
+ }
+
+ void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
+ Register Reg) const {
+ auto PSetIter = MRI.getPressureSets(Reg);
+ unsigned Weight = PSetIter.getWeight();
+ for (; PSetIter.isValid(); ++PSetIter) {
+ auto &P = Pressure[*PSetIter];
+ assert(Weight <= P &&
+ "register pressure must be greater or equal than weight");
+ P -= Weight;
+ }
+ }
+
+ // Retrun true if Reg is fixed one, for example, stack pointer
+ bool isFixedRegister(Register Reg) const {
+ return Reg.isPhysical() && TRI->isFixedRegister(MF, Reg.asMCReg());
+ }
+
+ bool isDefinedInThisLoop(Register Reg) const {
+ return Reg.isVirtual() && MRI.getVRegDef(Reg)->getParent() == OrigMBB;
+ }
+
+ // Search for live-in variables. They are factored into the register pressure
+ // from the begining.
+ void computeLiveIn() {
+ DenseSet<Register> Used;
+ for (const auto &MI : *OrigMBB) {
+ for (const auto &MO : MI.all_uses()) {
+ auto Use = MO.getReg();
+ // Ignore the variable that appears only on one side of phi instruction
+ // because it's used only at the first iteration.
+ if (MI.isPHI() && Use != getLoopPhiReg(MI, OrigMBB))
+ continue;
+ if (isFixedRegister(Use))
+ continue;
+ if (isDefinedInThisLoop(Use))
+ continue;
+ Used.insert(Use);
+ }
+ }
+
+ for (auto LiveIn : Used)
+ increaseRegisterPressure(InitSetPressure, LiveIn);
+ }
+
+ // Calcluate the upper limit of each pressure set
+ void computePressureSetLimit(const RegisterClassInfo &RCI) {
+ for (unsigned PSet = 0; PSet < PSetNum; PSet++)
+ PressureSetLimit[PSet] = RCI.getRegPressureSetLimit(PSet);
+
+ // We assume fixed registers, such as stack pointer, are already in use.
+ // Therefore subtracting the weight of the fixed registers from the limit of
+ // each pressure set in advance.
+ SmallDenseSet<Register, 8> FixedRegs;
+ for (const TargetRegisterClass *TRC : TRI->regclasses()) {
+ for (const MCPhysReg Reg : *TRC)
+ if (isFixedRegister(Reg) && !FixedRegs.contains(Reg))
+ FixedRegs.insert(Reg);
+ }
+
+ LLVM_DEBUG({
+ for (auto Reg : FixedRegs) {
+ dbgs() << printReg(Reg, TRI, 0, &MRI) << ": [";
+ const int *Sets = TRI->getRegUnitPressureSets(Reg);
+ for (; *Sets != -1; Sets++) {
+ dbgs() << TRI->getRegPressureSetName(*Sets) << ", ";
+ }
+ dbgs() << "]\n";
+ }
+ });
+
+ for (auto Reg : FixedRegs) {
+ LLVM_DEBUG(dbgs() << "fixed register: " << printReg(Reg, TRI, 0, &MRI)
+ << "\n");
+ auto PSetIter = MRI.getPressureSets(Reg);
+ unsigned Weight = PSetIter.getWeight();
+ for (; PSetIter.isValid(); ++PSetIter) {
+ unsigned &Limit = PressureSetLimit[*PSetIter];
+ assert(Weight <= Limit &&
+ "register pressure limit must be greater or equal than weight");
+ Limit -= Weight;
+ LLVM_DEBUG(dbgs() << "PSet=" << *PSetIter << " Limit=" << Limit
+ << " (decreased by " << Weight << ")\n");
+ }
+ }
+ }
+
+ // There are two patterns of last-use.
+ // - by an instruction of the current iteration
+ // - by a phi instruction of the next iteration (loop carried value)
+ //
+ // Furthermore, following two gropus of instructions are executed
+ // simultaneously
+ // - next iteration's phi instructions in i-th stage
+ // - current iteration's instructions in i+1-th stage
+ //
+ // This function calculates the last-use of each register while taking into
+ // account the above two patterns.
+ Instr2LastUsesTy computeLastUses(const OrderedInstsTy &OrderedInsts,
+ Instr2StageTy &Stages) const {
+ // We treat virtual registers that are defined and used in this loop.
+ // Following virtual register will be ignored
+ // - live-in one
+ // - defined but not used in the loop (potentially live-out)
+ DenseSet<Register> TargetRegs;
+ const auto UpdateTargetRegs = [this, &TargetRegs](Register Reg) {
+ if (isDefinedInThisLoop(Reg))
+ TargetRegs.insert(Reg);
+ };
+ for (MachineInstr *MI : OrderedInsts) {
+ if (MI->isPHI()) {
+ Register Reg = getLoopPhiReg(*MI, OrigMBB);
+ UpdateTargetRegs(Reg);
+ } else {
+ for (const auto &MO : MI->all_uses())
+ if (MO.isReg())
+ UpdateTargetRegs(MO.getReg());
+ }
+ }
+
+ const auto InstrScore = [&Stages](MachineInstr *MI) {
+ return Stages[MI] + MI->isPHI();
+ };
+
+ DenseMap<Register, MachineInstr *> LastUseMI;
+ for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
+ for (const auto &MO : MI->all_uses()) {
+ if (!MO.isReg())
+ continue;
+ auto Reg = MO.getReg();
+ if (!TargetRegs.contains(Reg))
+ continue;
+ auto Ite = LastUseMI.find(Reg);
+ if (Ite == LastUseMI.end()) {
+ LastUseMI[Reg] = MI;
+ } else {
+ MachineInstr *Orig = Ite->second;
+ MachineInstr *New = MI;
+ if (InstrScore(Orig) < InstrScore(New))
+ LastUseMI[Reg] = New;
+ }
+ }
+ }
+
+ Instr2LastUsesTy LastUses;
+ for (auto &Entry : LastUseMI)
+ LastUses[Entry.second].insert(Entry.first);
+ return LastUses;
+ }
+
+ // Compute the maximum register pressure of the kernel. We'll simulate #Stage
+ // iterations and check the register pressure at the point where all stages
+ // overlapping.
+ //
+ // An example of unrolled loop where #Stage is 4..
+ // Iter i+0 i+1 i+2 i+3
+ // ------------------------
+ // Stage 0
+ // Stage 1 0
+ // Stage 2 1 0
+ // Stage 3 2 1 0 <- All stages overlap
+ //
+ std::vector<unsigned> exec(const OrderedInstsTy &OrderedInsts,
+ Instr2StageTy &Stages,
+ const unsigned StageCount) const {
+ using RegSetTy = SmallDenseSet<Register, 16>;
+
+ // Indexed by #Iter. To treat "local" variables of each stage separately, we
+ // manage the liveness of the registers independently by iterations.
+ SmallVector<RegSetTy> LiveRegSets(StageCount);
+
+ auto CurSetPressure = InitSetPressure;
+ auto MaxSetPressure = InitSetPressure;
+ auto LastUses = std::move(computeLastUses(OrderedInsts, Stages));
+
+ LLVM_DEBUG({
+ dbgs() << "Ordered instructions:\n";
+ for (MachineInstr *MI : OrderedInsts) {
+ dbgs() << "Stage " << Stages[MI] << ": ";
+ MI->dump();
+ }
+ });
+
+ const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
+ Register Reg) {
+ if (!Reg.isValid() || isFixedRegister(Reg))
+ return;
+
+ bool Inserted = RegSet.insert(Reg).second;
+ if (!Inserted)
+ return;
+
+ LLVM_DEBUG(dbgs() << "insert " << printReg(Reg, TRI, 0, &MRI) << "\n");
+ increaseRegisterPressure(CurSetPressure, Reg);
+ LLVM_DEBUG(dumpPSet(Reg));
+ };
+
+ const auto EraseReg = [this, &CurSetPressure](RegSetTy &RegSet,
+ Register Reg) {
+ if (!Reg.isValid() || isFixedRegister(Reg))
+ return;
+
+ // live-in register
+ if (!RegSet.contains(Reg))
+ return;
+
+ LLVM_DEBUG(dbgs() << "erase " << printReg(Reg, TRI, 0, &MRI) << "\n");
+ RegSet.erase(Reg);
+ decreaseRegisterPressure(CurSetPressure, Reg);
+ LLVM_DEBUG(dumpPSet(Reg));
+ };
+
+ for (unsigned I = 0; I < StageCount; I++) {
+ for (MachineInstr *MI : OrderedInsts) {
+ const auto Stage = Stages[MI];
+ if (I < Stage)
+ continue;
+
+ const unsigned Iter = I - Stage;
+
+ for (auto &MO : MI->all_defs())
+ InsertReg(LiveRegSets[Iter], MO.getReg());
+
+ for (auto LastUse : LastUses[MI]) {
+ if (MI->isPHI()) {
+ if (Iter != 0)
+ EraseReg(LiveRegSets[Iter - 1], LastUse);
+ } else {
+ EraseReg(LiveRegSets[Iter], LastUse);
+ }
+ }
+
+ for (unsigned PSet = 0; PSet < PSetNum; PSet++)
+ MaxSetPressure[PSet] =
+ std::max(MaxSetPressure[PSet], CurSetPressure[PSet]);
+
+ LLVM_DEBUG({
+ dbgs() << "CurSetPressure=";
+ dumpRegisterPressures(CurSetPressure);
+ dbgs() << " iter=" << Iter << " stage=" << Stage << ":";
+ MI->dump();
+ });
+ }
+ }
+
+ return MaxSetPressure;
+ }
+
+public:
+ HighRegisterPressureDetector() = delete;
----------------
kasuga-fj wrote:
You're right. Thanks.
https://github.com/llvm/llvm-project/pull/74807
More information about the llvm-commits
mailing list