[llvm] ef9a02c - [CodeGen] Use VirtRegOrUnit where appropriate (NFCI) (#167730)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 13 02:27:02 PST 2025
Author: Sergei Barannikov
Date: 2025-11-13T10:26:58Z
New Revision: ef9a02ce028782684f9a43dcda756804635ba86a
URL: https://github.com/llvm/llvm-project/commit/ef9a02ce028782684f9a43dcda756804635ba86a
DIFF: https://github.com/llvm/llvm-project/commit/ef9a02ce028782684f9a43dcda756804635ba86a.diff
LOG: [CodeGen] Use VirtRegOrUnit where appropriate (NFCI) (#167730)
Use it in `printVRegOrUnit()`, `getPressureSets()`/`PSetIterator`,
and in functions/classes dealing with register pressure.
Static type checking revealed several bugs, mainly in MachinePipeliner.
I'm not very familiar with this pass, so I left a bunch of FIXMEs.
There is one bug in `findUseBetween()` in RegisterPressure.cpp, also
annotated with a FIXME.
Added:
Modified:
llvm/include/llvm/CodeGen/MachineRegisterInfo.h
llvm/include/llvm/CodeGen/Register.h
llvm/include/llvm/CodeGen/RegisterPressure.h
llvm/include/llvm/CodeGen/TargetRegisterInfo.h
llvm/lib/CodeGen/MachinePipeliner.cpp
llvm/lib/CodeGen/MachineScheduler.cpp
llvm/lib/CodeGen/RegisterPressure.cpp
llvm/lib/CodeGen/TargetRegisterInfo.cpp
llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
llvm/lib/Target/AMDGPU/SIMachineScheduler.h
llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/MachineRegisterInfo.h b/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
index 6982dae4718d1..737b74ef3f761 100644
--- a/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
@@ -634,10 +634,9 @@ class MachineRegisterInfo {
/// function. Writing to a constant register has no effect.
LLVM_ABI bool isConstantPhysReg(MCRegister PhysReg) const;
- /// Get an iterator over the pressure sets affected by the given physical or
- /// virtual register. If RegUnit is physical, it must be a register unit (from
- /// MCRegUnitIterator).
- PSetIterator getPressureSets(Register RegUnit) const;
+ /// Get an iterator over the pressure sets affected by the virtual register
+ /// or register unit.
+ PSetIterator getPressureSets(VirtRegOrUnit VRegOrUnit) const;
//===--------------------------------------------------------------------===//
// Virtual Register Info
@@ -1249,15 +1248,16 @@ class PSetIterator {
public:
PSetIterator() = default;
- PSetIterator(Register RegUnit, const MachineRegisterInfo *MRI) {
+ PSetIterator(VirtRegOrUnit VRegOrUnit, const MachineRegisterInfo *MRI) {
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
- if (RegUnit.isVirtual()) {
- const TargetRegisterClass *RC = MRI->getRegClass(RegUnit);
+ if (VRegOrUnit.isVirtualReg()) {
+ const TargetRegisterClass *RC =
+ MRI->getRegClass(VRegOrUnit.asVirtualReg());
PSet = TRI->getRegClassPressureSets(RC);
Weight = TRI->getRegClassWeight(RC).RegWeight;
} else {
- PSet = TRI->getRegUnitPressureSets(RegUnit);
- Weight = TRI->getRegUnitWeight(RegUnit);
+ PSet = TRI->getRegUnitPressureSets(VRegOrUnit.asMCRegUnit());
+ Weight = TRI->getRegUnitWeight(VRegOrUnit.asMCRegUnit());
}
if (*PSet == -1)
PSet = nullptr;
@@ -1278,8 +1278,8 @@ class PSetIterator {
};
inline PSetIterator
-MachineRegisterInfo::getPressureSets(Register RegUnit) const {
- return PSetIterator(RegUnit, this);
+MachineRegisterInfo::getPressureSets(VirtRegOrUnit VRegOrUnit) const {
+ return PSetIterator(VRegOrUnit, this);
}
} // end namespace llvm
diff --git a/llvm/include/llvm/CodeGen/Register.h b/llvm/include/llvm/CodeGen/Register.h
index 790db8a11e390..5e1e12942a019 100644
--- a/llvm/include/llvm/CodeGen/Register.h
+++ b/llvm/include/llvm/CodeGen/Register.h
@@ -206,6 +206,10 @@ class VirtRegOrUnit {
constexpr bool operator==(const VirtRegOrUnit &Other) const {
return VRegOrUnit == Other.VRegOrUnit;
}
+
+ constexpr bool operator<(const VirtRegOrUnit &Other) const {
+ return VRegOrUnit < Other.VRegOrUnit;
+ }
};
} // namespace llvm
diff --git a/llvm/include/llvm/CodeGen/RegisterPressure.h b/llvm/include/llvm/CodeGen/RegisterPressure.h
index 261e5b0d73281..20a7e4fa2e9de 100644
--- a/llvm/include/llvm/CodeGen/RegisterPressure.h
+++ b/llvm/include/llvm/CodeGen/RegisterPressure.h
@@ -37,11 +37,11 @@ class MachineRegisterInfo;
class RegisterClassInfo;
struct VRegMaskOrUnit {
- Register RegUnit; ///< Virtual register or register unit.
+ VirtRegOrUnit VRegOrUnit;
LaneBitmask LaneMask;
- VRegMaskOrUnit(Register RegUnit, LaneBitmask LaneMask)
- : RegUnit(RegUnit), LaneMask(LaneMask) {}
+ VRegMaskOrUnit(VirtRegOrUnit VRegOrUnit, LaneBitmask LaneMask)
+ : VRegOrUnit(VRegOrUnit), LaneMask(LaneMask) {}
};
/// Base class for register pressure results.
@@ -157,7 +157,7 @@ class PressureDiff {
const_iterator begin() const { return &PressureChanges[0]; }
const_iterator end() const { return &PressureChanges[MaxPSets]; }
- LLVM_ABI void addPressureChange(Register RegUnit, bool IsDec,
+ LLVM_ABI void addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
const MachineRegisterInfo *MRI);
LLVM_ABI void dump(const TargetRegisterInfo &TRI) const;
@@ -279,25 +279,25 @@ class LiveRegSet {
RegSet Regs;
unsigned NumRegUnits = 0u;
- unsigned getSparseIndexFromReg(Register Reg) const {
- if (Reg.isVirtual())
- return Reg.virtRegIndex() + NumRegUnits;
- assert(Reg < NumRegUnits);
- return Reg.id();
+ unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
+ if (VRegOrUnit.isVirtualReg())
+ return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
+ assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
+ return VRegOrUnit.asMCRegUnit();
}
- Register getRegFromSparseIndex(unsigned SparseIndex) const {
+ VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
if (SparseIndex >= NumRegUnits)
- return Register::index2VirtReg(SparseIndex - NumRegUnits);
- return Register(SparseIndex);
+ return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
+ return VirtRegOrUnit(SparseIndex);
}
public:
LLVM_ABI void clear();
LLVM_ABI void init(const MachineRegisterInfo &MRI);
- LaneBitmask contains(Register Reg) const {
- unsigned SparseIndex = getSparseIndexFromReg(Reg);
+ LaneBitmask contains(VirtRegOrUnit VRegOrUnit) const {
+ unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(VRegOrUnit);
RegSet::const_iterator I = Regs.find(SparseIndex);
if (I == Regs.end())
return LaneBitmask::getNone();
@@ -307,7 +307,7 @@ class LiveRegSet {
/// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
/// Returns the previously live lanes of \p Pair.Reg.
LaneBitmask insert(VRegMaskOrUnit Pair) {
- unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
+ unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
if (!InsertRes.second) {
LaneBitmask PrevMask = InsertRes.first->LaneMask;
@@ -320,7 +320,7 @@ class LiveRegSet {
/// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
/// Returns the previously live lanes of \p Pair.Reg.
LaneBitmask erase(VRegMaskOrUnit Pair) {
- unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
+ unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
RegSet::iterator I = Regs.find(SparseIndex);
if (I == Regs.end())
return LaneBitmask::getNone();
@@ -335,9 +335,9 @@ class LiveRegSet {
void appendTo(SmallVectorImpl<VRegMaskOrUnit> &To) const {
for (const IndexMaskPair &P : Regs) {
- Register Reg = getRegFromSparseIndex(P.Index);
+ VirtRegOrUnit VRegOrUnit = getVirtRegOrUnitFromSparseIndex(P.Index);
if (P.LaneMask.any())
- To.emplace_back(Reg, P.LaneMask);
+ To.emplace_back(VRegOrUnit, P.LaneMask);
}
}
};
@@ -541,9 +541,11 @@ class RegPressureTracker {
LLVM_ABI void dump() const;
- LLVM_ABI void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
+ LLVM_ABI void increaseRegPressure(VirtRegOrUnit VRegOrUnit,
+ LaneBitmask PreviousMask,
LaneBitmask NewMask);
- LLVM_ABI void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
+ LLVM_ABI void decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
+ LaneBitmask PreviousMask,
LaneBitmask NewMask);
protected:
@@ -565,9 +567,12 @@ class RegPressureTracker {
discoverLiveInOrOut(VRegMaskOrUnit Pair,
SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut);
- LLVM_ABI LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
- LLVM_ABI LaneBitmask getLiveLanesAt(Register RegUnit, SlotIndex Pos) const;
- LLVM_ABI LaneBitmask getLiveThroughAt(Register RegUnit, SlotIndex Pos) const;
+ LLVM_ABI LaneBitmask getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
+ SlotIndex Pos) const;
+ LLVM_ABI LaneBitmask getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
+ SlotIndex Pos) const;
+ LLVM_ABI LaneBitmask getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
+ SlotIndex Pos) const;
};
LLVM_ABI void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,
diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
index dabf0dc5ec173..35b14e8b8fd30 100644
--- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
@@ -1450,7 +1450,7 @@ LLVM_ABI Printable printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI);
/// Create Printable object to print virtual registers and physical
/// registers on a \ref raw_ostream.
-LLVM_ABI Printable printVRegOrUnit(unsigned VRegOrUnit,
+LLVM_ABI Printable printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
const TargetRegisterInfo *TRI);
/// Create Printable object to print register classes or register banks
diff --git a/llvm/lib/CodeGen/MachinePipeliner.cpp b/llvm/lib/CodeGen/MachinePipeliner.cpp
index a717d9e4a618d..e2f7dfc5cadd5 100644
--- a/llvm/lib/CodeGen/MachinePipeliner.cpp
+++ b/llvm/lib/CodeGen/MachinePipeliner.cpp
@@ -1509,7 +1509,11 @@ class HighRegisterPressureDetector {
void dumpPSet(Register Reg) const {
dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
- for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
+ // FIXME: The static_cast is a bug compensating bugs in the callers.
+ VirtRegOrUnit VRegOrUnit =
+ Reg.isVirtual() ? VirtRegOrUnit(Reg)
+ : VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
+ for (auto PSetIter = MRI.getPressureSets(VRegOrUnit); PSetIter.isValid();
++PSetIter) {
dbgs() << *PSetIter << ' ';
}
@@ -1518,7 +1522,11 @@ class HighRegisterPressureDetector {
void increaseRegisterPressure(std::vector<unsigned> &Pressure,
Register Reg) const {
- auto PSetIter = MRI.getPressureSets(Reg);
+ // FIXME: The static_cast is a bug compensating bugs in the callers.
+ VirtRegOrUnit VRegOrUnit =
+ Reg.isVirtual() ? VirtRegOrUnit(Reg)
+ : VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
+ auto PSetIter = MRI.getPressureSets(VRegOrUnit);
unsigned Weight = PSetIter.getWeight();
for (; PSetIter.isValid(); ++PSetIter)
Pressure[*PSetIter] += Weight;
@@ -1526,7 +1534,7 @@ class HighRegisterPressureDetector {
void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
Register Reg) const {
- auto PSetIter = MRI.getPressureSets(Reg);
+ auto PSetIter = MRI.getPressureSets(VirtRegOrUnit(Reg));
unsigned Weight = PSetIter.getWeight();
for (; PSetIter.isValid(); ++PSetIter) {
auto &P = Pressure[*PSetIter];
@@ -1559,7 +1567,11 @@ class HighRegisterPressureDetector {
if (MI.isDebugInstr())
continue;
for (auto &Use : ROMap[&MI].Uses) {
- auto Reg = Use.RegUnit;
+ // FIXME: The static_cast is a bug.
+ Register Reg =
+ Use.VRegOrUnit.isVirtualReg()
+ ? Use.VRegOrUnit.asVirtualReg()
+ : Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
// Ignore the variable that appears only on one side of phi instruction
// because it's used only at the first iteration.
if (MI.isPHI() && Reg != getLoopPhiReg(MI, OrigMBB))
@@ -1609,8 +1621,14 @@ class HighRegisterPressureDetector {
Register Reg = getLoopPhiReg(*MI, OrigMBB);
UpdateTargetRegs(Reg);
} else {
- for (auto &Use : ROMap.find(MI)->getSecond().Uses)
- UpdateTargetRegs(Use.RegUnit);
+ for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
+ // FIXME: The static_cast is a bug.
+ Register Reg = Use.VRegOrUnit.isVirtualReg()
+ ? Use.VRegOrUnit.asVirtualReg()
+ : Register(static_cast<unsigned>(
+ Use.VRegOrUnit.asMCRegUnit()));
+ UpdateTargetRegs(Reg);
+ }
}
}
@@ -1621,7 +1639,11 @@ class HighRegisterPressureDetector {
DenseMap<Register, MachineInstr *> LastUseMI;
for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
- auto Reg = Use.RegUnit;
+ // FIXME: The static_cast is a bug.
+ Register Reg =
+ Use.VRegOrUnit.isVirtualReg()
+ ? Use.VRegOrUnit.asVirtualReg()
+ : Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
if (!TargetRegs.contains(Reg))
continue;
auto [Ite, Inserted] = LastUseMI.try_emplace(Reg, MI);
@@ -1635,8 +1657,8 @@ class HighRegisterPressureDetector {
}
Instr2LastUsesTy LastUses;
- for (auto &Entry : LastUseMI)
- LastUses[Entry.second].insert(Entry.first);
+ for (auto [Reg, MI] : LastUseMI)
+ LastUses[MI].insert(Reg);
return LastUses;
}
@@ -1675,7 +1697,12 @@ class HighRegisterPressureDetector {
});
const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
- Register Reg) {
+ VirtRegOrUnit VRegOrUnit) {
+ // FIXME: The static_cast is a bug.
+ Register Reg =
+ VRegOrUnit.isVirtualReg()
+ ? VRegOrUnit.asVirtualReg()
+ : Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
if (!Reg.isValid() || isReservedRegister(Reg))
return;
@@ -1712,7 +1739,7 @@ class HighRegisterPressureDetector {
const unsigned Iter = I - Stage;
for (auto &Def : ROMap.find(MI)->getSecond().Defs)
- InsertReg(LiveRegSets[Iter], Def.RegUnit);
+ InsertReg(LiveRegSets[Iter], Def.VRegOrUnit);
for (auto LastUse : LastUses[MI]) {
if (MI->isPHI()) {
@@ -2235,7 +2262,7 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<VRegMaskOrUnit, 8> LiveOutRegs;
- SmallSet<Register, 4> Uses;
+ SmallSet<VirtRegOrUnit, 4> Uses;
for (SUnit *SU : NS) {
const MachineInstr *MI = SU->getInstr();
if (MI->isPHI())
@@ -2243,9 +2270,10 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
for (const MachineOperand &MO : MI->all_uses()) {
Register Reg = MO.getReg();
if (Reg.isVirtual())
- Uses.insert(Reg);
+ Uses.insert(VirtRegOrUnit(Reg));
else if (MRI.isAllocatable(Reg))
- Uses.insert_range(TRI->regunits(Reg.asMCReg()));
+ for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
+ Uses.insert(VirtRegOrUnit(Unit));
}
}
for (SUnit *SU : NS)
@@ -2253,12 +2281,14 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
if (!MO.isDead()) {
Register Reg = MO.getReg();
if (Reg.isVirtual()) {
- if (!Uses.count(Reg))
- LiveOutRegs.emplace_back(Reg, LaneBitmask::getNone());
+ if (!Uses.count(VirtRegOrUnit(Reg)))
+ LiveOutRegs.emplace_back(VirtRegOrUnit(Reg),
+ LaneBitmask::getNone());
} else if (MRI.isAllocatable(Reg)) {
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
- if (!Uses.count(Unit))
- LiveOutRegs.emplace_back(Unit, LaneBitmask::getNone());
+ if (!Uses.count(VirtRegOrUnit(Unit)))
+ LiveOutRegs.emplace_back(VirtRegOrUnit(Unit),
+ LaneBitmask::getNone());
}
}
RPTracker.addLiveRegs(LiveOutRegs);
diff --git a/llvm/lib/CodeGen/MachineScheduler.cpp b/llvm/lib/CodeGen/MachineScheduler.cpp
index 73993705c4a7b..de29a9fab876e 100644
--- a/llvm/lib/CodeGen/MachineScheduler.cpp
+++ b/llvm/lib/CodeGen/MachineScheduler.cpp
@@ -1580,10 +1580,10 @@ updateScheduledPressure(const SUnit *SU,
/// instruction.
void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
for (const VRegMaskOrUnit &P : LiveUses) {
- Register Reg = P.RegUnit;
/// FIXME: Currently assuming single-use physregs.
- if (!Reg.isVirtual())
+ if (!P.VRegOrUnit.isVirtualReg())
continue;
+ Register Reg = P.VRegOrUnit.asVirtualReg();
if (ShouldTrackLaneMasks) {
// If the register has just become live then other uses won't change
@@ -1599,7 +1599,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
continue;
PressureDiff &PDiff = getPressureDiff(&SU);
- PDiff.addPressureChange(Reg, Decrement, &MRI);
+ PDiff.addPressureChange(VirtRegOrUnit(Reg), Decrement, &MRI);
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
return Change.isValid();
}))
@@ -1611,7 +1611,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
}
} else {
assert(P.LaneMask.any());
- LLVM_DEBUG(dbgs() << " LiveReg: " << printVRegOrUnit(Reg, TRI) << "\n");
+ LLVM_DEBUG(dbgs() << " LiveReg: " << printReg(Reg, TRI) << "\n");
// This may be called before CurrentBottom has been initialized. However,
// BotRPTracker must have a valid position. We want the value live into the
// instruction or live out of the block, so ask for the previous
@@ -1638,7 +1638,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
LI.Query(LIS->getInstructionIndex(*SU->getInstr()));
if (LRQ.valueIn() == VNI) {
PressureDiff &PDiff = getPressureDiff(SU);
- PDiff.addPressureChange(Reg, true, &MRI);
+ PDiff.addPressureChange(VirtRegOrUnit(Reg), true, &MRI);
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
return Change.isValid();
}))
@@ -1814,9 +1814,9 @@ unsigned ScheduleDAGMILive::computeCyclicCriticalPath() {
unsigned MaxCyclicLatency = 0;
// Visit each live out vreg def to find def/use pairs that cross iterations.
for (const VRegMaskOrUnit &P : RPTracker.getPressure().LiveOutRegs) {
- Register Reg = P.RegUnit;
- if (!Reg.isVirtual())
+ if (!P.VRegOrUnit.isVirtualReg())
continue;
+ Register Reg = P.VRegOrUnit.asVirtualReg();
const LiveInterval &LI = LIS->getInterval(Reg);
const VNInfo *DefVNI = LI.getVNInfoBefore(LIS->getMBBEndIdx(BB));
if (!DefVNI)
diff --git a/llvm/lib/CodeGen/RegisterPressure.cpp b/llvm/lib/CodeGen/RegisterPressure.cpp
index 7d4674b3f74f0..cd431bc7a171c 100644
--- a/llvm/lib/CodeGen/RegisterPressure.cpp
+++ b/llvm/lib/CodeGen/RegisterPressure.cpp
@@ -47,13 +47,14 @@ using namespace llvm;
/// Increase pressure for each pressure set provided by TargetRegisterInfo.
static void increaseSetPressure(std::vector<unsigned> &CurrSetPressure,
- const MachineRegisterInfo &MRI, unsigned Reg,
- LaneBitmask PrevMask, LaneBitmask NewMask) {
+ const MachineRegisterInfo &MRI,
+ VirtRegOrUnit VRegOrUnit, LaneBitmask PrevMask,
+ LaneBitmask NewMask) {
assert((PrevMask & ~NewMask).none() && "Must not remove bits");
if (PrevMask.any() || NewMask.none())
return;
- PSetIterator PSetI = MRI.getPressureSets(Reg);
+ PSetIterator PSetI = MRI.getPressureSets(VRegOrUnit);
unsigned Weight = PSetI.getWeight();
for (; PSetI.isValid(); ++PSetI)
CurrSetPressure[*PSetI] += Weight;
@@ -61,13 +62,14 @@ static void increaseSetPressure(std::vector<unsigned> &CurrSetPressure,
/// Decrease pressure for each pressure set provided by TargetRegisterInfo.
static void decreaseSetPressure(std::vector<unsigned> &CurrSetPressure,
- const MachineRegisterInfo &MRI, Register Reg,
- LaneBitmask PrevMask, LaneBitmask NewMask) {
+ const MachineRegisterInfo &MRI,
+ VirtRegOrUnit VRegOrUnit, LaneBitmask PrevMask,
+ LaneBitmask NewMask) {
assert((NewMask & ~PrevMask).none() && "Must not add bits");
if (NewMask.any() || PrevMask.none())
return;
- PSetIterator PSetI = MRI.getPressureSets(Reg);
+ PSetIterator PSetI = MRI.getPressureSets(VRegOrUnit);
unsigned Weight = PSetI.getWeight();
for (; PSetI.isValid(); ++PSetI) {
assert(CurrSetPressure[*PSetI] >= Weight && "register pressure underflow");
@@ -93,7 +95,7 @@ void RegisterPressure::dump(const TargetRegisterInfo *TRI) const {
dumpRegSetPressure(MaxSetPressure, TRI);
dbgs() << "Live In: ";
for (const VRegMaskOrUnit &P : LiveInRegs) {
- dbgs() << printVRegOrUnit(P.RegUnit, TRI);
+ dbgs() << printVRegOrUnit(P.VRegOrUnit, TRI);
if (!P.LaneMask.all())
dbgs() << ':' << PrintLaneMask(P.LaneMask);
dbgs() << ' ';
@@ -101,7 +103,7 @@ void RegisterPressure::dump(const TargetRegisterInfo *TRI) const {
dbgs() << '\n';
dbgs() << "Live Out: ";
for (const VRegMaskOrUnit &P : LiveOutRegs) {
- dbgs() << printVRegOrUnit(P.RegUnit, TRI);
+ dbgs() << printVRegOrUnit(P.VRegOrUnit, TRI);
if (!P.LaneMask.all())
dbgs() << ':' << PrintLaneMask(P.LaneMask);
dbgs() << ' ';
@@ -148,13 +150,13 @@ void RegPressureDelta::dump() const {
#endif
-void RegPressureTracker::increaseRegPressure(Register RegUnit,
+void RegPressureTracker::increaseRegPressure(VirtRegOrUnit VRegOrUnit,
LaneBitmask PreviousMask,
LaneBitmask NewMask) {
if (PreviousMask.any() || NewMask.none())
return;
- PSetIterator PSetI = MRI->getPressureSets(RegUnit);
+ PSetIterator PSetI = MRI->getPressureSets(VRegOrUnit);
unsigned Weight = PSetI.getWeight();
for (; PSetI.isValid(); ++PSetI) {
CurrSetPressure[*PSetI] += Weight;
@@ -163,10 +165,10 @@ void RegPressureTracker::increaseRegPressure(Register RegUnit,
}
}
-void RegPressureTracker::decreaseRegPressure(Register RegUnit,
+void RegPressureTracker::decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
LaneBitmask PreviousMask,
LaneBitmask NewMask) {
- decreaseSetPressure(CurrSetPressure, *MRI, RegUnit, PreviousMask, NewMask);
+ decreaseSetPressure(CurrSetPressure, *MRI, VRegOrUnit, PreviousMask, NewMask);
}
/// Clear the result so it can be used for another round of pressure tracking.
@@ -230,10 +232,11 @@ void LiveRegSet::clear() {
Regs.clear();
}
-static const LiveRange *getLiveRange(const LiveIntervals &LIS, unsigned Reg) {
- if (Register::isVirtualRegister(Reg))
- return &LIS.getInterval(Reg);
- return LIS.getCachedRegUnit(Reg);
+static const LiveRange *getLiveRange(const LiveIntervals &LIS,
+ VirtRegOrUnit VRegOrUnit) {
+ if (VRegOrUnit.isVirtualReg())
+ return &LIS.getInterval(VRegOrUnit.asVirtualReg());
+ return LIS.getCachedRegUnit(VRegOrUnit.asMCRegUnit());
}
void RegPressureTracker::reset() {
@@ -356,17 +359,18 @@ void RegPressureTracker::initLiveThru(const RegPressureTracker &RPTracker) {
LiveThruPressure.assign(TRI->getNumRegPressureSets(), 0);
assert(isBottomClosed() && "need bottom-up tracking to intialize.");
for (const VRegMaskOrUnit &Pair : P.LiveOutRegs) {
- Register RegUnit = Pair.RegUnit;
- if (RegUnit.isVirtual() && !RPTracker.hasUntiedDef(RegUnit))
- increaseSetPressure(LiveThruPressure, *MRI, RegUnit,
+ VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
+ if (VRegOrUnit.isVirtualReg() &&
+ !RPTracker.hasUntiedDef(VRegOrUnit.asVirtualReg()))
+ increaseSetPressure(LiveThruPressure, *MRI, VRegOrUnit,
LaneBitmask::getNone(), Pair.LaneMask);
}
}
static LaneBitmask getRegLanes(ArrayRef<VRegMaskOrUnit> RegUnits,
- Register RegUnit) {
- auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
- return Other.RegUnit == RegUnit;
+ VirtRegOrUnit VRegOrUnit) {
+ auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
if (I == RegUnits.end())
return LaneBitmask::getNone();
@@ -375,10 +379,10 @@ static LaneBitmask getRegLanes(ArrayRef<VRegMaskOrUnit> RegUnits,
static void addRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
VRegMaskOrUnit Pair) {
- Register RegUnit = Pair.RegUnit;
+ VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
assert(Pair.LaneMask.any());
- auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
- return Other.RegUnit == RegUnit;
+ auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
if (I == RegUnits.end()) {
RegUnits.push_back(Pair);
@@ -388,12 +392,12 @@ static void addRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
}
static void setRegZero(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
- Register RegUnit) {
- auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
- return Other.RegUnit == RegUnit;
+ VirtRegOrUnit VRegOrUnit) {
+ auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
if (I == RegUnits.end()) {
- RegUnits.emplace_back(RegUnit, LaneBitmask::getNone());
+ RegUnits.emplace_back(VRegOrUnit, LaneBitmask::getNone());
} else {
I->LaneMask = LaneBitmask::getNone();
}
@@ -401,10 +405,10 @@ static void setRegZero(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
static void removeRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
VRegMaskOrUnit Pair) {
- Register RegUnit = Pair.RegUnit;
+ VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
assert(Pair.LaneMask.any());
- auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
- return Other.RegUnit == RegUnit;
+ auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
if (I != RegUnits.end()) {
I->LaneMask &= ~Pair.LaneMask;
@@ -415,11 +419,11 @@ static void removeRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
static LaneBitmask
getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
- bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
- LaneBitmask SafeDefault,
+ bool TrackLaneMasks, VirtRegOrUnit VRegOrUnit,
+ SlotIndex Pos, LaneBitmask SafeDefault,
bool (*Property)(const LiveRange &LR, SlotIndex Pos)) {
- if (RegUnit.isVirtual()) {
- const LiveInterval &LI = LIS.getInterval(RegUnit);
+ if (VRegOrUnit.isVirtualReg()) {
+ const LiveInterval &LI = LIS.getInterval(VRegOrUnit.asVirtualReg());
LaneBitmask Result;
if (TrackLaneMasks && LI.hasSubRanges()) {
for (const LiveInterval::SubRange &SR : LI.subranges()) {
@@ -427,13 +431,14 @@ getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
Result |= SR.LaneMask;
}
} else if (Property(LI, Pos)) {
- Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
- : LaneBitmask::getAll();
+ Result = TrackLaneMasks
+ ? MRI.getMaxLaneMaskForVReg(VRegOrUnit.asVirtualReg())
+ : LaneBitmask::getAll();
}
return Result;
} else {
- const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
+ const LiveRange *LR = LIS.getCachedRegUnit(VRegOrUnit.asMCRegUnit());
// Be prepared for missing liveranges: We usually do not compute liveranges
// for physical registers on targets with many registers (GPUs).
if (LR == nullptr)
@@ -444,13 +449,11 @@ getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
static LaneBitmask getLiveLanesAt(const LiveIntervals &LIS,
const MachineRegisterInfo &MRI,
- bool TrackLaneMasks, Register RegUnit,
+ bool TrackLaneMasks, VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) {
- return getLanesWithProperty(LIS, MRI, TrackLaneMasks, RegUnit, Pos,
- LaneBitmask::getAll(),
- [](const LiveRange &LR, SlotIndex Pos) {
- return LR.liveAt(Pos);
- });
+ return getLanesWithProperty(
+ LIS, MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getAll(),
+ [](const LiveRange &LR, SlotIndex Pos) { return LR.liveAt(Pos); });
}
namespace {
@@ -514,10 +517,12 @@ class RegisterOperandsCollector {
void pushReg(Register Reg, SmallVectorImpl<VRegMaskOrUnit> &RegUnits) const {
if (Reg.isVirtual()) {
- addRegLanes(RegUnits, VRegMaskOrUnit(Reg, LaneBitmask::getAll()));
+ addRegLanes(RegUnits,
+ VRegMaskOrUnit(VirtRegOrUnit(Reg), LaneBitmask::getAll()));
} else if (MRI.isAllocatable(Reg)) {
for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
- addRegLanes(RegUnits, VRegMaskOrUnit(Unit, LaneBitmask::getAll()));
+ addRegLanes(RegUnits,
+ VRegMaskOrUnit(VirtRegOrUnit(Unit), LaneBitmask::getAll()));
}
}
@@ -549,10 +554,11 @@ class RegisterOperandsCollector {
LaneBitmask LaneMask = SubRegIdx != 0
? TRI.getSubRegIndexLaneMask(SubRegIdx)
: MRI.getMaxLaneMaskForVReg(Reg);
- addRegLanes(RegUnits, VRegMaskOrUnit(Reg, LaneMask));
+ addRegLanes(RegUnits, VRegMaskOrUnit(VirtRegOrUnit(Reg), LaneMask));
} else if (MRI.isAllocatable(Reg)) {
for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
- addRegLanes(RegUnits, VRegMaskOrUnit(Unit, LaneBitmask::getAll()));
+ addRegLanes(RegUnits,
+ VRegMaskOrUnit(VirtRegOrUnit(Unit), LaneBitmask::getAll()));
}
}
};
@@ -574,8 +580,7 @@ void RegisterOperands::detectDeadDefs(const MachineInstr &MI,
const LiveIntervals &LIS) {
SlotIndex SlotIdx = LIS.getInstructionIndex(MI);
for (auto *RI = Defs.begin(); RI != Defs.end(); /*empty*/) {
- Register Reg = RI->RegUnit;
- const LiveRange *LR = getLiveRange(LIS, Reg);
+ const LiveRange *LR = getLiveRange(LIS, RI->VRegOrUnit);
if (LR != nullptr) {
LiveQueryResult LRQ = LR->Query(SlotIdx);
if (LRQ.isDeadDef()) {
@@ -595,14 +600,14 @@ void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS,
SlotIndex Pos,
MachineInstr *AddFlagsMI) {
for (auto *I = Defs.begin(); I != Defs.end();) {
- LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, I->RegUnit,
- Pos.getDeadSlot());
+ LaneBitmask LiveAfter =
+ getLiveLanesAt(LIS, MRI, true, I->VRegOrUnit, Pos.getDeadSlot());
// If the def is all that is live after the instruction, then in case
// of a subregister def we need a read-undef flag.
- Register RegUnit = I->RegUnit;
- if (RegUnit.isVirtual() && AddFlagsMI != nullptr &&
+ VirtRegOrUnit VRegOrUnit = I->VRegOrUnit;
+ if (VRegOrUnit.isVirtualReg() && AddFlagsMI != nullptr &&
(LiveAfter & ~I->LaneMask).none())
- AddFlagsMI->setRegisterDefReadUndef(RegUnit);
+ AddFlagsMI->setRegisterDefReadUndef(VRegOrUnit.asVirtualReg());
LaneBitmask ActualDef = I->LaneMask & LiveAfter;
if (ActualDef.none()) {
@@ -614,18 +619,18 @@ void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS,
}
// For uses just copy the information from LIS.
- for (auto &[RegUnit, LaneMask] : Uses)
- LaneMask = getLiveLanesAt(LIS, MRI, true, RegUnit, Pos.getBaseIndex());
+ for (auto &[VRegOrUnit, LaneMask] : Uses)
+ LaneMask = getLiveLanesAt(LIS, MRI, true, VRegOrUnit, Pos.getBaseIndex());
if (AddFlagsMI != nullptr) {
for (const VRegMaskOrUnit &P : DeadDefs) {
- Register RegUnit = P.RegUnit;
- if (!RegUnit.isVirtual())
+ VirtRegOrUnit VRegOrUnit = P.VRegOrUnit;
+ if (!VRegOrUnit.isVirtualReg())
continue;
- LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, RegUnit,
- Pos.getDeadSlot());
+ LaneBitmask LiveAfter =
+ getLiveLanesAt(LIS, MRI, true, VRegOrUnit, Pos.getDeadSlot());
if (LiveAfter.none())
- AddFlagsMI->setRegisterDefReadUndef(RegUnit);
+ AddFlagsMI->setRegisterDefReadUndef(VRegOrUnit.asVirtualReg());
}
}
}
@@ -648,16 +653,16 @@ void PressureDiffs::addInstruction(unsigned Idx,
PressureDiff &PDiff = (*this)[Idx];
assert(!PDiff.begin()->isValid() && "stale PDiff");
for (const VRegMaskOrUnit &P : RegOpers.Defs)
- PDiff.addPressureChange(P.RegUnit, true, &MRI);
+ PDiff.addPressureChange(P.VRegOrUnit, true, &MRI);
for (const VRegMaskOrUnit &P : RegOpers.Uses)
- PDiff.addPressureChange(P.RegUnit, false, &MRI);
+ PDiff.addPressureChange(P.VRegOrUnit, false, &MRI);
}
/// Add a change in pressure to the pressure
diff of a given instruction.
-void PressureDiff::addPressureChange(Register RegUnit, bool IsDec,
+void PressureDiff::addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
const MachineRegisterInfo *MRI) {
- PSetIterator PSetI = MRI->getPressureSets(RegUnit);
+ PSetIterator PSetI = MRI->getPressureSets(VRegOrUnit);
int Weight = IsDec ? -PSetI.getWeight() : PSetI.getWeight();
for (; PSetI.isValid(); ++PSetI) {
// Find an existing entry in the pressure
diff for this PSet.
@@ -694,7 +699,7 @@ void RegPressureTracker::addLiveRegs(ArrayRef<VRegMaskOrUnit> Regs) {
for (const VRegMaskOrUnit &P : Regs) {
LaneBitmask PrevMask = LiveRegs.insert(P);
LaneBitmask NewMask = PrevMask | P.LaneMask;
- increaseRegPressure(P.RegUnit, PrevMask, NewMask);
+ increaseRegPressure(P.VRegOrUnit, PrevMask, NewMask);
}
}
@@ -702,9 +707,9 @@ void RegPressureTracker::discoverLiveInOrOut(
VRegMaskOrUnit Pair, SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut) {
assert(Pair.LaneMask.any());
- Register RegUnit = Pair.RegUnit;
- auto I = llvm::find_if(LiveInOrOut, [RegUnit](const VRegMaskOrUnit &Other) {
- return Other.RegUnit == RegUnit;
+ VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
+ auto I = find_if(LiveInOrOut, [VRegOrUnit](const VRegMaskOrUnit &Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
LaneBitmask PrevMask;
LaneBitmask NewMask;
@@ -717,7 +722,7 @@ void RegPressureTracker::discoverLiveInOrOut(
NewMask = PrevMask | Pair.LaneMask;
I->LaneMask = NewMask;
}
- increaseSetPressure(P.MaxSetPressure, *MRI, RegUnit, PrevMask, NewMask);
+ increaseSetPressure(P.MaxSetPressure, *MRI, VRegOrUnit, PrevMask, NewMask);
}
void RegPressureTracker::discoverLiveIn(VRegMaskOrUnit Pair) {
@@ -730,16 +735,14 @@ void RegPressureTracker::discoverLiveOut(VRegMaskOrUnit Pair) {
void RegPressureTracker::bumpDeadDefs(ArrayRef<VRegMaskOrUnit> DeadDefs) {
for (const VRegMaskOrUnit &P : DeadDefs) {
- Register Reg = P.RegUnit;
- LaneBitmask LiveMask = LiveRegs.contains(Reg);
+ LaneBitmask LiveMask = LiveRegs.contains(P.VRegOrUnit);
LaneBitmask BumpedMask = LiveMask | P.LaneMask;
- increaseRegPressure(Reg, LiveMask, BumpedMask);
+ increaseRegPressure(P.VRegOrUnit, LiveMask, BumpedMask);
}
for (const VRegMaskOrUnit &P : DeadDefs) {
- Register Reg = P.RegUnit;
- LaneBitmask LiveMask = LiveRegs.contains(Reg);
+ LaneBitmask LiveMask = LiveRegs.contains(P.VRegOrUnit);
LaneBitmask BumpedMask = LiveMask | P.LaneMask;
- decreaseRegPressure(Reg, BumpedMask, LiveMask);
+ decreaseRegPressure(P.VRegOrUnit, BumpedMask, LiveMask);
}
}
@@ -758,17 +761,17 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
// Kill liveness at live defs.
// TODO: consider earlyclobbers?
for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
- Register Reg = Def.RegUnit;
+ VirtRegOrUnit VRegOrUnit = Def.VRegOrUnit;
LaneBitmask PreviousMask = LiveRegs.erase(Def);
LaneBitmask NewMask = PreviousMask & ~Def.LaneMask;
LaneBitmask LiveOut = Def.LaneMask & ~PreviousMask;
if (LiveOut.any()) {
- discoverLiveOut(VRegMaskOrUnit(Reg, LiveOut));
+ discoverLiveOut(VRegMaskOrUnit(VRegOrUnit, LiveOut));
// Retroactively model effects on pressure of the live out lanes.
- increaseSetPressure(CurrSetPressure, *MRI, Reg, LaneBitmask::getNone(),
- LiveOut);
+ increaseSetPressure(CurrSetPressure, *MRI, VRegOrUnit,
+ LaneBitmask::getNone(), LiveOut);
PreviousMask = LiveOut;
}
@@ -776,10 +779,10 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
// Add a 0 entry to LiveUses as a marker that the complete vreg has become
// dead.
if (TrackLaneMasks && LiveUses != nullptr)
- setRegZero(*LiveUses, Reg);
+ setRegZero(*LiveUses, VRegOrUnit);
}
- decreaseRegPressure(Reg, PreviousMask, NewMask);
+ decreaseRegPressure(VRegOrUnit, PreviousMask, NewMask);
}
SlotIndex SlotIdx;
@@ -788,7 +791,7 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
// Generate liveness for uses.
for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
- Register Reg = Use.RegUnit;
+ VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
assert(Use.LaneMask.any());
LaneBitmask PreviousMask = LiveRegs.insert(Use);
LaneBitmask NewMask = PreviousMask | Use.LaneMask;
@@ -799,38 +802,38 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
if (PreviousMask.none()) {
if (LiveUses != nullptr) {
if (!TrackLaneMasks) {
- addRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+ addRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
} else {
- auto I = llvm::find_if(*LiveUses, [Reg](const VRegMaskOrUnit Other) {
- return Other.RegUnit == Reg;
+ auto I = find_if(*LiveUses, [VRegOrUnit](const VRegMaskOrUnit Other) {
+ return Other.VRegOrUnit == VRegOrUnit;
});
bool IsRedef = I != LiveUses->end();
if (IsRedef) {
// ignore re-defs here...
assert(I->LaneMask.none());
- removeRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+ removeRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
} else {
- addRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+ addRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
}
}
}
// Discover live outs if this may be the first occurance of this register.
if (RequireIntervals) {
- LaneBitmask LiveOut = getLiveThroughAt(Reg, SlotIdx);
+ LaneBitmask LiveOut = getLiveThroughAt(VRegOrUnit, SlotIdx);
if (LiveOut.any())
- discoverLiveOut(VRegMaskOrUnit(Reg, LiveOut));
+ discoverLiveOut(VRegMaskOrUnit(VRegOrUnit, LiveOut));
}
}
- increaseRegPressure(Reg, PreviousMask, NewMask);
+ increaseRegPressure(VRegOrUnit, PreviousMask, NewMask);
}
if (TrackUntiedDefs) {
for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
- Register RegUnit = Def.RegUnit;
- if (RegUnit.isVirtual() &&
- (LiveRegs.contains(RegUnit) & Def.LaneMask).none())
- UntiedDefs.insert(RegUnit);
+ VirtRegOrUnit VRegOrUnit = Def.VRegOrUnit;
+ if (VRegOrUnit.isVirtualReg() &&
+ (LiveRegs.contains(VRegOrUnit) & Def.LaneMask).none())
+ UntiedDefs.insert(VRegOrUnit.asVirtualReg());
}
}
}
@@ -898,20 +901,20 @@ void RegPressureTracker::advance(const RegisterOperands &RegOpers) {
}
for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
- Register Reg = Use.RegUnit;
- LaneBitmask LiveMask = LiveRegs.contains(Reg);
+ VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
+ LaneBitmask LiveMask = LiveRegs.contains(VRegOrUnit);
LaneBitmask LiveIn = Use.LaneMask & ~LiveMask;
if (LiveIn.any()) {
- discoverLiveIn(VRegMaskOrUnit(Reg, LiveIn));
- increaseRegPressure(Reg, LiveMask, LiveMask | LiveIn);
- LiveRegs.insert(VRegMaskOrUnit(Reg, LiveIn));
+ discoverLiveIn(VRegMaskOrUnit(VRegOrUnit, LiveIn));
+ increaseRegPressure(VRegOrUnit, LiveMask, LiveMask | LiveIn);
+ LiveRegs.insert(VRegMaskOrUnit(VRegOrUnit, LiveIn));
}
// Kill liveness at last uses.
if (RequireIntervals) {
- LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
+ LaneBitmask LastUseMask = getLastUsedLanes(VRegOrUnit, SlotIdx);
if (LastUseMask.any()) {
- LiveRegs.erase(VRegMaskOrUnit(Reg, LastUseMask));
- decreaseRegPressure(Reg, LiveMask, LiveMask & ~LastUseMask);
+ LiveRegs.erase(VRegMaskOrUnit(VRegOrUnit, LastUseMask));
+ decreaseRegPressure(VRegOrUnit, LiveMask, LiveMask & ~LastUseMask);
}
}
}
@@ -920,7 +923,7 @@ void RegPressureTracker::advance(const RegisterOperands &RegOpers) {
for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
LaneBitmask PreviousMask = LiveRegs.insert(Def);
LaneBitmask NewMask = PreviousMask | Def.LaneMask;
- increaseRegPressure(Def.RegUnit, PreviousMask, NewMask);
+ increaseRegPressure(Def.VRegOrUnit, PreviousMask, NewMask);
}
// Boost pressure for all dead defs together.
@@ -1047,22 +1050,20 @@ void RegPressureTracker::bumpUpwardPressure(const MachineInstr *MI) {
// Kill liveness at live defs.
for (const VRegMaskOrUnit &P : RegOpers.Defs) {
- Register Reg = P.RegUnit;
- LaneBitmask LiveAfter = LiveRegs.contains(Reg);
- LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, Reg);
+ LaneBitmask LiveAfter = LiveRegs.contains(P.VRegOrUnit);
+ LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, P.VRegOrUnit);
LaneBitmask DefLanes = P.LaneMask;
LaneBitmask LiveBefore = (LiveAfter & ~DefLanes) | UseLanes;
// There may be parts of the register that were dead before the
// instruction, but became live afterwards.
- decreaseRegPressure(Reg, LiveAfter, LiveAfter & LiveBefore);
+ decreaseRegPressure(P.VRegOrUnit, LiveAfter, LiveAfter & LiveBefore);
}
// Generate liveness for uses. Also handle any uses which overlap with defs.
for (const VRegMaskOrUnit &P : RegOpers.Uses) {
- Register Reg = P.RegUnit;
- LaneBitmask LiveAfter = LiveRegs.contains(Reg);
+ LaneBitmask LiveAfter = LiveRegs.contains(P.VRegOrUnit);
LaneBitmask LiveBefore = LiveAfter | P.LaneMask;
- increaseRegPressure(Reg, LiveAfter, LiveBefore);
+ increaseRegPressure(P.VRegOrUnit, LiveAfter, LiveBefore);
}
}
@@ -1209,11 +1210,17 @@ getUpwardPressureDelta(const MachineInstr *MI, /*const*/ PressureDiff &PDiff,
/// Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx).
/// The query starts with a lane bitmask which gets lanes/bits removed for every
/// use we find.
-static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
+static LaneBitmask findUseBetween(VirtRegOrUnit VRegOrUnit,
+ LaneBitmask LastUseMask,
SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
const MachineRegisterInfo &MRI,
const LiveIntervals *LIS) {
const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
+ // FIXME: The static_cast is a bug.
+ Register Reg =
+ VRegOrUnit.isVirtualReg()
+ ? VRegOrUnit.asVirtualReg()
+ : Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
if (MO.isUndef())
continue;
@@ -1230,32 +1237,30 @@ static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
return LastUseMask;
}
-LaneBitmask RegPressureTracker::getLiveLanesAt(Register RegUnit,
+LaneBitmask RegPressureTracker::getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const {
assert(RequireIntervals);
- return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
- LaneBitmask::getAll(),
- [](const LiveRange &LR, SlotIndex Pos) {
- return LR.liveAt(Pos);
- });
+ return getLanesWithProperty(
+ *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getAll(),
+ [](const LiveRange &LR, SlotIndex Pos) { return LR.liveAt(Pos); });
}
-LaneBitmask RegPressureTracker::getLastUsedLanes(Register RegUnit,
+LaneBitmask RegPressureTracker::getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const {
assert(RequireIntervals);
- return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit,
- Pos.getBaseIndex(), LaneBitmask::getNone(),
- [](const LiveRange &LR, SlotIndex Pos) {
+ return getLanesWithProperty(
+ *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos.getBaseIndex(),
+ LaneBitmask::getNone(), [](const LiveRange &LR, SlotIndex Pos) {
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
return S != nullptr && S->end == Pos.getRegSlot();
});
}
-LaneBitmask RegPressureTracker::getLiveThroughAt(Register RegUnit,
+LaneBitmask RegPressureTracker::getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
SlotIndex Pos) const {
assert(RequireIntervals);
- return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
- LaneBitmask::getNone(),
+ return getLanesWithProperty(
+ *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getNone(),
[](const LiveRange &LR, SlotIndex Pos) {
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
return S != nullptr && S->start < Pos.getRegSlot(true) &&
@@ -1284,8 +1289,8 @@ void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) {
if (RequireIntervals) {
for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
- Register Reg = Use.RegUnit;
- LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
+ VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
+ LaneBitmask LastUseMask = getLastUsedLanes(VRegOrUnit, SlotIdx);
if (LastUseMask.none())
continue;
// The LastUseMask is queried from the liveness information of instruction
@@ -1294,23 +1299,22 @@ void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) {
// FIXME: allow the caller to pass in the list of vreg uses that remain
// to be bottom-scheduled to avoid searching uses at each query.
SlotIndex CurrIdx = getCurrSlot();
- LastUseMask
- = findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS);
+ LastUseMask =
+ findUseBetween(VRegOrUnit, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS);
if (LastUseMask.none())
continue;
- LaneBitmask LiveMask = LiveRegs.contains(Reg);
+ LaneBitmask LiveMask = LiveRegs.contains(VRegOrUnit);
LaneBitmask NewMask = LiveMask & ~LastUseMask;
- decreaseRegPressure(Reg, LiveMask, NewMask);
+ decreaseRegPressure(VRegOrUnit, LiveMask, NewMask);
}
}
// Generate liveness for defs.
for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
- Register Reg = Def.RegUnit;
- LaneBitmask LiveMask = LiveRegs.contains(Reg);
+ LaneBitmask LiveMask = LiveRegs.contains(Def.VRegOrUnit);
LaneBitmask NewMask = LiveMask | Def.LaneMask;
- increaseRegPressure(Reg, LiveMask, NewMask);
+ increaseRegPressure(Def.VRegOrUnit, LiveMask, NewMask);
}
// Boost pressure for all dead defs together.
diff --git a/llvm/lib/CodeGen/TargetRegisterInfo.cpp b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
index a5c81afc57a80..975895809b9de 100644
--- a/llvm/lib/CodeGen/TargetRegisterInfo.cpp
+++ b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
@@ -156,12 +156,13 @@ Printable llvm::printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI) {
});
}
-Printable llvm::printVRegOrUnit(unsigned Unit, const TargetRegisterInfo *TRI) {
- return Printable([Unit, TRI](raw_ostream &OS) {
- if (Register::isVirtualRegister(Unit)) {
- OS << '%' << Register(Unit).virtRegIndex();
+Printable llvm::printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
+ const TargetRegisterInfo *TRI) {
+ return Printable([VRegOrUnit, TRI](raw_ostream &OS) {
+ if (VRegOrUnit.isVirtualReg()) {
+ OS << '%' << VRegOrUnit.asVirtualReg().virtRegIndex();
} else {
- OS << printRegUnit(Unit, TRI);
+ OS << printRegUnit(VRegOrUnit.asMCRegUnit(), TRI);
}
});
}
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
index 4e11c4ff3d56e..0c5e3d0837800 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
@@ -282,11 +282,12 @@ collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
Register Reg = MO.getReg();
auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) {
- return RM.RegUnit == Reg;
+ return RM.VRegOrUnit.asVirtualReg() == Reg;
});
auto &P = I == VRegMaskOrUnits.end()
- ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone())
+ ? VRegMaskOrUnits.emplace_back(VirtRegOrUnit(Reg),
+ LaneBitmask::getNone())
: *I;
P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
@@ -295,7 +296,7 @@ collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
SlotIndex InstrSI;
for (auto &P : VRegMaskOrUnits) {
- auto &LI = LIS.getInterval(P.RegUnit);
+ auto &LI = LIS.getInterval(P.VRegOrUnit.asVirtualReg());
if (!LI.hasSubRanges())
continue;
@@ -562,10 +563,10 @@ void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
SmallVector<VRegMaskOrUnit, 8> RegUses;
collectVirtualRegUses(RegUses, MI, LIS, *MRI);
for (const VRegMaskOrUnit &U : RegUses) {
- LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
+ LaneBitmask &LiveMask = LiveRegs[U.VRegOrUnit.asVirtualReg()];
LaneBitmask PrevMask = LiveMask;
LiveMask |= U.LaneMask;
- CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
+ CurPressure.inc(U.VRegOrUnit.asVirtualReg(), PrevMask, LiveMask, *MRI);
}
// Update MaxPressure with uses plus early-clobber defs pressure.
@@ -748,9 +749,9 @@ GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
GCNRegPressure TempPressure = CurPressure;
for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
- Register Reg = Use.RegUnit;
- if (!Reg.isVirtual())
+ if (!Use.VRegOrUnit.isVirtualReg())
continue;
+ Register Reg = Use.VRegOrUnit.asVirtualReg();
LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
if (LastUseMask.none())
continue;
@@ -782,9 +783,9 @@ GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
// Generate liveness for defs.
for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
- Register Reg = Def.RegUnit;
- if (!Reg.isVirtual())
+ if (!Def.VRegOrUnit.isVirtualReg())
continue;
+ Register Reg = Def.VRegOrUnit.asVirtualReg();
auto It = LiveRegs.find(Reg);
LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
LaneBitmask NewMask = LiveMask | Def.LaneMask;
@@ -824,8 +825,7 @@ Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
Register Reg = Register::index2VirtReg(I);
auto It = LiveRegs.find(Reg);
if (It != LiveRegs.end() && It->second.any())
- OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
- << PrintLaneMask(It->second);
+ OS << ' ' << printReg(Reg, TRI) << ':' << PrintLaneMask(It->second);
}
OS << '\n';
});
diff --git a/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp b/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
index fd28abeb887c2..2f3ad39c75fcc 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
+++ b/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
@@ -323,8 +323,8 @@ void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock,
// Do not Track Physical Registers, because it messes up.
for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) {
- if (RegMaskPair.RegUnit.isVirtual())
- LiveInRegs.insert(RegMaskPair.RegUnit);
+ if (RegMaskPair.VRegOrUnit.isVirtualReg())
+ LiveInRegs.insert(RegMaskPair.VRegOrUnit.asVirtualReg());
}
LiveOutRegs.clear();
// There is several possibilities to distinguish:
@@ -350,12 +350,13 @@ void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock,
// Comparing to LiveInRegs is not sufficient to
diff erentiate 4 vs 5, 7
// The use of findDefBetween removes the case 4.
for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) {
- Register Reg = RegMaskPair.RegUnit;
- if (Reg.isVirtual() &&
- isDefBetween(Reg, LIS->getInstructionIndex(*BeginBlock).getRegSlot(),
+ VirtRegOrUnit VRegOrUnit = RegMaskPair.VRegOrUnit;
+ if (VRegOrUnit.isVirtualReg() &&
+ isDefBetween(VRegOrUnit.asVirtualReg(),
+ LIS->getInstructionIndex(*BeginBlock).getRegSlot(),
LIS->getInstructionIndex(*EndBlock).getRegSlot(), MRI,
LIS)) {
- LiveOutRegs.insert(Reg);
+ LiveOutRegs.insert(VRegOrUnit.asVirtualReg());
}
}
@@ -578,11 +579,11 @@ void SIScheduleBlock::printDebug(bool full) {
<< LiveOutPressure[AMDGPU::RegisterPressureSets::VGPR_32] << "\n\n";
dbgs() << "LiveIns:\n";
for (Register Reg : LiveInRegs)
- dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
+ dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
dbgs() << "\nLiveOuts:\n";
for (Register Reg : LiveOutRegs)
- dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
+ dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
}
dbgs() << "\nInstructions:\n";
@@ -1446,23 +1447,24 @@ SIScheduleBlockScheduler::SIScheduleBlockScheduler(SIScheduleDAGMI *DAG,
}
#endif
- std::set<Register> InRegs = DAG->getInRegs();
+ std::set<VirtRegOrUnit> InRegs = DAG->getInRegs();
addLiveRegs(InRegs);
// Increase LiveOutRegsNumUsages for blocks
// producing registers consumed in another
// scheduling region.
- for (Register Reg : DAG->getOutRegs()) {
+ for (VirtRegOrUnit VRegOrUnit : DAG->getOutRegs()) {
for (unsigned i = 0, e = Blocks.size(); i != e; ++i) {
// Do reverse traversal
int ID = BlocksStruct.TopDownIndex2Block[Blocks.size()-1-i];
SIScheduleBlock *Block = Blocks[ID];
const std::set<Register> &OutRegs = Block->getOutRegs();
- if (OutRegs.find(Reg) == OutRegs.end())
+ if (!VRegOrUnit.isVirtualReg() ||
+ OutRegs.find(VRegOrUnit.asVirtualReg()) == OutRegs.end())
continue;
- ++LiveOutRegsNumUsages[ID][Reg];
+ ++LiveOutRegsNumUsages[ID][VRegOrUnit.asVirtualReg()];
break;
}
}
@@ -1565,15 +1567,18 @@ SIScheduleBlock *SIScheduleBlockScheduler::pickBlock() {
maxVregUsage = VregCurrentUsage;
if (SregCurrentUsage > maxSregUsage)
maxSregUsage = SregCurrentUsage;
- LLVM_DEBUG(dbgs() << "Picking New Blocks\n"; dbgs() << "Available: ";
- for (SIScheduleBlock *Block : ReadyBlocks)
- dbgs() << Block->getID() << ' ';
- dbgs() << "\nCurrent Live:\n";
- for (Register Reg : LiveRegs)
- dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
- dbgs() << '\n';
- dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n';
- dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n';);
+ LLVM_DEBUG({
+ dbgs() << "Picking New Blocks\n";
+ dbgs() << "Available: ";
+ for (SIScheduleBlock *Block : ReadyBlocks)
+ dbgs() << Block->getID() << ' ';
+ dbgs() << "\nCurrent Live:\n";
+ for (Register Reg : LiveRegs)
+ dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
+ dbgs() << '\n';
+ dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n';
+ dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n';
+ });
Cand.Block = nullptr;
for (std::vector<SIScheduleBlock*>::iterator I = ReadyBlocks.begin(),
@@ -1625,13 +1630,13 @@ SIScheduleBlock *SIScheduleBlockScheduler::pickBlock() {
// Tracking of currently alive registers to determine VGPR Usage.
-void SIScheduleBlockScheduler::addLiveRegs(std::set<Register> &Regs) {
- for (Register Reg : Regs) {
+void SIScheduleBlockScheduler::addLiveRegs(std::set<VirtRegOrUnit> &Regs) {
+ for (VirtRegOrUnit VRegOrUnit : Regs) {
// For now only track virtual registers.
- if (!Reg.isVirtual())
+ if (!VRegOrUnit.isVirtualReg())
continue;
// If not already in the live set, then add it.
- (void) LiveRegs.insert(Reg);
+ (void)LiveRegs.insert(VRegOrUnit.asVirtualReg());
}
}
@@ -1662,7 +1667,7 @@ void SIScheduleBlockScheduler::releaseBlockSuccs(SIScheduleBlock *Parent) {
void SIScheduleBlockScheduler::blockScheduled(SIScheduleBlock *Block) {
decreaseLiveRegs(Block, Block->getInRegs());
- addLiveRegs(Block->getOutRegs());
+ LiveRegs.insert(Block->getOutRegs().begin(), Block->getOutRegs().end());
releaseBlockSuccs(Block);
for (const auto &RegP : LiveOutRegsNumUsages[Block->getID()]) {
// We produce this register, thus it must not be previously alive.
@@ -1689,7 +1694,7 @@ SIScheduleBlockScheduler::checkRegUsageImpact(std::set<Register> &InRegs,
continue;
if (LiveRegsConsumers[Reg] > 1)
continue;
- PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg);
+ PSetIterator PSetI = DAG->getMRI()->getPressureSets(VirtRegOrUnit(Reg));
for (; PSetI.isValid(); ++PSetI) {
DiffSetPressure[*PSetI] -= PSetI.getWeight();
}
@@ -1699,7 +1704,7 @@ SIScheduleBlockScheduler::checkRegUsageImpact(std::set<Register> &InRegs,
// For now only track virtual registers.
if (!Reg.isVirtual())
continue;
- PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg);
+ PSetIterator PSetI = DAG->getMRI()->getPressureSets(VirtRegOrUnit(Reg));
for (; PSetI.isValid(); ++PSetI) {
DiffSetPressure[*PSetI] += PSetI.getWeight();
}
@@ -1846,7 +1851,7 @@ SIScheduleDAGMI::fillVgprSgprCost(_Iterator First, _Iterator End,
// For now only track virtual registers
if (!Reg.isVirtual())
continue;
- PSetIterator PSetI = MRI.getPressureSets(Reg);
+ PSetIterator PSetI = MRI.getPressureSets(VirtRegOrUnit(Reg));
for (; PSetI.isValid(); ++PSetI) {
if (*PSetI == AMDGPU::RegisterPressureSets::VGPR_32)
VgprUsage += PSetI.getWeight();
diff --git a/llvm/lib/Target/AMDGPU/SIMachineScheduler.h b/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
index b219cbd5672f0..1245774400af1 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
@@ -389,7 +389,7 @@ class SIScheduleBlockScheduler {
SIBlockSchedCandidate &TryCand);
SIScheduleBlock *pickBlock();
- void addLiveRegs(std::set<Register> &Regs);
+ void addLiveRegs(std::set<VirtRegOrUnit> &Regs);
void decreaseLiveRegs(SIScheduleBlock *Block, std::set<Register> &Regs);
void releaseBlockSuccs(SIScheduleBlock *Parent);
void blockScheduled(SIScheduleBlock *Block);
@@ -462,18 +462,18 @@ class SIScheduleDAGMI final : public ScheduleDAGMILive {
unsigned &VgprUsage,
unsigned &SgprUsage);
- std::set<Register> getInRegs() {
- std::set<Register> InRegs;
+ std::set<VirtRegOrUnit> getInRegs() {
+ std::set<VirtRegOrUnit> InRegs;
for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) {
- InRegs.insert(RegMaskPair.RegUnit);
+ InRegs.insert(RegMaskPair.VRegOrUnit);
}
return InRegs;
}
- std::set<unsigned> getOutRegs() {
- std::set<unsigned> OutRegs;
+ std::set<VirtRegOrUnit> getOutRegs() {
+ std::set<VirtRegOrUnit> OutRegs;
for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) {
- OutRegs.insert(RegMaskPair.RegUnit);
+ OutRegs.insert(RegMaskPair.VRegOrUnit);
}
return OutRegs;
};
diff --git a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
index 6611e1e6507e1..10762edc16264 100644
--- a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
+++ b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
@@ -188,8 +188,9 @@ class SIWholeQuadMode {
void markInstruction(MachineInstr &MI, char Flag,
std::vector<WorkItem> &Worklist);
- void markDefs(const MachineInstr &UseMI, LiveRange &LR, Register Reg,
- unsigned SubReg, char Flag, std::vector<WorkItem> &Worklist);
+ void markDefs(const MachineInstr &UseMI, LiveRange &LR,
+ VirtRegOrUnit VRegOrUnit, unsigned SubReg, char Flag,
+ std::vector<WorkItem> &Worklist);
void markOperand(const MachineInstr &MI, const MachineOperand &Op, char Flag,
std::vector<WorkItem> &Worklist);
void markInstructionUses(const MachineInstr &MI, char Flag,
@@ -318,8 +319,8 @@ void SIWholeQuadMode::markInstruction(MachineInstr &MI, char Flag,
/// Mark all relevant definitions of register \p Reg in usage \p UseMI.
void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
- Register Reg, unsigned SubReg, char Flag,
- std::vector<WorkItem> &Worklist) {
+ VirtRegOrUnit VRegOrUnit, unsigned SubReg,
+ char Flag, std::vector<WorkItem> &Worklist) {
LLVM_DEBUG(dbgs() << "markDefs " << PrintState(Flag) << ": " << UseMI);
LiveQueryResult UseLRQ = LR.Query(LIS->getInstructionIndex(UseMI));
@@ -331,8 +332,9 @@ void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
// cover registers.
const LaneBitmask UseLanes =
SubReg ? TRI->getSubRegIndexLaneMask(SubReg)
- : (Reg.isVirtual() ? MRI->getMaxLaneMaskForVReg(Reg)
- : LaneBitmask::getNone());
+ : (VRegOrUnit.isVirtualReg()
+ ? MRI->getMaxLaneMaskForVReg(VRegOrUnit.asVirtualReg())
+ : LaneBitmask::getNone());
// Perform a depth-first iteration of the LiveRange graph marking defs.
// Stop processing of a given branch when all use lanes have been defined.
@@ -382,11 +384,11 @@ void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
MachineInstr *MI = LIS->getInstructionFromIndex(Value->def);
assert(MI && "Def has no defining instruction");
- if (Reg.isVirtual()) {
+ if (VRegOrUnit.isVirtualReg()) {
// Iterate over all operands to find relevant definitions
bool HasDef = false;
for (const MachineOperand &Op : MI->all_defs()) {
- if (Op.getReg() != Reg)
+ if (Op.getReg() != VRegOrUnit.asVirtualReg())
continue;
// Compute lanes defined and overlap with use
@@ -453,7 +455,7 @@ void SIWholeQuadMode::markOperand(const MachineInstr &MI,
<< " for " << MI);
if (Reg.isVirtual()) {
LiveRange &LR = LIS->getInterval(Reg);
- markDefs(MI, LR, Reg, Op.getSubReg(), Flag, Worklist);
+ markDefs(MI, LR, VirtRegOrUnit(Reg), Op.getSubReg(), Flag, Worklist);
} else {
// Handle physical registers that we need to track; this is mostly relevant
// for VCC, which can appear as the (implicit) input of a uniform branch,
@@ -462,7 +464,8 @@ void SIWholeQuadMode::markOperand(const MachineInstr &MI,
LiveRange &LR = LIS->getRegUnit(Unit);
const VNInfo *Value = LR.Query(LIS->getInstructionIndex(MI)).valueIn();
if (Value)
- markDefs(MI, LR, Unit, AMDGPU::NoSubRegister, Flag, Worklist);
+ markDefs(MI, LR, VirtRegOrUnit(Unit), AMDGPU::NoSubRegister, Flag,
+ Worklist);
}
}
}
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index 6077c18463240..02887ce93c525 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -6551,7 +6551,7 @@ class ARMPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
static int constexpr LAST_IS_USE = MAX_STAGES;
static int constexpr SEEN_AS_LIVE = MAX_STAGES + 1;
typedef std::bitset<MAX_STAGES + 2> IterNeed;
- typedef std::map<unsigned, IterNeed> IterNeeds;
+ typedef std::map<Register, IterNeed> IterNeeds;
void bumpCrossIterationPressure(RegPressureTracker &RPT,
const IterNeeds &CIN);
@@ -6625,14 +6625,14 @@ void ARMPipelinerLoopInfo::bumpCrossIterationPressure(RegPressureTracker &RPT,
for (const auto &N : CIN) {
int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
for (int I = 0; I < Cnt; ++I)
- RPT.increaseRegPressure(Register(N.first), LaneBitmask::getNone(),
+ RPT.increaseRegPressure(VirtRegOrUnit(N.first), LaneBitmask::getNone(),
LaneBitmask::getAll());
}
// Decrease pressure by the amounts in CrossIterationNeeds
for (const auto &N : CIN) {
int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
for (int I = 0; I < Cnt; ++I)
- RPT.decreaseRegPressure(Register(N.first), LaneBitmask::getAll(),
+ RPT.decreaseRegPressure(VirtRegOrUnit(N.first), LaneBitmask::getAll(),
LaneBitmask::getNone());
}
}
More information about the llvm-commits
mailing list