[llvm] 0bdde9d - [OpenMP] Make OpenMPOpt aware of the OpenMP runtime's status

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 26 11:23:52 PST 2023


Author: Joseph Huber
Date: 2023-01-26T13:23:44-06:00
New Revision: 0bdde9dfb9b1dbfabee147c196db820e1f5dca1f

URL: https://github.com/llvm/llvm-project/commit/0bdde9dfb9b1dbfabee147c196db820e1f5dca1f
DIFF: https://github.com/llvm/llvm-project/commit/0bdde9dfb9b1dbfabee147c196db820e1f5dca1f.diff

LOG: [OpenMP] Make OpenMPOpt aware of the OpenMP runtime's status

The `OpenMPOpt` pass contains optimizations that generate new calls into
the OpenMP runtime. This causes problems if we are in a state where the
runtime has already been linked statically. Generating these new calls
will result in them never being resolved. We should indicate if we are
in a "post-link" LTO phase and prevent OpenMPOpt from generating new
runtime calls.

Generally, it's not desireable for passes to maintain state about the
context in which they're called. But this is the only reasonable
solution to static linking when we have a pass that generates new
runtime calls.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D142646

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
    llvm/lib/Passes/PassBuilderPipelines.cpp
    llvm/lib/Passes/PassRegistry.def
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp
    llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
    llvm/test/Transforms/OpenMP/spmdization.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index bf08336663b6b..73aee47bfef50 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -37,13 +37,25 @@ KernelSet getDeviceKernels(Module &M);
 /// OpenMP optimizations pass.
 class OpenMPOptPass : public PassInfoMixin<OpenMPOptPass> {
 public:
+  OpenMPOptPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
+  OpenMPOptPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}
+
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+private:
+  const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
 class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
 public:
+  OpenMPOptCGSCCPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
+  OpenMPOptCGSCCPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
+private:
+  const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 0762c535f7f58..b94f3b52b8d29 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -1604,7 +1604,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
   }
 
   // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata present.
-  MPM.addPass(OpenMPOptPass());
+  MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
 
   // Remove unused virtual tables to improve the quality of code generated by
   // whole-program devirtualization and bitset lowering.
@@ -1808,7 +1808,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
   addVectorPasses(Level, MainFPM, /* IsFullLTO */ true);
 
   // Run the OpenMPOpt CGSCC pass again late.
-  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(OpenMPOptCGSCCPass()));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
+      OpenMPOptCGSCCPass(ThinOrFullLTOPhase::FullLTOPostLink)));
 
   invokePeepholeEPCallbacks(MainFPM, Level);
   MainFPM.addPass(JumpThreadingPass());

diff  --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index ad44d86ea1a7d..10af4160c5452 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -44,6 +44,7 @@ MODULE_PASS("always-inline", AlwaysInlinerPass())
 MODULE_PASS("attributor", AttributorPass())
 MODULE_PASS("annotation2metadata", Annotation2MetadataPass())
 MODULE_PASS("openmp-opt", OpenMPOptPass())
+MODULE_PASS("openmp-opt-postlink", OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
 MODULE_PASS("called-value-propagation", CalledValuePropagationPass())
 MODULE_PASS("canonicalize-aliases", CanonicalizeAliasesPass())
 MODULE_PASS("cg-profile", CGProfilePass())

diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 928380890186f..677cd2749fc2d 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -188,9 +188,9 @@ struct AAICVTracker;
 struct OMPInformationCache : public InformationCache {
   OMPInformationCache(Module &M, AnalysisGetter &AG,
                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
-                      KernelSet &Kernels)
+                      KernelSet &Kernels, bool OpenMPPostLink)
       : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
-        Kernels(Kernels) {
+        Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
 
     OMPBuilder.initialize();
     initializeRuntimeFunctions(M);
@@ -448,6 +448,24 @@ struct OMPInformationCache : public InformationCache {
       CI->setCallingConv(Fn->getCallingConv());
   }
 
+  // Helper function to determine if it's legal to create a call to the runtime
+  // functions.
+  bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
+    // We can always emit calls if we haven't yet linked in the runtime.
+    if (!OpenMPPostLink)
+      return true;
+
+    // Once the runtime has been already been linked in we cannot emit calls to
+    // any undefined functions.
+    for (RuntimeFunction Fn : Fns) {
+      RuntimeFunctionInfo &RFI = RFIs[Fn];
+
+      if (RFI.Declaration && RFI.Declaration->isDeclaration())
+        return false;
+    }
+    return true;
+  }
+
   /// Helper to initialize all runtime function information for those defined
   /// in OpenMPKinds.def.
   void initializeRuntimeFunctions(Module &M) {
@@ -523,6 +541,9 @@ struct OMPInformationCache : public InformationCache {
 
   /// Collection of known OpenMP runtime functions..
   DenseSet<const Function *> RTLFunctions;
+
+  /// Indicates if we have already linked in the OpenMP device library.
+  bool OpenMPPostLink = false;
 };
 
 template <typename Ty, bool InsertInvalidates = true>
@@ -1412,7 +1433,10 @@ struct OpenMPOpt {
       Changed |= WasSplit;
       return WasSplit;
     };
-    RFI.foreachUse(SCC, SplitMemTransfers);
+    if (OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___tgt_target_data_begin_mapper_issue,
+             OMPRTL___tgt_target_data_begin_mapper_wait}))
+      RFI.foreachUse(SCC, SplitMemTransfers);
 
     return Changed;
   }
@@ -3912,6 +3936,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
   bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
 
+    // We cannot change to SPMD mode if the runtime functions aren't availible.
+    if (!OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___kmpc_get_hardware_thread_id_in_block,
+             OMPRTL___kmpc_barrier_simple_spmd}))
+      return false;
+
     if (!SPMDCompatibilityTracker.isAssumed()) {
       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
         if (!NonCompatibleI)
@@ -4019,6 +4049,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
     if (!ReachedKnownParallelRegions.isValidState())
       return ChangeStatus::UNCHANGED;
 
+    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+    if (!OMPInfoCache.runtimeFnsAvailable(
+            {OMPRTL___kmpc_get_hardware_num_threads_in_block,
+             OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
+             OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
+      return ChangeStatus::UNCHANGED;
+
     const int InitModeArgNo = 1;
     const int InitUseStateMachineArgNo = 2;
 
@@ -4165,7 +4202,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
     BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
 
     Module &M = *Kernel->getParent();
-    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
     FunctionCallee BlockHwSizeFn =
         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
             M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
@@ -5341,7 +5377,10 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
   BumpPtrAllocator Allocator;
   CallGraphUpdater CGUpdater;
 
-  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
+  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
+  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels,
+                                PostLink);
 
   unsigned MaxFixpointIterations =
       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
@@ -5415,9 +5454,11 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
   CallGraphUpdater CGUpdater;
   CGUpdater.initialize(CG, C, AM, UR);
 
+  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
+                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
   SetVector<Function *> Functions(SCC.begin(), SCC.end());
   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
-                                /*CGSCC*/ &Functions, Kernels);
+                                /*CGSCC*/ &Functions, Kernels, PostLink);
 
   unsigned MaxFixpointIterations =
       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;

diff  --git a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
index 6dae1732b8c8a..eb83c596e2caa 100644
--- a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
+++ b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
@@ -2,7 +2,9 @@
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
 ; RUN: opt --mtriple=nvptx64--         -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
+; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU
 ; RUN: opt --mtriple=nvptx64--         -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
+; RUN: opt --mtriple=nvptx64--         -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX
 
 ;; void p0(void);
 ;; void p1(void);

diff  --git a/llvm/test/Transforms/OpenMP/spmdization.ll b/llvm/test/Transforms/OpenMP/spmdization.ll
index e7100dbf06822..277d092880499 100644
--- a/llvm/test/Transforms/OpenMP/spmdization.ll
+++ b/llvm/test/Transforms/OpenMP/spmdization.ll
@@ -2,7 +2,9 @@
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=AMDGPU
 ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=NVPTX
 ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
+; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
 ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=NVPTX-DISABLED
+; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX-DISABLED
 
 ;; void unknown(void);
 ;; void spmd_amenable(void) __attribute__((assume("ompx_spmd_amenable")));


        


More information about the llvm-commits mailing list