[clang] [OpenACC][NFC] Add OpenACC Clause AST Nodes/infrastructure (PR #87675)

Erich Keane via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 5 07:17:07 PDT 2024


https://github.com/erichkeane updated https://github.com/llvm/llvm-project/pull/87675

>From 2e05458175478002a14c9316a7fde66f7301dd94 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Thu, 4 Apr 2024 10:59:34 -0700
Subject: [PATCH 1/2] [OpenACC][NFC] Add OpenACC Clause AST
 Nodes/infrastructure

As a first step in adding clause support for OpenACC to Semantic
Analysis, this patch adds the 'base' AST nodes required for clauses.

This patch has no functional effect at the moment, but followup patches
will add the semantic analysis of clauses (plus individual clauses).
---
 clang/include/clang/AST/ASTNodeTraverser.h    |  13 ++
 clang/include/clang/AST/JSONNodeDumper.h      |   1 +
 clang/include/clang/AST/OpenACCClause.h       | 135 ++++++++++++++++++
 clang/include/clang/AST/RecursiveASTVisitor.h |  13 +-
 clang/include/clang/AST/StmtOpenACC.h         |  53 +++++--
 clang/include/clang/AST/TextNodeDumper.h      |   2 +
 .../clang/Serialization/ASTRecordReader.h     |   7 +
 .../clang/Serialization/ASTRecordWriter.h     |   7 +
 clang/lib/AST/CMakeLists.txt                  |   1 +
 clang/lib/AST/JSONNodeDumper.cpp              |   2 +
 clang/lib/AST/OpenACCClause.cpp               |  17 +++
 clang/lib/AST/StmtOpenACC.cpp                 |  19 +--
 clang/lib/AST/StmtPrinter.cpp                 |   8 +-
 clang/lib/AST/StmtProfile.cpp                 |  21 ++-
 clang/lib/AST/TextNodeDumper.cpp              |  23 ++-
 clang/lib/Sema/SemaOpenACC.cpp                |   3 +-
 clang/lib/Serialization/ASTReader.cpp         |  66 +++++++++
 clang/lib/Serialization/ASTReaderStmt.cpp     |  10 +-
 clang/lib/Serialization/ASTWriter.cpp         |  62 ++++++++
 clang/lib/Serialization/ASTWriterStmt.cpp     |   3 +-
 20 files changed, 438 insertions(+), 28 deletions(-)
 create mode 100644 clang/include/clang/AST/OpenACCClause.h
 create mode 100644 clang/lib/AST/OpenACCClause.cpp

diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index 06d67e9cba9536..94e7dd817809dd 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -53,6 +53,7 @@ struct {
   void Visit(TypeLoc);
   void Visit(const Decl *D);
   void Visit(const CXXCtorInitializer *Init);
+  void Visit(const OpenACCClause *C);
   void Visit(const OMPClause *C);
   void Visit(const BlockDecl::Capture &C);
   void Visit(const GenericSelectionExpr::ConstAssociation &A);
@@ -239,6 +240,13 @@ class ASTNodeTraverser
     });
   }
 
+  void Visit(const OpenACCClause *C) {
+    getNodeDelegate().AddChild([=] {
+      getNodeDelegate().Visit(C);
+      // TODO OpenACC: Switch on clauses that have children, and add them.
+    });
+  }
+
   void Visit(const OMPClause *C) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(C);
@@ -799,6 +807,11 @@ class ASTNodeTraverser
       Visit(C);
   }
 
+  void VisitOpenACCConstructStmt(const OpenACCConstructStmt *Node) {
+    for (const auto *C : Node->clauses())
+      Visit(C);
+  }
+
   void VisitInitListExpr(const InitListExpr *ILE) {
     if (auto *Filler = ILE->getArrayFiller()) {
       Visit(Filler, "array_filler");
diff --git a/clang/include/clang/AST/JSONNodeDumper.h b/clang/include/clang/AST/JSONNodeDumper.h
index dde70dde2fa2be..7a60f362650ca0 100644
--- a/clang/include/clang/AST/JSONNodeDumper.h
+++ b/clang/include/clang/AST/JSONNodeDumper.h
@@ -203,6 +203,7 @@ class JSONNodeDumper
   void Visit(const TemplateArgument &TA, SourceRange R = {},
              const Decl *From = nullptr, StringRef Label = {});
   void Visit(const CXXCtorInitializer *Init);
+  void Visit(const OpenACCClause *C);
   void Visit(const OMPClause *C);
   void Visit(const BlockDecl::Capture &C);
   void Visit(const GenericSelectionExpr::ConstAssociation &A);
diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h
new file mode 100644
index 00000000000000..06a0098bbda4cd
--- /dev/null
+++ b/clang/include/clang/AST/OpenACCClause.h
@@ -0,0 +1,135 @@
+//===- OpenACCClause.h - Classes for OpenACC clauses ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This file defines OpenACC AST classes for clauses.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_OPENACCCLAUSE_H
+#define LLVM_CLANG_AST_OPENACCCLAUSE_H
+#include "clang/AST/ASTContext.h"
+#include "clang/Basic/OpenACCKinds.h"
+
+namespace clang {
+/// This is the base type for all OpenACC Clauses.
+class OpenACCClause {
+  OpenACCClauseKind Kind;
+  SourceRange Location;
+
+protected:
+  OpenACCClause(OpenACCClauseKind K, SourceLocation BeginLoc,
+                SourceLocation EndLoc)
+      : Kind(K), Location(BeginLoc, EndLoc) {}
+
+public:
+  OpenACCClauseKind getClauseKind() const { return Kind; }
+  SourceLocation getBeginLoc() const { return Location.getBegin(); }
+  SourceLocation getEndLoc() const { return Location.getEnd(); }
+
+  static bool classof(const OpenACCClause *) { return true; }
+
+  virtual ~OpenACCClause() = default;
+};
+
+/// Represents a clause that has a list of parameters.
+class OpenACCClauseWithParams : public OpenACCClause {
+  /// Location of the '('.
+  SourceLocation LParenLoc;
+
+protected:
+  OpenACCClauseWithParams(OpenACCClauseKind K, SourceLocation BeginLoc,
+                          SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OpenACCClause(K, BeginLoc, EndLoc), LParenLoc(LParenLoc) {}
+
+public:
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+};
+
+template <class Impl> class OpenACCClauseVisitor {
+  Impl &getDerived() { return static_cast<Impl &>(*this); }
+
+public:
+  void VisitClauseList(ArrayRef<const OpenACCClause *> List) {
+    for (const OpenACCClause *Clause : List)
+      Visit(Clause);
+  }
+
+  void Visit(const OpenACCClause *C) {
+    if (!C)
+      return;
+
+    switch (C->getClauseKind()) {
+    case OpenACCClauseKind::Default:
+    case OpenACCClauseKind::Finalize:
+    case OpenACCClauseKind::IfPresent:
+    case OpenACCClauseKind::Seq:
+    case OpenACCClauseKind::Independent:
+    case OpenACCClauseKind::Auto:
+    case OpenACCClauseKind::Worker:
+    case OpenACCClauseKind::Vector:
+    case OpenACCClauseKind::NoHost:
+    case OpenACCClauseKind::If:
+    case OpenACCClauseKind::Self:
+    case OpenACCClauseKind::Copy:
+    case OpenACCClauseKind::UseDevice:
+    case OpenACCClauseKind::Attach:
+    case OpenACCClauseKind::Delete:
+    case OpenACCClauseKind::Detach:
+    case OpenACCClauseKind::Device:
+    case OpenACCClauseKind::DevicePtr:
+    case OpenACCClauseKind::DeviceResident:
+    case OpenACCClauseKind::FirstPrivate:
+    case OpenACCClauseKind::Host:
+    case OpenACCClauseKind::Link:
+    case OpenACCClauseKind::NoCreate:
+    case OpenACCClauseKind::Present:
+    case OpenACCClauseKind::Private:
+    case OpenACCClauseKind::CopyOut:
+    case OpenACCClauseKind::CopyIn:
+    case OpenACCClauseKind::Create:
+    case OpenACCClauseKind::Reduction:
+    case OpenACCClauseKind::Collapse:
+    case OpenACCClauseKind::Bind:
+    case OpenACCClauseKind::VectorLength:
+    case OpenACCClauseKind::NumGangs:
+    case OpenACCClauseKind::NumWorkers:
+    case OpenACCClauseKind::DeviceNum:
+    case OpenACCClauseKind::DefaultAsync:
+    case OpenACCClauseKind::DeviceType:
+    case OpenACCClauseKind::DType:
+    case OpenACCClauseKind::Async:
+    case OpenACCClauseKind::Tile:
+    case OpenACCClauseKind::Gang:
+    case OpenACCClauseKind::Wait:
+    case OpenACCClauseKind::Invalid:
+      llvm_unreachable("Clause visitor not yet implemented");
+    }
+    llvm_unreachable("Invalid Clause kind");
+  }
+};
+
+class OpenACCClausePrinter final
+    : public OpenACCClauseVisitor<OpenACCClausePrinter> {
+  raw_ostream &OS;
+
+public:
+  void VisitClauseList(ArrayRef<const OpenACCClause *> List) {
+    for (const OpenACCClause *Clause : List) {
+      Visit(Clause);
+
+      if (Clause != List.back())
+        OS << ' ';
+    }
+  }
+  OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}
+};
+
+} // namespace clang
+
+#endif // LLVM_CLANG_AST_OPENACCCLAUSE_H
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8630317795a9ad..7eb92e304a3856 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -509,6 +509,7 @@ template <typename Derived> class RecursiveASTVisitor {
   bool TraverseOpenACCConstructStmt(OpenACCConstructStmt *S);
   bool
   TraverseOpenACCAssociatedStmtConstruct(OpenACCAssociatedStmtConstruct *S);
+  bool VisitOpenACCClauseList(ArrayRef<const OpenACCClause *>);
 };
 
 template <typename Derived>
@@ -3936,8 +3937,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPXBareClause(OMPXBareClause *C) {
 
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseOpenACCConstructStmt(
-    OpenACCConstructStmt *) {
-  // TODO OpenACC: When we implement clauses, ensure we traverse them here.
+    OpenACCConstructStmt *C) {
+  TRY_TO(VisitOpenACCClauseList(C->clauses()));
   return true;
 }
 
@@ -3949,6 +3950,14 @@ bool RecursiveASTVisitor<Derived>::TraverseOpenACCAssociatedStmtConstruct(
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOpenACCClauseList(
+    ArrayRef<const OpenACCClause *>) {
+  // TODO OpenACC: When we have Clauses with expressions, we should visit them
+  // here.
+  return true;
+}
+
 DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
                   { TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenACC.h b/clang/include/clang/AST/StmtOpenACC.h
index 19da66832c7374..504330b50af916 100644
--- a/clang/include/clang/AST/StmtOpenACC.h
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_AST_STMTOPENACC_H
 #define LLVM_CLANG_AST_STMTOPENACC_H
 
+#include "clang/AST/OpenACCClause.h"
 #include "clang/AST/Stmt.h"
 #include "clang/Basic/OpenACCKinds.h"
 #include "clang/Basic/SourceLocation.h"
@@ -30,13 +31,23 @@ class OpenACCConstructStmt : public Stmt {
   /// the directive.
   SourceRange Range;
 
-  // TODO OPENACC: Clauses should probably be collected in this class.
+  /// The list of clauses.  This is stored here as an ArrayRef, as this is the
+  /// most convienient place to access the list, however the list itself should
+  /// be stored in leaf nodes, likely in trailing-storage.
+  MutableArrayRef<const OpenACCClause *> Clauses;
 
 protected:
   OpenACCConstructStmt(StmtClass SC, OpenACCDirectiveKind K,
                        SourceLocation Start, SourceLocation End)
       : Stmt(SC), Kind(K), Range(Start, End) {}
 
+  // Used only for initialization, the leaf class can initialize this to
+  // trailing storage.
+  void setClauseList(MutableArrayRef<const OpenACCClause *> NewClauses) {
+    assert(Clauses.empty() && "Cannot change clause list");
+    Clauses = NewClauses;
+  }
+
 public:
   OpenACCDirectiveKind getDirectiveKind() const { return Kind; }
 
@@ -47,6 +58,7 @@ class OpenACCConstructStmt : public Stmt {
 
   SourceLocation getBeginLoc() const { return Range.getBegin(); }
   SourceLocation getEndLoc() const { return Range.getEnd(); }
+  const ArrayRef<const OpenACCClause *> clauses() const { return Clauses; }
 
   child_range children() {
     return child_range(child_iterator(), child_iterator());
@@ -101,17 +113,31 @@ class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
 /// those three, as they are semantically identical, and have only minor
 /// differences in the permitted list of clauses, which can be differentiated by
 /// the 'Kind'.
-class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
+class OpenACCComputeConstruct final
+    : public OpenACCAssociatedStmtConstruct,
+      public llvm::TrailingObjects<OpenACCComputeConstruct,
+                                   const OpenACCClause *> {
   friend class ASTStmtWriter;
   friend class ASTStmtReader;
   friend class ASTContext;
-  OpenACCComputeConstruct()
-      : OpenACCAssociatedStmtConstruct(
-            OpenACCComputeConstructClass, OpenACCDirectiveKind::Invalid,
-            SourceLocation{}, SourceLocation{}, /*AssociatedStmt=*/nullptr) {}
+  OpenACCComputeConstruct(unsigned NumClauses)
+      : OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass,
+                                       OpenACCDirectiveKind::Invalid,
+                                       SourceLocation{}, SourceLocation{},
+                                       /*AssociatedStmt=*/nullptr) {
+    // We cannot send the TrailingObjects storage to the base class (which holds
+    // a reference to the data) until it is constructed, so we have to set it
+    // separately here.
+    memset(getTrailingObjects<const OpenACCClause *>(), 0,
+           NumClauses * sizeof(const OpenACCClause *));
+    setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
+                                  NumClauses));
+  }
 
   OpenACCComputeConstruct(OpenACCDirectiveKind K, SourceLocation Start,
-                          SourceLocation End, Stmt *StructuredBlock)
+                          SourceLocation End,
+                          ArrayRef<const OpenACCClause *> Clauses,
+                          Stmt *StructuredBlock)
       : OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass, K, Start,
                                        End, StructuredBlock) {
     assert((K == OpenACCDirectiveKind::Parallel ||
@@ -119,6 +145,13 @@ class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
             K == OpenACCDirectiveKind::Kernels) &&
            "Only parallel, serial, and kernels constructs should be "
            "represented by this type");
+
+    // Initialize the trailing storage.
+    for (unsigned I = 0; I < Clauses.size(); ++I)
+      *(getTrailingObjects<const OpenACCClause *>() + I) = Clauses[I];
+
+    setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
+                                  Clauses.size()));
   }
 
   void setStructuredBlock(Stmt *S) { setAssociatedStmt(S); }
@@ -128,10 +161,12 @@ class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
     return T->getStmtClass() == OpenACCComputeConstructClass;
   }
 
-  static OpenACCComputeConstruct *CreateEmpty(const ASTContext &C, EmptyShell);
+  static OpenACCComputeConstruct *CreateEmpty(const ASTContext &C,
+                                              unsigned NumClauses);
   static OpenACCComputeConstruct *
   Create(const ASTContext &C, OpenACCDirectiveKind K, SourceLocation BeginLoc,
-         SourceLocation EndLoc, Stmt *StructuredBlock);
+         SourceLocation EndLoc, ArrayRef<const OpenACCClause *> Clauses,
+         Stmt *StructuredBlock);
 
   Stmt *getStructuredBlock() { return getAssociatedStmt(); }
   const Stmt *getStructuredBlock() const {
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index efb5bfe7f83d40..1fede6e462e925 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -189,6 +189,8 @@ class TextNodeDumper
 
   void Visit(const OMPClause *C);
 
+  void Visit(const OpenACCClause *C);
+
   void Visit(const BlockDecl::Capture &C);
 
   void Visit(const GenericSelectionExpr::ConstAssociation &A);
diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h
index 5d3e95cb5d630f..7dd1140106e47c 100644
--- a/clang/include/clang/Serialization/ASTRecordReader.h
+++ b/clang/include/clang/Serialization/ASTRecordReader.h
@@ -24,6 +24,7 @@
 #include "llvm/ADT/APSInt.h"
 
 namespace clang {
+class OpenACCClause;
 class OMPTraitInfo;
 class OMPChildren;
 
@@ -278,6 +279,12 @@ class ASTRecordReader
   /// Read an OpenMP children, advancing Idx.
   void readOMPChildren(OMPChildren *Data);
 
+  /// Read an OpenACC clause, advancing Idx.
+  OpenACCClause *readOpenACCClause();
+
+  /// Read a list of OpenACC clauses into the passed SmallVector.
+  void readOpenACCClauseList(MutableArrayRef<const OpenACCClause *> Clauses);
+
   /// Read a source location, advancing Idx.
   SourceLocation readSourceLocation(LocSeq *Seq = nullptr) {
     return Reader->ReadSourceLocation(*F, Record, Idx, Seq);
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index e007d4a70843a1..1feb8fcbacf772 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -21,6 +21,7 @@
 
 namespace clang {
 
+class OpenACCClause;
 class TypeLoc;
 
 /// An object for streaming information to a record.
@@ -292,6 +293,12 @@ class ASTRecordWriter
   /// Writes data related to the OpenMP directives.
   void writeOMPChildren(OMPChildren *Data);
 
+  /// Writes out a single OpenACC Clause.
+  void writeOpenACCClause(const OpenACCClause *C);
+
+  /// Writes out a list of OpenACC clauses.
+  void writeOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses);
+
   /// Emit a string.
   void AddString(StringRef Str) {
     return Writer->AddString(Str, *Record);
diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index 3fba052d916c9e..3faefb54f599fb 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -98,6 +98,7 @@ add_clang_library(clangAST
   NSAPI.cpp
   ODRDiagsEmitter.cpp
   ODRHash.cpp
+  OpenACCClause.cpp
   OpenMPClause.cpp
   OSLog.cpp
   ParentMap.cpp
diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp
index 5861d5a7ea0dd2..fb3494393f7559 100644
--- a/clang/lib/AST/JSONNodeDumper.cpp
+++ b/clang/lib/AST/JSONNodeDumper.cpp
@@ -187,6 +187,8 @@ void JSONNodeDumper::Visit(const CXXCtorInitializer *Init) {
     llvm_unreachable("Unknown initializer type");
 }
 
+void JSONNodeDumper::Visit(const OpenACCClause *C) {}
+
 void JSONNodeDumper::Visit(const OMPClause *C) {}
 
 void JSONNodeDumper::Visit(const BlockDecl::Capture &C) {
diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp
new file mode 100644
index 00000000000000..e1db872f25c322
--- /dev/null
+++ b/clang/lib/AST/OpenACCClause.cpp
@@ -0,0 +1,17 @@
+//===---- OpenACCClause.cpp - Classes for OpenACC Clauses  ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the subclasses of the OpenACCClause class declared in
+// OpenACCClause.h
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/OpenACCClause.h"
+#include "clang/AST/ASTContext.h"
+
+using namespace clang;
diff --git a/clang/lib/AST/StmtOpenACC.cpp b/clang/lib/AST/StmtOpenACC.cpp
index e6191bc6db7080..a381a8dd7b62c3 100644
--- a/clang/lib/AST/StmtOpenACC.cpp
+++ b/clang/lib/AST/StmtOpenACC.cpp
@@ -15,20 +15,23 @@
 using namespace clang;
 
 OpenACCComputeConstruct *
-OpenACCComputeConstruct::CreateEmpty(const ASTContext &C, EmptyShell) {
-  void *Mem = C.Allocate(sizeof(OpenACCComputeConstruct),
-                         alignof(OpenACCComputeConstruct));
-  auto *Inst = new (Mem) OpenACCComputeConstruct;
+OpenACCComputeConstruct::CreateEmpty(const ASTContext &C, unsigned NumClauses) {
+  void *Mem = C.Allocate(
+      OpenACCComputeConstruct::totalSizeToAlloc<const OpenACCClause *>(
+          NumClauses));
+  auto *Inst = new (Mem) OpenACCComputeConstruct(NumClauses);
   return Inst;
 }
 
 OpenACCComputeConstruct *
 OpenACCComputeConstruct::Create(const ASTContext &C, OpenACCDirectiveKind K,
                                 SourceLocation BeginLoc, SourceLocation EndLoc,
+                                ArrayRef<const OpenACCClause *> Clauses,
                                 Stmt *StructuredBlock) {
-  void *Mem = C.Allocate(sizeof(OpenACCComputeConstruct),
-                         alignof(OpenACCComputeConstruct));
-  auto *Inst =
-      new (Mem) OpenACCComputeConstruct(K, BeginLoc, EndLoc, StructuredBlock);
+  void *Mem = C.Allocate(
+      OpenACCComputeConstruct::totalSizeToAlloc<const OpenACCClause *>(
+          Clauses.size()));
+  auto *Inst = new (Mem)
+      OpenACCComputeConstruct(K, BeginLoc, EndLoc, Clauses, StructuredBlock);
   return Inst;
 }
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index d66c3ccce2094c..74b18e50bf1f40 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1142,7 +1142,13 @@ void StmtPrinter::VisitOMPTargetParallelGenericLoopDirective(
 //===----------------------------------------------------------------------===//
 void StmtPrinter::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
   Indent() << "#pragma acc " << S->getDirectiveKind();
-  // TODO OpenACC: Print Clauses.
+
+  if (!S->clauses().empty()) {
+    OS << ' ';
+    OpenACCClausePrinter Printer(OS);
+    Printer.VisitClauseList(S->clauses());
+  }
+
   PrintStmt(S->getStructuredBlock());
 }
 
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index b545ff472e5a2b..d68547f444c52f 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2441,11 +2441,30 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
   }
 }
 
+namespace {
+class OpenACCClauseProfiler
+    : public OpenACCClauseVisitor<OpenACCClauseProfiler> {
+
+public:
+  OpenACCClauseProfiler() = default;
+
+  void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
+    for (const OpenACCClause *Clause : Clauses) {
+      // TODO OpenACC: When we have clauses with expressions, we should
+      // profile them too.
+      Visit(Clause);
+    }
+  }
+};
+} // namespace
+
 void StmtProfiler::VisitOpenACCComputeConstruct(
     const OpenACCComputeConstruct *S) {
   // VisitStmt handles children, so the AssociatedStmt is handled.
   VisitStmt(S);
-  // TODO OpenACC: Visit Clauses.
+
+  OpenACCClauseProfiler P;
+  P.VisitOpenACCClauseList(S->clauses());
 }
 
 void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 413e452146bdb2..0ffbf47c9a2f4e 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -381,6 +381,28 @@ void TextNodeDumper::Visit(const OMPClause *C) {
     OS << " <implicit>";
 }
 
+void TextNodeDumper::Visit(const OpenACCClause *C) {
+  if (!C) {
+    ColorScope Color(OS, ShowColors, NullColor);
+    OS << "<<<NULL>>> OpenACCClause";
+    return;
+  }
+  {
+    ColorScope Color(OS, ShowColors, AttrColor);
+    OS << C->getClauseKind();
+
+    // Handle clauses with parens for types that have no children, likely
+    // because there is no sub expression.
+    switch (C->getClauseKind()) {
+    default:
+      // Nothing to do here.
+      break;
+    }
+  }
+  dumpPointer(C);
+  dumpSourceRange(SourceRange(C->getBeginLoc(), C->getEndLoc()));
+}
+
 void TextNodeDumper::Visit(const GenericSelectionExpr::ConstAssociation &A) {
   const TypeSourceInfo *TSI = A.getTypeSourceInfo();
   if (TSI) {
@@ -2684,5 +2706,4 @@ void TextNodeDumper::VisitHLSLBufferDecl(const HLSLBufferDecl *D) {
 
 void TextNodeDumper::VisitOpenACCConstructStmt(const OpenACCConstructStmt *S) {
   OS << " " << S->getDirectiveKind();
-  // TODO OpenACC: Dump clauses as well.
 }
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 2ac994cac71e19..7d4b84f0eae9d2 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -94,9 +94,10 @@ StmtResult SemaOpenACC::ActOnEndStmtDirective(OpenACCDirectiveKind K,
   case OpenACCDirectiveKind::Parallel:
   case OpenACCDirectiveKind::Serial:
   case OpenACCDirectiveKind::Kernels:
+    // TODO OpenACC: Add clauses to the construct here.
     return OpenACCComputeConstruct::Create(
         getASTContext(), K, StartLoc, EndLoc,
-        AssocStmt.isUsable() ? AssocStmt.get() : nullptr);
+        /*Clauses=*/{}, AssocStmt.isUsable() ? AssocStmt.get() : nullptr);
   }
   llvm_unreachable("Unhandled case in directive handling?");
 }
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 9a39e7d3826e7d..800043bfe456bb 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -32,6 +32,7 @@
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/ODRDiagsEmitter.h"
 #include "clang/AST/ODRHash.h"
+#include "clang/AST/OpenACCClause.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/RawCommentList.h"
 #include "clang/AST/TemplateBase.h"
@@ -53,6 +54,7 @@
 #include "clang/Basic/LangOptions.h"
 #include "clang/Basic/Module.h"
 #include "clang/Basic/ObjCRuntime.h"
+#include "clang/Basic/OpenACCKinds.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/OperatorKinds.h"
 #include "clang/Basic/PragmaKinds.h"
@@ -11751,3 +11753,67 @@ void ASTRecordReader::readOMPChildren(OMPChildren *Data) {
   for (unsigned I = 0, E = Data->getNumChildren(); I < E; ++I)
     Data->getChildren()[I] = readStmt();
 }
+
+OpenACCClause *ASTRecordReader::readOpenACCClause() {
+  OpenACCClauseKind ClauseKind = readEnum<OpenACCClauseKind>();
+  SourceLocation BeginLoc = readSourceLocation();
+  SourceLocation EndLoc = readSourceLocation();
+
+  // TODO OpenACC: We don't have these used anywhere, but eventually we should
+  // be constructing the Clauses with them, so these casts can go away.
+  (void)BeginLoc;
+  (void)EndLoc;
+  switch (ClauseKind) {
+  case OpenACCClauseKind::Default:
+  case OpenACCClauseKind::Finalize:
+  case OpenACCClauseKind::IfPresent:
+  case OpenACCClauseKind::Seq:
+  case OpenACCClauseKind::Independent:
+  case OpenACCClauseKind::Auto:
+  case OpenACCClauseKind::Worker:
+  case OpenACCClauseKind::Vector:
+  case OpenACCClauseKind::NoHost:
+  case OpenACCClauseKind::If:
+  case OpenACCClauseKind::Self:
+  case OpenACCClauseKind::Copy:
+  case OpenACCClauseKind::UseDevice:
+  case OpenACCClauseKind::Attach:
+  case OpenACCClauseKind::Delete:
+  case OpenACCClauseKind::Detach:
+  case OpenACCClauseKind::Device:
+  case OpenACCClauseKind::DevicePtr:
+  case OpenACCClauseKind::DeviceResident:
+  case OpenACCClauseKind::FirstPrivate:
+  case OpenACCClauseKind::Host:
+  case OpenACCClauseKind::Link:
+  case OpenACCClauseKind::NoCreate:
+  case OpenACCClauseKind::Present:
+  case OpenACCClauseKind::Private:
+  case OpenACCClauseKind::CopyOut:
+  case OpenACCClauseKind::CopyIn:
+  case OpenACCClauseKind::Create:
+  case OpenACCClauseKind::Reduction:
+  case OpenACCClauseKind::Collapse:
+  case OpenACCClauseKind::Bind:
+  case OpenACCClauseKind::VectorLength:
+  case OpenACCClauseKind::NumGangs:
+  case OpenACCClauseKind::NumWorkers:
+  case OpenACCClauseKind::DeviceNum:
+  case OpenACCClauseKind::DefaultAsync:
+  case OpenACCClauseKind::DeviceType:
+  case OpenACCClauseKind::DType:
+  case OpenACCClauseKind::Async:
+  case OpenACCClauseKind::Tile:
+  case OpenACCClauseKind::Gang:
+  case OpenACCClauseKind::Wait:
+  case OpenACCClauseKind::Invalid:
+    llvm_unreachable("Clause serialization not yet implemented");
+  }
+  llvm_unreachable("Invalid Clause Kind");
+}
+
+void ASTRecordReader::readOpenACCClauseList(
+    MutableArrayRef<const OpenACCClause *> Clauses) {
+  for (unsigned I = 0; I < Clauses.size(); ++I)
+    Clauses[I] = readOpenACCClause();
+}
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index bbeb6db011646f..f0984c3e469603 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2784,9 +2784,10 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
 // OpenACC Constructs/Directives.
 //===----------------------------------------------------------------------===//
 void ASTStmtReader::VisitOpenACCConstructStmt(OpenACCConstructStmt *S) {
+  (void)Record.readInt();
   S->Kind = Record.readEnum<OpenACCDirectiveKind>();
   S->Range = Record.readSourceRange();
-  // TODO OpenACC: Deserialize Clauses.
+  Record.readOpenACCClauseList(S->Clauses);
 }
 
 void ASTStmtReader::VisitOpenACCAssociatedStmtConstruct(
@@ -4218,10 +4219,11 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) {
       S = new (Context) ConceptSpecializationExpr(Empty);
       break;
     }
-    case STMT_OPENACC_COMPUTE_CONSTRUCT:
-      S = OpenACCComputeConstruct::CreateEmpty(Context, Empty);
+    case STMT_OPENACC_COMPUTE_CONSTRUCT: {
+      unsigned NumClauses = Record[ASTStmtReader::NumStmtFields];
+      S = OpenACCComputeConstruct::CreateEmpty(Context, NumClauses);
       break;
-
+    }
     case EXPR_REQUIRES:
       unsigned numLocalParameters = Record[ASTStmtReader::NumExprFields];
       unsigned numRequirement = Record[ASTStmtReader::NumExprFields + 1];
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index ba6a8a5e16e4e7..baf03f69d73065 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -29,6 +29,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/LambdaCapture.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/OpenACCClause.h"
 #include "clang/AST/OpenMPClause.h"
 #include "clang/AST/RawCommentList.h"
 #include "clang/AST/TemplateName.h"
@@ -44,6 +45,7 @@
 #include "clang/Basic/LangOptions.h"
 #include "clang/Basic/Module.h"
 #include "clang/Basic/ObjCRuntime.h"
+#include "clang/Basic/OpenACCKinds.h"
 #include "clang/Basic/OpenCLOptions.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
@@ -7397,3 +7399,63 @@ void ASTRecordWriter::writeOMPChildren(OMPChildren *Data) {
   for (unsigned I = 0, E = Data->getNumChildren(); I < E; ++I)
     AddStmt(Data->getChildren()[I]);
 }
+
+void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
+  writeEnum(C->getClauseKind());
+  writeSourceLocation(C->getBeginLoc());
+  writeSourceLocation(C->getEndLoc());
+
+  switch (C->getClauseKind()) {
+  case OpenACCClauseKind::Default:
+  case OpenACCClauseKind::Finalize:
+  case OpenACCClauseKind::IfPresent:
+  case OpenACCClauseKind::Seq:
+  case OpenACCClauseKind::Independent:
+  case OpenACCClauseKind::Auto:
+  case OpenACCClauseKind::Worker:
+  case OpenACCClauseKind::Vector:
+  case OpenACCClauseKind::NoHost:
+  case OpenACCClauseKind::If:
+  case OpenACCClauseKind::Self:
+  case OpenACCClauseKind::Copy:
+  case OpenACCClauseKind::UseDevice:
+  case OpenACCClauseKind::Attach:
+  case OpenACCClauseKind::Delete:
+  case OpenACCClauseKind::Detach:
+  case OpenACCClauseKind::Device:
+  case OpenACCClauseKind::DevicePtr:
+  case OpenACCClauseKind::DeviceResident:
+  case OpenACCClauseKind::FirstPrivate:
+  case OpenACCClauseKind::Host:
+  case OpenACCClauseKind::Link:
+  case OpenACCClauseKind::NoCreate:
+  case OpenACCClauseKind::Present:
+  case OpenACCClauseKind::Private:
+  case OpenACCClauseKind::CopyOut:
+  case OpenACCClauseKind::CopyIn:
+  case OpenACCClauseKind::Create:
+  case OpenACCClauseKind::Reduction:
+  case OpenACCClauseKind::Collapse:
+  case OpenACCClauseKind::Bind:
+  case OpenACCClauseKind::VectorLength:
+  case OpenACCClauseKind::NumGangs:
+  case OpenACCClauseKind::NumWorkers:
+  case OpenACCClauseKind::DeviceNum:
+  case OpenACCClauseKind::DefaultAsync:
+  case OpenACCClauseKind::DeviceType:
+  case OpenACCClauseKind::DType:
+  case OpenACCClauseKind::Async:
+  case OpenACCClauseKind::Tile:
+  case OpenACCClauseKind::Gang:
+  case OpenACCClauseKind::Wait:
+  case OpenACCClauseKind::Invalid:
+    llvm_unreachable("Clause serialization not yet implemented");
+  }
+  llvm_unreachable("Invalid Clause Kind");
+}
+
+void ASTRecordWriter::writeOpenACCClauseList(
+    ArrayRef<const OpenACCClause *> Clauses) {
+  for (const OpenACCClause *Clause : Clauses)
+    writeOpenACCClause(Clause);
+}
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 22e190450d3918..0651614e2ce548 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2839,9 +2839,10 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
 // OpenACC Constructs/Directives.
 //===----------------------------------------------------------------------===//
 void ASTStmtWriter::VisitOpenACCConstructStmt(OpenACCConstructStmt *S) {
+  Record.push_back(S->clauses().size());
   Record.writeEnum(S->Kind);
   Record.AddSourceRange(S->Range);
-  // TODO OpenACC: Serialize Clauses.
+  Record.writeOpenACCClauseList(S->clauses());
 }
 
 void ASTStmtWriter::VisitOpenACCAssociatedStmtConstruct(

>From f539e8841c44d5b3141967b09f337a7661df7276 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Fri, 5 Apr 2024 06:59:14 -0700
Subject: [PATCH 2/2] Incorporate review suggestions

---
 clang/include/clang/AST/StmtOpenACC.h | 12 +++++++-----
 clang/lib/Sema/SemaOpenACC.cpp        |  3 ++-
 clang/lib/Serialization/ASTReader.cpp | 11 +++++------
 3 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/clang/include/clang/AST/StmtOpenACC.h b/clang/include/clang/AST/StmtOpenACC.h
index 504330b50af916..419cb6cada0bc7 100644
--- a/clang/include/clang/AST/StmtOpenACC.h
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -17,6 +17,7 @@
 #include "clang/AST/Stmt.h"
 #include "clang/Basic/OpenACCKinds.h"
 #include "clang/Basic/SourceLocation.h"
+#include <memory>
 
 namespace clang {
 /// This is the base class for an OpenACC statement-level construct, other
@@ -58,7 +59,7 @@ class OpenACCConstructStmt : public Stmt {
 
   SourceLocation getBeginLoc() const { return Range.getBegin(); }
   SourceLocation getEndLoc() const { return Range.getEnd(); }
-  const ArrayRef<const OpenACCClause *> clauses() const { return Clauses; }
+  ArrayRef<const OpenACCClause *> clauses() const { return Clauses; }
 
   child_range children() {
     return child_range(child_iterator(), child_iterator());
@@ -128,8 +129,9 @@ class OpenACCComputeConstruct final
     // We cannot send the TrailingObjects storage to the base class (which holds
     // a reference to the data) until it is constructed, so we have to set it
     // separately here.
-    memset(getTrailingObjects<const OpenACCClause *>(), 0,
-           NumClauses * sizeof(const OpenACCClause *));
+    std::uninitialized_value_construct(
+        getTrailingObjects<const OpenACCClause *>(),
+        getTrailingObjects<const OpenACCClause *>() + NumClauses);
     setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
                                   NumClauses));
   }
@@ -147,8 +149,8 @@ class OpenACCComputeConstruct final
            "represented by this type");
 
     // Initialize the trailing storage.
-    for (unsigned I = 0; I < Clauses.size(); ++I)
-      *(getTrailingObjects<const OpenACCClause *>() + I) = Clauses[I];
+    std::uninitialized_copy(Clauses.begin(), Clauses.end(),
+                            getTrailingObjects<const OpenACCClause *>());
 
     setClauseList(MutableArrayRef(getTrailingObjects<const OpenACCClause *>(),
                                   Clauses.size()));
diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp
index 7d4b84f0eae9d2..86ffa5ad74c130 100644
--- a/clang/lib/Sema/SemaOpenACC.cpp
+++ b/clang/lib/Sema/SemaOpenACC.cpp
@@ -97,7 +97,8 @@ StmtResult SemaOpenACC::ActOnEndStmtDirective(OpenACCDirectiveKind K,
     // TODO OpenACC: Add clauses to the construct here.
     return OpenACCComputeConstruct::Create(
         getASTContext(), K, StartLoc, EndLoc,
-        /*Clauses=*/{}, AssocStmt.isUsable() ? AssocStmt.get() : nullptr);
+        /*Clauses=*/std::nullopt,
+        AssocStmt.isUsable() ? AssocStmt.get() : nullptr);
   }
   llvm_unreachable("Unhandled case in directive handling?");
 }
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 800043bfe456bb..9c0364b3934b55 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11756,13 +11756,12 @@ void ASTRecordReader::readOMPChildren(OMPChildren *Data) {
 
 OpenACCClause *ASTRecordReader::readOpenACCClause() {
   OpenACCClauseKind ClauseKind = readEnum<OpenACCClauseKind>();
-  SourceLocation BeginLoc = readSourceLocation();
-  SourceLocation EndLoc = readSourceLocation();
-
   // TODO OpenACC: We don't have these used anywhere, but eventually we should
-  // be constructing the Clauses with them, so these casts can go away.
-  (void)BeginLoc;
-  (void)EndLoc;
+  // be constructing the Clauses with them, so these attributes can go away at
+  // that point.
+  [[maybe_unused]] SourceLocation BeginLoc = readSourceLocation();
+  [[maybe_unused]] SourceLocation EndLoc = readSourceLocation();
+
   switch (ClauseKind) {
   case OpenACCClauseKind::Default:
   case OpenACCClauseKind::Finalize:



More information about the cfe-commits mailing list