[clang] [Clang][OpenMP] Allow `num_teams` to accept multiple expressions (PR #99732)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Jul 19 19:08:43 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Shilei Tian (shiltian)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/99732.diff
7 Files Affected:
- (modified) clang/include/clang/AST/OpenMPClause.h (+48-31)
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1-1)
- (modified) clang/include/clang/Sema/SemaOpenMP.h (+2-1)
- (modified) clang/lib/AST/OpenMPClause.cpp (+23-3)
- (modified) clang/lib/AST/StmtProfile.cpp (+1-2)
- (modified) clang/lib/Parse/ParseOpenMP.cpp (+1-1)
- (modified) clang/lib/Sema/SemaOpenMP.cpp (+23-18)
``````````diff
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 325a1baa44614..2e82ccac28dc8 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
- friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+ : public OMPVarListClause<OMPNumTeamsClause>,
+ public OMPClauseWithPreInit,
+ private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+ friend OMPVarListClause;
+ friend TrailingObjects;
/// Location of '('.
SourceLocation LParenLoc;
- /// NumTeams number.
- Stmt *NumTeams = nullptr;
+ OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+ N),
+ OMPClauseWithPreInit(this) {}
- /// Set the NumTeams number.
- ///
- /// \param E NumTeams number.
- void setNumTeams(Expr *E) { NumTeams = E; }
+ /// Build an empty clause.
+ OMPNumTeamsClause(unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+ SourceLocation(), SourceLocation(), N),
+ OMPClauseWithPreInit(this) {}
public:
- /// Build 'num_teams' clause.
+ /// Creates clause with a list of variables \a VL.
///
- /// \param E Expression associated with this clause.
- /// \param HelperE Helper Expression associated with this clause.
- /// \param CaptureRegion Innermost OpenMP region where expressions in this
- /// clause must be captured.
+ /// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
- OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
- SourceLocation StartLoc, SourceLocation LParenLoc,
- SourceLocation EndLoc)
- : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
- OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
- setPreInitStmt(HelperE, CaptureRegion);
- }
+ /// \param VL List of references to the variables.
+ /// \param PreInit
+ static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> VL,
+ Stmt *PreInit);
- /// Build an empty clause.
- OMPNumTeamsClause()
- : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
- SourceLocation()),
- OMPClauseWithPreInit(this) {}
+ /// Creates an empty clause with \a N variables.
+ ///
+ /// \param C AST context.
+ /// \param N The number of variables.
+ static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }
- /// Return NumTeams number.
- Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+ /// Return NumTeams number. By default, we return the first expression.
+ Expr *getNumTeams() { return getVarRefs().front(); }
- /// Return NumTeams number.
- Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+ /// Return NumTeams number. By default, we return the first expression.
+ Expr *getNumTeams() const {
+ return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+ }
- child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+ child_range children() {
+ return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+ reinterpret_cast<Stmt **>(varlist_end()));
+ }
const_child_range children() const {
- return const_child_range(&NumTeams, &NumTeams + 1);
+ auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+ return const_child_range(Children.begin(), Children.end());
}
child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e3c0cb46799f7..beb7b3597c2a8 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
+ TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
- TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 54d81f91ffebc..bf5fbc670b05c 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
- OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+ OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+ SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..ee9e9a0d39a92 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
return *It;
}
+OMPNumTeamsClause *
+OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc,
+ ArrayRef<Expr *> VL, Stmt *PreInit) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+ OMPNumTeamsClause *Clause =
+ new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+ Clause->setVarRefs(VL);
+ Clause->setPreInitStmt(PreInit);
+ return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+ unsigned N) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+ return new (Mem) OMPNumTeamsClause(N);
+}
+
//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
- OS << "num_teams(";
- Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
- OS << ")";
+ if (!Node->varlist_empty()) {
+ OS << "num_teams";
+ VisitOMPClauseList(Node, '(');
+ OS << ")";
+ }
}
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 89d2a422509d8..b782a4ab8367e 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+ VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
- if (C->getNumTeams())
- Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index f5b44d210680c..e851bb4ac7fef 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_affinity:
case OMPC_doacross:
case OMPC_enter:
+ case OMPC_num_teams:
if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered &&
CKind == OMPC_depend)
Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 3bd981cb442aa..a4e0ce730ae05 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
- case OMPC_num_teams:
- Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
- break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
@@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
+ case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
@@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
+ case OMPC_num_teams:
+ Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+ break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
@@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
- Expr *ValExpr = NumTeams;
- Stmt *HelperValStmt = nullptr;
-
- // OpenMP [teams Constrcut, Restrictions]
- // The num_teams expression must evaluate to a positive integer value.
- if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
- /*StrictlyPositive=*/true))
+ if (VarList.empty())
return nullptr;
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
- if (CaptureRegion != OMPD_unknown &&
- !SemaRef.CurContext->isDependentContext()) {
+
+ if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+ return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
+ EndLoc, VarList, /*PreInit=*/nullptr);
+
+ llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+ SmallVector<Expr *, 3> Vars;
+ for (Expr *ValExpr : VarList) {
+ // OpenMP [teams Constrcut, Restrictions]
+ // The num_teams expression must evaluate to a positive integer value.
+ if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+ /*StrictlyPositive=*/true))
+ return nullptr;
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
- llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
- HelperValStmt = buildPreInits(getASTContext(), Captures);
+ Vars.push_back(ValExpr);
}
- return new (getASTContext()) OMPNumTeamsClause(
- ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+ Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+ return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
+ Vars, PreInit);
}
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
``````````
</details>
https://github.com/llvm/llvm-project/pull/99732
More information about the cfe-commits
mailing list