[llvm] [OpenMPOpt] Allow indirect calls in AAKernelInfoCallSite (PR #65836)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 8 22:43:01 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-openmp

<details><summary>Changes</summary><pre>
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 63493eb78c451a6..44aed2697842201 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3532,6 +3532,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
 
+  /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
+  /// unknown callees.
+  static bool requiresCalleeForCallBase() { return false; }
+
   /// Statistics are tracked as part of manifest for now.
   void trackStatistics() const override {}
 
@@ -4797,139 +4801,157 @@ struct AAKernelInfoCallSite : AAKernelInfo {
     // we will handle them explicitly in the switch below. If it is not, we
     // will use an AAKernelInfo object on the callee to gather information and
     // merge that into the current state. The latter happens in the updateImpl.
-    Function *Callee = getAssociatedFunction();
-    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
-    const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
-    if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
-      // Unknown caller or declarations are not analyzable, we give up.
-      if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
-
-        // Unknown callees might contain parallel regions, except if they have
-        // an appropriate assumption attached.
-        if (!AssumptionAA ||
-            !(AssumptionAA->hasAssumption("omp_no_openmp") ||
-              AssumptionAA->hasAssumption("omp_no_parallelism")))
-          ReachedUnknownParallelRegions.insert(&CB);
-
-        // If SPMDCompatibilityTracker is not fixed, we need to give up on the
-        // idea we can run something unknown in SPMD-mode.
-        if (!SPMDCompatibilityTracker.isAtFixpoint()) {
-          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-          SPMDCompatibilityTracker.insert(&CB);
-        }
+    auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
+      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
+      if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+        // Unknown caller or declarations are not analyzable, we give up.
+        if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
+
+          // Unknown callees might contain parallel regions, except if they have
+          // an appropriate assumption attached.
+          if (!AssumptionAA ||
+              !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+                AssumptionAA->hasAssumption("omp_no_parallelism")))
+            ReachedUnknownParallelRegions.insert(&CB);
+
+          // If SPMDCompatibilityTracker is not fixed, we need to give up on the
+          // idea we can run something unknown in SPMD-mode.
+          if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+            SPMDCompatibilityTracker.insert(&CB);
+          }
 
-        // We have updated the state for this unknown call properly, there won't
-        // be any change so we indicate a fixpoint.
-        indicateOptimisticFixpoint();
+          // We have updated the state for this unknown call properly, there
+          // won't be any change so we indicate a fixpoint.
+          indicateOptimisticFixpoint();
+        }
+        // If the callee is known and can be used in IPO, we will update the
+        // state based on the callee state in updateImpl.
+        return;
+      }
+      if (NumCallees > 1) {
+        indicatePessimisticFixpoint();
+        return;
       }
-      // If the callee is known and can be used in IPO, we will update the state
-      // based on the callee state in updateImpl.
-      return;
-    }
 
-    RuntimeFunction RF = It->getSecond();
-    switch (RF) {
-    // All the functions we know are compatible with SPMD mode.
-    case OMPRTL___kmpc_is_spmd_exec_mode:
-    case OMPRTL___kmpc_distribute_static_fini:
-    case OMPRTL___kmpc_for_static_fini:
-    case OMPRTL___kmpc_global_thread_num:
-    case OMPRTL___kmpc_get_hardware_num_threads_in_block:
-    case OMPRTL___kmpc_get_hardware_num_blocks:
-    case OMPRTL___kmpc_single:
-    case OMPRTL___kmpc_end_single:
-    case OMPRTL___kmpc_master:
-    case OMPRTL___kmpc_end_master:
-    case OMPRTL___kmpc_barrier:
-    case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
-    case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
-    case OMPRTL___kmpc_nvptx_end_reduce_nowait:
-    case OMPRTL___kmpc_error:
-    case OMPRTL___kmpc_flush:
-    case OMPRTL___kmpc_get_hardware_thread_id_in_block:
-    case OMPRTL___kmpc_get_warp_size:
-    case OMPRTL_omp_get_thread_num:
-    case OMPRTL_omp_get_num_threads:
-    case OMPRTL_omp_get_max_threads:
-    case OMPRTL_omp_in_parallel:
-    case OMPRTL_omp_get_dynamic:
-    case OMPRTL_omp_get_cancellation:
-    case OMPRTL_omp_get_nested:
-    case OMPRTL_omp_get_schedule:
-    case OMPRTL_omp_get_thread_limit:
-    case OMPRTL_omp_get_supported_active_levels:
-    case OMPRTL_omp_get_max_active_levels:
-    case OMPRTL_omp_get_level:
-    case OMPRTL_omp_get_ancestor_thread_num:
-    case OMPRTL_omp_get_team_size:
-    case OMPRTL_omp_get_active_level:
-    case OMPRTL_omp_in_final:
-    case OMPRTL_omp_get_proc_bind:
-    case OMPRTL_omp_get_num_places:
-    case OMPRTL_omp_get_num_procs:
-    case OMPRTL_omp_get_place_proc_ids:
-    case OMPRTL_omp_get_place_num:
-    case OMPRTL_omp_get_partition_num_places:
-    case OMPRTL_omp_get_partition_place_nums:
-    case OMPRTL_omp_get_wtime:
-      break;
-    case OMPRTL___kmpc_distribute_static_init_4:
-    case OMPRTL___kmpc_distribute_static_init_4u:
-    case OMPRTL___kmpc_distribute_static_init_8:
-    case OMPRTL___kmpc_distribute_static_init_8u:
-    case OMPRTL___kmpc_for_static_init_4:
-    case OMPRTL___kmpc_for_static_init_4u:
-    case OMPRTL___kmpc_for_static_init_8:
-    case OMPRTL___kmpc_for_static_init_8u: {
-      // Check the schedule and allow static schedule in SPMD mode.
-      unsigned ScheduleArgOpNo = 2;
-      auto *ScheduleTypeCI =
-          dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
-      unsigned ScheduleTypeVal =
-          ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
-      switch (OMPScheduleType(ScheduleTypeVal)) {
-      case OMPScheduleType::UnorderedStatic:
-      case OMPScheduleType::UnorderedStaticChunked:
-      case OMPScheduleType::OrderedDistribute:
-      case OMPScheduleType::OrderedDistributeChunked:
+      RuntimeFunction RF = It->getSecond();
+      switch (RF) {
+      // All the functions we know are compatible with SPMD mode.
+      case OMPRTL___kmpc_is_spmd_exec_mode:
+      case OMPRTL___kmpc_distribute_static_fini:
+      case OMPRTL___kmpc_for_static_fini:
+      case OMPRTL___kmpc_global_thread_num:
+      case OMPRTL___kmpc_get_hardware_num_threads_in_block:
+      case OMPRTL___kmpc_get_hardware_num_blocks:
+      case OMPRTL___kmpc_single:
+      case OMPRTL___kmpc_end_single:
+      case OMPRTL___kmpc_master:
+      case OMPRTL___kmpc_end_master:
+      case OMPRTL___kmpc_barrier:
+      case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
+      case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
+      case OMPRTL___kmpc_nvptx_end_reduce_nowait:
+      case OMPRTL___kmpc_error:
+      case OMPRTL___kmpc_flush:
+      case OMPRTL___kmpc_get_hardware_thread_id_in_block:
+      case OMPRTL___kmpc_get_warp_size:
+      case OMPRTL_omp_get_thread_num:
+      case OMPRTL_omp_get_num_threads:
+      case OMPRTL_omp_get_max_threads:
+      case OMPRTL_omp_in_parallel:
+      case OMPRTL_omp_get_dynamic:
+      case OMPRTL_omp_get_cancellation:
+      case OMPRTL_omp_get_nested:
+      case OMPRTL_omp_get_schedule:
+      case OMPRTL_omp_get_thread_limit:
+      case OMPRTL_omp_get_supported_active_levels:
+      case OMPRTL_omp_get_max_active_levels:
+      case OMPRTL_omp_get_level:
+      case OMPRTL_omp_get_ancestor_thread_num:
+      case OMPRTL_omp_get_team_size:
+      case OMPRTL_omp_get_active_level:
+      case OMPRTL_omp_in_final:
+      case OMPRTL_omp_get_proc_bind:
+      case OMPRTL_omp_get_num_places:
+      case OMPRTL_omp_get_num_procs:
+      case OMPRTL_omp_get_place_proc_ids:
+      case OMPRTL_omp_get_place_num:
+      case OMPRTL_omp_get_partition_num_places:
+      case OMPRTL_omp_get_partition_place_nums:
+      case OMPRTL_omp_get_wtime:
         break;
+      case OMPRTL___kmpc_distribute_static_init_4:
+      case OMPRTL___kmpc_distribute_static_init_4u:
+      case OMPRTL___kmpc_distribute_static_init_8:
+      case OMPRTL___kmpc_distribute_static_init_8u:
+      case OMPRTL___kmpc_for_static_init_4:
+      case OMPRTL___kmpc_for_static_init_4u:
+      case OMPRTL___kmpc_for_static_init_8:
+      case OMPRTL___kmpc_for_static_init_8u: {
+        // Check the schedule and allow static schedule in SPMD mode.
+        unsigned ScheduleArgOpNo = 2;
+        auto *ScheduleTypeCI =
+            dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
+        unsigned ScheduleTypeVal =
+            ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
+        switch (OMPScheduleType(ScheduleTypeVal)) {
+        case OMPScheduleType::UnorderedStatic:
+        case OMPScheduleType::UnorderedStaticChunked:
+        case OMPScheduleType::OrderedDistribute:
+        case OMPScheduleType::OrderedDistributeChunked:
+          break;
+        default:
+          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+          SPMDCompatibilityTracker.insert(&CB);
+          break;
+        };
+      } break;
+      case OMPRTL___kmpc_target_init:
+        KernelInitCB = &CB;
+        break;
+      case OMPRTL___kmpc_target_deinit:
+        KernelDeinitCB = &CB;
+        break;
+      case OMPRTL___kmpc_parallel_51:
+        if (!handleParallel51(A, CB))
+          indicatePessimisticFixpoint();
+        return;
+      case OMPRTL___kmpc_omp_task:
+        // We do not look into tasks right now, just give up.
+        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+        SPMDCompatibilityTracker.insert(&CB);
+        ReachedUnknownParallelRegions.insert(&CB);
+        break;
+      case OMPRTL___kmpc_alloc_shared:
+      case OMPRTL___kmpc_free_shared:
+        // Return without setting a fixpoint, to be resolved in updateImpl.
+        return;
       default:
+        // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
+        // generally. However, they do not hide parallel regions.
         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
         SPMDCompatibilityTracker.insert(&CB);
         break;
-      };
-    } break;
-    case OMPRTL___kmpc_target_init:
-      KernelInitCB = &CB;
-      break;
-    case OMPRTL___kmpc_target_deinit:
-      KernelDeinitCB = &CB;
-      break;
-    case OMPRTL___kmpc_parallel_51:
-      if (!handleParallel51(A, CB))
-        indicatePessimisticFixpoint();
-      return;
-    case OMPRTL___kmpc_omp_task:
-      // We do not look into tasks right now, just give up.
-      SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-      SPMDCompatibilityTracker.insert(&CB);
-      ReachedUnknownParallelRegions.insert(&CB);
-      break;
-    case OMPRTL___kmpc_alloc_shared:
-    case OMPRTL___kmpc_free_shared:
-      // Return without setting a fixpoint, to be resolved in updateImpl.
+      }
+      // All other OpenMP runtime calls will not reach parallel regions so they
+      // can be safely ignored for now. Since it is a known OpenMP runtime call
+      // we have now modeled all effects and there is no need for any update.
+      indicateOptimisticFixpoint();
+    };
+
+    const auto *AACE =
+        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+    if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+      CheckCallee(getAssociatedFunction(), 1);
       return;
-    default:
-      // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
-      // generally. However, they do not hide parallel regions.
-      SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-      SPMDCompatibilityTracker.insert(&CB);
-      break;
     }
-    // All other OpenMP runtime calls will not reach parallel regions so they
-    // can be safely ignored for now. Since it is a known OpenMP runtime call we
-    // have now modeled all effects and there is no need for any update.
-    indicateOptimisticFixpoint();
+    const auto &OptimisticEdges = AACE->getOptimisticEdges();
+    for (auto *Callee : OptimisticEdges) {
+      CheckCallee(Callee, OptimisticEdges.size());
+      if (isAtFixpoint())
+        break;
+    }
   }
 
   ChangeStatus updateImpl(Attributor &A) override {
@@ -4937,64 +4959,83 @@ struct AAKernelInfoCallSite : AAKernelInfo {
     //       call site specific liveness information and then it makes
     //       sense to specialize attributes for call sites arguments instead of
     //       redirecting requests to the callee argument.
-    Function *F = getAssociatedFunction();
-
     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
-    const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
-
-    // If F is not a runtime function, propagate the AAKernelInfo of the callee.
-    if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
-      const IRPosition &FnPos = IRPosition::function(*F);
-      auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
-      if (!FnAA)
-        return indicatePessimisticFixpoint();
-      if (getState() == FnAA->getState())
-        return ChangeStatus::UNCHANGED;
-      getState() = FnAA->getState();
-      return ChangeStatus::CHANGED;
-    }
-
     KernelInfoState StateBefore = getState();
-    CallBase &CB = cast<CallBase>(getAssociatedValue());
-    if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
-      if (!handleParallel51(A, CB))
-        return indicatePessimisticFixpoint();
-      return StateBefore == getState() ? ChangeStatus::UNCHANGED
-                                       : ChangeStatus::CHANGED;
-    }
-
-    // F is a runtime function that allocates or frees memory, check
-    // AAHeapToStack and AAHeapToShared.
-    assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
-            It->getSecond() == OMPRTL___kmpc_free_shared) &&
-           "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
 
+    auto CheckCallee = [&](Function *F, int NumCallees) {
+      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+
+      // If F is not a runtime function, propagate the AAKernelInfo of the
+      // callee.
+      if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+        const IRPosition &FnPos = IRPosition::function(*F);
+        auto *FnAA =
+            A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+        if (!FnAA)
+          return indicatePessimisticFixpoint();
+        if (getState() == FnAA->getState())
+          return ChangeStatus::UNCHANGED;
+        getState() = FnAA->getState();
+        return ChangeStatus::CHANGED;
+      }
+      if (NumCallees > 1)
+        return indicatePessimisticFixpoint();
 
-    auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
-        *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
-    auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
-        *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
-
-    RuntimeFunction RF = It->getSecond();
+      CallBase &CB = cast<CallBase>(getAssociatedValue());
+      if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
+        if (!handleParallel51(A, CB))
+          return indicatePessimisticFixpoint();
+        return StateBefore == getState() ? ChangeStatus::UNCHANGED
+                                         : ChangeStatus::CHANGED;
+      }
 
-    switch (RF) {
-    // If neither HeapToStack nor HeapToShared assume the call is removed,
-    // assume SPMD incompatibility.
-    case OMPRTL___kmpc_alloc_shared:
-      if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
-          (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
-        SPMDCompatibilityTracker.insert(&CB);
-      break;
-    case OMPRTL___kmpc_free_shared:
-      if ((!HeapToStackAA ||
-           !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
-          (!HeapToSharedAA ||
-           !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+      // F is a runtime function that allocates or frees memory, check
+      // AAHeapToStack and AAHeapToShared.
+      assert(
+          (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
+           It->getSecond() == OMPRTL___kmpc_free_shared) &&
+          "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
+
+      auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
+          *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+      auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
+          *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
+
+      RuntimeFunction RF = It->getSecond();
+
+      switch (RF) {
+      // If neither HeapToStack nor HeapToShared assume the call is removed,
+      // assume SPMD incompatibility.
+      case OMPRTL___kmpc_alloc_shared:
+        if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
+            (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
+          SPMDCompatibilityTracker.insert(&CB);
+        break;
+      case OMPRTL___kmpc_free_shared:
+        if ((!HeapToStackAA ||
+             !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
+            (!HeapToSharedAA ||
+             !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
+          SPMDCompatibilityTracker.insert(&CB);
+        break;
+      default:
+        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
         SPMDCompatibilityTracker.insert(&CB);
-      break;
-    default:
-      SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-      SPMDCompatibilityTracker.insert(&CB);
+      }
+      return ChangeStatus::CHANGED;
+    };
+
+    const auto *AACE =
+        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
+    if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
+      CheckCallee(getAssociatedFunction(), 1);
+    } else {
+      const auto &OptimisticEdges = AACE->getOptimisticEdges();
+      for (auto *Callee : OptimisticEdges) {
+        CheckCallee(Callee, OptimisticEdges.size());
+        if (isAtFixpoint())
+          break;
+      }
     }
 
     return StateBefore == getState() ? ChangeStatus::UNCHANGED
diff --git a/llvm/test/Transforms/OpenMP/spmdization_indirect.ll b/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
index 4ca646470eabe1a..04b0e50d4bce4a1 100644
--- a/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
+++ b/llvm/test/Transforms/OpenMP/spmdization_indirect.ll
@@ -16,15 +16,15 @@
 ;.
 ; AMDGPU: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [23 x i8] c"
 ; AMDGPU: @[[GLOB1:[0-9]+]] = private unnamed_addr constant [[STRUCT_IDENT_T:%.*]] { i32 0, i32 2, i32 0, i32 0, ptr @[[GLOB0]] }, align 8
-; AMDGPU: @[[SPMD_CALLEES_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\\.-]+]] = local_unnamed_addr constant [[STRUCT_KERNELENVIRONMENTTY:%.*]] { [[STRUCT_CONFIGURATIONENVIRONMENTTY:%.*]] { i8 0, i8 0, i8 1 }, ptr @[[GLOB1]], ptr null }
-; AMDGPU: @[[SPMD_CALLEES_METADATA_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\\.-]+]] = local_unnamed_addr constant [[STRUCT_KERNELENVIRONMENTTY:%.*]] { [[STRUCT_CONFIGURATIONENVIRONMENTTY:%.*]] { i8 0, i8 0, i8 1 }, ptr @[[GLOB1]], ptr null }
+; AMDGPU: @[[SPMD_CALLEES_KERNEL_ENVIRONMENT:[a-zA-Z0-9_$"\
</pre></details>

https://github.com/llvm/llvm-project/pull/65836


More information about the llvm-commits mailing list