[llvm] 5b0581a - [OpenMP] Replace function pointer uses in GPU state machine

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 10 23:50:15 PDT 2020


Author: Johannes Doerfert
Date: 2020-07-11T01:44:00-05:00
New Revision: 5b0581aedc2252481462970503d1085dc27e65eb

URL: https://github.com/llvm/llvm-project/commit/5b0581aedc2252481462970503d1085dc27e65eb
DIFF: https://github.com/llvm/llvm-project/commit/5b0581aedc2252481462970503d1085dc27e65eb.diff

LOG: [OpenMP] Replace function pointer uses in GPU state machine

In non-SPMD mode we create a state machine like code to identify the
parallel region the GPU worker threads should execute next. The
identification uses the parallel region function pointer as that allows
it to work even if the kernel (=target region) and the parallel region
are in separate TUs. However, taking the address of a function comes
with various downsides. With this patch we will identify the most common
situation and replace the function pointer use with a dummy global
symbol (for identification purposes only). That means, if the parallel
region is only called from a single target region (or kernel), we do not
use the function pointer of the parallel region to identify it but a new
global symbol.

Fixes PR46450.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D83271

Added: 
    llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll

Modified: 
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 38647b5eae68..4df65f81912b 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -52,6 +52,9 @@ STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
           "Number of OpenMP runtime function uses identified");
 STATISTIC(NumOpenMPTargetRegionKernels,
           "Number of OpenMP target region entry points (=kernels) identified");
+STATISTIC(
+    NumOpenMPParallelRegionsReplacedInGPUStateMachine,
+    "Number of OpenMP parallel regions replaced with ID in GPU state machines");
 
 #if !defined(NDEBUG)
 static constexpr auto TAG = "[" DEBUG_TYPE "]";
@@ -496,6 +499,8 @@ struct OpenMPOpt {
     if (PrintOpenMPKernels)
       printKernels();
 
+    Changed |= rewriteDeviceCodeStateMachine();
+
     Changed |= runAttributor();
 
     // Recollect uses, in case Attributor deleted any.
@@ -849,6 +854,31 @@ struct OpenMPOpt {
       AddUserArgs(*GTIdArgs[u]);
   }
 
+  /// Kernel (=GPU) optimizations and utility functions
+  ///
+  ///{{
+
+  /// Check if \p F is a kernel, hence entry point for target offloading.
+  bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
+
+  /// Cache to remember the unique kernel for a function.
+  DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
+
+  /// Find the unique kernel that will execute \p F, if any.
+  Kernel getUniqueKernelFor(Function &F);
+
+  /// Find the unique kernel that will execute \p I, if any.
+  Kernel getUniqueKernelFor(Instruction &I) {
+    return getUniqueKernelFor(*I.getFunction());
+  }
+
+  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
+  /// the cases we can avoid taking the address of a function.
+  bool rewriteDeviceCodeStateMachine();
+
+  ///
+  ///}}
+
   /// Emit a remark generically
   ///
   /// This template function can be used to generically emit a remark. The
@@ -930,6 +960,140 @@ struct OpenMPOpt {
   }
 };
 
+Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
+  if (!OMPInfoCache.ModuleSlice.count(&F))
+    return nullptr;
+
+  // Use a scope to keep the lifetime of the CachedKernel short.
+  {
+    Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
+    if (CachedKernel)
+      return *CachedKernel;
+
+    // TODO: We should use an AA to create an (optimistic and callback
+    //       call-aware) call graph. For now we stick to simple patterns that
+    //       are less powerful, basically the worst fixpoint.
+    if (isKernel(F)) {
+      CachedKernel = Kernel(&F);
+      return *CachedKernel;
+    }
+
+    CachedKernel = nullptr;
+    if (!F.hasLocalLinkage())
+      return nullptr;
+  }
+
+  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
+    if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
+      // Allow use in equality comparisons.
+      if (Cmp->isEquality())
+        return getUniqueKernelFor(*Cmp);
+      return nullptr;
+    }
+    if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
+      // Allow direct calls.
+      if (CB->isCallee(&U))
+        return getUniqueKernelFor(*CB);
+      // Allow the use in __kmpc_kernel_prepare_parallel calls.
+      if (Function *Callee = CB->getCalledFunction())
+        if (Callee->getName() == "__kmpc_kernel_prepare_parallel")
+          return getUniqueKernelFor(*CB);
+      return nullptr;
+    }
+    // Disallow every other use.
+    return nullptr;
+  };
+
+  // TODO: In the future we want to track more than just a unique kernel.
+  SmallPtrSet<Kernel, 2> PotentialKernels;
+  foreachUse(F, [&](const Use &U) {
+    PotentialKernels.insert(GetUniqueKernelForUse(U));
+  });
+
+  Kernel K = nullptr;
+  if (PotentialKernels.size() == 1)
+    K = *PotentialKernels.begin();
+
+  // Cache the result.
+  UniqueKernelMap[&F] = K;
+
+  return K;
+}
+
+bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
+  constexpr unsigned KMPC_KERNEL_PARALLEL_WORK_FN_PTR_ARG_NO = 0;
+
+  OMPInformationCache::RuntimeFunctionInfo &KernelPrepareParallelRFI =
+      OMPInfoCache.RFIs[OMPRTL___kmpc_kernel_prepare_parallel];
+
+  bool Changed = false;
+  if (!KernelPrepareParallelRFI)
+    return Changed;
+
+  for (Function *F : SCC) {
+
+    // Check if the function is uses in a __kmpc_kernel_prepare_parallel call at
+    // all.
+    bool UnknownUse = false;
+    unsigned NumDirectCalls = 0;
+
+    SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
+    foreachUse(*F, [&](Use &U) {
+      if (auto *CB = dyn_cast<CallBase>(U.getUser()))
+        if (CB->isCallee(&U)) {
+          ++NumDirectCalls;
+          return;
+        }
+
+      if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
+        ToBeReplacedStateMachineUses.push_back(&U);
+        return;
+      }
+      if (CallInst *CI = OpenMPOpt::getCallIfRegularCall(
+              *U.getUser(), &KernelPrepareParallelRFI)) {
+        ToBeReplacedStateMachineUses.push_back(&U);
+        return;
+      }
+      UnknownUse = true;
+    });
+
+    // If this ever hits, we should investigate.
+    if (UnknownUse || NumDirectCalls != 1)
+      continue;
+
+    // TODO: This is not a necessary restriction and should be lifted.
+    if (ToBeReplacedStateMachineUses.size() != 2)
+      continue;
+
+    // Even if we have __kmpc_kernel_prepare_parallel calls, we (for now) give
+    // up if the function is not called from a unique kernel.
+    Kernel K = getUniqueKernelFor(*F);
+    if (!K)
+      continue;
+
+    // We now know F is a parallel body function called only from the kernel K.
+    // We also identified the state machine uses in which we replace the
+    // function pointer by a new global symbol for identification purposes. This
+    // ensures only direct calls to the function are left.
+
+    Module &M = *F->getParent();
+    Type *Int8Ty = Type::getInt8Ty(M.getContext());
+
+    auto *ID = new GlobalVariable(
+        M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
+        UndefValue::get(Int8Ty), F->getName() + ".ID");
+
+    for (Use *U : ToBeReplacedStateMachineUses)
+      U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
+
+    ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
+
+    Changed = true;
+  }
+
+  return Changed;
+}
+
 /// Abstract Attribute for tracking ICV values.
 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
   using Base = StateWrapper<BooleanState, AbstractAttribute>;

diff  --git a/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll
new file mode 100644
index 000000000000..0a8d7a9d231a
--- /dev/null
+++ b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll
@@ -0,0 +1,153 @@
+; RUN: opt -S -passes=openmpopt -pass-remarks=openmp-opt -openmp-print-gpu-kernels < %s | FileCheck %s
+; RUN: opt -S        -openmpopt -pass-remarks=openmp-opt -openmp-print-gpu-kernels < %s | FileCheck %s
+
+; C input used for this test:
+
+; void bar(void) {
+;     #pragma omp parallel
+;     { }
+; }
+; void foo(void) {
+;   #pragma omp target teams
+;   {
+;     #pragma omp parallel
+;     {}
+;     bar();
+;     #pragma omp parallel
+;     {}
+;   }
+; }
+
+; Verify we replace the function pointer uses for the first and last outlined
+; region (1 and 3) but not for the middle one (2) because it could be called from
+; another kernel.
+
+; CHECK-DAG: @__omp_outlined__1_wrapper.ID = private constant i8 undef
+; CHECK-DAG: @__omp_outlined__3_wrapper.ID = private constant i8 undef
+
+; CHECK-DAG:   icmp eq i8* %5, @__omp_outlined__1_wrapper.ID
+; CHECK-DAG:   icmp eq i8* %7, @__omp_outlined__3_wrapper.ID
+
+; CHECK-DAG:   call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__1_wrapper.ID)
+; CHECK-DAG:   call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*))
+; CHECK-DAG:   call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__3_wrapper.ID)
+
+
+%struct.ident_t = type { i32, i32, i32, i32, i8* }
+
+define internal void @__omp_offloading_35_a1e179_foo_l7_worker() {
+entry:
+  %work_fn = alloca i8*, align 8
+  %exec_status = alloca i8, align 1
+  store i8* null, i8** %work_fn, align 8
+  store i8 0, i8* %exec_status, align 1
+  br label %.await.work
+
+.await.work:                                      ; preds = %.barrier.parallel, %entry
+  call void @__kmpc_barrier_simple_spmd(%struct.ident_t* null, i32 0)
+  %0 = call i1 @__kmpc_kernel_parallel(i8** %work_fn)
+  %1 = zext i1 %0 to i8
+  store i8 %1, i8* %exec_status, align 1
+  %2 = load i8*, i8** %work_fn, align 8
+  %should_terminate = icmp eq i8* %2, null
+  br i1 %should_terminate, label %.exit, label %.select.workers
+
+.select.workers:                                  ; preds = %.await.work
+  %3 = load i8, i8* %exec_status, align 1
+  %is_active = icmp ne i8 %3, 0
+  br i1 %is_active, label %.execute.parallel, label %.barrier.parallel
+
+.execute.parallel:                                ; preds = %.select.workers
+  %4 = call i32 @__kmpc_global_thread_num(%struct.ident_t* null)
+  %5 = load i8*, i8** %work_fn, align 8
+  %work_match = icmp eq i8* %5, bitcast (void ()* @__omp_outlined__1_wrapper to i8*)
+  br i1 %work_match, label %.execute.fn, label %.check.next
+
+.execute.fn:                                      ; preds = %.execute.parallel
+  call void @__omp_outlined__1_wrapper()
+  br label %.terminate.parallel
+
+.check.next:                                      ; preds = %.execute.parallel
+  %6 = load i8*, i8** %work_fn, align 8
+  %work_match1 = icmp eq i8* %6, bitcast (void ()* @__omp_outlined__2_wrapper to i8*)
+  br i1 %work_match1, label %.execute.fn2, label %.check.next3
+
+.execute.fn2:                                     ; preds = %.check.next
+  call void @__omp_outlined__2_wrapper()
+  br label %.terminate.parallel
+
+.check.next3:                                     ; preds = %.check.next
+  %7 = load i8*, i8** %work_fn, align 8
+  %work_match4 = icmp eq i8* %7, bitcast (void ()* @__omp_outlined__3_wrapper to i8*)
+  br i1 %work_match4, label %.execute.fn5, label %.check.next6
+
+.execute.fn5:                                     ; preds = %.check.next3
+  call void @__omp_outlined__3_wrapper()
+  br label %.terminate.parallel
+
+.check.next6:                                     ; preds = %.check.next3
+  %8 = bitcast i8* %2 to void ()*
+  call void %8()
+  br label %.terminate.parallel
+
+.terminate.parallel:                              ; preds = %.check.next6, %.execute.fn5, %.execute.fn2, %.execute.fn
+  call void @__kmpc_kernel_end_parallel()
+  br label %.barrier.parallel
+
+.barrier.parallel:                                ; preds = %.terminate.parallel, %.select.workers
+  call void @__kmpc_barrier_simple_spmd(%struct.ident_t* null, i32 0)
+  br label %.await.work
+
+.exit:                                            ; preds = %.await.work
+  ret void
+}
+
+define weak void @__omp_offloading_35_a1e179_foo_l7() {
+  call void @__omp_offloading_35_a1e179_foo_l7_worker()
+  call void @__omp_outlined__()
+  ret void
+}
+
+define internal void @__omp_outlined__() {
+  call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__1_wrapper to i8*))
+  call void @bar()
+  call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__3_wrapper to i8*))
+  ret void
+}
+
+define internal void @__omp_outlined__1() {
+  ret void
+}
+
+define internal void @__omp_outlined__1_wrapper() {
+  call void @__omp_outlined__1()
+  ret void
+}
+
+define hidden void @bar() {
+  call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*))
+  ret void
+}
+
+define internal void @__omp_outlined__2_wrapper() {
+  ret void
+}
+
+define internal void @__omp_outlined__3_wrapper() {
+  ret void
+}
+
+declare void @__kmpc_kernel_prepare_parallel(i8* %WorkFn)
+
+declare zeroext i1 @__kmpc_kernel_parallel(i8** nocapture %WorkFn)
+
+declare void @__kmpc_kernel_end_parallel()
+
+declare void @__kmpc_barrier_simple_spmd(%struct.ident_t* nocapture readnone %loc_ref, i32 %tid)
+
+declare i32 @__kmpc_global_thread_num(%struct.ident_t* nocapture readnone)
+
+
+!nvvm.annotations = !{!0}
+
+!0 = !{void ()* @__omp_offloading_35_a1e179_foo_l7, !"kernel", i32 1}


        


More information about the llvm-commits mailing list