[clang] 1d5711c - [OpenMP] Unified entry point for SPMD & generic kernels in the device RTL

Johannes Doerfert via cfe-commits cfe-commits at lists.llvm.org
Sat Jul 10 10:33:36 PDT 2021


Author: Johannes Doerfert
Date: 2021-07-10T12:32:50-05:00
New Revision: 1d5711c3eeb62098b46d4d383f2e849b9756105d

URL: https://github.com/llvm/llvm-project/commit/1d5711c3eeb62098b46d4d383f2e849b9756105d
DIFF: https://github.com/llvm/llvm-project/commit/1d5711c3eeb62098b46d4d383f2e849b9756105d.diff

LOG: [OpenMP] Unified entry point for SPMD & generic kernels in the device RTL

In the spirit of TRegions [0], this patch provides a simpler and uniform
interface for a kernel to set up the device runtime. The OMPIRBuilder is
used for reuse in Flang. A custom state machine will be generated in the
follow up patch.

The "surplus" threads of the "master warp" will not exit early anymore
so we need to use non-aligned barriers. The new runtime will not have an
extra warp but also require these non-aligned barriers.

[0] https://link.springer.com/chapter/10.1007/978-3-030-28596-8_11

This was in parts extracted from D59319.

Reviewed By: ABataev, JonChesterfield

Differential Revision: https://reviews.llvm.org/D101976

Added: 
    openmp/libomptarget/deviceRTLs/common/include/target.h

Modified: 
    clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
    clang/lib/CodeGen/CGOpenMPRuntimeGPU.h
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp
    llvm/test/Transforms/OpenMP/replace_globalization.ll
    llvm/test/Transforms/OpenMP/single_threaded_execution.ll
    openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
    openmp/libomptarget/deviceRTLs/common/src/parallel.cu
    openmp/libomptarget/deviceRTLs/interface.h
    openmp/libomptarget/deviceRTLs/nvptx/src/target_impl.cu

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 965b3f1534d67..1cb367ec71885 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -553,63 +553,6 @@ static llvm::Value *getNVPTXLaneID(CodeGenFunction &CGF) {
                        "nvptx_lane_id");
 }
 
-/// Get the value of the thread_limit clause in the teams directive.
-/// For the 'generic' execution mode, the runtime encodes thread_limit in
-/// the launch parameters, always starting thread_limit+warpSize threads per
-/// CTA. The threads in the last warp are reserved for master execution.
-/// For the 'spmd' execution mode, all threads in a CTA are part of the team.
-static llvm::Value *getThreadLimit(CodeGenFunction &CGF,
-                                   bool IsInSPMDExecutionMode = false) {
-  CGBuilderTy &Bld = CGF.Builder;
-  auto &RT = static_cast<CGOpenMPRuntimeGPU &>(CGF.CGM.getOpenMPRuntime());
-  llvm::Value *ThreadLimit = nullptr;
-  if (IsInSPMDExecutionMode)
-    ThreadLimit = RT.getGPUNumThreads(CGF);
-  else {
-    llvm::Value *GPUNumThreads = RT.getGPUNumThreads(CGF);
-    llvm::Value *GPUWarpSize = RT.getGPUWarpSize(CGF);
-    ThreadLimit = Bld.CreateNUWSub(GPUNumThreads, GPUWarpSize, "thread_limit");
-  }
-  assert(ThreadLimit != nullptr && "Expected non-null ThreadLimit");
-  return ThreadLimit;
-}
-
-/// Get the thread id of the OMP master thread.
-/// The master thread id is the first thread (lane) of the last warp in the
-/// GPU block.  Warp size is assumed to be some power of 2.
-/// Thread id is 0 indexed.
-/// E.g: If NumThreads is 33, master id is 32.
-///      If NumThreads is 64, master id is 32.
-///      If NumThreads is 1024, master id is 992.
-static llvm::Value *getMasterThreadID(CodeGenFunction &CGF) {
-  CGBuilderTy &Bld = CGF.Builder;
-  auto &RT = static_cast<CGOpenMPRuntimeGPU &>(CGF.CGM.getOpenMPRuntime());
-  llvm::Value *NumThreads = RT.getGPUNumThreads(CGF);
-  // We assume that the warp size is a power of 2.
-  llvm::Value *Mask = Bld.CreateNUWSub(RT.getGPUWarpSize(CGF), Bld.getInt32(1));
-
-  llvm::Value *NumThreadsSubOne = Bld.CreateNUWSub(NumThreads, Bld.getInt32(1));
-  return Bld.CreateAnd(NumThreadsSubOne, Bld.CreateNot(Mask), "master_tid");
-}
-
-CGOpenMPRuntimeGPU::WorkerFunctionState::WorkerFunctionState(
-    CodeGenModule &CGM, SourceLocation Loc)
-    : WorkerFn(nullptr), CGFI(CGM.getTypes().arrangeNullaryFunction()),
-      Loc(Loc) {
-  createWorkerFunction(CGM);
-}
-
-void CGOpenMPRuntimeGPU::WorkerFunctionState::createWorkerFunction(
-    CodeGenModule &CGM) {
-  // Create an worker function with no arguments.
-
-  WorkerFn = llvm::Function::Create(
-      CGM.getTypes().GetFunctionType(CGFI), llvm::GlobalValue::InternalLinkage,
-      /*placeholder=*/"_worker", &CGM.getModule());
-  CGM.SetInternalFunctionAttributes(GlobalDecl(), WorkerFn, CGFI);
-  WorkerFn->setDoesNotRecurse();
-}
-
 CGOpenMPRuntimeGPU::ExecutionMode
 CGOpenMPRuntimeGPU::getExecutionMode() const {
   return CurrentExecutionMode;
@@ -1073,23 +1016,19 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
                                              const RegionCodeGenTy &CodeGen) {
   ExecutionRuntimeModesRAII ModeRAII(CurrentExecutionMode);
   EntryFunctionState EST;
-  WorkerFunctionState WST(CGM, D.getBeginLoc());
-  Work.clear();
   WrapperFunctionsMap.clear();
 
   // Emit target region as a standalone region.
   class NVPTXPrePostActionTy : public PrePostActionTy {
     CGOpenMPRuntimeGPU::EntryFunctionState &EST;
-    CGOpenMPRuntimeGPU::WorkerFunctionState &WST;
 
   public:
-    NVPTXPrePostActionTy(CGOpenMPRuntimeGPU::EntryFunctionState &EST,
-                         CGOpenMPRuntimeGPU::WorkerFunctionState &WST)
-        : EST(EST), WST(WST) {}
+    NVPTXPrePostActionTy(CGOpenMPRuntimeGPU::EntryFunctionState &EST)
+        : EST(EST) {}
     void Enter(CodeGenFunction &CGF) override {
       auto &RT =
           static_cast<CGOpenMPRuntimeGPU &>(CGF.CGM.getOpenMPRuntime());
-      RT.emitNonSPMDEntryHeader(CGF, EST, WST);
+      RT.emitKernelInit(CGF, EST, /* IsSPMD */ false);
       // Skip target region initialization.
       RT.setLocThreadIdInsertPt(CGF, /*AtCurrentPoint=*/true);
     }
@@ -1097,93 +1036,33 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
       auto &RT =
           static_cast<CGOpenMPRuntimeGPU &>(CGF.CGM.getOpenMPRuntime());
       RT.clearLocThreadIdInsertPt(CGF);
-      RT.emitNonSPMDEntryFooter(CGF, EST);
+      RT.emitKernelDeinit(CGF, EST, /* IsSPMD */ false);
     }
-  } Action(EST, WST);
+  } Action(EST);
   CodeGen.setAction(Action);
   IsInTTDRegion = true;
   emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
                                    IsOffloadEntry, CodeGen);
   IsInTTDRegion = false;
-
-  // Now change the name of the worker function to correspond to this target
-  // region's entry function.
-  WST.WorkerFn->setName(Twine(OutlinedFn->getName(), "_worker"));
-
-  // Create the worker function
-  emitWorkerFunction(WST);
 }
 
-// Setup NVPTX threads for master-worker OpenMP scheme.
-void CGOpenMPRuntimeGPU::emitNonSPMDEntryHeader(CodeGenFunction &CGF,
-                                                  EntryFunctionState &EST,
-                                                  WorkerFunctionState &WST) {
+void CGOpenMPRuntimeGPU::emitKernelInit(CodeGenFunction &CGF,
+                                        EntryFunctionState &EST, bool IsSPMD) {
   CGBuilderTy &Bld = CGF.Builder;
-
-  llvm::BasicBlock *WorkerBB = CGF.createBasicBlock(".worker");
-  llvm::BasicBlock *MasterCheckBB = CGF.createBasicBlock(".mastercheck");
-  llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
-  EST.ExitBB = CGF.createBasicBlock(".exit");
-
-  auto &RT = static_cast<CGOpenMPRuntimeGPU &>(CGF.CGM.getOpenMPRuntime());
-  llvm::Value *GPUThreadID = RT.getGPUThreadID(CGF);
-  llvm::Value *ThreadLimit = getThreadLimit(CGF);
-  llvm::Value *IsWorker = Bld.CreateICmpULT(GPUThreadID, ThreadLimit);
-  Bld.CreateCondBr(IsWorker, WorkerBB, MasterCheckBB);
-
-  CGF.EmitBlock(WorkerBB);
-  emitCall(CGF, WST.Loc, WST.WorkerFn);
-  CGF.EmitBranch(EST.ExitBB);
-
-  CGF.EmitBlock(MasterCheckBB);
-  GPUThreadID = RT.getGPUThreadID(CGF);
-  llvm::Value *MasterThreadID = getMasterThreadID(CGF);
-  llvm::Value *IsMaster = Bld.CreateICmpEQ(GPUThreadID, MasterThreadID);
-  Bld.CreateCondBr(IsMaster, MasterBB, EST.ExitBB);
-
-  CGF.EmitBlock(MasterBB);
-  IsInTargetMasterThreadRegion = true;
-  // SEQUENTIAL (MASTER) REGION START
-  // First action in sequential region:
-  // Initialize the state of the OpenMP runtime library on the GPU.
-  // TODO: Optimize runtime initialization and pass in correct value.
-  llvm::Value *Args[] = {getThreadLimit(CGF),
-                         Bld.getInt16(/*RequiresOMPRuntime=*/1)};
-  CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                          CGM.getModule(), OMPRTL___kmpc_kernel_init),
-                      Args);
-
-  emitGenericVarsProlog(CGF, WST.Loc);
+  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, requiresFullRuntime()));
+  IsInTargetMasterThreadRegion = IsSPMD;
+  if (!IsSPMD)
+    emitGenericVarsProlog(CGF, EST.Loc);
 }
 
-void CGOpenMPRuntimeGPU::emitNonSPMDEntryFooter(CodeGenFunction &CGF,
-                                                  EntryFunctionState &EST) {
-  IsInTargetMasterThreadRegion = false;
-  if (!CGF.HaveInsertPoint())
-    return;
-
-  emitGenericVarsEpilog(CGF);
-
-  if (!EST.ExitBB)
-    EST.ExitBB = CGF.createBasicBlock(".exit");
-
-  llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".termination.notifier");
-  CGF.EmitBranch(TerminateBB);
+void CGOpenMPRuntimeGPU::emitKernelDeinit(CodeGenFunction &CGF,
+                                          EntryFunctionState &EST,
+                                          bool IsSPMD) {
+  if (!IsSPMD)
+    emitGenericVarsEpilog(CGF);
 
-  CGF.EmitBlock(TerminateBB);
-  // Signal termination condition.
-  // TODO: Optimize runtime initialization and pass in correct value.
-  llvm::Value *Args[] = {CGF.Builder.getInt16(/*IsOMPRuntimeInitialized=*/1)};
-  CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                          CGM.getModule(), OMPRTL___kmpc_kernel_deinit),
-                      Args);
-  // Barrier to terminate worker threads.
-  syncCTAThreads(CGF);
-  // Master thread jumps to exit point.
-  CGF.EmitBranch(EST.ExitBB);
-
-  CGF.EmitBlock(EST.ExitBB);
-  EST.ExitBB = nullptr;
+  CGBuilderTy &Bld = CGF.Builder;
+  OMPBuilder.createTargetDeinit(Bld, IsSPMD, requiresFullRuntime());
 }
 
 void CGOpenMPRuntimeGPU::emitSPMDKernel(const OMPExecutableDirective &D,
@@ -1202,23 +1081,21 @@ void CGOpenMPRuntimeGPU::emitSPMDKernel(const OMPExecutableDirective &D,
   class NVPTXPrePostActionTy : public PrePostActionTy {
     CGOpenMPRuntimeGPU &RT;
     CGOpenMPRuntimeGPU::EntryFunctionState &EST;
-    const OMPExecutableDirective &D;
 
   public:
     NVPTXPrePostActionTy(CGOpenMPRuntimeGPU &RT,
-                         CGOpenMPRuntimeGPU::EntryFunctionState &EST,
-                         const OMPExecutableDirective &D)
-        : RT(RT), EST(EST), D(D) {}
+                         CGOpenMPRuntimeGPU::EntryFunctionState &EST)
+        : RT(RT), EST(EST) {}
     void Enter(CodeGenFunction &CGF) override {
-      RT.emitSPMDEntryHeader(CGF, EST, D);
+      RT.emitKernelInit(CGF, EST, /* IsSPMD */ true);
       // Skip target region initialization.
       RT.setLocThreadIdInsertPt(CGF, /*AtCurrentPoint=*/true);
     }
     void Exit(CodeGenFunction &CGF) override {
       RT.clearLocThreadIdInsertPt(CGF);
-      RT.emitSPMDEntryFooter(CGF, EST);
+      RT.emitKernelDeinit(CGF, EST, /* IsSPMD */ true);
     }
-  } Action(*this, EST, D);
+  } Action(*this, EST);
   CodeGen.setAction(Action);
   IsInTTDRegion = true;
   emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
@@ -1226,54 +1103,6 @@ void CGOpenMPRuntimeGPU::emitSPMDKernel(const OMPExecutableDirective &D,
   IsInTTDRegion = false;
 }
 
-void CGOpenMPRuntimeGPU::emitSPMDEntryHeader(
-    CodeGenFunction &CGF, EntryFunctionState &EST,
-    const OMPExecutableDirective &D) {
-  CGBuilderTy &Bld = CGF.Builder;
-
-  // Setup BBs in entry function.
-  llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute");
-  EST.ExitBB = CGF.createBasicBlock(".exit");
-
-  llvm::Value *Args[] = {getThreadLimit(CGF, /*IsInSPMDExecutionMode=*/true),
-                         /*RequiresOMPRuntime=*/
-                         Bld.getInt16(RequiresFullRuntime ? 1 : 0)};
-  CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                          CGM.getModule(), OMPRTL___kmpc_spmd_kernel_init),
-                      Args);
-
-  CGF.EmitBranch(ExecuteBB);
-
-  CGF.EmitBlock(ExecuteBB);
-
-  IsInTargetMasterThreadRegion = true;
-}
-
-void CGOpenMPRuntimeGPU::emitSPMDEntryFooter(CodeGenFunction &CGF,
-                                               EntryFunctionState &EST) {
-  IsInTargetMasterThreadRegion = false;
-  if (!CGF.HaveInsertPoint())
-    return;
-
-  if (!EST.ExitBB)
-    EST.ExitBB = CGF.createBasicBlock(".exit");
-
-  llvm::BasicBlock *OMPDeInitBB = CGF.createBasicBlock(".omp.deinit");
-  CGF.EmitBranch(OMPDeInitBB);
-
-  CGF.EmitBlock(OMPDeInitBB);
-  // DeInitialize the OMP state in the runtime; called by all active threads.
-  llvm::Value *Args[] = {/*RequiresOMPRuntime=*/
-                         CGF.Builder.getInt16(RequiresFullRuntime ? 1 : 0)};
-  CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                          CGM.getModule(), OMPRTL___kmpc_spmd_kernel_deinit_v2),
-                      Args);
-  CGF.EmitBranch(EST.ExitBB);
-
-  CGF.EmitBlock(EST.ExitBB);
-  EST.ExitBB = nullptr;
-}
-
 // Create a unique global variable to indicate the execution mode of this target
 // region. The execution mode is either 'generic', or 'spmd' depending on the
 // target directive. This variable is picked up by the offload library to setup
@@ -1290,137 +1119,6 @@ static void setPropertyExecutionMode(CodeGenModule &CGM, StringRef Name,
   CGM.addCompilerUsedGlobal(GVMode);
 }
 
-void CGOpenMPRuntimeGPU::emitWorkerFunction(WorkerFunctionState &WST) {
-  ASTContext &Ctx = CGM.getContext();
-
-  CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
-  CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, WST.WorkerFn, WST.CGFI, {},
-                    WST.Loc, WST.Loc);
-  emitWorkerLoop(CGF, WST);
-  CGF.FinishFunction();
-}
-
-void CGOpenMPRuntimeGPU::emitWorkerLoop(CodeGenFunction &CGF,
-                                        WorkerFunctionState &WST) {
-  //
-  // The workers enter this loop and wait for parallel work from the master.
-  // When the master encounters a parallel region it sets up the work + variable
-  // arguments, and wakes up the workers.  The workers first check to see if
-  // they are required for the parallel region, i.e., within the # of requested
-  // parallel threads.  The activated workers load the variable arguments and
-  // execute the parallel work.
-  //
-
-  CGBuilderTy &Bld = CGF.Builder;
-
-  llvm::BasicBlock *AwaitBB = CGF.createBasicBlock(".await.work");
-  llvm::BasicBlock *SelectWorkersBB = CGF.createBasicBlock(".select.workers");
-  llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute.parallel");
-  llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".terminate.parallel");
-  llvm::BasicBlock *BarrierBB = CGF.createBasicBlock(".barrier.parallel");
-  llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
-
-  CGF.EmitBranch(AwaitBB);
-
-  // Workers wait for work from master.
-  CGF.EmitBlock(AwaitBB);
-  // Wait for parallel work
-  syncCTAThreads(CGF);
-
-  Address WorkFn =
-      CGF.CreateDefaultAlignTempAlloca(CGF.Int8PtrTy, /*Name=*/"work_fn");
-  Address ExecStatus =
-      CGF.CreateDefaultAlignTempAlloca(CGF.Int8Ty, /*Name=*/"exec_status");
-  CGF.InitTempAlloca(ExecStatus, Bld.getInt8(/*C=*/0));
-  CGF.InitTempAlloca(WorkFn, llvm::Constant::getNullValue(CGF.Int8PtrTy));
-
-  // TODO: Optimize runtime initialization and pass in correct value.
-  llvm::Value *Args[] = {WorkFn.getPointer()};
-  llvm::Value *Ret =
-      CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                              CGM.getModule(), OMPRTL___kmpc_kernel_parallel),
-                          Args);
-  Bld.CreateStore(Bld.CreateZExt(Ret, CGF.Int8Ty), ExecStatus);
-
-  // On termination condition (workid == 0), exit loop.
-  llvm::Value *WorkID = Bld.CreateLoad(WorkFn);
-  llvm::Value *ShouldTerminate = Bld.CreateIsNull(WorkID, "should_terminate");
-  Bld.CreateCondBr(ShouldTerminate, ExitBB, SelectWorkersBB);
-
-  // Activate requested workers.
-  CGF.EmitBlock(SelectWorkersBB);
-  llvm::Value *IsActive =
-      Bld.CreateIsNotNull(Bld.CreateLoad(ExecStatus), "is_active");
-  Bld.CreateCondBr(IsActive, ExecuteBB, BarrierBB);
-
-  // Signal start of parallel region.
-  CGF.EmitBlock(ExecuteBB);
-  // Skip initialization.
-  setLocThreadIdInsertPt(CGF, /*AtCurrentPoint=*/true);
-
-  // Process work items: outlined parallel functions.
-  for (llvm::Function *W : Work) {
-    // Try to match this outlined function.
-    llvm::Value *ID = Bld.CreatePointerBitCastOrAddrSpaceCast(W, CGM.Int8PtrTy);
-
-    llvm::Value *WorkFnMatch =
-        Bld.CreateICmpEQ(Bld.CreateLoad(WorkFn), ID, "work_match");
-
-    llvm::BasicBlock *ExecuteFNBB = CGF.createBasicBlock(".execute.fn");
-    llvm::BasicBlock *CheckNextBB = CGF.createBasicBlock(".check.next");
-    Bld.CreateCondBr(WorkFnMatch, ExecuteFNBB, CheckNextBB);
-
-    // Execute this outlined function.
-    CGF.EmitBlock(ExecuteFNBB);
-
-    // Insert call to work function via shared wrapper. The shared
-    // wrapper takes two arguments:
-    //   - the parallelism level;
-    //   - the thread ID;
-    emitCall(CGF, WST.Loc, W,
-             {Bld.getInt16(/*ParallelLevel=*/0), getThreadID(CGF, WST.Loc)});
-
-    // Go to end of parallel region.
-    CGF.EmitBranch(TerminateBB);
-
-    CGF.EmitBlock(CheckNextBB);
-  }
-  // Default case: call to outlined function through pointer if the target
-  // region makes a declare target call that may contain an orphaned parallel
-  // directive.
-  auto *ParallelFnTy =
-      llvm::FunctionType::get(CGM.VoidTy, {CGM.Int16Ty, CGM.Int32Ty},
-                              /*isVarArg=*/false);
-  llvm::Value *WorkFnCast =
-      Bld.CreateBitCast(WorkID, ParallelFnTy->getPointerTo());
-  // Insert call to work function via shared wrapper. The shared
-  // wrapper takes two arguments:
-  //   - the parallelism level;
-  //   - the thread ID;
-  emitCall(CGF, WST.Loc, {ParallelFnTy, WorkFnCast},
-           {Bld.getInt16(/*ParallelLevel=*/0), getThreadID(CGF, WST.Loc)});
-  // Go to end of parallel region.
-  CGF.EmitBranch(TerminateBB);
-
-  // Signal end of parallel region.
-  CGF.EmitBlock(TerminateBB);
-  CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
-                          CGM.getModule(), OMPRTL___kmpc_kernel_end_parallel),
-                      llvm::None);
-  CGF.EmitBranch(BarrierBB);
-
-  // All active and inactive workers wait at a barrier after parallel region.
-  CGF.EmitBlock(BarrierBB);
-  // Barrier after parallel region.
-  syncCTAThreads(CGF);
-  CGF.EmitBranch(AwaitBB);
-
-  // Exit target region.
-  CGF.EmitBlock(ExitBB);
-  // Skip initialization.
-  clearLocThreadIdInsertPt(CGF);
-}
-
 void CGOpenMPRuntimeGPU::createOffloadEntry(llvm::Constant *ID,
                                               llvm::Constant *Addr,
                                               uint64_t Size, int32_t,
@@ -1806,11 +1504,8 @@ void CGOpenMPRuntimeGPU::emitParallelCall(CodeGenFunction &CGF,
     CGBuilderTy &Bld = CGF.Builder;
     llvm::Function *WFn = WrapperFunctionsMap[OutlinedFn];
     llvm::Value *ID = llvm::ConstantPointerNull::get(CGM.Int8PtrTy);
-    if (WFn) {
+    if (WFn)
       ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
-      // Remember for post-processing in worker loop.
-      Work.emplace_back(WFn);
-    }
     llvm::Value *FnPtr = Bld.CreateBitOrPointerCast(OutlinedFn, CGM.Int8PtrTy);
 
     // Create a private scope that will globalize the arguments

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h
index 3decf48cbb932..464af1294b46e 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.h
@@ -38,19 +38,7 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
   llvm::SmallVector<llvm::Function *, 16> Work;
 
   struct EntryFunctionState {
-    llvm::BasicBlock *ExitBB = nullptr;
-  };
-
-  class WorkerFunctionState {
-  public:
-    llvm::Function *WorkerFn;
-    const CGFunctionInfo &CGFI;
     SourceLocation Loc;
-
-    WorkerFunctionState(CodeGenModule &CGM, SourceLocation Loc);
-
-  private:
-    void createWorkerFunction(CodeGenModule &CGM);
   };
 
   ExecutionMode getExecutionMode() const;
@@ -60,20 +48,13 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
   /// Get barrier to synchronize all threads in a block.
   void syncCTAThreads(CodeGenFunction &CGF);
 
-  /// Emit the worker function for the current target region.
-  void emitWorkerFunction(WorkerFunctionState &WST);
+  /// Helper for target directive initialization.
+  void emitKernelInit(CodeGenFunction &CGF, EntryFunctionState &EST,
+                      bool IsSPMD);
 
-  /// Helper for worker function. Emit body of worker loop.
-  void emitWorkerLoop(CodeGenFunction &CGF, WorkerFunctionState &WST);
-
-  /// Helper for non-SPMD target entry function. Guide the master and
-  /// worker threads to their respective locations.
-  void emitNonSPMDEntryHeader(CodeGenFunction &CGF, EntryFunctionState &EST,
-                              WorkerFunctionState &WST);
-
-  /// Signal termination of OMP execution for non-SPMD target entry
-  /// function.
-  void emitNonSPMDEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST);
+  /// Helper for target directive finalization.
+  void emitKernelDeinit(CodeGenFunction &CGF, EntryFunctionState &EST,
+                        bool IsSPMD);
 
   /// Helper for generic variables globalization prolog.
   void emitGenericVarsProlog(CodeGenFunction &CGF, SourceLocation Loc,
@@ -82,13 +63,6 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
   /// Helper for generic variables globalization epilog.
   void emitGenericVarsEpilog(CodeGenFunction &CGF, bool WithSPMDCheck = false);
 
-  /// Helper for SPMD mode target directive's entry function.
-  void emitSPMDEntryHeader(CodeGenFunction &CGF, EntryFunctionState &EST,
-                           const OMPExecutableDirective &D);
-
-  /// Signal termination of SPMD mode execution.
-  void emitSPMDEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST);
-
   //
   // Base class overrides.
   //

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 0a249b3e25749..a92c3ba381c67 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -779,6 +779,29 @@ class OpenMPIRBuilder {
                                       llvm::ConstantInt *Size,
                                       const llvm::Twine &Name = Twine(""));
 
+  /// The `omp target` interface
+  ///
+  /// For more information about the usage of this interface,
+  /// \see openmp/libomptarget/deviceRTLs/common/include/target.h
+  ///
+  ///{
+
+  /// Create a runtime call for kmpc_target_init
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
+  /// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
+  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
+
+  /// Create a runtime call for kmpc_target_deinit
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
+  /// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
+  void createTargetDeinit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
+
+  ///}
+
   /// Declarations for LLVM-IR types (simple, array, function and structure) are
   /// generated below. Their names are defined and used in OpenMPKinds.def. Here
   /// we provide the declarations, the initializeTypes function will provide the

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 1804cfeef7b8d..2003f44e34e9c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -409,10 +409,8 @@ __OMP_RTL(__kmpc_task_allow_completion_event, false, VoidPtr, IdentPtr,
           /* Int */ Int32, /* kmp_task_t */ VoidPtr)
 
 /// OpenMP Device runtime functions
-__OMP_RTL(__kmpc_kernel_init, false, Void, Int32, Int16)
-__OMP_RTL(__kmpc_kernel_deinit, false, Void, Int16)
-__OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16)
-__OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16)
+__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1)
+__OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1)
 __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
 __OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
           VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1020de5f30ee9..60d71805c758f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Value.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -2191,6 +2192,70 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
   return Builder.CreateCall(Fn, Args);
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+  Value *Ident = getOrCreateIdent(SrcLocStr);
+  ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
+  ConstantInt *UseGenericStateMachine =
+      ConstantInt::getBool(Int32->getContext(), !IsSPMD);
+  ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
+
+  Function *Fn = getOrCreateRuntimeFunctionPtr(
+      omp::RuntimeFunction::OMPRTL___kmpc_target_init);
+
+  CallInst *ThreadKind =
+      Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
+
+  Value *ExecUserCode = Builder.CreateICmpEQ(
+      ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code");
+
+  // ThreadKind = __kmpc_target_init(...)
+  // if (ThreadKind == -1)
+  //   user_code
+  // else
+  //   return;
+
+  auto *UI = Builder.CreateUnreachable();
+  BasicBlock *CheckBB = UI->getParent();
+  BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
+
+  BasicBlock *WorkerExitBB = BasicBlock::Create(
+      CheckBB->getContext(), "worker.exit", CheckBB->getParent());
+  Builder.SetInsertPoint(WorkerExitBB);
+  Builder.CreateRetVoid();
+
+  auto *CheckBBTI = CheckBB->getTerminator();
+  Builder.SetInsertPoint(CheckBBTI);
+  Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
+
+  CheckBBTI->eraseFromParent();
+  UI->eraseFromParent();
+
+  // Continue in the "user_code" block, see diagram above and in
+  // openmp/libomptarget/deviceRTLs/common/include/target.h .
+  return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
+}
+
+void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
+                                         bool IsSPMD, bool RequiresFullRuntime) {
+  if (!updateToLocation(Loc))
+    return;
+
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
+  Value *Ident = getOrCreateIdent(SrcLocStr);
+  ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
+  ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
+
+  Function *Fn = getOrCreateRuntimeFunctionPtr(
+      omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
+
+  Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
+}
+
 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
                                                    StringRef FirstSeparator,
                                                    StringRef Separator) {

diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 0127f9da43083..b1230b96dd6ae 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -26,9 +26,6 @@
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/IntrinsicsAMDGPU.h"
-#include "llvm/IR/IntrinsicsNVPTX.h"
-#include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/IPO.h"
@@ -37,7 +34,6 @@
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
 #include "llvm/Transforms/Utils/CodeExtractor.h"
 
-using namespace llvm::PatternMatch;
 using namespace llvm;
 using namespace omp;
 
@@ -2341,10 +2337,12 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
                               AllCallSitesKnown))
     SingleThreadedBBs.erase(&F->getEntryBlock());
 
-  // Check if the edge into the successor block compares a thread-id function to
-  // a constant zero.
-  // TODO: Use AAValueSimplify to simplify and propogate constants.
-  // TODO: Check more than a single use for thread ID's.
+  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
+
+  // Check if the edge into the successor block compares the __kmpc_target_init
+  // result with -1. If we are in non-SPMD-mode that signals only the main
+  // thread will execute the edge.
   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
     if (!Edge || !Edge->isConditional())
       return false;
@@ -2355,31 +2353,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
       return false;
 
-    // Temporarily match the pattern generated by clang for teams regions.
-    // TODO: Remove this once the new runtime is in place.
-    ConstantInt *One, *NegOne;
-    CmpInst::Predicate Pred;
-    auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>();
-    auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>();
-    auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>();
-    if (match(Cmp, m_Cmp(Pred, m_ThreadID,
-                         m_And(m_Sub(m_BlockSize, m_ConstantInt(One)),
-                               m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)),
-                                     m_ConstantInt(NegOne))))))
-      if (One->isOne() && NegOne->isMinusOne() &&
-          Pred == CmpInst::Predicate::ICMP_EQ)
-        return true;
-
     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
-    if (!C || !C->isZero())
+    if (!C)
       return false;
 
-    if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
-      if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
-        return true;
-    if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
-      if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
-        return true;
+    // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
+    if (C->isAllOnesValue()) {
+      auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
+      if (!CB || CB->getCalledFunction() != RFI.Declaration)
+        return false;
+      const int InitIsSPMDArgNo = 1;
+      auto *IsSPMDModeCI =
+          dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
+      return IsSPMDModeCI && IsSPMDModeCI->isZero();
+    }
 
     return false;
   };
@@ -2394,7 +2381,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
          PredBB != PredEndBB; ++PredBB) {
       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
-                              BB))
+                               BB))
         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
     }
 

diff  --git a/llvm/test/Transforms/OpenMP/replace_globalization.ll b/llvm/test/Transforms/OpenMP/replace_globalization.ll
index cb96fc3832a92..06224e6d4068e 100644
--- a/llvm/test/Transforms/OpenMP/replace_globalization.ll
+++ b/llvm/test/Transforms/OpenMP/replace_globalization.ll
@@ -3,10 +3,16 @@
 target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
 target triple = "nvptx64"
 
+%struct.ident_t = type { i32, i32, i32, i32, i8* }
+
 @S = external local_unnamed_addr global i8*
+ at 0 = private unnamed_addr constant [113 x i8] c";llvm/test/Transforms/OpenMP/custom_state_machines_remarks.c;__omp_offloading_2a_d80d3d_test_fallback_l11;11;1;;\00", align 1
+ at 1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([113 x i8], [113 x i8]* @0, i32 0, i32 0) }, align 8
 
 ; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory
 ; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory
+; CHECK-REMARKS-NOT: 6 bytes
+
 ; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef
 ; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef
 
@@ -25,14 +31,15 @@ entry:
 define void @bar() {
   call void @baz()
   call void @qux()
+  call void @negative_qux_spmd()
   ret void
 }
 
 ; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*))
 define internal void @baz() {
 entry:
-  %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  %cmp = icmp eq i32 %tid, 0
+  %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 true)
+  %cmp = icmp eq i32 %call, -1
   br i1 %cmp, label %master, label %exit
 master:
   %x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !11
@@ -48,20 +55,30 @@ exit:
 ; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*))
 define internal void @qux() {
 entry:
-  %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  %ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
-  %warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
-  %0 = sub nuw i32 %warpsize, 1
-  %1 = sub nuw i32 %ntid, 1
-  %2 = xor i32 %0, -1
-  %master_tid = and i32 %1, %2
-  %3 = icmp eq i32 %tid, %master_tid
-  br i1 %3, label %master, label %exit
+  %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
+  %0 = icmp eq i32 %call, -1
+  br i1 %0, label %master, label %exit
 master:
   %y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !12
   %y_on_stack = bitcast i8* %y to [4 x i32]*
-  %4 = bitcast [4 x i32]* %y_on_stack to i8*
-  call void @use(i8* %4)
+  %1 = bitcast [4 x i32]* %y_on_stack to i8*
+  call void @use(i8* %1)
+  call void @__kmpc_free_shared(i8* %y)
+  br label %exit
+exit:
+  ret void
+}
+
+define internal void @negative_qux_spmd() {
+entry:
+  %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 true, i1 true, i1 true)
+  %0 = icmp eq i32 %call, -1
+  br i1 %0, label %master, label %exit
+master:
+  %y = call i8* @__kmpc_alloc_shared(i64 6), !dbg !12
+  %y_on_stack = bitcast i8* %y to [6 x i32]*
+  %1 = bitcast [6 x i32]* %y_on_stack to i8*
+  call void @use(i8* %1)
   call void @__kmpc_free_shared(i8* %y)
   br label %exit
 exit:
@@ -85,6 +102,7 @@ declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
 
 declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
 
+declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
 
 !llvm.dbg.cu = !{!0}
 !llvm.module.flags = !{!3, !4, !5, !6}

diff  --git a/llvm/test/Transforms/OpenMP/single_threaded_execution.ll b/llvm/test/Transforms/OpenMP/single_threaded_execution.ll
index 5fff563d364d8..ae56477902b12 100644
--- a/llvm/test/Transforms/OpenMP/single_threaded_execution.ll
+++ b/llvm/test/Transforms/OpenMP/single_threaded_execution.ll
@@ -3,8 +3,13 @@
 ; REQUIRES: asserts
 ; ModuleID = 'single_threaded_exeuction.c'
 
-define weak void @kernel() {
-  call void @__kmpc_kernel_init(i32 512, i16 1)
+%struct.ident_t = type { i32, i32, i32, i32, i8* }
+
+ at 0 = private unnamed_addr constant [1 x i8] c"\00", align 1
+ at 1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([1 x i8], [1 x i8]* @0, i32 0, i32 0) }, align 8
+
+define void @kernel() {
+  call void @__kmpc_kernel_prepare_parallel(i8* null)
   call void @nvptx()
   call void @amdgcn()
   ret void
@@ -19,8 +24,8 @@ define weak void @kernel() {
 ; Function Attrs: noinline
 define internal void @nvptx() {
 entry:
-  %call = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  %cmp = icmp eq i32 %call, 0
+  %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 false)
+  %cmp = icmp eq i32 %call, -1
   br i1 %cmp, label %if.then, label %if.end
 
 if.then:
@@ -40,8 +45,8 @@ if.end:
 ; Function Attrs: noinline
 define internal void @amdgcn() {
 entry:
-  %call = call i32 @llvm.amdgcn.workitem.id.x()
-  %cmp = icmp eq i32 %call, 0
+  %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
+  %cmp = icmp eq i32 %call, -1
   br i1 %cmp, label %if.then, label %if.end
 
 if.then:
@@ -87,7 +92,9 @@ declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
 
 declare i32 @llvm.amdgcn.workitem.id.x()
 
-declare void @__kmpc_kernel_init(i32, i16)
+declare void @__kmpc_kernel_prepare_parallel(i8*)
+
+declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
 
 attributes #0 = { cold noinline }
 

diff  --git a/openmp/libomptarget/deviceRTLs/common/include/target.h b/openmp/libomptarget/deviceRTLs/common/include/target.h
new file mode 100644
index 0000000000000..997e93b924e24
--- /dev/null
+++ b/openmp/libomptarget/deviceRTLs/common/include/target.h
@@ -0,0 +1,94 @@
+//===-- target.h ---------- OpenMP device runtime target implementation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Target region interfaces are simple interfaces designed to allow middle-end
+// (=LLVM) passes to analyze and transform the code. To achieve good performance
+// it may be required to run the associated passes. However, implementations of
+// this interface shall always provide a correct implementation as close to the
+// user expected code as possible.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_OPENMP_LIBOMPTARGET_DEVICERTLS_COMMON_TARGET_H
+#define LLVM_OPENMP_LIBOMPTARGET_DEVICERTLS_COMMON_TARGET_H
+
+#include <stdint.h>
+
+extern "C" {
+
+/// Forward declaration of the source location identifier "ident".
+typedef struct ident ident_t;
+
+/// The target region _kernel_ interface for GPUs
+///
+/// This deliberatly simple interface provides the middle-end (=LLVM) with
+/// easier means to reason about the semantic of the code and transform it as
+/// well. The runtime calls are therefore also desiged to carry sufficient
+/// information necessary for optimizations.
+///
+///
+/// Intended usage:
+///
+/// \code
+/// void kernel(...) {
+///   ThreadKind = __kmpc_target_init(Ident, /* IsSPMD */ false,
+///                                   /* UseGenericStateMachine */ true,
+///                                   /* RequiresFullRuntime */ ... );
+///   if (ThreadKind == -1) {
+///     // User defined kernel code.
+///   }
+///   __kmpc_target_deinit(...);
+/// }
+/// \endcode
+///
+/// Which can be transformed to:
+///
+/// \code
+/// void kernel(...) {
+///   ThreadKind = __kmpc_target_init(Ident, /* IsSPMD */ false,
+///                                   /* UseGenericStateMachine */ false,
+///                                   /* RequiresFullRuntime */ ... );
+///   if (ThreadKind == -1) {
+///     // User defined kernel code.
+///   } else {
+///     assume(ThreadKind == ThreadId);
+///     // Custom, kernel-specific state machine code.
+///   }
+///   __kmpc_target_deinit(...);
+/// }
+/// \endcode
+///
+///
+///{
+
+/// Initialization
+///
+/// Must be called by all threads.
+///
+/// \param Ident               Source location identification, can be NULL.
+///
+int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
+                           bool UseGenericStateMachine,
+                           bool RequiresFullRuntime);
+
+/// De-Initialization
+///
+/// Must be called by the main thread in generic mode, can be called by all
+/// threads. Must be called by all threads in SPMD mode.
+///
+/// In non-SPMD, this function releases the workers trapped in a state machine
+/// and also any memory dynamically allocated by the runtime.
+///
+/// \param Ident Source location identification, can be NULL.
+///
+void __kmpc_target_deinit(ident_t *Ident, bool IsSPMD,
+                          bool RequiresFullRuntime);
+
+///}
+}
+#endif

diff  --git a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
index c117c7e00bf28..34af243fab541 100644
--- a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
+++ b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
@@ -12,6 +12,7 @@
 #pragma omp declare target
 
 #include "common/omptarget.h"
+#include "common/support.h"
 #include "target_impl.h"
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -26,16 +27,18 @@ extern omptarget_nvptx_Queue<omptarget_nvptx_ThreadPrivateContext,
 // init entry points
 ////////////////////////////////////////////////////////////////////////////////
 
-EXTERN void __kmpc_kernel_init(int ThreadLimit, int16_t RequiresOMPRuntime) {
+static void __kmpc_generic_kernel_init() {
   PRINT(LD_IO, "call to __kmpc_kernel_init with version %f\n",
         OMPTARGET_NVPTX_VERSION);
-  ASSERT0(LT_FUSSY, RequiresOMPRuntime,
-          "Generic always requires initialized runtime.");
-  setExecutionParameters(Generic, RuntimeInitialized);
-  for (int I = 0; I < MAX_THREADS_PER_TEAM / WARPSIZE; ++I)
-    parallelLevel[I] = 0;
+
+  if (GetLaneId() == 0)
+    parallelLevel[GetWarpId()] = 0;
 
   int threadIdInBlock = GetThreadIdInBlock();
+  if (threadIdInBlock != GetMasterThreadID())
+    return;
+
+  setExecutionParameters(Generic, RuntimeInitialized);
   ASSERT0(LT_FUSSY, threadIdInBlock == GetMasterThreadID(),
           "__kmpc_kernel_init() must be called by team master warp only!");
   PRINT0(LD_IO, "call to __kmpc_kernel_init for master\n");
@@ -47,7 +50,7 @@ EXTERN void __kmpc_kernel_init(int ThreadLimit, int16_t RequiresOMPRuntime) {
       omptarget_nvptx_device_State[slot].Dequeue();
 
   // init thread private
-  int threadId = GetLogicalThreadIdInBlock(/*isSPMDExecutionMode=*/false);
+  int threadId = 0;
   omptarget_nvptx_threadPrivateContext->InitThreadPrivateContext(threadId);
 
   // init team context
@@ -62,20 +65,17 @@ EXTERN void __kmpc_kernel_init(int ThreadLimit, int16_t RequiresOMPRuntime) {
   // set number of threads and thread limit in team to started value
   omptarget_nvptx_TaskDescr *currTaskDescr =
       omptarget_nvptx_threadPrivateContext->GetTopLevelTaskDescr(threadId);
-  nThreads = GetNumberOfThreadsInBlock();
-  threadLimit = ThreadLimit;
+  nThreads = GetNumberOfWorkersInTeam();
+  threadLimit = nThreads;
 
-  if (!__kmpc_is_spmd_exec_mode())
-    omptarget_nvptx_globalArgs.Init();
+  omptarget_nvptx_globalArgs.Init();
 
   __kmpc_data_sharing_init_stack();
   __kmpc_impl_target_init();
 }
 
-EXTERN void __kmpc_kernel_deinit(int16_t IsOMPRuntimeInitialized) {
+static void __kmpc_generic_kernel_deinit() {
   PRINT0(LD_IO, "call to __kmpc_kernel_deinit\n");
-  ASSERT0(LT_FUSSY, IsOMPRuntimeInitialized,
-          "Generic always requires initialized runtime.");
   // Enqueue omp state object for use by another team.
   int slot = usedSlotIdx;
   omptarget_nvptx_device_State[slot].Enqueue(
@@ -84,12 +84,11 @@ EXTERN void __kmpc_kernel_deinit(int16_t IsOMPRuntimeInitialized) {
   omptarget_nvptx_workFn = 0;
 }
 
-EXTERN void __kmpc_spmd_kernel_init(int ThreadLimit,
-                                    int16_t RequiresOMPRuntime) {
+static void __kmpc_spmd_kernel_init(bool RequiresFullRuntime) {
   PRINT0(LD_IO, "call to __kmpc_spmd_kernel_init\n");
 
-  setExecutionParameters(Spmd, RequiresOMPRuntime ? RuntimeInitialized
-                                                  : RuntimeUninitialized);
+  setExecutionParameters(Spmd, RequiresFullRuntime ? RuntimeInitialized
+                         : RuntimeUninitialized);
   int threadId = GetThreadIdInBlock();
   if (threadId == 0) {
     usedSlotIdx = __kmpc_impl_smid() % MAX_SM;
@@ -100,11 +99,8 @@ EXTERN void __kmpc_spmd_kernel_init(int ThreadLimit,
         1 + (GetNumberOfThreadsInBlock() > 1 ? OMP_ACTIVE_PARALLEL_LEVEL : 0);
   }
   __kmpc_data_sharing_init_stack();
-  if (!RequiresOMPRuntime) {
-    // Runtime is not required - exit.
-    __kmpc_impl_syncthreads();
+  if (!RequiresFullRuntime)
     return;
-  }
 
   //
   // Team Context Initialization.
@@ -138,16 +134,17 @@ EXTERN void __kmpc_spmd_kernel_init(int ThreadLimit,
                                                              newTaskDescr);
 
   // init thread private from init value
+  int ThreadLimit = GetNumberOfProcsInTeam(/* IsSPMD */ true);
   PRINT(LD_PAR,
         "thread will execute parallel region with id %d in a team of "
         "%d threads\n",
         (int)newTaskDescr->ThreadId(), (int)ThreadLimit);
 }
 
-EXTERN void __kmpc_spmd_kernel_deinit_v2(int16_t RequiresOMPRuntime) {
+static void __kmpc_spmd_kernel_deinit(bool RequiresFullRuntime) {
   // We're not going to pop the task descr stack of each thread since
   // there are no more parallel regions in SPMD mode.
-  if (!RequiresOMPRuntime)
+  if (!RequiresFullRuntime)
     return;
 
   __kmpc_impl_syncthreads();
@@ -165,4 +162,68 @@ EXTERN int8_t __kmpc_is_spmd_exec_mode() {
   return (execution_param & ModeMask) == Spmd;
 }
 
+EXTERN bool __kmpc_kernel_parallel(void**WorkFn);
+
+static void __kmpc_target_region_state_machine(ident_t *Ident) {
+
+  int TId = GetThreadIdInBlock();
+  do {
+    void* WorkFn = 0;
+
+    // Wait for the signal that we have a new work function.
+    __kmpc_barrier_simple_spmd(Ident, TId);
+
+
+    // Retrieve the work function from the runtime.
+    bool IsActive = __kmpc_kernel_parallel(&WorkFn);
+
+    // If there is nothing more to do, break out of the state machine by
+    // returning to the caller.
+    if (!WorkFn)
+      return;
+
+    if (IsActive) {
+      ((void(*)(uint32_t,uint32_t))WorkFn)(0, TId);
+      __kmpc_kernel_end_parallel();
+    }
+
+    __kmpc_barrier_simple_spmd(Ident, TId);
+
+  } while (true);
+}
+
+EXTERN
+int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
+                           bool UseGenericStateMachine,
+                           bool RequiresFullRuntime) {
+  int TId = GetThreadIdInBlock();
+  if (IsSPMD)
+    __kmpc_spmd_kernel_init(RequiresFullRuntime);
+  else
+    __kmpc_generic_kernel_init();
+
+   if (IsSPMD) {
+    __kmpc_barrier_simple_spmd(Ident, TId);
+     return -1;
+   }
+
+   if (TId == GetMasterThreadID())
+     return -1;
+
+  if (UseGenericStateMachine)
+    __kmpc_target_region_state_machine(Ident);
+
+  return TId;
+}
+
+EXTERN
+void __kmpc_target_deinit(ident_t *Ident, bool IsSPMD,
+                           bool RequiresFullRuntime) {
+  if (IsSPMD)
+    __kmpc_spmd_kernel_deinit(RequiresFullRuntime);
+  else
+    __kmpc_generic_kernel_deinit();
+}
+
+
 #pragma omp end declare target

diff  --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
index 29a6db85d172c..ea885b806aca4 100644
--- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
+++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
@@ -331,7 +331,7 @@ EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
         (1 + (IsActiveParallelRegion ? OMP_ACTIVE_PARALLEL_LEVEL : 0));
 
   // Master signals work to activate workers.
-  __kmpc_barrier_simple_spmd(nullptr, 0);
+  __kmpc_barrier_simple_spmd(ident, 0);
 
   // OpenMP [2.5, Parallel Construct, p.49]
   // There is an implied barrier at the end of a parallel region. After the
@@ -339,7 +339,7 @@ EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
   // execution of the enclosing task region.
   //
   // The master waits at this barrier until all workers are done.
-  __kmpc_barrier_simple_spmd(nullptr, 0);
+  __kmpc_barrier_simple_spmd(ident, 0);
 
   // Decrement parallel level for non-SPMD warps.
   for (int I = 0; I < NumWarps; ++I)

diff  --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h
index 082b6b9d11090..e0c433060c85b 100644
--- a/openmp/libomptarget/deviceRTLs/interface.h
+++ b/openmp/libomptarget/deviceRTLs/interface.h
@@ -416,11 +416,11 @@ EXTERN int32_t __kmpc_cancel(kmp_Ident *loc, int32_t global_tid,
                              int32_t cancelVal);
 
 // non standard
-EXTERN void __kmpc_kernel_init(int ThreadLimit, int16_t RequiresOMPRuntime);
-EXTERN void __kmpc_kernel_deinit(int16_t IsOMPRuntimeInitialized);
-EXTERN void __kmpc_spmd_kernel_init(int ThreadLimit,
-                                    int16_t RequiresOMPRuntime);
-EXTERN void __kmpc_spmd_kernel_deinit_v2(int16_t RequiresOMPRuntime);
+EXTERN int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
+                                 bool UseGenericStateMachine,
+                           bool RequiresFullRuntime);
+EXTERN void __kmpc_target_deinit(ident_t *Ident, bool IsSPMD,
+                           bool RequiresFullRuntime);
 EXTERN void __kmpc_kernel_prepare_parallel(void *WorkFn);
 EXTERN bool __kmpc_kernel_parallel(void **WorkFn);
 EXTERN void __kmpc_kernel_end_parallel();

diff  --git a/openmp/libomptarget/deviceRTLs/nvptx/src/target_impl.cu b/openmp/libomptarget/deviceRTLs/nvptx/src/target_impl.cu
index eafa73426a950..35324f070e4d6 100644
--- a/openmp/libomptarget/deviceRTLs/nvptx/src/target_impl.cu
+++ b/openmp/libomptarget/deviceRTLs/nvptx/src/target_impl.cu
@@ -60,7 +60,13 @@ EXTERN __kmpc_impl_lanemask_t __kmpc_impl_activemask() {
   return Mask;
 }
 
-EXTERN void __kmpc_impl_syncthreads() { __syncthreads(); }
+EXTERN void __kmpc_impl_syncthreads() { 
+  int barrier = 2;
+  asm volatile("barrier.sync %0;"
+               :
+               : "r"(barrier)
+               : "memory");
+}
 
 EXTERN void __kmpc_impl_syncwarp(__kmpc_impl_lanemask_t Mask) {
   __nvvm_bar_warp_sync(Mask);
@@ -75,7 +81,7 @@ EXTERN void __kmpc_impl_named_sync(uint32_t num_threads) {
   // The named barrier for active parallel threads of a team in an L1 parallel
   // region to synchronize with each other.
   int barrier = 1;
-  asm volatile("bar.sync %0, %1;"
+  asm volatile("barrier.sync %0, %1;"
                :
                : "r"(barrier), "r"(num_threads)
                : "memory");


        


More information about the cfe-commits mailing list