[llvm] [TableGen][SIInsertWaitcnts] use RegIntervals for AMDGPU (PR #174888)
Ryan Mitchell via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 13 11:45:40 PST 2026
https://github.com/RyanRio updated https://github.com/llvm/llvm-project/pull/174888
>From adaa750379d1a7e94aafe863a573306d960c2868 Mon Sep 17 00:00:00 2001
From: Ryan Mitchell <Ryan.Mitchell at amd.com>
Date: Wed, 7 Jan 2026 15:23:57 -0800
Subject: [PATCH 1/2] [TableGen][SIInsertWaitcnts] use RegIntervals for AMDGPU
---
llvm/include/llvm/MC/MCRegisterInfo.h | 18 +++-
llvm/include/llvm/Target/Target.td | 5 +
llvm/lib/Target/AMDGPU/AMDGPU.td | 1 +
llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp | 76 +++++++------
.../TableGen/regunit-intervals-impossible.td | 35 ++++++
llvm/test/TableGen/regunit-intervals.td | 73 +++++++++++++
.../TableGen/Common/CodeGenRegisters.cpp | 101 ++++++++++++++++++
llvm/utils/TableGen/Common/CodeGenRegisters.h | 9 ++
llvm/utils/TableGen/RegisterInfoEmitter.cpp | 17 +++
9 files changed, 300 insertions(+), 35 deletions(-)
create mode 100644 llvm/test/TableGen/regunit-intervals-impossible.td
create mode 100644 llvm/test/TableGen/regunit-intervals.td
diff --git a/llvm/include/llvm/MC/MCRegisterInfo.h b/llvm/include/llvm/MC/MCRegisterInfo.h
index f4897b6a406fb..76ef62da3d35b 100644
--- a/llvm/include/llvm/MC/MCRegisterInfo.h
+++ b/llvm/include/llvm/MC/MCRegisterInfo.h
@@ -180,6 +180,7 @@ class LLVM_ABI MCRegisterInfo {
unsigned NumSubRegIndices; // Number of subreg indices.
const uint16_t *RegEncodingTable; // Pointer to array of register
// encodings.
+ const unsigned (*RegUnitIntervals)[2]; // Pointer to regunit interval table.
unsigned L2DwarfRegsSize;
unsigned EHL2DwarfRegsSize;
@@ -286,7 +287,8 @@ class LLVM_ABI MCRegisterInfo {
const int16_t *DL, const LaneBitmask *RUMS,
const char *Strings, const char *ClassStrings,
const uint16_t *SubIndices, unsigned NumIndices,
- const uint16_t *RET) {
+ const uint16_t *RET,
+ const unsigned (*RUI)[2] = nullptr) {
Desc = D;
NumRegs = NR;
RAReg = RA;
@@ -302,6 +304,7 @@ class LLVM_ABI MCRegisterInfo {
SubRegIndices = SubIndices;
NumSubRegIndices = NumIndices;
RegEncodingTable = RET;
+ RegUnitIntervals = RUI;
// Initialize DWARF register mapping variables
EHL2DwarfRegs = nullptr;
@@ -511,6 +514,19 @@ class LLVM_ABI MCRegisterInfo {
/// Returns true if the two registers are equal or alias each other.
bool regsOverlap(MCRegister RegA, MCRegister RegB) const;
+
+ /// Returns true if this target uses regunit intervals.
+ bool hasRegUnitIntervals() const { return RegUnitIntervals != nullptr; }
+
+ /// Returns an iterator range over all native regunits in the RegUnitInterval
+ /// table for \p Reg.
+ iota_range<unsigned> regunits_interval(MCRegister Reg) const {
+ assert(hasRegUnitIntervals() &&
+ "Target does not support regunit intervals");
+ assert(Reg.id() < NumRegs && "Invalid register number");
+ return seq<unsigned>(RegUnitIntervals[Reg.id()][0],
+ RegUnitIntervals[Reg.id()][1]);
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td
index 315de55b75510..45ed2a674860b 100644
--- a/llvm/include/llvm/Target/Target.td
+++ b/llvm/include/llvm/Target/Target.td
@@ -1935,6 +1935,11 @@ class Target {
// setting hasExtraDefRegAllocReq and hasExtraSrcRegAllocReq to 1
// for all opcodes if this flag is set to 0.
int AllowRegisterRenaming = 0;
+
+ // RegistersAreIntervals - Controls whether this target requires
+ // all Registers to have RegUnit intervals. Will attempt to reorder
+ // RegUnits to enforce this, and if a solution is not found, will error.
+ bit RegistersAreIntervals = 0;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index f015353ee4f3a..f8ed85ebb70e5 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -2327,6 +2327,7 @@ def AMDGPU : Target {
VOP3_DPPAsmParserVariant];
let AssemblyWriters = [AMDGPUAsmWriter];
let AllowRegisterRenaming = 1;
+ let RegistersAreIntervals = 1;
}
// Dummy Instruction itineraries for pseudo instructions
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 20ba44e3b8e59..cea3d3aef27ae 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -133,11 +133,15 @@ static unsigned getWaitCountMax(const AMDGPU::HardwareLimits &Limits,
return 0;
}
}
+// An integer representing a particular Unit of storage.
+// On AMDGPU targets, storage that spans multiple Units
+// is guaranteed to be contiguous.
+using StorageUnit = unsigned;
/// Integer IDs used to track vector memory locations we may have to wait on.
/// Encoded as u16 chunks:
///
-/// [0, REGUNITS_END ): MCRegUnit
+/// [0, REGUNITS_END ): RegUnit IDs
/// [LDSDMA_BEGIN, LDSDMA_END ) : LDS DMA IDs
///
/// NOTE: The choice of encoding these as "u16 chunks" is arbitrary.
@@ -152,7 +156,7 @@ enum : VMEMID {
TRACKINGID_RANGE_LEN = (1 << 16),
// Important: MCRegUnits must always be tracked starting from 0, as we
- // need to be able to convert between a MCRegUnit and a VMEMID freely.
+ // need to be able to convert between a StorageUnit and a VMEMID freely.
REGUNITS_BEGIN = 0,
REGUNITS_END = REGUNITS_BEGIN + TRACKINGID_RANGE_LEN,
@@ -164,10 +168,16 @@ enum : VMEMID {
LDSDMA_END = LDSDMA_BEGIN + NUM_LDSDMA,
};
-/// Convert a MCRegUnit to a VMEMID.
-static constexpr VMEMID toVMEMID(MCRegUnit RU) {
- return static_cast<unsigned>(RU);
-}
+struct HardwareLimits {
+ unsigned LoadcntMax; // Corresponds to VMcnt prior to gfx12.
+ unsigned ExpcntMax;
+ unsigned DscntMax; // Corresponds to LGKMcnt prior to gfx12.
+ unsigned StorecntMax; // Corresponds to VScnt in gfx10/gfx11.
+ unsigned SamplecntMax; // gfx12+ only.
+ unsigned BvhcntMax; // gfx12+ only.
+ unsigned KmcntMax; // gfx12+ only.
+ unsigned XcntMax; // gfx1250.
+};
#define AMDGPU_DECLARE_WAIT_EVENTS(DECL) \
DECL(VMEM_ACCESS) /* vmem read & write (pre-gfx10), vmem read (gfx10+) */ \
@@ -708,7 +718,7 @@ class WaitcntBrackets {
return getScoreUB(T) - getScoreLB(T);
}
- unsigned getSGPRScore(MCRegUnit RU, InstCounterType T) const {
+ unsigned getSGPRScore(StorageUnit RU, InstCounterType T) const {
auto It = SGPRs.find(RU);
return It != SGPRs.end() ? It->second.Scores[getSgprScoresIdx(T)] : 0;
}
@@ -784,8 +794,8 @@ class WaitcntBrackets {
// Return true if there might be pending writes to the vgpr-interval by VMEM
// instructions with types different from V.
bool hasOtherPendingVmemTypes(MCPhysReg Reg, VmemType V) const {
- for (MCRegUnit RU : regunits(Reg)) {
- auto It = VMem.find(toVMEMID(RU));
+ for (unsigned RU : regunits_interval(Reg)) {
+ auto It = VMem.find(RU);
if (It != VMem.end() && (It->second.VMEMTypes & ~(1 << V)))
return true;
}
@@ -793,8 +803,8 @@ class WaitcntBrackets {
}
void clearVgprVmemTypes(MCPhysReg Reg) {
- for (MCRegUnit RU : regunits(Reg)) {
- if (auto It = VMem.find(toVMEMID(RU)); It != VMem.end()) {
+ for (unsigned RU : regunits_interval(Reg)) {
+ if (auto It = VMem.find(RU); It != VMem.end()) {
It->second.VMEMTypes = 0;
if (It->second.empty())
VMem.erase(It);
@@ -837,15 +847,15 @@ class WaitcntBrackets {
static bool mergeScore(const MergeInfo &M, unsigned &Score,
unsigned OtherScore);
- iterator_range<MCRegUnitIterator> regunits(MCPhysReg Reg) const {
+ iota_range<StorageUnit> regunits_interval(MCPhysReg Reg) const {
assert(Reg != AMDGPU::SCC && "Shouldn't be used on SCC");
if (!Context->TRI->isInAllocatableClass(Reg))
- return {{}, {}};
+ return seq<StorageUnit>(0, 0);
const TargetRegisterClass *RC = Context->TRI->getPhysRegBaseClass(Reg);
unsigned Size = Context->TRI->getRegSizeInBits(*RC);
if (Size == 16 && Context->ST->hasD16Writes32BitVgpr())
Reg = Context->TRI->get32BitRegister(Reg);
- return Context->TRI->regunits(Reg);
+ return Context->TRI->regunits_interval(Reg);
}
void setScoreLB(InstCounterType T, unsigned Val) {
@@ -870,11 +880,11 @@ class WaitcntBrackets {
if (Reg == AMDGPU::SCC) {
SCCScore = Val;
} else if (TRI->isVectorRegister(*Context->MRI, Reg)) {
- for (MCRegUnit RU : regunits(Reg))
- VMem[toVMEMID(RU)].Scores[T] = Val;
+ for (unsigned RU : regunits_interval(Reg))
+ VMem[RU].Scores[T] = Val;
} else if (TRI->isSGPRReg(*Context->MRI, Reg)) {
auto STy = getSgprScoresIdx(T);
- for (MCRegUnit RU : regunits(Reg))
+ for (unsigned RU : regunits_interval(Reg))
SGPRs[RU].Scores[STy] = Val;
} else {
llvm_unreachable("Register cannot be tracked/unknown register!");
@@ -933,7 +943,7 @@ class WaitcntBrackets {
};
DenseMap<VMEMID, VMEMInfo> VMem; // VGPR + LDS DMA
- DenseMap<MCRegUnit, SGPRInfo> SGPRs;
+ DenseMap<StorageUnit, SGPRInfo> SGPRs;
// Reg score for SCC.
unsigned SCCScore = 0;
@@ -1137,8 +1147,8 @@ void WaitcntBrackets::updateByEvent(WaitEventType E, MachineInstr &Inst) {
// this with another potential dependency
if (hasPointSampleAccel(Inst))
TypesMask |= 1 << VMEM_NOSAMPLER;
- for (MCRegUnit RU : regunits(Op.getReg().asMCReg()))
- VMem[toVMEMID(RU)].VMEMTypes |= TypesMask;
+ for (unsigned RU : regunits_interval(Op.getReg().asMCReg()))
+ VMem[RU].VMEMTypes |= TypesMask;
}
}
setScoreByOperand(Op, T, CurrScore);
@@ -1264,7 +1274,7 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
// Also need to print sgpr scores for lgkm_cnt or xcnt.
if (isSmemCounter(T)) {
- SmallVector<MCRegUnit> SortedSMEMIDs(SGPRs.keys());
+ SmallVector<StorageUnit> SortedSMEMIDs(SGPRs.keys());
sort(SortedSMEMIDs);
for (auto ID : SortedSMEMIDs) {
unsigned RegScore = SGPRs.at(ID).Scores[getSgprScoresIdx(T)];
@@ -1380,10 +1390,9 @@ void WaitcntBrackets::determineWaitForPhysReg(InstCounterType T, MCPhysReg Reg,
determineWaitForScore(T, SCCScore, Wait);
} else {
bool IsVGPR = Context->TRI->isVectorRegister(*Context->MRI, Reg);
- for (MCRegUnit RU : regunits(Reg))
+ for (unsigned RU : regunits_interval(Reg))
determineWaitForScore(
- T, IsVGPR ? getVMemScore(toVMEMID(RU), T) : getSGPRScore(RU, T),
- Wait);
+ T, IsVGPR ? getVMemScore(RU, T) : getSGPRScore(RU, T), Wait);
}
}
@@ -3114,9 +3123,9 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
bool VMemInvalidated = false;
// DS optimization only applies to GFX12+ where DS_CNT is separate.
bool DSInvalidated = !ST->hasExtendedWaitCounts();
- DenseSet<MCRegUnit> VgprUse;
- DenseSet<MCRegUnit> VgprDefVMEM;
- DenseSet<MCRegUnit> VgprDefDS;
+ DenseSet<StorageUnit> VgprUse;
+ DenseSet<StorageUnit> VgprDefVMEM;
+ DenseSet<StorageUnit> VgprDefDS;
for (MachineBasicBlock *MBB : ML->blocks()) {
bool SeenDSStoreInCurrMBB = false;
@@ -3137,7 +3146,7 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
if (Op.isDebug() || !TRI->isVectorRegister(*MRI, Op.getReg()))
continue;
// Vgpr use
- for (MCRegUnit RU : TRI->regunits(Op.getReg().asMCReg())) {
+ for (unsigned RU : TRI->regunits_interval(Op.getReg().asMCReg())) {
// If we find a register that is loaded inside the loop, 1. and 2.
// are invalidated.
if (VgprDefVMEM.contains(RU))
@@ -3154,20 +3163,19 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
VgprUse.insert(RU);
// Check if this register has a pending VMEM load from outside the
// loop (value loaded outside and used inside).
- VMEMID ID = toVMEMID(RU);
bool HasPendingVMEM =
- Brackets.getVMemScore(ID, LOAD_CNT) >
+ Brackets.getVMemScore(RU, LOAD_CNT) >
Brackets.getScoreLB(LOAD_CNT) ||
- Brackets.getVMemScore(ID, SAMPLE_CNT) >
+ Brackets.getVMemScore(RU, SAMPLE_CNT) >
Brackets.getScoreLB(SAMPLE_CNT) ||
- Brackets.getVMemScore(ID, BVH_CNT) > Brackets.getScoreLB(BVH_CNT);
+ Brackets.getVMemScore(RU, BVH_CNT) > Brackets.getScoreLB(BVH_CNT);
if (HasPendingVMEM)
UsesVgprLoadedOutsideVMEM = true;
// Check if loaded outside the loop via DS (not VMEM/FLAT).
// Only consider it a DS load if there's no pending VMEM load for
// this register, since FLAT can set both counters.
if (!HasPendingVMEM &&
- Brackets.getVMemScore(ID, DS_CNT) > Brackets.getScoreLB(DS_CNT))
+ Brackets.getVMemScore(RU, DS_CNT) > Brackets.getScoreLB(DS_CNT))
UsesVgprLoadedOutsideDS = true;
}
}
@@ -3175,7 +3183,7 @@ SIInsertWaitcnts::getPreheaderFlushFlags(MachineLoop *ML,
// VMem load vgpr def
if (isVMEMOrFlatVMEM(MI) && MI.mayLoad()) {
for (const MachineOperand &Op : MI.all_defs()) {
- for (MCRegUnit RU : TRI->regunits(Op.getReg().asMCReg())) {
+ for (unsigned RU : TRI->regunits_interval(Op.getReg().asMCReg())) {
// If we find a register that is loaded inside the loop, 1. and 2.
// are invalidated.
if (VgprUse.contains(RU))
diff --git a/llvm/test/TableGen/regunit-intervals-impossible.td b/llvm/test/TableGen/regunit-intervals-impossible.td
new file mode 100644
index 0000000000000..9b4f6561178b7
--- /dev/null
+++ b/llvm/test/TableGen/regunit-intervals-impossible.td
@@ -0,0 +1,35 @@
+// RUN: not llvm-tblgen -gen-register-info -I %p/../../include %s 2>&1 | FileCheck %s
+
+include "llvm/Target/Target.td"
+
+def TestInstrInfo : InstrInfo;
+def TestTarget : Target {
+ let InstructionSet = TestInstrInfo;
+ let RegistersAreIntervals = 1;
+}
+
+def sub_lo : SubRegIndex<32>;
+def sub_hi : SubRegIndex<32, 32>;
+
+let Namespace = "Test" in {
+ def R1 : Register<"r1">; // unit 0
+ def R2 : Register<"r2">; // unit 1
+ def R3 : Register<"r3">; // unit 2
+
+ // First composite: units {0, 1} - contiguous, OK
+ def R1_R2 : Register<"r1_r2"> {
+ let SubRegs = [R1, R2];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+
+ // Second composite: units {0, 2} - non-contiguous!
+ // Algorithm will swap 1 and 2, making R1_R2 not contiguous
+ def R1_R3 : Register<"r1_r3"> {
+ let SubRegs = [R1, R3];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+}
+
+def GPR32 : RegisterClass<"Test", [i32], 32, (add R1, R2, R3, R4)>;
+
+// CHECK: error: Cannot enforce regunit intervals, final renumbering did not produce contiguous units for register R10_R11
diff --git a/llvm/test/TableGen/regunit-intervals.td b/llvm/test/TableGen/regunit-intervals.td
new file mode 100644
index 0000000000000..3d5dcd7c404dc
--- /dev/null
+++ b/llvm/test/TableGen/regunit-intervals.td
@@ -0,0 +1,73 @@
+// RUN: llvm-tblgen -gen-register-info -I %p/../../include %s | FileCheck %s
+
+include "llvm/Target/Target.td"
+
+def TestInstrInfo : InstrInfo;
+def TestTarget : Target {
+ let InstructionSet = TestInstrInfo;
+ let RegistersAreIntervals = 1;
+}
+
+def sub_lo : SubRegIndex<32>;
+def sub_hi : SubRegIndex<32, 32>;
+
+let Namespace = "Test" in {
+ // Simple 32-bit registers (each gets 1 regunit)
+ def R0 : Register<"r0">;
+ def R1 : Register<"r1">;
+ def R2 : Register<"r2">;
+ def R3 : Register<"r3">;
+
+ // 64-bit register composed of R0:R1 (gets 2 regunits)
+ def R0_R1 : Register<"r0_r1"> {
+ let SubRegs = [R0, R1];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+
+ // 64-bit register composed of R2:R3 (gets 2 regunits)
+ def R2_R3 : Register<"r2_r3"> {
+ let SubRegs = [R2, R3];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+}
+
+// CHECK: extern const uint16_t TestRegUnitIntervals[][2] = {
+// CHECK-NEXT: { 0, 1 },
+// CHECK-NEXT: { 1, 2 },
+// CHECK-NEXT: { 2, 3 },
+// CHECK-NEXT: { 3, 4 },
+
+let Namespace = "Test" in {
+ def R4 : Register<"r4">; // Gets unit 4
+ def R5 : Register<"r5">; // Gets unit 5
+ def R6 : Register<"r6">; // Gets unit 6
+ def R7 : Register<"r7">; // Gets unit 7
+
+ // This register skips R5, creating non-contiguous units {4, 6}
+ def R4_R6 : Register<"r4_r6"> {
+ let SubRegs = [R4, R6];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+
+ // This register skips R6, creating non-contiguous units {5, 7}
+ def R5_R7 : Register<"r4_r6"> {
+ let SubRegs = [R5, R7];
+ let SubRegIndices = [sub_lo, sub_hi];
+ }
+}
+
+
+def GPR32 : RegisterClass<"Test", [i32], 32, (add R0, R1, R2, R3)>;
+def GPR64 : RegisterClass<"Test", [i64], 64, (add R0_R1, R2_R3)>;
+
+// Note R5 is assigned 6,7 so that R6 gets 5,6
+// CHECK-NEXT: { 4, 5 },
+// CHECK-NEXT: { 6, 7 },
+// CHECK-NEXT: { 5, 6 },
+// CHECK-NEXT: { 7, 8 },
+// All contiguous
+// CHECK-NEXT: { 0, 2 },
+// CHECK-NEXT: { 2, 4 },
+// CHECK-NEXT: { 4, 6 },
+// CHECK-NEXT: { 6, 8 },
+// CHECK-NEXT: };
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
index 65a2594859e69..a707afaafb1d2 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
@@ -1929,6 +1929,102 @@ void CodeGenRegBank::computeRegUnitWeights() {
}
}
+// Enforce that all registers are intervals of regunits if the target
+// requests this property. This will renumber regunits to ensure the
+// interval property holds, or error out if it cannot be satisfied.
+void CodeGenRegBank::enforceRegUnitIntervals() {
+ std::vector<const Record *> Targets =
+ Records.getAllDerivedDefinitions("Target");
+
+ if (Targets.empty())
+ return;
+
+ const Record *Target = Targets[0];
+ if (!Target->getValueAsBit("RegistersAreIntervals"))
+ return;
+
+ LLVM_DEBUG(dbgs() << "Enforcing regunit intervals for target\n");
+ std::vector<unsigned> RegUnitRenumbering(RegUnits.size(), ~0u);
+
+ // RegUnits that have been renumbered from X -> Y. Y is what is marked so that
+ // it doesn't create a chain of swaps.
+ SparseBitVector DontRenumberUnits;
+
+ auto GetRenumberedUnit =
+ [&](unsigned RegUnit) -> unsigned {
+ if (RegUnitRenumbering[RegUnit] != ~0u)
+ return RegUnitRenumbering[RegUnit];
+ return RegUnit;
+ };
+
+ auto IsContiguous =
+ [&](CodeGenRegister::RegUnitList &Units) -> bool {
+ unsigned LastUnit = Units.find_first();
+ for (auto ThisUnit : llvm::make_range(++Units.begin(), Units.end())) {
+ if (ThisUnit != LastUnit + 1)
+ return false;
+ LastUnit = ThisUnit;
+ }
+ return true;
+ };
+
+ // Process registers in definition order
+ for (CodeGenRegister &Reg : Registers) {
+ LLVM_DEBUG(dbgs() << "Processing register " << Reg.getName() << "\n");
+ const auto &Units = Reg.getNativeRegUnits();
+ if (Units.empty())
+ continue;
+ SparseBitVector RenumberedUnits;
+ // First renumber all the units for this register according to previous
+ // renumbering.
+ LLVM_DEBUG(dbgs() << " Original (Renumbered) units:");
+ for (unsigned U : Units) {
+ LLVM_DEBUG(dbgs() << " " << U << "(" << GetRenumberedUnit(U) << "), ");
+ RenumberedUnits.set(GetRenumberedUnit(U));
+ }
+ LLVM_DEBUG(dbgs() << "\n");
+
+ unsigned LastUnit = RenumberedUnits.find_first();
+ for (auto ThisUnit :
+ llvm::make_range(++RenumberedUnits.begin(), RenumberedUnits.end())) {
+ if (ThisUnit != LastUnit + 1) {
+ if (DontRenumberUnits.test(LastUnit + 1)) {
+ PrintFatalError("Cannot enforce regunit intervals for register " +
+ Reg.getName() + ": unit " + Twine(LastUnit + 1) +
+ " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
+ ") has already been renumbered and cannot be swapped");
+ }
+ LLVM_DEBUG(dbgs() << " Renumbering unit " << ThisUnit << " to "
+ << (LastUnit + 1) << "\n");
+ RegUnitRenumbering[LastUnit + 1] = ThisUnit;
+ RegUnitRenumbering[ThisUnit] = LastUnit + 1;
+ DontRenumberUnits.set(LastUnit + 1);
+ ThisUnit = LastUnit + 1;
+ }
+ LastUnit = ThisUnit;
+ }
+ }
+
+ // Apply the renumbering to all registers
+ for (CodeGenRegister &Reg : Registers) {
+ CodeGenRegister::RegUnitList NewRegUnits;
+ for (unsigned OldUnit : Reg.getRegUnits())
+ NewRegUnits.set(GetRenumberedUnit(OldUnit));
+ Reg.setNewRegUnits(NewRegUnits);
+
+ CodeGenRegister::RegUnitList NewNativeUnits;
+ for (unsigned OldUnit : Reg.getNativeRegUnits())
+ NewNativeUnits.set(GetRenumberedUnit(OldUnit));
+ if (!IsContiguous(NewNativeUnits)) {
+ reportFatalInternalError("Cannot enforce regunit intervals, final "
+ "renumbering did not produce contiguous units "
+ "for register " +
+ Reg.getName() + "\n");
+ }
+ Reg.NativeRegUnits = NewNativeUnits;
+ }
+}
+
// Find a set in UniqueSets with the same elements as Set.
// Return an iterator into UniqueSets.
static std::vector<RegUnitSet>::const_iterator
@@ -2209,6 +2305,11 @@ void CodeGenRegBank::computeDerivedInfo() {
computeRegUnitWeights();
Records.getTimer().stopTimer();
+ // Enforce regunit intervals if requested by the target.
+ Records.getTimer().startTimer("Enforce regunit intervals");
+ enforceRegUnitIntervals();
+ Records.getTimer().stopTimer();
+
// Compute a unique set of RegUnitSets. One for each RegClass and inferred
// supersets for the union of overlapping sets.
computeRegUnitSets();
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.h b/llvm/utils/TableGen/Common/CodeGenRegisters.h
index a3ad0b797a704..176cb9dafd22b 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.h
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.h
@@ -169,6 +169,8 @@ class CodeGenSubRegIndex {
/// CodeGenRegister - Represents a register definition.
class CodeGenRegister {
+ friend class CodeGenRegBank;
+
public:
const Record *TheDef;
unsigned EnumValue;
@@ -257,6 +259,10 @@ class CodeGenRegister {
// This is only valid after computeSubRegs() completes.
const RegUnitList &getRegUnits() const { return RegUnits; }
+ void setNewRegUnits(const RegUnitList &NewRegUnits) {
+ RegUnits = NewRegUnits;
+ }
+
ArrayRef<LaneBitmask> getRegUnitLaneMasks() const {
return ArrayRef(RegUnitLaneMasks).slice(0, NativeRegUnits.count());
}
@@ -693,6 +699,9 @@ class CodeGenRegBank {
// Compute a weight for each register unit created during getSubRegs.
void computeRegUnitWeights();
+ // Enforce that all registers are intervals of regunits if requested.
+ void enforceRegUnitIntervals();
+
// Create a RegUnitSet for each RegClass and infer superclasses.
void computeRegUnitSets();
diff --git a/llvm/utils/TableGen/RegisterInfoEmitter.cpp b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
index 02fd8648302f1..552b32736509c 100644
--- a/llvm/utils/TableGen/RegisterInfoEmitter.cpp
+++ b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
@@ -1044,6 +1044,23 @@ void RegisterInfoEmitter::runMCDesc(raw_ostream &OS, raw_ostream &MainOS,
}
OS << "};\n\n";
+ // Emit the table of register unit intervals.
+ if (Target.getTargetRecord()->getValueAsBit("RegistersAreIntervals")) {
+ OS << "extern const unsigned " << TargetName
+ << "RegUnitIntervals[][2] = {\n";
+ for (const auto &Reg : Regs) {
+ const auto &Units = Reg.getNativeRegUnits();
+ if (Units.empty()) {
+ OS << " { 0, 0 },\n";
+ } else {
+ unsigned First = Units.find_first();
+ unsigned Last = Units.find_last();
+ OS << " { " << First << ", " << Last + 1 << " },\n";
+ }
+ }
+ OS << "};\n\n";
+ }
+
const auto &RegisterClasses = RegBank.getRegClasses();
// Loop over all of the register classes... emitting each one.
>From 807a9cef83cc70d551c18239bc95a51385e85369 Mon Sep 17 00:00:00 2001
From: Ryan Mitchell <Ryan.Mitchell at amd.com>
Date: Tue, 13 Jan 2026 11:45:25 -0800
Subject: [PATCH 2/2] Format
---
llvm/utils/TableGen/Common/CodeGenRegisters.cpp | 15 +++++++--------
1 file changed, 7 insertions(+), 8 deletions(-)
diff --git a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
index a707afaafb1d2..bf1a2e767370d 100644
--- a/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenRegisters.cpp
@@ -1950,15 +1950,13 @@ void CodeGenRegBank::enforceRegUnitIntervals() {
// it doesn't create a chain of swaps.
SparseBitVector DontRenumberUnits;
- auto GetRenumberedUnit =
- [&](unsigned RegUnit) -> unsigned {
+ auto GetRenumberedUnit = [&](unsigned RegUnit) -> unsigned {
if (RegUnitRenumbering[RegUnit] != ~0u)
return RegUnitRenumbering[RegUnit];
return RegUnit;
};
- auto IsContiguous =
- [&](CodeGenRegister::RegUnitList &Units) -> bool {
+ auto IsContiguous = [&](CodeGenRegister::RegUnitList &Units) -> bool {
unsigned LastUnit = Units.find_first();
for (auto ThisUnit : llvm::make_range(++Units.begin(), Units.end())) {
if (ThisUnit != LastUnit + 1)
@@ -1989,10 +1987,11 @@ void CodeGenRegBank::enforceRegUnitIntervals() {
llvm::make_range(++RenumberedUnits.begin(), RenumberedUnits.end())) {
if (ThisUnit != LastUnit + 1) {
if (DontRenumberUnits.test(LastUnit + 1)) {
- PrintFatalError("Cannot enforce regunit intervals for register " +
- Reg.getName() + ": unit " + Twine(LastUnit + 1) +
- " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
- ") has already been renumbered and cannot be swapped");
+ PrintFatalError(
+ "Cannot enforce regunit intervals for register " + Reg.getName() +
+ ": unit " + Twine(LastUnit + 1) +
+ " (root: " + RegUnits[LastUnit + 1].Roots[0]->getName() +
+ ") has already been renumbered and cannot be swapped");
}
LLVM_DEBUG(dbgs() << " Renumbering unit " << ThisUnit << " to "
<< (LastUnit + 1) << "\n");
More information about the llvm-commits
mailing list