[llvm] AMDGPU: Propagate amdgpu-max-num-workgroups attribute (PR #113018)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 09:19:08 PST 2024
================
@@ -821,6 +826,146 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
"AAAMDFlatWorkGroupSize is only valid for function position");
}
+struct TupleDecIntegerRangeState : public AbstractState {
+ DecIntegerState<uint32_t> X, Y, Z;
+
+ bool isValidState() const override {
+ return X.isValidState() && Y.isValidState() && Z.isValidState();
+ }
+
+ bool isAtFixpoint() const override {
+ return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
+ }
+
+ ChangeStatus indicateOptimisticFixpoint() override {
+ return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
+ Z.indicateOptimisticFixpoint();
+ }
+
+ ChangeStatus indicatePessimisticFixpoint() override {
+ return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
+ Z.indicatePessimisticFixpoint();
+ }
+
+ TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
+ X ^= Other.X;
+ Y ^= Other.Y;
+ Z ^= Other.Z;
+ return *this;
+ }
+
+ bool operator==(const TupleDecIntegerRangeState &Other) const {
+ return X == Other.X && Y == Other.Y && Z == Other.Z;
+ }
+
+ TupleDecIntegerRangeState &getAssumed() { return *this; }
+ const TupleDecIntegerRangeState &getAssumed() const { return *this; }
+};
+
+using AAAMDMaxNumWorkgroupsState =
+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
+
+/// Propagate amdgpu-max-num-workgroups attribute.
+struct AAAMDMaxNumWorkgroups
+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
+
+ AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+ void initialize(Attributor &A) override {
+ Function *F = getAssociatedFunction();
+ auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
+
+ SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
+
+ X.takeKnownMinimum(MaxNumWorkgroups[0]);
+ Y.takeKnownMinimum(MaxNumWorkgroups[1]);
+ Z.takeKnownMinimum(MaxNumWorkgroups[2]);
+
+ if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
+ indicatePessimisticFixpoint();
+ }
+
+ ChangeStatus updateImpl(Attributor &A) override {
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+ auto CheckCallSite = [&](AbstractCallSite CS) {
+ Function *Caller = CS.getInstruction()->getFunction();
+ LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
+ << "->" << getAssociatedFunction()->getName() << '\n');
+
+ const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
+ *this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
+ if (!CallerInfo || !CallerInfo->isValidState())
+ return false;
+
+ Change |=
+ clampStateAndIndicateChange(this->getState(), CallerInfo->getState());
+ return true;
+ };
+
+ bool AllCallSitesKnown = true;
+ if (!A.checkForAllCallSites(CheckCallSite, *this,
+ /*RequireAllCallSites=*/true,
+ AllCallSitesKnown))
+ return indicatePessimisticFixpoint();
+
+ return Change;
+ }
+
+ /// Create an abstract attribute view for the position \p IRP.
+ static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP,
+ Attributor &A);
+
+ ChangeStatus manifest(Attributor &A) override {
+ Function *F = getAssociatedFunction();
+ // TODO: Skip adding if worst case?
+ LLVMContext &Ctx = F->getContext();
----------------
arsenm wrote:
I would expect manifest to not get called in the first place if the worst state is considered invalid
https://github.com/llvm/llvm-project/pull/113018
More information about the llvm-commits
mailing list