[clang] [OpenACC] Implement 'if' clause for Compute Constructs (PR #88411)

Erich Keane via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 11 09:34:00 PDT 2024


https://github.com/erichkeane created https://github.com/llvm/llvm-project/pull/88411

Like with the 'default' clause, this is being applied to only Compute Constructs for now.  The 'if' clause takes a condition expression which is used as a runtime value.

This is not a particularly complex semantic implementation, as there isn't much to this clause, other than its interactions with 'self',
  which will be managed in the patch to implement that.

>From 408f39f8ed0ee121aeaeb15c02603bb127e8cb73 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Wed, 10 Apr 2024 07:56:30 -0700
Subject: [PATCH] [OpenACC] Implement 'if' clause for Compute Constructs

Like with the 'default' clause, this is being applied to only Compute
Constructs for now.  The 'if' clause takes a condition expression which
is used as a runtime value.

This is not a particularly complex semantic implementation, as there
isn't much to this clause, other than its interactions with 'self',
  which will be managed in the patch to implement that.
---
 clang/include/clang/AST/ASTNodeTraverser.h    |   3 +-
 clang/include/clang/AST/OpenACCClause.h       |  83 +++++++++++-
 clang/include/clang/Basic/OpenACCClauses.def  |  21 +++
 clang/include/clang/Parse/Parser.h            |   6 +-
 clang/include/clang/Sema/SemaOpenACC.h        |  28 +++-
 clang/lib/AST/OpenACCClause.cpp               |  41 ++++++
 clang/lib/AST/StmtProfile.cpp                 |  19 ++-
 clang/lib/AST/TextNodeDumper.cpp              |   5 +
 clang/lib/Parse/ParseOpenACC.cpp              |  32 +++--
 clang/lib/Sema/SemaOpenACC.cpp                |  73 +++++++++--
 clang/lib/Sema/TreeTransform.h                |  13 ++
 clang/lib/Serialization/ASTReader.cpp         |   7 +-
 clang/lib/Serialization/ASTWriter.cpp         |   7 +-
 clang/test/ParserOpenACC/parse-clauses.c      |   2 -
 .../compute-construct-clause-ast.cpp          | 120 +++++++++++++++++-
 .../SemaOpenACC/compute-construct-if-clause.c |  62 +++++++++
 .../compute-construct-if-clause.cpp           |  33 +++++
 17 files changed, 517 insertions(+), 38 deletions(-)
 create mode 100644 clang/include/clang/Basic/OpenACCClauses.def
 create mode 100644 clang/test/SemaOpenACC/compute-construct-if-clause.c
 create mode 100644 clang/test/SemaOpenACC/compute-construct-if-clause.cpp

diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index 94e7dd817809dd..37fe030fb8e5a3 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -243,7 +243,8 @@ class ASTNodeTraverser
   void Visit(const OpenACCClause *C) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(C);
-      // TODO OpenACC: Switch on clauses that have children, and add them.
+      for (const auto *S : C->children())
+      Visit(S);
     });
   }
 
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
index 27e4e1a12c9837..6e3c00614168e7 100644
--- a/clang/include/clang/AST/OpenACCClause.h
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -15,6 +15,7 @@
 #define LLVM_CLANG_AST_OPENACCCLAUSE_H
 #include "clang/AST/ASTContext.h"
 #include "clang/Basic/OpenACCKinds.h"
+#include "clang/AST/StmtIterator.h"
 
 namespace clang {
 /// This is the base type for all OpenACC Clauses.
@@ -34,6 +35,17 @@ class OpenACCClause {
 
   static bool classof(const OpenACCClause *) { return true; }
 
+  using child_iterator = StmtIterator;
+  using const_child_iterator = ConstStmtIterator;
+  using child_range = llvm::iterator_range<child_iterator>;
+  using const_child_range = llvm::iterator_range<const_child_iterator>;
+
+  child_range children();
+  const_child_range children() const {
+    auto Children = const_cast<OpenACCClause *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
+  }
+
   virtual ~OpenACCClause() = default;
 };
 
@@ -49,6 +61,14 @@ class OpenACCClauseWithParams : public OpenACCClause {
 
 public:
   SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
 };
 
 /// A 'default' clause, has the optional 'none' or 'present' argument.
@@ -81,6 +101,52 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
                                       SourceLocation EndLoc);
 };
 
+/// Represents one of the handful of classes that has an optional/required
+/// 'condition' expression as an argument.
+class OpenACCClauseWithCondition : public OpenACCClauseWithParams {
+  Expr *ConditionExpr;
+
+  protected:
+    OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc,
+                               SourceLocation LParenLoc,
+                               Expr *ConditionExpr, SourceLocation EndLoc)
+        : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
+          ConditionExpr(ConditionExpr) {}
+
+  public:
+  bool hasConditionExpr() const { return ConditionExpr; }
+  const Expr *getConditionExpr() const { return ConditionExpr; }
+  Expr *getConditionExpr() { return ConditionExpr; }
+
+  child_range children() {
+    if (ConditionExpr)
+      return child_range(reinterpret_cast<Stmt **>(&ConditionExpr),
+                         reinterpret_cast<Stmt **>(&ConditionExpr + 1));
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    if (ConditionExpr)
+      return const_child_range(
+          reinterpret_cast<Stmt *const *>(&ConditionExpr),
+          reinterpret_cast<Stmt *const *>(&ConditionExpr + 1));
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+};
+
+/// An 'if' clause, which has a required condition expression.
+class OpenACCIfClause : public OpenACCClauseWithCondition {
+protected:
+  OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
+                  Expr *ConditionExpr, SourceLocation EndLoc);
+
+public:
+  static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
+                                 SourceLocation LParenLoc,
+                                 Expr *ConditionExpr,
+                                 SourceLocation EndLoc);
+};
+
 template <class Impl> class OpenACCClauseVisitor {
   Impl &getDerived() { return static_cast<Impl &>(*this); }
 
@@ -98,6 +164,9 @@ template <class Impl> class OpenACCClauseVisitor {
     case OpenACCClauseKind::Default:
       VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
       return;
+    case OpenACCClauseKind::If:
+      VisitOpenACCIfClause(*cast<OpenACCIfClause>(C));
+      return;
     case OpenACCClauseKind::Finalize:
     case OpenACCClauseKind::IfPresent:
     case OpenACCClauseKind::Seq:
@@ -106,7 +175,6 @@ template <class Impl> class OpenACCClauseVisitor {
     case OpenACCClauseKind::Worker:
     case OpenACCClauseKind::Vector:
     case OpenACCClauseKind::NoHost:
-    case OpenACCClauseKind::If:
     case OpenACCClauseKind::Self:
     case OpenACCClauseKind::Copy:
     case OpenACCClauseKind::UseDevice:
@@ -145,9 +213,13 @@ template <class Impl> class OpenACCClauseVisitor {
     llvm_unreachable("Invalid Clause kind");
   }
 
-  void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
-    return getDerived().VisitOpenACCDefaultClause(Clause);
+#define VISIT_CLAUSE(CLAUSE_NAME)                                              \
+  void VisitOpenACC##CLAUSE_NAME##Clause(                                      \
+                                  const OpenACC##CLAUSE_NAME##Clause &Clause) {\
+  return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause);               \
   }
+
+#include "clang/Basic/OpenACCClauses.def"
 };
 
 class OpenACCClausePrinter final
@@ -165,7 +237,10 @@ class OpenACCClausePrinter final
   }
   OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}
 
-  void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
+#define VISIT_CLAUSE(CLAUSE_NAME)                                              \
+  void VisitOpenACC##CLAUSE_NAME##Clause(                                      \
+                                   const OpenACC##CLAUSE_NAME##Clause &Clause);
+#include "clang/Basic/OpenACCClauses.def"
 };
 
 } // namespace clang
diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def
new file mode 100644
index 00000000000000..7fd2720e02ce22
--- /dev/null
+++ b/clang/include/clang/Basic/OpenACCClauses.def
@@ -0,0 +1,21 @@
+//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a list of currently implemented OpenACC Clauses (and
+// eventually, the entire list) in a way that makes generating 'visitor' and
+// other lists easier.
+//
+// The primary macro is a single-argument version taking the name of the Clause
+// as used in Clang source (so `Default` instead of `default`).
+//
+// VISIT_CLAUSE(CLAUSE_NAME)
+
+VISIT_CLAUSE(Default)
+VISIT_CLAUSE(If)
+
+#undef VISIT_CLAUSE
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 3a055c10ffb387..9d83a52929789e 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler {
   OpenACCClauseParseResult OpenACCCannotContinue();
   OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause);
 
+  using OpenACCConditionExprParseResult =
+      std::pair<ExprResult, OpenACCParseCanContinue>;
+
   /// Parses the OpenACC directive (the entire pragma) including the clause
   /// list, but does not produce the main AST node.
   OpenACCDirectiveParseInfo ParseOpenACCDirective();
@@ -3657,7 +3660,8 @@ class Parser : public CodeCompletionHandler {
   bool ParseOpenACCGangArgList();
   /// Parses a 'gang-arg', used for the 'gang' clause.
   bool ParseOpenACCGangArg();
-
+  /// Parses a 'condition' expr, ensuring it results in a
+  ExprResult ParseOpenACCConditionExpr();
 private:
   //===--------------------------------------------------------------------===//
   // C++ 14: Templates [temp]
diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h
index 27aaee164a2880..c1fe0f5b9c0f6b 100644
--- a/clang/include/clang/Sema/SemaOpenACC.h
+++ b/clang/include/clang/Sema/SemaOpenACC.h
@@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase {
       OpenACCDefaultClauseKind DefaultClauseKind;
     };
 
-    std::variant<DefaultDetails> Details;
+    struct ConditionDetails {
+      Expr *ConditionExpr;
+    };
+
+    std::variant<DefaultDetails, ConditionDetails> Details;
 
   public:
     OpenACCParsedClause(OpenACCDirectiveKind DirKind,
@@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase {
       return std::get<DefaultDetails>(Details).DefaultClauseKind;
     }
 
+    const Expr *getConditionExpr() const {
+      return const_cast<OpenACCParsedClause *>(this)->getConditionExpr();
+    }
+
+    Expr *getConditionExpr() {
+      assert(ClauseKind == OpenACCClauseKind::If &&
+             "Parsed clause kind does not have a condition expr");
+      return std::get<ConditionDetails>(Details).ConditionExpr;
+    }
+
     void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
     void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }
 
@@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase {
              "Parsed clause is not a default clause");
       Details = DefaultDetails{DefKind};
     }
+
+    void setConditionDetails(Expr *ConditionExpr) {
+      assert(ClauseKind == OpenACCClauseKind::If &&
+             "Parsed clause kind does not have a condition expr");
+      // In C++ we can count on this being a 'bool', but in C this gets left as
+      // some sort of scalar that codegen will have to take care of converting.
+      assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
+              ConditionExpr->getType()->isScalarType()) &&
+             "Condition expression type not scalar/dependent");
+
+      Details = ConditionDetails{ConditionExpr};
+    }
   };
 
   SemaOpenACC(Sema &S);
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
index c83128b60e3acc..0a512d48253a8c 100644
--- a/clang/lib/AST/OpenACCClause.cpp
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/AST/OpenACCClause.h"
+#include "clang/AST/Expr.h"
 #include "clang/AST/ASTContext.h"
 
 using namespace clang;
@@ -27,6 +28,41 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
   return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
 }
 
+OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C,
+                                         SourceLocation BeginLoc,
+                                         SourceLocation LParenLoc,
+                                         Expr *ConditionExpr,
+                                         SourceLocation EndLoc) {
+  void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
+  return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
+}
+
+OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
+                                 SourceLocation LParenLoc,
+                                 Expr *ConditionExpr,
+                                 SourceLocation EndLoc)
+    : OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc,
+                                 ConditionExpr, EndLoc) {
+  assert(ConditionExpr && "if clause requires condition expr");
+  assert((ConditionExpr->isInstantiationDependent() ||
+          ConditionExpr->getType()->isScalarType()) &&
+         "Condition expression type not scalar/dependent");
+}
+
+OpenACCClause::child_range OpenACCClause::children() {
+  switch (getClauseKind()) {
+    default:
+      assert(false && "Clause children function not implemented");
+      break;
+#define VISIT_CLAUSE(CLAUSE_NAME)                                              \
+    case OpenACCClauseKind::CLAUSE_NAME:                                       \
+      return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children();
+
+#include "clang/Basic/OpenACCClauses.def"
+  }
+  return child_range(child_iterator(), child_iterator());
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenACC clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -34,3 +70,8 @@ void OpenACCClausePrinter::VisitOpenACCDefaultClause(
     const OpenACCDefaultClause &C) {
   OS << "default(" << C.getDefaultClauseKind() << ")";
 }
+
+void OpenACCClausePrinter::VisitOpenACCIfClause(
+    const OpenACCIfClause &C) {
+  OS << "if(" << C.getConditionExpr() << ")";
+}
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 01e1d1cc8289bf..24593fd2f4d405 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
 namespace {
 class OpenACCClauseProfiler
     : public OpenACCClauseVisitor<OpenACCClauseProfiler> {
+      StmtProfiler &Profiler;
 
 public:
-  OpenACCClauseProfiler() = default;
+  OpenACCClauseProfiler(StmtProfiler &P) :Profiler(P) {}
 
   void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
     for (const OpenACCClause *Clause : Clauses) {
@@ -2456,12 +2457,24 @@ class OpenACCClauseProfiler
       Visit(Clause);
     }
   }
-  void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
+
+#define VISIT_CLAUSE(CLAUSE_NAME)                                              \
+  void VisitOpenACC##CLAUSE_NAME##Clause(                                      \
+      const OpenACC##CLAUSE_NAME##Clause &Clause);
+
+#include "clang/Basic/OpenACCClauses.def"
 };
 
 /// Nothing to do here, there are no sub-statements.
 void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
     const OpenACCDefaultClause &Clause) {}
+
+void OpenACCClauseProfiler::VisitOpenACCIfClause(
+    const OpenACCIfClause &Clause) {
+  assert(Clause.hasConditionExpr() &&
+        "if clause requires a valid condition expr");
+  Profiler.VisitStmt(Clause.getConditionExpr());
+  }
 } // namespace
 
 void StmtProfiler::VisitOpenACCComputeConstruct(
@@ -2469,7 +2482,7 @@ void StmtProfiler::VisitOpenACCComputeConstruct(
   // VisitStmt handles children, so the AssociatedStmt is handled.
   VisitStmt(S);
 
-  OpenACCClauseProfiler P;
+  OpenACCClauseProfiler P{*this};
   P.VisitOpenACCClauseList(S->clauses());
 }
 
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 085a7f51ce99ad..56650f99134d45 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
     case OpenACCClauseKind::Default:
       OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
       break;
+    case OpenACCClauseKind::If:
+      // The condition expression will be printed as a part of the 'children',
+      // but print 'clause' here so it is clear what is happening from the dump.
+      OS << " clause";
+        break;
     default:
       // Nothing to do here.
       break;
diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp
index b487a1968d1ec8..6192afa8541cad 100644
--- a/clang/lib/Parse/ParseOpenACC.cpp
+++ b/clang/lib/Parse/ParseOpenACC.cpp
@@ -535,14 +535,6 @@ bool ClauseHasRequiredParens(OpenACCDirectiveKind DirKind,
   return getClauseParensKind(DirKind, Kind) == ClauseParensKind::Required;
 }
 
-ExprResult ParseOpenACCConditionalExpr(Parser &P) {
-  // FIXME: It isn't clear if the spec saying 'condition' means the same as
-  // it does in an if/while/etc (See ParseCXXCondition), however as it was
-  // written with Fortran/C in mind, we're going to assume it just means an
-  // 'expression evaluating to boolean'.
-  return P.getActions().CorrectDelayedTyposInExpr(P.ParseExpression());
-}
-
 // Skip until we see the end of pragma token, but don't consume it. This is us
 // just giving up on the rest of the pragma so we can continue executing. We
 // have to do this because 'SkipUntil' considers paren balancing, which isn't
@@ -595,6 +587,23 @@ Parser::OpenACCClauseParseResult Parser::OpenACCSuccess(OpenACCClause *Clause) {
   return {Clause, OpenACCParseCanContinue::Can};
 }
 
+ExprResult Parser::ParseOpenACCConditionExpr() {
+  // FIXME: It isn't clear if the spec saying 'condition' means the same as
+  // it does in an if/while/etc (See ParseCXXCondition), however as it was
+  // written with Fortran/C in mind, we're going to assume it just means an
+  // 'expression evaluating to boolean'.
+  ExprResult ER = getActions().CorrectDelayedTyposInExpr(ParseExpression());
+
+  if (!ER.isUsable())
+    return ER;
+
+  Sema::ConditionResult R =
+      getActions().ActOnCondition(getCurScope(), ER.get()->getExprLoc(),
+                                  ER.get(), Sema::ConditionKind::Boolean);
+
+  return R.isInvalid() ? ExprError () : R.get().second;
+}
+
 // OpenACC 3.3, section 1.7:
 // To simplify the specification and convey appropriate constraint information,
 // a pqr-list is a comma-separated list of pdr items. The one exception is a
@@ -842,12 +851,15 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
       break;
     }
     case OpenACCClauseKind::If: {
-      ExprResult CondExpr = ParseOpenACCConditionalExpr(*this);
+      ExprResult CondExpr = ParseOpenACCConditionExpr();
+      ParsedClause.setConditionDetails(
+          CondExpr.isUsable() ? CondExpr.get() : nullptr);
 
       if (CondExpr.isInvalid()) {
         Parens.skipToEnd();
         return OpenACCCanContinue();
       }
+
       break;
     }
     case OpenACCClauseKind::CopyIn:
@@ -964,7 +976,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
       switch (ClauseKind) {
       case OpenACCClauseKind::Self: {
         assert(DirKind != OpenACCDirectiveKind::Update);
-        ExprResult CondExpr = ParseOpenACCConditionalExpr(*this);
+        ExprResult CondExpr = ParseOpenACCConditionExpr();
 
         if (CondExpr.isInvalid()) {
           Parens.skipToEnd();
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index a6f4453e525d01..8e98f3ae913325 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -55,12 +55,49 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
     default:
       return false;
     }
+  case OpenACCClauseKind::If:
+    switch (DirectiveKind) {
+    case OpenACCDirectiveKind::Parallel:
+    case OpenACCDirectiveKind::Serial:
+    case OpenACCDirectiveKind::Kernels:
+    case OpenACCDirectiveKind::Data:
+    case OpenACCDirectiveKind::EnterData:
+    case OpenACCDirectiveKind::ExitData:
+    case OpenACCDirectiveKind::HostData:
+    case OpenACCDirectiveKind::Init:
+    case OpenACCDirectiveKind::Shutdown:
+    case OpenACCDirectiveKind::Set:
+    case OpenACCDirectiveKind::Update:
+    case OpenACCDirectiveKind::Wait:
+    case OpenACCDirectiveKind::ParallelLoop:
+    case OpenACCDirectiveKind::SerialLoop:
+    case OpenACCDirectiveKind::KernelsLoop:
+      return true;
+    default:
+      return false;
+    }
   default:
     // Do nothing so we can go to the 'unimplemented' diagnostic instead.
     return true;
   }
   llvm_unreachable("Invalid clause kind");
 }
+
+bool checkAlreadyHasClauseOfKind(
+    SemaOpenACC &S, ArrayRef<const OpenACCClause *> ExistingClauses,
+    SemaOpenACC::OpenACCParsedClause &Clause) {
+  auto Itr = llvm::find_if(ExistingClauses, [&](const OpenACCClause *C) {
+    return C->getClauseKind() == Clause.getClauseKind();
+  });
+  if (Itr != ExistingClauses.end()) {
+    S.Diag(Clause.getBeginLoc(), diag::err_acc_duplicate_clause_disallowed)
+        << Clause.getDirectiveKind() << Clause.getClauseKind();
+    S.Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
+    return true;
+  }
+  return false;
+}
+
 } // namespace
 
 SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -97,22 +134,38 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClauses,
     // At most one 'default' clause may appear, and it must have a value of
     // either 'none' or 'present'.
     // Second half of the sentence is diagnosed during parsing.
-    auto Itr = llvm::find_if(ExistingClauses, [](const OpenACCClause *C) {
-      return C->getClauseKind() == OpenACCClauseKind::Default;
-    });
-
-    if (Itr != ExistingClauses.end()) {
-      Diag(Clause.getBeginLoc(),
-                   diag::err_acc_duplicate_clause_disallowed)
-          << Clause.getDirectiveKind() << Clause.getClauseKind();
-      Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here);
+    if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
       return nullptr;
-    }
 
     return OpenACCDefaultClause::Create(
         getASTContext(), Clause.getDefaultClauseKind(), Clause.getBeginLoc(),
         Clause.getLParenLoc(), Clause.getEndLoc());
   }
+
+  case OpenACCClauseKind::If: {
+    // Restrictions only properly implemented on 'compute' constructs, and
+    // 'compute' constructs are the only construct that can do anything with
+    // this yet, so skip/treat as unimplemented in this case.
+    if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Parallel &&
+        Clause.getDirectiveKind() != OpenACCDirectiveKind::Serial &&
+        Clause.getDirectiveKind() != OpenACCDirectiveKind::Kernels)
+      break;
+
+    // There is no prose in the standard that says duplicates aren't allowed,
+    // but this diagnostic is present in other compilers, as well as makes
+    // sense.
+    if (checkAlreadyHasClauseOfKind(*this, ExistingClauses, Clause))
+      return nullptr;
+
+    // The parser has ensured that we have a proper condition expr, so there
+    // isn't really much to do here.
+
+    // TODO OpenACC: When we implement 'self', this clauses causes us to
+    // 'ignore' the self clause, so we should implement a warning here.
+    return OpenACCIfClause::Create(
+        getASTContext(), Clause.getBeginLoc(), Clause.getLParenLoc(),
+        Clause.getConditionExpr(), Clause.getEndLoc());
+  }
   default:
     break;
   }
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 33a9356e82f409..8e6f95c7a7292e 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -11099,6 +11099,19 @@ OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(
     ParsedClause.setDefaultDetails(
         cast<OpenACCDefaultClause>(OldClause)->getDefaultClauseKind());
     break;
+  case OpenACCClauseKind::If: {
+    Expr *Cond = const_cast<Expr *>(
+        cast<OpenACCIfClause>(OldClause)->getConditionExpr());
+    Sema::ConditionResult Res =
+        TransformCondition(Cond->getExprLoc(), /*Var=*/nullptr, Cond,
+                           Sema::ConditionKind::Boolean);
+
+    if (Res.isInvalid() || !Res.get().second)
+      return nullptr;
+
+    ParsedClause.setConditionDetails(Res.get().second);
+    break;
+  }
   default:
     assert(false && "Unhandled OpenACC clause in TreeTransform");
     return nullptr;
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 4f6987f92fc82e..e9946a2eea02ff 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11764,6 +11764,12 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
     return OpenACCDefaultClause::Create(getContext(), DCK, BeginLoc, LParenLoc,
                                         EndLoc);
   }
+  case OpenACCClauseKind::If: {
+    SourceLocation LParenLoc = readSourceLocation();
+    Expr *CondExpr = readSubExpr();
+    return OpenACCIfClause::Create(getContext(), BeginLoc, LParenLoc, CondExpr,
+                                   EndLoc);
+  }
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
   case OpenACCClauseKind::Seq:
@@ -11772,7 +11778,6 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
   case OpenACCClauseKind::Worker:
   case OpenACCClauseKind::Vector:
   case OpenACCClauseKind::NoHost:
-  case OpenACCClauseKind::If:
   case OpenACCClauseKind::Self:
   case OpenACCClauseKind::Copy:
   case OpenACCClauseKind::UseDevice:
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index ffc53292e39124..a5cb7c68122018 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -7425,6 +7425,12 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
     writeEnum(DC->getDefaultClauseKind());
     return;
   }
+  case OpenACCClauseKind::If: {
+    const auto *IC = cast<OpenACCIfClause>(C);
+    writeSourceLocation(IC->getLParenLoc());
+    AddStmt(const_cast<Expr*>(IC->getConditionExpr()));
+    return;
+  }
   case OpenACCClauseKind::Finalize:
   case OpenACCClauseKind::IfPresent:
   case OpenACCClauseKind::Seq:
@@ -7433,7 +7439,6 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
   case OpenACCClauseKind::Worker:
   case OpenACCClauseKind::Vector:
   case OpenACCClauseKind::NoHost:
-  case OpenACCClauseKind::If:
   case OpenACCClauseKind::Self:
   case OpenACCClauseKind::Copy:
   case OpenACCClauseKind::UseDevice:
diff --git a/clang/test/ParserOpenACC/parse-clauses.c b/clang/test/ParserOpenACC/parse-clauses.c
index b363a0cb1362b0..2369df58308a72 100644
--- a/clang/test/ParserOpenACC/parse-clauses.c
+++ b/clang/test/ParserOpenACC/parse-clauses.c
@@ -283,11 +283,9 @@ void IfClause() {
 
   int i, j;
 
-  // expected-warning at +1{{OpenACC clause 'if' not yet implemented, clause ignored}}
 #pragma acc serial if(i > j)
   for(;;){}
 
-  // expected-warning at +2{{OpenACC clause 'if' not yet implemented, clause ignored}}
   // expected-warning at +1{{OpenACC clause 'seq' not yet implemented, clause ignored}}
 #pragma acc serial if(1+5>3), seq
   for(;;){}
diff --git a/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp b/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
index bd80103445028a..018f0b68c78109 100644
--- a/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
+++ b/clang/test/SemaOpenACC/compute-construct-clause-ast.cpp
@@ -6,8 +6,10 @@
 
 #ifndef PCH_HELPER
 #define PCH_HELPER
-void NormalFunc() {
+void NormalFunc(int i, float f) {
   // CHECK: FunctionDecl{{.*}}NormalFunc
+  // CHECK-NEXT: ParmVarDecl
+  // CHECK-NEXT: ParmVarDecl
   // CHECK-NEXT: CompoundStmt
 #pragma acc parallel default(none)
   while(true);
@@ -24,6 +26,20 @@ void NormalFunc() {
   // CHECK-NEXT: WhileStmt
   // CHECK-NEXT: CXXBoolLiteralExpr
   // CHECK-NEXT: NullStmt
+
+#pragma acc kernels if( i < f)
+  while(true);
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: BinaryOperator{{.*}} 'bool' '<'
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <IntegralToFloating>
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'int' <LValueToRValue>
+  // CHECK-NEXT: DeclRefExpr{{.*}} 'int' lvalue ParmVar{{.*}} 'i' 'int'
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <LValueToRValue>
+  // CHECK-NEXT: DeclRefExpr{{.*}} 'float' lvalue ParmVar{{.*}} 'f' 'float'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
 }
 
 template<typename T>
@@ -51,24 +67,120 @@ void TemplFunc() {
   // CHECK-NEXT: CXXBoolLiteralExpr
   // CHECK-NEXT: NullStmt
 
+#pragma acc parallel if(T::SomeFloat < typename T::IntTy{})
+  while(true);
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: BinaryOperator{{.*}} '<dependent type>' '<'
+  // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+  // CHECK-NEXT: CXXUnresolvedConstructExpr{{.*}} 'typename T::IntTy' 'typename T::IntTy'
+  // CHECK-NEXT: InitListExpr{{.*}} 'void'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+#pragma acc serial if(typename T::IntTy{})
+  while(true);
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}serial
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: CXXUnresolvedConstructExpr{{.*}} 'typename T::IntTy' 'typename T::IntTy'
+  // CHECK-NEXT: InitListExpr{{.*}} 'void'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+#pragma acc kernels if(T::SomeFloat)
+  while(true);
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+#pragma acc parallel if(T::BC)
+  while(true);
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: DependentScopeDeclRefExpr{{.*}} '<dependent type>' lvalue
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'T'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
   // Match the instantiation:
   // CHECK: FunctionDecl{{.*}}TemplFunc{{.*}}implicit_instantiation
-  // CHECK-NEXT: TemplateArgument type 'int'
-  // CHECK-NEXT: BuiltinType
+  // CHECK-NEXT: TemplateArgument type 'InstTy'
+  // CHECK-NEXT: RecordType{{.*}} 'InstTy'
+  // CHECK-NEXT: CXXRecord{{.*}} 'InstTy'
   // CHECK-NEXT: CompoundStmt
+
   // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
   // CHECK-NEXT: default(none)
   // CHECK-NEXT: WhileStmt
   // CHECK-NEXT: CXXBoolLiteralExpr
   // CHECK-NEXT: NullStmt
+
   // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
   // CHECK-NEXT: default(present)
   // CHECK-NEXT: WhileStmt
   // CHECK-NEXT: CXXBoolLiteralExpr
   // CHECK-NEXT: NullStmt
+
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: BinaryOperator{{.*}} 'bool' '<'
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <LValueToRValue>
+  // CHECK-NEXT: DeclRefExpr{{.*}} 'const float' lvalue Var{{.*}} 'SomeFloat' 'const float'
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'float' <IntegralToFloating>
+  // CHECK-NEXT: CXXFunctionalCastExpr{{.*}}'typename InstTy::IntTy':'int' functional cast to typename struct InstTy::IntTy <NoOp>
+  // CHECK-NEXT: InitListExpr {{.*}}'typename InstTy::IntTy':'int'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}serial
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: ImplicitCastExpr{{.*}}'bool' <IntegralToBoolean>
+  // CHECK-NEXT: CXXFunctionalCastExpr{{.*}}'typename InstTy::IntTy':'int' functional cast to typename struct InstTy::IntTy <NoOp>
+  // CHECK-NEXT: InitListExpr {{.*}}'typename InstTy::IntTy':'int'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}kernels
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: ImplicitCastExpr{{.*}}'bool' <FloatingToBoolean>
+  // CHECK-NEXT: ImplicitCastExpr{{.*}}'float' <LValueToRValue>
+  // CHECK-NEXT: DeclRefExpr{{.*}} 'const float' lvalue Var{{.*}} 'SomeFloat' 'const float'
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
+
+  // CHECK-NEXT: OpenACCComputeConstruct{{.*}}parallel
+  // CHECK-NEXT: if clause
+  // CHECK-NEXT: ImplicitCastExpr{{.*}} 'bool' <UserDefinedConversion>
+  // CHECK-NEXT: CXXMemberCallExpr{{.*}} 'bool'
+  // CHECK-NEXT: MemberExpr{{.*}} .operator bool
+  // CHECK-NEXT: DeclRefExpr{{.*}} 'const BoolConversion' lvalue Var{{.*}} 'BC' 'const BoolConversion'
+  // CHECK-NEXT: NestedNameSpecifier TypeSpec 'InstTy'
+  // CHECK-NEXT: WhileStmt
+  // CHECK-NEXT: CXXBoolLiteralExpr
+  // CHECK-NEXT: NullStmt
 }
 
+struct BoolConversion{ operator bool() const;};
+struct InstTy {
+  using IntTy = int;
+  static constexpr float SomeFloat = 5.0;
+  static constexpr BoolConversion BC;
+};
+
 void Instantiate() {
-  TemplFunc<int>();
+  TemplFunc<InstTy>();
 }
 #endif
diff --git a/clang/test/SemaOpenACC/compute-construct-if-clause.c b/clang/test/SemaOpenACC/compute-construct-if-clause.c
new file mode 100644
index 00000000000000..767b8414b3a68a
--- /dev/null
+++ b/clang/test/SemaOpenACC/compute-construct-if-clause.c
@@ -0,0 +1,62 @@
+// RUN: %clang_cc1 %s -fopenacc -verify
+
+void BoolExpr(int *I, float *F) {
+
+  typedef struct {} SomeStruct;
+  int Array[5];
+
+  struct C{};
+  // expected-error at +1{{expected expression}}
+#pragma acc parallel if (struct C f())
+  while(0);
+
+  // expected-error at +1{{unexpected type name 'SomeStruct': expected expression}}
+#pragma acc serial if (SomeStruct)
+  while(0);
+
+  // expected-error at +1{{unexpected type name 'SomeStruct': expected expression}}
+#pragma acc serial if (SomeStruct())
+  while(0);
+
+  SomeStruct S;
+  // expected-error at +1{{statement requires expression of scalar type ('SomeStruct' invalid)}}
+#pragma acc serial if (S)
+  while(0);
+
+  // expected-warning at +1{{address of array 'Array' will always evaluate to 'true'}}
+#pragma acc kernels if (Array)
+  while(0);
+
+  // expected-warning at +4{{incompatible pointer types assigning to 'int *' from 'float *'}}
+  // expected-warning at +3{{using the result of an assignment as a condition without parentheses}}
+  // expected-note at +2{{place parentheses around the assignment to silence this warning}}
+  // expected-note at +1{{use '==' to turn this assignment into an equality comparison}}
+#pragma acc kernels if (I = F)
+  while(0);
+
+#pragma acc parallel if (I)
+  while(0);
+
+#pragma acc serial if (F)
+  while(0);
+
+#pragma acc kernels if (*I < *F)
+  while(0);
+
+  // expected-warning at +2{{OpenACC construct 'data' not yet implemented}}
+  // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc data if (*I < *F)
+  while(0);
+  // expected-warning at +2{{OpenACC construct 'parallel loop' not yet implemented}}
+  // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc parallel loop if (*I < *F)
+  while(0);
+  // expected-warning at +2{{OpenACC construct 'serial loop' not yet implemented}}
+  // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc serial loop if (*I < *F)
+  while(0);
+  // expected-warning at +2{{OpenACC construct 'kernels loop' not yet implemented}}
+  // expected-warning at +1{{OpenACC clause 'if' not yet implemented}}
+#pragma acc kernels loop if (*I < *F)
+  while(0);
+}
diff --git a/clang/test/SemaOpenACC/compute-construct-if-clause.cpp b/clang/test/SemaOpenACC/compute-construct-if-clause.cpp
new file mode 100644
index 00000000000000..2a9bb8a3d5c4ea
--- /dev/null
+++ b/clang/test/SemaOpenACC/compute-construct-if-clause.cpp
@@ -0,0 +1,33 @@
+// RUN: %clang_cc1 %s -fopenacc -verify
+
+struct NoBoolConversion{};
+struct BoolConversion{
+  operator bool();
+};
+
+template <typename T, typename U>
+void BoolExpr() {
+
+  // expected-error at +1{{value of type 'NoBoolConversion' is not contextually convertible to 'bool'}}
+#pragma acc parallel if (NoBoolConversion{})
+  while(0);
+
+  // expected-error at +2{{no member named 'NotValid' in 'NoBoolConversion'}}
+  // expected-note@#INST{{in instantiation of function template specialization}}
+#pragma acc parallel if (T::NotValid)
+  while(0);
+
+#pragma acc parallel if (BoolConversion{})
+  while(0);
+
+  // expected-error at +1{{value of type 'NoBoolConversion' is not contextually convertible to 'bool'}}
+#pragma acc parallel if (T{})
+  while(0);
+
+#pragma acc parallel if (U{})
+  while(0);
+}
+
+void Instantiate() {
+  BoolExpr<NoBoolConversion, BoolConversion>(); // #INST
+}



More information about the cfe-commits mailing list