[clang] [Clang][Sema][OpenMP] Allow `num_teams` to accept multiple expressions (PR #99732)
Shilei Tian via cfe-commits
cfe-commits at lists.llvm.org
Thu Aug 1 16:01:39 PDT 2024
https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/99732
>From 27422f2fb7f6986589df0a5173899c1a39cc011b Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Fri, 19 Jul 2024 22:07:06 -0400
Subject: [PATCH] [Clang][OpenMP] Allow `num_teams` to accept multiple
expressions
---
clang/docs/ReleaseNotes.rst | 3 +
clang/include/clang/AST/OpenMPClause.h | 79 +++++++++++-------
clang/include/clang/AST/RecursiveASTVisitor.h | 2 +-
.../clang/Basic/DiagnosticSemaKinds.td | 1 +
clang/include/clang/Sema/SemaOpenMP.h | 3 +-
clang/lib/AST/OpenMPClause.cpp | 26 +++++-
clang/lib/AST/StmtProfile.cpp | 3 +-
clang/lib/CodeGen/CGOpenMPRuntime.cpp | 7 +-
clang/lib/CodeGen/CGStmtOpenMP.cpp | 2 +-
clang/lib/Parse/ParseOpenMP.cpp | 8 +-
clang/lib/Sema/SemaOpenMP.cpp | 80 ++++++++++++++-----
clang/lib/Sema/TreeTransform.h | 7 +-
clang/lib/Serialization/ASTReader.cpp | 10 ++-
clang/lib/Serialization/ASTWriter.cpp | 4 +-
clang/test/OpenMP/target_teams_ast_print.cpp | 4 +
...et_teams_distribute_num_teams_messages.cpp | 6 ++
...ribute_parallel_for_num_teams_messages.cpp | 5 ++
.../test/OpenMP/teams_num_teams_messages.cpp | 7 ++
clang/tools/libclang/CIndex.cpp | 2 +-
19 files changed, 191 insertions(+), 68 deletions(-)
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 866adefd5d3c4..4abc4e83ac9bc 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -305,6 +305,9 @@ Python Binding Changes
OpenMP Support
--------------
+- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
+ This allows the target region to be launched with multi-dim grid on GPUs.
+
Additional Information
======================
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index b029c72fa7d8f..50ac1e0ea8db7 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
- friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+ : public OMPVarListClause<OMPNumTeamsClause>,
+ public OMPClauseWithPreInit,
+ private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+ friend OMPVarListClause;
+ friend TrailingObjects;
/// Location of '('.
SourceLocation LParenLoc;
- /// NumTeams number.
- Stmt *NumTeams = nullptr;
+ OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+ N),
+ OMPClauseWithPreInit(this) {}
- /// Set the NumTeams number.
- ///
- /// \param E NumTeams number.
- void setNumTeams(Expr *E) { NumTeams = E; }
+ /// Build an empty clause.
+ OMPNumTeamsClause(unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+ SourceLocation(), SourceLocation(), N),
+ OMPClauseWithPreInit(this) {}
public:
- /// Build 'num_teams' clause.
+ /// Creates clause with a list of variables \a VL.
///
- /// \param E Expression associated with this clause.
- /// \param HelperE Helper Expression associated with this clause.
- /// \param CaptureRegion Innermost OpenMP region where expressions in this
- /// clause must be captured.
+ /// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
- OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
- SourceLocation StartLoc, SourceLocation LParenLoc,
- SourceLocation EndLoc)
- : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
- OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
- setPreInitStmt(HelperE, CaptureRegion);
- }
+ /// \param VL List of references to the variables.
+ /// \param PreInit
+ static OMPNumTeamsClause *
+ Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+ SourceLocation StartLoc, SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
- /// Build an empty clause.
- OMPNumTeamsClause()
- : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
- SourceLocation()),
- OMPClauseWithPreInit(this) {}
+ /// Creates an empty clause with \a N variables.
+ ///
+ /// \param C AST context.
+ /// \param N The number of variables.
+ static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }
- /// Return NumTeams number.
- Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+ /// Return NumTeams expressions.
+ ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }
- /// Return NumTeams number.
- Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+ /// Return NumTeams expressions.
+ ArrayRef<Expr *> getNumTeams() const {
+ return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+ }
- child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+ child_range children() {
+ return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+ reinterpret_cast<Stmt **>(varlist_end()));
+ }
const_child_range children() const {
- return const_child_range(&NumTeams, &NumTeams + 1);
+ auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+ return const_child_range(Children.begin(), Children.end());
}
child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index dcf5dbf449f8b..9a6e8a9ea1c7b 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
+ TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
- TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..8e98aa028db08 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11639,6 +11639,7 @@ def warn_omp_unterminated_declare_target : Warning<
InGroup<SourceUsesOpenMP>;
def err_ompx_bare_no_grid : Error<
"'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
+def err_omp_multi_expr_not_allowed: Error<"only one expression allowed to '%0' clause">;
} // end of OpenMP category
let CategoryName = "Related Result Type Issue" in {
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index aa61dae9415e2..703c1511fc3ae 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1226,7 +1226,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
- OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+ OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+ SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..9ec2f593b4477 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
return *It;
}
+OMPNumTeamsClause *OMPNumTeamsClause::Create(
+ const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
+ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
+ ArrayRef<Expr *> VL, Stmt *PreInit) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+ OMPNumTeamsClause *Clause =
+ new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+ Clause->setVarRefs(VL);
+ Clause->setPreInitStmt(PreInit, CaptureRegion);
+ return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+ unsigned N) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+ return new (Mem) OMPNumTeamsClause(N);
+}
+
//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
- OS << "num_teams(";
- Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
- OS << ")";
+ if (!Node->varlist_empty()) {
+ OS << "num_teams";
+ VisitOMPClauseList(Node, '(');
+ OS << ")";
+ }
}
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index f1e723b4242ee..00ba8e490ac4e 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+ VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
- if (C->getNumTeams())
- Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d869aa3322cce..f229202ae5535 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
- const Expr *NumTeams =
- NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
+ const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
+ ->getNumTeams()
+ .front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant =
NumTeams->getIntegerConstantExpr(CGF.getContext()))
@@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
case OMPD_target_teams_distribute_parallel_for_simd: {
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
- D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
+ D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index b1ac9361957ff..0cb8b7804f644 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
- const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
+ const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index e975e96c5c7e4..50930aa0e9a4a 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
? ParseOpenMPSimpleClause(CKind, WrongDirective)
: ParseOpenMPClause(CKind, WrongDirective);
break;
+ case OMPC_num_teams:
+ if (!FirstClause) {
+ Diag(Tok, diag::err_omp_more_one_clause)
+ << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
+ ErrorFound = true;
+ }
+ [[clang::fallthrough]];
case OMPC_private:
case OMPC_firstprivate:
case OMPC_lastprivate:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 4f50efda155fb..8d25358ef5fa3 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -13004,6 +13004,24 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective(
Clauses, AStmt);
}
+// This checks whether num_teams clause only has one expression.
+static bool checkNumTeamsClauseSingleExpr(SemaBase &SemaRef,
+ ArrayRef<OMPClause *> Clauses) {
+ auto NumTeamsClauseItr =
+ llvm::find_if(Clauses, llvm::IsaPred<OMPNumTeamsClause>);
+ if (NumTeamsClauseItr != Clauses.end()) {
+ ArrayRef<const Expr *> NumTeams =
+ cast<OMPNumTeamsClause>(*NumTeamsClauseItr)->getNumTeams();
+ if (NumTeams.size() > 1) {
+ SemaRef.Diag(NumTeams[1]->getBeginLoc(),
+ diag::err_omp_multi_expr_not_allowed)
+ << getOpenMPClauseName(OMPC_num_teams);
+ return false;
+ }
+ }
+ return true;
+}
+
StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt,
SourceLocation StartLoc,
@@ -13011,6 +13029,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
if (!AStmt)
return StmtError();
+ if (!checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
// Report affected OpenMP target offloading behavior when in HIP lang-mode.
if (getLangOpts().HIP && (DSAStack->getParentDirective() == OMPD_target))
Diag(StartLoc, diag::warn_hip_omp_target_directives);
@@ -13785,6 +13806,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
return StmtError();
}
+ if (!HasBareClause && !checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
Clauses, AStmt);
}
@@ -13795,6 +13819,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
CapturedStmt *CS =
setBranchProtectedScope(SemaRef, OMPD_target_teams_distribute, AStmt);
@@ -13821,6 +13848,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_parallel_for, AStmt);
@@ -13848,6 +13878,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_parallel_for_simd, AStmt);
@@ -13878,6 +13911,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumTeamsClauseSingleExpr(*this, Clauses))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_simd, AStmt);
@@ -14925,9 +14961,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
- case OMPC_num_teams:
- Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
- break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
@@ -15031,6 +15064,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
+ case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
@@ -16894,6 +16928,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
+ case OMPC_num_teams:
+ Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+ break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
@@ -16924,7 +16961,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -21587,32 +21623,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
- Expr *ValExpr = NumTeams;
- Stmt *HelperValStmt = nullptr;
-
- // OpenMP [teams Constrcut, Restrictions]
- // The num_teams expression must evaluate to a positive integer value.
- if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
- /*StrictlyPositive=*/true))
+ if (VarList.empty())
return nullptr;
+ for (Expr *ValExpr : VarList) {
+ // OpenMP [teams Constrcut, Restrictions]
+ // The num_teams expression must evaluate to a positive integer value.
+ if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+ /*StrictlyPositive=*/true))
+ return nullptr;
+ }
+
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
- if (CaptureRegion != OMPD_unknown &&
- !SemaRef.CurContext->isDependentContext()) {
+ if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+ return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+ LParenLoc, EndLoc, VarList,
+ /*PreInit=*/nullptr);
+
+ llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+ SmallVector<Expr *, 3> Vars;
+ for (Expr *ValExpr : VarList) {
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
- llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
- HelperValStmt = buildPreInits(getASTContext(), Captures);
+ Vars.push_back(ValExpr);
}
- return new (getASTContext()) OMPNumTeamsClause(
- ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+ Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+ return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+ LParenLoc, EndLoc, Vars, PreInit);
}
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 8d3e1edf7a45d..3fbfb2ec989ce 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2065,10 +2065,11 @@ class TreeTransform {
///
/// By default, performs semantic analysis to build the new statement.
/// Subclasses may override this routine to provide different behavior.
- OMPClause *RebuildOMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+ OMPClause *RebuildOMPNumTeamsClause(ArrayRef<Expr *> VarList,
+ SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
- return getSema().OpenMP().ActOnOpenMPNumTeamsClause(NumTeams, StartLoc,
+ return getSema().OpenMP().ActOnOpenMPNumTeamsClause(VarList, StartLoc,
LParenLoc, EndLoc);
}
@@ -10872,7 +10873,7 @@ TreeTransform<Derived>::TransformOMPAllocateClause(OMPAllocateClause *C) {
template <typename Derived>
OMPClause *
TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) {
- ExprResult E = getDerived().TransformExpr(C->getNumTeams());
+ ExprResult E = getDerived().TransformExpr(C->getNumTeams().front());
if (E.isInvalid())
return nullptr;
return getDerived().RebuildOMPNumTeamsClause(
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 85ff3ab8974ee..3f35bcdc70b4f 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -104,6 +104,7 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
@@ -10569,7 +10570,7 @@ OMPClause *OMPClauseReader::readClause() {
break;
}
case llvm::omp::OMPC_num_teams:
- C = new (Context) OMPNumTeamsClause();
+ C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
break;
case llvm::omp::OMPC_thread_limit:
C = new (Context) OMPThreadLimitClause();
@@ -11357,8 +11358,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {
void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
VisitOMPClauseWithPreInit(C);
- C->setNumTeams(Record.readSubExpr());
C->setLParenLoc(Record.readSourceLocation());
+ unsigned NumVars = C->varlist_size();
+ SmallVector<Expr *, 16> Vars;
+ Vars.reserve(NumVars);
+ for ([[maybe_unused]] unsigned I : llvm::seq<unsigned>(NumVars))
+ Vars.push_back(Record.readSubExpr());
+ C->setVarRefs(Vars);
}
void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index f0f9d397f1717..657eb6d3d1cc4 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7528,9 +7528,11 @@ void OMPClauseWriter::VisitOMPAllocateClause(OMPAllocateClause *C) {
}
void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
+ Record.push_back(C->varlist_size());
VisitOMPClauseWithPreInit(C);
- Record.AddStmt(C->getNumTeams());
Record.AddSourceLocation(C->getLParenLoc());
+ for (auto *VE : C->varlist())
+ Record.AddStmt(VE);
}
void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
diff --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp
index 2ff34e4498bfe..1590a996289f8 100644
--- a/clang/test/OpenMP/target_teams_ast_print.cpp
+++ b/clang/test/OpenMP/target_teams_ast_print.cpp
@@ -115,6 +115,10 @@ int main (int argc, char **argv) {
// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1) thread_limit(32)
a=3;
// CHECK-NEXT: a = 3;
+#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(32)
+// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(32)
+ a=4;
+// CHECK-NEXT: a = 4;
#pragma omp target teams default(none), private(argc,b) num_teams(f) firstprivate(argv) reduction(| : c, d) reduction(* : e) thread_limit(f+g)
// CHECK-NEXT: #pragma omp target teams default(none) private(argc,b) num_teams(f) firstprivate(argv) reduction(|: c,d) reduction(*: e) thread_limit(f + g)
foo();
diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
index c0a31fa19b282..e8f898f1f25ee 100644
--- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp
@@ -44,6 +44,9 @@ T tmain(T argc) {
#pragma omp target teams distribute num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ for (int i=0; i<100; i++) foo();
+
return 0;
}
@@ -85,5 +88,8 @@ int main(int argc, char **argv) {
#pragma omp target teams distribute num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ for (int i=0; i<100; i++) foo();
+
return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
}
diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
index d80b6ea380b93..2a2f5ae27ac55 100644
--- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
+++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp
@@ -43,6 +43,8 @@ T tmain(T argc) {
for (int i=0; i<100; i++) foo();
#pragma omp target teams distribute parallel for num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ for (int i=0; i<100; i++) foo();
return 0;
}
@@ -85,5 +87,8 @@ int main(int argc, char **argv) {
#pragma omp target teams distribute parallel for num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
for (int i=0; i<100; i++) foo();
+#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ for (int i=0; i<100; i++) foo();
+
return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
}
diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp
index 40da396b01069..09429167ee39e 100644
--- a/clang/test/OpenMP/teams_num_teams_messages.cpp
+++ b/clang/test/OpenMP/teams_num_teams_messages.cpp
@@ -57,6 +57,9 @@ T tmain(T argc) {
#pragma omp target
#pragma omp teams num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}}
foo();
+#pragma omp target
+#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ foo();
return 0;
}
@@ -111,5 +114,9 @@ int main(int argc, char **argv) {
#pragma omp teams num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}}
foo();
+#pragma omp target
+#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}}
+ foo();
+
return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}}
}
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 937d7ff09e4ee..d34da8b4eb158 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -2499,8 +2499,8 @@ void OMPClauseEnqueue::VisitOMPDeviceClause(const OMPDeviceClause *C) {
}
void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+ VisitOMPClauseList(C);
VisitOMPClauseWithPreInit(C);
- Visitor->AddStmt(C->getNumTeams());
}
void OMPClauseEnqueue::VisitOMPThreadLimitClause(
More information about the cfe-commits
mailing list