[clang] [OpenACC] Implement AST for OpenACC Compute Constructs (PR #81188)

Erich Keane via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 9 06:37:05 PST 2024


https://github.com/erichkeane updated https://github.com/llvm/llvm-project/pull/81188

>From b7ca554663c4d0994ac255ee17ac016ef94f6778 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Thu, 8 Feb 2024 07:56:30 -0800
Subject: [PATCH 1/3] [OpenACC] Implement AST for OpenACC Compute Constructs

'serial', 'parallel', and 'kernel' constructs are all considered
'Compute' constructs. This patch creates the AST type, plus the required
infrastructure for such a type, plus some base types that will be useful
in the future for breaking this up.

The only difference between the three is the 'kind'( plus some minor
 clause legalization rules, but those can be differentiated easily
 enough), so rather than representing them as separate AST nodes, it seems
to make sense to make them the same.

Additionally, no clause AST functionality is being implemented yet, as
that fits better in a separate patch, and this is enough to get the
'naked' constructs implemented.

This is otherwise an 'NFC' patch, as it doesn't alter execution at all,
so there aren't any tests.  I did this to break up the review workload
and to get feedback on the layout.
---
 clang/include/clang-c/Index.h                 |   6 +-
 clang/include/clang/AST/RecursiveASTVisitor.h |  22 +++
 clang/include/clang/AST/StmtOpenACC.h         | 141 ++++++++++++++++++
 clang/include/clang/AST/StmtVisitor.h         |   1 +
 clang/include/clang/AST/TextNodeDumper.h      |   1 +
 clang/include/clang/Basic/OpenACCKinds.h      |  31 +++-
 clang/include/clang/Basic/StmtNodes.td        |   6 +
 .../include/clang/Serialization/ASTBitCodes.h |   3 +
 clang/lib/AST/ASTStructuralEquivalence.cpp    |   1 +
 clang/lib/AST/CMakeLists.txt                  |   1 +
 clang/lib/AST/Stmt.cpp                        |   1 +
 clang/lib/AST/StmtOpenACC.cpp                 |  33 ++++
 clang/lib/AST/StmtPrinter.cpp                 |   9 ++
 clang/lib/AST/StmtProfile.cpp                 |   8 +
 clang/lib/AST/TextNodeDumper.cpp              |   5 +
 clang/lib/CodeGen/CGStmt.cpp                  |   3 +
 clang/lib/CodeGen/CodeGenFunction.h           |  10 ++
 clang/lib/Sema/SemaExceptionSpec.cpp          |   1 +
 clang/lib/Sema/TreeTransform.h                |  22 +++
 clang/lib/Serialization/ASTReaderStmt.cpp     |  24 +++
 clang/lib/Serialization/ASTWriterStmt.cpp     |  22 +++
 clang/lib/StaticAnalyzer/Core/ExprEngine.cpp  |   1 +
 clang/tools/libclang/CIndex.cpp               |   2 +
 clang/tools/libclang/CXCursor.cpp             |   3 +
 24 files changed, 352 insertions(+), 5 deletions(-)
 create mode 100644 clang/include/clang/AST/StmtOpenACC.h
 create mode 100644 clang/lib/AST/StmtOpenACC.cpp

diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 6af41424ba89a..8d939e8d21901 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2145,7 +2145,11 @@ enum CXCursorKind {
    */
   CXCursor_OMPScopeDirective = 306,
 
-  CXCursor_LastStmt = CXCursor_OMPScopeDirective,
+  /** OpenACC Compute Construct.
+   */
+  CXCursor_OpenACCComputeConstruct = 307,
+
+  CXCursor_LastStmt = CXCursor_OpenACCComputeConstruct,
 
   /**
    * Cursor that represents the translation unit itself.
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 9da5206a21c34..5080551ada4fc 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -34,6 +34,7 @@
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
@@ -505,6 +506,9 @@ template <typename Derived> class RecursiveASTVisitor {
   bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node);
 
   bool PostVisitStmt(Stmt *S);
+  bool TraverseOpenACCConstructStmt(OpenACCConstructStmt *S);
+  bool
+  TraverseOpenACCAssociatedStmtConstruct(OpenACCAssociatedStmtConstruct *S);
 };
 
 template <typename Derived>
@@ -3910,6 +3914,24 @@ bool RecursiveASTVisitor<Derived>::VisitOMPXBareClause(OMPXBareClause *C) {
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::TraverseOpenACCConstructStmt(
+    OpenACCConstructStmt *) {
+  // TODO OpenACC: When we implement clauses, ensure we traverse them here.
+  return true;
+}
+
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::TraverseOpenACCAssociatedStmtConstruct(
+    OpenACCAssociatedStmtConstruct *S) {
+  TRY_TO(TraverseOpenACCConstructStmt(S));
+  TRY_TO(TraverseStmt(S->getAssociatedStmt()));
+  return true;
+}
+
+DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
+                  { TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
+
 // FIXME: look at the following tricky-seeming exprs to see if we
 // need to recurse on anything.  These are ones that have methods
 // returning decls or qualtypes or nestednamespecifier -- though I'm
diff --git a/clang/include/clang/AST/StmtOpenACC.h b/clang/include/clang/AST/StmtOpenACC.h
new file mode 100644
index 0000000000000..1e14d599ff451
--- /dev/null
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -0,0 +1,141 @@
+//===- StmtOpenACC.h - Classes for OpenACC directives  ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines OpenACC AST classes for statement-level contructs.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_STMTOPENACC_H
+#define LLVM_CLANG_AST_STMTOPENACC_H
+
+#include "clang/AST/Stmt.h"
+#include "clang/Basic/OpenACCKinds.h"
+#include "clang/Basic/SourceLocation.h"
+
+namespace clang {
+/// This is the base class for an OpenACC statement-level construct, other
+/// construct types are expected to inherit from this.
+class OpenACCConstructStmt : public Stmt {
+  friend class ASTStmtWriter;
+  friend class ASTStmtReader;
+  /// The directive kind. Each implementation of this interface should handle
+  /// specific kinds.
+  OpenACCDirectiveKind Kind = OpenACCDirectiveKind::Invalid;
+  /// The location of the directive statement, from the '#' to the last token of
+  /// the directive.
+  SourceRange Range;
+
+  // TODO OPENACC: Clauses should probably be collected in this class.
+
+protected:
+  OpenACCConstructStmt(StmtClass SC, OpenACCDirectiveKind K,
+                       SourceLocation Start, SourceLocation End)
+      : Stmt(SC), Kind(K), Range(Start, End) {}
+
+public:
+  OpenACCDirectiveKind getDirectiveKind() const { return Kind; }
+
+  static bool classof(const Stmt *S) {
+    return S->getStmtClass() >= firstOpenACCConstructStmtConstant &&
+           S->getStmtClass() <= lastOpenACCConstructStmtConstant;
+  }
+
+  SourceLocation getBeginLoc() const { return Range.getBegin(); }
+  SourceLocation getEndLoc() const { return Range.getEnd(); }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_cast<OpenACCConstructStmt *>(this)->children();
+  }
+};
+
+/// This is a base class for any OpenACC statement-level constructs that have an
+/// associated statement. This class is not intended to be instantiated, but is
+/// a convenient place to hold the associated statement.
+class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
+  friend class ASTStmtWriter;
+  friend class ASTStmtReader;
+  template<typename Derived>
+  friend class RecursiveASTVisitor;
+  Stmt *AssociatedStmt = nullptr;
+
+protected:
+  OpenACCAssociatedStmtConstruct(StmtClass SC, OpenACCDirectiveKind K,
+                                 SourceLocation Start, SourceLocation End)
+      : OpenACCConstructStmt(SC, K, Start, End) {}
+
+  void setAssociatedStmt(Stmt *S) { AssociatedStmt = S; }
+  Stmt *getAssociatedStmt() { return AssociatedStmt; }
+  const Stmt *getAssociatedStmt() const {
+    return const_cast<OpenACCAssociatedStmtConstruct *>(this)
+        ->getAssociatedStmt();
+  }
+
+public:
+  child_range children() {
+    if (getAssociatedStmt())
+      return child_range(&AssociatedStmt, &AssociatedStmt + 1);
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_cast<OpenACCAssociatedStmtConstruct *>(this)->children();
+  }
+};
+/// This class represents a compute construct, representing a 'Kind' of
+/// `parallel', 'serial', or 'kernel'. These constructs are associated with a
+/// 'structured block', defined as:
+///
+///  in C or C++, an executable statement, possibly compound, with a single
+///  entry at the top and a single exit at the bottom
+///
+/// At the moment there is no real motivation to have a different AST node for
+/// those three, as they are semantically identical, and have only minor
+/// differences in the permitted list of clauses, which can be differentiated by
+/// the 'Kind'.
+class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
+  friend class ASTStmtWriter;
+  friend class ASTStmtReader;
+  OpenACCComputeConstruct()
+      : OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass,
+                                       OpenACCDirectiveKind::Invalid,
+                                       SourceLocation{}, SourceLocation{}) {}
+
+  OpenACCComputeConstruct(OpenACCDirectiveKind K, SourceLocation Start,
+                          SourceLocation End)
+      : OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass, K, Start,
+                                       End) {
+    assert((K == OpenACCDirectiveKind::Parallel ||
+            K == OpenACCDirectiveKind::Serial ||
+            K == OpenACCDirectiveKind::Kernels) &&
+           "Only parallel, serial, and kernels constructs should be "
+           "represented by this type");
+  }
+
+public:
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OpenACCComputeConstructClass;
+  }
+
+  static OpenACCComputeConstruct *CreateEmpty(const ASTContext &C, EmptyShell);
+  static OpenACCComputeConstruct *Create(const ASTContext &C,
+                                         OpenACCDirectiveKind K,
+                                         SourceLocation BeginLoc,
+                                         SourceLocation EndLoc);
+
+  void setStructuredBlock(Stmt *S) { setAssociatedStmt(S); }
+  Stmt *getStructuredBlock() { return getAssociatedStmt(); }
+  const Stmt *getStructuredBlock() const {
+    return const_cast<OpenACCComputeConstruct *>(this)->getStructuredBlock();
+  }
+};
+} // namespace clang
+#endif // LLVM_CLANG_AST_STMTOPENACC_H
diff --git a/clang/include/clang/AST/StmtVisitor.h b/clang/include/clang/AST/StmtVisitor.h
index 3e5155199eace..b94a8e62aa37d 100644
--- a/clang/include/clang/AST/StmtVisitor.h
+++ b/clang/include/clang/AST/StmtVisitor.h
@@ -20,6 +20,7 @@
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/Basic/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 3c4283f657efa..de67f0b571484 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -401,6 +401,7 @@ class TextNodeDumper
   void
   VisitLifetimeExtendedTemporaryDecl(const LifetimeExtendedTemporaryDecl *D);
   void VisitHLSLBufferDecl(const HLSLBufferDecl *D);
+  void VisitOpenACCConstructStmt(const OpenACCConstructStmt *S);
 };
 
 } // namespace clang
diff --git a/clang/include/clang/Basic/OpenACCKinds.h b/clang/include/clang/Basic/OpenACCKinds.h
index afdd0e8983c9e..4456f4afd142d 100644
--- a/clang/include/clang/Basic/OpenACCKinds.h
+++ b/clang/include/clang/Basic/OpenACCKinds.h
@@ -16,6 +16,7 @@
 
 #include "clang/Basic/Diagnostic.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace clang {
 // Represents the Construct/Directive kind of a pragma directive. Note the
@@ -65,8 +66,9 @@ enum class OpenACCDirectiveKind {
   Invalid,
 };
 
-inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
-                                             OpenACCDirectiveKind K) {
+template <typename StreamTy>
+inline StreamTy &PrintOpenACCDirectiveKind(StreamTy &Out,
+                                           OpenACCDirectiveKind K) {
   switch (K) {
   case OpenACCDirectiveKind::Parallel:
     return Out << "parallel";
@@ -134,6 +136,16 @@ inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
   llvm_unreachable("Uncovered directive kind");
 }
 
+inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
+                                             OpenACCDirectiveKind K) {
+  return PrintOpenACCDirectiveKind(Out, K);
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
+                                     OpenACCDirectiveKind K) {
+  return PrintOpenACCDirectiveKind(Out, K);
+}
+
 enum class OpenACCAtomicKind {
   Read,
   Write,
@@ -253,8 +265,8 @@ enum class OpenACCClauseKind {
   Invalid,
 };
 
-inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
-                                             OpenACCClauseKind K) {
+template <typename StreamTy>
+inline StreamTy &PrintOpenACCClauseKind(StreamTy &Out, OpenACCClauseKind K) {
   switch (K) {
   case OpenACCClauseKind::Finalize:
     return Out << "finalize";
@@ -387,6 +399,17 @@ inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
   }
   llvm_unreachable("Uncovered clause kind");
 }
+
+inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
+                                             OpenACCClauseKind K) {
+  return PrintOpenACCClauseKind(Out, K);
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
+                                     OpenACCClauseKind K) {
+  return PrintOpenACCClauseKind(Out, K);
+}
+
 enum class OpenACCDefaultClauseKind {
   /// 'none' option.
   None,
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 9d03800840fcd..b4e3ae573b95e 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -296,3 +296,9 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPErrorDirective : StmtNode<OMPExecutableDirective>;
+
+// OpenACC Constructs.
+def OpenACCConstructStmt : StmtNode<Stmt, /*abstract=*/1>;
+def OpenACCAssociatedStmtConstruct
+    : StmtNode<OpenACCConstructStmt, /*abstract=*/1>;
+def OpenACCComputeConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 9de925163599d..f31efa5117f0d 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -2018,6 +2018,9 @@ enum StmtCode {
 
   // SYCLUniqueStableNameExpr
   EXPR_SYCL_UNIQUE_STABLE_NAME,
+
+  // OpenACC Constructs
+  STMT_OPENACC_COMPUTE_CONSTRUCT,
 };
 
 /// The kinds of designators that can occur in a
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index 3b7ebbbd89ea2..fe6e03ce174e5 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -74,6 +74,7 @@
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt
index ebcb3952198a5..49dcf2e4da3e7 100644
--- a/clang/lib/AST/CMakeLists.txt
+++ b/clang/lib/AST/CMakeLists.txt
@@ -112,6 +112,7 @@ add_clang_library(clangAST
   StmtCXX.cpp
   StmtIterator.cpp
   StmtObjC.cpp
+  StmtOpenACC.cpp
   StmtOpenMP.cpp
   StmtPrinter.cpp
   StmtProfile.cpp
diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp
index afd05881cb162..fe59d6070b3e8 100644
--- a/clang/lib/AST/Stmt.cpp
+++ b/clang/lib/AST/Stmt.cpp
@@ -23,6 +23,7 @@
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/CharInfo.h"
diff --git a/clang/lib/AST/StmtOpenACC.cpp b/clang/lib/AST/StmtOpenACC.cpp
new file mode 100644
index 0000000000000..1a99c24638183
--- /dev/null
+++ b/clang/lib/AST/StmtOpenACC.cpp
@@ -0,0 +1,33 @@
+//===--- StmtOpenACC.cpp - Classes for OpenACC Constructs -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the subclesses of Stmt class declared in StmtOpenACC.h
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/StmtOpenACC.h"
+#include "clang/AST/ASTContext.h"
+using namespace clang;
+
+OpenACCComputeConstruct *
+OpenACCComputeConstruct::CreateEmpty(const ASTContext &C, EmptyShell) {
+  void *Mem = C.Allocate(sizeof(OpenACCComputeConstruct),
+                         alignof(OpenACCComputeConstruct));
+  auto *Inst = new (Mem) OpenACCComputeConstruct;
+  return Inst;
+}
+
+OpenACCComputeConstruct *
+OpenACCComputeConstruct::Create(const ASTContext &C, OpenACCDirectiveKind K,
+                                SourceLocation BeginLoc,
+                                SourceLocation EndLoc) {
+  void *Mem = C.Allocate(sizeof(OpenACCComputeConstruct),
+                         alignof(OpenACCComputeConstruct));
+  auto *Inst = new (Mem) OpenACCComputeConstruct(K, BeginLoc, EndLoc);
+  return Inst;
+}
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 1df040e06db35..d66c3ccce2094 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1137,6 +1137,15 @@ void StmtPrinter::VisitOMPTargetParallelGenericLoopDirective(
   PrintOMPExecutableDirective(Node);
 }
 
+//===----------------------------------------------------------------------===//
+//  OpenACC construct printing methods
+//===----------------------------------------------------------------------===//
+void StmtPrinter::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
+  Indent() << "#pragma acc " << S->getDirectiveKind();
+  // TODO OpenACC: Print Clauses.
+  PrintStmt(S->getStructuredBlock());
+}
+
 //===----------------------------------------------------------------------===//
 //  Expr printing methods.
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 1b817cf58b999..d224dd2e20159 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2441,6 +2441,13 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
   }
 }
 
+void StmtProfiler::VisitOpenACCComputeConstruct(
+    const OpenACCComputeConstruct *S) {
+  // VisitStmt handles children, so the AssociatedStmt is handled.
+  VisitStmt(S);
+  // TODO OpenACC: Visit Clauses.
+}
+
 void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
                    bool Canonical, bool ProfileLambdaExpr) const {
   StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
@@ -2452,3 +2459,4 @@ void Stmt::ProcessODRHash(llvm::FoldingSetNodeID &ID,
   StmtProfilerWithoutPointers Profiler(ID, Hash);
   Profiler.Visit(this);
 }
+
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 0000d26dd49eb..b683eb1edd8f1 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -2668,3 +2668,8 @@ void TextNodeDumper::VisitHLSLBufferDecl(const HLSLBufferDecl *D) {
     OS << " tbuffer";
   dumpName(D);
 }
+
+void TextNodeDumper::VisitOpenACCConstructStmt(const OpenACCConstructStmt *S) {
+  OS << " " << S->getDirectiveKind();
+  // TODO OpenACC: Dump clauses as well.
+}
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index beff0ad9da270..af51875782c9f 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -435,6 +435,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::OMPParallelMaskedDirectiveClass:
     EmitOMPParallelMaskedDirective(cast<OMPParallelMaskedDirective>(*S));
     break;
+  case Stmt::OpenACCComputeConstructClass:
+    EmitOpenACCComputeConstruct(cast<OpenACCComputeConstruct>(*S));
+    break;
   }
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 143ad64e8816b..53003b00d19bc 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -26,6 +26,7 @@
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/ExprOpenMP.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/ABI.h"
@@ -3837,6 +3838,15 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitSections(const OMPExecutableDirective &S);
 
 public:
+  //===--------------------------------------------------------------------===//
+  //                         OpenACC Emission
+  //===--------------------------------------------------------------------===//
+  void EmitOpenACCComputeConstruct(const OpenACCComputeConstruct &S) {
+    // TODO OpenACC: Implement this.  It is currently implemented as a 'no-op',
+    // simply emitting its structured block, but in the future we will implement
+    // some sort of IR.
+    EmitStmt(S.getStructuredBlock());
+  }
 
   //===--------------------------------------------------------------------===//
   //                         LValue Expression Emission
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 8d58ef5ee16d5..3563b4f683f07 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1423,6 +1423,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
     llvm_unreachable("Invalid class for expression");
 
     // Most statements can throw if any substatement can throw.
+  case Stmt::OpenACCComputeConstructClass:
   case Stmt::AttributedStmtClass:
   case Stmt::BreakStmtClass:
   case Stmt::CapturedStmtClass:
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 3ed17c3360a83..3e2ef47f12199 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -27,6 +27,7 @@
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/StmtOpenACC.h"
 #include "clang/AST/StmtOpenMP.h"
 #include "clang/Basic/DiagnosticParse.h"
 #include "clang/Basic/OpenMPKinds.h"
@@ -3995,6 +3996,12 @@ class TreeTransform {
     return getSema().CreateRecoveryExpr(BeginLoc, EndLoc, SubExprs, Type);
   }
 
+  StmtResult RebuildOpenACCComputeConstruct(OpenACCDirectiveKind K,
+                                            SourceLocation BeginLoc,
+                                            SourceLocation EndLoc,
+                                            StmtResult StrBlock) {
+    llvm_unreachable("Not yet implemented!");
+  }
 private:
   TypeLoc TransformTypeInObjectScope(TypeLoc TL,
                                      QualType ObjectType,
@@ -10993,6 +11000,21 @@ OMPClause *TreeTransform<Derived>::TransformOMPXBareClause(OMPXBareClause *C) {
   return getDerived().RebuildOMPXBareClause(C->getBeginLoc(), C->getEndLoc());
 }
 
+//===----------------------------------------------------------------------===//
+// OpenACC transformation
+//===----------------------------------------------------------------------===//
+template <typename Derived>
+StmtResult TreeTransform<Derived>::TransformOpenACCComputeConstruct(
+    OpenACCComputeConstruct *C) {
+  // TODO OpenACC: Transform clauses.
+
+  // Transform Structured Block.
+  StmtResult StrBlock = getDerived().TransformStmt(C->getStructuredBlock());
+
+  return getDerived().RebuildOpenACCComputeConstruct(
+      C->getDirectiveKind(), C->getBeginLoc(), C->getEndLoc(), StrBlock);
+}
+
 //===----------------------------------------------------------------------===//
 // Expression transformation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index d79f194fd16c6..440ec84f2788a 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2788,6 +2788,27 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
   VisitOMPLoopDirective(D);
 }
 
+//===----------------------------------------------------------------------===//
+// OpenACC Constructs/Directives.
+//===----------------------------------------------------------------------===//
+void ASTStmtReader::VisitOpenACCConstructStmt(OpenACCConstructStmt *S) {
+  S->Kind = Record.readEnum<OpenACCDirectiveKind>();
+  S->Range = Record.readSourceRange();
+  // TODO OpenACC: Deserialize Clauses.
+}
+
+void ASTStmtReader::VisitOpenACCAssociatedStmtConstruct(
+    OpenACCAssociatedStmtConstruct *S) {
+  VisitOpenACCConstructStmt(S);
+  S->setAssociatedStmt(Record.readSubStmt());
+}
+
+void ASTStmtReader::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
+  VisitStmt(S);
+  VisitOpenACCConstructStmt(S);
+}
+
+
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
@@ -4206,6 +4227,9 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) {
       S = new (Context) ConceptSpecializationExpr(Empty);
       break;
     }
+    case STMT_OPENACC_COMPUTE_CONSTRUCT:
+      S = OpenACCComputeConstruct::CreateEmpty(Context, Empty);
+      break;
 
     case EXPR_REQUIRES:
       unsigned numLocalParameters = Record[ASTStmtReader::NumExprFields];
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 5b0b90234c410..aaa5c8dbc5860 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2838,6 +2838,28 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
   Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE;
 }
 
+//===----------------------------------------------------------------------===//
+// OpenACC Constructs/Directives.
+//===----------------------------------------------------------------------===//
+void ASTStmtWriter::VisitOpenACCConstructStmt(OpenACCConstructStmt *S) {
+  Record.writeEnum(S->Kind);
+  Record.AddSourceRange(S->Range);
+  // TODO OpenACC: Serialize Clauses.
+}
+
+void ASTStmtWriter::VisitOpenACCAssociatedStmtConstruct(
+    OpenACCAssociatedStmtConstruct *S) {
+  VisitOpenACCConstructStmt(S);
+  Record.AddStmt(S->getAssociatedStmt());
+}
+
+void ASTStmtWriter::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
+  VisitStmt(S);
+  VisitOpenACCConstructStmt(S);
+  Code = serialization::STMT_OPENACC_COMPUTE_CONSTRUCT;
+}
+
+
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index ccc3c0f1e0c10..09c69f9612d96 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1821,6 +1821,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::OMPParallelGenericLoopDirectiveClass:
     case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
     case Stmt::CapturedStmtClass:
+    case Stmt::OpenACCComputeConstructClass:
     case Stmt::OMPUnrollDirectiveClass:
     case Stmt::OMPMetaDirectiveClass: {
       const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index e5c0971996e01..4ded92cbe9aea 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -6114,6 +6114,8 @@ CXString clang_getCursorKindSpelling(enum CXCursorKind Kind) {
     return cxstring::createRef("attribute(aligned)");
   case CXCursor_ConceptDecl:
     return cxstring::createRef("ConceptDecl");
+  case CXCursor_OpenACCComputeConstruct:
+    return cxstring::createRef("OpenACCComputeConstruct");
   }
 
   llvm_unreachable("Unhandled CXCursorKind");
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 01b8a23f6eac3..454bf75498618 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -870,6 +870,9 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::OMPParallelGenericLoopDirectiveClass:
     K = CXCursor_OMPParallelGenericLoopDirective;
     break;
+  case Stmt::OpenACCComputeConstructClass:
+    K = CXCursor_OpenACCComputeConstruct;
+    break;
   case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
     K = CXCursor_OMPTargetParallelGenericLoopDirective;
     break;

>From 2369b6ded87a72ab4394fdda8cfa0f393956cecb Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Thu, 8 Feb 2024 12:46:01 -0800
Subject: [PATCH 2/3] Clang-format

---
 clang/include/clang/AST/StmtOpenACC.h     | 3 +--
 clang/include/clang/AST/StmtVisitor.h     | 2 +-
 clang/lib/AST/StmtProfile.cpp             | 1 -
 clang/lib/Sema/TreeTransform.h            | 1 +
 clang/lib/Serialization/ASTReaderStmt.cpp | 1 -
 clang/lib/Serialization/ASTWriterStmt.cpp | 1 -
 6 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/clang/include/clang/AST/StmtOpenACC.h b/clang/include/clang/AST/StmtOpenACC.h
index 1e14d599ff451..67dd2b665a218 100644
--- a/clang/include/clang/AST/StmtOpenACC.h
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -63,8 +63,7 @@ class OpenACCConstructStmt : public Stmt {
 class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
   friend class ASTStmtWriter;
   friend class ASTStmtReader;
-  template<typename Derived>
-  friend class RecursiveASTVisitor;
+  template <typename Derived> friend class RecursiveASTVisitor;
   Stmt *AssociatedStmt = nullptr;
 
 protected:
diff --git a/clang/include/clang/AST/StmtVisitor.h b/clang/include/clang/AST/StmtVisitor.h
index b94a8e62aa37d..990aa2df180d4 100644
--- a/clang/include/clang/AST/StmtVisitor.h
+++ b/clang/include/clang/AST/StmtVisitor.h
@@ -13,8 +13,8 @@
 #ifndef LLVM_CLANG_AST_STMTVISITOR_H
 #define LLVM_CLANG_AST_STMTVISITOR_H
 
-#include "clang/AST/ExprConcepts.h"
 #include "clang/AST/ExprCXX.h"
+#include "clang/AST/ExprConcepts.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/Stmt.h"
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index d224dd2e20159..b545ff472e5a2 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2459,4 +2459,3 @@ void Stmt::ProcessODRHash(llvm::FoldingSetNodeID &ID,
   StmtProfilerWithoutPointers Profiler(ID, Hash);
   Profiler.Visit(this);
 }
-
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 3e2ef47f12199..6e5ae123a6ba2 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -4002,6 +4002,7 @@ class TreeTransform {
                                             StmtResult StrBlock) {
     llvm_unreachable("Not yet implemented!");
   }
+
 private:
   TypeLoc TransformTypeInObjectScope(TypeLoc TL,
                                      QualType ObjectType,
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 440ec84f2788a..3da44ffccc38a 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2808,7 +2808,6 @@ void ASTStmtReader::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
   VisitOpenACCConstructStmt(S);
 }
 
-
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index aaa5c8dbc5860..484621ae81309 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2859,7 +2859,6 @@ void ASTStmtWriter::VisitOpenACCComputeConstruct(OpenACCComputeConstruct *S) {
   Code = serialization::STMT_OPENACC_COMPUTE_CONSTRUCT;
 }
 
-
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//

>From d872d7d46798a583e7cd73b02ee5b51f69daae49 Mon Sep 17 00:00:00 2001
From: erichkeane <ekeane at nvidia.com>
Date: Fri, 9 Feb 2024 06:36:47 -0800
Subject: [PATCH 3/3] Make setStructuredBlock private in ComputeConstruct:

---
 clang/include/clang/AST/StmtOpenACC.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/AST/StmtOpenACC.h b/clang/include/clang/AST/StmtOpenACC.h
index 67dd2b665a218..9424f4f080785 100644
--- a/clang/include/clang/AST/StmtOpenACC.h
+++ b/clang/include/clang/AST/StmtOpenACC.h
@@ -103,6 +103,7 @@ class OpenACCAssociatedStmtConstruct : public OpenACCConstructStmt {
 class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
   friend class ASTStmtWriter;
   friend class ASTStmtReader;
+  friend class ASTContext;
   OpenACCComputeConstruct()
       : OpenACCAssociatedStmtConstruct(OpenACCComputeConstructClass,
                                        OpenACCDirectiveKind::Invalid,
@@ -119,6 +120,8 @@ class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
            "represented by this type");
   }
 
+  void setStructuredBlock(Stmt *S) { setAssociatedStmt(S); }
+
 public:
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OpenACCComputeConstructClass;
@@ -130,7 +133,6 @@ class OpenACCComputeConstruct : public OpenACCAssociatedStmtConstruct {
                                          SourceLocation BeginLoc,
                                          SourceLocation EndLoc);
 
-  void setStructuredBlock(Stmt *S) { setAssociatedStmt(S); }
   Stmt *getStructuredBlock() { return getAssociatedStmt(); }
   const Stmt *getStructuredBlock() const {
     return const_cast<OpenACCComputeConstruct *>(this)->getStructuredBlock();



More information about the cfe-commits mailing list