[llvm-branch-commits] [llvm] [AMDGPU][Attributor] Skip update if an AA is at its initial state (PR #114726)
Matt Arsenault via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 11 03:05:21 PST 2024
================
@@ -1145,31 +1169,71 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
ChangeStatus Change = ChangeStatus::UNCHANGED;
+ Function *F = getAssociatedFunction();
+
+ const auto *AAFlatWorkGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
+ *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+ if (!AAFlatWorkGroupSize || !AAFlatWorkGroupSize->isValidState()) {
+ LLVM_DEBUG(
+ dbgs() << '[' << getName()
+ << "] AAAMDFlatWorkGroupSize is unavailable or invalid.\n");
+ return ChangeStatus::UNCHANGED;
+ }
+
+ if (AAFlatWorkGroupSize->isAtInitialState()) {
+ LLVM_DEBUG(dbgs() << '[' << getName()
+ << "] AAAMDFlatWorkGroupSize is still at initial "
+ "state. Skip the update.\n");
+ return ChangeStatus::UNCHANGED;
+ }
+
+ auto CurrentWorkGroupSize = std::make_pair(
+ AAFlatWorkGroupSize->getAssumed().getLower().getZExtValue(),
+ AAFlatWorkGroupSize->getAssumed().getUpper().getZExtValue() - 1);
+
+ auto DoUpdate = [&](std::pair<unsigned, unsigned> WavesPerEU,
+ std::pair<unsigned, unsigned> FlatWorkGroupSize) {
+ auto [Min, Max] =
+ InfoCache.getEffectiveWavesPerEU(*F, WavesPerEU, FlatWorkGroupSize);
+ ConstantRange CR(APInt(32, Min), APInt(32, Max + 1));
+ IntegerRangeState IRS(CR);
+ Change |= clampStateAndIndicateChange(this->getState(), IRS);
+ };
+
+ // // We need to clamp once if we are not at initial state, because
+ // // AAAMDFlatWorkGroupSize could be updated in last iteration.
+ if (!isAtInitialState()) {
+ auto CurrentWavesPerEU =
+ std::make_pair(getAssumed().getLower().getZExtValue(),
+ getAssumed().getUpper().getZExtValue() - 1);
+ DoUpdate(CurrentWavesPerEU, CurrentWorkGroupSize);
+ }
+
auto CheckCallSite = [&](AbstractCallSite CS) {
Function *Caller = CS.getInstruction()->getFunction();
- Function *Func = getAssociatedFunction();
+
LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
- << "->" << Func->getName() << '\n');
+ << "->" << F->getName() << '\n');
- const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
+ const auto *AAWavesPerEU = A.getAAFor<AAAMDWavesPerEU>(
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
- const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
- *this, IRPosition::function(*Func), DepClassTy::REQUIRED);
- if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() ||
- !AssumedGroupSize->isValidState())
+ if (!AAWavesPerEU || !AAWavesPerEU->isValidState()) {
+ LLVM_DEBUG(dbgs() << '[' << getName() << "] Caller "
+ << Caller->getName()
+ << " is unavailable or invalid.\n");
return false;
+ }
+ if (AAWavesPerEU->isAtInitialState()) {
----------------
arsenm wrote:
Same as above
https://github.com/llvm/llvm-project/pull/114726
More information about the llvm-branch-commits
mailing list