[clang] [OpenMP][OMPIRBuilder] Add support to omp target parallel (PR #67000)

Dominik Adamski via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 26 07:01:17 PDT 2023


================
@@ -1126,6 +1133,185 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
 }
 
+// Callback used to create OpenMP runtime calls to support
+// omp parallel clause for the device.
+// We need to use this callback to replace call to the OutlinedFn in OuterFn
+// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
+static void
+targetParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
+                       Function *OuterFn, Value *Ident, Value *IfCondition,
+                       Value *NumThreads, Instruction *PrivTID,
+                       AllocaInst *PrivTIDAddr, Value *ThreadID,
+                       const SmallVector<Instruction *, 4> &ToBeDeleted) {
+  // Add some known attributes.
+  Module &M = OMPIRBuilder->M;
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  OutlinedFn.addParamAttr(0, Attribute::NoAlias);
+  OutlinedFn.addParamAttr(1, Attribute::NoAlias);
+  OutlinedFn.addParamAttr(0, Attribute::NoUndef);
+  OutlinedFn.addParamAttr(1, Attribute::NoUndef);
+  OutlinedFn.addFnAttr(Attribute::NoUnwind);
+
+  assert(OutlinedFn.arg_size() >= 2 &&
+         "Expected at least tid and bounded tid as arguments");
+  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
+  CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
+  assert(CI && "Expected call instruction to outlined function");
+  CI->getParent()->setName("omp_parallel");
+  // Replace direct call to the outlined function by the call to
+  // __kmpc_parallel_51
+  Builder.SetInsertPoint(CI);
+
+  // Build call __kmpc_parallel_51
+  auto PtrTy = Type::getInt8PtrTy(M.getContext());
+  Value *Void = ConstantPointerNull::get(PtrTy);
+  // Add alloca for kernel args. Put this instruction at the beginning
+  // of the function.
+  OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
+  Builder.SetInsertPoint(&OuterFn->front(),
+                         OuterFn->front().getFirstInsertionPt());
+  AllocaInst *ArgsAlloca =
+      Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
+  Value *Args =
+      Builder.CreatePointerCast(ArgsAlloca, Type::getInt8PtrTy(M.getContext()));
+  Builder.restoreIP(CurrentIP);
+  // Store captured vars which are used by kmpc_parallel_51
+  if (NumCapturedVars) {
+    for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
+      Value *V = *(CI->arg_begin() + 2 + Idx);
+      Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
+          ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
+      Builder.CreateStore(V, StoreAddress);
+    }
+  }
+  Value *Cond = IfCondition ? Builder.CreateSExtOrTrunc(
+                                  IfCondition, Type::getInt32Ty(M.getContext()))
+                            : Builder.getInt32(1);
+  Value *Parallel51CallArgs[] = {
+      /* identifier*/ Ident,
+      /* global thread num*/ ThreadID,
+      /* if expression */ Cond, NumThreads ? NumThreads : Builder.getInt32(-1),
+      /* Proc bind */ Builder.getInt32(-1),
+      /* outlined function */
+      Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr), Void,
+      Args, Builder.getInt64(NumCapturedVars)};
+
+  SmallVector<Value *, 16> RealArgs;
+  RealArgs.append(std::begin(Parallel51CallArgs), std::end(Parallel51CallArgs));
+  FunctionCallee RTLFn =
+      OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
+
+  Builder.CreateCall(RTLFn, RealArgs);
+
+  LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
+                    << *Builder.GetInsertBlock()->getParent() << "\n");
+
+  // Initialize the local TID stack location with the argument value.
+  Builder.SetInsertPoint(PrivTID);
+  Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
+  Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
+                      PrivTIDAddr);
+
+  // Remove redundant call to the outlined function.
+  CI->eraseFromParent();
+
+  for (Instruction *I : ToBeDeleted) {
+    I->eraseFromParent();
+  }
+}
+
+// Callback used to create OpenMP runtime calls to support
+// omp parallel clause for the host.
+// We need to use this callback to replace call to the OutlinedFn in OuterFn
+// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
+static void
+hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
+                     Function *OuterFn, Value *Ident, Value *IfCondition,
+                     Instruction *PrivTID, AllocaInst *PrivTIDAddr,
+                     const SmallVector<Instruction *, 4> &ToBeDeleted) {
+  Module &M = OMPIRBuilder->M;
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  FunctionCallee RTLFn;
+  if (IfCondition) {
+    RTLFn =
+        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
+  } else {
+    RTLFn =
+        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
+  }
+  if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
+    if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
----------------
DominikAdamski wrote:

Done (scope of changes: this PR). I did not add any new `llvm::` inside OMPIRBuilder file.

https://github.com/llvm/llvm-project/pull/67000


More information about the cfe-commits mailing list