[llvm] AMDGPU: Propagate amdgpu-max-num-workgroups attribute (PR #113018)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 21 16:01:56 PDT 2024
================
@@ -821,6 +826,150 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
"AAAMDFlatWorkGroupSize is only valid for function position");
}
+struct TupleDecIntegerRangeState : public AbstractState {
+ DecIntegerState<uint32_t> X, Y, Z;
+
+ bool isValidState() const override {
+ return X.isValidState() && Y.isValidState() && Z.isValidState();
+ }
+
+ bool isAtFixpoint() const override {
+ return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
+ }
+
+ ChangeStatus indicateOptimisticFixpoint() override {
+ return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
+ Z.indicateOptimisticFixpoint();
+ }
+
+ ChangeStatus indicatePessimisticFixpoint() override {
+ return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
+ Z.indicatePessimisticFixpoint();
+ }
+
+ TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
+ X ^= Other.X;
+ Y ^= Other.Y;
+ Z ^= Other.Z;
+ return *this;
+ }
+
+ bool operator==(const TupleDecIntegerRangeState &Other) const {
+ return X == Other.X && Y == Other.Y && Z == Other.Z;
+ }
+
+ TupleDecIntegerRangeState &getAssumed() { return *this; }
+ const TupleDecIntegerRangeState &getAssumed() const { return *this; }
+};
+
+using AAAMDMaxNumWorkgroupsState =
+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
+
+/// Propagate amdgpu-max-num-workgroups attribute.
+struct AAAMDMaxNumWorkgroups
+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
+
+ AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+ void initialize(Attributor &A) override {
+ Function *F = getAssociatedFunction();
+ auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
+
+ SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
+
+ // FIXME: What is the interpretation of 0?
+ for (unsigned &Entry : MaxNumWorkgroups) {
+ if (Entry == 0)
+ Entry = std::numeric_limits<uint32_t>::max();
+ }
+
+ X.takeKnownMinimum(MaxNumWorkgroups[0]);
+ Y.takeKnownMinimum(MaxNumWorkgroups[1]);
+ Z.takeKnownMinimum(MaxNumWorkgroups[2]);
+
+ if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+ indicatePessimisticFixpoint();
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+ auto CheckCallSite = [&](AbstractCallSite CS) {
+ Function *Caller = CS.getInstruction()->getFunction();
+ LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
+ << "->" << getAssociatedFunction()->getName() << '\n');
+
+ const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
+ *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
+ if (!CallerInfo)
----------------
arsenm wrote:
What do you mean? I copied this from the uniform-work-group-size case
https://github.com/llvm/llvm-project/pull/113018
More information about the llvm-commits
mailing list