[llvm] 7604c59 - [OpenMP][IRBuilder] `omp task` support

Shraiysh Vaishay via llvm-commits llvm-commits at lists.llvm.org
Mon May 23 21:52:22 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-05-24T10:22:11+05:30
New Revision: 7604c59bd2336ebb34f28de3e6c883abbdd3f7c7

URL: https://github.com/llvm/llvm-project/commit/7604c59bd2336ebb34f28de3e6c883abbdd3f7c7
DIFF: https://github.com/llvm/llvm-project/commit/7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.diff

LOG: [OpenMP][IRBuilder] `omp task` support

This patch adds basic support for `omp task` to the OpenMPIRBuilder.

The outlined function after code extraction is called from a wrapper function with appropriate arguments. This wrapper function is passed to the runtime calls for task allocation.

This approach is different from the Clang approach - clang directly emits the runtime call to the outlined function. The outlining utility (OutlineInfo) simply outlines the code and generates a function call to the outlined function. After the function has been generated by the outlining utility, there is no easy way to alter the function arguments without meddling with the outlining itself. Hence the wrapper function approach is taken.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D71989

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7f2805c2e84e..93b1297a7930 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -618,6 +618,16 @@ class OpenMPIRBuilder {
   /// \param Loc The location where the taskyield directive was encountered.
   void createTaskyield(const LocationDescription &Loc);
 
+  /// Generator for `#omp task`
+  ///
+  /// \param Loc The location where the task construct was encountered.
+  /// \param AllocaIP The insertion point to be used for alloca instructions.
+  /// \param BodyGenCB Callback that will generate the region code.
+  /// \param Tied True if the task is tied, false if the task is untied.
+  InsertPointTy createTask(const LocationDescription &Loc,
+                           InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
+                           bool Tied = true);
+
   /// Functions used to generate reductions. Such functions take two Values
   /// representing LHS and RHS of the reduction, respectively, and a reference
   /// to the value that is updated to refer to the reduction result.

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8a7ce1010364..dfa20d4cabd1 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1253,6 +1253,172 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
   emitTaskyieldImpl(Loc);
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTask(const LocationDescription &Loc,
+                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
+                            bool Tied) {
+  if (!updateToLocation(Loc))
+    return InsertPointTy();
+
+  // The current basic block is split into four basic blocks. After outlining,
+  // they will be mapped as follows:
+  // ```
+  // def current_fn() {
+  //   current_basic_block:
+  //     br label %task.exit
+  //   task.exit:
+  //     ; instructions after task
+  // }
+  // def outlined_fn() {
+  //   task.alloca:
+  //     br label %task.body
+  //   task.body:
+  //     ret void
+  // }
+  // ```
+  BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
+  BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
+  BasicBlock *TaskAllocaBB =
+      splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
+
+  OutlineInfo OI;
+  OI.EntryBB = TaskAllocaBB;
+  OI.OuterAllocaBB = AllocaIP.getBlock();
+  OI.ExitBB = TaskExitBB;
+  OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) {
+    // The input IR here looks like the following-
+    // ```
+    // func @current_fn() {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+    //
+    // This is changed to the following-
+    //
+    // ```
+    // func @current_fn() {
+    //   runtime_call(..., wrapper_fn, ...)
+    // }
+    // func @wrapper_fn(..., %args) {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+
+    // The stale call instruction will be replaced with a new call instruction
+    // for runtime call with a wrapper function.
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+
+    // HasTaskData is true if any variables are captured in the outlined region,
+    // false otherwise.
+    bool HasTaskData = StaleCI->arg_size() > 0;
+    Builder.SetInsertPoint(StaleCI);
+
+    // Gather the arguments for emitting the runtime call for
+    // @__kmpc_omp_task_alloc
+    Function *TaskAllocFn =
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+
+    // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
+    // call.
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+    Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+    Value *ThreadID = getOrCreateThreadID(Ident);
+
+    // Argument - `flags`
+    // If task is tied, then (Flags & 1) == 1.
+    // If task is untied, then (Flags & 1) == 0.
+    // TODO: Handle the other flags.
+    Value *Flags = Builder.getInt32(Tied);
+
+    // Argument - `sizeof_kmp_task_t` (TaskSize)
+    // Tasksize refers to the size in bytes of kmp_task_t data structure
+    // including private vars accessed in task.
+    Value *TaskSize = Builder.getInt64(0);
+    if (HasTaskData) {
+      AllocaInst *ArgStructAlloca =
+          dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
+      assert(ArgStructAlloca &&
+             "Unable to find the alloca instruction corresponding to arguments "
+             "for extracted function");
+      StructType *ArgStructType =
+          dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+      assert(ArgStructType && "Unable to find struct type corresponding to "
+                              "arguments for extracted function");
+      TaskSize =
+          Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+    }
+
+    // TODO: Argument - sizeof_shareds
+
+    // Argument - task_entry (the wrapper function)
+    // If the outlined function has some captured variables (i.e. HasTaskData is
+    // true), then the wrapper function will have an additional argument (the
+    // struct containing captured variables). Otherwise, no such argument will
+    // be present.
+    SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
+    if (HasTaskData)
+      WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
+    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+        (Twine(OutlinedFn.getName()) + ".wrapper").str(),
+        FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
+    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+    PointerType *WrapperFuncBitcastType =
+        FunctionType::get(Builder.getInt32Ty(),
+                          {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false)
+            ->getPointerTo();
+    Value *WrapperFuncBitcast =
+        ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType);
+
+    // Emit the @__kmpc_omp_task_alloc runtime call
+    // The runtime call returns a pointer to an area where the task captured
+    // variables must be copied before the task is run (NewTaskData)
+    CallInst *NewTaskData = Builder.CreateCall(
+        TaskAllocFn,
+        {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+         /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
+         /*task_func=*/WrapperFuncBitcast});
+
+    // Copy the arguments for outlined function
+    if (HasTaskData) {
+      Value *TaskData = StaleCI->getArgOperand(0);
+      Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
+      Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
+                           TaskSize);
+    }
+
+    // Emit the @__kmpc_omp_task runtime call to spawn the task
+    Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
+    Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+
+    StaleCI->eraseFromParent();
+
+    // Emit the body for wrapper function
+    BasicBlock *WrapperEntryBB =
+        BasicBlock::Create(M.getContext(), "", WrapperFunc);
+    Builder.SetInsertPoint(WrapperEntryBB);
+    if (HasTaskData)
+      Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
+    else
+      Builder.CreateCall(&OutlinedFn);
+    Builder.CreateRet(Builder.getInt32(0));
+  };
+
+  addOutlineInfo(std::move(OI));
+
+  InsertPointTy TaskAllocaIP =
+      InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
+  InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
+  BodyGenCB(TaskAllocaIP, TaskBodyIP);
+  Builder.SetInsertPoint(TaskExitBB);
+
+  return Builder.saveIP();
+}
+
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 5f7e3228ae57..aef5992c93ea 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4412,4 +4412,169 @@ TEST_F(OpenMPIRBuilderTest, EmitMapperCall) {
   EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy());
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTask) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+  AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+  Value *Val128 =
+      Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+                                                "bodygen.alloca128");
+
+    Builder.restoreIP(CodeGenIP);
+    // Loading and storing captured pointer and values
+    Builder.CreateStore(Val128, Local128);
+    Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+                                      "bodygen.load32");
+
+    LoadInst *PrivLoad128 = Builder.CreateLoad(
+        Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+    Value *Cmp = Builder.CreateICmpNE(
+        Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+    Instruction *ThenTerm, *ElseTerm;
+    SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+                                  &ThenTerm, &ElseTerm);
+  };
+
+  BasicBlock *AllocaBB = Builder.GetInsertBlock();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  OpenMPIRBuilder::LocationDescription Loc(
+      InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+  Builder.restoreIP(OMPBuilder.createTask(
+      Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+      BodyGenCB));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+
+  // Verify the Ident argument
+  GlobalVariable *Ident = cast<GlobalVariable>(TaskAllocCall->getArgOperand(0));
+  ASSERT_NE(Ident, nullptr);
+  EXPECT_TRUE(Ident->hasInitializer());
+  Constant *Initializer = Ident->getInitializer();
+  GlobalVariable *SrcStrGlob =
+      cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+  ASSERT_NE(SrcStrGlob, nullptr);
+  ConstantDataArray *SrcSrc =
+      dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+  ASSERT_NE(SrcSrc, nullptr);
+
+  // Verify the num_threads argument.
+  CallInst *GTID = dyn_cast<CallInst>(TaskAllocCall->getArgOperand(1));
+  ASSERT_NE(GTID, nullptr);
+  EXPECT_EQ(GTID->arg_size(), 1U);
+  EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+
+  // Verify the flags
+  // TODO: Check for others flags. Currently testing only for tiedness.
+  ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
+  ASSERT_NE(Flags, nullptr);
+  EXPECT_EQ(Flags->getSExtValue(), 1);
+
+  // Verify the data size
+  ConstantInt *DataSize =
+      dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
+  ASSERT_NE(DataSize, nullptr);
+  EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer
+
+  // TODO: Verify size of shared clause variables
+
+  // Verify Wrapper function
+  Function *WrapperFunc =
+      dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+  ASSERT_NE(WrapperFunc, nullptr);
+  EXPECT_FALSE(WrapperFunc->isDeclaration());
+  CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin());
+  ASSERT_NE(OutlinedFnCall, nullptr);
+  EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
+  EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1));
+
+  // Verify the presence of `trunc` and `icmp` instructions in Outlined function
+  Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
+  ASSERT_NE(OutlinedFn, nullptr);
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+
+  // Verify the execution of the task
+  CallInst *TaskCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+          ->user_back());
+  ASSERT_NE(TaskCall, nullptr);
+  EXPECT_EQ(TaskCall->getArgOperand(0), Ident);
+  EXPECT_EQ(TaskCall->getArgOperand(1), GTID);
+  EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall);
+
+  // Verify that the argument data has been copied
+  for (User *in : TaskAllocCall->users()) {
+    if (MemCpyInst *memCpyInst = dyn_cast<MemCpyInst>(in))
+      EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall);
+  }
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+
+  BasicBlock *AllocaBB = Builder.GetInsertBlock();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  OpenMPIRBuilder::LocationDescription Loc(
+      InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+  Builder.restoreIP(OMPBuilder.createTask(
+      Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+      BodyGenCB));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+  BasicBlock *AllocaBB = Builder.GetInsertBlock();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  OpenMPIRBuilder::LocationDescription Loc(
+      InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+  Builder.restoreIP(OMPBuilder.createTask(
+      Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
+      /*Tied=*/false));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  // Check for the `Tied` argument
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  ASSERT_NE(TaskAllocCall, nullptr);
+  ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
+  ASSERT_NE(Flags, nullptr);
+  EXPECT_EQ(Flags->getZExtValue() & 1U, 0U);
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
 } // namespace


        


More information about the llvm-commits mailing list