[llvm-branch-commits] [llvm] [AMDGPU][Attributor] Rework update of waves per eu (PR #123995)

Shilei Tian via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 23 11:54:11 PST 2025


================
@@ -1109,74 +1109,38 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
     Function *F = getAssociatedFunction();
     auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
 
-    auto TakeRange = [&](std::pair<unsigned, unsigned> R) {
-      auto [Min, Max] = R;
-      ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
-      IntegerRangeState RangeState(Range);
-      clampStateAndIndicateChange(this->getState(), RangeState);
-      indicateOptimisticFixpoint();
-    };
-
-    std::pair<unsigned, unsigned> MaxWavesPerEURange{
-        1U, InfoCache.getMaxWavesPerEU(*F)};
-
     // If the attribute exists, we will honor it if it is not the default.
     if (auto Attr = InfoCache.getWavesPerEUAttr(*F)) {
+      std::pair<unsigned, unsigned> MaxWavesPerEURange{
+          1U, InfoCache.getMaxWavesPerEU(*F)};
       if (*Attr != MaxWavesPerEURange) {
-        TakeRange(*Attr);
+        auto [Min, Max] = *Attr;
+        ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
+        IntegerRangeState RangeState(Range);
+        clampStateAndIndicateChange(this->getState(), RangeState);
+        indicateOptimisticFixpoint();
         return;
       }
     }
 
-    // Unlike AAAMDFlatWorkGroupSize, it's getting trickier here. Since the
-    // calculation of waves per EU involves flat work group size, we can't
-    // simply use an assumed flat work group size as a start point, because the
-    // update of flat work group size is in an inverse direction of waves per
-    // EU. However, we can still do something if it is an entry function. Since
-    // an entry function is a terminal node, and flat work group size either
-    // from attribute or default will be used anyway, we can take that value and
-    // calculate the waves per EU based on it. This result can't be updated by
-    // no means, but that could still allow us to propagate it.
-    if (AMDGPU::isEntryFunctionCC(F->getCallingConv())) {
-      std::pair<unsigned, unsigned> FlatWorkGroupSize;
-      if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr(*F))
-        FlatWorkGroupSize = *Attr;
-      else
-        FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize(*F);
-      TakeRange(InfoCache.getEffectiveWavesPerEU(*F, MaxWavesPerEURange,
-                                                 FlatWorkGroupSize));
-    }
+    if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+      indicatePessimisticFixpoint();
   }
 
   ChangeStatus updateImpl(Attributor &A) override {
-    auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
     ChangeStatus Change = ChangeStatus::UNCHANGED;
 
     auto CheckCallSite = [&](AbstractCallSite CS) {
       Function *Caller = CS.getInstruction()->getFunction();
-      Function *Func = getAssociatedFunction();
-      LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
-                        << "->" << Func->getName() << '\n');
-
       const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
           *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
-      const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
-          *this, IRPosition::function(*Func), DepClassTy::REQUIRED);
-      if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() ||
-          !AssumedGroupSize->isValidState())
+      if (!CallerInfo || !CallerInfo->isValidState())
         return false;
-
-      unsigned Min, Max;
-      std::tie(Min, Max) = InfoCache.getEffectiveWavesPerEU(
-          *Caller,
-          {CallerInfo->getAssumed().getLower().getZExtValue(),
-           CallerInfo->getAssumed().getUpper().getZExtValue() - 1},
-          {AssumedGroupSize->getAssumed().getLower().getZExtValue(),
-           AssumedGroupSize->getAssumed().getUpper().getZExtValue() - 1});
-      ConstantRange CallerRange(APInt(32, Min), APInt(32, Max + 1));
+      unsigned Min = CallerInfo->getAssumed().getLower().getZExtValue();
+      unsigned Max = CallerInfo->getAssumed().getUpper().getZExtValue();
+      ConstantRange CallerRange(APInt(32, Min), APInt(32, Max));
       IntegerRangeState CallerRangeState(CallerRange);
       Change |= clampStateAndIndicateChange(this->getState(), CallerRangeState);
----------------
shiltian wrote:

@arsenm did I get this part correctly?

https://github.com/llvm/llvm-project/pull/123995


More information about the llvm-branch-commits mailing list