[llvm-branch-commits] [llvm] [AMDGPU][Attributor] Skip update if an AA is at its initial state (PR #114726)

Matt Arsenault via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 11 03:05:21 PST 2024


================
@@ -1145,31 +1169,71 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
     auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
     ChangeStatus Change = ChangeStatus::UNCHANGED;
 
+    Function *F = getAssociatedFunction();
+
+    const auto *AAFlatWorkGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
+        *this, IRPosition::function(*F), DepClassTy::REQUIRED);
+    if (!AAFlatWorkGroupSize || !AAFlatWorkGroupSize->isValidState()) {
+      LLVM_DEBUG(
+          dbgs() << '[' << getName()
+                 << "] AAAMDFlatWorkGroupSize is unavailable or invalid.\n");
+      return ChangeStatus::UNCHANGED;
+    }
+
+    if (AAFlatWorkGroupSize->isAtInitialState()) {
+      LLVM_DEBUG(dbgs() << '[' << getName()
+                        << "] AAAMDFlatWorkGroupSize is still at initial "
+                           "state. Skip the update.\n");
+      return ChangeStatus::UNCHANGED;
+    }
+
+    auto CurrentWorkGroupSize = std::make_pair(
+        AAFlatWorkGroupSize->getAssumed().getLower().getZExtValue(),
+        AAFlatWorkGroupSize->getAssumed().getUpper().getZExtValue() - 1);
+
+    auto DoUpdate = [&](std::pair<unsigned, unsigned> WavesPerEU,
+                        std::pair<unsigned, unsigned> FlatWorkGroupSize) {
+      auto [Min, Max] =
+          InfoCache.getEffectiveWavesPerEU(*F, WavesPerEU, FlatWorkGroupSize);
+      ConstantRange CR(APInt(32, Min), APInt(32, Max + 1));
+      IntegerRangeState IRS(CR);
+      Change |= clampStateAndIndicateChange(this->getState(), IRS);
+    };
+
+    // // We need to clamp once if we are not at initial state, because
+    // // AAAMDFlatWorkGroupSize could be updated in last iteration.
+    if (!isAtInitialState()) {
+      auto CurrentWavesPerEU =
+          std::make_pair(getAssumed().getLower().getZExtValue(),
+                         getAssumed().getUpper().getZExtValue() - 1);
+      DoUpdate(CurrentWavesPerEU, CurrentWorkGroupSize);
+    }
+
     auto CheckCallSite = [&](AbstractCallSite CS) {
       Function *Caller = CS.getInstruction()->getFunction();
-      Function *Func = getAssociatedFunction();
+
       LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
-                        << "->" << Func->getName() << '\n');
+                        << "->" << F->getName() << '\n');
 
-      const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
+      const auto *AAWavesPerEU = A.getAAFor<AAAMDWavesPerEU>(
           *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
-      const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
-          *this, IRPosition::function(*Func), DepClassTy::REQUIRED);
-      if (!CallerInfo || !AssumedGroupSize || !CallerInfo->isValidState() ||
-          !AssumedGroupSize->isValidState())
+      if (!AAWavesPerEU || !AAWavesPerEU->isValidState()) {
+        LLVM_DEBUG(dbgs() << '[' << getName() << "] Caller "
+                          << Caller->getName()
+                          << " is unavailable or invalid.\n");
         return false;
+      }
+      if (AAWavesPerEU->isAtInitialState()) {
----------------
arsenm wrote:

Same as above 

https://github.com/llvm/llvm-project/pull/114726


More information about the llvm-branch-commits mailing list