[llvm] f62badd - [OpenMP][IRBuilder] Add final clause to task

Shraiysh Vaishay via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 10 11:32:47 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-06-11T00:02:18+05:30
New Revision: f62baddac052b4e746a3ec817b31adafec79427c

URL: https://github.com/llvm/llvm-project/commit/f62baddac052b4e746a3ec817b31adafec79427c
DIFF: https://github.com/llvm/llvm-project/commit/f62baddac052b4e746a3ec817b31adafec79427c.diff

LOG: [OpenMP][IRBuilder] Add final clause to task

This patch adds final clause to OpenMP IR Builder.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D126626

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c0f3020201d54..8a6b1c7d412df 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -624,9 +624,11 @@ class OpenMPIRBuilder {
   /// \param AllocaIP The insertion point to be used for alloca instructions.
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param Tied True if the task is tied, false if the task is untied.
+  /// \param Final i1 value which is `true` if the task is final, `false` if the
+  ///              task is not final.
   InsertPointTy createTask(const LocationDescription &Loc,
                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                           bool Tied = true);
+                           bool Tied = true, Value *Final = nullptr);
 
   /// Functions used to generate reductions. Such functions take two Values
   /// representing LHS and RHS of the reduction, respectively, and a reference

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index b00246760ea6d..9b08a24e14d44 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1256,7 +1256,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                            bool Tied) {
+                            bool Tied, Value *Final) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -1285,7 +1285,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) {
+  OI.PostOutlineCB = [this, &Loc, Tied, Final](Function &OutlinedFn) {
     // The input IR here looks like the following-
     // ```
     // func @current_fn() {
@@ -1330,10 +1330,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     Value *ThreadID = getOrCreateThreadID(Ident);
 
     // Argument - `flags`
-    // If task is tied, then (Flags & 1) == 1.
-    // If task is untied, then (Flags & 1) == 0.
+    // Task is tied iff (Flags & 1) == 1.
+    // Task is untied iff (Flags & 1) == 0.
+    // Task is final iff (Flags & 2) == 2.
+    // Task is not final iff (Flags & 2) == 0.
     // TODO: Handle the other flags.
     Value *Flags = Builder.getInt32(Tied);
+    if (Final) {
+      Value *FinalFlag =
+          Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
+      Flags = Builder.CreateOr(FinalFlag, Flags);
+    }
 
     // Argument - `sizeof_kmp_task_t` (TaskSize)
     // Tasksize refers to the size in bytes of kmp_task_t data structure

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 9abce4443c639..2012e5b44fc9a 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4832,4 +4832,57 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  Builder.SetInsertPoint(BodyBB);
+  Value *Final = Builder.CreateICmp(
+      CmpInst::Predicate::ICMP_EQ, F->getArg(0),
+      ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
+  OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+  Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
+                                          /*Tied=*/false, Final));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  // Check for the `Tied` argument
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  ASSERT_NE(TaskAllocCall, nullptr);
+  BinaryOperator *OrInst =
+      dyn_cast<BinaryOperator>(TaskAllocCall->getArgOperand(2));
+  ASSERT_NE(OrInst, nullptr);
+  EXPECT_EQ(OrInst->getOpcode(), BinaryOperator::BinaryOps::Or);
+
+  // One of the arguments to `or` instruction is the tied flag, which is equal
+  // to zero.
+  EXPECT_TRUE(any_of(OrInst->operands(), [](Value *op) {
+    if (ConstantInt *TiedValue = dyn_cast<ConstantInt>(op))
+      return TiedValue->getSExtValue() == 0;
+    return false;
+  }));
+
+  // One of the arguments to `or` instruction is the final condition.
+  EXPECT_TRUE(any_of(OrInst->operands(), [Final](Value *op) {
+    if (SelectInst *Select = dyn_cast<SelectInst>(op)) {
+      ConstantInt *TrueValue = dyn_cast<ConstantInt>(Select->getTrueValue());
+      ConstantInt *FalseValue = dyn_cast<ConstantInt>(Select->getFalseValue());
+      if (!TrueValue || !FalseValue)
+        return false;
+      return Select->getCondition() == Final &&
+             TrueValue->getSExtValue() == 2 && FalseValue->getSExtValue() == 0;
+    }
+    return false;
+  }));
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
 } // namespace


        


More information about the llvm-commits mailing list