[flang-commits] [flang] [OpenMPIRBuilder] Add ThreadLimit and NumTeamsUpper clauses to teams clause (PR #68364)

via flang-commits flang-commits at lists.llvm.org
Thu Oct 5 16:05:28 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

<details>
<summary>Changes</summary>

This patch adds support for `thread_limit` and upperbound on `num_teams` clause for the teams construct in OpenMP.

Added testcases for the same.

---
Full diff: https://github.com/llvm/llvm-project/pull/68364.diff


3 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+6-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+13-1) 
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+157) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 75da461cfd8d95e..95f30d4a0d5cd5d 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1917,8 +1917,13 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The location where the teams construct was encountered.
   /// \param BodyGenCB Callback that will generate the region code.
+  /// \param NumTeamsUpper Upper bound on the number of teams.
+  /// \param ThreadLimit on the number of threads that may participate in a
+  ///        contention group created by each team.
   InsertPointTy createTeams(const LocationDescription &Loc,
-                            BodyGenCallbackTy BodyGenCB);
+                            BodyGenCallbackTy BodyGenCB,
+                            Value *NumTeamsUpper = nullptr,
+                            Value *ThreadLimit = nullptr);
 
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 72e1af55fe63f60..b6d79e068f61e6e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
 
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
-                             BodyGenCallbackTy BodyGenCB) {
+                             BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper,
+                             Value *ThreadLimit) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -5771,6 +5772,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   BasicBlock *AllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
+  // Push num_teams
+  if (NumTeamsUpper || ThreadLimit) {
+    NumTeamsUpper =
+        NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper;
+    ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit;
+    Value *ThreadNum = getOrCreateThreadID(Ident);
+    Builder.CreateCall(
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams),
+        {Ident, ThreadNum, NumTeamsUpper, ThreadLimit});
+  }
+
   OutlineInfo OI;
   OI.EntryBB = AllocaBB;
   OI.ExitBB = ExitBB;
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..fb87389023910c2 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4082,6 +4082,163 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
                      [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  BasicBlock *CodegenBB = splitBB(Builder, true);
+  Builder.SetInsertPoint(CodegenBB);
+
+  Value *NumTeamsUpper =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
+  Value *ThreadLimit =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit);
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>

``````````

</details>


https://github.com/llvm/llvm-project/pull/68364


More information about the flang-commits mailing list