[llvm] [AMDGPU] Simplify `GCNRegPressure::RegKind` (NFC) (PR #142682)
Lucas Ramirez via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 3 15:22:29 PDT 2025
https://github.com/lucas-rami created https://github.com/llvm/llvm-project/pull/142682
This NFC simplifies the `GCNRegPressure::RegKind` enum so that instead of containing a pair of values for each type of register (one for non-tuple registers and one for tuple registers of that type) it only contains one value representing all registers of that type.
The `GCNRegPressure::Value` array is still sized as before, though all elements corresponding to tuple-kinds now start after all elements corresponding to non-tuple-kinds instead of the two being interleaved. This allows to simplify the `GCNRegPressure::inc` logic, eliminating the switch entirely.
>From 17f3b400450b9d53315ceebc110545daddb14e66 Mon Sep 17 00:00:00 2001
From: Lucas Ramirez <lucas.rami at proton.me>
Date: Tue, 3 Jun 2025 15:31:13 +0000
Subject: [PATCH] Merge regular/tuple registers in same kind (NFC)
---
llvm/lib/Target/AMDGPU/GCNRegPressure.cpp | 50 ++++++++---------------
llvm/lib/Target/AMDGPU/GCNRegPressure.h | 50 +++++++++++------------
2 files changed, 40 insertions(+), 60 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
index f74d12cfab0c0..ce213b91b1f7e 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp
@@ -35,17 +35,9 @@ bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
///////////////////////////////////////////////////////////////////////////////
// GCNRegPressure
-unsigned GCNRegPressure::getRegKind(Register Reg,
- const MachineRegisterInfo &MRI) {
- assert(Reg.isVirtual());
- const auto *const RC = MRI.getRegClass(Reg);
- const auto *STI =
- static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo());
- return STI->isSGPRClass(RC)
- ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
- : STI->isAGPRClass(RC)
- ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
- : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
+unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC,
+ const SIRegisterInfo *STI) {
+ return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR);
}
void GCNRegPressure::inc(unsigned Reg,
@@ -61,32 +53,22 @@ void GCNRegPressure::inc(unsigned Reg,
std::swap(NewMask, PrevMask);
Sign = -1;
}
+ assert(PrevMask < NewMask && "prev mask should always be lesser than new");
- switch (auto Kind = getRegKind(Reg, MRI)) {
- case SGPR32:
- case VGPR32:
- case AGPR32:
- Value[Kind] += Sign;
- break;
-
- case SGPR_TUPLE:
- case VGPR_TUPLE:
- case AGPR_TUPLE:
- assert(PrevMask < NewMask);
-
- Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
- Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
-
+ const TargetRegisterClass *RC = MRI.getRegClass(Reg);
+ const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
+ const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI);
+ unsigned RegKind = getRegKind(RC, STI);
+ if (TRI->getRegSizeInBits(*RC) != 32) {
+ // Reg is from a tuple register class.
if (PrevMask.none()) {
- assert(NewMask.any());
- const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
- Value[Kind] +=
- Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight;
+ unsigned TupleIdx = TOTAL_KINDS + RegKind;
+ Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight;
}
- break;
-
- default: llvm_unreachable("Unknown register kind");
+ // Pressure scales with number of new registers covered by the new mask.
+ Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
}
+ Value[RegKind] += Sign;
}
bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
@@ -226,7 +208,7 @@ bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
return Printable([&RP, ST](raw_ostream &OS) {
- OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
+ OS << "VGPRs: " << RP.getArchVGPRNum() << ' '
<< "AGPRs: " << RP.getAGPRNum();
if (ST)
OS << "(O"
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
index 4d1e96ff04e8d..f3d7983056cfc 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
@@ -29,34 +29,26 @@ class raw_ostream;
class SlotIndex;
struct GCNRegPressure {
- enum RegKind {
- SGPR32,
- SGPR_TUPLE,
- VGPR32,
- VGPR_TUPLE,
- AGPR32,
- AGPR_TUPLE,
- TOTAL_KINDS
- };
+ enum RegKind { SGPR, VGPR, AGPR, TOTAL_KINDS };
GCNRegPressure() {
clear();
}
- bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
+ bool empty() const { return !Value[SGPR] && !Value[VGPR] && !Value[AGPR]; }
- void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
+ void clear() { std::fill(&Value[0], &Value[ValueArraySize], 0); }
/// \returns the SGPR32 pressure
- unsigned getSGPRNum() const { return Value[SGPR32]; }
+ unsigned getSGPRNum() const { return Value[SGPR]; }
/// \returns the aggregated ArchVGPR32, AccVGPR32 pressure dependent upon \p
/// UnifiedVGPRFile
unsigned getVGPRNum(bool UnifiedVGPRFile) const {
if (UnifiedVGPRFile) {
- return Value[AGPR32] ? getUnifiedVGPRNum(Value[VGPR32], Value[AGPR32])
- : Value[VGPR32] + Value[AGPR32];
+ return Value[AGPR] ? getUnifiedVGPRNum(Value[VGPR], Value[AGPR])
+ : Value[VGPR];
}
- return std::max(Value[VGPR32], Value[AGPR32]);
+ return std::max(Value[VGPR], Value[AGPR]);
}
/// Returns the aggregated VGPR pressure, assuming \p NumArchVGPRs ArchVGPRs
@@ -68,13 +60,14 @@ struct GCNRegPressure {
}
/// \returns the ArchVGPR32 pressure
- unsigned getArchVGPRNum() const { return Value[VGPR32]; }
+ unsigned getArchVGPRNum() const { return Value[VGPR]; }
/// \returns the AccVGPR32 pressure
- unsigned getAGPRNum() const { return Value[AGPR32]; }
+ unsigned getAGPRNum() const { return Value[AGPR]; }
- unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
- Value[AGPR_TUPLE]); }
- unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
+ unsigned getVGPRTuplesWeight() const {
+ return std::max(Value[TOTAL_KINDS + VGPR], Value[TOTAL_KINDS + AGPR]);
+ }
+ unsigned getSGPRTuplesWeight() const { return Value[TOTAL_KINDS + SGPR]; }
unsigned getOccupancy(const GCNSubtarget &ST) const {
return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
@@ -106,7 +99,7 @@ struct GCNRegPressure {
unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
bool operator==(const GCNRegPressure &O) const {
- return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
+ return std::equal(&Value[0], &Value[ValueArraySize], O.Value);
}
bool operator!=(const GCNRegPressure &O) const {
@@ -114,13 +107,13 @@ struct GCNRegPressure {
}
GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
- for (unsigned I = 0; I < TOTAL_KINDS; ++I)
+ for (unsigned I = 0; I < ValueArraySize; ++I)
Value[I] += RHS.Value[I];
return *this;
}
GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
- for (unsigned I = 0; I < TOTAL_KINDS; ++I)
+ for (unsigned I = 0; I < ValueArraySize; ++I)
Value[I] -= RHS.Value[I];
return *this;
}
@@ -128,9 +121,14 @@ struct GCNRegPressure {
void dump() const;
private:
- unsigned Value[TOTAL_KINDS];
+ static constexpr unsigned ValueArraySize = TOTAL_KINDS * 2;
+
+ /// Pressure for all register kinds (first all regular registers kinds, then
+ /// all tuple register kinds).
+ unsigned Value[ValueArraySize];
- static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
+ static unsigned getRegKind(const TargetRegisterClass *RC,
+ const SIRegisterInfo *STI);
friend GCNRegPressure max(const GCNRegPressure &P1,
const GCNRegPressure &P2);
@@ -140,7 +138,7 @@ struct GCNRegPressure {
inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
GCNRegPressure Res;
- for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
+ for (unsigned I = 0; I < GCNRegPressure::ValueArraySize; ++I)
Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
return Res;
}
More information about the llvm-commits
mailing list