[llvm-branch-commits] [llvm] [AMDGPU][Attributor] Rework update of waves per eu (PR #123995)
Shilei Tian via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 23 11:54:11 PST 2025
================
@@ -1109,74 +1109,38 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
Function *F = getAssociatedFunction();
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
- auto TakeRange = [&](std::pair<unsigned, unsigned> R) {
- auto [Min, Max] = R;
- ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
- IntegerRangeState RangeState(Range);
- clampStateAndIndicateChange(this->getState(), RangeState);
- indicateOptimisticFixpoint();
- };
-
- std::pair<unsigned, unsigned> MaxWavesPerEURange{
- 1U, InfoCache.getMaxWavesPerEU(*F)};
-
// If the attribute exists, we will honor it if it is not the default.
if (auto Attr = InfoCache.getWavesPerEUAttr(*F)) {
+ std::pair<unsigned, unsigned> MaxWavesPerEURange{
+ 1U, InfoCache.getMaxWavesPerEU(*F)};
if (*Attr != MaxWavesPerEURange) {
- TakeRange(*Attr);
+ auto [Min, Max] = *Attr;
+ ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
+ IntegerRangeState RangeState(Range);
+ clampStateAndIndicateChange(this->getState(), RangeState);
+ indicateOptimisticFixpoint();
return;
}
}
- // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
- // calculation of waves per EU involves flat work group size, we can't
- // simply use an assumed flat work group size as a start point, because the
- // update of flat work group size is in an inverse direction of waves per
- // EU. However, we can still do something if it is an entry function. Since
- // an entry function is a terminal node, and flat work group size either
- // from attribute or default will be used anyway, we can take that value and
- // calculate the waves per EU based on it. This result can't be updated by
- // no means, but that could still allow us to propagate it.
- if (AMDGPU::isEntryFunctionCC(F->getCallingConv())) {
- std::pair<unsigned, unsigned> FlatWorkGroupSize;
- if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr(*F))
- FlatWorkGroupSize = *Attr;
- else
- FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize(*F);
- TakeRange(InfoCache.getEffectiveWavesPerEU(*F, MaxWavesPerEURange,
- FlatWorkGroupSize));
- }
+ if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+ indicatePessimisticFixpoint();
}
ChangeStatus updateImpl(Attributor &A) override {
- auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
ChangeStatus Change = ChangeStatus::UNCHANGED;
auto CheckCallSite = [&](AbstractCallSite CS) {
Function *Caller = CS.getInstruction()->getFunction();
- Function *Func = getAssociatedFunction();
- LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
- << "->" << Func->getName() << '\n');
-
const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
- const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
- *this, IRPosition::function(*Func), DepClassTy::REQUIRED);
- if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() ||
- !AssumedGroupSize->isValidState())
+ if (!CallerInfo || !CallerInfo->isValidState())
return false;
-
- unsigned Min, Max;
- std::tie(Min, Max) = InfoCache.getEffectiveWavesPerEU(
- *Caller,
- {CallerInfo->getAssumed().getLower().getZExtValue(),
- CallerInfo->getAssumed().getUpper().getZExtValue() - 1},
- {AssumedGroupSize->getAssumed().getLower().getZExtValue(),
- AssumedGroupSize->getAssumed().getUpper().getZExtValue() - 1});
- ConstantRange CallerRange(APInt(32, Min), APInt(32, Max + 1));
+ unsigned Min = CallerInfo->getAssumed().getLower().getZExtValue();
+ unsigned Max = CallerInfo->getAssumed().getUpper().getZExtValue();
+ ConstantRange CallerRange(APInt(32, Min), APInt(32, Max));
IntegerRangeState CallerRangeState(CallerRange);
Change |= clampStateAndIndicateChange(this->getState(), CallerRangeState);
----------------
shiltian wrote:
@arsenm did I get this part correctly?
https://github.com/llvm/llvm-project/pull/123995
More information about the llvm-branch-commits
mailing list