[llvm] [OpenMPOpt] Allow indirect calls in AAKernelInfoCallSite (PR #65836)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 10 16:59:31 PDT 2023
================
@@ -4797,204 +4801,241 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// we will handle them explicitly in the switch below. If it is not, we
// will use an AAKernelInfo object on the callee to gather information and
// merge that into the current state. The latter happens in the updateImpl.
- Function *Callee = getAssociatedFunction();
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- // Unknown caller or declarations are not analyzable, we give up.
- if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
-
- // Unknown callees might contain parallel regions, except if they have
- // an appropriate assumption attached.
- if (!AssumptionAA ||
- !(AssumptionAA->hasAssumption("omp_no_openmp") ||
- AssumptionAA->hasAssumption("omp_no_parallelism")))
- ReachedUnknownParallelRegions.insert(&CB);
-
- // If SPMDCompatibilityTracker is not fixed, we need to give up on the
- // idea we can run something unknown in SPMD-mode.
- if (!SPMDCompatibilityTracker.isAtFixpoint()) {
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- }
+ auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ // Unknown caller or declarations are not analyzable, we give up.
+ if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
+
+ // Unknown callees might contain parallel regions, except if they have
+ // an appropriate assumption attached.
+ if (!AssumptionAA ||
+ !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+ AssumptionAA->hasAssumption("omp_no_parallelism")))
+ ReachedUnknownParallelRegions.insert(&CB);
+
+ // If SPMDCompatibilityTracker is not fixed, we need to give up on the
+ // idea we can run something unknown in SPMD-mode.
+ if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ }
- // We have updated the state for this unknown call properly, there won't
- // be any change so we indicate a fixpoint.
- indicateOptimisticFixpoint();
+ // We have updated the state for this unknown call properly, there
+ // won't be any change so we indicate a fixpoint.
+ indicateOptimisticFixpoint();
+ }
+ // If the callee is known and can be used in IPO, we will update the
+ // state based on the callee state in updateImpl.
+ return;
+ }
+ if (NumCallees > 1) {
+ indicatePessimisticFixpoint();
+ return;
}
- // If the callee is known and can be used in IPO, we will update the state
- // based on the callee state in updateImpl.
- return;
- }
- RuntimeFunction RF = It->getSecond();
- switch (RF) {
- // All the functions we know are compatible with SPMD mode.
- case OMPRTL___kmpc_is_spmd_exec_mode:
- case OMPRTL___kmpc_distribute_static_fini:
- case OMPRTL___kmpc_for_static_fini:
- case OMPRTL___kmpc_global_thread_num:
- case OMPRTL___kmpc_get_hardware_num_threads_in_block:
- case OMPRTL___kmpc_get_hardware_num_blocks:
- case OMPRTL___kmpc_single:
- case OMPRTL___kmpc_end_single:
- case OMPRTL___kmpc_master:
- case OMPRTL___kmpc_end_master:
- case OMPRTL___kmpc_barrier:
- case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_end_reduce_nowait:
- case OMPRTL___kmpc_error:
- case OMPRTL___kmpc_flush:
- case OMPRTL___kmpc_get_hardware_thread_id_in_block:
- case OMPRTL___kmpc_get_warp_size:
- case OMPRTL_omp_get_thread_num:
- case OMPRTL_omp_get_num_threads:
- case OMPRTL_omp_get_max_threads:
- case OMPRTL_omp_in_parallel:
- case OMPRTL_omp_get_dynamic:
- case OMPRTL_omp_get_cancellation:
- case OMPRTL_omp_get_nested:
- case OMPRTL_omp_get_schedule:
- case OMPRTL_omp_get_thread_limit:
- case OMPRTL_omp_get_supported_active_levels:
- case OMPRTL_omp_get_max_active_levels:
- case OMPRTL_omp_get_level:
- case OMPRTL_omp_get_ancestor_thread_num:
- case OMPRTL_omp_get_team_size:
- case OMPRTL_omp_get_active_level:
- case OMPRTL_omp_in_final:
- case OMPRTL_omp_get_proc_bind:
- case OMPRTL_omp_get_num_places:
- case OMPRTL_omp_get_num_procs:
- case OMPRTL_omp_get_place_proc_ids:
- case OMPRTL_omp_get_place_num:
- case OMPRTL_omp_get_partition_num_places:
- case OMPRTL_omp_get_partition_place_nums:
- case OMPRTL_omp_get_wtime:
- break;
- case OMPRTL___kmpc_distribute_static_init_4:
- case OMPRTL___kmpc_distribute_static_init_4u:
- case OMPRTL___kmpc_distribute_static_init_8:
- case OMPRTL___kmpc_distribute_static_init_8u:
- case OMPRTL___kmpc_for_static_init_4:
- case OMPRTL___kmpc_for_static_init_4u:
- case OMPRTL___kmpc_for_static_init_8:
- case OMPRTL___kmpc_for_static_init_8u: {
- // Check the schedule and allow static schedule in SPMD mode.
- unsigned ScheduleArgOpNo = 2;
- auto *ScheduleTypeCI =
- dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
- unsigned ScheduleTypeVal =
- ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
- switch (OMPScheduleType(ScheduleTypeVal)) {
- case OMPScheduleType::UnorderedStatic:
- case OMPScheduleType::UnorderedStaticChunked:
- case OMPScheduleType::OrderedDistribute:
- case OMPScheduleType::OrderedDistributeChunked:
+ RuntimeFunction RF = It->getSecond();
+ switch (RF) {
+ // All the functions we know are compatible with SPMD mode.
+ case OMPRTL___kmpc_is_spmd_exec_mode:
+ case OMPRTL___kmpc_distribute_static_fini:
+ case OMPRTL___kmpc_for_static_fini:
+ case OMPRTL___kmpc_global_thread_num:
+ case OMPRTL___kmpc_get_hardware_num_threads_in_block:
+ case OMPRTL___kmpc_get_hardware_num_blocks:
+ case OMPRTL___kmpc_single:
+ case OMPRTL___kmpc_end_single:
+ case OMPRTL___kmpc_master:
+ case OMPRTL___kmpc_end_master:
+ case OMPRTL___kmpc_barrier:
+ case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
+ case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
+ case OMPRTL___kmpc_nvptx_end_reduce_nowait:
+ case OMPRTL___kmpc_error:
+ case OMPRTL___kmpc_flush:
+ case OMPRTL___kmpc_get_hardware_thread_id_in_block:
+ case OMPRTL___kmpc_get_warp_size:
+ case OMPRTL_omp_get_thread_num:
+ case OMPRTL_omp_get_num_threads:
+ case OMPRTL_omp_get_max_threads:
+ case OMPRTL_omp_in_parallel:
+ case OMPRTL_omp_get_dynamic:
+ case OMPRTL_omp_get_cancellation:
+ case OMPRTL_omp_get_nested:
+ case OMPRTL_omp_get_schedule:
+ case OMPRTL_omp_get_thread_limit:
+ case OMPRTL_omp_get_supported_active_levels:
+ case OMPRTL_omp_get_max_active_levels:
+ case OMPRTL_omp_get_level:
+ case OMPRTL_omp_get_ancestor_thread_num:
+ case OMPRTL_omp_get_team_size:
+ case OMPRTL_omp_get_active_level:
+ case OMPRTL_omp_in_final:
+ case OMPRTL_omp_get_proc_bind:
+ case OMPRTL_omp_get_num_places:
+ case OMPRTL_omp_get_num_procs:
+ case OMPRTL_omp_get_place_proc_ids:
+ case OMPRTL_omp_get_place_num:
+ case OMPRTL_omp_get_partition_num_places:
+ case OMPRTL_omp_get_partition_place_nums:
+ case OMPRTL_omp_get_wtime:
break;
+ case OMPRTL___kmpc_distribute_static_init_4:
+ case OMPRTL___kmpc_distribute_static_init_4u:
+ case OMPRTL___kmpc_distribute_static_init_8:
+ case OMPRTL___kmpc_distribute_static_init_8u:
+ case OMPRTL___kmpc_for_static_init_4:
+ case OMPRTL___kmpc_for_static_init_4u:
+ case OMPRTL___kmpc_for_static_init_8:
+ case OMPRTL___kmpc_for_static_init_8u: {
+ // Check the schedule and allow static schedule in SPMD mode.
+ unsigned ScheduleArgOpNo = 2;
+ auto *ScheduleTypeCI =
+ dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
+ unsigned ScheduleTypeVal =
+ ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
+ switch (OMPScheduleType(ScheduleTypeVal)) {
+ case OMPScheduleType::UnorderedStatic:
+ case OMPScheduleType::UnorderedStaticChunked:
+ case OMPScheduleType::OrderedDistribute:
+ case OMPScheduleType::OrderedDistributeChunked:
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ };
+ } break;
+ case OMPRTL___kmpc_target_init:
+ KernelInitCB = &CB;
+ break;
+ case OMPRTL___kmpc_target_deinit:
+ KernelDeinitCB = &CB;
+ break;
+ case OMPRTL___kmpc_parallel_51:
+ if (!handleParallel51(A, CB))
+ indicatePessimisticFixpoint();
+ return;
+ case OMPRTL___kmpc_omp_task:
+ // We do not look into tasks right now, just give up.
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ ReachedUnknownParallelRegions.insert(&CB);
+ break;
+ case OMPRTL___kmpc_alloc_shared:
+ case OMPRTL___kmpc_free_shared:
+ // Return without setting a fixpoint, to be resolved in updateImpl.
+ return;
default:
+ // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
+ // generally. However, they do not hide parallel regions.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
break;
- };
- } break;
- case OMPRTL___kmpc_target_init:
- KernelInitCB = &CB;
- break;
- case OMPRTL___kmpc_target_deinit:
- KernelDeinitCB = &CB;
- break;
- case OMPRTL___kmpc_parallel_51:
- if (!handleParallel51(A, CB))
- indicatePessimisticFixpoint();
- return;
- case OMPRTL___kmpc_omp_task:
- // We do not look into tasks right now, just give up.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- ReachedUnknownParallelRegions.insert(&CB);
- break;
- case OMPRTL___kmpc_alloc_shared:
- case OMPRTL___kmpc_free_shared:
- // Return without setting a fixpoint, to be resolved in updateImpl.
+ }
+ // All other OpenMP runtime calls will not reach parallel regions so they
+ // can be safely ignored for now. Since it is a known OpenMP runtime call
+ // we have now modeled all effects and there is no need for any update.
+ indicateOptimisticFixpoint();
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ CheckCallee(getAssociatedFunction(), 1);
return;
- default:
- // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
- // generally. However, they do not hide parallel regions.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- break;
}
- // All other OpenMP runtime calls will not reach parallel regions so they
- // can be safely ignored for now. Since it is a known OpenMP runtime call we
- // have now modeled all effects and there is no need for any update.
- indicateOptimisticFixpoint();
+ const auto &OptimisticEdges = AACE->getOptimisticEdges();
+ for (auto *Callee : OptimisticEdges) {
+ CheckCallee(Callee, OptimisticEdges.size());
+ if (isAtFixpoint())
+ break;
+ }
}
ChangeStatus updateImpl(Attributor &A) override {
// TODO: Once we have call site specific value information we can provide
// call site specific liveness information and then it makes
// sense to specialize attributes for call sites arguments instead of
// redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
-
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
-
- // If F is not a runtime function, propagate the AAKernelInfo of the callee.
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- const IRPosition &FnPos = IRPosition::function(*F);
- auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
- if (!FnAA)
- return indicatePessimisticFixpoint();
- if (getState() == FnAA->getState())
- return ChangeStatus::UNCHANGED;
- getState() = FnAA->getState();
- return ChangeStatus::CHANGED;
- }
-
KernelInfoState StateBefore = getState();
- CallBase &CB = cast<CallBase>(getAssociatedValue());
- if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
- if (!handleParallel51(A, CB))
- return indicatePessimisticFixpoint();
- return StateBefore == getState() ? ChangeStatus::UNCHANGED
- : ChangeStatus::CHANGED;
- }
-
- // F is a runtime function that allocates or frees memory, check
- // AAHeapToStack and AAHeapToShared.
- assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
- It->getSecond() == OMPRTL___kmpc_free_shared) &&
- "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+ auto CheckCallee = [&](Function *F, int NumCallees) {
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+
+ // If F is not a runtime function, propagate the AAKernelInfo of the
+ // callee.
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto *FnAA =
+ A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
+ if (getState() == FnAA->getState())
+ return ChangeStatus::UNCHANGED;
+ getState() = FnAA->getState();
+ return ChangeStatus::CHANGED;
+ }
+ if (NumCallees > 1)
+ return indicatePessimisticFixpoint();
- auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
- auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
-
- RuntimeFunction RF = It->getSecond();
+ CallBase &CB = cast<CallBase>(getAssociatedValue());
+ if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
+ if (!handleParallel51(A, CB))
+ return indicatePessimisticFixpoint();
+ return StateBefore == getState() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
- switch (RF) {
- // If neither HeapToStack nor HeapToShared assume the call is removed,
- // assume SPMD incompatibility.
- case OMPRTL___kmpc_alloc_shared:
- if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
- (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
- SPMDCompatibilityTracker.insert(&CB);
- break;
- case OMPRTL___kmpc_free_shared:
- if ((!HeapToStackAA ||
- !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
- (!HeapToSharedAA ||
- !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ // F is a runtime function that allocates or frees memory, check
+ // AAHeapToStack and AAHeapToShared.
+ assert(
+ (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
+ It->getSecond() == OMPRTL___kmpc_free_shared) &&
+ "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+
+ auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+ auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+
+ RuntimeFunction RF = It->getSecond();
+
+ switch (RF) {
+ // If neither HeapToStack nor HeapToShared assume the call is removed,
+ // assume SPMD incompatibility.
+ case OMPRTL___kmpc_alloc_shared:
+ if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
+ (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ case OMPRTL___kmpc_free_shared:
+ if ((!HeapToStackAA ||
+ !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
+ (!HeapToSharedAA ||
+ !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
- break;
- default:
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
+ }
+ return ChangeStatus::CHANGED;
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ CheckCallee(getAssociatedFunction(), 1);
----------------
shiltian wrote:
```suggestion
CheckCallee(getAssociatedFunction(), /*NumCallees=*/1);
```
https://github.com/llvm/llvm-project/pull/65836
More information about the llvm-commits
mailing list