[clang] [Clang][OpenMP] Allow `num_teams` to accept multiple expressions (PR #99732)

via cfe-commits cfe-commits at lists.llvm.org
Fri Jul 19 19:08:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/99732.diff


7 Files Affected:

- (modified) clang/include/clang/AST/OpenMPClause.h (+48-31) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1-1) 
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+2-1) 
- (modified) clang/lib/AST/OpenMPClause.cpp (+23-3) 
- (modified) clang/lib/AST/StmtProfile.cpp (+1-2) 
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+1-1) 
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+23-18) 


``````````diff
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 325a1baa44614..2e82ccac28dc8 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
 /// \endcode
 /// In this example directive '#pragma omp teams' has clause 'num_teams'
 /// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
-  friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+    : public OMPVarListClause<OMPNumTeamsClause>,
+      public OMPClauseWithPreInit,
+      private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+  friend OMPVarListClause;
+  friend TrailingObjects;
 
   /// Location of '('.
   SourceLocation LParenLoc;
 
-  /// NumTeams number.
-  Stmt *NumTeams = nullptr;
+  OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+                    SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+                         N),
+        OMPClauseWithPreInit(this) {}
 
-  /// Set the NumTeams number.
-  ///
-  /// \param E NumTeams number.
-  void setNumTeams(Expr *E) { NumTeams = E; }
+  /// Build an empty clause.
+  OMPNumTeamsClause(unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+                         SourceLocation(), SourceLocation(), N),
+        OMPClauseWithPreInit(this) {}
 
 public:
-  /// Build 'num_teams' clause.
+  /// Creates clause with a list of variables \a VL.
   ///
-  /// \param E Expression associated with this clause.
-  /// \param HelperE Helper Expression associated with this clause.
-  /// \param CaptureRegion Innermost OpenMP region where expressions in this
-  /// clause must be captured.
+  /// \param C AST context.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
   /// \param EndLoc Ending location of the clause.
-  OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
-                    SourceLocation StartLoc, SourceLocation LParenLoc,
-                    SourceLocation EndLoc)
-      : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
-        OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
-    setPreInitStmt(HelperE, CaptureRegion);
-  }
+  /// \param VL List of references to the variables.
+  /// \param PreInit
+  static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+                                   SourceLocation LParenLoc,
+                                   SourceLocation EndLoc, ArrayRef<Expr *> VL,
+                                   Stmt *PreInit);
 
-  /// Build an empty clause.
-  OMPNumTeamsClause()
-      : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
-                  SourceLocation()),
-        OMPClauseWithPreInit(this) {}
+  /// Creates an empty clause with \a N variables.
+  ///
+  /// \param C AST context.
+  /// \param N The number of variables.
+  static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
 
   /// Sets the location of '('.
   void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
   /// Returns the location of '('.
   SourceLocation getLParenLoc() const { return LParenLoc; }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+  /// Return NumTeams number. By default, we return the first expression.
+  Expr *getNumTeams() { return getVarRefs().front(); }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+  /// Return NumTeams number. By default, we return the first expression.
+  Expr *getNumTeams() const {
+    return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+  }
 
-  child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+                       reinterpret_cast<Stmt **>(varlist_end()));
+  }
 
   const_child_range children() const {
-    return const_child_range(&NumTeams, &NumTeams + 1);
+    auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
   }
 
   child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e3c0cb46799f7..beb7b3597c2a8 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
     OMPNumTeamsClause *C) {
+  TRY_TO(VisitOMPClauseList(C));
   TRY_TO(VisitOMPClauseWithPreInit(C));
-  TRY_TO(TraverseStmt(C->getNumTeams()));
   return true;
 }
 
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 54d81f91ffebc..bf5fbc670b05c 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
       const OMPVarListLocTy &Locs, bool NoDiagnose = false,
       ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
   /// Called on well-formed 'num_teams' clause.
-  OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+  OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+                                       SourceLocation StartLoc,
                                        SourceLocation LParenLoc,
                                        SourceLocation EndLoc);
   /// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..ee9e9a0d39a92 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
   return *It;
 }
 
+OMPNumTeamsClause *
+OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
+                          SourceLocation LParenLoc, SourceLocation EndLoc,
+                          ArrayRef<Expr *> VL, Stmt *PreInit) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+  OMPNumTeamsClause *Clause =
+      new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+  Clause->setVarRefs(VL);
+  Clause->setPreInitStmt(PreInit);
+  return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+                                                  unsigned N) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+  return new (Mem) OMPNumTeamsClause(N);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
 }
 
 void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
-  OS << "num_teams(";
-  Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
-  OS << ")";
+  if (!Node->varlist_empty()) {
+    OS << "num_teams";
+    VisitOMPClauseList(Node, '(');
+    OS << ")";
+  }
 }
 
 void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 89d2a422509d8..b782a4ab8367e 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
   VisitOMPClauseList(C);
 }
 void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+  VisitOMPClauseList(C);
   VistOMPClauseWithPreInit(C);
-  if (C->getNumTeams())
-    Profiler->VisitStmt(C->getNumTeams());
 }
 void OMPClauseProfiler::VisitOMPThreadLimitClause(
     const OMPThreadLimitClause *C) {
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index f5b44d210680c..e851bb4ac7fef 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_simdlen:
   case OMPC_collapse:
   case OMPC_ordered:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_affinity:
   case OMPC_doacross:
   case OMPC_enter:
+  case OMPC_num_teams:
     if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered &&
         CKind == OMPC_depend)
       Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 3bd981cb442aa..a4e0ce730ae05 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_ordered:
     Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
     break;
-  case OMPC_num_teams:
-    Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_thread_limit:
     Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_affinity:
   case OMPC_when:
   case OMPC_bind:
+  case OMPC_num_teams:
   default:
     llvm_unreachable("Clause is not allowed.");
   }
@@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
         static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
         ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_num_teams:
+    Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_if:
   case OMPC_depobj:
   case OMPC_final:
@@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
   case OMPC_device:
   case OMPC_threads:
   case OMPC_simd:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
   return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
                                                  SourceLocation StartLoc,
                                                  SourceLocation LParenLoc,
                                                  SourceLocation EndLoc) {
-  Expr *ValExpr = NumTeams;
-  Stmt *HelperValStmt = nullptr;
-
-  // OpenMP [teams Constrcut, Restrictions]
-  // The num_teams expression must evaluate to a positive integer value.
-  if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
-                                 /*StrictlyPositive=*/true))
+  if (VarList.empty())
     return nullptr;
 
   OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
   OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
       DKind, OMPC_num_teams, getLangOpts().OpenMP);
-  if (CaptureRegion != OMPD_unknown &&
-      !SemaRef.CurContext->isDependentContext()) {
+
+  if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+    return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
+                                     EndLoc, VarList, /*PreInit=*/nullptr);
+
+  llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+  SmallVector<Expr *, 3> Vars;
+  for (Expr *ValExpr : VarList) {
+    // OpenMP [teams Constrcut, Restrictions]
+    // The num_teams expression must evaluate to a positive integer value.
+    if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+                                   /*StrictlyPositive=*/true))
+      return nullptr;
     ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
-    llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
     ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
-    HelperValStmt = buildPreInits(getASTContext(), Captures);
+    Vars.push_back(ValExpr);
   }
 
-  return new (getASTContext()) OMPNumTeamsClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+  return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
+                                   Vars, PreInit);
 }
 
 OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,

``````````

</details>


https://github.com/llvm/llvm-project/pull/99732


More information about the cfe-commits mailing list