[clang] 1c26992 - [Clang][Sema][OpenMP] Allow `thread_limit` to accept multiple expressions (#102715)

via cfe-commits cfe-commits at lists.llvm.org
Sat Aug 10 06:55:02 PDT 2024


Author: Shilei Tian
Date: 2024-08-10T09:54:58-04:00
New Revision: 1c269929d03e4a664a1f05d494b8fefe291ef8c0

URL: https://github.com/llvm/llvm-project/commit/1c269929d03e4a664a1f05d494b8fefe291ef8c0
DIFF: https://github.com/llvm/llvm-project/commit/1c269929d03e4a664a1f05d494b8fefe291ef8c0.diff

LOG: [Clang][Sema][OpenMP] Allow `thread_limit` to accept multiple expressions (#102715)

Added: 
    

Modified: 
    clang/docs/OpenMPSupport.rst
    clang/docs/ReleaseNotes.rst
    clang/include/clang/AST/OpenMPClause.h
    clang/include/clang/AST/RecursiveASTVisitor.h
    clang/include/clang/Sema/SemaOpenMP.h
    clang/lib/AST/OpenMPClause.cpp
    clang/lib/AST/StmtProfile.cpp
    clang/lib/CodeGen/CGOpenMPRuntime.cpp
    clang/lib/CodeGen/CGStmtOpenMP.cpp
    clang/lib/Parse/ParseOpenMP.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Sema/TreeTransform.h
    clang/lib/Serialization/ASTReader.cpp
    clang/lib/Serialization/ASTWriter.cpp
    clang/test/OpenMP/target_teams_ast_print.cpp
    clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
    clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
    clang/test/OpenMP/teams_num_teams_messages.cpp
    clang/tools/libclang/CIndex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/OpenMPSupport.rst b/clang/docs/OpenMPSupport.rst
index 3fc74cdd07f71c..cdbd69520e5bb5 100644
--- a/clang/docs/OpenMPSupport.rst
+++ b/clang/docs/OpenMPSupport.rst
@@ -363,7 +363,8 @@ considered for standardization. Please post on the
 | device extension             | `'ompx_bare' clause on 'target teams' construct                                   | :good:`prototyped`       | #66844, #70612                                         |
 |                              | <https://www.osti.gov/servlets/purl/2205717>`_                                    |                          |                                                        |
 +------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
-| device extension             | Multi-dim 'num_teams' clause on 'target teams ompx_bare' construct                | :good:`partial`          | #99732, #101407                                        |
+| device extension             | Multi-dim 'num_teams' and 'thread_limit' clause on 'target teams ompx_bare'       | :good:`partial`          | #99732, #101407, #102715                               |
+|                              | construct                                                                         |                          |                                                        |
 +------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
 
 .. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35

diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 351b41b1c0c588..602f3edaf121cb 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -360,8 +360,9 @@ Improvements
 ^^^^^^^^^^^^
 - Improve the handling of mapping array-section for struct containing nested structs with user defined mappers
 
-- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
-  This allows the target region to be launched with multi-dim grid on GPUs.
+- `num_teams` and `thead_limit` now accept multiple expressions when it is used
+  along in ``target teams ompx_bare`` construct. This allows the target region
+  to be launched with multi-dim grid on GPUs.
 
 Additional Information
 ======================

diff  --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 1e830b14727c19..c1b9e0dbafb6c3 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6462,44 +6462,55 @@ class OMPNumTeamsClause final
 /// \endcode
 /// In this example directive '#pragma omp teams' has clause 'thread_limit'
 /// with single expression 'n'.
-class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit {
-  friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit'
+/// clause can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare thread_limit(x, y, z)
+/// \endcode
+class OMPThreadLimitClause final
+    : public OMPVarListClause<OMPThreadLimitClause>,
+      public OMPClauseWithPreInit,
+      private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> {
+  friend OMPVarListClause;
+  friend TrailingObjects;
 
   /// Location of '('.
   SourceLocation LParenLoc;
 
-  /// ThreadLimit number.
-  Stmt *ThreadLimit = nullptr;
+  OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc,
+                       SourceLocation LParenLoc, SourceLocation EndLoc,
+                       unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_thread_limit, StartLoc, LParenLoc,
+                         EndLoc, N),
+        OMPClauseWithPreInit(this) {}
 
-  /// Set the ThreadLimit number.
-  ///
-  /// \param E ThreadLimit number.
-  void setThreadLimit(Expr *E) { ThreadLimit = E; }
+  /// Build an empty clause.
+  OMPThreadLimitClause(unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
+                         SourceLocation(), SourceLocation(), N),
+        OMPClauseWithPreInit(this) {}
 
 public:
-  /// Build 'thread_limit' clause.
+  /// Creates clause with a list of variables \a VL.
   ///
-  /// \param E Expression associated with this clause.
-  /// \param HelperE Helper Expression associated with this clause.
-  /// \param CaptureRegion Innermost OpenMP region where expressions in this
-  /// clause must be captured.
+  /// \param C AST context.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
   /// \param EndLoc Ending location of the clause.
-  OMPThreadLimitClause(Expr *E, Stmt *HelperE,
-                       OpenMPDirectiveKind CaptureRegion,
-                       SourceLocation StartLoc, SourceLocation LParenLoc,
-                       SourceLocation EndLoc)
-      : OMPClause(llvm::omp::OMPC_thread_limit, StartLoc, EndLoc),
-        OMPClauseWithPreInit(this), LParenLoc(LParenLoc), ThreadLimit(E) {
-    setPreInitStmt(HelperE, CaptureRegion);
-  }
+  /// \param VL List of references to the variables.
+  /// \param PreInit
+  static OMPThreadLimitClause *
+  Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+         SourceLocation StartLoc, SourceLocation LParenLoc,
+         SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
 
-  /// Build an empty clause.
-  OMPThreadLimitClause()
-      : OMPClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
-                  SourceLocation()),
-        OMPClauseWithPreInit(this) {}
+  /// Creates an empty clause with \a N variables.
+  ///
+  /// \param C AST context.
+  /// \param N The number of variables.
+  static OMPThreadLimitClause *CreateEmpty(const ASTContext &C, unsigned N);
 
   /// Sets the location of '('.
   void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6507,16 +6518,22 @@ class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit {
   /// Returns the location of '('.
   SourceLocation getLParenLoc() const { return LParenLoc; }
 
-  /// Return ThreadLimit number.
-  Expr *getThreadLimit() { return cast<Expr>(ThreadLimit); }
+  /// Return ThreadLimit expressions.
+  ArrayRef<Expr *> getThreadLimit() { return getVarRefs(); }
 
-  /// Return ThreadLimit number.
-  Expr *getThreadLimit() const { return cast<Expr>(ThreadLimit); }
+  /// Return ThreadLimit expressions.
+  ArrayRef<Expr *> getThreadLimit() const {
+    return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit();
+  }
 
-  child_range children() { return child_range(&ThreadLimit, &ThreadLimit + 1); }
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+                       reinterpret_cast<Stmt **>(varlist_end()));
+  }
 
   const_child_range children() const {
-    return const_child_range(&ThreadLimit, &ThreadLimit + 1);
+    auto Children = const_cast<OMPThreadLimitClause *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
   }
 
   child_range used_children() {

diff  --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index b505c746cc7dc2..2b35997bd539ac 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3836,8 +3836,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPThreadLimitClause(
     OMPThreadLimitClause *C) {
+  TRY_TO(VisitOMPClauseList(C));
   TRY_TO(VisitOMPClauseWithPreInit(C));
-  TRY_TO(TraverseStmt(C->getThreadLimit()));
   return true;
 }
 

diff  --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 0ceb5fc07765c4..e55731212c4a41 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1264,7 +1264,7 @@ class SemaOpenMP : public SemaBase {
                                        SourceLocation LParenLoc,
                                        SourceLocation EndLoc);
   /// Called on well-formed 'thread_limit' clause.
-  OMPClause *ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
+  OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList,
                                           SourceLocation StartLoc,
                                           SourceLocation LParenLoc,
                                           SourceLocation EndLoc);

diff  --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 6bdc86f6167920..7e73c076239410 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1773,6 +1773,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
   return new (Mem) OMPNumTeamsClause(N);
 }
 
+OMPThreadLimitClause *OMPThreadLimitClause::Create(
+    const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+    SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
+    ArrayRef<Expr *> VL, Stmt *PreInit) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+  OMPThreadLimitClause *Clause =
+      new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+  Clause->setVarRefs(VL);
+  Clause->setPreInitStmt(PreInit, CaptureRegion);
+  return Clause;
+}
+
+OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C,
+                                                        unsigned N) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+  return new (Mem) OMPThreadLimitClause(N);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -2081,9 +2099,11 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
 }
 
 void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
-  OS << "thread_limit(";
-  Node->getThreadLimit()->printPretty(OS, nullptr, Policy, 0);
-  OS << ")";
+  if (!Node->varlist_empty()) {
+    OS << "thread_limit";
+    VisitOMPClauseList(Node, '(');
+    OS << ")";
+  }
 }
 
 void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) {

diff  --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index bf46984e94a85d..35d8b0706fe3ce 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -862,9 +862,8 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
 }
 void OMPClauseProfiler::VisitOMPThreadLimitClause(
     const OMPThreadLimitClause *C) {
+  VisitOMPClauseList(C);
   VistOMPClauseWithPreInit(C);
-  if (C->getThreadLimit())
-    Profiler->VisitStmt(C->getThreadLimit());
 }
 void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
   VistOMPClauseWithPreInit(C);

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index be8ab2d121277e..8c5e4aa9c037e2 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6332,7 +6332,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
           CGOpenMPInnerExprInfo CGInfo(CGF, *CS);
           CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
           CodeGenFunction::LexicalScope Scope(
-              CGF, ThreadLimitClause->getThreadLimit()->getSourceRange());
+              CGF,
+              ThreadLimitClause->getThreadLimit().front()->getSourceRange());
           if (const auto *PreInit =
                   cast_or_null<DeclStmt>(ThreadLimitClause->getPreInitStmt())) {
             for (const auto *I : PreInit->decls()) {
@@ -6349,7 +6350,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
       }
     }
     if (ThreadLimitClause)
-      CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
+      CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
+                        ThreadLimitExpr);
     if (const auto *Dir = dyn_cast_or_null<OMPExecutableDirective>(Child)) {
       if (isOpenMPTeamsDirective(Dir->getDirectiveKind()) &&
           !isOpenMPDistributeDirective(Dir->getDirectiveKind())) {
@@ -6370,7 +6372,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
     if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
       CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
       const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
-      CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
+      CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
+                        ThreadLimitExpr);
     }
     const CapturedStmt *CS = D.getInnermostCapturedStmt();
     getNumThreads(CGF, CS, NTPtr, UpperBound, UpperBoundOnly, CondVal);
@@ -6388,7 +6391,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
     if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
       CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
       const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
-      CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
+      CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
+                        ThreadLimitExpr);
     }
     getNumThreads(CGF, D.getInnermostCapturedStmt(), NTPtr, UpperBound,
                   UpperBoundOnly, CondVal);
@@ -6424,7 +6428,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
     if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
       CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
       const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
-      CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
+      CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
+                        ThreadLimitExpr);
     }
     if (D.hasClausesOfKind<OMPNumThreadsClause>()) {
       CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);

diff  --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 6841ceb3b41548..8afe2abf2cc494 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -5259,7 +5259,7 @@ void CodeGenFunction::EmitOMPTargetTaskBasedDirective(
       // enclosing this target region. This will indirectly set the thread_limit
       // for every applicable construct within target region.
       CGF.CGM.getOpenMPRuntime().emitThreadLimitClause(
-          CGF, TL->getThreadLimit(), S.getBeginLoc());
+          CGF, TL->getThreadLimit().front(), S.getBeginLoc());
     }
     BodyGen(CGF);
   };
@@ -6860,7 +6860,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
   const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
   if (NT || TL) {
     const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
-    const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
+    const Expr *ThreadLimit = TL ? TL->getThreadLimit().front() : nullptr;
 
     CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
                                                   S.getBeginLoc());

diff  --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 5732ee7add7c03..61aa72c30a4654 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_simdlen:
   case OMPC_collapse:
   case OMPC_ordered:
-  case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
   case OMPC_num_tasks:
@@ -3332,6 +3331,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
                  : ParseOpenMPClause(CKind, WrongDirective);
     break;
   case OMPC_num_teams:
+  case OMPC_thread_limit:
     if (!FirstClause) {
       Diag(Tok, diag::err_omp_more_one_clause)
           << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index b5978ddde24651..87d81dfaad601b 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -13061,6 +13061,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
     return StmtError();
 
   if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(
           *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
     return StmtError();
 
@@ -13843,7 +13845,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
                         ? diag::err_ompx_more_than_three_expr_not_allowed
                         : diag::err_omp_multi_expr_not_allowed;
   if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses,
-                                                ClauseMaxNumExprs, DiagNo))
+                                                ClauseMaxNumExprs, DiagNo) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(*this, Clauses,
+                                                   ClauseMaxNumExprs, DiagNo))
     return StmtError();
 
   return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
@@ -13857,6 +13861,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective(
     return StmtError();
 
   if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(
           *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
     return StmtError();
 
@@ -13887,6 +13893,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
     return StmtError();
 
   if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(
           *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
     return StmtError();
 
@@ -13918,6 +13926,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(
     return StmtError();
 
   if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(
           *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
     return StmtError();
 
@@ -13952,6 +13962,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
     return StmtError();
 
   if (!checkNumExprsInClause<OMPNumTeamsClause>(
+          *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) ||
+      !checkNumExprsInClause<OMPThreadLimitClause>(
           *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
     return StmtError();
 
@@ -15002,9 +15014,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_ordered:
     Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
     break;
-  case OMPC_thread_limit:
-    Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_priority:
     Res = ActOnOpenMPPriorityClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15109,6 +15118,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_when:
   case OMPC_bind:
   case OMPC_num_teams:
+  case OMPC_thread_limit:
   default:
     llvm_unreachable("Clause is not allowed.");
   }
@@ -16975,6 +16985,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
   case OMPC_num_teams:
     Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_thread_limit:
+    Res = ActOnOpenMPThreadLimitClause(VarList, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_if:
   case OMPC_depobj:
   case OMPC_final:
@@ -17005,7 +17018,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
   case OMPC_device:
   case OMPC_threads:
   case OMPC_simd:
-  case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
   case OMPC_nogroup:
@@ -21919,32 +21931,40 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
                                    LParenLoc, EndLoc, Vars, PreInit);
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
+OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList,
                                                     SourceLocation StartLoc,
                                                     SourceLocation LParenLoc,
                                                     SourceLocation EndLoc) {
-  Expr *ValExpr = ThreadLimit;
-  Stmt *HelperValStmt = nullptr;
-
-  // OpenMP [teams Constrcut, Restrictions]
-  // The thread_limit expression must evaluate to a positive integer value.
-  if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_thread_limit,
-                                 /*StrictlyPositive=*/true))
+  if (VarList.empty())
     return nullptr;
 
+  for (Expr *ValExpr : VarList) {
+    // OpenMP [teams Constrcut, Restrictions]
+    // The thread_limit expression must evaluate to a positive integer value.
+    if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_thread_limit,
+                                   /*StrictlyPositive=*/true))
+      return nullptr;
+  }
+
   OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
   OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
       DKind, OMPC_thread_limit, getLangOpts().OpenMP);
-  if (CaptureRegion != OMPD_unknown &&
-      !SemaRef.CurContext->isDependentContext()) {
+  if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+    return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion,
+                                        StartLoc, LParenLoc, EndLoc, VarList,
+                                        /*PreInit=*/nullptr);
+
+  llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+  SmallVector<Expr *, 3> Vars;
+  for (Expr *ValExpr : VarList) {
     ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
-    llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
     ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
-    HelperValStmt = buildPreInits(getASTContext(), Captures);
+    Vars.push_back(ValExpr);
   }
 
-  return new (getASTContext()) OMPThreadLimitClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+  return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion, StartLoc,
+                                      LParenLoc, EndLoc, Vars, PreInit);
 }
 
 OMPClause *SemaOpenMP::ActOnOpenMPPriorityClause(Expr *Priority,

diff  --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 8f6f30434af65e..78ec964037dfe9 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2091,12 +2091,12 @@ class TreeTransform {
   ///
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide 
diff erent behavior.
-  OMPClause *RebuildOMPThreadLimitClause(Expr *ThreadLimit,
+  OMPClause *RebuildOMPThreadLimitClause(ArrayRef<Expr *> VarList,
                                          SourceLocation StartLoc,
                                          SourceLocation LParenLoc,
                                          SourceLocation EndLoc) {
-    return getSema().OpenMP().ActOnOpenMPThreadLimitClause(
-        ThreadLimit, StartLoc, LParenLoc, EndLoc);
+    return getSema().OpenMP().ActOnOpenMPThreadLimitClause(VarList, StartLoc,
+                                                           LParenLoc, EndLoc);
   }
 
   /// Build a new OpenMP 'priority' clause.
@@ -11028,11 +11028,16 @@ TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) {
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPThreadLimitClause(OMPThreadLimitClause *C) {
-  ExprResult E = getDerived().TransformExpr(C->getThreadLimit());
-  if (E.isInvalid())
-    return nullptr;
+  llvm::SmallVector<Expr *, 3> Vars;
+  Vars.reserve(C->varlist_size());
+  for (auto *VE : C->varlist()) {
+    ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE));
+    if (EVar.isInvalid())
+      return nullptr;
+    Vars.push_back(EVar.get());
+  }
   return getDerived().RebuildOMPThreadLimitClause(
-      E.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+      Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
 }
 
 template <typename Derived>

diff  --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index ad8d6c336f2780..e1d554ee7db224 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -10645,7 +10645,7 @@ OMPClause *OMPClauseReader::readClause() {
     C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
     break;
   case llvm::omp::OMPC_thread_limit:
-    C = new (Context) OMPThreadLimitClause();
+    C = OMPThreadLimitClause::CreateEmpty(Context, Record.readInt());
     break;
   case llvm::omp::OMPC_priority:
     C = new (Context) OMPPriorityClause();
@@ -11477,8 +11477,13 @@ void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
 
 void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
   VisitOMPClauseWithPreInit(C);
-  C->setThreadLimit(Record.readSubExpr());
   C->setLParenLoc(Record.readSourceLocation());
+  unsigned NumVars = C->varlist_size();
+  SmallVector<Expr *, 16> Vars;
+  Vars.reserve(NumVars);
+  for (auto _ : llvm::seq<unsigned>(NumVars))
+    Vars.push_back(Record.readSubExpr());
+  C->setVarRefs(Vars);
 }
 
 void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) {

diff  --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 25e50e4bdc5f80..b5d487465541b8 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7589,9 +7589,11 @@ void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
 }
 
 void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
+  Record.push_back(C->varlist_size());
   VisitOMPClauseWithPreInit(C);
-  Record.AddStmt(C->getThreadLimit());
   Record.AddSourceLocation(C->getLParenLoc());
+  for (auto *VE : C->varlist())
+    Record.AddStmt(VE);
 }
 
 void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) {

diff  --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp
index 1590a996289f8f..ca5d26822ec96d 100644
--- a/clang/test/OpenMP/target_teams_ast_print.cpp
+++ b/clang/test/OpenMP/target_teams_ast_print.cpp
@@ -115,8 +115,8 @@ int main (int argc, char **argv) {
 // CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1) thread_limit(32)
   a=3;
 // CHECK-NEXT: a = 3;
-#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(32)
-// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(32)
+#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(2, 4, 6)
+// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(2,4,6)
   a=4;
 // CHECK-NEXT: a = 4;
 #pragma omp target teams default(none), private(argc,b) num_teams(f) firstprivate(argv) reduction(| : c, d) reduction(* : e) thread_limit(f+g)

diff  --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
index b489e6a860d672..8bf388f0b5da98 100644
--- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
@@ -47,9 +47,15 @@ T tmain(T argc) {
 #pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  for (int i=0; i<100; i++) foo();
+
 #pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}}
+  for (int i=0; i<100; i++) foo();
+
   return 0;
 }
 
@@ -94,8 +100,14 @@ int main(int argc, char **argv) {
 #pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  for (int i=0; i<100; i++) foo();
+
 #pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}}
+  for (int i=0; i<100; i++) foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }

diff  --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
index fa6e8f5887f834..092e0137d250d8 100644
--- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
@@ -45,6 +45,8 @@ T tmain(T argc) {
   for (int i=0; i<100; i++) foo();
 #pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  for (int i=0; i<100; i++) foo();
 
   return 0;
 }
@@ -90,5 +92,8 @@ int main(int argc, char **argv) {
 #pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   for (int i=0; i<100; i++) foo();
 
+#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  for (int i=0; i<100; i++) foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }

diff  --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp
index 0cfecc5e117438..615bf0be0d8147 100644
--- a/clang/test/OpenMP/teams_num_teams_messages.cpp
+++ b/clang/test/OpenMP/teams_num_teams_messages.cpp
@@ -60,6 +60,9 @@ T tmain(T argc) {
 #pragma omp target
 #pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   foo();
+#pragma omp target
+#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  foo();
 
   return 0;
 }
@@ -118,5 +121,9 @@ int main(int argc, char **argv) {
 #pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}}
   foo();
 
+#pragma omp target
+#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}}
+  foo();
+
   return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
 }

diff  --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 48b34e025729c8..66636f2c665feb 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -2522,8 +2522,8 @@ void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
 
 void OMPClauseEnqueue::VisitOMPThreadLimitClause(
     const OMPThreadLimitClause *C) {
+  VisitOMPClauseList(C);
   VisitOMPClauseWithPreInit(C);
-  Visitor->AddStmt(C->getThreadLimit());
 }
 
 void OMPClauseEnqueue::VisitOMPPriorityClause(const OMPPriorityClause *C) {


        


More information about the cfe-commits mailing list