[llvm] AMDGPU: Propagate amdgpu-max-num-workgroups attribute (PR #113018)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 10:47:44 PST 2024


================
@@ -821,6 +826,152 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
       "AAAMDFlatWorkGroupSize is only valid for function position");
 }
 
+struct TupleDecIntegerRangeState : public AbstractState {
+  DecIntegerState<uint32_t> X, Y, Z;
+
+  bool isValidState() const override {
+    return X.isValidState() && Y.isValidState() && Z.isValidState();
+  }
+
+  bool isAtFixpoint() const override {
+    return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
+  }
+
+  ChangeStatus indicateOptimisticFixpoint() override {
+    return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
+           Z.indicateOptimisticFixpoint();
+  }
+
+  ChangeStatus indicatePessimisticFixpoint() override {
+    return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
+           Z.indicatePessimisticFixpoint();
+  }
+
+  TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
+    X ^= Other.X;
+    Y ^= Other.Y;
+    Z ^= Other.Z;
+    return *this;
+  }
+
+  bool operator==(const TupleDecIntegerRangeState &Other) const {
+    return X == Other.X && Y == Other.Y && Z == Other.Z;
+  }
+
+  TupleDecIntegerRangeState &getAssumed() { return *this; }
+  const TupleDecIntegerRangeState &getAssumed() const { return *this; }
+};
+
+using AAAMDMaxNumWorkgroupsState =
+    StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
+
+/// Propagate amdgpu-max-num-workgroups attribute.
+struct AAAMDMaxNumWorkgroups
+    : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
+  using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
+
+  AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  void initialize(Attributor &A) override {
+    Function *F = getAssociatedFunction();
+    auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
+
+    SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
+
+    // FIXME: What is the interpretation of 0?
+    for (unsigned &Entry : MaxNumWorkgroups) {
+      if (Entry == 0)
+        Entry = std::numeric_limits<uint32_t>::max();
+    }
+
+    X.takeKnownMinimum(MaxNumWorkgroups[0]);
+    Y.takeKnownMinimum(MaxNumWorkgroups[1]);
+    Z.takeKnownMinimum(MaxNumWorkgroups[2]);
+
+    if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+      indicatePessimisticFixpoint();
+  }
+
+  ChangeStatus updateImpl(Attributor &A) override {
+    ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+    auto CheckCallSite = [&](AbstractCallSite CS) {
+      Function *Caller = CS.getInstruction()->getFunction();
+      LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
+                        << "->" << getAssociatedFunction()->getName() << '\n');
+
+      const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
+          *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
+      if (!CallerInfo || !CallerInfo->isValidState())
+        return false;
+
+      Change |=
+          clampStateAndIndicateChange(this->getState(), CallerInfo->getState());
+      return true;
+    };
+
+    bool AllCallSitesKnown = true;
+    if (!A.checkForAllCallSites(CheckCallSite, *this,
+                                /*RequireAllCallSites=*/true,
+                                AllCallSitesKnown))
+      return indicatePessimisticFixpoint();
+
+    return Change;
+  }
+
+  /// Create an abstract attribute view for the position \p IRP.
+  static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP,
+                                                  Attributor &A);
+
+  ChangeStatus manifest(Attributor &A) override {
+    Function *F = getAssociatedFunction();
+    // TODO: Skip adding if worst case?
----------------
shiltian wrote:

pessimistic is only invalid if known is the worst state, and because you update known it is apparently no longer the worst state, thus no longer invalid.

with your current approach, everything is valid and the manifest is bound to happen.

https://github.com/llvm/llvm-project/pull/113018


More information about the llvm-commits mailing list