[llvm] [CodeGen][MachinePipeliner] Limit register pressure when scheduling (PR #74807)

Leandro Lupori via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 3 10:43:44 PST 2024


================
@@ -1102,6 +1121,359 @@ struct FuncUnitSorter {
   }
 };
 
+/// Calculate the maximum register pressure of the scheduled instructions stream
+class HighRegisterPressureDetector {
+  MachineBasicBlock *OrigMBB;
+  const MachineFunction &MF;
+  const MachineRegisterInfo &MRI;
+  const TargetRegisterInfo *TRI;
+
+  const unsigned PSetNum;
+
+  // Indexed by PSet ID
+  // InitSetPressure takes into account the register preesure of live-in
+  // registers. It's not depend on how the loop is scheduled, so it's enough to
+  // calculate them once at the begining.
+  std::vector<unsigned> InitSetPressure;
+
+  // Indexed by PSet ID
+  // Upper limit for each register pressure set
+  std::vector<unsigned> PressureSetLimit;
+
+  using Instr2LastUsesTy = DenseMap<MachineInstr *, SmallDenseSet<Register, 4>>;
+
+public:
+  using OrderedInstsTy = std::vector<MachineInstr *>;
+  using Instr2StageTy = DenseMap<MachineInstr *, unsigned>;
+
+private:
+  static void dumpRegisterPressures(const std::vector<unsigned> &Pressures) {
+    if (Pressures.size() == 0) {
+      dbgs() << "[]";
+    } else {
+      char Prefix = '[';
+      for (unsigned P : Pressures) {
+        dbgs() << Prefix << P;
+        Prefix = ' ';
+      }
+      dbgs() << ']';
+    }
+  }
+
+  void dumpPSet(Register Reg) const {
+    dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
+    for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
+         ++PSetIter) {
+      dbgs() << *PSetIter << ' ';
+    }
+    dbgs() << '\n';
+  }
+
+  void increaseRegisterPressure(std::vector<unsigned> &Pressure,
+                                Register Reg) const {
+    auto PSetIter = MRI.getPressureSets(Reg);
+    unsigned Weight = PSetIter.getWeight();
+    for (; PSetIter.isValid(); ++PSetIter)
+      Pressure[*PSetIter] += Weight;
+  }
+
+  void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
+                                Register Reg) const {
+    auto PSetIter = MRI.getPressureSets(Reg);
+    unsigned Weight = PSetIter.getWeight();
+    for (; PSetIter.isValid(); ++PSetIter) {
+      auto &P = Pressure[*PSetIter];
+      assert(Weight <= P &&
+             "register pressure must be greater or equal than weight");
+      P -= Weight;
+    }
+  }
+
+  // Retrun true if Reg is fixed one, for example, stack pointer
+  bool isFixedRegister(Register Reg) const {
+    return Reg.isPhysical() && TRI->isFixedRegister(MF, Reg.asMCReg());
+  }
+
+  bool isDefinedInThisLoop(Register Reg) const {
+    return Reg.isVirtual() && MRI.getVRegDef(Reg)->getParent() == OrigMBB;
+  }
+
+  // Search for live-in variables. They are factored into the register pressure
+  // from the begining.
+  void computeLiveIn() {
+    DenseSet<Register> Used;
+    for (const auto &MI : *OrigMBB) {
+      for (const auto &MO : MI.all_uses()) {
+        auto Use = MO.getReg();
+        // Ignore the variable that appears only on one side of phi instruction
+        // because it's used only at the first iteration.
+        if (MI.isPHI() && Use != getLoopPhiReg(MI, OrigMBB))
+          continue;
+        if (isFixedRegister(Use))
+          continue;
+        if (isDefinedInThisLoop(Use))
+          continue;
+        Used.insert(Use);
+      }
+    }
+
+    for (auto LiveIn : Used)
+      increaseRegisterPressure(InitSetPressure, LiveIn);
+  }
+
+  // Calcluate the upper limit of each pressure set
+  void computePressureSetLimit(const RegisterClassInfo &RCI) {
+    for (unsigned PSet = 0; PSet < PSetNum; PSet++)
+      PressureSetLimit[PSet] = RCI.getRegPressureSetLimit(PSet);
+
+    // We assume fixed registers, such as stack pointer, are already in use.
+    // Therefore subtracting the weight of the fixed registers from the limit of
+    // each pressure set in advance.
+    SmallDenseSet<Register, 8> FixedRegs;
+    for (const TargetRegisterClass *TRC : TRI->regclasses()) {
+      for (const MCPhysReg Reg : *TRC)
+        if (isFixedRegister(Reg) && !FixedRegs.contains(Reg))
+          FixedRegs.insert(Reg);
+    }
+
+    LLVM_DEBUG({
+      for (auto Reg : FixedRegs) {
+        dbgs() << printReg(Reg, TRI, 0, &MRI) << ": [";
+        const int *Sets = TRI->getRegUnitPressureSets(Reg);
+        for (; *Sets != -1; Sets++) {
+          dbgs() << TRI->getRegPressureSetName(*Sets) << ", ";
+        }
+        dbgs() << "]\n";
+      }
+    });
+
+    for (auto Reg : FixedRegs) {
+      LLVM_DEBUG(dbgs() << "fixed register: " << printReg(Reg, TRI, 0, &MRI)
+                        << "\n");
+      auto PSetIter = MRI.getPressureSets(Reg);
+      unsigned Weight = PSetIter.getWeight();
+      for (; PSetIter.isValid(); ++PSetIter) {
+        unsigned &Limit = PressureSetLimit[*PSetIter];
+        assert(Weight <= Limit &&
+               "register pressure limit must be greater or equal than weight");
+        Limit -= Weight;
+        LLVM_DEBUG(dbgs() << "PSet=" << *PSetIter << " Limit=" << Limit
+                          << " (decreased by " << Weight << ")\n");
+      }
+    }
+  }
+
+  // There are two patterns of last-use.
+  //   - by an instruction of the current iteration
+  //   - by a phi instruction of the next iteration (loop carried value)
+  //
+  // Furthermore, following two gropus of instructions are executed
+  // simultaneously
+  //   - next iteration's phi instructions in i-th stage
+  //   - current iteration's instructions in i+1-th stage
+  //
+  // This function calculates the last-use of each register while taking into
+  // account the above two patterns.
+  Instr2LastUsesTy computeLastUses(const OrderedInstsTy &OrderedInsts,
+                                   Instr2StageTy &Stages) const {
+    // We treat virtual registers that are defined and used in this loop.
+    // Following virtual register will be ignored
+    //   - live-in one
+    //   - defined but not used in the loop (potentially live-out)
+    DenseSet<Register> TargetRegs;
+    const auto UpdateTargetRegs = [this, &TargetRegs](Register Reg) {
+      if (isDefinedInThisLoop(Reg))
+        TargetRegs.insert(Reg);
+    };
+    for (MachineInstr *MI : OrderedInsts) {
+      if (MI->isPHI()) {
+        Register Reg = getLoopPhiReg(*MI, OrigMBB);
+        UpdateTargetRegs(Reg);
+      } else {
+        for (const auto &MO : MI->all_uses())
+          if (MO.isReg())
+            UpdateTargetRegs(MO.getReg());
+      }
+    }
+
+    const auto InstrScore = [&Stages](MachineInstr *MI) {
+      return Stages[MI] + MI->isPHI();
+    };
+
+    DenseMap<Register, MachineInstr *> LastUseMI;
+    for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
+      for (const auto &MO : MI->all_uses()) {
+        if (!MO.isReg())
+          continue;
+        auto Reg = MO.getReg();
+        if (!TargetRegs.contains(Reg))
+          continue;
+        auto Ite = LastUseMI.find(Reg);
+        if (Ite == LastUseMI.end()) {
+          LastUseMI[Reg] = MI;
+        } else {
+          MachineInstr *Orig = Ite->second;
+          MachineInstr *New = MI;
+          if (InstrScore(Orig) < InstrScore(New))
+            LastUseMI[Reg] = New;
+        }
+      }
+    }
+
+    Instr2LastUsesTy LastUses;
+    for (auto &Entry : LastUseMI)
+      LastUses[Entry.second].insert(Entry.first);
+    return LastUses;
+  }
+
+  // Compute the maximum register pressure of the kernel. We'll simulate #Stage
+  // iterations and check the register pressure at the point where all stages
+  // overlapping.
+  //
+  // An example of unrolled loop where #Stage is 4..
+  // Iter   i+0 i+1 i+2 i+3
+  // ------------------------
+  // Stage   0
+  // Stage   1   0
+  // Stage   2   1   0
+  // Stage   3   2   1   0  <- All stages overlap
+  //
+  std::vector<unsigned> exec(const OrderedInstsTy &OrderedInsts,
+                             Instr2StageTy &Stages,
+                             const unsigned StageCount) const {
+    using RegSetTy = SmallDenseSet<Register, 16>;
+
+    // Indexed by #Iter. To treat "local" variables of each stage separately, we
+    // manage the liveness of the registers independently by iterations.
+    SmallVector<RegSetTy> LiveRegSets(StageCount);
+
+    auto CurSetPressure = InitSetPressure;
+    auto MaxSetPressure = InitSetPressure;
+    auto LastUses = std::move(computeLastUses(OrderedInsts, Stages));
+
+    LLVM_DEBUG({
+      dbgs() << "Ordered instructions:\n";
+      for (MachineInstr *MI : OrderedInsts) {
+        dbgs() << "Stage " << Stages[MI] << ": ";
+        MI->dump();
+      }
+    });
+
+    const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
+                                                   Register Reg) {
+      if (!Reg.isValid() || isFixedRegister(Reg))
+        return;
+
+      bool Inserted = RegSet.insert(Reg).second;
+      if (!Inserted)
+        return;
+
+      LLVM_DEBUG(dbgs() << "insert " << printReg(Reg, TRI, 0, &MRI) << "\n");
+      increaseRegisterPressure(CurSetPressure, Reg);
+      LLVM_DEBUG(dumpPSet(Reg));
+    };
+
+    const auto EraseReg = [this, &CurSetPressure](RegSetTy &RegSet,
+                                                  Register Reg) {
+      if (!Reg.isValid() || isFixedRegister(Reg))
+        return;
+
+      // live-in register
+      if (!RegSet.contains(Reg))
+        return;
+
+      LLVM_DEBUG(dbgs() << "erase " << printReg(Reg, TRI, 0, &MRI) << "\n");
+      RegSet.erase(Reg);
+      decreaseRegisterPressure(CurSetPressure, Reg);
+      LLVM_DEBUG(dumpPSet(Reg));
+    };
+
+    for (unsigned I = 0; I < StageCount; I++) {
+      for (MachineInstr *MI : OrderedInsts) {
+        const auto Stage = Stages[MI];
+        if (I < Stage)
+          continue;
+
+        const unsigned Iter = I - Stage;
+
+        for (auto &MO : MI->all_defs())
+          InsertReg(LiveRegSets[Iter], MO.getReg());
+
+        for (auto LastUse : LastUses[MI]) {
+          if (MI->isPHI()) {
+            if (Iter != 0)
+              EraseReg(LiveRegSets[Iter - 1], LastUse);
+          } else {
+            EraseReg(LiveRegSets[Iter], LastUse);
+          }
+        }
+
+        for (unsigned PSet = 0; PSet < PSetNum; PSet++)
+          MaxSetPressure[PSet] =
+              std::max(MaxSetPressure[PSet], CurSetPressure[PSet]);
+
+        LLVM_DEBUG({
+          dbgs() << "CurSetPressure=";
+          dumpRegisterPressures(CurSetPressure);
+          dbgs() << " iter=" << Iter << " stage=" << Stage << ":";
+          MI->dump();
+        });
+      }
+    }
+
+    return MaxSetPressure;
+  }
+
+public:
+  HighRegisterPressureDetector() = delete;
----------------
luporl wrote:

IIRC the default constructor is automatically deleted when another constructor is declared.

https://github.com/llvm/llvm-project/pull/74807


More information about the llvm-commits mailing list