[llvm] [OpenMPOpt] Allow indirect calls in AAKernelInfoCallSite (PR #65836)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 8 22:43:01 PDT 2023
llvmbot wrote:
@llvm/pr-subscribers-openmp
<details><summary>Changes</summary><pre>
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 63493eb78c451a6..44aed2697842201 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3532,6 +3532,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+ /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
+ /// unknown callees.
+ static bool requiresCalleeForCallBase() { return false; }
+
/// Statistics are tracked as part of manifest for now.
void trackStatistics() const override {}
@@ -4797,139 +4801,157 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// we will handle them explicitly in the switch below. If it is not, we
// will use an AAKernelInfo object on the callee to gather information and
// merge that into the current state. The latter happens in the updateImpl.
- Function *Callee = getAssociatedFunction();
- auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- // Unknown caller or declarations are not analyzable, we give up.
- if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
-
- // Unknown callees might contain parallel regions, except if they have
- // an appropriate assumption attached.
- if (!AssumptionAA ||
- !(AssumptionAA->hasAssumption("omp_no_openmp") ||
- AssumptionAA->hasAssumption("omp_no_parallelism")))
- ReachedUnknownParallelRegions.insert(&CB);
-
- // If SPMDCompatibilityTracker is not fixed, we need to give up on the
- // idea we can run something unknown in SPMD-mode.
- if (!SPMDCompatibilityTracker.isAtFixpoint()) {
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- }
+ auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ // Unknown caller or declarations are not analyzable, we give up.
+ if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
+
+ // Unknown callees might contain parallel regions, except if they have
+ // an appropriate assumption attached.
+ if (!AssumptionAA ||
+ !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+ AssumptionAA->hasAssumption("omp_no_parallelism")))
+ ReachedUnknownParallelRegions.insert(&CB);
+
+ // If SPMDCompatibilityTracker is not fixed, we need to give up on the
+ // idea we can run something unknown in SPMD-mode.
+ if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ }
- // We have updated the state for this unknown call properly, there won't
- // be any change so we indicate a fixpoint.
- indicateOptimisticFixpoint();
+ // We have updated the state for this unknown call properly, there
+ // won't be any change so we indicate a fixpoint.
+ indicateOptimisticFixpoint();
+ }
+ // If the callee is known and can be used in IPO, we will update the
+ // state based on the callee state in updateImpl.
+ return;
+ }
+ if (NumCallees > 1) {
+ indicatePessimisticFixpoint();
+ return;
}
- // If the callee is known and can be used in IPO, we will update the state
- // based on the callee state in updateImpl.
- return;
- }
- RuntimeFunction RF = It->getSecond();
- switch (RF) {
- // All the functions we know are compatible with SPMD mode.
- case OMPRTL___kmpc_is_spmd_exec_mode:
- case OMPRTL___kmpc_distribute_static_fini:
- case OMPRTL___kmpc_for_static_fini:
- case OMPRTL___kmpc_global_thread_num:
- case OMPRTL___kmpc_get_hardware_num_threads_in_block:
- case OMPRTL___kmpc_get_hardware_num_blocks:
- case OMPRTL___kmpc_single:
- case OMPRTL___kmpc_end_single:
- case OMPRTL___kmpc_master:
- case OMPRTL___kmpc_end_master:
- case OMPRTL___kmpc_barrier:
- case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
- case OMPRTL___kmpc_nvptx_end_reduce_nowait:
- case OMPRTL___kmpc_error:
- case OMPRTL___kmpc_flush:
- case OMPRTL___kmpc_get_hardware_thread_id_in_block:
- case OMPRTL___kmpc_get_warp_size:
- case OMPRTL_omp_get_thread_num:
- case OMPRTL_omp_get_num_threads:
- case OMPRTL_omp_get_max_threads:
- case OMPRTL_omp_in_parallel:
- case OMPRTL_omp_get_dynamic:
- case OMPRTL_omp_get_cancellation:
- case OMPRTL_omp_get_nested:
- case OMPRTL_omp_get_schedule:
- case OMPRTL_omp_get_thread_limit:
- case OMPRTL_omp_get_supported_active_levels:
- case OMPRTL_omp_get_max_active_levels:
- case OMPRTL_omp_get_level:
- case OMPRTL_omp_get_ancestor_thread_num:
- case OMPRTL_omp_get_team_size:
- case OMPRTL_omp_get_active_level:
- case OMPRTL_omp_in_final:
- case OMPRTL_omp_get_proc_bind:
- case OMPRTL_omp_get_num_places:
- case OMPRTL_omp_get_num_procs:
- case OMPRTL_omp_get_place_proc_ids:
- case OMPRTL_omp_get_place_num:
- case OMPRTL_omp_get_partition_num_places:
- case OMPRTL_omp_get_partition_place_nums:
- case OMPRTL_omp_get_wtime:
- break;
- case OMPRTL___kmpc_distribute_static_init_4:
- case OMPRTL___kmpc_distribute_static_init_4u:
- case OMPRTL___kmpc_distribute_static_init_8:
- case OMPRTL___kmpc_distribute_static_init_8u:
- case OMPRTL___kmpc_for_static_init_4:
- case OMPRTL___kmpc_for_static_init_4u:
- case OMPRTL___kmpc_for_static_init_8:
- case OMPRTL___kmpc_for_static_init_8u: {
- // Check the schedule and allow static schedule in SPMD mode.
- unsigned ScheduleArgOpNo = 2;
- auto *ScheduleTypeCI =
- dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
- unsigned ScheduleTypeVal =
- ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
- switch (OMPScheduleType(ScheduleTypeVal)) {
- case OMPScheduleType::UnorderedStatic:
- case OMPScheduleType::UnorderedStaticChunked:
- case OMPScheduleType::OrderedDistribute:
- case OMPScheduleType::OrderedDistributeChunked:
+ RuntimeFunction RF = It->getSecond();
+ switch (RF) {
+ // All the functions we know are compatible with SPMD mode.
+ case OMPRTL___kmpc_is_spmd_exec_mode:
+ case OMPRTL___kmpc_distribute_static_fini:
+ case OMPRTL___kmpc_for_static_fini:
+ case OMPRTL___kmpc_global_thread_num:
+ case OMPRTL___kmpc_get_hardware_num_threads_in_block:
+ case OMPRTL___kmpc_get_hardware_num_blocks:
+ case OMPRTL___kmpc_single:
+ case OMPRTL___kmpc_end_single:
+ case OMPRTL___kmpc_master:
+ case OMPRTL___kmpc_end_master:
+ case OMPRTL___kmpc_barrier:
+ case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
+ case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
+ case OMPRTL___kmpc_nvptx_end_reduce_nowait:
+ case OMPRTL___kmpc_error:
+ case OMPRTL___kmpc_flush:
+ case OMPRTL___kmpc_get_hardware_thread_id_in_block:
+ case OMPRTL___kmpc_get_warp_size:
+ case OMPRTL_omp_get_thread_num:
+ case OMPRTL_omp_get_num_threads:
+ case OMPRTL_omp_get_max_threads:
+ case OMPRTL_omp_in_parallel:
+ case OMPRTL_omp_get_dynamic:
+ case OMPRTL_omp_get_cancellation:
+ case OMPRTL_omp_get_nested:
+ case OMPRTL_omp_get_schedule:
+ case OMPRTL_omp_get_thread_limit:
+ case OMPRTL_omp_get_supported_active_levels:
+ case OMPRTL_omp_get_max_active_levels:
+ case OMPRTL_omp_get_level:
+ case OMPRTL_omp_get_ancestor_thread_num:
+ case OMPRTL_omp_get_team_size:
+ case OMPRTL_omp_get_active_level:
+ case OMPRTL_omp_in_final:
+ case OMPRTL_omp_get_proc_bind:
+ case OMPRTL_omp_get_num_places:
+ case OMPRTL_omp_get_num_procs:
+ case OMPRTL_omp_get_place_proc_ids:
+ case OMPRTL_omp_get_place_num:
+ case OMPRTL_omp_get_partition_num_places:
+ case OMPRTL_omp_get_partition_place_nums:
+ case OMPRTL_omp_get_wtime:
break;
+ case OMPRTL___kmpc_distribute_static_init_4:
+ case OMPRTL___kmpc_distribute_static_init_4u:
+ case OMPRTL___kmpc_distribute_static_init_8:
+ case OMPRTL___kmpc_distribute_static_init_8u:
+ case OMPRTL___kmpc_for_static_init_4:
+ case OMPRTL___kmpc_for_static_init_4u:
+ case OMPRTL___kmpc_for_static_init_8:
+ case OMPRTL___kmpc_for_static_init_8u: {
+ // Check the schedule and allow static schedule in SPMD mode.
+ unsigned ScheduleArgOpNo = 2;
+ auto *ScheduleTypeCI =
+ dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
+ unsigned ScheduleTypeVal =
+ ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
+ switch (OMPScheduleType(ScheduleTypeVal)) {
+ case OMPScheduleType::UnorderedStatic:
+ case OMPScheduleType::UnorderedStaticChunked:
+ case OMPScheduleType::OrderedDistribute:
+ case OMPScheduleType::OrderedDistributeChunked:
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ };
+ } break;
+ case OMPRTL___kmpc_target_init:
+ KernelInitCB = &CB;
+ break;
+ case OMPRTL___kmpc_target_deinit:
+ KernelDeinitCB = &CB;
+ break;
+ case OMPRTL___kmpc_parallel_51:
+ if (!handleParallel51(A, CB))
+ indicatePessimisticFixpoint();
+ return;
+ case OMPRTL___kmpc_omp_task:
+ // We do not look into tasks right now, just give up.
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.insert(&CB);
+ ReachedUnknownParallelRegions.insert(&CB);
+ break;
+ case OMPRTL___kmpc_alloc_shared:
+ case OMPRTL___kmpc_free_shared:
+ // Return without setting a fixpoint, to be resolved in updateImpl.
+ return;
default:
+ // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
+ // generally. However, they do not hide parallel regions.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
break;
- };
- } break;
- case OMPRTL___kmpc_target_init:
- KernelInitCB = &CB;
- break;
- case OMPRTL___kmpc_target_deinit:
- KernelDeinitCB = &CB;
- break;
- case OMPRTL___kmpc_parallel_51:
- if (!handleParallel51(A, CB))
- indicatePessimisticFixpoint();
- return;
- case OMPRTL___kmpc_omp_task:
- // We do not look into tasks right now, just give up.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- ReachedUnknownParallelRegions.insert(&CB);
- break;
- case OMPRTL___kmpc_alloc_shared:
- case OMPRTL___kmpc_free_shared:
- // Return without setting a fixpoint, to be resolved in updateImpl.
+ }
+ // All other OpenMP runtime calls will not reach parallel regions so they
+ // can be safely ignored for now. Since it is a known OpenMP runtime call
+ // we have now modeled all effects and there is no need for any update.
+ indicateOptimisticFixpoint();
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ CheckCallee(getAssociatedFunction(), 1);
return;
- default:
- // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
- // generally. However, they do not hide parallel regions.
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
- break;
}
- // All other OpenMP runtime calls will not reach parallel regions so they
- // can be safely ignored for now. Since it is a known OpenMP runtime call we
- // have now modeled all effects and there is no need for any update.
- indicateOptimisticFixpoint();
+ const auto &OptimisticEdges = AACE->getOptimisticEdges();
+ for (auto *Callee : OptimisticEdges) {
+ CheckCallee(Callee, OptimisticEdges.size());
+ if (isAtFixpoint())
+ break;
+ }
}
ChangeStatus updateImpl(Attributor &A) override {
@@ -4937,64 +4959,83 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// call site specific liveness information and then it makes
// sense to specialize attributes for call sites arguments instead of
// redirecting requests to the callee argument.
- Function *F = getAssociatedFunction();
-
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
- const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
-
- // If F is not a runtime function, propagate the AAKernelInfo of the callee.
- if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
- const IRPosition &FnPos = IRPosition::function(*F);
- auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
- if (!FnAA)
- return indicatePessimisticFixpoint();
- if (getState() == FnAA->getState())
- return ChangeStatus::UNCHANGED;
- getState() = FnAA->getState();
- return ChangeStatus::CHANGED;
- }
-
KernelInfoState StateBefore = getState();
- CallBase &CB = cast<CallBase>(getAssociatedValue());
- if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
- if (!handleParallel51(A, CB))
- return indicatePessimisticFixpoint();
- return StateBefore == getState() ? ChangeStatus::UNCHANGED
- : ChangeStatus::CHANGED;
- }
-
- // F is a runtime function that allocates or frees memory, check
- // AAHeapToStack and AAHeapToShared.
- assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
- It->getSecond() == OMPRTL___kmpc_free_shared) &&
- "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+ auto CheckCallee = [&](Function *F, int NumCallees) {
+ const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+
+ // If F is not a runtime function, propagate the AAKernelInfo of the
+ // callee.
+ if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+ const IRPosition &FnPos = IRPosition::function(*F);
+ auto *FnAA =
+ A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+ if (!FnAA)
+ return indicatePessimisticFixpoint();
+ if (getState() == FnAA->getState())
+ return ChangeStatus::UNCHANGED;
+ getState() = FnAA->getState();
+ return ChangeStatus::CHANGED;
+ }
+ if (NumCallees > 1)
+ return indicatePessimisticFixpoint();
- auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
- auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
- *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
-
- RuntimeFunction RF = It->getSecond();
+ CallBase &CB = cast<CallBase>(getAssociatedValue());
+ if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
+ if (!handleParallel51(A, CB))
+ return indicatePessimisticFixpoint();
+ return StateBefore == getState() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
- switch (RF) {
- // If neither HeapToStack nor HeapToShared assume the call is removed,
- // assume SPMD incompatibility.
- case OMPRTL___kmpc_alloc_shared:
- if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
- (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
- SPMDCompatibilityTracker.insert(&CB);
- break;
- case OMPRTL___kmpc_free_shared:
- if ((!HeapToStackAA ||
- !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
- (!HeapToSharedAA ||
- !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ // F is a runtime function that allocates or frees memory, check
+ // AAHeapToStack and AAHeapToShared.
+ assert(
+ (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
+ It->getSecond() == OMPRTL___kmpc_free_shared) &&
+ "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+
+ auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+ auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
+ *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+
+ RuntimeFunction RF = It->getSecond();
+
+ switch (RF) {
+ // If neither HeapToStack nor HeapToShared assume the call is removed,
+ // assume SPMD incompatibility.
+ case OMPRTL___kmpc_alloc_shared:
+ if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
+ (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ case OMPRTL___kmpc_free_shared:
+ if ((!HeapToStackAA ||
+ !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
+ (!HeapToSharedAA ||
+ !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+ SPMDCompatibilityTracker.insert(&CB);
+ break;
+ default:
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
- break;
- default:
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.insert(&CB);
+ }
+ return ChangeStatus::CHANGED;
+ };
+
+ const auto *AACE =
+ A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+ if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+ CheckCallee(getAssociatedFunction(), 1);
+ } else {
+ const auto &OptimisticEdges = AACE->getOptimisticEdges();
+ for (auto *Callee : OptimisticEdges) {
+ CheckCallee(Callee, OptimisticEdges.size());
+ if (isAtFixpoint())
+ break;
+ }
}
return StateBefore == getState() ? ChangeStatus::UNCHANGED
diff --git a/llvm/test/Transforms/OpenMP/spmdization_indirect.ll b/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
index 4ca646470eabe1a..04b0e50d4bce4a1 100644
--- a/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
+++ b/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
@@ -16,15 +16,15 @@
;.
; AMDGPU: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [23 x i8] c"
; AMDGPU: @[[GLOB1:[0-9]+]] = private unnamed_addr constant [[STRUCT_IDENT_T:%.*]] { i32 0, i32 2, i32 0, i32 0, ptr @[[GLOB0]] }, align 8
-; AMDGPU: @[[SPMD_CALLEES_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\\.-]+]] = local_unnamed_addr constant [[STRUCT_KERNELENVIRONMENTTY:%.*]] { [[STRUCT_CONFIGURATIONENVIRONMENTTY:%.*]] { i8 0, i8 0, i8 1 }, ptr @[[GLOB1]], ptr null }
-; AMDGPU: @[[SPMD_CALLEES_METADATA_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\\.-]+]] = local_unnamed_addr constant [[STRUCT_KERNELENVIRONMENTTY:%.*]] { [[STRUCT_CONFIGURATIONENVIRONMENTTY:%.*]] { i8 0, i8 0, i8 1 }, ptr @[[GLOB1]], ptr null }
+; AMDGPU: @[[SPMD_CALLEES_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\
</pre></details>
https://github.com/llvm/llvm-project/pull/65836
More information about the llvm-commits
mailing list