[llvm] [OpenMPIRBuilder] Add ThreadLimit and NumTeamsUpper clauses to teams clause (PR #68364)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 5 16:05:29 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
<details>
<summary>Changes</summary>
This patch adds support for `thread_limit` and upperbound on `num_teams` clause for the teams construct in OpenMP.
Added testcases for the same.
---
Full diff: https://github.com/llvm/llvm-project/pull/68364.diff
3 Files Affected:
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+6-1)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+13-1)
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+157)
``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 75da461cfd8d95e..95f30d4a0d5cd5d 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1917,8 +1917,13 @@ class OpenMPIRBuilder {
///
/// \param Loc The location where the teams construct was encountered.
/// \param BodyGenCB Callback that will generate the region code.
+ /// \param NumTeamsUpper Upper bound on the number of teams.
+ /// \param ThreadLimit on the number of threads that may participate in a
+ /// contention group created by each team.
InsertPointTy createTeams(const LocationDescription &Loc,
- BodyGenCallbackTy BodyGenCB);
+ BodyGenCallbackTy BodyGenCB,
+ Value *NumTeamsUpper = nullptr,
+ Value *ThreadLimit = nullptr);
/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 72e1af55fe63f60..b6d79e068f61e6e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
- BodyGenCallbackTy BodyGenCB) {
+ BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper,
+ Value *ThreadLimit) {
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -5771,6 +5772,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BasicBlock *AllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+ // Push num_teams
+ if (NumTeamsUpper || ThreadLimit) {
+ NumTeamsUpper =
+ NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper;
+ ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit;
+ Value *ThreadNum = getOrCreateThreadID(Ident);
+ Builder.CreateCall(
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams),
+ {Ident, ThreadNum, NumTeamsUpper, ThreadLimit});
+ }
+
OutlineInfo OI;
OI.EntryBB = AllocaBB;
OI.ExitBB = ExitBB;
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..fb87389023910c2 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4082,6 +4082,163 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
[](Instruction &inst) { return isa<ICmpInst>(&inst); }));
}
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ Builder.restoreIP(
+ OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin()));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ Function *PushNumTeamsRTL =
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+ EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+
+ // Verifying that the next instruction to execute is kmpc_fork_teams
+ BranchInst *BrInst =
+ dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+ ASSERT_NE(BrInst, nullptr);
+ ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+ Instruction *NextInstruction =
+ BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+ CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+ ASSERT_NE(ForkTeamsCI, nullptr);
+ EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin()));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ Function *PushNumTeamsRTL =
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+ EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+
+ // Verifying that the next instruction to execute is kmpc_fork_teams
+ BranchInst *BrInst =
+ dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+ ASSERT_NE(BrInst, nullptr);
+ ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+ Instruction *NextInstruction =
+ BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+ CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+ ASSERT_NE(ForkTeamsCI, nullptr);
+ EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ BasicBlock *CodegenBB = splitBB(Builder, true);
+ Builder.SetInsertPoint(CodegenBB);
+
+ Value *NumTeamsUpper =
+ Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
+ Value *ThreadLimit =
+ Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ Builder.restoreIP(
+ OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ Function *PushNumTeamsRTL =
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+ EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper);
+ EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit);
+
+ // Verifying that the next instruction to execute is kmpc_fork_teams
+ BranchInst *BrInst =
+ dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+ ASSERT_NE(BrInst, nullptr);
+ ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+ Instruction *NextInstruction =
+ BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+ CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+ ASSERT_NE(ForkTeamsCI, nullptr);
+ EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>
``````````
</details>
https://github.com/llvm/llvm-project/pull/68364
More information about the llvm-commits
mailing list