[llvm] 95eb510 - [OpenMP][IRBuilder] Added if clause to task

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 22 18:39:52 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-09-23T01:39:41Z
New Revision: 95eb5109afa4c5d26a9432e01eb31dbfc16355c5

URL: https://github.com/llvm/llvm-project/commit/95eb5109afa4c5d26a9432e01eb31dbfc16355c5
DIFF: https://github.com/llvm/llvm-project/commit/95eb5109afa4c5d26a9432e01eb31dbfc16355c5.diff

LOG: [OpenMP][IRBuilder] Added if clause to task

This patch adds support for if clause to task construct in OpenMP
IRBuilder.

Reviewed By: raghavendhra

Differential Revision: https://reviews.llvm.org/D130615

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 369392df3fb7c..ba63353edd5ff 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -647,9 +647,16 @@ class OpenMPIRBuilder {
   /// \param Tied True if the task is tied, false if the task is untied.
   /// \param Final i1 value which is `true` if the task is final, `false` if the
   ///              task is not final.
+  /// \param IfCondition i1 value. If it evaluates to `false`, an undeferred
+  ///                    task is generated, and the encountering thread must
+  ///                    suspend the current task region, for which execution
+  ///                    cannot be resumed until execution of the structured
+  ///                    block that is associated with the generated task is
+  ///                    completed.
   InsertPointTy createTask(const LocationDescription &Loc,
                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                           bool Tied = true, Value *Final = nullptr);
+                           bool Tied = true, Value *Final = nullptr,
+                           Value *IfCondition = nullptr);
 
   /// Generator for the taskgroup construct
   ///

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 92f6289087e1e..de2ac6e1f9561 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -24,6 +24,7 @@
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/MDBuilder.h"
@@ -1289,7 +1290,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                            bool Tied, Value *Final) {
+                            bool Tied, Value *Final, Value *IfCondition) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -1321,7 +1322,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) {
+  OI.PostOutlineCB = [this, Ident, Tied, Final,
+                      IfCondition](Function &OutlinedFn) {
     // The input IR here looks like the following-
     // ```
     // func @current_fn() {
@@ -1431,6 +1433,44 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                            TaskSize);
     }
 
+    // In the presence of the `if` clause, the following IR is generated:
+    //    ...
+    //    %data = call @__kmpc_omp_task_alloc(...)
+    //    br i1 %if_condition, label %then, label %else
+    //  then:
+    //    call @__kmpc_omp_task(...)
+    //    br label %exit
+    //  else:
+    //    call @__kmpc_omp_task_begin_if0(...)
+    //    call @wrapper_fn(...)
+    //    call @__kmpc_omp_task_complete_if0(...)
+    //    br label %exit
+    //  exit:
+    //    ...
+    if (IfCondition) {
+      // `SplitBlockAndInsertIfThenElse` requires the block to have a
+      // terminator.
+      BasicBlock *NewBasicBlock =
+          splitBB(Builder, /*CreateBranch=*/true, "if.end");
+      Instruction *IfTerminator =
+          NewBasicBlock->getSinglePredecessor()->getTerminator();
+      Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
+      Builder.SetInsertPoint(IfTerminator);
+      SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
+                                    &ElseTI);
+      Builder.SetInsertPoint(ElseTI);
+      Function *TaskBeginFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
+      Function *TaskCompleteFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
+      Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
+      if (HasTaskData)
+        Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
+      else
+        Builder.CreateCall(WrapperFunc, {ThreadID});
+      Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
+      Builder.SetInsertPoint(ThenTI);
+    }
     // Emit the @__kmpc_omp_task runtime call to spawn the task
     Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
     Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index aa120c1a08878..92a118be4ff60 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5040,6 +5040,71 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  Builder.SetInsertPoint(BodyBB);
+  Value *IfCondition = Builder.CreateICmp(
+      CmpInst::Predicate::ICMP_EQ, F->getArg(0),
+      ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
+  OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+  Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
+                                          /*Tied=*/false, /*Final=*/nullptr,
+                                          IfCondition));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  ASSERT_NE(TaskAllocCall, nullptr);
+
+  // Check the branching is based on the if condition argument.
+  BranchInst *IfConditionBranchInst =
+      dyn_cast<BranchInst>(TaskAllocCall->getParent()->getTerminator());
+  ASSERT_NE(IfConditionBranchInst, nullptr);
+  ASSERT_TRUE(IfConditionBranchInst->isConditional());
+  EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition);
+
+  // Check that the `__kmpc_omp_task` executes only in the then branch.
+  CallInst *TaskCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+          ->user_back());
+  ASSERT_NE(TaskCall, nullptr);
+  EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0));
+
+  // Check that the OpenMP Runtime Functions specific to `if` clause execute
+  // only in the else branch. Also check that the function call is between the
+  // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls.
+  CallInst *TaskBeginIfCall = dyn_cast<CallInst>(
+      OMPBuilder
+          .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0)
+          ->user_back());
+  CallInst *TaskCompleteCall = dyn_cast<CallInst>(
+      OMPBuilder
+          .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0)
+          ->user_back());
+  ASSERT_NE(TaskBeginIfCall, nullptr);
+  ASSERT_NE(TaskCompleteCall, nullptr);
+  Function *WrapperFunc =
+      dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+  ASSERT_NE(WrapperFunc, nullptr);
+  CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
+  ASSERT_NE(WrapperFuncCall, nullptr);
+  EXPECT_EQ(TaskBeginIfCall->getParent(),
+            IfConditionBranchInst->getSuccessor(1));
+  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
+  EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);


        


More information about the llvm-commits mailing list