[clang] [Clang][Sema][OpenMP] Allow `num_teams` to accept multiple expressions (PR #99732)

Shilei Tian via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 6 05:31:43 PDT 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/99732

>From 0bd0ff8f905c4e430c7203cede19d7e2179d89ef Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Fri, 19 Jul 2024 22:07:06 -0400
Subject: [PATCH] [Clang][OpenMP] Allow `num_teams` to accept multiple
 expressions

---
 clang/docs/OpenMPSupport.rst                  |  2 +
 clang/docs/ReleaseNotes.rst                   |  3 +
 clang/include/clang/AST/OpenMPClause.h        | 79 +++++++++-------
 clang/include/clang/AST/RecursiveASTVisitor.h |  2 +-
 .../clang/Basic/DiagnosticSemaKinds.td        |  2 +
 clang/include/clang/Sema/SemaOpenMP.h         |  3 +-
 clang/lib/AST/OpenMPClause.cpp                | 26 +++++-
 clang/lib/AST/StmtProfile.cpp                 |  3 +-
 clang/lib/CodeGen/CGOpenMPRuntime.cpp         |  7 +-
 clang/lib/CodeGen/CGStmtOpenMP.cpp            |  2 +-
 clang/lib/Parse/ParseOpenMP.cpp               |  8 +-
 clang/lib/Sema/SemaOpenMP.cpp                 | 91 +++++++++++++++----
 clang/lib/Sema/TreeTransform.h                |  7 +-
 clang/lib/Serialization/ASTReader.cpp         | 10 +-
 clang/lib/Serialization/ASTWriter.cpp         |  4 +-
 clang/test/OpenMP/target_teams_ast_print.cpp  |  4 +
 ...et_teams_distribute_num_teams_messages.cpp | 12 +++
 ...ribute_parallel_for_num_teams_messages.cpp |  5 +
 .../test/OpenMP/teams_num_teams_messages.cpp  |  7 ++
 clang/tools/libclang/CIndex.cpp               |  2 +-
 20 files changed, 211 insertions(+), 68 deletions(-)

diff --git a/clang/docs/OpenMPSupport.rst b/clang/docs/OpenMPSupport.rst
index 0e72b3c6e09c9..3fca3318c1fc4 100644
--- a/clang/docs/OpenMPSupport.rst
+++ b/clang/docs/OpenMPSupport.rst
@@ -363,5 +363,7 @@ considered for standardization. Please post on the
 | device extension             | `'ompx_bare' clause on 'target teams' construct                                   | :good:`prototyped`       | #66844, #70612                                         |
 |                              | <https://www.osti.gov/servlets/purl/2205717>`_                                    |                          |                                                        |
 +------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
+| device extension             | Multi-dim `'num_teams' clause on 'target teams ompx_bare' construct               | :good:`partial`          | #99732, #101407                                        |
++------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
 
 .. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 54decb5dbf230..9d2789b2f94f9 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -331,6 +331,9 @@ Improvements
 ^^^^^^^^^^^^
 - Improve the handling of mapping array-section for struct containing nested structs with user defined mappers
 
+- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
+  This allows the target region to be launched with multi-dim grid on GPUs.
+
 Additional Information
 ======================
 
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 3a8337abfe687..1e830b14727c1 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6369,43 +6369,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
 /// \endcode
 /// In this example directive '#pragma omp teams' has clause 'num_teams'
 /// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
-  friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+    : public OMPVarListClause<OMPNumTeamsClause>,
+      public OMPClauseWithPreInit,
+      private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+  friend OMPVarListClause;
+  friend TrailingObjects;
 
   /// Location of '('.
   SourceLocation LParenLoc;
 
-  /// NumTeams number.
-  Stmt *NumTeams = nullptr;
+  OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+                    SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+                         N),
+        OMPClauseWithPreInit(this) {}
 
-  /// Set the NumTeams number.
-  ///
-  /// \param E NumTeams number.
-  void setNumTeams(Expr *E) { NumTeams = E; }
+  /// Build an empty clause.
+  OMPNumTeamsClause(unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+                         SourceLocation(), SourceLocation(), N),
+        OMPClauseWithPreInit(this) {}
 
 public:
-  /// Build 'num_teams' clause.
+  /// Creates clause with a list of variables \a VL.
   ///
-  /// \param E Expression associated with this clause.
-  /// \param HelperE Helper Expression associated with this clause.
-  /// \param CaptureRegion Innermost OpenMP region where expressions in this
-  /// clause must be captured.
+  /// \param C AST context.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
   /// \param EndLoc Ending location of the clause.
-  OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
-                    SourceLocation StartLoc, SourceLocation LParenLoc,
-                    SourceLocation EndLoc)
-      : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
-        OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
-    setPreInitStmt(HelperE, CaptureRegion);
-  }
+  /// \param VL List of references to the variables.
+  /// \param PreInit
+  static OMPNumTeamsClause *
+  Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+         SourceLocation StartLoc, SourceLocation LParenLoc,
+         SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
 
-  /// Build an empty clause.
-  OMPNumTeamsClause()
-      : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
-                  SourceLocation()),
-        OMPClauseWithPreInit(this) {}
+  /// Creates an empty clause with \a N variables.
+  ///
+  /// \param C AST context.
+  /// \param N The number of variables.
+  static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
 
   /// Sets the location of '('.
   void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6413,16 +6424,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
   /// Returns the location of '('.
   SourceLocation getLParenLoc() const { return LParenLoc; }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+  /// Return NumTeams expressions.
+  ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+  /// Return NumTeams expressions.
+  ArrayRef<Expr *> getNumTeams() const {
+    return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+  }
 
-  child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+                       reinterpret_cast<Stmt **>(varlist_end()));
+  }
 
   const_child_range children() const {
-    return const_child_range(&NumTeams, &NumTeams + 1);
+    auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
   }
 
   child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index f469cd9094558..b505c746cc7dc 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3828,8 +3828,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
     OMPNumTeamsClause *C) {
+  TRY_TO(VisitOMPClauseList(C));
   TRY_TO(VisitOMPClauseWithPreInit(C));
-  TRY_TO(TraverseStmt(C->getNumTeams()));
   return true;
 }
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..02dc0ffe41594 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11639,6 +11639,8 @@ def warn_omp_unterminated_declare_target : Warning<
   InGroup<SourceUsesOpenMP>;
 def err_ompx_bare_no_grid : Error<
   "'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
+def err_omp_multi_expr_not_allowed: Error<"only one expression allowed in '%0' clause">;
+def err_ompx_more_than_three_expr_not_allowed: Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">;
 } // end of OpenMP category
 
 let CategoryName = "Related Result Type Issue" in {
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index e23c0cdd5a089..0ceb5fc07765c 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1259,7 +1259,8 @@ class SemaOpenMP : public SemaBase {
       const OMPVarListLocTy &Locs, bool NoDiagnose = false,
       ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
   /// Called on well-formed 'num_teams' clause.
-  OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+  OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+                                       SourceLocation StartLoc,
                                        SourceLocation LParenLoc,
                                        SourceLocation EndLoc);
   /// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index ae4fc11404763..6bdc86f616792 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1755,6 +1755,24 @@ OMPContainsClause *OMPContainsClause::CreateEmpty(const ASTContext &C,
   return new (Mem) OMPContainsClause(K);
 }
 
+OMPNumTeamsClause *OMPNumTeamsClause::Create(
+    const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+    SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
+    ArrayRef<Expr *> VL, Stmt *PreInit) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+  OMPNumTeamsClause *Clause =
+      new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+  Clause->setVarRefs(VL);
+  Clause->setPreInitStmt(PreInit, CaptureRegion);
+  return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+                                                  unsigned N) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+  return new (Mem) OMPNumTeamsClause(N);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -2055,9 +2073,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
 }
 
 void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
-  OS << "num_teams(";
-  Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
-  OS << ")";
+  if (!Node->varlist_empty()) {
+    OS << "num_teams";
+    VisitOMPClauseList(Node, '(');
+    OS << ")";
+  }
 }
 
 void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 09562aadc728a..bf46984e94a85 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -857,9 +857,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
   VisitOMPClauseList(C);
 }
 void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+  VisitOMPClauseList(C);
   VistOMPClauseWithPreInit(C);
-  if (C->getNumTeams())
-    Profiler->VisitStmt(C->getNumTeams());
 }
 void OMPClauseProfiler::VisitOMPThreadLimitClause(
     const OMPThreadLimitClause *C) {
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index de04b7997ca37..fc8fcaed48272 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
             dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
       if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
         if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
-          const Expr *NumTeams =
-              NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
+          const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
+                                     ->getNumTeams()
+                                     .front();
           if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
             if (auto Constant =
                     NumTeams->getIntegerConstantExpr(CGF.getContext()))
@@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
   case OMPD_target_teams_distribute_parallel_for_simd: {
     if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
       const Expr *NumTeams =
-          D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
+          D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
       if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
         if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
           MinTeamsVal = MaxTeamsVal = Constant->getExtValue();
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 96012af7be001..6841ceb3b4154 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
   const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
   const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
   if (NT || TL) {
-    const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
+    const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
     const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
 
     CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index d4dc06dcb9c26..899a56d3d5dfb 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_simdlen:
   case OMPC_collapse:
   case OMPC_ordered:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -3332,6 +3331,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
                  ? ParseOpenMPSimpleClause(CKind, WrongDirective)
                  : ParseOpenMPClause(CKind, WrongDirective);
     break;
+  case OMPC_num_teams:
+    if (!FirstClause) {
+      Diag(Tok, diag::err_omp_more_one_clause)
+          << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
+      ErrorFound = true;
+    }
+    [[clang::fallthrough]];
   case OMPC_private:
   case OMPC_firstprivate:
   case OMPC_lastprivate:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 9b60afd9b211f..7d814e6b27a10 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -13034,6 +13034,25 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective(
                                           Clauses, AStmt);
 }
 
+/// This checks whether a \p ClauseType clause \p C has at most \p Max
+/// expression. If not, a diag of number \p Diag will be emitted.
+template <typename ClauseType>
+static bool checkNumExprsInClause(SemaBase &SemaRef,
+                                  ArrayRef<OMPClause *> Clauses,
+                                  unsigned MaxNum, unsigned Diag) {
+  auto ClauseItr = llvm::find_if(Clauses, llvm::IsaPred<ClauseType>);
+  if (ClauseItr == Clauses.end())
+    return true;
+  const auto *C = cast<ClauseType>(*ClauseItr);
+  auto VarList = C->getVarRefs();
+  if (VarList.size() > MaxNum) {
+    SemaRef.Diag(VarList[MaxNum]->getBeginLoc(), Diag)
+        << getOpenMPClauseName(C->getClauseKind());
+    return false;
+  }
+  return true;
+}
+
 StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
                                                  Stmt *AStmt,
                                                  SourceLocation StartLoc,
@@ -13041,6 +13060,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
   if (!AStmt)
     return StmtError();
 
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+    return StmtError();
+
   // Report affected OpenMP target offloading behavior when in HIP lang-mode.
   if (getLangOpts().HIP && (DSAStack->getParentDirective() == OMPD_target))
     Diag(StartLoc, diag::warn_hip_omp_target_directives);
@@ -13815,6 +13838,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
     return StmtError();
   }
 
+  unsigned ClauseMaxNumExprs = HasBareClause ? 3 : 1;
+  unsigned DiagNo = HasBareClause
+                        ? diag::err_ompx_more_than_three_expr_not_allowed
+                        : diag::err_omp_multi_expr_not_allowed;
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses,
+                                                ClauseMaxNumExprs, DiagNo))
+    return StmtError();
+
   return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
                                          Clauses, AStmt);
 }
@@ -13825,6 +13856,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective(
   if (!AStmt)
     return StmtError();
 
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+    return StmtError();
+
   CapturedStmt *CS =
       setBranchProtectedScope(SemaRef, OMPD_target_teams_distribute, AStmt);
 
@@ -13851,6 +13886,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
   if (!AStmt)
     return StmtError();
 
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+    return StmtError();
+
   CapturedStmt *CS = setBranchProtectedScope(
       SemaRef, OMPD_target_teams_distribute_parallel_for, AStmt);
 
@@ -13878,6 +13917,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(
   if (!AStmt)
     return StmtError();
 
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+    return StmtError();
+
   CapturedStmt *CS = setBranchProtectedScope(
       SemaRef, OMPD_target_teams_distribute_parallel_for_simd, AStmt);
 
@@ -13908,6 +13951,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
   if (!AStmt)
     return StmtError();
 
+  if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+    return StmtError();
+
   CapturedStmt *CS = setBranchProtectedScope(
       SemaRef, OMPD_target_teams_distribute_simd, AStmt);
 
@@ -14955,9 +15002,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_ordered:
     Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
     break;
-  case OMPC_num_teams:
-    Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_thread_limit:
     Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15064,6 +15108,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_affinity:
   case OMPC_when:
   case OMPC_bind:
+  case OMPC_num_teams:
   default:
     llvm_unreachable("Clause is not allowed.");
   }
@@ -16927,6 +16972,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
         static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
         ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_num_teams:
+    Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_if:
   case OMPC_depobj:
   case OMPC_final:
@@ -16957,7 +17005,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
   case OMPC_device:
   case OMPC_threads:
   case OMPC_simd:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -21834,32 +21881,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
   return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
                                                  SourceLocation StartLoc,
                                                  SourceLocation LParenLoc,
                                                  SourceLocation EndLoc) {
-  Expr *ValExpr = NumTeams;
-  Stmt *HelperValStmt = nullptr;
-
-  // OpenMP [teams Constrcut, Restrictions]
-  // The num_teams expression must evaluate to a positive integer value.
-  if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
-                                 /*StrictlyPositive=*/true))
+  if (VarList.empty())
     return nullptr;
 
+  for (Expr *ValExpr : VarList) {
+    // OpenMP [teams Constrcut, Restrictions]
+    // The num_teams expression must evaluate to a positive integer value.
+    if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+                                   /*StrictlyPositive=*/true))
+      return nullptr;
+  }
+
   OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
   OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
       DKind, OMPC_num_teams, getLangOpts().OpenMP);
-  if (CaptureRegion != OMPD_unknown &&
-      !SemaRef.CurContext->isDependentContext()) {
+  if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+    return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+                                     LParenLoc, EndLoc, VarList,
+                                     /*PreInit=*/nullptr);
+
+  llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+  SmallVector<Expr *, 3> Vars;
+  for (Expr *ValExpr : VarList) {
     ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
-    llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
     ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
-    HelperValStmt = buildPreInits(getASTContext(), Captures);
+    Vars.push_back(ValExpr);
   }
 
-  return new (getASTContext()) OMPNumTeamsClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+  return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+                                   LParenLoc, EndLoc, Vars, PreInit);
 }
 
 OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 7f078d55eeb34..26b3c83c0d136 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2079,10 +2079,11 @@ class TreeTransform {
   ///
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
-  OMPClause *RebuildOMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+  OMPClause *RebuildOMPNumTeamsClause(ArrayRef<Expr *> VarList,
+                                      SourceLocation StartLoc,
                                       SourceLocation LParenLoc,
                                       SourceLocation EndLoc) {
-    return getSema().OpenMP().ActOnOpenMPNumTeamsClause(NumTeams, StartLoc,
+    return getSema().OpenMP().ActOnOpenMPNumTeamsClause(VarList, StartLoc,
                                                         LParenLoc, EndLoc);
   }
 
@@ -11018,7 +11019,7 @@ TreeTransform<Derived>::TransformOMPAllocateClause(OMPAllocateClause *C) {
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) {
-  ExprResult E = getDerived().TransformExpr(C->getNumTeams());
+  ExprResult E = getDerived().TransformExpr(C->getNumTeams().front());
   if (E.isInvalid())
     return nullptr;
   return getDerived().RebuildOMPNumTeamsClause(
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index fccdecef27856..bca7f87540cc1 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -104,6 +104,7 @@
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
@@ -10596,7 +10597,7 @@ OMPClause *OMPClauseReader::readClause() {
     break;
   }
   case llvm::omp::OMPC_num_teams:
-    C = new (Context) OMPNumTeamsClause();
+    C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
     break;
   case llvm::omp::OMPC_thread_limit:
     C = new (Context) OMPThreadLimitClause();
@@ -11418,8 +11419,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {
 
 void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
   VisitOMPClauseWithPreInit(C);
-  C->setNumTeams(Record.readSubExpr());
   C->setLParenLoc(Record.readSourceLocation());
+  unsigned NumVars = C->varlist_size();
+  SmallVector<Expr *, 16> Vars;
+  Vars.reserve(NumVars);
+  for (auto _ : llvm::seq<unsigned>(NumVars))
+    Vars.push_back(Record.readSubExpr());
+  C->setVarRefs(Vars);
 }
 
 void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index c167bb8623c13..3082b98f4cd32 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7556,9 +7556,11 @@ void OMPClauseWriter::VisitOMPAllocateClause(OMPAllocateClause *C) {
 }
 
 void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
+  Record.push_back(C->varlist_size());
   VisitOMPClauseWithPreInit(C);
-  Record.AddStmt(C->getNumTeams());
   Record.AddSourceLocation(C->getLParenLoc());
+  for (auto *VE : C->varlist())
+    Record.AddStmt(VE);
 }
 
 void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
diff --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp
index 2ff34e4498bfe..1590a996289f8 100644
--- a/clang/test/OpenMP/target_teams_ast_print.cpp
+++ b/clang/test/OpenMP/target_teams_ast_print.cpp
@@ -115,6 +115,10 @@ int main (int argc, char **argv) {
 // CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1) thread_limit(32)
   a=3;
 // CHECK-NEXT: a = 3;
+#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(32)
+// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(32)
+  a=4;
+// CHECK-NEXT: a = 4;
 #pragma omp target teams default(none), private(argc,b) num_teams(f) firstprivate(argv) reduction(| : c, d) reduction(* : e) thread_limit(f+g)
 // CHECK-NEXT: #pragma omp target teams default(none) private(argc,b) num_teams(f) firstprivate(argv) reduction(|: c,d) reduction(*: e) thread_limit(f + g)
   foo();
diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
index c0a31fa19b282..b489e6a860d67 100644
--- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
@@ -44,6 +44,12 @@ T tmain(T argc) {
 #pragma omp target teams distribute num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  for (int i=0; i<100; i++) foo();
+
+#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}}
+  for (int i=0; i<100; i++) foo();
+
   return 0;
 }
 
@@ -85,5 +91,11 @@ int main(int argc, char **argv) {
 #pragma omp target teams distribute num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  for (int i=0; i<100; i++) foo();
+
+#pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}}
+  for (int i=0; i<100; i++) foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }
diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
index d80b6ea380b93..fa6e8f5887f83 100644
--- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
@@ -43,6 +43,8 @@ T tmain(T argc) {
   for (int i=0; i<100; i++) foo();
 #pragma omp target teams distribute parallel for num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
   for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  for (int i=0; i<100; i++) foo();
 
   return 0;
 }
@@ -85,5 +87,8 @@ int main(int argc, char **argv) {
 #pragma omp target teams distribute parallel for num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  for (int i=0; i<100; i++) foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }
diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp
index 40da396b01069..0cfecc5e11743 100644
--- a/clang/test/OpenMP/teams_num_teams_messages.cpp
+++ b/clang/test/OpenMP/teams_num_teams_messages.cpp
@@ -57,6 +57,9 @@ T tmain(T argc) {
 #pragma omp target
 #pragma omp teams num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
   foo();
+#pragma omp target
+#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  foo();
 
   return 0;
 }
@@ -111,5 +114,9 @@ int main(int argc, char **argv) {
 #pragma omp teams num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
   foo();
 
+#pragma omp target
+#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
+  foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 52d0bd9167b88..48b34e025729c 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -2516,8 +2516,8 @@ void OMPClauseEnqueue::VisitOMPDeviceClause(const OMPDeviceClause *C) {
 }
 
 void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+  VisitOMPClauseList(C);
   VisitOMPClauseWithPreInit(C);
-  Visitor->AddStmt(C->getNumTeams());
 }
 
 void OMPClauseEnqueue::VisitOMPThreadLimitClause(



More information about the cfe-commits mailing list