[clang] [OpenMPIRBuilder] Add ThreadLimit and NumTeamsUpper clauses to teams clause (PR #68364)

via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 9 09:48:57 PDT 2023


https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/68364

>From 2d3b34476df53f39d6cc6b7eee02b9d0d33e7a04 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Wed, 4 Oct 2023 15:55:55 -0500
Subject: [PATCH 1/5] [OpenMPIRBuilder] Add clauses to teams

This patch adds `num_teams` (upperbound) and `thread_limit` clauses to
`OpenMPIRBuilder`.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   7 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  14 ++-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 115 ++++++++++++++++++
 3 files changed, 134 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 1699ed3aeab7661..8745b6df9e86330 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1893,8 +1893,13 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The location where the teams construct was encountered.
   /// \param BodyGenCB Callback that will generate the region code.
+  /// \param NumTeamsUpper Upper bound on the number of teams.
+  /// \param ThreadLimit on the number of threads that may participate in a
+  ///        contention group created by each team.
   InsertPointTy createTeams(const LocationDescription &Loc,
-                            BodyGenCallbackTy BodyGenCB);
+                            BodyGenCallbackTy BodyGenCB,
+                            Value *NumTeamsUpper = nullptr,
+                            Value *ThreadLimit = nullptr);
 
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..62bc7b3d40ca43a 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
 
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
-                             BodyGenCallbackTy BodyGenCB) {
+                             BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper,
+                             Value *ThreadLimit) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -5771,6 +5772,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   BasicBlock *AllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
+  // Push num_teams
+  if (NumTeamsUpper || ThreadLimit) {
+    NumTeamsUpper =
+        NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper;
+    ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit;
+    Value *ThreadNum = getOrCreateThreadID(Ident);
+    Builder.CreateCall(
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams),
+        {Ident, ThreadNum, NumTeamsUpper, ThreadLimit});
+  }
+
   OutlineInfo OI;
   OI.EntryBB = AllocaBB;
   OI.ExitBB = ExitBB;
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..88b7e4b397e46de 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4082,6 +4082,121 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
                      [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin()));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  // M->print(dbgs(), nullptr);
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  BasicBlock *CodegenBB = splitBB(Builder, true);
+  Builder.SetInsertPoint(CodegenBB);
+
+  Value *NumTeamsUpper =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
+  Value *ThreadLimit =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  // M->print(dbgs(), nullptr);
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>

>From 8393f14fb9a5b9f2cf2b8745cebe3d0b702c9541 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 5 Oct 2023 17:57:11 -0500
Subject: [PATCH 2/5] Add testcases

---
 .../Frontend/OpenMPIRBuilderTest.cpp          | 47 +++++++++++++++++--
 1 file changed, 44 insertions(+), 3 deletions(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 88b7e4b397e46de..496c60ba38605ce 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4157,7 +4157,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  // M->print(dbgs(), nullptr);
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
@@ -4194,8 +4215,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  // M->print(dbgs(), nullptr);
-}
+  Function *PushNumTeamsRTL =
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
+  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit);
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));}
 
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.

>From 9f368708a33a87dd9fea8944082c54ac37dd85c9 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 5 Oct 2023 17:58:02 -0500
Subject: [PATCH 3/5] Formatting

---
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 496c60ba38605ce..fb87389023910c2 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4236,7 +4236,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
   CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
   ASSERT_NE(ForkTeamsCI, nullptr);
   EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
-            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));}
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
 
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.

>From 25870b3f64b9a07452e9207c30d15dc960c69f18 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Mon, 9 Oct 2023 09:10:18 -0500
Subject: [PATCH 4/5] Address comments

---
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index bd38f5bece16df1..fef592718f79c95 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4091,9 +4091,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
     Builder.CreateCall(FakeFunction, {});
   };
 
-  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Builder.restoreIP(
-      OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin()));
+  // `F` has an argument - an integer, so we use that as the thread limit.
+  Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
+                                           /*NumTeamsUpper=*/nullptr,
+                                           /*ThreadLimit=*/F->arg_begin()));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4141,8 +4142,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
     Builder.CreateCall(FakeFunction, {});
   };
 
-  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin()));
+  // `F` already has an integer argument, so we use that as upper bound to
+  // `num_teams`
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
+                                           /*NumTeamsUpper=*/F->arg_begin()));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4184,6 +4187,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
   BasicBlock *CodegenBB = splitBB(Builder, true);
   Builder.SetInsertPoint(CodegenBB);
 
+  // Generate values for `num_teams` and `thread_limit` using the first argument
+  // of the testing function.
   Value *NumTeamsUpper =
       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
   Value *ThreadLimit =

>From c13e639039cd57d5ce9e97c1075dc21f662dbee8 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Mon, 9 Oct 2023 11:45:56 -0500
Subject: [PATCH 5/5] Address comment about lower bound

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  4 +
 .../include/llvm/Frontend/OpenMP/OMPKinds.def |  1 +
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 26 ++++--
 .../Frontend/OpenMPIRBuilderTest.cpp          | 92 ++++++++++++++-----
 4 files changed, 93 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ba679e2998eb413..9d2adf229b78654 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1917,11 +1917,15 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The location where the teams construct was encountered.
   /// \param BodyGenCB Callback that will generate the region code.
+  /// \param NumTeamsLower Lower bound on number of teams. If this is nullptr,
+  ///        it is as if lower bound is specified as equal to upperbound. If
+  ///        this is non-null, then upperbound must also be non-null.
   /// \param NumTeamsUpper Upper bound on the number of teams.
   /// \param ThreadLimit on the number of threads that may participate in a
   ///        contention group created by each team.
   InsertPointTy createTeams(const LocationDescription &Loc,
                             BodyGenCallbackTy BodyGenCB,
+                            Value *NumTeamsLower = nullptr,
                             Value *NumTeamsUpper = nullptr,
                             Value *ThreadLimit = nullptr);
 
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 176b883fe68f7ad..4823c4cc6b833ec 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -387,6 +387,7 @@ __OMP_RTL(__kmpc_cancellationpoint, false, Int32, IdentPtr, Int32, Int32)
 
 __OMP_RTL(__kmpc_fork_teams, true, Void, IdentPtr, Int32, ParallelTaskPtr)
 __OMP_RTL(__kmpc_push_num_teams, false, Void, IdentPtr, Int32, Int32, Int32)
+__OMP_RTL(__kmpc_push_num_teams_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32)
 __OMP_RTL(__kmpc_set_thread_limit, false, Void, IdentPtr, Int32, Int32)
 
 __OMP_RTL(__kmpc_copyprivate, false, Void, IdentPtr, Int32, SizeTy, VoidPtr,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 315967dc2b2a6f6..a658990f2d45355 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5733,8 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
 
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
-                             BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper,
-                             Value *ThreadLimit) {
+                             BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
+                             Value *NumTeamsUpper, Value *ThreadLimit) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -5773,14 +5773,24 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
   // Push num_teams
-  if (NumTeamsUpper || ThreadLimit) {
-    NumTeamsUpper =
-        NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper;
-    ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit;
+  if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
+    assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
+           "if lowerbound is non-null, then upperbound must also be non-null "
+           "for bounds on num_teams");
+
+    if (NumTeamsUpper == nullptr)
+      NumTeamsUpper = Builder.getInt32(0);
+
+    if (NumTeamsLower == nullptr)
+      NumTeamsLower = NumTeamsUpper;
+
+    if (ThreadLimit == nullptr)
+      ThreadLimit = Builder.getInt32(0);
+
     Value *ThreadNum = getOrCreateThreadID(Ident);
     Builder.CreateCall(
-        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams),
-        {Ident, ThreadNum, NumTeamsUpper, ThreadLimit});
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
+        {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
   }
   // Generate the body of teams.
   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fef592718f79c95..37400a9be0d14a3 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4093,6 +4093,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
 
   // `F` has an argument - an integer, so we use that as the thread limit.
   Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
+                                           /*NumTeamsLower=*/nullptr,
                                            /*NumTeamsUpper=*/nullptr,
                                            /*ThreadLimit=*/F->arg_begin()));
 
@@ -4101,16 +4102,13 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  Function *PushNumTeamsRTL =
-      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
-  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
-
   CallInst *PushNumTeamsCallInst =
-      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
   ASSERT_NE(PushNumTeamsCallInst, nullptr);
 
   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
-  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), &*F->arg_begin());
 
   // Verifying that the next instruction to execute is kmpc_fork_teams
   BranchInst *BrInst =
@@ -4125,7 +4123,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
 }
 
-TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
@@ -4145,6 +4143,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
   // `F` already has an integer argument, so we use that as upper bound to
   // `num_teams`
   Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
+                                           /*NumTeamsLower=*/nullptr,
                                            /*NumTeamsUpper=*/F->arg_begin()));
 
   Builder.CreateRetVoid();
@@ -4152,16 +4151,66 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) {
 
   ASSERT_FALSE(verifyModule(*M));
 
-  Function *PushNumTeamsRTL =
-      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
-  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
-
   CallInst *PushNumTeamsCallInst =
-      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
   ASSERT_NE(PushNumTeamsCallInst, nullptr);
 
   EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
-  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));
+
+  // Verifying that the next instruction to execute is kmpc_fork_teams
+  BranchInst *BrInst =
+      dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
+  ASSERT_NE(BrInst, nullptr);
+  ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
+  Instruction *NextInstruction =
+      BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
+  CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
+  ASSERT_NE(ForkTeamsCI, nullptr);
+  EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
+            OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  Value *NumTeamsLower =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
+  Value *NumTeamsUpper =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  // `F` already has an integer argument, so we use that as upper bound to
+  // `num_teams`
+  Builder.restoreIP(
+      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));
 
   // Verifying that the next instruction to execute is kmpc_fork_teams
   BranchInst *BrInst =
@@ -4189,6 +4238,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
 
   // Generate values for `num_teams` and `thread_limit` using the first argument
   // of the testing function.
+  Value *NumTeamsLower =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
   Value *NumTeamsUpper =
       Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
   Value *ThreadLimit =
@@ -4204,24 +4255,21 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
   };
 
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Builder.restoreIP(
-      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit));
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
+                                           NumTeamsUpper, ThreadLimit));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
 
   ASSERT_FALSE(verifyModule(*M));
 
-  Function *PushNumTeamsRTL =
-      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams);
-  EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U);
-
   CallInst *PushNumTeamsCallInst =
-      findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder);
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
   ASSERT_NE(PushNumTeamsCallInst, nullptr);
 
-  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper);
-  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
+  EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), ThreadLimit);
 
   // Verifying that the next instruction to execute is kmpc_fork_teams
   BranchInst *BrInst =



More information about the cfe-commits mailing list