[llvm] ef9a02c - [CodeGen] Use VirtRegOrUnit where appropriate (NFCI) (#167730)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 13 02:27:02 PST 2025


Author: Sergei Barannikov
Date: 2025-11-13T10:26:58Z
New Revision: ef9a02ce028782684f9a43dcda756804635ba86a

URL: https://github.com/llvm/llvm-project/commit/ef9a02ce028782684f9a43dcda756804635ba86a
DIFF: https://github.com/llvm/llvm-project/commit/ef9a02ce028782684f9a43dcda756804635ba86a.diff

LOG: [CodeGen] Use VirtRegOrUnit where appropriate (NFCI) (#167730)

Use it in `printVRegOrUnit()`, `getPressureSets()`/`PSetIterator`,
and in functions/classes dealing with register pressure.

Static type checking revealed several bugs, mainly in MachinePipeliner.
I'm not very familiar with this pass, so I left a bunch of FIXMEs.

There is one bug in `findUseBetween()` in RegisterPressure.cpp, also
annotated with a FIXME.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/MachineRegisterInfo.h
    llvm/include/llvm/CodeGen/Register.h
    llvm/include/llvm/CodeGen/RegisterPressure.h
    llvm/include/llvm/CodeGen/TargetRegisterInfo.h
    llvm/lib/CodeGen/MachinePipeliner.cpp
    llvm/lib/CodeGen/MachineScheduler.cpp
    llvm/lib/CodeGen/RegisterPressure.cpp
    llvm/lib/CodeGen/TargetRegisterInfo.cpp
    llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
    llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
    llvm/lib/Target/AMDGPU/SIMachineScheduler.h
    llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
    llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachineRegisterInfo.h b/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
index 6982dae4718d1..737b74ef3f761 100644
--- a/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineRegisterInfo.h
@@ -634,10 +634,9 @@ class MachineRegisterInfo {
   /// function. Writing to a constant register has no effect.
   LLVM_ABI bool isConstantPhysReg(MCRegister PhysReg) const;
 
-  /// Get an iterator over the pressure sets affected by the given physical or
-  /// virtual register. If RegUnit is physical, it must be a register unit (from
-  /// MCRegUnitIterator).
-  PSetIterator getPressureSets(Register RegUnit) const;
+  /// Get an iterator over the pressure sets affected by the virtual register
+  /// or register unit.
+  PSetIterator getPressureSets(VirtRegOrUnit VRegOrUnit) const;
 
   //===--------------------------------------------------------------------===//
   // Virtual Register Info
@@ -1249,15 +1248,16 @@ class PSetIterator {
 public:
   PSetIterator() = default;
 
-  PSetIterator(Register RegUnit, const MachineRegisterInfo *MRI) {
+  PSetIterator(VirtRegOrUnit VRegOrUnit, const MachineRegisterInfo *MRI) {
     const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
-    if (RegUnit.isVirtual()) {
-      const TargetRegisterClass *RC = MRI->getRegClass(RegUnit);
+    if (VRegOrUnit.isVirtualReg()) {
+      const TargetRegisterClass *RC =
+          MRI->getRegClass(VRegOrUnit.asVirtualReg());
       PSet = TRI->getRegClassPressureSets(RC);
       Weight = TRI->getRegClassWeight(RC).RegWeight;
     } else {
-      PSet = TRI->getRegUnitPressureSets(RegUnit);
-      Weight = TRI->getRegUnitWeight(RegUnit);
+      PSet = TRI->getRegUnitPressureSets(VRegOrUnit.asMCRegUnit());
+      Weight = TRI->getRegUnitWeight(VRegOrUnit.asMCRegUnit());
     }
     if (*PSet == -1)
       PSet = nullptr;
@@ -1278,8 +1278,8 @@ class PSetIterator {
 };
 
 inline PSetIterator
-MachineRegisterInfo::getPressureSets(Register RegUnit) const {
-  return PSetIterator(RegUnit, this);
+MachineRegisterInfo::getPressureSets(VirtRegOrUnit VRegOrUnit) const {
+  return PSetIterator(VRegOrUnit, this);
 }
 
 } // end namespace llvm

diff  --git a/llvm/include/llvm/CodeGen/Register.h b/llvm/include/llvm/CodeGen/Register.h
index 790db8a11e390..5e1e12942a019 100644
--- a/llvm/include/llvm/CodeGen/Register.h
+++ b/llvm/include/llvm/CodeGen/Register.h
@@ -206,6 +206,10 @@ class VirtRegOrUnit {
   constexpr bool operator==(const VirtRegOrUnit &Other) const {
     return VRegOrUnit == Other.VRegOrUnit;
   }
+
+  constexpr bool operator<(const VirtRegOrUnit &Other) const {
+    return VRegOrUnit < Other.VRegOrUnit;
+  }
 };
 
 } // namespace llvm

diff  --git a/llvm/include/llvm/CodeGen/RegisterPressure.h b/llvm/include/llvm/CodeGen/RegisterPressure.h
index 261e5b0d73281..20a7e4fa2e9de 100644
--- a/llvm/include/llvm/CodeGen/RegisterPressure.h
+++ b/llvm/include/llvm/CodeGen/RegisterPressure.h
@@ -37,11 +37,11 @@ class MachineRegisterInfo;
 class RegisterClassInfo;
 
 struct VRegMaskOrUnit {
-  Register RegUnit; ///< Virtual register or register unit.
+  VirtRegOrUnit VRegOrUnit;
   LaneBitmask LaneMask;
 
-  VRegMaskOrUnit(Register RegUnit, LaneBitmask LaneMask)
-      : RegUnit(RegUnit), LaneMask(LaneMask) {}
+  VRegMaskOrUnit(VirtRegOrUnit VRegOrUnit, LaneBitmask LaneMask)
+      : VRegOrUnit(VRegOrUnit), LaneMask(LaneMask) {}
 };
 
 /// Base class for register pressure results.
@@ -157,7 +157,7 @@ class PressureDiff {
   const_iterator begin() const { return &PressureChanges[0]; }
   const_iterator end() const { return &PressureChanges[MaxPSets]; }
 
-  LLVM_ABI void addPressureChange(Register RegUnit, bool IsDec,
+  LLVM_ABI void addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
                                   const MachineRegisterInfo *MRI);
 
   LLVM_ABI void dump(const TargetRegisterInfo &TRI) const;
@@ -279,25 +279,25 @@ class LiveRegSet {
   RegSet Regs;
   unsigned NumRegUnits = 0u;
 
-  unsigned getSparseIndexFromReg(Register Reg) const {
-    if (Reg.isVirtual())
-      return Reg.virtRegIndex() + NumRegUnits;
-    assert(Reg < NumRegUnits);
-    return Reg.id();
+  unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
+    if (VRegOrUnit.isVirtualReg())
+      return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
+    assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
+    return VRegOrUnit.asMCRegUnit();
   }
 
-  Register getRegFromSparseIndex(unsigned SparseIndex) const {
+  VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
     if (SparseIndex >= NumRegUnits)
-      return Register::index2VirtReg(SparseIndex - NumRegUnits);
-    return Register(SparseIndex);
+      return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
+    return VirtRegOrUnit(SparseIndex);
   }
 
 public:
   LLVM_ABI void clear();
   LLVM_ABI void init(const MachineRegisterInfo &MRI);
 
-  LaneBitmask contains(Register Reg) const {
-    unsigned SparseIndex = getSparseIndexFromReg(Reg);
+  LaneBitmask contains(VirtRegOrUnit VRegOrUnit) const {
+    unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(VRegOrUnit);
     RegSet::const_iterator I = Regs.find(SparseIndex);
     if (I == Regs.end())
       return LaneBitmask::getNone();
@@ -307,7 +307,7 @@ class LiveRegSet {
   /// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
   /// Returns the previously live lanes of \p Pair.Reg.
   LaneBitmask insert(VRegMaskOrUnit Pair) {
-    unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
+    unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
     auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
     if (!InsertRes.second) {
       LaneBitmask PrevMask = InsertRes.first->LaneMask;
@@ -320,7 +320,7 @@ class LiveRegSet {
   /// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
   /// Returns the previously live lanes of \p Pair.Reg.
   LaneBitmask erase(VRegMaskOrUnit Pair) {
-    unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
+    unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
     RegSet::iterator I = Regs.find(SparseIndex);
     if (I == Regs.end())
       return LaneBitmask::getNone();
@@ -335,9 +335,9 @@ class LiveRegSet {
 
   void appendTo(SmallVectorImpl<VRegMaskOrUnit> &To) const {
     for (const IndexMaskPair &P : Regs) {
-      Register Reg = getRegFromSparseIndex(P.Index);
+      VirtRegOrUnit VRegOrUnit = getVirtRegOrUnitFromSparseIndex(P.Index);
       if (P.LaneMask.any())
-        To.emplace_back(Reg, P.LaneMask);
+        To.emplace_back(VRegOrUnit, P.LaneMask);
     }
   }
 };
@@ -541,9 +541,11 @@ class RegPressureTracker {
 
   LLVM_ABI void dump() const;
 
-  LLVM_ABI void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
+  LLVM_ABI void increaseRegPressure(VirtRegOrUnit VRegOrUnit,
+                                    LaneBitmask PreviousMask,
                                     LaneBitmask NewMask);
-  LLVM_ABI void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
+  LLVM_ABI void decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
+                                    LaneBitmask PreviousMask,
                                     LaneBitmask NewMask);
 
 protected:
@@ -565,9 +567,12 @@ class RegPressureTracker {
   discoverLiveInOrOut(VRegMaskOrUnit Pair,
                       SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut);
 
-  LLVM_ABI LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
-  LLVM_ABI LaneBitmask getLiveLanesAt(Register RegUnit, SlotIndex Pos) const;
-  LLVM_ABI LaneBitmask getLiveThroughAt(Register RegUnit, SlotIndex Pos) const;
+  LLVM_ABI LaneBitmask getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
+                                        SlotIndex Pos) const;
+  LLVM_ABI LaneBitmask getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
+                                      SlotIndex Pos) const;
+  LLVM_ABI LaneBitmask getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
+                                        SlotIndex Pos) const;
 };
 
 LLVM_ABI void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,

diff  --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
index dabf0dc5ec173..35b14e8b8fd30 100644
--- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
@@ -1450,7 +1450,7 @@ LLVM_ABI Printable printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI);
 
 /// Create Printable object to print virtual registers and physical
 /// registers on a \ref raw_ostream.
-LLVM_ABI Printable printVRegOrUnit(unsigned VRegOrUnit,
+LLVM_ABI Printable printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
                                    const TargetRegisterInfo *TRI);
 
 /// Create Printable object to print register classes or register banks

diff  --git a/llvm/lib/CodeGen/MachinePipeliner.cpp b/llvm/lib/CodeGen/MachinePipeliner.cpp
index a717d9e4a618d..e2f7dfc5cadd5 100644
--- a/llvm/lib/CodeGen/MachinePipeliner.cpp
+++ b/llvm/lib/CodeGen/MachinePipeliner.cpp
@@ -1509,7 +1509,11 @@ class HighRegisterPressureDetector {
 
   void dumpPSet(Register Reg) const {
     dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
-    for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
+    // FIXME: The static_cast is a bug compensating bugs in the callers.
+    VirtRegOrUnit VRegOrUnit =
+        Reg.isVirtual() ? VirtRegOrUnit(Reg)
+                        : VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
+    for (auto PSetIter = MRI.getPressureSets(VRegOrUnit); PSetIter.isValid();
          ++PSetIter) {
       dbgs() << *PSetIter << ' ';
     }
@@ -1518,7 +1522,11 @@ class HighRegisterPressureDetector {
 
   void increaseRegisterPressure(std::vector<unsigned> &Pressure,
                                 Register Reg) const {
-    auto PSetIter = MRI.getPressureSets(Reg);
+    // FIXME: The static_cast is a bug compensating bugs in the callers.
+    VirtRegOrUnit VRegOrUnit =
+        Reg.isVirtual() ? VirtRegOrUnit(Reg)
+                        : VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
+    auto PSetIter = MRI.getPressureSets(VRegOrUnit);
     unsigned Weight = PSetIter.getWeight();
     for (; PSetIter.isValid(); ++PSetIter)
       Pressure[*PSetIter] += Weight;
@@ -1526,7 +1534,7 @@ class HighRegisterPressureDetector {
 
   void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
                                 Register Reg) const {
-    auto PSetIter = MRI.getPressureSets(Reg);
+    auto PSetIter = MRI.getPressureSets(VirtRegOrUnit(Reg));
     unsigned Weight = PSetIter.getWeight();
     for (; PSetIter.isValid(); ++PSetIter) {
       auto &P = Pressure[*PSetIter];
@@ -1559,7 +1567,11 @@ class HighRegisterPressureDetector {
       if (MI.isDebugInstr())
         continue;
       for (auto &Use : ROMap[&MI].Uses) {
-        auto Reg = Use.RegUnit;
+        // FIXME: The static_cast is a bug.
+        Register Reg =
+            Use.VRegOrUnit.isVirtualReg()
+                ? Use.VRegOrUnit.asVirtualReg()
+                : Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
         // Ignore the variable that appears only on one side of phi instruction
         // because it's used only at the first iteration.
         if (MI.isPHI() && Reg != getLoopPhiReg(MI, OrigMBB))
@@ -1609,8 +1621,14 @@ class HighRegisterPressureDetector {
         Register Reg = getLoopPhiReg(*MI, OrigMBB);
         UpdateTargetRegs(Reg);
       } else {
-        for (auto &Use : ROMap.find(MI)->getSecond().Uses)
-          UpdateTargetRegs(Use.RegUnit);
+        for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
+          // FIXME: The static_cast is a bug.
+          Register Reg = Use.VRegOrUnit.isVirtualReg()
+                             ? Use.VRegOrUnit.asVirtualReg()
+                             : Register(static_cast<unsigned>(
+                                   Use.VRegOrUnit.asMCRegUnit()));
+          UpdateTargetRegs(Reg);
+        }
       }
     }
 
@@ -1621,7 +1639,11 @@ class HighRegisterPressureDetector {
     DenseMap<Register, MachineInstr *> LastUseMI;
     for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
       for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
-        auto Reg = Use.RegUnit;
+        // FIXME: The static_cast is a bug.
+        Register Reg =
+            Use.VRegOrUnit.isVirtualReg()
+                ? Use.VRegOrUnit.asVirtualReg()
+                : Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
         if (!TargetRegs.contains(Reg))
           continue;
         auto [Ite, Inserted] = LastUseMI.try_emplace(Reg, MI);
@@ -1635,8 +1657,8 @@ class HighRegisterPressureDetector {
     }
 
     Instr2LastUsesTy LastUses;
-    for (auto &Entry : LastUseMI)
-      LastUses[Entry.second].insert(Entry.first);
+    for (auto [Reg, MI] : LastUseMI)
+      LastUses[MI].insert(Reg);
     return LastUses;
   }
 
@@ -1675,7 +1697,12 @@ class HighRegisterPressureDetector {
     });
 
     const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
-                                                   Register Reg) {
+                                                   VirtRegOrUnit VRegOrUnit) {
+      // FIXME: The static_cast is a bug.
+      Register Reg =
+          VRegOrUnit.isVirtualReg()
+              ? VRegOrUnit.asVirtualReg()
+              : Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
       if (!Reg.isValid() || isReservedRegister(Reg))
         return;
 
@@ -1712,7 +1739,7 @@ class HighRegisterPressureDetector {
         const unsigned Iter = I - Stage;
 
         for (auto &Def : ROMap.find(MI)->getSecond().Defs)
-          InsertReg(LiveRegSets[Iter], Def.RegUnit);
+          InsertReg(LiveRegSets[Iter], Def.VRegOrUnit);
 
         for (auto LastUse : LastUses[MI]) {
           if (MI->isPHI()) {
@@ -2235,7 +2262,7 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
   const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
   MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<VRegMaskOrUnit, 8> LiveOutRegs;
-  SmallSet<Register, 4> Uses;
+  SmallSet<VirtRegOrUnit, 4> Uses;
   for (SUnit *SU : NS) {
     const MachineInstr *MI = SU->getInstr();
     if (MI->isPHI())
@@ -2243,9 +2270,10 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
     for (const MachineOperand &MO : MI->all_uses()) {
       Register Reg = MO.getReg();
       if (Reg.isVirtual())
-        Uses.insert(Reg);
+        Uses.insert(VirtRegOrUnit(Reg));
       else if (MRI.isAllocatable(Reg))
-        Uses.insert_range(TRI->regunits(Reg.asMCReg()));
+        for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
+          Uses.insert(VirtRegOrUnit(Unit));
     }
   }
   for (SUnit *SU : NS)
@@ -2253,12 +2281,14 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
       if (!MO.isDead()) {
         Register Reg = MO.getReg();
         if (Reg.isVirtual()) {
-          if (!Uses.count(Reg))
-            LiveOutRegs.emplace_back(Reg, LaneBitmask::getNone());
+          if (!Uses.count(VirtRegOrUnit(Reg)))
+            LiveOutRegs.emplace_back(VirtRegOrUnit(Reg),
+                                     LaneBitmask::getNone());
         } else if (MRI.isAllocatable(Reg)) {
           for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
-            if (!Uses.count(Unit))
-              LiveOutRegs.emplace_back(Unit, LaneBitmask::getNone());
+            if (!Uses.count(VirtRegOrUnit(Unit)))
+              LiveOutRegs.emplace_back(VirtRegOrUnit(Unit),
+                                       LaneBitmask::getNone());
         }
       }
   RPTracker.addLiveRegs(LiveOutRegs);

diff  --git a/llvm/lib/CodeGen/MachineScheduler.cpp b/llvm/lib/CodeGen/MachineScheduler.cpp
index 73993705c4a7b..de29a9fab876e 100644
--- a/llvm/lib/CodeGen/MachineScheduler.cpp
+++ b/llvm/lib/CodeGen/MachineScheduler.cpp
@@ -1580,10 +1580,10 @@ updateScheduledPressure(const SUnit *SU,
 /// instruction.
 void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
   for (const VRegMaskOrUnit &P : LiveUses) {
-    Register Reg = P.RegUnit;
     /// FIXME: Currently assuming single-use physregs.
-    if (!Reg.isVirtual())
+    if (!P.VRegOrUnit.isVirtualReg())
       continue;
+    Register Reg = P.VRegOrUnit.asVirtualReg();
 
     if (ShouldTrackLaneMasks) {
       // If the register has just become live then other uses won't change
@@ -1599,7 +1599,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
           continue;
 
         PressureDiff &PDiff = getPressureDiff(&SU);
-        PDiff.addPressureChange(Reg, Decrement, &MRI);
+        PDiff.addPressureChange(VirtRegOrUnit(Reg), Decrement, &MRI);
         if (llvm::any_of(PDiff, [](const PressureChange &Change) {
               return Change.isValid();
             }))
@@ -1611,7 +1611,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
       }
     } else {
       assert(P.LaneMask.any());
-      LLVM_DEBUG(dbgs() << "  LiveReg: " << printVRegOrUnit(Reg, TRI) << "\n");
+      LLVM_DEBUG(dbgs() << "  LiveReg: " << printReg(Reg, TRI) << "\n");
       // This may be called before CurrentBottom has been initialized. However,
       // BotRPTracker must have a valid position. We want the value live into the
       // instruction or live out of the block, so ask for the previous
@@ -1638,7 +1638,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
               LI.Query(LIS->getInstructionIndex(*SU->getInstr()));
           if (LRQ.valueIn() == VNI) {
             PressureDiff &PDiff = getPressureDiff(SU);
-            PDiff.addPressureChange(Reg, true, &MRI);
+            PDiff.addPressureChange(VirtRegOrUnit(Reg), true, &MRI);
             if (llvm::any_of(PDiff, [](const PressureChange &Change) {
                   return Change.isValid();
                 }))
@@ -1814,9 +1814,9 @@ unsigned ScheduleDAGMILive::computeCyclicCriticalPath() {
   unsigned MaxCyclicLatency = 0;
   // Visit each live out vreg def to find def/use pairs that cross iterations.
   for (const VRegMaskOrUnit &P : RPTracker.getPressure().LiveOutRegs) {
-    Register Reg = P.RegUnit;
-    if (!Reg.isVirtual())
+    if (!P.VRegOrUnit.isVirtualReg())
       continue;
+    Register Reg = P.VRegOrUnit.asVirtualReg();
     const LiveInterval &LI = LIS->getInterval(Reg);
     const VNInfo *DefVNI = LI.getVNInfoBefore(LIS->getMBBEndIdx(BB));
     if (!DefVNI)

diff  --git a/llvm/lib/CodeGen/RegisterPressure.cpp b/llvm/lib/CodeGen/RegisterPressure.cpp
index 7d4674b3f74f0..cd431bc7a171c 100644
--- a/llvm/lib/CodeGen/RegisterPressure.cpp
+++ b/llvm/lib/CodeGen/RegisterPressure.cpp
@@ -47,13 +47,14 @@ using namespace llvm;
 
 /// Increase pressure for each pressure set provided by TargetRegisterInfo.
 static void increaseSetPressure(std::vector<unsigned> &CurrSetPressure,
-                                const MachineRegisterInfo &MRI, unsigned Reg,
-                                LaneBitmask PrevMask, LaneBitmask NewMask) {
+                                const MachineRegisterInfo &MRI,
+                                VirtRegOrUnit VRegOrUnit, LaneBitmask PrevMask,
+                                LaneBitmask NewMask) {
   assert((PrevMask & ~NewMask).none() && "Must not remove bits");
   if (PrevMask.any() || NewMask.none())
     return;
 
-  PSetIterator PSetI = MRI.getPressureSets(Reg);
+  PSetIterator PSetI = MRI.getPressureSets(VRegOrUnit);
   unsigned Weight = PSetI.getWeight();
   for (; PSetI.isValid(); ++PSetI)
     CurrSetPressure[*PSetI] += Weight;
@@ -61,13 +62,14 @@ static void increaseSetPressure(std::vector<unsigned> &CurrSetPressure,
 
 /// Decrease pressure for each pressure set provided by TargetRegisterInfo.
 static void decreaseSetPressure(std::vector<unsigned> &CurrSetPressure,
-                                const MachineRegisterInfo &MRI, Register Reg,
-                                LaneBitmask PrevMask, LaneBitmask NewMask) {
+                                const MachineRegisterInfo &MRI,
+                                VirtRegOrUnit VRegOrUnit, LaneBitmask PrevMask,
+                                LaneBitmask NewMask) {
   assert((NewMask & ~PrevMask).none() && "Must not add bits");
   if (NewMask.any() || PrevMask.none())
     return;
 
-  PSetIterator PSetI = MRI.getPressureSets(Reg);
+  PSetIterator PSetI = MRI.getPressureSets(VRegOrUnit);
   unsigned Weight = PSetI.getWeight();
   for (; PSetI.isValid(); ++PSetI) {
     assert(CurrSetPressure[*PSetI] >= Weight && "register pressure underflow");
@@ -93,7 +95,7 @@ void RegisterPressure::dump(const TargetRegisterInfo *TRI) const {
   dumpRegSetPressure(MaxSetPressure, TRI);
   dbgs() << "Live In: ";
   for (const VRegMaskOrUnit &P : LiveInRegs) {
-    dbgs() << printVRegOrUnit(P.RegUnit, TRI);
+    dbgs() << printVRegOrUnit(P.VRegOrUnit, TRI);
     if (!P.LaneMask.all())
       dbgs() << ':' << PrintLaneMask(P.LaneMask);
     dbgs() << ' ';
@@ -101,7 +103,7 @@ void RegisterPressure::dump(const TargetRegisterInfo *TRI) const {
   dbgs() << '\n';
   dbgs() << "Live Out: ";
   for (const VRegMaskOrUnit &P : LiveOutRegs) {
-    dbgs() << printVRegOrUnit(P.RegUnit, TRI);
+    dbgs() << printVRegOrUnit(P.VRegOrUnit, TRI);
     if (!P.LaneMask.all())
       dbgs() << ':' << PrintLaneMask(P.LaneMask);
     dbgs() << ' ';
@@ -148,13 +150,13 @@ void RegPressureDelta::dump() const {
 
 #endif
 
-void RegPressureTracker::increaseRegPressure(Register RegUnit,
+void RegPressureTracker::increaseRegPressure(VirtRegOrUnit VRegOrUnit,
                                              LaneBitmask PreviousMask,
                                              LaneBitmask NewMask) {
   if (PreviousMask.any() || NewMask.none())
     return;
 
-  PSetIterator PSetI = MRI->getPressureSets(RegUnit);
+  PSetIterator PSetI = MRI->getPressureSets(VRegOrUnit);
   unsigned Weight = PSetI.getWeight();
   for (; PSetI.isValid(); ++PSetI) {
     CurrSetPressure[*PSetI] += Weight;
@@ -163,10 +165,10 @@ void RegPressureTracker::increaseRegPressure(Register RegUnit,
   }
 }
 
-void RegPressureTracker::decreaseRegPressure(Register RegUnit,
+void RegPressureTracker::decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
                                              LaneBitmask PreviousMask,
                                              LaneBitmask NewMask) {
-  decreaseSetPressure(CurrSetPressure, *MRI, RegUnit, PreviousMask, NewMask);
+  decreaseSetPressure(CurrSetPressure, *MRI, VRegOrUnit, PreviousMask, NewMask);
 }
 
 /// Clear the result so it can be used for another round of pressure tracking.
@@ -230,10 +232,11 @@ void LiveRegSet::clear() {
   Regs.clear();
 }
 
-static const LiveRange *getLiveRange(const LiveIntervals &LIS, unsigned Reg) {
-  if (Register::isVirtualRegister(Reg))
-    return &LIS.getInterval(Reg);
-  return LIS.getCachedRegUnit(Reg);
+static const LiveRange *getLiveRange(const LiveIntervals &LIS,
+                                     VirtRegOrUnit VRegOrUnit) {
+  if (VRegOrUnit.isVirtualReg())
+    return &LIS.getInterval(VRegOrUnit.asVirtualReg());
+  return LIS.getCachedRegUnit(VRegOrUnit.asMCRegUnit());
 }
 
 void RegPressureTracker::reset() {
@@ -356,17 +359,18 @@ void RegPressureTracker::initLiveThru(const RegPressureTracker &RPTracker) {
   LiveThruPressure.assign(TRI->getNumRegPressureSets(), 0);
   assert(isBottomClosed() && "need bottom-up tracking to intialize.");
   for (const VRegMaskOrUnit &Pair : P.LiveOutRegs) {
-    Register RegUnit = Pair.RegUnit;
-    if (RegUnit.isVirtual() && !RPTracker.hasUntiedDef(RegUnit))
-      increaseSetPressure(LiveThruPressure, *MRI, RegUnit,
+    VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
+    if (VRegOrUnit.isVirtualReg() &&
+        !RPTracker.hasUntiedDef(VRegOrUnit.asVirtualReg()))
+      increaseSetPressure(LiveThruPressure, *MRI, VRegOrUnit,
                           LaneBitmask::getNone(), Pair.LaneMask);
   }
 }
 
 static LaneBitmask getRegLanes(ArrayRef<VRegMaskOrUnit> RegUnits,
-                               Register RegUnit) {
-  auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
-    return Other.RegUnit == RegUnit;
+                               VirtRegOrUnit VRegOrUnit) {
+  auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+    return Other.VRegOrUnit == VRegOrUnit;
   });
   if (I == RegUnits.end())
     return LaneBitmask::getNone();
@@ -375,10 +379,10 @@ static LaneBitmask getRegLanes(ArrayRef<VRegMaskOrUnit> RegUnits,
 
 static void addRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
                         VRegMaskOrUnit Pair) {
-  Register RegUnit = Pair.RegUnit;
+  VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
   assert(Pair.LaneMask.any());
-  auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
-    return Other.RegUnit == RegUnit;
+  auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+    return Other.VRegOrUnit == VRegOrUnit;
   });
   if (I == RegUnits.end()) {
     RegUnits.push_back(Pair);
@@ -388,12 +392,12 @@ static void addRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
 }
 
 static void setRegZero(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
-                       Register RegUnit) {
-  auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
-    return Other.RegUnit == RegUnit;
+                       VirtRegOrUnit VRegOrUnit) {
+  auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+    return Other.VRegOrUnit == VRegOrUnit;
   });
   if (I == RegUnits.end()) {
-    RegUnits.emplace_back(RegUnit, LaneBitmask::getNone());
+    RegUnits.emplace_back(VRegOrUnit, LaneBitmask::getNone());
   } else {
     I->LaneMask = LaneBitmask::getNone();
   }
@@ -401,10 +405,10 @@ static void setRegZero(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
 
 static void removeRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
                            VRegMaskOrUnit Pair) {
-  Register RegUnit = Pair.RegUnit;
+  VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
   assert(Pair.LaneMask.any());
-  auto I = llvm::find_if(RegUnits, [RegUnit](const VRegMaskOrUnit Other) {
-    return Other.RegUnit == RegUnit;
+  auto I = llvm::find_if(RegUnits, [VRegOrUnit](const VRegMaskOrUnit Other) {
+    return Other.VRegOrUnit == VRegOrUnit;
   });
   if (I != RegUnits.end()) {
     I->LaneMask &= ~Pair.LaneMask;
@@ -415,11 +419,11 @@ static void removeRegLanes(SmallVectorImpl<VRegMaskOrUnit> &RegUnits,
 
 static LaneBitmask
 getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
-                     bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
-                     LaneBitmask SafeDefault,
+                     bool TrackLaneMasks, VirtRegOrUnit VRegOrUnit,
+                     SlotIndex Pos, LaneBitmask SafeDefault,
                      bool (*Property)(const LiveRange &LR, SlotIndex Pos)) {
-  if (RegUnit.isVirtual()) {
-    const LiveInterval &LI = LIS.getInterval(RegUnit);
+  if (VRegOrUnit.isVirtualReg()) {
+    const LiveInterval &LI = LIS.getInterval(VRegOrUnit.asVirtualReg());
     LaneBitmask Result;
     if (TrackLaneMasks && LI.hasSubRanges()) {
         for (const LiveInterval::SubRange &SR : LI.subranges()) {
@@ -427,13 +431,14 @@ getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
             Result |= SR.LaneMask;
         }
     } else if (Property(LI, Pos)) {
-      Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
-                              : LaneBitmask::getAll();
+      Result = TrackLaneMasks
+                   ? MRI.getMaxLaneMaskForVReg(VRegOrUnit.asVirtualReg())
+                   : LaneBitmask::getAll();
     }
 
     return Result;
   } else {
-    const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
+    const LiveRange *LR = LIS.getCachedRegUnit(VRegOrUnit.asMCRegUnit());
     // Be prepared for missing liveranges: We usually do not compute liveranges
     // for physical registers on targets with many registers (GPUs).
     if (LR == nullptr)
@@ -444,13 +449,11 @@ getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
 
 static LaneBitmask getLiveLanesAt(const LiveIntervals &LIS,
                                   const MachineRegisterInfo &MRI,
-                                  bool TrackLaneMasks, Register RegUnit,
+                                  bool TrackLaneMasks, VirtRegOrUnit VRegOrUnit,
                                   SlotIndex Pos) {
-  return getLanesWithProperty(LIS, MRI, TrackLaneMasks, RegUnit, Pos,
-                              LaneBitmask::getAll(),
-                              [](const LiveRange &LR, SlotIndex Pos) {
-                                return LR.liveAt(Pos);
-                              });
+  return getLanesWithProperty(
+      LIS, MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getAll(),
+      [](const LiveRange &LR, SlotIndex Pos) { return LR.liveAt(Pos); });
 }
 
 namespace {
@@ -514,10 +517,12 @@ class RegisterOperandsCollector {
 
   void pushReg(Register Reg, SmallVectorImpl<VRegMaskOrUnit> &RegUnits) const {
     if (Reg.isVirtual()) {
-      addRegLanes(RegUnits, VRegMaskOrUnit(Reg, LaneBitmask::getAll()));
+      addRegLanes(RegUnits,
+                  VRegMaskOrUnit(VirtRegOrUnit(Reg), LaneBitmask::getAll()));
     } else if (MRI.isAllocatable(Reg)) {
       for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
-        addRegLanes(RegUnits, VRegMaskOrUnit(Unit, LaneBitmask::getAll()));
+        addRegLanes(RegUnits,
+                    VRegMaskOrUnit(VirtRegOrUnit(Unit), LaneBitmask::getAll()));
     }
   }
 
@@ -549,10 +554,11 @@ class RegisterOperandsCollector {
       LaneBitmask LaneMask = SubRegIdx != 0
                              ? TRI.getSubRegIndexLaneMask(SubRegIdx)
                              : MRI.getMaxLaneMaskForVReg(Reg);
-      addRegLanes(RegUnits, VRegMaskOrUnit(Reg, LaneMask));
+      addRegLanes(RegUnits, VRegMaskOrUnit(VirtRegOrUnit(Reg), LaneMask));
     } else if (MRI.isAllocatable(Reg)) {
       for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
-        addRegLanes(RegUnits, VRegMaskOrUnit(Unit, LaneBitmask::getAll()));
+        addRegLanes(RegUnits,
+                    VRegMaskOrUnit(VirtRegOrUnit(Unit), LaneBitmask::getAll()));
     }
   }
 };
@@ -574,8 +580,7 @@ void RegisterOperands::detectDeadDefs(const MachineInstr &MI,
                                       const LiveIntervals &LIS) {
   SlotIndex SlotIdx = LIS.getInstructionIndex(MI);
   for (auto *RI = Defs.begin(); RI != Defs.end(); /*empty*/) {
-    Register Reg = RI->RegUnit;
-    const LiveRange *LR = getLiveRange(LIS, Reg);
+    const LiveRange *LR = getLiveRange(LIS, RI->VRegOrUnit);
     if (LR != nullptr) {
       LiveQueryResult LRQ = LR->Query(SlotIdx);
       if (LRQ.isDeadDef()) {
@@ -595,14 +600,14 @@ void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS,
                                           SlotIndex Pos,
                                           MachineInstr *AddFlagsMI) {
   for (auto *I = Defs.begin(); I != Defs.end();) {
-    LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, I->RegUnit,
-                                           Pos.getDeadSlot());
+    LaneBitmask LiveAfter =
+        getLiveLanesAt(LIS, MRI, true, I->VRegOrUnit, Pos.getDeadSlot());
     // If the def is all that is live after the instruction, then in case
     // of a subregister def we need a read-undef flag.
-    Register RegUnit = I->RegUnit;
-    if (RegUnit.isVirtual() && AddFlagsMI != nullptr &&
+    VirtRegOrUnit VRegOrUnit = I->VRegOrUnit;
+    if (VRegOrUnit.isVirtualReg() && AddFlagsMI != nullptr &&
         (LiveAfter & ~I->LaneMask).none())
-      AddFlagsMI->setRegisterDefReadUndef(RegUnit);
+      AddFlagsMI->setRegisterDefReadUndef(VRegOrUnit.asVirtualReg());
 
     LaneBitmask ActualDef = I->LaneMask & LiveAfter;
     if (ActualDef.none()) {
@@ -614,18 +619,18 @@ void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS,
   }
 
   // For uses just copy the information from LIS.
-  for (auto &[RegUnit, LaneMask] : Uses)
-    LaneMask = getLiveLanesAt(LIS, MRI, true, RegUnit, Pos.getBaseIndex());
+  for (auto &[VRegOrUnit, LaneMask] : Uses)
+    LaneMask = getLiveLanesAt(LIS, MRI, true, VRegOrUnit, Pos.getBaseIndex());
 
   if (AddFlagsMI != nullptr) {
     for (const VRegMaskOrUnit &P : DeadDefs) {
-      Register RegUnit = P.RegUnit;
-      if (!RegUnit.isVirtual())
+      VirtRegOrUnit VRegOrUnit = P.VRegOrUnit;
+      if (!VRegOrUnit.isVirtualReg())
         continue;
-      LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, RegUnit,
-                                             Pos.getDeadSlot());
+      LaneBitmask LiveAfter =
+          getLiveLanesAt(LIS, MRI, true, VRegOrUnit, Pos.getDeadSlot());
       if (LiveAfter.none())
-        AddFlagsMI->setRegisterDefReadUndef(RegUnit);
+        AddFlagsMI->setRegisterDefReadUndef(VRegOrUnit.asVirtualReg());
     }
   }
 }
@@ -648,16 +653,16 @@ void PressureDiffs::addInstruction(unsigned Idx,
   PressureDiff &PDiff = (*this)[Idx];
   assert(!PDiff.begin()->isValid() && "stale PDiff");
   for (const VRegMaskOrUnit &P : RegOpers.Defs)
-    PDiff.addPressureChange(P.RegUnit, true, &MRI);
+    PDiff.addPressureChange(P.VRegOrUnit, true, &MRI);
 
   for (const VRegMaskOrUnit &P : RegOpers.Uses)
-    PDiff.addPressureChange(P.RegUnit, false, &MRI);
+    PDiff.addPressureChange(P.VRegOrUnit, false, &MRI);
 }
 
 /// Add a change in pressure to the pressure 
diff  of a given instruction.
-void PressureDiff::addPressureChange(Register RegUnit, bool IsDec,
+void PressureDiff::addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
                                      const MachineRegisterInfo *MRI) {
-  PSetIterator PSetI = MRI->getPressureSets(RegUnit);
+  PSetIterator PSetI = MRI->getPressureSets(VRegOrUnit);
   int Weight = IsDec ? -PSetI.getWeight() : PSetI.getWeight();
   for (; PSetI.isValid(); ++PSetI) {
     // Find an existing entry in the pressure 
diff  for this PSet.
@@ -694,7 +699,7 @@ void RegPressureTracker::addLiveRegs(ArrayRef<VRegMaskOrUnit> Regs) {
   for (const VRegMaskOrUnit &P : Regs) {
     LaneBitmask PrevMask = LiveRegs.insert(P);
     LaneBitmask NewMask = PrevMask | P.LaneMask;
-    increaseRegPressure(P.RegUnit, PrevMask, NewMask);
+    increaseRegPressure(P.VRegOrUnit, PrevMask, NewMask);
   }
 }
 
@@ -702,9 +707,9 @@ void RegPressureTracker::discoverLiveInOrOut(
     VRegMaskOrUnit Pair, SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut) {
   assert(Pair.LaneMask.any());
 
-  Register RegUnit = Pair.RegUnit;
-  auto I = llvm::find_if(LiveInOrOut, [RegUnit](const VRegMaskOrUnit &Other) {
-    return Other.RegUnit == RegUnit;
+  VirtRegOrUnit VRegOrUnit = Pair.VRegOrUnit;
+  auto I = find_if(LiveInOrOut, [VRegOrUnit](const VRegMaskOrUnit &Other) {
+    return Other.VRegOrUnit == VRegOrUnit;
   });
   LaneBitmask PrevMask;
   LaneBitmask NewMask;
@@ -717,7 +722,7 @@ void RegPressureTracker::discoverLiveInOrOut(
     NewMask = PrevMask | Pair.LaneMask;
     I->LaneMask = NewMask;
   }
-  increaseSetPressure(P.MaxSetPressure, *MRI, RegUnit, PrevMask, NewMask);
+  increaseSetPressure(P.MaxSetPressure, *MRI, VRegOrUnit, PrevMask, NewMask);
 }
 
 void RegPressureTracker::discoverLiveIn(VRegMaskOrUnit Pair) {
@@ -730,16 +735,14 @@ void RegPressureTracker::discoverLiveOut(VRegMaskOrUnit Pair) {
 
 void RegPressureTracker::bumpDeadDefs(ArrayRef<VRegMaskOrUnit> DeadDefs) {
   for (const VRegMaskOrUnit &P : DeadDefs) {
-    Register Reg = P.RegUnit;
-    LaneBitmask LiveMask = LiveRegs.contains(Reg);
+    LaneBitmask LiveMask = LiveRegs.contains(P.VRegOrUnit);
     LaneBitmask BumpedMask = LiveMask | P.LaneMask;
-    increaseRegPressure(Reg, LiveMask, BumpedMask);
+    increaseRegPressure(P.VRegOrUnit, LiveMask, BumpedMask);
   }
   for (const VRegMaskOrUnit &P : DeadDefs) {
-    Register Reg = P.RegUnit;
-    LaneBitmask LiveMask = LiveRegs.contains(Reg);
+    LaneBitmask LiveMask = LiveRegs.contains(P.VRegOrUnit);
     LaneBitmask BumpedMask = LiveMask | P.LaneMask;
-    decreaseRegPressure(Reg, BumpedMask, LiveMask);
+    decreaseRegPressure(P.VRegOrUnit, BumpedMask, LiveMask);
   }
 }
 
@@ -758,17 +761,17 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
   // Kill liveness at live defs.
   // TODO: consider earlyclobbers?
   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
-    Register Reg = Def.RegUnit;
+    VirtRegOrUnit VRegOrUnit = Def.VRegOrUnit;
 
     LaneBitmask PreviousMask = LiveRegs.erase(Def);
     LaneBitmask NewMask = PreviousMask & ~Def.LaneMask;
 
     LaneBitmask LiveOut = Def.LaneMask & ~PreviousMask;
     if (LiveOut.any()) {
-      discoverLiveOut(VRegMaskOrUnit(Reg, LiveOut));
+      discoverLiveOut(VRegMaskOrUnit(VRegOrUnit, LiveOut));
       // Retroactively model effects on pressure of the live out lanes.
-      increaseSetPressure(CurrSetPressure, *MRI, Reg, LaneBitmask::getNone(),
-                          LiveOut);
+      increaseSetPressure(CurrSetPressure, *MRI, VRegOrUnit,
+                          LaneBitmask::getNone(), LiveOut);
       PreviousMask = LiveOut;
     }
 
@@ -776,10 +779,10 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
       // Add a 0 entry to LiveUses as a marker that the complete vreg has become
       // dead.
       if (TrackLaneMasks && LiveUses != nullptr)
-        setRegZero(*LiveUses, Reg);
+        setRegZero(*LiveUses, VRegOrUnit);
     }
 
-    decreaseRegPressure(Reg, PreviousMask, NewMask);
+    decreaseRegPressure(VRegOrUnit, PreviousMask, NewMask);
   }
 
   SlotIndex SlotIdx;
@@ -788,7 +791,7 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
 
   // Generate liveness for uses.
   for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
-    Register Reg = Use.RegUnit;
+    VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
     assert(Use.LaneMask.any());
     LaneBitmask PreviousMask = LiveRegs.insert(Use);
     LaneBitmask NewMask = PreviousMask | Use.LaneMask;
@@ -799,38 +802,38 @@ void RegPressureTracker::recede(const RegisterOperands &RegOpers,
     if (PreviousMask.none()) {
       if (LiveUses != nullptr) {
         if (!TrackLaneMasks) {
-          addRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+          addRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
         } else {
-          auto I = llvm::find_if(*LiveUses, [Reg](const VRegMaskOrUnit Other) {
-            return Other.RegUnit == Reg;
+          auto I = find_if(*LiveUses, [VRegOrUnit](const VRegMaskOrUnit Other) {
+            return Other.VRegOrUnit == VRegOrUnit;
           });
           bool IsRedef = I != LiveUses->end();
           if (IsRedef) {
             // ignore re-defs here...
             assert(I->LaneMask.none());
-            removeRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+            removeRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
           } else {
-            addRegLanes(*LiveUses, VRegMaskOrUnit(Reg, NewMask));
+            addRegLanes(*LiveUses, VRegMaskOrUnit(VRegOrUnit, NewMask));
           }
         }
       }
 
       // Discover live outs if this may be the first occurance of this register.
       if (RequireIntervals) {
-        LaneBitmask LiveOut = getLiveThroughAt(Reg, SlotIdx);
+        LaneBitmask LiveOut = getLiveThroughAt(VRegOrUnit, SlotIdx);
         if (LiveOut.any())
-          discoverLiveOut(VRegMaskOrUnit(Reg, LiveOut));
+          discoverLiveOut(VRegMaskOrUnit(VRegOrUnit, LiveOut));
       }
     }
 
-    increaseRegPressure(Reg, PreviousMask, NewMask);
+    increaseRegPressure(VRegOrUnit, PreviousMask, NewMask);
   }
   if (TrackUntiedDefs) {
     for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
-      Register RegUnit = Def.RegUnit;
-      if (RegUnit.isVirtual() &&
-          (LiveRegs.contains(RegUnit) & Def.LaneMask).none())
-        UntiedDefs.insert(RegUnit);
+      VirtRegOrUnit VRegOrUnit = Def.VRegOrUnit;
+      if (VRegOrUnit.isVirtualReg() &&
+          (LiveRegs.contains(VRegOrUnit) & Def.LaneMask).none())
+        UntiedDefs.insert(VRegOrUnit.asVirtualReg());
     }
   }
 }
@@ -898,20 +901,20 @@ void RegPressureTracker::advance(const RegisterOperands &RegOpers) {
   }
 
   for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
-    Register Reg = Use.RegUnit;
-    LaneBitmask LiveMask = LiveRegs.contains(Reg);
+    VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
+    LaneBitmask LiveMask = LiveRegs.contains(VRegOrUnit);
     LaneBitmask LiveIn = Use.LaneMask & ~LiveMask;
     if (LiveIn.any()) {
-      discoverLiveIn(VRegMaskOrUnit(Reg, LiveIn));
-      increaseRegPressure(Reg, LiveMask, LiveMask | LiveIn);
-      LiveRegs.insert(VRegMaskOrUnit(Reg, LiveIn));
+      discoverLiveIn(VRegMaskOrUnit(VRegOrUnit, LiveIn));
+      increaseRegPressure(VRegOrUnit, LiveMask, LiveMask | LiveIn);
+      LiveRegs.insert(VRegMaskOrUnit(VRegOrUnit, LiveIn));
     }
     // Kill liveness at last uses.
     if (RequireIntervals) {
-      LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
+      LaneBitmask LastUseMask = getLastUsedLanes(VRegOrUnit, SlotIdx);
       if (LastUseMask.any()) {
-        LiveRegs.erase(VRegMaskOrUnit(Reg, LastUseMask));
-        decreaseRegPressure(Reg, LiveMask, LiveMask & ~LastUseMask);
+        LiveRegs.erase(VRegMaskOrUnit(VRegOrUnit, LastUseMask));
+        decreaseRegPressure(VRegOrUnit, LiveMask, LiveMask & ~LastUseMask);
       }
     }
   }
@@ -920,7 +923,7 @@ void RegPressureTracker::advance(const RegisterOperands &RegOpers) {
   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
     LaneBitmask PreviousMask = LiveRegs.insert(Def);
     LaneBitmask NewMask = PreviousMask | Def.LaneMask;
-    increaseRegPressure(Def.RegUnit, PreviousMask, NewMask);
+    increaseRegPressure(Def.VRegOrUnit, PreviousMask, NewMask);
   }
 
   // Boost pressure for all dead defs together.
@@ -1047,22 +1050,20 @@ void RegPressureTracker::bumpUpwardPressure(const MachineInstr *MI) {
 
   // Kill liveness at live defs.
   for (const VRegMaskOrUnit &P : RegOpers.Defs) {
-    Register Reg = P.RegUnit;
-    LaneBitmask LiveAfter = LiveRegs.contains(Reg);
-    LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, Reg);
+    LaneBitmask LiveAfter = LiveRegs.contains(P.VRegOrUnit);
+    LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, P.VRegOrUnit);
     LaneBitmask DefLanes = P.LaneMask;
     LaneBitmask LiveBefore = (LiveAfter & ~DefLanes) | UseLanes;
 
     // There may be parts of the register that were dead before the
     // instruction, but became live afterwards.
-    decreaseRegPressure(Reg, LiveAfter, LiveAfter & LiveBefore);
+    decreaseRegPressure(P.VRegOrUnit, LiveAfter, LiveAfter & LiveBefore);
   }
   // Generate liveness for uses. Also handle any uses which overlap with defs.
   for (const VRegMaskOrUnit &P : RegOpers.Uses) {
-    Register Reg = P.RegUnit;
-    LaneBitmask LiveAfter = LiveRegs.contains(Reg);
+    LaneBitmask LiveAfter = LiveRegs.contains(P.VRegOrUnit);
     LaneBitmask LiveBefore = LiveAfter | P.LaneMask;
-    increaseRegPressure(Reg, LiveAfter, LiveBefore);
+    increaseRegPressure(P.VRegOrUnit, LiveAfter, LiveBefore);
   }
 }
 
@@ -1209,11 +1210,17 @@ getUpwardPressureDelta(const MachineInstr *MI, /*const*/ PressureDiff &PDiff,
 /// Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx).
 /// The query starts with a lane bitmask which gets lanes/bits removed for every
 /// use we find.
-static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
+static LaneBitmask findUseBetween(VirtRegOrUnit VRegOrUnit,
+                                  LaneBitmask LastUseMask,
                                   SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
                                   const MachineRegisterInfo &MRI,
                                   const LiveIntervals *LIS) {
   const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
+  // FIXME: The static_cast is a bug.
+  Register Reg =
+      VRegOrUnit.isVirtualReg()
+          ? VRegOrUnit.asVirtualReg()
+          : Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
   for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
     if (MO.isUndef())
       continue;
@@ -1230,32 +1237,30 @@ static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
   return LastUseMask;
 }
 
-LaneBitmask RegPressureTracker::getLiveLanesAt(Register RegUnit,
+LaneBitmask RegPressureTracker::getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
                                                SlotIndex Pos) const {
   assert(RequireIntervals);
-  return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
-                              LaneBitmask::getAll(),
-      [](const LiveRange &LR, SlotIndex Pos) {
-        return LR.liveAt(Pos);
-      });
+  return getLanesWithProperty(
+      *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getAll(),
+      [](const LiveRange &LR, SlotIndex Pos) { return LR.liveAt(Pos); });
 }
 
-LaneBitmask RegPressureTracker::getLastUsedLanes(Register RegUnit,
+LaneBitmask RegPressureTracker::getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
                                                  SlotIndex Pos) const {
   assert(RequireIntervals);
-  return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit,
-                              Pos.getBaseIndex(), LaneBitmask::getNone(),
-      [](const LiveRange &LR, SlotIndex Pos) {
+  return getLanesWithProperty(
+      *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos.getBaseIndex(),
+      LaneBitmask::getNone(), [](const LiveRange &LR, SlotIndex Pos) {
         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
         return S != nullptr && S->end == Pos.getRegSlot();
       });
 }
 
-LaneBitmask RegPressureTracker::getLiveThroughAt(Register RegUnit,
+LaneBitmask RegPressureTracker::getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
                                                  SlotIndex Pos) const {
   assert(RequireIntervals);
-  return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
-                              LaneBitmask::getNone(),
+  return getLanesWithProperty(
+      *LIS, *MRI, TrackLaneMasks, VRegOrUnit, Pos, LaneBitmask::getNone(),
       [](const LiveRange &LR, SlotIndex Pos) {
         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
         return S != nullptr && S->start < Pos.getRegSlot(true) &&
@@ -1284,8 +1289,8 @@ void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) {
 
   if (RequireIntervals) {
     for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
-      Register Reg = Use.RegUnit;
-      LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
+      VirtRegOrUnit VRegOrUnit = Use.VRegOrUnit;
+      LaneBitmask LastUseMask = getLastUsedLanes(VRegOrUnit, SlotIdx);
       if (LastUseMask.none())
         continue;
       // The LastUseMask is queried from the liveness information of instruction
@@ -1294,23 +1299,22 @@ void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) {
       // FIXME: allow the caller to pass in the list of vreg uses that remain
       // to be bottom-scheduled to avoid searching uses at each query.
       SlotIndex CurrIdx = getCurrSlot();
-      LastUseMask
-        = findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS);
+      LastUseMask =
+          findUseBetween(VRegOrUnit, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS);
       if (LastUseMask.none())
         continue;
 
-      LaneBitmask LiveMask = LiveRegs.contains(Reg);
+      LaneBitmask LiveMask = LiveRegs.contains(VRegOrUnit);
       LaneBitmask NewMask = LiveMask & ~LastUseMask;
-      decreaseRegPressure(Reg, LiveMask, NewMask);
+      decreaseRegPressure(VRegOrUnit, LiveMask, NewMask);
     }
   }
 
   // Generate liveness for defs.
   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
-    Register Reg = Def.RegUnit;
-    LaneBitmask LiveMask = LiveRegs.contains(Reg);
+    LaneBitmask LiveMask = LiveRegs.contains(Def.VRegOrUnit);
     LaneBitmask NewMask = LiveMask | Def.LaneMask;
-    increaseRegPressure(Reg, LiveMask, NewMask);
+    increaseRegPressure(Def.VRegOrUnit, LiveMask, NewMask);
   }
 
   // Boost pressure for all dead defs together.

diff  --git a/llvm/lib/CodeGen/TargetRegisterInfo.cpp b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
index a5c81afc57a80..975895809b9de 100644
--- a/llvm/lib/CodeGen/TargetRegisterInfo.cpp
+++ b/llvm/lib/CodeGen/TargetRegisterInfo.cpp
@@ -156,12 +156,13 @@ Printable llvm::printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI) {
   });
 }
 
-Printable llvm::printVRegOrUnit(unsigned Unit, const TargetRegisterInfo *TRI) {
-  return Printable([Unit, TRI](raw_ostream &OS) {
-    if (Register::isVirtualRegister(Unit)) {
-      OS << '%' << Register(Unit).virtRegIndex();
+Printable llvm::printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
+                                const TargetRegisterInfo *TRI) {
+  return Printable([VRegOrUnit, TRI](raw_ostream &OS) {
+    if (VRegOrUnit.isVirtualReg()) {
+      OS << '%' << VRegOrUnit.asVirtualReg().virtRegIndex();
     } else {
-      OS << printRegUnit(Unit, TRI);
+      OS << printRegUnit(VRegOrUnit.asMCRegUnit(), TRI);
     }
   });
 }

diff  --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
index 4e11c4ff3d56e..0c5e3d0837800 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
@@ -282,11 +282,12 @@ collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
 
     Register Reg = MO.getReg();
     auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) {
-      return RM.RegUnit == Reg;
+      return RM.VRegOrUnit.asVirtualReg() == Reg;
     });
 
     auto &P = I == VRegMaskOrUnits.end()
-                  ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone())
+                  ? VRegMaskOrUnits.emplace_back(VirtRegOrUnit(Reg),
+                                                 LaneBitmask::getNone())
                   : *I;
 
     P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
@@ -295,7 +296,7 @@ collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
 
   SlotIndex InstrSI;
   for (auto &P : VRegMaskOrUnits) {
-    auto &LI = LIS.getInterval(P.RegUnit);
+    auto &LI = LIS.getInterval(P.VRegOrUnit.asVirtualReg());
     if (!LI.hasSubRanges())
       continue;
 
@@ -562,10 +563,10 @@ void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
   SmallVector<VRegMaskOrUnit, 8> RegUses;
   collectVirtualRegUses(RegUses, MI, LIS, *MRI);
   for (const VRegMaskOrUnit &U : RegUses) {
-    LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
+    LaneBitmask &LiveMask = LiveRegs[U.VRegOrUnit.asVirtualReg()];
     LaneBitmask PrevMask = LiveMask;
     LiveMask |= U.LaneMask;
-    CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
+    CurPressure.inc(U.VRegOrUnit.asVirtualReg(), PrevMask, LiveMask, *MRI);
   }
 
   // Update MaxPressure with uses plus early-clobber defs pressure.
@@ -748,9 +749,9 @@ GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
   GCNRegPressure TempPressure = CurPressure;
 
   for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
-    Register Reg = Use.RegUnit;
-    if (!Reg.isVirtual())
+    if (!Use.VRegOrUnit.isVirtualReg())
       continue;
+    Register Reg = Use.VRegOrUnit.asVirtualReg();
     LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
     if (LastUseMask.none())
       continue;
@@ -782,9 +783,9 @@ GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
 
   // Generate liveness for defs.
   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
-    Register Reg = Def.RegUnit;
-    if (!Reg.isVirtual())
+    if (!Def.VRegOrUnit.isVirtualReg())
       continue;
+    Register Reg = Def.VRegOrUnit.asVirtualReg();
     auto It = LiveRegs.find(Reg);
     LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
     LaneBitmask NewMask = LiveMask | Def.LaneMask;
@@ -824,8 +825,7 @@ Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
       Register Reg = Register::index2VirtReg(I);
       auto It = LiveRegs.find(Reg);
       if (It != LiveRegs.end() && It->second.any())
-        OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
-           << PrintLaneMask(It->second);
+        OS << ' ' << printReg(Reg, TRI) << ':' << PrintLaneMask(It->second);
     }
     OS << '\n';
   });

diff  --git a/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp b/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
index fd28abeb887c2..2f3ad39c75fcc 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
+++ b/llvm/lib/Target/AMDGPU/SIMachineScheduler.cpp
@@ -323,8 +323,8 @@ void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock,
 
   // Do not Track Physical Registers, because it messes up.
   for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) {
-    if (RegMaskPair.RegUnit.isVirtual())
-      LiveInRegs.insert(RegMaskPair.RegUnit);
+    if (RegMaskPair.VRegOrUnit.isVirtualReg())
+      LiveInRegs.insert(RegMaskPair.VRegOrUnit.asVirtualReg());
   }
   LiveOutRegs.clear();
   // There is several possibilities to distinguish:
@@ -350,12 +350,13 @@ void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock,
   // Comparing to LiveInRegs is not sufficient to 
diff erentiate 4 vs 5, 7
   // The use of findDefBetween removes the case 4.
   for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) {
-    Register Reg = RegMaskPair.RegUnit;
-    if (Reg.isVirtual() &&
-        isDefBetween(Reg, LIS->getInstructionIndex(*BeginBlock).getRegSlot(),
+    VirtRegOrUnit VRegOrUnit = RegMaskPair.VRegOrUnit;
+    if (VRegOrUnit.isVirtualReg() &&
+        isDefBetween(VRegOrUnit.asVirtualReg(),
+                     LIS->getInstructionIndex(*BeginBlock).getRegSlot(),
                      LIS->getInstructionIndex(*EndBlock).getRegSlot(), MRI,
                      LIS)) {
-      LiveOutRegs.insert(Reg);
+      LiveOutRegs.insert(VRegOrUnit.asVirtualReg());
     }
   }
 
@@ -578,11 +579,11 @@ void SIScheduleBlock::printDebug(bool full) {
            << LiveOutPressure[AMDGPU::RegisterPressureSets::VGPR_32] << "\n\n";
     dbgs() << "LiveIns:\n";
     for (Register Reg : LiveInRegs)
-      dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
+      dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
 
     dbgs() << "\nLiveOuts:\n";
     for (Register Reg : LiveOutRegs)
-      dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
+      dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
   }
 
   dbgs() << "\nInstructions:\n";
@@ -1446,23 +1447,24 @@ SIScheduleBlockScheduler::SIScheduleBlockScheduler(SIScheduleDAGMI *DAG,
   }
 #endif
 
-  std::set<Register> InRegs = DAG->getInRegs();
+  std::set<VirtRegOrUnit> InRegs = DAG->getInRegs();
   addLiveRegs(InRegs);
 
   // Increase LiveOutRegsNumUsages for blocks
   // producing registers consumed in another
   // scheduling region.
-  for (Register Reg : DAG->getOutRegs()) {
+  for (VirtRegOrUnit VRegOrUnit : DAG->getOutRegs()) {
     for (unsigned i = 0, e = Blocks.size(); i != e; ++i) {
       // Do reverse traversal
       int ID = BlocksStruct.TopDownIndex2Block[Blocks.size()-1-i];
       SIScheduleBlock *Block = Blocks[ID];
       const std::set<Register> &OutRegs = Block->getOutRegs();
 
-      if (OutRegs.find(Reg) == OutRegs.end())
+      if (!VRegOrUnit.isVirtualReg() ||
+          OutRegs.find(VRegOrUnit.asVirtualReg()) == OutRegs.end())
         continue;
 
-      ++LiveOutRegsNumUsages[ID][Reg];
+      ++LiveOutRegsNumUsages[ID][VRegOrUnit.asVirtualReg()];
       break;
     }
   }
@@ -1565,15 +1567,18 @@ SIScheduleBlock *SIScheduleBlockScheduler::pickBlock() {
     maxVregUsage = VregCurrentUsage;
   if (SregCurrentUsage > maxSregUsage)
     maxSregUsage = SregCurrentUsage;
-  LLVM_DEBUG(dbgs() << "Picking New Blocks\n"; dbgs() << "Available: ";
-             for (SIScheduleBlock *Block : ReadyBlocks)
-               dbgs() << Block->getID() << ' ';
-             dbgs() << "\nCurrent Live:\n";
-             for (Register Reg : LiveRegs)
-               dbgs() << printVRegOrUnit(Reg, DAG->getTRI()) << ' ';
-             dbgs() << '\n';
-             dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n';
-             dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n';);
+  LLVM_DEBUG({
+    dbgs() << "Picking New Blocks\n";
+    dbgs() << "Available: ";
+    for (SIScheduleBlock *Block : ReadyBlocks)
+      dbgs() << Block->getID() << ' ';
+    dbgs() << "\nCurrent Live:\n";
+    for (Register Reg : LiveRegs)
+      dbgs() << printReg(Reg, DAG->getTRI()) << ' ';
+    dbgs() << '\n';
+    dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n';
+    dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n';
+  });
 
   Cand.Block = nullptr;
   for (std::vector<SIScheduleBlock*>::iterator I = ReadyBlocks.begin(),
@@ -1625,13 +1630,13 @@ SIScheduleBlock *SIScheduleBlockScheduler::pickBlock() {
 
 // Tracking of currently alive registers to determine VGPR Usage.
 
-void SIScheduleBlockScheduler::addLiveRegs(std::set<Register> &Regs) {
-  for (Register Reg : Regs) {
+void SIScheduleBlockScheduler::addLiveRegs(std::set<VirtRegOrUnit> &Regs) {
+  for (VirtRegOrUnit VRegOrUnit : Regs) {
     // For now only track virtual registers.
-    if (!Reg.isVirtual())
+    if (!VRegOrUnit.isVirtualReg())
       continue;
     // If not already in the live set, then add it.
-    (void) LiveRegs.insert(Reg);
+    (void)LiveRegs.insert(VRegOrUnit.asVirtualReg());
   }
 }
 
@@ -1662,7 +1667,7 @@ void SIScheduleBlockScheduler::releaseBlockSuccs(SIScheduleBlock *Parent) {
 
 void SIScheduleBlockScheduler::blockScheduled(SIScheduleBlock *Block) {
   decreaseLiveRegs(Block, Block->getInRegs());
-  addLiveRegs(Block->getOutRegs());
+  LiveRegs.insert(Block->getOutRegs().begin(), Block->getOutRegs().end());
   releaseBlockSuccs(Block);
   for (const auto &RegP : LiveOutRegsNumUsages[Block->getID()]) {
     // We produce this register, thus it must not be previously alive.
@@ -1689,7 +1694,7 @@ SIScheduleBlockScheduler::checkRegUsageImpact(std::set<Register> &InRegs,
       continue;
     if (LiveRegsConsumers[Reg] > 1)
       continue;
-    PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg);
+    PSetIterator PSetI = DAG->getMRI()->getPressureSets(VirtRegOrUnit(Reg));
     for (; PSetI.isValid(); ++PSetI) {
       DiffSetPressure[*PSetI] -= PSetI.getWeight();
     }
@@ -1699,7 +1704,7 @@ SIScheduleBlockScheduler::checkRegUsageImpact(std::set<Register> &InRegs,
     // For now only track virtual registers.
     if (!Reg.isVirtual())
       continue;
-    PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg);
+    PSetIterator PSetI = DAG->getMRI()->getPressureSets(VirtRegOrUnit(Reg));
     for (; PSetI.isValid(); ++PSetI) {
       DiffSetPressure[*PSetI] += PSetI.getWeight();
     }
@@ -1846,7 +1851,7 @@ SIScheduleDAGMI::fillVgprSgprCost(_Iterator First, _Iterator End,
     // For now only track virtual registers
     if (!Reg.isVirtual())
       continue;
-    PSetIterator PSetI = MRI.getPressureSets(Reg);
+    PSetIterator PSetI = MRI.getPressureSets(VirtRegOrUnit(Reg));
     for (; PSetI.isValid(); ++PSetI) {
       if (*PSetI == AMDGPU::RegisterPressureSets::VGPR_32)
         VgprUsage += PSetI.getWeight();

diff  --git a/llvm/lib/Target/AMDGPU/SIMachineScheduler.h b/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
index b219cbd5672f0..1245774400af1 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineScheduler.h
@@ -389,7 +389,7 @@ class SIScheduleBlockScheduler {
                             SIBlockSchedCandidate &TryCand);
   SIScheduleBlock *pickBlock();
 
-  void addLiveRegs(std::set<Register> &Regs);
+  void addLiveRegs(std::set<VirtRegOrUnit> &Regs);
   void decreaseLiveRegs(SIScheduleBlock *Block, std::set<Register> &Regs);
   void releaseBlockSuccs(SIScheduleBlock *Parent);
   void blockScheduled(SIScheduleBlock *Block);
@@ -462,18 +462,18 @@ class SIScheduleDAGMI final : public ScheduleDAGMILive {
                                                      unsigned &VgprUsage,
                                                      unsigned &SgprUsage);
 
-  std::set<Register> getInRegs() {
-    std::set<Register> InRegs;
+  std::set<VirtRegOrUnit> getInRegs() {
+    std::set<VirtRegOrUnit> InRegs;
     for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) {
-      InRegs.insert(RegMaskPair.RegUnit);
+      InRegs.insert(RegMaskPair.VRegOrUnit);
     }
     return InRegs;
   }
 
-  std::set<unsigned> getOutRegs() {
-    std::set<unsigned> OutRegs;
+  std::set<VirtRegOrUnit> getOutRegs() {
+    std::set<VirtRegOrUnit> OutRegs;
     for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) {
-      OutRegs.insert(RegMaskPair.RegUnit);
+      OutRegs.insert(RegMaskPair.VRegOrUnit);
     }
     return OutRegs;
   };

diff  --git a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
index 6611e1e6507e1..10762edc16264 100644
--- a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
+++ b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
@@ -188,8 +188,9 @@ class SIWholeQuadMode {
 
   void markInstruction(MachineInstr &MI, char Flag,
                        std::vector<WorkItem> &Worklist);
-  void markDefs(const MachineInstr &UseMI, LiveRange &LR, Register Reg,
-                unsigned SubReg, char Flag, std::vector<WorkItem> &Worklist);
+  void markDefs(const MachineInstr &UseMI, LiveRange &LR,
+                VirtRegOrUnit VRegOrUnit, unsigned SubReg, char Flag,
+                std::vector<WorkItem> &Worklist);
   void markOperand(const MachineInstr &MI, const MachineOperand &Op, char Flag,
                    std::vector<WorkItem> &Worklist);
   void markInstructionUses(const MachineInstr &MI, char Flag,
@@ -318,8 +319,8 @@ void SIWholeQuadMode::markInstruction(MachineInstr &MI, char Flag,
 
 /// Mark all relevant definitions of register \p Reg in usage \p UseMI.
 void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
-                               Register Reg, unsigned SubReg, char Flag,
-                               std::vector<WorkItem> &Worklist) {
+                               VirtRegOrUnit VRegOrUnit, unsigned SubReg,
+                               char Flag, std::vector<WorkItem> &Worklist) {
   LLVM_DEBUG(dbgs() << "markDefs " << PrintState(Flag) << ": " << UseMI);
 
   LiveQueryResult UseLRQ = LR.Query(LIS->getInstructionIndex(UseMI));
@@ -331,8 +332,9 @@ void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
   // cover registers.
   const LaneBitmask UseLanes =
       SubReg ? TRI->getSubRegIndexLaneMask(SubReg)
-             : (Reg.isVirtual() ? MRI->getMaxLaneMaskForVReg(Reg)
-                                : LaneBitmask::getNone());
+             : (VRegOrUnit.isVirtualReg()
+                    ? MRI->getMaxLaneMaskForVReg(VRegOrUnit.asVirtualReg())
+                    : LaneBitmask::getNone());
 
   // Perform a depth-first iteration of the LiveRange graph marking defs.
   // Stop processing of a given branch when all use lanes have been defined.
@@ -382,11 +384,11 @@ void SIWholeQuadMode::markDefs(const MachineInstr &UseMI, LiveRange &LR,
       MachineInstr *MI = LIS->getInstructionFromIndex(Value->def);
       assert(MI && "Def has no defining instruction");
 
-      if (Reg.isVirtual()) {
+      if (VRegOrUnit.isVirtualReg()) {
         // Iterate over all operands to find relevant definitions
         bool HasDef = false;
         for (const MachineOperand &Op : MI->all_defs()) {
-          if (Op.getReg() != Reg)
+          if (Op.getReg() != VRegOrUnit.asVirtualReg())
             continue;
 
           // Compute lanes defined and overlap with use
@@ -453,7 +455,7 @@ void SIWholeQuadMode::markOperand(const MachineInstr &MI,
                     << " for " << MI);
   if (Reg.isVirtual()) {
     LiveRange &LR = LIS->getInterval(Reg);
-    markDefs(MI, LR, Reg, Op.getSubReg(), Flag, Worklist);
+    markDefs(MI, LR, VirtRegOrUnit(Reg), Op.getSubReg(), Flag, Worklist);
   } else {
     // Handle physical registers that we need to track; this is mostly relevant
     // for VCC, which can appear as the (implicit) input of a uniform branch,
@@ -462,7 +464,8 @@ void SIWholeQuadMode::markOperand(const MachineInstr &MI,
       LiveRange &LR = LIS->getRegUnit(Unit);
       const VNInfo *Value = LR.Query(LIS->getInstructionIndex(MI)).valueIn();
       if (Value)
-        markDefs(MI, LR, Unit, AMDGPU::NoSubRegister, Flag, Worklist);
+        markDefs(MI, LR, VirtRegOrUnit(Unit), AMDGPU::NoSubRegister, Flag,
+                 Worklist);
     }
   }
 }

diff  --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index 6077c18463240..02887ce93c525 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -6551,7 +6551,7 @@ class ARMPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
   static int constexpr LAST_IS_USE = MAX_STAGES;
   static int constexpr SEEN_AS_LIVE = MAX_STAGES + 1;
   typedef std::bitset<MAX_STAGES + 2> IterNeed;
-  typedef std::map<unsigned, IterNeed> IterNeeds;
+  typedef std::map<Register, IterNeed> IterNeeds;
 
   void bumpCrossIterationPressure(RegPressureTracker &RPT,
                                   const IterNeeds &CIN);
@@ -6625,14 +6625,14 @@ void ARMPipelinerLoopInfo::bumpCrossIterationPressure(RegPressureTracker &RPT,
   for (const auto &N : CIN) {
     int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
     for (int I = 0; I < Cnt; ++I)
-      RPT.increaseRegPressure(Register(N.first), LaneBitmask::getNone(),
+      RPT.increaseRegPressure(VirtRegOrUnit(N.first), LaneBitmask::getNone(),
                               LaneBitmask::getAll());
   }
   // Decrease pressure by the amounts in CrossIterationNeeds
   for (const auto &N : CIN) {
     int Cnt = N.second.count() - N.second[SEEN_AS_LIVE] * 2;
     for (int I = 0; I < Cnt; ++I)
-      RPT.decreaseRegPressure(Register(N.first), LaneBitmask::getAll(),
+      RPT.decreaseRegPressure(VirtRegOrUnit(N.first), LaneBitmask::getAll(),
                               LaneBitmask::getNone());
   }
 }


        


More information about the llvm-commits mailing list