[flang-commits] [flang] [OpenMPIRBuilder] Add ThreadLimit and NumTeamsUpper clauses to teams clause (PR #68364)

via flang-commits flang-commits at lists.llvm.org
Thu Oct 5 16:04:17 PDT 2023


https://github.com/shraiysh created https://github.com/llvm/llvm-project/pull/68364

This patch adds support for `thread_limit` and upperbound on `num_teams` clause for the teams construct in OpenMP.

Added testcases for the same.

>From 2d3b34476df53f39d6cc6b7eee02b9d0d33e7a04 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Wed, 4 Oct 2023 15:55:55 -0500
Subject: [PATCH 1/3] [OpenMPIRBuilder] Add clauses to teams

This patch adds `num_teams` (upperbound) and `thread_limit` clauses to
`OpenMPIRBuilder`.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   7 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  14 ++-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 115 ++++++++++++++++++
 3 files changed, 134 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 1699ed3aeab7661..8745b6df9e86330 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1893,8 +1893,13 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The location where the teams construct was encountered.
   /// \param BodyGenCB Callback that will generate the region code.
+  /// \param NumTeamsUpper Upper bound on the number of teams.
+  /// \param ThreadLimit on the number of threads that may participate in a
+  ///        contention group created by each team.
   InsertPointTy createTeams(const LocationDescription &Loc,
-                            BodyGenCallbackTy BodyGenCB);
+                            BodyGenCallbackTy BodyGenCB,
+                            Value *NumTeamsUpper = nullptr,
+                            Value *ThreadLimit = nullptr);
 
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..62bc7b3d40ca43a 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
 
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
-                             BodyGenCallbackTy BodyGenCB) {
+                             BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper,
+                             Value *ThreadLimit) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -5771,6 +5772,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   BasicBlock *AllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
+  // Push num_teams
+  if (NumTeamsUpper || ThreadLimit) {
+    NumTeamsUpper =
+        NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper;
+    ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit;
+    Value *ThreadNum = getOrCreateThreadID(Ident);
+    Builder.CreateCall(
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams),
+        {Ident, ThreadNum, NumTeamsUpper, ThreadLimit});
+  }
+
   OutlineInfo OI;
   OI.EntryBB = AllocaBB;
   OI.ExitBB = ExitBB;
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..88b7e4b397e46de 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4082,6 +4082,121 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
                      [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  // M->print(dbgs(), nullptr);
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  BasicBlock *CodegenBB = splitBB(Builder, true);
+  Builder.SetInsertPoint(CodegenBB);
+
+  Value *NumTeamsUpper =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
+  Value *ThreadLimit =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  // M->print(dbgs(), nullptr);
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>

>From 8393f14fb9a5b9f2cf2b8745cebe3d0b702c9541 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 5 Oct 2023 17:57:11 -0500
Subject: [PATCH 2/3] Add testcases

---
 .../Frontend/OpenMPIRBuilderTest.cpp          | 47 +++++++++++++++++--
 1 file changed, 44 insertions(+), 3 deletions(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 88b7e4b397e46de..496c60ba38605ce 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4157,7 +4157,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  // M->print(dbgs(), nullptr);
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
@@ -4194,8 +4215,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  // M->print(dbgs(), nullptr);
-}
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit);
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));}
 
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.

>From 9f368708a33a87dd9fea8944082c54ac37dd85c9 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 5 Oct 2023 17:58:02 -0500
Subject: [PATCH 3/3] Formatting

---
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 496c60ba38605ce..fb87389023910c2 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4236,7 +4236,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
   ASSERT_NE(ForkTeamsCI, nullptr);
   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
-            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));}
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
 
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.



More information about the flang-commits mailing list