[llvm] AMDGPU: Propagate amdgpu-max-num-workgroups attribute (PR #113018)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 16:52:29 PDT 2024


================
@@ -821,6 +826,150 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
       "AAAMDFlatWorkGroupSize is only valid for function position");
 }
 
+struct TupleDecIntegerRangeState : public AbstractState {
+  DecIntegerState<uint32_t> X, Y, Z;
+
+  bool isValidState() const override {
+    return X.isValidState() && Y.isValidState() && Z.isValidState();
+  }
+
+  bool isAtFixpoint() const override {
+    return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
+  }
+
+  ChangeStatus indicateOptimisticFixpoint() override {
+    return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
+           Z.indicateOptimisticFixpoint();
+  }
+
+  ChangeStatus indicatePessimisticFixpoint() override {
+    return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
+           Z.indicatePessimisticFixpoint();
+  }
+
+  TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
+    X ^= Other.X;
+    Y ^= Other.Y;
+    Z ^= Other.Z;
+    return *this;
+  }
+
+  bool operator==(const TupleDecIntegerRangeState &Other) const {
+    return X == Other.X && Y == Other.Y && Z == Other.Z;
+  }
+
+  TupleDecIntegerRangeState &getAssumed() { return *this; }
+  const TupleDecIntegerRangeState &getAssumed() const { return *this; }
+};
+
+using AAAMDMaxNumWorkgroupsState =
+    StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
+
+/// Propagate amdgpu-max-num-workgroups attribute.
+struct AAAMDMaxNumWorkgroups
+    : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
+  using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
+
+  AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  void initialize(Attributor &A) override {
+    Function *F = getAssociatedFunction();
+    auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
+
+    SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
+
+    // FIXME: What is the interpretation of 0?
+    for (unsigned &Entry : MaxNumWorkgroups) {
+      if (Entry == 0)
+        Entry = std::numeric_limits<uint32_t>::max();
+    }
+
+    X.takeKnownMinimum(MaxNumWorkgroups[0]);
+    Y.takeKnownMinimum(MaxNumWorkgroups[1]);
+    Z.takeKnownMinimum(MaxNumWorkgroups[2]);
+
+    if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+      indicatePessimisticFixpoint();
+  }
+
+  ChangeStatus updateImpl(Attributor &A) override {
+    ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+    auto CheckCallSite = [&](AbstractCallSite CS) {
+      Function *Caller = CS.getInstruction()->getFunction();
+      LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
+                        << "->" << getAssociatedFunction()->getName() << '\n');
+
+      const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
+          *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
+      if (!CallerInfo)
----------------
shiltian wrote:

`if (!CallerInfo || !CallerInfo->isValidState)`

https://github.com/llvm/llvm-project/pull/113018


More information about the llvm-commits mailing list