[llvm] c66dee4 - [AMDGPU] Refactor several functions for merging with downstream work. (#110562)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 1 08:28:58 PDT 2024
Author: Gang Chen
Date: 2024-10-01T08:28:55-07:00
New Revision: c66dee4c6bd650ef20105532a311a95abb25ece5
URL: https://github.com/llvm/llvm-project/commit/c66dee4c6bd650ef20105532a311a95abb25ece5
DIFF: https://github.com/llvm/llvm-project/commit/c66dee4c6bd650ef20105532a311a95abb25ece5.diff
LOG: [AMDGPU] Refactor several functions for merging with downstream work. (#110562)
For setScore, the root function is setScoreByInterval with RegInterval
input
For determineWait, the root function is determineWait with RegInterval
input
Added:
Modified:
llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index 80a7529002ac90..e64b35d230d486 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -310,7 +310,14 @@ class WaitcntBrackets {
bool counterOutOfOrder(InstCounterType T) const;
void simplifyWaitcnt(AMDGPU::Waitcnt &Wait) const;
void simplifyWaitcnt(InstCounterType T, unsigned &Count) const;
- void determineWait(InstCounterType T, int RegNo, AMDGPU::Waitcnt &Wait) const;
+
+ void determineWait(InstCounterType T, RegInterval Interval,
+ AMDGPU::Waitcnt &Wait) const;
+ void determineWait(InstCounterType T, int RegNo,
+ AMDGPU::Waitcnt &Wait) const {
+ determineWait(T, {RegNo, RegNo + 1}, Wait);
+ }
+
void applyWaitcnt(const AMDGPU::Waitcnt &Wait);
void applyWaitcnt(InstCounterType T, unsigned Count);
void updateByEvent(const SIInstrInfo *TII, const SIRegisterInfo *TRI,
@@ -345,16 +352,22 @@ class WaitcntBrackets {
LastFlat[DS_CNT] = ScoreUBs[DS_CNT];
}
- // Return true if there might be pending writes to the specified vgpr by VMEM
+ // Return true if there might be pending writes to the vgpr-interval by VMEM
// instructions with types
diff erent from V.
- bool hasOtherPendingVmemTypes(int GprNo, VmemType V) const {
- assert(GprNo < NUM_ALL_VGPRS);
- return VgprVmemTypes[GprNo] & ~(1 << V);
+ bool hasOtherPendingVmemTypes(RegInterval Interval, VmemType V) const {
+ for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+ assert(RegNo < NUM_ALL_VGPRS);
+ if (VgprVmemTypes[RegNo] & ~(1 << V))
+ return true;
+ }
+ return false;
}
- void clearVgprVmemTypes(int GprNo) {
- assert(GprNo < NUM_ALL_VGPRS);
- VgprVmemTypes[GprNo] = 0;
+ void clearVgprVmemTypes(RegInterval Interval) {
+ for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+ assert(RegNo < NUM_ALL_VGPRS);
+ VgprVmemTypes[RegNo] = 0;
+ }
}
void setStateOnFunctionEntryOrReturn() {
@@ -396,19 +409,16 @@ class WaitcntBrackets {
}
void setRegScore(int GprNo, InstCounterType T, unsigned Val) {
- if (GprNo < NUM_ALL_VGPRS) {
- VgprUB = std::max(VgprUB, GprNo);
- VgprScores[T][GprNo] = Val;
- } else {
- assert(T == SmemAccessCounter);
- SgprUB = std::max(SgprUB, GprNo - NUM_ALL_VGPRS);
- SgprScores[GprNo - NUM_ALL_VGPRS] = Val;
- }
+ setScoreByInterval({GprNo, GprNo + 1}, T, Val);
}
- void setExpScore(const MachineInstr *MI, const SIRegisterInfo *TRI,
- const MachineRegisterInfo *MRI, const MachineOperand &Op,
- unsigned Val);
+ void setScoreByInterval(RegInterval Interval, InstCounterType CntTy,
+ unsigned Score);
+
+ void setScoreByOperand(const MachineInstr *MI, const SIRegisterInfo *TRI,
+ const MachineRegisterInfo *MRI,
+ const MachineOperand &Op, InstCounterType CntTy,
+ unsigned Val);
const GCNSubtarget *ST = nullptr;
InstCounterType MaxCounter = NUM_EXTENDED_INST_CNTS;
@@ -772,17 +782,30 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
return Result;
}
-void WaitcntBrackets::setExpScore(const MachineInstr *MI,
- const SIRegisterInfo *TRI,
- const MachineRegisterInfo *MRI,
- const MachineOperand &Op, unsigned Val) {
- RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
- assert(TRI->isVectorRegister(*MRI, Op.getReg()));
+void WaitcntBrackets::setScoreByInterval(RegInterval Interval,
+ InstCounterType CntTy,
+ unsigned Score) {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
- setRegScore(RegNo, EXP_CNT, Val);
+ if (RegNo < NUM_ALL_VGPRS) {
+ VgprUB = std::max(VgprUB, RegNo);
+ VgprScores[CntTy][RegNo] = Score;
+ } else {
+ assert(CntTy == SmemAccessCounter);
+ SgprUB = std::max(SgprUB, RegNo - NUM_ALL_VGPRS);
+ SgprScores[RegNo - NUM_ALL_VGPRS] = Score;
+ }
}
}
+void WaitcntBrackets::setScoreByOperand(const MachineInstr *MI,
+ const SIRegisterInfo *TRI,
+ const MachineRegisterInfo *MRI,
+ const MachineOperand &Op,
+ InstCounterType CntTy, unsigned Score) {
+ RegInterval Interval = getRegInterval(MI, MRI, TRI, Op);
+ setScoreByInterval(Interval, CntTy, Score);
+}
+
void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI,
@@ -806,57 +829,61 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
// All GDS operations must protect their address register (same as
// export.)
if (const auto *AddrOp = TII->getNamedOperand(Inst, AMDGPU::OpName::addr))
- setExpScore(&Inst, TRI, MRI, *AddrOp, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, *AddrOp, EXP_CNT, CurrScore);
if (Inst.mayStore()) {
if (const auto *Data0 =
TII->getNamedOperand(Inst, AMDGPU::OpName::data0))
- setExpScore(&Inst, TRI, MRI, *Data0, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, *Data0, EXP_CNT, CurrScore);
if (const auto *Data1 =
TII->getNamedOperand(Inst, AMDGPU::OpName::data1))
- setExpScore(&Inst, TRI, MRI, *Data1, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, *Data1, EXP_CNT, CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst) && !SIInstrInfo::isGWS(Inst) &&
Inst.getOpcode() != AMDGPU::DS_APPEND &&
Inst.getOpcode() != AMDGPU::DS_CONSUME &&
Inst.getOpcode() != AMDGPU::DS_ORDERED_COUNT) {
for (const MachineOperand &Op : Inst.all_uses()) {
if (TRI->isVectorRegister(*MRI, Op.getReg()))
- setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
}
}
} else if (TII->isFLAT(Inst)) {
if (Inst.mayStore()) {
- setExpScore(&Inst, TRI, MRI,
- *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
- CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI,
+ *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+ EXP_CNT, CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
- setExpScore(&Inst, TRI, MRI,
- *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
- CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI,
+ *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+ EXP_CNT, CurrScore);
}
} else if (TII->isMIMG(Inst)) {
if (Inst.mayStore()) {
- setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+ CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
- setExpScore(&Inst, TRI, MRI,
- *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
- CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI,
+ *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+ EXP_CNT, CurrScore);
}
} else if (TII->isMTBUF(Inst)) {
if (Inst.mayStore())
- setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+ CurrScore);
} else if (TII->isMUBUF(Inst)) {
if (Inst.mayStore()) {
- setExpScore(&Inst, TRI, MRI, Inst.getOperand(0), CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, Inst.getOperand(0), EXP_CNT,
+ CurrScore);
} else if (SIInstrInfo::isAtomicRet(Inst)) {
- setExpScore(&Inst, TRI, MRI,
- *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
- CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI,
+ *TII->getNamedOperand(Inst, AMDGPU::OpName::data),
+ EXP_CNT, CurrScore);
}
} else if (TII->isLDSDIR(Inst)) {
// LDSDIR instructions attach the score to the destination.
- setExpScore(&Inst, TRI, MRI,
- *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst), CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI,
+ *TII->getNamedOperand(Inst, AMDGPU::OpName::vdst),
+ EXP_CNT, CurrScore);
} else {
if (TII->isEXP(Inst)) {
// For export the destination registers are really temps that
@@ -865,15 +892,13 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
// score.
for (MachineOperand &DefMO : Inst.all_defs()) {
if (TRI->isVGPR(*MRI, DefMO.getReg())) {
- setRegScore(
- TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)),
- EXP_CNT, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, DefMO, EXP_CNT, CurrScore);
}
}
}
for (const MachineOperand &Op : Inst.all_uses()) {
if (TRI->isVectorRegister(*MRI, Op.getReg()))
- setExpScore(&Inst, TRI, MRI, Op, CurrScore);
+ setScoreByOperand(&Inst, TRI, MRI, Op, EXP_CNT, CurrScore);
}
}
} else /* LGKM_CNT || EXP_CNT || VS_CNT || NUM_INST_CNTS */ {
@@ -901,9 +926,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
VgprVmemTypes[RegNo] |= 1 << V;
}
}
- for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
- setRegScore(RegNo, T, CurrScore);
- }
+ setScoreByInterval(Interval, T, CurrScore);
}
if (Inst.mayStore() &&
(TII->isDS(Inst) || TII->mayWriteLDSThroughDMA(Inst))) {
@@ -1034,31 +1057,34 @@ void WaitcntBrackets::simplifyWaitcnt(InstCounterType T,
Count = ~0u;
}
-void WaitcntBrackets::determineWait(InstCounterType T, int RegNo,
+void WaitcntBrackets::determineWait(InstCounterType T, RegInterval Interval,
AMDGPU::Waitcnt &Wait) const {
- unsigned ScoreToWait = getRegScore(RegNo, T);
-
- // If the score of src_operand falls within the bracket, we need an
- // s_waitcnt instruction.
const unsigned LB = getScoreLB(T);
const unsigned UB = getScoreUB(T);
- if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
- if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
- !ST->hasFlatLgkmVMemCountInOrder()) {
- // If there is a pending FLAT operation, and this is a VMem or LGKM
- // waitcnt and the target can report early completion, then we need
- // to force a waitcnt 0.
- addWait(Wait, T, 0);
- } else if (counterOutOfOrder(T)) {
- // Counter can get decremented out-of-order when there
- // are multiple types event in the bracket. Also emit an s_wait counter
- // with a conservative value of 0 for the counter.
- addWait(Wait, T, 0);
- } else {
- // If a counter has been maxed out avoid overflow by waiting for
- // MAX(CounterType) - 1 instead.
- unsigned NeededWait = std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
- addWait(Wait, T, NeededWait);
+ for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
+ unsigned ScoreToWait = getRegScore(RegNo, T);
+
+ // If the score of src_operand falls within the bracket, we need an
+ // s_waitcnt instruction.
+ if ((UB >= ScoreToWait) && (ScoreToWait > LB)) {
+ if ((T == LOAD_CNT || T == DS_CNT) && hasPendingFlat() &&
+ !ST->hasFlatLgkmVMemCountInOrder()) {
+ // If there is a pending FLAT operation, and this is a VMem or LGKM
+ // waitcnt and the target can report early completion, then we need
+ // to force a waitcnt 0.
+ addWait(Wait, T, 0);
+ } else if (counterOutOfOrder(T)) {
+ // Counter can get decremented out-of-order when there
+ // are multiple types event in the bracket. Also emit an s_wait counter
+ // with a conservative value of 0 for the counter.
+ addWait(Wait, T, 0);
+ } else {
+ // If a counter has been maxed out avoid overflow by waiting for
+ // MAX(CounterType) - 1 instead.
+ unsigned NeededWait =
+ std::min(UB - ScoreToWait, getWaitCountMax(T) - 1);
+ addWait(Wait, T, NeededWait);
+ }
}
}
}
@@ -1670,18 +1696,16 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
RegInterval CallAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOp);
- for (int RegNo = CallAddrOpInterval.first;
- RegNo < CallAddrOpInterval.second; ++RegNo)
- ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+ ScoreBrackets.determineWait(SmemAccessCounter, CallAddrOpInterval,
+ Wait);
if (const auto *RtnAddrOp =
TII->getNamedOperand(MI, AMDGPU::OpName::dst)) {
RegInterval RtnAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, MRI, TRI, *RtnAddrOp);
- for (int RegNo = RtnAddrOpInterval.first;
- RegNo < RtnAddrOpInterval.second; ++RegNo)
- ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+ ScoreBrackets.determineWait(SmemAccessCounter, RtnAddrOpInterval,
+ Wait);
}
}
} else {
@@ -1750,36 +1774,34 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, Op);
const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
- for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
- if (IsVGPR) {
- // Implicit VGPR defs and uses are never a part of the memory
- // instructions description and usually present to account for
- // super-register liveness.
- // TODO: Most of the other instructions also have implicit uses
- // for the liveness accounting only.
- if (Op.isImplicit() && MI.mayLoadOrStore())
- continue;
-
- // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
- // previous write and this write are the same type of VMEM
- // instruction, in which case they are (in some architectures)
- // guaranteed to write their results in order anyway.
- if (Op.isUse() || !updateVMCntOnly(MI) ||
- ScoreBrackets.hasOtherPendingVmemTypes(RegNo,
- getVmemType(MI)) ||
- !ST->hasVmemWriteVgprInOrder()) {
- ScoreBrackets.determineWait(LOAD_CNT, RegNo, Wait);
- ScoreBrackets.determineWait(SAMPLE_CNT, RegNo, Wait);
- ScoreBrackets.determineWait(BVH_CNT, RegNo, Wait);
- ScoreBrackets.clearVgprVmemTypes(RegNo);
- }
- if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
- ScoreBrackets.determineWait(EXP_CNT, RegNo, Wait);
- }
- ScoreBrackets.determineWait(DS_CNT, RegNo, Wait);
- } else {
- ScoreBrackets.determineWait(SmemAccessCounter, RegNo, Wait);
+ if (IsVGPR) {
+ // Implicit VGPR defs and uses are never a part of the memory
+ // instructions description and usually present to account for
+ // super-register liveness.
+ // TODO: Most of the other instructions also have implicit uses
+ // for the liveness accounting only.
+ if (Op.isImplicit() && MI.mayLoadOrStore())
+ continue;
+
+ // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the
+ // previous write and this write are the same type of VMEM
+ // instruction, in which case they are (in some architectures)
+ // guaranteed to write their results in order anyway.
+ if (Op.isUse() || !updateVMCntOnly(MI) ||
+ ScoreBrackets.hasOtherPendingVmemTypes(Interval,
+ getVmemType(MI)) ||
+ !ST->hasVmemWriteVgprInOrder()) {
+ ScoreBrackets.determineWait(LOAD_CNT, Interval, Wait);
+ ScoreBrackets.determineWait(SAMPLE_CNT, Interval, Wait);
+ ScoreBrackets.determineWait(BVH_CNT, Interval, Wait);
+ ScoreBrackets.clearVgprVmemTypes(Interval);
+ }
+ if (Op.isDef() || ScoreBrackets.hasPendingEvent(EXP_LDS_ACCESS)) {
+ ScoreBrackets.determineWait(EXP_CNT, Interval, Wait);
}
+ ScoreBrackets.determineWait(DS_CNT, Interval, Wait);
+ } else {
+ ScoreBrackets.determineWait(SmemAccessCounter, Interval, Wait);
}
}
}
More information about the llvm-commits
mailing list