[llvm] [OpenMPIRBuilder] Added `if` clause for `teams` (PR #69139)

via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 15 20:57:18 PDT 2023


https://github.com/shraiysh created https://github.com/llvm/llvm-project/pull/69139

This patch adds support for the `if` clause on `teams` construct. The value of the argument must be an integer value. If the value evaluates to true (non-zero) integer, then the number of threads is determined by `num_threads` clause (or default and ICV if `num_threads` is absent). When the condition evaluates to false (zero), then the bounds are set to 1.

This essentially means that
```
upperbound = ifexpr ? upperbound : 1
lowerbound = ifexpr ? lowerbound : 1
```

>From 66edac69884f19acdc614f4596efedf6d0cda546 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Sun, 15 Oct 2023 22:53:01 -0500
Subject: [PATCH] [OpenMPIRBuilder] Added `if` clause for `teams`

This patch adds support for the `if` clause on `teams` construct. The
value of the argument must be an integer value. If the value evaluates
to true (non-zero) integer, then the number of threads is determined by
`num_threads` clause (or default and ICV if `num_threads` is absent).
When the condition evaluates to false (zero), then the bounds are set to
1.

This essentially means that
```
upperbound = ifexpr ? upperbound : 1
lowerbound = ifexpr ? lowerbound : 1
```
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  11 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  21 ++-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 146 +++++++++++++++++-
 3 files changed, 165 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9d2adf229b78654..00b4707a7f820d7 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1923,11 +1923,12 @@ class OpenMPIRBuilder {
   /// \param NumTeamsUpper Upper bound on the number of teams.
   /// \param ThreadLimit on the number of threads that may participate in a
   ///        contention group created by each team.
-  InsertPointTy createTeams(const LocationDescription &Loc,
-                            BodyGenCallbackTy BodyGenCB,
-                            Value *NumTeamsLower = nullptr,
-                            Value *NumTeamsUpper = nullptr,
-                            Value *ThreadLimit = nullptr);
+  /// \param IfExpr is the integer argument value of the if condition on the
+  ///        teams clause.
+  InsertPointTy
+  createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
+              Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr,
+              Value *ThreadLimit = nullptr, Value *IfExpr = nullptr);
 
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a658990f2d45355..5b24e9fe2e0c5bd 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5734,7 +5734,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
-                             Value *NumTeamsUpper, Value *ThreadLimit) {
+                             Value *NumTeamsUpper, Value *ThreadLimit,
+                             Value *IfExpr) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -5773,7 +5774,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
   // Push num_teams
-  if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
+  if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
            "if lowerbound is non-null, then upperbound must also be non-null "
            "for bounds on num_teams");
@@ -5784,6 +5785,22 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
     if (NumTeamsLower == nullptr)
       NumTeamsLower = NumTeamsUpper;
 
+    if (IfExpr) {
+      assert(IfExpr->getType()->isIntegerTy() &&
+             "argument to if clause must be an integer value");
+
+      // upper = ifexpr ? upper : 1
+      if (IfExpr->getType() != Int1)
+        IfExpr = Builder.CreateICmpNE(IfExpr,
+                                      ConstantInt::get(IfExpr->getType(), 0));
+      NumTeamsUpper = Builder.CreateSelect(
+          IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
+
+      // lower = ifexpr ? lower : 1
+      NumTeamsLower = Builder.CreateSelect(
+          IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
+    }
+
     if (ThreadLimit == nullptr)
       ThreadLimit = Builder.getInt32(0);
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d770facc1730252..97cfc339675f657 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
   };
 
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+  Builder.restoreIP(OMPBuilder.createTeams(
+      Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+      /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
 
   OMPBuilder.finalize();
   Builder.CreateRetVoid();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
   Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
                                            /*NumTeamsLower=*/nullptr,
                                            /*NumTeamsUpper=*/nullptr,
-                                           /*ThreadLimit=*/F->arg_begin()));
+                                           /*ThreadLimit=*/F->arg_begin(),
+                                           /*IfExpr=*/nullptr));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
   // `num_teams`
   Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
                                            /*NumTeamsLower=*/nullptr,
-                                           /*NumTeamsUpper=*/F->arg_begin()));
+                                           /*NumTeamsUpper=*/F->arg_begin(),
+                                           /*ThreadLimit=*/nullptr,
+                                           /*IfExpr=*/nullptr));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
   // `F` already has an integer argument, so we use that as upper bound to
   // `num_teams`
   Builder.restoreIP(
-      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
+      OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
+                             /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
   };
 
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
-                                           NumTeamsUpper, ThreadLimit));
+  Builder.restoreIP(OMPBuilder.createTeams(
+      Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
             OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
+                                     Builder.CreateAlloca(Builder.getInt1Ty()));
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  // `F` already has an integer argument, so we use that as upper bound to
+  // `num_teams`
+  Builder.restoreIP(OMPBuilder.createTeams(
+      Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+      /*ThreadLimit=*/nullptr, IfExpr));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+  Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
+  Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
+  Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);
+
+  // Check the lower_bound
+  ASSERT_NE(NumTeamsLower, nullptr);
+  SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
+  ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+  EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
+  EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
+  EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+  // Check the upper_bound
+  ASSERT_NE(NumTeamsUpper, nullptr);
+  SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
+  ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+  EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
+  EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
+  EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+  // Check thread_limit
+  EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> &Builder = OMPBuilder.Builder;
+  Builder.SetInsertPoint(BB);
+
+  Value *IfExpr = Builder.CreateLoad(
+      Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
+  Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
+  Value *NumTeamsUpper =
+      Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
+  Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));
+
+  Function *FakeFunction =
+      Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+                       GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    Builder.CreateCall(FakeFunction, {});
+  };
+
+  // `F` already has an integer argument, so we use that as upper bound to
+  // `num_teams`
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
+                                           NumTeamsUpper, ThreadLimit, IfExpr));
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  ASSERT_FALSE(verifyModule(*M));
+
+  CallInst *PushNumTeamsCallInst =
+      findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+  ASSERT_NE(PushNumTeamsCallInst, nullptr);
+  Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
+  Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
+  Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);
+
+  // Get the boolean conversion of if expression
+  ASSERT_EQ(IfExpr->getNumUses(), 1U);
+  User *IfExprInst = IfExpr->user_back();
+  ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
+  ASSERT_NE(IfExprCmpInst, nullptr);
+  EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
+  EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
+  EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));
+
+  // Check the lower_bound
+  ASSERT_NE(NumTeamsLowerArg, nullptr);
+  SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
+  ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+  EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
+  EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
+  EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+  // Check the upper_bound
+  ASSERT_NE(NumTeamsUpperArg, nullptr);
+  SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
+  ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+  EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
+  EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
+  EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+  // Check thread_limit
+  EXPECT_EQ(ThreadLimitArg, ThreadLimit);
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>



More information about the llvm-commits mailing list