[llvm] [OpenMP][OMPIRBuilder] Add support to omp target parallel (PR #67000)

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 21 03:57:28 PDT 2023


================
@@ -1235,63 +1244,146 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
   }
 
   OutlineInfo OI;
-  OI.PostOutlineCB = [=](Function &OutlinedFn) {
-    // Add some known attributes.
-    OutlinedFn.addParamAttr(0, Attribute::NoAlias);
-    OutlinedFn.addParamAttr(1, Attribute::NoAlias);
-    OutlinedFn.addFnAttr(Attribute::NoUnwind);
-    OutlinedFn.addFnAttr(Attribute::NoRecurse);
-
-    assert(OutlinedFn.arg_size() >= 2 &&
-           "Expected at least tid and bounded tid as arguments");
-    unsigned NumCapturedVars =
-        OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
-    CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
-    CI->getParent()->setName("omp_parallel");
-    Builder.SetInsertPoint(CI);
-
-    // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
-    Value *ForkCallArgs[] = {
-        Ident, Builder.getInt32(NumCapturedVars),
-        Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
-
-    SmallVector<Value *, 16> RealArgs;
-    RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
-    if (IfCondition) {
-      Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
-                                              Type::getInt32Ty(M.getContext()));
-      RealArgs.push_back(Cond);
-    }
-    RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
-
-    // __kmpc_fork_call_if always expects a void ptr as the last argument
-    // If there are no arguments, pass a null pointer.
-    auto PtrTy = Type::getInt8PtrTy(M.getContext());
-    if (IfCondition && NumCapturedVars == 0) {
+  if (Config.isTargetDevice()) {
+    // Generate OpenMP target specific runtime call
+    OI.PostOutlineCB = [=](Function &OutlinedFn) {
+      // Add some known attributes.
+      OutlinedFn.addParamAttr(0, Attribute::NoAlias);
+      OutlinedFn.addParamAttr(1, Attribute::NoAlias);
+      OutlinedFn.addParamAttr(0, Attribute::NoUndef);
+      OutlinedFn.addParamAttr(1, Attribute::NoUndef);
+      OutlinedFn.addFnAttr(Attribute::NoUnwind);
+      OutlinedFn.addFnAttr(Attribute::NoRecurse);
+
+      assert(OutlinedFn.arg_size() >= 2 &&
+             "Expected at least tid and bounded tid as arguments");
+      unsigned NumCapturedVars =
+          OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
+      CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
+      assert(CI && "Expected call instruction to outlined function");
+      CI->getParent()->setName("omp_parallel");
+      // Replace direct call to the outlined function by the call to
+      // __kmpc_parallel_51
+      Builder.SetInsertPoint(CI);
+
+      // Build call __kmpc_parallel_51
+      auto PtrTy = Type::getInt8PtrTy(M.getContext());
       llvm::Value *Void = ConstantPointerNull::get(PtrTy);
-      RealArgs.push_back(Void);
-    }
-    if (IfCondition && RealArgs.back()->getType() != PtrTy)
-      RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
+      // Add alloca for kernel args. Put this instruction at the beginning
+      // of the function.
+      InsertPointTy CurrentIP = Builder.saveIP();
+      Builder.SetInsertPoint(&OuterFn->front(),
+                             OuterFn->front().getFirstInsertionPt());
+      AllocaInst *ArgsAlloca =
+          Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
+      Value *Args = Builder.CreatePointerCast(
+          ArgsAlloca, Type::getInt8PtrTy(M.getContext()));
+      Builder.restoreIP(CurrentIP);
+      // Store captured vars which are used by kmpc_parallel_51
+      if (NumCapturedVars) {
+        for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
+          Value *V = *(CI->arg_begin() + 2 + Idx);
+          Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
+              ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
+          Builder.CreateStore(V, StoreAddress);
+        }
+      }
+      Value *Cond = IfCondition
+                        ? Builder.CreateSExtOrTrunc(
+                              IfCondition, Type::getInt32Ty(M.getContext()))
+                        : Builder.getInt32(1);
+      Value *Parallel51CallArgs[] = {
+          /* identifier*/ Ident,
+          /* global thread num*/ ThreadID,
+          /* if expression */ Cond,
+          NumThreads ? NumThreads : Builder.getInt32(-1),
+          /* Proc bind */ Builder.getInt32(-1),
+          /* outlined function */
+          Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr), Void, Args,
+          Builder.getInt64(NumCapturedVars)};
+
+      SmallVector<Value *, 16> RealArgs;
+      RealArgs.append(std::begin(Parallel51CallArgs),
+                      std::end(Parallel51CallArgs));
+
+      Builder.CreateCall(RTLFn, RealArgs);
+
+      LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
+                        << *Builder.GetInsertBlock()->getParent() << "\n");
+
+      InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
+
+      // Initialize the local TID stack location with the argument value.
+      Builder.SetInsertPoint(PrivTID);
+      Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
+      Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
+
+      CI->eraseFromParent();
+
+      for (Instruction *I : ToBeDeleted)
+        I->eraseFromParent();
+    };
+  } else {
+    // Generate OpenMP host runtime call
+    OI.PostOutlineCB = [=](Function &OutlinedFn) {
+      // Add some known attributes.
+      OutlinedFn.addParamAttr(0, Attribute::NoAlias);
+      OutlinedFn.addParamAttr(1, Attribute::NoAlias);
+      OutlinedFn.addFnAttr(Attribute::NoUnwind);
+      OutlinedFn.addFnAttr(Attribute::NoRecurse);
----------------
jdoerfert wrote:

I think the no recourse was wrong (already before). I don’t see why it could not reach the same parallel region from a parallel region.

https://github.com/llvm/llvm-project/pull/67000


More information about the llvm-commits mailing list