[clang] [OpenACC] Implement 'if' clause for Compute Constructs (PR #88411)
Erich Keane via cfe-commits
cfe-commits at lists.llvm.org
Thu Apr 11 09:34:00 PDT 2024
https://github.com/erichkeane created https://github.com/llvm/llvm-project/pull/88411
Like with the 'default' clause, this is being applied to only Compute Constructs for now. The 'if' clause takes a condition expression which is used as a runtime value.
This is not a particularly complex semantic implementation, as there isn't much to this clause, other than its interactions with 'self',
which will be managed in the patch to implement that.
>From 408f39f8ed0ee121aeaeb15c02603bb127e8cb73 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Wed, 10 Apr 2024 07:56:30 -0700
Subject: [PATCH] [OpenACC] Implement 'if' clause for Compute Constructs
Like with the 'default' clause, this is being applied to only Compute
Constructs for now. The 'if' clause takes a condition expression which
is used as a runtime value.
This is not a particularly complex semantic implementation, as there
isn't much to this clause, other than its interactions with 'self',
which will be managed in the patch to implement that.
---
clang/include/clang/AST/ASTNodeTraverser.h | 3 +-
clang/include/clang/AST/OpenACCClause.h | 83 +++++++++++-
clang/include/clang/Basic/OpenACCClauses.def | 21 +++
clang/include/clang/Parse/Parser.h | 6 +-
clang/include/clang/Sema/SemaOpenACC.h | 28 +++-
clang/lib/AST/OpenACCClause.cpp | 41 ++++++
clang/lib/AST/StmtProfile.cpp | 19 ++-
clang/lib/AST/TextNodeDumper.cpp | 5 +
clang/lib/Parse/ParseOpenACC.cpp | 32 +++--
clang/lib/Sema/SemaOpenACC.cpp | 73 +++++++++--
clang/lib/Sema/TreeTransform.h | 13 ++
clang/lib/Serialization/ASTReader.cpp | 7 +-
clang/lib/Serialization/ASTWriter.cpp | 7 +-
clang/test/ParserOpenACC/parse-clauses.c | 2 -
.../compute-construct-clause-ast.cpp | 120 +++++++++++++++++-
.../SemaOpenACC/compute-construct-if-clause.c | 62 +++++++++
.../compute-construct-if-clause.cpp | 33 +++++
17 files changed, 517 insertions(+), 38 deletions(-)
create mode 100644 clang/include/clang/Basic/OpenACCClauses.def
create mode 100644 clang/test/SemaOpenACC/compute-construct-if-clause.c
create mode 100644 clang/test/SemaOpenACC/compute-construct-if-clause.cpp
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index 94e7dd817809dd..37fe030fb8e5a3 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -243,7 +243,8 @@ class ASTNodeTraverser
void Visit(const OpenACCClause *C) {
getNodeDelegate().AddChild([=] {
getNodeDelegate().Visit(C);
- // TODO OpenACC: Switch on clauses that have children, and add them.
+ for (const auto *S : C->children())
+ Visit(S);
});
}
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 27e4e1a12c9837..6e3c00614168e7 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -15,6 +15,7 @@
#define LLVM_CLANG_AST_OPENACCCLAUSE_H
#include "clang/AST/ASTContext.h"
#include "clang/Basic/OpenACCKinds.h"
+#include "clang/AST/StmtIterator.h"
namespace clang {
/// This is the base type for all OpenACC Clauses.
@@ -34,6 +35,17 @@ class OpenACCClause {
static bool classof(const OpenACCClause *) { return true; }
+ using child_iterator = StmtIterator;
+ using const_child_iterator = ConstStmtIterator;
+ using child_range = llvm::iterator_range<child_iterator>;
+ using const_child_range = llvm::iterator_range<const_child_iterator>;
+
+ child_range children();
+ const_child_range children() const {
+ auto Children = const_cast<OpenACCClause *>(this)->children();
+ return const_child_range(Children.begin(), Children.end());
+ }
+
virtual ~OpenACCClause() = default;
};
@@ -49,6 +61,14 @@ class OpenACCClauseWithParams : public OpenACCClause {
public:
SourceLocation getLParenLoc() const { return LParenLoc; }
+
+ child_range children() {
+ return child_range(child_iterator(), child_iterator());
+ }
+ const_child_range children() const {
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+
};
/// A 'default' clause, has the optional 'none' or 'present' argument.
@@ -81,6 +101,52 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
SourceLocation EndLoc);
};
+/// Represents one of the handful of classes that has an optional/required
+/// 'condition' expression as an argument.
+class OpenACCClauseWithCondition : public OpenACCClauseWithParams {
+ Expr *ConditionExpr;
+
+ protected:
+ OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc,
+ SourceLocation LParenLoc,
+ Expr *ConditionExpr, SourceLocation EndLoc)
+ : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
+ ConditionExpr(ConditionExpr) {}
+
+ public:
+ bool hasConditionExpr() const { return ConditionExpr; }
+ const Expr *getConditionExpr() const { return ConditionExpr; }
+ Expr *getConditionExpr() { return ConditionExpr; }
+
+ child_range children() {
+ if (ConditionExpr)
+ return child_range(reinterpret_cast<Stmt **>(&ConditionExpr),
+ reinterpret_cast<Stmt **>(&ConditionExpr + 1));
+ return child_range(child_iterator(), child_iterator());
+ }
+
+ const_child_range children() const {
+ if (ConditionExpr)
+ return const_child_range(
+ reinterpret_cast<Stmt *const *>(&ConditionExpr),
+ reinterpret_cast<Stmt *const *>(&ConditionExpr + 1));
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+};
+
+/// An 'if' clause, which has a required condition expression.
+class OpenACCIfClause : public OpenACCClauseWithCondition {
+protected:
+ OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+ Expr *ConditionExpr, SourceLocation EndLoc);
+
+public:
+ static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
+ SourceLocation LParenLoc,
+ Expr *ConditionExpr,
+ SourceLocation EndLoc);
+};
+
template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }
@@ -98,6 +164,9 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Default:
VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
return;
+ case OpenACCClauseKind::If:
+ VisitOpenACCIfClause(*cast<OpenACCIfClause>(C));
+ return;
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
@@ -106,7 +175,6 @@ template <class Impl> class OpenACCClauseVisitor {
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
- case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
@@ -145,9 +213,13 @@ template <class Impl> class OpenACCClauseVisitor {
llvm_unreachable("Invalid Clause kind");
}
- void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
- return getDerived().VisitOpenACCDefaultClause(Clause);
+#define VISIT_CLAUSE(CLAUSE_NAME) \
+ void VisitOpenACC##CLAUSE_NAME##Clause( \
+ const OpenACC##CLAUSE_NAME##Clause &Clause) {\
+ return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause); \
}
+
+#include "clang/Basic/OpenACCClauses.def"
};
class OpenACCClausePrinter final
@@ -165,7 +237,10 @@ class OpenACCClausePrinter final
}
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}
- void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
+#define VISIT_CLAUSE(CLAUSE_NAME) \
+ void VisitOpenACC##CLAUSE_NAME##Clause( \
+ const OpenACC##CLAUSE_NAME##Clause &Clause);
+#include "clang/Basic/OpenACCClauses.def"
};
} // namespace clang
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
new file mode 100644
index 00000000000000..7fd2720e02ce22
--- /dev/null
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -0,0 +1,21 @@
+//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a list of currently implemented OpenACC Clauses (and
+// eventually, the entire list) in a way that makes generating 'visitor' and
+// other lists easier.
+//
+// The primary macro is a single-argument version taking the name of the Clause
+// as used in Clang source (so `Default` instead of `default`).
+//
+// VISIT_CLAUSE(CLAUSE_NAME)
+
+VISIT_CLAUSE(Default)
+VISIT_CLAUSE(If)
+
+#undef VISIT_CLAUSE
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 3a055c10ffb387..9d83a52929789e 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler {
OpenACCClauseParseResult OpenACCCannotContinue();
OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause);
+ using OpenACCConditionExprParseResult =
+ std::pair<ExprResult, OpenACCParseCanContinue>;
+
/// Parses the OpenACC directive (the entire pragma) including the clause
/// list, but does not produce the main AST node.
OpenACCDirectiveParseInfo ParseOpenACCDirective();
@@ -3657,7 +3660,8 @@ class Parser : public CodeCompletionHandler {
bool ParseOpenACCGangArgList();
/// Parses a 'gang-arg', used for the 'gang' clause.
bool ParseOpenACCGangArg();
-
+ /// Parses a 'condition' expr, ensuring it results in a
+ ExprResult ParseOpenACCConditionExpr();
private:
//===--------------------------------------------------------------------===//
// C++ 14: Templates [temp]
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 27aaee164a2880..c1fe0f5b9c0f6b 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase {
OpenACCDefaultClauseKind DefaultClauseKind;
};
- std::variant<DefaultDetails> Details;
+ struct ConditionDetails {
+ Expr *ConditionExpr;
+ };
+
+ std::variant<DefaultDetails, ConditionDetails> Details;
public:
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
@@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase {
return std::get<DefaultDetails>(Details).DefaultClauseKind;
}
+ const Expr *getConditionExpr() const {
+ return const_cast<OpenACCParsedClause *>(this)->getConditionExpr();
+ }
+
+ Expr *getConditionExpr() {
+ assert(ClauseKind == OpenACCClauseKind::If &&
+ "Parsed clause kind does not have a condition expr");
+ return std::get<ConditionDetails>(Details).ConditionExpr;
+ }
+
void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }
@@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase {
"Parsed clause is not a default clause");
Details = DefaultDetails{DefKind};
}
+
+ void setConditionDetails(Expr *ConditionExpr) {
+ assert(ClauseKind == OpenACCClauseKind::If &&
+ "Parsed clause kind does not have a condition expr");
+ // In C++ we can count on this being a 'bool', but in C this gets left as
+ // some sort of scalar that codegen will have to take care of converting.
+ assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
+ ConditionExpr->getType()->isScalarType()) &&
+ "Condition expression type not scalar/dependent");
+
+ Details = ConditionDetails{ConditionExpr};
+ }
};
SemaOpenACC(Sema &S);
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index c83128b60e3acc..0a512d48253a8c 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "clang/AST/OpenACCClause.h"
+#include "clang/AST/Expr.h"
#include "clang/AST/ASTContext.h"
using namespace clang;
@@ -27,6 +28,41 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
}
+OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C,
+ SourceLocation BeginLoc,
+ SourceLocation LParenLoc,
+ Expr *ConditionExpr,
+ SourceLocation EndLoc) {
+ void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
+ return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
+}
+
+OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
+ SourceLocation LParenLoc,
+ Expr *ConditionExpr,
+ SourceLocation EndLoc)
+ : OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc,
+ ConditionExpr, EndLoc) {
+ assert(ConditionExpr && "if clause requires condition expr");
+ assert((ConditionExpr->isInstantiationDependent() ||
+ ConditionExpr->getType()->isScalarType()) &&
+ "Condition expression type not scalar/dependent");
+}
+
+OpenACCClause::child_range OpenACCClause::children() {
+ switch (getClauseKind()) {
+ default:
+ assert(false && "Clause children function not implemented");
+ break;
+#define VISIT_CLAUSE(CLAUSE_NAME) \
+ case OpenACCClauseKind::CLAUSE_NAME: \
+ return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children();
+
+#include "clang/Basic/OpenACCClauses.def"
+ }
+ return child_range(child_iterator(), child_iterator());
+}
+
//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
@@ -34,3 +70,8 @@ void OpenACCClausePrinter::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &C) {
OS << "default(" << C.getDefaultClauseKind() << ")";
}
+
+void OpenACCClausePrinter::VisitOpenACCIfClause(
+ const OpenACCIfClause &C) {
+ OS << "if(" << C.getConditionExpr() << ")";
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 01e1d1cc8289bf..24593fd2f4d405 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
namespace {
class OpenACCClauseProfiler
: public OpenACCClauseVisitor<OpenACCClauseProfiler> {
+ StmtProfiler &Profiler;
public:
- OpenACCClauseProfiler() = default;
+ OpenACCClauseProfiler(StmtProfiler &P) :Profiler(P) {}
void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
for (const OpenACCClause *Clause : Clauses) {
@@ -2456,12 +2457,24 @@ class OpenACCClauseProfiler
Visit(Clause);
}
}
- void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
+
+#define VISIT_CLAUSE(CLAUSE_NAME) \
+ void VisitOpenACC##CLAUSE_NAME##Clause( \
+ const OpenACC##CLAUSE_NAME##Clause &Clause);
+
+#include "clang/Basic/OpenACCClauses.def"
};
/// Nothing to do here, there are no sub-statements.
void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
const OpenACCDefaultClause &Clause) {}
+
+void OpenACCClauseProfiler::VisitOpenACCIfClause(
+ const OpenACCIfClause &Clause) {
+ assert(Clause.hasConditionExpr() &&
+ "if clause requires a valid condition expr");
+ Profiler.VisitStmt(Clause.getConditionExpr());
+ }
} // namespace
void StmtProfiler::VisitOpenACCComputeConstruct(
@@ -2469,7 +2482,7 @@ void StmtProfiler::VisitOpenACCComputeConstruct(
// VisitStmt handles children, so the AssociatedStmt is handled.
VisitStmt(S);
- OpenACCClauseProfiler P;
+ OpenACCClauseProfiler P{*this};
P.VisitOpenACCClauseList(S->clauses());
}
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 085a7f51ce99ad..56650f99134d45 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Default:
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
break;
+ case OpenACCClauseKind::If:
+ // The condition expression will be printed as a part of the 'children',
+ // but print 'clause' here so it is clear what is happening from the dump.
+ OS << " clause";
+ break;
default:
// Nothing to do here.
break;
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index b487a1968d1ec8..6192afa8541cad 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -535,14 +535,6 @@ bool ClauseHasRequiredParens(OpenACCDirectiveKind DirKind,
return getClauseParensKind(DirKind, Kind) == ClauseParensKind::Required;
}
-ExprResult ParseOpenACCConditionalExpr(Parser &P) {
- // FIXME: It isn't clear if the spec saying 'condition' means the same as
- // it does in an if/while/etc (See ParseCXXCondition), however as it was
- // written with Fortran/C in mind, we're going to assume it just means an
- // 'expression evaluating to boolean'.
- return P.getActions().CorrectDelayedTyposInExpr(P.ParseExpression());
-}
-
// Skip until we see the end of pragma token, but don't consume it. This is us
// just giving up on the rest of the pragma so we can continue executing. We
// have to do this because 'SkipUntil' considers paren balancing, which isn't
@@ -595,6 +587,23 @@ Parser::OpenACCClauseParseResult Parser::OpenACCSuccess(OpenACCClause *Clause) {
return {Clause, OpenACCParseCanContinue::Can};
}
+ExprResult Parser::ParseOpenACCConditionExpr() {
+ // FIXME: It isn't clear if the spec saying 'condition' means the same as
+ // it does in an if/while/etc (See ParseCXXCondition), however as it was
+ // written with Fortran/C in mind, we're going to assume it just means an
+ // 'expression evaluating to boolean'.
+ ExprResult ER = getActions().CorrectDelayedTyposInExpr(ParseExpression());
+
+ if (!ER.isUsable())
+ return ER;
+
+ Sema::ConditionResult R =
+ getActions().ActOnCondition(getCurScope(), ER.get()->getExprLoc(),
+ ER.get(), Sema::ConditionKind::Boolean);
+
+ return R.isInvalid() ? ExprError () : R.get().second;
+}
+
// OpenACC 3.3, section 1.7:
// To simplify the specification and convey appropriate constraint information,
// a pqr-list is a comma-separated list of pdr items. The one exception is a
@@ -842,12 +851,15 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
break;
}
case OpenACCClauseKind::If: {
- ExprResult CondExpr = ParseOpenACCConditionalExpr(*this);
+ ExprResult CondExpr = ParseOpenACCConditionExpr();
+ ParsedClause.setConditionDetails(
+ CondExpr.isUsable() ? CondExpr.get() : nullptr);
if (CondExpr.isInvalid()) {
Parens.skipToEnd();
return OpenACCCanContinue();
}
+
break;
}
case OpenACCClauseKind::CopyIn:
@@ -964,7 +976,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
switch (ClauseKind) {
case OpenACCClauseKind::Self: {
assert(DirKind != OpenACCDirectiveKind::Update);
- ExprResult CondExpr = ParseOpenACCConditionalExpr(*this);
+ ExprResult CondExpr = ParseOpenACCConditionExpr();
if (CondExpr.isInvalid()) {
Parens.skipToEnd();
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index a6f4453e525d01..8e98f3ae913325 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -55,12 +55,49 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
default:
return false;
}
+ case OpenACCClauseKind::If:
+ switch (DirectiveKind) {
+ case OpenACCDirectiveKind::Parallel:
+ case OpenACCDirectiveKind::Serial:
+ case OpenACCDirectiveKind::Kernels:
+ case OpenACCDirectiveKind::Data:
+ case OpenACCDirectiveKind::EnterData:
+ case OpenACCDirectiveKind::ExitData:
+ case OpenACCDirectiveKind::HostData:
+ case OpenACCDirectiveKind::Init:
+ case OpenACCDirectiveKind::Shutdown:
+ case OpenACCDirectiveKind::Set:
+ case OpenACCDirectiveKind::Update:
+ case OpenACCDirectiveKind::Wait:
+ case OpenACCDirectiveKind::ParallelLoop:
+ case OpenACCDirectiveKind::SerialLoop:
+ case OpenACCDirectiveKind::KernelsLoop:
+ return true;
+ default:
+ return false;
+ }
default:
// Do nothing so we can go to the 'unimplemented' diagnostic instead.
return true;
}
llvm_unreachable("Invalid clause kind");
}
+
+bool checkAlreadyHasClauseOfKind(
+ SemaOpenACC &S, ArrayRef<const OpenACCClause *> ExistingClauses,
+ SemaOpenACC::OpenACCParsedClause &Clause) {
+ auto Itr = llvm::find_if(ExistingClauses, [&](const OpenACCClause *C) {
+ return C->getClauseKind() == Clause.getClauseKind();
+ });
+ if (Itr != ExistingClauses.end()) {
+ S.Diag(Clause.getBeginLoc(), diag::err_acc_duplicate_clause_disallowed)
+ << Clause.getDirectiveKind() << Clause.getClauseKind();
+ S.Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
+ return true;
+ }
+ return false;
+}
+
} // namespace
SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -97,22 +134,38 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
// At most one 'default' clause may appear, and it must have a value of
// either 'none' or 'present'.
// Second half of the sentence is diagnosed during parsing.
- auto Itr = llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
- return C->getClauseKind() == OpenACCClauseKind::Default;
- });
-
- if (Itr != ExistingClauses.end()) {
- Diag(Clause.getBeginLoc(),
- diag::err_acc_duplicate_clause_disallowed)
- << Clause.getDirectiveKind() << Clause.getClauseKind();
- Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
+ if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
return nullptr;
- }
return OpenACCDefaultClause::Create(
getASTContext(), Clause.getDefaultClauseKind(), Clause.getBeginLoc(),
Clause.getLParenLoc(), Clause.getEndLoc());
}
+
+ case OpenACCClauseKind::If: {
+ // Restrictions only properly implemented on 'compute' constructs, and
+ // 'compute' constructs are the only construct that can do anything with
+ // this yet, so skip/treat as unimplemented in this case.
+ if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Parallel &&
+ Clause.getDirectiveKind() != OpenACCDirectiveKind::Serial &&
+ Clause.getDirectiveKind() != OpenACCDirectiveKind::Kernels)
+ break;
+
+ // There is no prose in the standard that says duplicates aren't allowed,
+ // but this diagnostic is present in other compilers, as well as makes
+ // sense.
+ if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
+ return nullptr;
+
+ // The parser has ensured that we have a proper condition expr, so there
+ // isn't really much to do here.
+
+ // TODO OpenACC: When we implement 'self', this clauses causes us to
+ // 'ignore' the self clause, so we should implement a warning here.
+ return OpenACCIfClause::Create(
+ getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
+ Clause.getConditionExpr(), Clause.getEndLoc());
+ }
default:
break;
}
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 33a9356e82f409..8e6f95c7a7292e 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -11099,6 +11099,19 @@ OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
ParsedClause.setDefaultDetails(
cast<OpenACCDefaultClause>(OldClause)->getDefaultClauseKind());
break;
+ case OpenACCClauseKind::If: {
+ Expr *Cond = const_cast<Expr *>(
+ cast<OpenACCIfClause>(OldClause)->getConditionExpr());
+ Sema::ConditionResult Res =
+ TransformCondition(Cond->getExprLoc(), /*Var=*/nullptr, Cond,
+ Sema::ConditionKind::Boolean);
+
+ if (Res.isInvalid() || !Res.get().second)
+ return nullptr;
+
+ ParsedClause.setConditionDetails(Res.get().second);
+ break;
+ }
default:
assert(false && "Unhandled OpenACC clause in TreeTransform");
return nullptr;
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 4f6987f92fc82e..e9946a2eea02ff 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11764,6 +11764,12 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
return OpenACCDefaultClause::Create(getContext(), DCK, BeginLoc, LParenLoc,
EndLoc);
}
+ case OpenACCClauseKind::If: {
+ SourceLocation LParenLoc = readSourceLocation();
+ Expr *CondExpr = readSubExpr();
+ return OpenACCIfClause::Create(getContext(), BeginLoc, LParenLoc, CondExpr,
+ EndLoc);
+ }
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
@@ -11772,7 +11778,6 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
- case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index ffc53292e39124..a5cb7c68122018 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7425,6 +7425,12 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
writeEnum(DC->getDefaultClauseKind());
return;
}
+ case OpenACCClauseKind::If: {
+ const auto *IC = cast<OpenACCIfClause>(C);
+ writeSourceLocation(IC->getLParenLoc());
+ AddStmt(const_cast<Expr*>(IC->getConditionExpr()));
+ return;
+ }
case OpenACCClauseKind::Finalize:
case OpenACCClauseKind::IfPresent:
case OpenACCClauseKind::Seq:
@@ -7433,7 +7439,6 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
case OpenACCClauseKind::Worker:
case OpenACCClauseKind::Vector:
case OpenACCClauseKind::NoHost:
- case OpenACCClauseKind::If:
case OpenACCClauseKind::Self:
case OpenACCClauseKind::Copy:
case OpenACCClauseKind::UseDevice:
diff --git a/clang/test/ParserOpenACC/parse-clauses.c b/clang/test/ParserOpenACC/parse-clauses.c
index b363a0cb1362b0..2369df58308a72 100644
--- a/clang/test/ParserOpenACC/parse-clauses.c
+++ b/clang/test/ParserOpenACC/parse-clauses.c
@@ -283,11 +283,9 @@ void IfClause() {
int i, j;
- // expected-warning at +1{{OpenACC clause 'if' not yet implemented, clause ignored}}
#pragma acc serial if(i > j)
for(;;){}
- // expected-warning at +2{{OpenACC clause 'if' not yet implemented, clause ignored}}
// expected-warning at +1{{OpenACC clause 'seq' not yet implemented, clause ignored}}
#pragma acc serial if(1+5>3), seq
for(;;){}
diff --git a/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp b/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
index bd80103445028a..018f0b68c78109 100644
--- a/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
+++ b/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
@@ -6,8 +6,10 @@
#ifndef PCH_HELPER
#define PCH_HELPER
-void NormalFunc() {
+void NormalFunc(int i, float f) {
// CHECK: FunctionDecl{{.*}}NormalFunc
+ // CHECK-NEXT: ParmVarDecl
+ // CHECK-NEXT: ParmVarDecl
// CHECK-NEXT: CompoundStmt
#pragma acc parallel default(none)
while(true);
@@ -24,6 +26,20 @@ void NormalFunc() {
// CHECK-NEXT: WhileStmt
// CHECK-NEXT: CXXBoolLiteralExpr
// CHECK-NEXT: NullStmt
+
+#pragma acc kernels if( i < f)
+ while(true);
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: BinaryOperator{{.*}} 'bool' '<'
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <IntegralToFloating>
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'int' <LValueToRValue>
+ // CHECK-NEXT: DeclRefExpr{{.*}} 'int' lvalue ParmVar{{.*}} 'i' 'int'
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <LValueToRValue>
+ // CHECK-NEXT: DeclRefExpr{{.*}} 'float' lvalue ParmVar{{.*}} 'f' 'float'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
}
template<typename T>
@@ -51,24 +67,120 @@ void TemplFunc() {
// CHECK-NEXT: CXXBoolLiteralExpr
// CHECK-NEXT: NullStmt
+#pragma acc parallel if(T::SomeFloat < typename T::IntTy{})
+ while(true);
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: BinaryOperator{{.*}} '<dependent type>' '<'
+ // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+ // CHECK-NEXT: CXXUnresolvedConstructExpr{{.*}} 'typename T::IntTy' 'typename T::IntTy'
+ // CHECK-NEXT: InitListExpr{{.*}} 'void'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+#pragma acc serial if(typename T::IntTy{})
+ while(true);
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}serial
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: CXXUnresolvedConstructExpr{{.*}} 'typename T::IntTy' 'typename T::IntTy'
+ // CHECK-NEXT: InitListExpr{{.*}} 'void'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+#pragma acc kernels if(T::SomeFloat)
+ while(true);
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+#pragma acc parallel if(T::BC)
+ while(true);
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
// Match the instantiation:
// CHECK: FunctionDecl{{.*}}TemplFunc{{.*}}implicit_instantiation
- // CHECK-NEXT: TemplateArgument type 'int'
- // CHECK-NEXT: BuiltinType
+ // CHECK-NEXT: TemplateArgument type 'InstTy'
+ // CHECK-NEXT: RecordType{{.*}} 'InstTy'
+ // CHECK-NEXT: CXXRecord{{.*}} 'InstTy'
// CHECK-NEXT: CompoundStmt
+
// CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
// CHECK-NEXT: default(none)
// CHECK-NEXT: WhileStmt
// CHECK-NEXT: CXXBoolLiteralExpr
// CHECK-NEXT: NullStmt
+
// CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
// CHECK-NEXT: default(present)
// CHECK-NEXT: WhileStmt
// CHECK-NEXT: CXXBoolLiteralExpr
// CHECK-NEXT: NullStmt
+
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: BinaryOperator{{.*}} 'bool' '<'
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <LValueToRValue>
+ // CHECK-NEXT: DeclRefExpr{{.*}} 'const float' lvalue Var{{.*}} 'SomeFloat' 'const float'
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <IntegralToFloating>
+ // CHECK-NEXT: CXXFunctionalCastExpr{{.*}}'typename InstTy::IntTy':'int' functional cast to typename struct InstTy::IntTy <NoOp>
+ // CHECK-NEXT: InitListExpr {{.*}}'typename InstTy::IntTy':'int'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}serial
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: ImplicitCastExpr{{.*}}'bool' <IntegralToBoolean>
+ // CHECK-NEXT: CXXFunctionalCastExpr{{.*}}'typename InstTy::IntTy':'int' functional cast to typename struct InstTy::IntTy <NoOp>
+ // CHECK-NEXT: InitListExpr {{.*}}'typename InstTy::IntTy':'int'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: ImplicitCastExpr{{.*}}'bool' <FloatingToBoolean>
+ // CHECK-NEXT: ImplicitCastExpr{{.*}}'float' <LValueToRValue>
+ // CHECK-NEXT: DeclRefExpr{{.*}} 'const float' lvalue Var{{.*}} 'SomeFloat' 'const float'
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
+
+ // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+ // CHECK-NEXT: if clause
+ // CHECK-NEXT: ImplicitCastExpr{{.*}} 'bool' <UserDefinedConversion>
+ // CHECK-NEXT: CXXMemberCallExpr{{.*}} 'bool'
+ // CHECK-NEXT: MemberExpr{{.*}} .operator bool
+ // CHECK-NEXT: DeclRefExpr{{.*}} 'const BoolConversion' lvalue Var{{.*}} 'BC' 'const BoolConversion'
+ // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+ // CHECK-NEXT: WhileStmt
+ // CHECK-NEXT: CXXBoolLiteralExpr
+ // CHECK-NEXT: NullStmt
}
+struct BoolConversion{ operator bool() const;};
+struct InstTy {
+ using IntTy = int;
+ static constexpr float SomeFloat = 5.0;
+ static constexpr BoolConversion BC;
+};
+
void Instantiate() {
- TemplFunc<int>();
+ TemplFunc<InstTy>();
}
#endif
diff --git a/clang/test/SemaOpenACC/compute-construct-if-clause.c b/clang/test/SemaOpenACC/compute-construct-if-clause.c
new file mode 100644
index 00000000000000..767b8414b3a68a
--- /dev/null
+++ b/clang/test/SemaOpenACC/compute-construct-if-clause.c
@@ -0,0 +1,62 @@
+// RUN: %clang_cc1 %s -fopenacc -verify
+
+void BoolExpr(int *I, float *F) {
+
+ typedef struct {} SomeStruct;
+ int Array[5];
+
+ struct C{};
+ // expected-error at +1{{expected expression}}
+#pragma acc parallel if (struct C f())
+ while(0);
+
+ // expected-error at +1{{unexpected type name 'SomeStruct': expected expression}}
+#pragma acc serial if (SomeStruct)
+ while(0);
+
+ // expected-error at +1{{unexpected type name 'SomeStruct': expected expression}}
+#pragma acc serial if (SomeStruct())
+ while(0);
+
+ SomeStruct S;
+ // expected-error at +1{{statement requires expression of scalar type ('SomeStruct' invalid)}}
+#pragma acc serial if (S)
+ while(0);
+
+ // expected-warning at +1{{address of array 'Array' will always evaluate to 'true'}}
+#pragma acc kernels if (Array)
+ while(0);
+
+ // expected-warning at +4{{incompatible pointer types assigning to 'int *' from 'float *'}}
+ // expected-warning at +3{{using the result of an assignment as a condition without parentheses}}
+ // expected-note at +2{{place parentheses around the assignment to silence this warning}}
+ // expected-note at +1{{use '==' to turn this assignment into an equality comparison}}
+#pragma acc kernels if (I = F)
+ while(0);
+
+#pragma acc parallel if (I)
+ while(0);
+
+#pragma acc serial if (F)
+ while(0);
+
+#pragma acc kernels if (*I < *F)
+ while(0);
+
+ // expected-warning at +2{{OpenACC construct 'data' not yet implemented}}
+ // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc data if (*I < *F)
+ while(0);
+ // expected-warning at +2{{OpenACC construct 'parallel loop' not yet implemented}}
+ // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc parallel loop if (*I < *F)
+ while(0);
+ // expected-warning at +2{{OpenACC construct 'serial loop' not yet implemented}}
+ // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc serial loop if (*I < *F)
+ while(0);
+ // expected-warning at +2{{OpenACC construct 'kernels loop' not yet implemented}}
+ // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc kernels loop if (*I < *F)
+ while(0);
+}
diff --git a/clang/test/SemaOpenACC/compute-construct-if-clause.cpp b/clang/test/SemaOpenACC/compute-construct-if-clause.cpp
new file mode 100644
index 00000000000000..2a9bb8a3d5c4ea
--- /dev/null
+++ b/clang/test/SemaOpenACC/compute-construct-if-clause.cpp
@@ -0,0 +1,33 @@
+// RUN: %clang_cc1 %s -fopenacc -verify
+
+struct NoBoolConversion{};
+struct BoolConversion{
+ operator bool();
+};
+
+template <typename T, typename U>
+void BoolExpr() {
+
+ // expected-error at +1{{value of type 'NoBoolConversion' is not contextually convertible to 'bool'}}
+#pragma acc parallel if (NoBoolConversion{})
+ while(0);
+
+ // expected-error at +2{{no member named 'NotValid' in 'NoBoolConversion'}}
+ // expected-note@#INST{{in instantiation of function template specialization}}
+#pragma acc parallel if (T::NotValid)
+ while(0);
+
+#pragma acc parallel if (BoolConversion{})
+ while(0);
+
+ // expected-error at +1{{value of type 'NoBoolConversion' is not contextually convertible to 'bool'}}
+#pragma acc parallel if (T{})
+ while(0);
+
+#pragma acc parallel if (U{})
+ while(0);
+}
+
+void Instantiate() {
+ BoolExpr<NoBoolConversion, BoolConversion>(); // #INST
+}
More information about the cfe-commits
mailing list