[llvm] [TableGen][SIInsertWaitcnts] use RegIntervals for AMDGPU (PR #174888)

Ryan Mitchell via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 13 11:45:40 PST 2026


https://github.com/RyanRio updated https://github.com/llvm/llvm-project/pull/174888

>From adaa750379d1a7e94aafe863a573306d960c2868 Mon Sep 17 00:00:00 2001
From: Ryan Mitchell <Ryan.Mitchell at amd.com>
Date: Wed, 7 Jan 2026 15:23:57 -0800
Subject: [PATCH 1/2] [TableGen][SIInsertWaitcnts] use RegIntervals for AMDGPU

---
 llvm/include/llvm/MC/MCRegisterInfo.h         |  18 +++-
 llvm/include/llvm/Target/Target.td            |   5 +
 llvm/lib/Target/AMDGPU/AMDGPU.td              |   1 +
 llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp   |  76 +++++++------
 .../TableGen/regunit-intervals-impossible.td  |  35 ++++++
 llvm/test/TableGen/regunit-intervals.td       |  73 +++++++++++++
 .../TableGen/Common/CodeGenRegisters.cpp      | 101 ++++++++++++++++++
 llvm/utils/TableGen/Common/CodeGenRegisters.h |   9 ++
 llvm/utils/TableGen/RegisterInfoEmitter.cpp   |  17 +++
 9 files changed, 300 insertions(+), 35 deletions(-)
 create mode 100644 llvm/test/TableGen/regunit-intervals-impossible.td
 create mode 100644 llvm/test/TableGen/regunit-intervals.td

diff --git a/llvm/include/llvm/MC/MCRegisterInfo.h b/llvm/include/llvm/MC/MCRegisterInfo.h
index f4897b6a406fb..76ef62da3d35b 100644
--- a/llvm/include/llvm/MC/MCRegisterInfo.h
+++ b/llvm/include/llvm/MC/MCRegisterInfo.h
@@ -180,6 +180,7 @@ class LLVM_ABI MCRegisterInfo {
   unsigned NumSubRegIndices;                  // Number of subreg indices.
   const uint16_t *RegEncodingTable;           // Pointer to array of register
                                               // encodings.
+  const unsigned (*RegUnitIntervals)[2]; // Pointer to regunit interval table.
 
   unsigned L2DwarfRegsSize;
   unsigned EHL2DwarfRegsSize;
@@ -286,7 +287,8 @@ class LLVM_ABI MCRegisterInfo {
                           const int16_t *DL, const LaneBitmask *RUMS,
                           const char *Strings, const char *ClassStrings,
                           const uint16_t *SubIndices, unsigned NumIndices,
-                          const uint16_t *RET) {
+                          const uint16_t *RET,
+                          const unsigned (*RUI)[2] = nullptr) {
     Desc = D;
     NumRegs = NR;
     RAReg = RA;
@@ -302,6 +304,7 @@ class LLVM_ABI MCRegisterInfo {
     SubRegIndices = SubIndices;
     NumSubRegIndices = NumIndices;
     RegEncodingTable = RET;
+    RegUnitIntervals = RUI;
 
     // Initialize DWARF register mapping variables
     EHL2DwarfRegs = nullptr;
@@ -511,6 +514,19 @@ class LLVM_ABI MCRegisterInfo {
 
   /// Returns true if the two registers are equal or alias each other.
   bool regsOverlap(MCRegister RegA, MCRegister RegB) const;
+
+  /// Returns true if this target uses regunit intervals.
+  bool hasRegUnitIntervals() const { return RegUnitIntervals != nullptr; }
+
+  /// Returns an iterator range over all native regunits in the RegUnitInterval
+  /// table for \p Reg.
+  iota_range<unsigned> regunits_interval(MCRegister Reg) const {
+    assert(hasRegUnitIntervals() &&
+           "Target does not support regunit intervals");
+    assert(Reg.id() < NumRegs && "Invalid register number");
+    return seq<unsigned>(RegUnitIntervals[Reg.id()][0],
+                         RegUnitIntervals[Reg.id()][1]);
+  }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td
index 315de55b75510..45ed2a674860b 100644
--- a/llvm/include/llvm/Target/Target.td
+++ b/llvm/include/llvm/Target/Target.td
@@ -1935,6 +1935,11 @@ class Target {
   // setting hasExtraDefRegAllocReq and hasExtraSrcRegAllocReq to 1
   // for all opcodes if this flag is set to 0.
   int AllowRegisterRenaming = 0;
+
+  // RegistersAreIntervals - Controls whether this target requires
+  // all Registers to have RegUnit intervals. Will attempt to reorder
+  // RegUnits to enforce this, and if a solution is not found, will error.
+  bit RegistersAreIntervals = 0;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index f015353ee4f3a..f8ed85ebb70e5 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -2327,6 +2327,7 @@ def AMDGPU : Target {
                                 VOP3_DPPAsmParserVariant];
   let AssemblyWriters = [AMDGPUAsmWriter];
   let AllowRegisterRenaming = 1;
+  let RegistersAreIntervals = 1;
 }
 
 // Dummy Instruction itineraries for pseudo instructions
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 20ba44e3b8e59..cea3d3aef27ae 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -133,11 +133,15 @@ static unsigned getWaitCountMax(const AMDGPU::HardwareLimits &Limits,
     return 0;
   }
 }
+// An integer representing a particular Unit of storage.
+// On AMDGPU targets, storage that spans multiple Units
+// is guaranteed to be contiguous.
+using StorageUnit = unsigned;
 
 /// Integer IDs used to track vector memory locations we may have to wait on.
 /// Encoded as u16 chunks:
 ///
-///   [0,               REGUNITS_END ): MCRegUnit
+///   [0,               REGUNITS_END ): RegUnit IDs
 ///   [LDSDMA_BEGIN,    LDSDMA_END  ) : LDS DMA IDs
 ///
 /// NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
@@ -152,7 +156,7 @@ enum : VMEMID {
   TRACKINGID_RANGE_LEN = (1 << 16),
 
   // Important: MCRegUnits must always be tracked starting from 0, as we
-  // need to be able to convert between a MCRegUnit and a VMEMID freely.
+  // need to be able to convert between a StorageUnit and a VMEMID freely.
   REGUNITS_BEGIN = 0,
   REGUNITS_END = REGUNITS_BEGIN + TRACKINGID_RANGE_LEN,
 
@@ -164,10 +168,16 @@ enum : VMEMID {
   LDSDMA_END = LDSDMA_BEGIN + NUM_LDSDMA,
 };
 
-/// Convert a MCRegUnit to a VMEMID.
-static constexpr VMEMID toVMEMID(MCRegUnit RU) {
-  return static_cast<unsigned>(RU);
-}
+struct HardwareLimits {
+  unsigned LoadcntMax; // Corresponds to VMcnt prior to gfx12.
+  unsigned ExpcntMax;
+  unsigned DscntMax;     // Corresponds to LGKMcnt prior to gfx12.
+  unsigned StorecntMax;  // Corresponds to VScnt in gfx10/gfx11.
+  unsigned SamplecntMax; // gfx12+ only.
+  unsigned BvhcntMax;    // gfx12+ only.
+  unsigned KmcntMax;     // gfx12+ only.
+  unsigned XcntMax;      // gfx1250.
+};
 
 #define AMDGPU_DECLARE_WAIT_EVENTS(DECL)                                       \
   DECL(VMEM_ACCESS) /* vmem read & write (pre-gfx10), vmem read (gfx10+) */    \
@@ -708,7 +718,7 @@ class WaitcntBrackets {
     return getScoreUB(T) - getScoreLB(T);
   }
 
-  unsigned getSGPRScore(MCRegUnit RU, InstCounterType T) const {
+  unsigned getSGPRScore(StorageUnit RU, InstCounterType T) const {
     auto It = SGPRs.find(RU);
     return It != SGPRs.end() ? It->second.Scores[getSgprScoresIdx(T)] : 0;
   }
@@ -784,8 +794,8 @@ class WaitcntBrackets {
   // Return true if there might be pending writes to the vgpr-interval by VMEM
   // instructions with types different from V.
   bool hasOtherPendingVmemTypes(MCPhysReg Reg, VmemType V) const {
-    for (MCRegUnit RU : regunits(Reg)) {
-      auto It = VMem.find(toVMEMID(RU));
+    for (unsigned RU : regunits_interval(Reg)) {
+      auto It = VMem.find(RU);
       if (It != VMem.end() && (It->second.VMEMTypes & ~(1 << V)))
         return true;
     }
@@ -793,8 +803,8 @@ class WaitcntBrackets {
   }
 
   void clearVgprVmemTypes(MCPhysReg Reg) {
-    for (MCRegUnit RU : regunits(Reg)) {
-      if (auto It = VMem.find(toVMEMID(RU)); It != VMem.end()) {
+    for (unsigned RU : regunits_interval(Reg)) {
+      if (auto It = VMem.find(RU); It != VMem.end()) {
         It->second.VMEMTypes = 0;
         if (It->second.empty())
           VMem.erase(It);
@@ -837,15 +847,15 @@ class WaitcntBrackets {
   static bool mergeScore(const MergeInfo &M, unsigned &Score,
                          unsigned OtherScore);
 
-  iterator_range<MCRegUnitIterator> regunits(MCPhysReg Reg) const {
+  iota_range<StorageUnit> regunits_interval(MCPhysReg Reg) const {
     assert(Reg != AMDGPU::SCC && "Shouldn't be used on SCC");
     if (!Context->TRI->isInAllocatableClass(Reg))
-      return {{}, {}};
+      return seq<StorageUnit>(0, 0);
     const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
     unsigned Size = Context->TRI->getRegSizeInBits(*RC);
     if (Size == 16 && Context->ST->hasD16Writes32BitVgpr())
       Reg = Context->TRI->get32BitRegister(Reg);
-    return Context->TRI->regunits(Reg);
+    return Context->TRI->regunits_interval(Reg);
   }
 
   void setScoreLB(InstCounterType T, unsigned Val) {
@@ -870,11 +880,11 @@ class WaitcntBrackets {
     if (Reg == AMDGPU::SCC) {
       SCCScore = Val;
     } else if (TRI->isVectorRegister(*Context->MRI, Reg)) {
-      for (MCRegUnit RU : regunits(Reg))
-        VMem[toVMEMID(RU)].Scores[T] = Val;
+      for (unsigned RU : regunits_interval(Reg))
+        VMem[RU].Scores[T] = Val;
     } else if (TRI->isSGPRReg(*Context->MRI, Reg)) {
       auto STy = getSgprScoresIdx(T);
-      for (MCRegUnit RU : regunits(Reg))
+      for (unsigned RU : regunits_interval(Reg))
         SGPRs[RU].Scores[STy] = Val;
     } else {
       llvm_unreachable("Register cannot be tracked/unknown register!");
@@ -933,7 +943,7 @@ class WaitcntBrackets {
   };
 
   DenseMap<VMEMID, VMEMInfo> VMem; // VGPR + LDS DMA
-  DenseMap<MCRegUnit, SGPRInfo> SGPRs;
+  DenseMap<StorageUnit, SGPRInfo> SGPRs;
 
   // Reg score for SCC.
   unsigned SCCScore = 0;
@@ -1137,8 +1147,8 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
           // this with another potential dependency
           if (hasPointSampleAccel(Inst))
             TypesMask |= 1 << VMEM_NOSAMPLER;
-          for (MCRegUnit RU : regunits(Op.getReg().asMCReg()))
-            VMem[toVMEMID(RU)].VMEMTypes |= TypesMask;
+          for (unsigned RU : regunits_interval(Op.getReg().asMCReg()))
+            VMem[RU].VMEMTypes |= TypesMask;
         }
       }
       setScoreByOperand(Op, T, CurrScore);
@@ -1264,7 +1274,7 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
 
       // Also need to print sgpr scores for lgkm_cnt or xcnt.
       if (isSmemCounter(T)) {
-        SmallVector<MCRegUnit> SortedSMEMIDs(SGPRs.keys());
+        SmallVector<StorageUnit> SortedSMEMIDs(SGPRs.keys());
         sort(SortedSMEMIDs);
         for (auto ID : SortedSMEMIDs) {
           unsigned RegScore = SGPRs.at(ID).Scores[getSgprScoresIdx(T)];
@@ -1380,10 +1390,9 @@ void WaitcntBrackets::determineWaitForPhysReg(InstCounterType T, MCPhysReg Reg,
     determineWaitForScore(T, SCCScore, Wait);
   } else {
     bool IsVGPR = Context->TRI->isVectorRegister(*Context->MRI, Reg);
-    for (MCRegUnit RU : regunits(Reg))
+    for (unsigned RU : regunits_interval(Reg))
       determineWaitForScore(
-          T, IsVGPR ? getVMemScore(toVMEMID(RU), T) : getSGPRScore(RU, T),
-          Wait);
+          T, IsVGPR ? getVMemScore(RU, T) : getSGPRScore(RU, T), Wait);
   }
 }
 
@@ -3114,9 +3123,9 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
   bool VMemInvalidated = false;
   // DS optimization only applies to GFX12+ where DS_CNT is separate.
   bool DSInvalidated = !ST->hasExtendedWaitCounts();
-  DenseSet<MCRegUnit> VgprUse;
-  DenseSet<MCRegUnit> VgprDefVMEM;
-  DenseSet<MCRegUnit> VgprDefDS;
+  DenseSet<StorageUnit> VgprUse;
+  DenseSet<StorageUnit> VgprDefVMEM;
+  DenseSet<StorageUnit> VgprDefDS;
 
   for (MachineBasicBlock *MBB : ML->blocks()) {
     bool SeenDSStoreInCurrMBB = false;
@@ -3137,7 +3146,7 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
         if (Op.isDebug() || !TRI->isVectorRegister(*MRI, Op.getReg()))
           continue;
         // Vgpr use
-        for (MCRegUnit RU : TRI->regunits(Op.getReg().asMCReg())) {
+        for (unsigned RU : TRI->regunits_interval(Op.getReg().asMCReg())) {
           // If we find a register that is loaded inside the loop, 1. and 2.
           // are invalidated.
           if (VgprDefVMEM.contains(RU))
@@ -3154,20 +3163,19 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
           VgprUse.insert(RU);
           // Check if this register has a pending VMEM load from outside the
           // loop (value loaded outside and used inside).
-          VMEMID ID = toVMEMID(RU);
           bool HasPendingVMEM =
-              Brackets.getVMemScore(ID, LOAD_CNT) >
+              Brackets.getVMemScore(RU, LOAD_CNT) >
                   Brackets.getScoreLB(LOAD_CNT) ||
-              Brackets.getVMemScore(ID, SAMPLE_CNT) >
+              Brackets.getVMemScore(RU, SAMPLE_CNT) >
                   Brackets.getScoreLB(SAMPLE_CNT) ||
-              Brackets.getVMemScore(ID, BVH_CNT) > Brackets.getScoreLB(BVH_CNT);
+              Brackets.getVMemScore(RU, BVH_CNT) > Brackets.getScoreLB(BVH_CNT);
           if (HasPendingVMEM)
             UsesVgprLoadedOutsideVMEM = true;
           // Check if loaded outside the loop via DS (not VMEM/FLAT).
           // Only consider it a DS load if there's no pending VMEM load for
           // this register, since FLAT can set both counters.
           if (!HasPendingVMEM &&
-              Brackets.getVMemScore(ID, DS_CNT) > Brackets.getScoreLB(DS_CNT))
+              Brackets.getVMemScore(RU, DS_CNT) > Brackets.getScoreLB(DS_CNT))
             UsesVgprLoadedOutsideDS = true;
         }
       }
@@ -3175,7 +3183,7 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
       // VMem load vgpr def
       if (isVMEMOrFlatVMEM(MI) && MI.mayLoad()) {
         for (const MachineOperand &Op : MI.all_defs()) {
-          for (MCRegUnit RU : TRI->regunits(Op.getReg().asMCReg())) {
+          for (unsigned RU : TRI->regunits_interval(Op.getReg().asMCReg())) {
             // If we find a register that is loaded inside the loop, 1. and 2.
             // are invalidated.
             if (VgprUse.contains(RU))
diff --git a/llvm/test/TableGen/regunit-intervals-impossible.td b/llvm/test/TableGen/regunit-intervals-impossible.td
new file mode 100644
index 0000000000000..9b4f6561178b7
--- /dev/null
+++ b/llvm/test/TableGen/regunit-intervals-impossible.td
@@ -0,0 +1,35 @@
+// RUN: not llvm-tblgen -gen-register-info -I %p/../../include %s 2>&1 | FileCheck %s
+
+include "llvm/Target/Target.td"
+
+def TestInstrInfo : InstrInfo;
+def TestTarget : Target {
+  let InstructionSet = TestInstrInfo;
+  let RegistersAreIntervals = 1;
+}
+
+def sub_lo : SubRegIndex<32>;
+def sub_hi : SubRegIndex<32, 32>;
+
+let Namespace = "Test" in {
+  def R1 : Register<"r1">;  // unit 0
+  def R2 : Register<"r2">;  // unit 1
+  def R3 : Register<"r3">;  // unit 2
+
+  // First composite: units {0, 1} - contiguous, OK
+  def R1_R2 : Register<"r1_r2"> {
+    let SubRegs = [R1, R2];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+
+  // Second composite: units {0, 2} - non-contiguous!
+  // Algorithm will swap 1 and 2, making R1_R2 not contiguous
+  def R1_R3 : Register<"r1_r3"> {
+    let SubRegs = [R1, R3];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+}
+
+def GPR32 : RegisterClass<"Test", [i32], 32, (add R1, R2, R3, R4)>;
+
+// CHECK: error: Cannot enforce regunit intervals, final renumbering did not produce contiguous units for register R10_R11
diff --git a/llvm/test/TableGen/regunit-intervals.td b/llvm/test/TableGen/regunit-intervals.td
new file mode 100644
index 0000000000000..3d5dcd7c404dc
--- /dev/null
+++ b/llvm/test/TableGen/regunit-intervals.td
@@ -0,0 +1,73 @@
+// RUN: llvm-tblgen -gen-register-info -I %p/../../include %s | FileCheck %s
+
+include "llvm/Target/Target.td"
+
+def TestInstrInfo : InstrInfo;
+def TestTarget : Target {
+  let InstructionSet = TestInstrInfo;
+  let RegistersAreIntervals  = 1;
+}
+
+def sub_lo : SubRegIndex<32>;
+def sub_hi : SubRegIndex<32, 32>;
+
+let Namespace = "Test" in {
+  // Simple 32-bit registers (each gets 1 regunit)
+  def R0 : Register<"r0">;
+  def R1 : Register<"r1">;
+  def R2 : Register<"r2">;
+  def R3 : Register<"r3">;
+
+  // 64-bit register composed of R0:R1 (gets 2 regunits)
+  def R0_R1 : Register<"r0_r1"> {
+    let SubRegs = [R0, R1];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+
+  // 64-bit register composed of R2:R3 (gets 2 regunits)
+  def R2_R3 : Register<"r2_r3"> {
+    let SubRegs = [R2, R3];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+}
+
+// CHECK: extern const uint16_t TestRegUnitIntervals[][2] = {
+// CHECK-NEXT: { 0, 1 },
+// CHECK-NEXT: { 1, 2 },
+// CHECK-NEXT: { 2, 3 },
+// CHECK-NEXT: { 3, 4 },
+
+let Namespace = "Test" in {
+  def R4 : Register<"r4">;      // Gets unit 4
+  def R5 : Register<"r5">;      // Gets unit 5
+  def R6 : Register<"r6">;      // Gets unit 6
+  def R7 : Register<"r7">;      // Gets unit 7
+  
+  // This register skips R5, creating non-contiguous units {4, 6}
+  def R4_R6 : Register<"r4_r6"> {
+    let SubRegs = [R4, R6];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+
+  // This register skips R6, creating non-contiguous units {5, 7}
+  def R5_R7 : Register<"r4_r6"> {
+    let SubRegs = [R5, R7];
+    let SubRegIndices = [sub_lo, sub_hi];
+  }
+}
+
+
+def GPR32 : RegisterClass<"Test", [i32], 32, (add R0, R1, R2, R3)>;
+def GPR64 : RegisterClass<"Test", [i64], 64, (add R0_R1, R2_R3)>;
+
+// Note R5 is assigned 6,7 so that R6 gets 5,6
+// CHECK-NEXT: { 4, 5 },
+// CHECK-NEXT: { 6, 7 },
+// CHECK-NEXT: { 5, 6 },
+// CHECK-NEXT: { 7, 8 },
+// All contiguous
+// CHECK-NEXT: { 0, 2 },
+// CHECK-NEXT: { 2, 4 },
+// CHECK-NEXT: { 4, 6 },
+// CHECK-NEXT: { 6, 8 },
+// CHECK-NEXT: };
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
index 65a2594859e69..a707afaafb1d2 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
@@ -1929,6 +1929,102 @@ void CodeGenRegBank::computeRegUnitWeights() {
   }
 }
 
+// Enforce that all registers are intervals of regunits if the target
+// requests this property. This will renumber regunits to ensure the
+// interval property holds, or error out if it cannot be satisfied.
+void CodeGenRegBank::enforceRegUnitIntervals() {
+  std::vector<const Record *> Targets =
+      Records.getAllDerivedDefinitions("Target");
+
+  if (Targets.empty())
+    return;
+
+  const Record *Target = Targets[0];
+  if (!Target->getValueAsBit("RegistersAreIntervals"))
+    return;
+
+  LLVM_DEBUG(dbgs() << "Enforcing regunit intervals for target\n");
+  std::vector<unsigned> RegUnitRenumbering(RegUnits.size(), ~0u);
+
+  // RegUnits that have been renumbered from X -> Y. Y is what is marked so that
+  // it doesn't create a chain of swaps.
+  SparseBitVector DontRenumberUnits;
+
+  auto GetRenumberedUnit =
+      [&](unsigned RegUnit) -> unsigned {
+    if (RegUnitRenumbering[RegUnit] != ~0u)
+      return RegUnitRenumbering[RegUnit];
+    return RegUnit;
+  };
+
+  auto IsContiguous =
+      [&](CodeGenRegister::RegUnitList &Units) -> bool {
+    unsigned LastUnit = Units.find_first();
+    for (auto ThisUnit : llvm::make_range(++Units.begin(), Units.end())) {
+      if (ThisUnit != LastUnit + 1)
+        return false;
+      LastUnit = ThisUnit;
+    }
+    return true;
+  };
+
+  // Process registers in definition order
+  for (CodeGenRegister &Reg : Registers) {
+    LLVM_DEBUG(dbgs() << "Processing register " << Reg.getName() << "\n");
+    const auto &Units = Reg.getNativeRegUnits();
+    if (Units.empty())
+      continue;
+    SparseBitVector RenumberedUnits;
+    // First renumber all the units for this register according to previous
+    // renumbering.
+    LLVM_DEBUG(dbgs() << "  Original (Renumbered) units:");
+    for (unsigned U : Units) {
+      LLVM_DEBUG(dbgs() << " " << U << "(" << GetRenumberedUnit(U) << "), ");
+      RenumberedUnits.set(GetRenumberedUnit(U));
+    }
+    LLVM_DEBUG(dbgs() << "\n");
+
+    unsigned LastUnit = RenumberedUnits.find_first();
+    for (auto ThisUnit :
+         llvm::make_range(++RenumberedUnits.begin(), RenumberedUnits.end())) {
+      if (ThisUnit != LastUnit + 1) {
+        if (DontRenumberUnits.test(LastUnit + 1)) {
+          PrintFatalError("Cannot enforce regunit intervals for register " +
+                  Reg.getName() + ": unit " + Twine(LastUnit + 1) +
+                  " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
+                  ") has already been renumbered and cannot be swapped");
+        }
+        LLVM_DEBUG(dbgs() << "  Renumbering unit " << ThisUnit << " to "
+                          << (LastUnit + 1) << "\n");
+        RegUnitRenumbering[LastUnit + 1] = ThisUnit;
+        RegUnitRenumbering[ThisUnit] = LastUnit + 1;
+        DontRenumberUnits.set(LastUnit + 1);
+        ThisUnit = LastUnit + 1;
+      }
+      LastUnit = ThisUnit;
+    }
+  }
+
+  // Apply the renumbering to all registers
+  for (CodeGenRegister &Reg : Registers) {
+    CodeGenRegister::RegUnitList NewRegUnits;
+    for (unsigned OldUnit : Reg.getRegUnits())
+      NewRegUnits.set(GetRenumberedUnit(OldUnit));
+    Reg.setNewRegUnits(NewRegUnits);
+
+    CodeGenRegister::RegUnitList NewNativeUnits;
+    for (unsigned OldUnit : Reg.getNativeRegUnits())
+      NewNativeUnits.set(GetRenumberedUnit(OldUnit));
+    if (!IsContiguous(NewNativeUnits)) {
+      reportFatalInternalError("Cannot enforce regunit intervals, final "
+                               "renumbering did not produce contiguous units "
+                               "for register " +
+                               Reg.getName() + "\n");
+    }
+    Reg.NativeRegUnits = NewNativeUnits;
+  }
+}
+
 // Find a set in UniqueSets with the same elements as Set.
 // Return an iterator into UniqueSets.
 static std::vector<RegUnitSet>::const_iterator
@@ -2209,6 +2305,11 @@ void CodeGenRegBank::computeDerivedInfo() {
   computeRegUnitWeights();
   Records.getTimer().stopTimer();
 
+  // Enforce regunit intervals if requested by the target.
+  Records.getTimer().startTimer("Enforce regunit intervals");
+  enforceRegUnitIntervals();
+  Records.getTimer().stopTimer();
+
   // Compute a unique set of RegUnitSets. One for each RegClass and inferred
   // supersets for the union of overlapping sets.
   computeRegUnitSets();
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.h b/llvm/utils/TableGen/Common/CodeGenRegisters.h
index a3ad0b797a704..176cb9dafd22b 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.h
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.h
@@ -169,6 +169,8 @@ class CodeGenSubRegIndex {
 
 /// CodeGenRegister - Represents a register definition.
 class CodeGenRegister {
+  friend class CodeGenRegBank;
+
 public:
   const Record *TheDef;
   unsigned EnumValue;
@@ -257,6 +259,10 @@ class CodeGenRegister {
   // This is only valid after computeSubRegs() completes.
   const RegUnitList &getRegUnits() const { return RegUnits; }
 
+  void setNewRegUnits(const RegUnitList &NewRegUnits) {
+    RegUnits = NewRegUnits;
+  }
+
   ArrayRef<LaneBitmask> getRegUnitLaneMasks() const {
     return ArrayRef(RegUnitLaneMasks).slice(0, NativeRegUnits.count());
   }
@@ -693,6 +699,9 @@ class CodeGenRegBank {
   // Compute a weight for each register unit created during getSubRegs.
   void computeRegUnitWeights();
 
+  // Enforce that all registers are intervals of regunits if requested.
+  void enforceRegUnitIntervals();
+
   // Create a RegUnitSet for each RegClass and infer superclasses.
   void computeRegUnitSets();
 
diff --git a/llvm/utils/TableGen/RegisterInfoEmitter.cpp b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
index 02fd8648302f1..552b32736509c 100644
--- a/llvm/utils/TableGen/RegisterInfoEmitter.cpp
+++ b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
@@ -1044,6 +1044,23 @@ void RegisterInfoEmitter::runMCDesc(raw_ostream &OS, raw_ostream &MainOS,
   }
   OS << "};\n\n";
 
+  // Emit the table of register unit intervals.
+  if (Target.getTargetRecord()->getValueAsBit("RegistersAreIntervals")) {
+    OS << "extern const unsigned " << TargetName
+       << "RegUnitIntervals[][2] = {\n";
+    for (const auto &Reg : Regs) {
+      const auto &Units = Reg.getNativeRegUnits();
+      if (Units.empty()) {
+        OS << "  { 0, 0 },\n";
+      } else {
+        unsigned First = Units.find_first();
+        unsigned Last = Units.find_last();
+        OS << "  { " << First << ", " << Last + 1 << " },\n";
+      }
+    }
+    OS << "};\n\n";
+  }
+
   const auto &RegisterClasses = RegBank.getRegClasses();
 
   // Loop over all of the register classes... emitting each one.

>From 807a9cef83cc70d551c18239bc95a51385e85369 Mon Sep 17 00:00:00 2001
From: Ryan Mitchell <Ryan.Mitchell at amd.com>
Date: Tue, 13 Jan 2026 11:45:25 -0800
Subject: [PATCH 2/2] Format

---
 llvm/utils/TableGen/Common/CodeGenRegisters.cpp | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
index a707afaafb1d2..bf1a2e767370d 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
@@ -1950,15 +1950,13 @@ void CodeGenRegBank::enforceRegUnitIntervals() {
   // it doesn't create a chain of swaps.
   SparseBitVector DontRenumberUnits;
 
-  auto GetRenumberedUnit =
-      [&](unsigned RegUnit) -> unsigned {
+  auto GetRenumberedUnit = [&](unsigned RegUnit) -> unsigned {
     if (RegUnitRenumbering[RegUnit] != ~0u)
       return RegUnitRenumbering[RegUnit];
     return RegUnit;
   };
 
-  auto IsContiguous =
-      [&](CodeGenRegister::RegUnitList &Units) -> bool {
+  auto IsContiguous = [&](CodeGenRegister::RegUnitList &Units) -> bool {
     unsigned LastUnit = Units.find_first();
     for (auto ThisUnit : llvm::make_range(++Units.begin(), Units.end())) {
       if (ThisUnit != LastUnit + 1)
@@ -1989,10 +1987,11 @@ void CodeGenRegBank::enforceRegUnitIntervals() {
          llvm::make_range(++RenumberedUnits.begin(), RenumberedUnits.end())) {
       if (ThisUnit != LastUnit + 1) {
         if (DontRenumberUnits.test(LastUnit + 1)) {
-          PrintFatalError("Cannot enforce regunit intervals for register " +
-                  Reg.getName() + ": unit " + Twine(LastUnit + 1) +
-                  " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
-                  ") has already been renumbered and cannot be swapped");
+          PrintFatalError(
+              "Cannot enforce regunit intervals for register " + Reg.getName() +
+              ": unit " + Twine(LastUnit + 1) +
+              " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
+              ") has already been renumbered and cannot be swapped");
         }
         LLVM_DEBUG(dbgs() << "  Renumbering unit " << ThisUnit << " to "
                           << (LastUnit + 1) << "\n");



More information about the llvm-commits mailing list