[llvm-branch-commits] [clang] [clang][OpenMP] Add AST node for root of compound directive (PR #118878)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 5 13:49:49 PST 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/118878

This will be used to print the original directive source from the AST after splitting compound directives.

>From 1447ec21597f752b29e367a46f06eecdf9c81dd7 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 30 Oct 2024 13:34:21 -0500
Subject: [PATCH] [clang][OpenMP] Add AST node for root of compound directive

This will be used to print the original directive source from the AST
after splitting compound directives.
---
 clang/bindings/python/clang/cindex.py         |  3 +
 clang/include/clang-c/Index.h                 |  4 ++
 clang/include/clang/AST/RecursiveASTVisitor.h |  3 +
 clang/include/clang/AST/StmtOpenMP.h          | 60 +++++++++++++++++++
 clang/include/clang/AST/TextNodeDumper.h      |  1 +
 clang/include/clang/Basic/StmtNodes.td        |  1 +
 clang/include/clang/Sema/SemaOpenMP.h         |  6 ++
 .../include/clang/Serialization/ASTBitCodes.h |  1 +
 clang/lib/AST/StmtOpenMP.cpp                  | 15 +++++
 clang/lib/AST/StmtPrinter.cpp                 |  6 ++
 clang/lib/AST/StmtProfile.cpp                 |  5 ++
 clang/lib/AST/TextNodeDumper.cpp              |  7 +++
 clang/lib/CodeGen/CGStmt.cpp                  |  4 ++
 clang/lib/Sema/SemaExceptionSpec.cpp          |  1 +
 clang/lib/Sema/SemaOpenMP.cpp                 |  7 +++
 clang/lib/Sema/TreeTransform.h                |  8 +++
 clang/lib/Serialization/ASTReaderStmt.cpp     | 15 +++++
 clang/lib/Serialization/ASTWriterStmt.cpp     |  7 +++
 clang/lib/StaticAnalyzer/Core/ExprEngine.cpp  |  1 +
 19 files changed, 155 insertions(+)

diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index 3ae7c479153690..5174e16f28f06d 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -1416,6 +1416,9 @@ def is_unexposed(self):
     # OpenMP opaque loop-associated directive.
     OMP_OPAQUE_LOOP_DIRECTIVE = 311
 
+    # OpenMP compound root directive.
+    OMP_COMPOUND_ROOT_DIRECTIVE = 312
+
     # OpenACC Compute Construct.
     OPEN_ACC_COMPUTE_DIRECTIVE = 320
 
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 5d1db153aaafed..02ce2b7690ef06 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2166,6 +2166,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPOpaqueLoopDirective = 311,
 
+  /** OpenMP compound root directive.
+   */
+  CXCursor_OMPCompoundRootDirective = 312,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e6fe46acb5fbc5..2881604ec781a3 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3026,6 +3026,9 @@ RecursiveASTVisitor<Derived>::TraverseOMPLoopDirective(OMPLoopDirective *S) {
   return TraverseOMPExecutableDirective(S);
 }
 
+DEF_TRAVERSE_STMT(OMPCompoundRootDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPOpaqueBlockDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 65434967142c84..4a3c2a53377d69 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1560,6 +1560,66 @@ class OMPLoopDirective : public OMPLoopBasedDirective {
   }
 };
 
+/// This represents the root of the tree of broken-up compound directive.
+/// It is used to implement pretty-printing consistent with the original
+/// source. This is a pass-through directive for the purposes of semantic
+/// analysis and code generation.
+/// The getDirectiveKind() will return the id of the original, compound
+/// directive. The associated statement will be the outermost one of the
+/// constituent directives. The associated statement is always present.
+class OMPCompoundRootDirective final : public OMPExecutableDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Build directive with the given start and end location.
+  ///
+  /// \param DKind The OpenMP directive kind.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending location of the directive.
+  ///
+  OMPCompoundRootDirective(OpenMPDirectiveKind DKind, SourceLocation StartLoc,
+                           SourceLocation EndLoc)
+      : OMPExecutableDirective(OMPCompoundRootDirectiveClass, DKind, StartLoc,
+                               EndLoc) {}
+
+  /// Build an empty directive.
+  ///
+  /// \param Kind The OpenMP directive kind.
+  ///
+  explicit OMPCompoundRootDirective(OpenMPDirectiveKind DKind)
+      : OMPExecutableDirective(OMPCompoundRootDirectiveClass, DKind,
+                               SourceLocation(), SourceLocation()) {}
+
+public:
+  /// Creates directive with a list of \a Clauses.
+  ///
+  /// \param C AST context.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending Location of the directive.
+  /// \param DKind The OpenMP directive kind.
+  /// \param Clauses List of clauses.
+  /// \param AssociatedStmt Statement, associated with the directive.
+  static OMPCompoundRootDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
+         Stmt *AssociatedStmt);
+
+  /// Creates an empty directive with the place for \a NumClauses
+  /// clauses.
+  ///
+  /// \param C AST context.
+  /// \param DKind The OpenMP directive kind.
+  /// \param NumClauses Number of clauses.
+  ///
+  static OMPCompoundRootDirective *
+  CreateEmpty(const ASTContext &C, OpenMPDirectiveKind DKind,
+              unsigned NumClauses, EmptyShell);
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPCompoundRootDirectiveClass;
+  }
+};
+
 /// This represents any executable OpenMP directive that is not loop-
 /// associated (usually block-associated).
 class OMPOpaqueBlockDirective final : public OMPExecutableDirective {
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 6606cce3ff0852..ea0fafb884cb44 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -366,6 +366,7 @@ class TextNodeDumper
   void VisitPragmaCommentDecl(const PragmaCommentDecl *D);
   void VisitPragmaDetectMismatchDecl(const PragmaDetectMismatchDecl *D);
   void VisitOMPExecutableDirective(const OMPExecutableDirective *D);
+  void VisitOMPCompoundRootDirective(const OMPCompoundRootDirective *D);
   void VisitOMPOpaqueBlockDirective(const OMPOpaqueBlockDirective *D);
   void VisitOMPOpaqueLoopDirective(const OMPOpaqueLoopDirective *D);
   void VisitOMPDeclareReductionDecl(const OMPDeclareReductionDecl *D);
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index da8c91043c814e..e365dea0dc6f14 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -225,6 +225,7 @@ def OMPExecutableDirective : StmtNode<Stmt, 1>;
 def OMPMetaDirective : StmtNode<OMPExecutableDirective>;
 def OMPLoopBasedDirective : StmtNode<OMPExecutableDirective, 1>;
 def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>;
+def OMPCompoundRootDirective : StmtNode<OMPExecutableDirective>;
 def OMPOpaqueBlockDirective : StmtNode<OMPExecutableDirective>;
 def OMPOpaqueLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 18103a72c9cd31..45ddde3e1ade9b 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -849,6 +849,12 @@ class SemaOpenMP : public SemaBase {
       ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
       SourceLocation AppendArgsLoc, SourceRange SR);
 
+  StmtResult ActOnOpenMPCompoundRootDirective(OpenMPDirectiveKind DKind,
+                                              ArrayRef<OMPClause *> Clauses,
+                                              Stmt *AStmt,
+                                              SourceLocation StartLoc,
+                                              SourceLocation EndLoc);
+
   StmtResult ActOnOpenMPOpaqueBlockDirective(
       OpenMPDirectiveKind Kind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
       OpenMPDirectiveKind CancelRegion, const DeclarationNameInfo &DirName,
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index ac332f69826357..fabe9b64cfa75e 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1904,6 +1904,7 @@ enum StmtCode {
   STMT_SEH_TRY,                     // SEHTryStmt
 
   // OpenMP directives
+  STMT_OMP_COMPOUND_ROOT_DIRECTIVE,
   STMT_OMP_OPAQUE_BLOCK_DIRECTIVE,
   STMT_OMP_OPAQUE_LOOP_DIRECTIVE,
   STMT_OMP_META_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 3bd88acb9a3c6e..7550aca7929127 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -259,6 +259,21 @@ void OMPLoopDirective::setFinalsConditions(ArrayRef<Expr *> A) {
   llvm::copy(A, getFinalsConditions().begin());
 }
 
+OMPCompoundRootDirective *OMPCompoundRootDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
+    Stmt *AssociatedStmt) {
+  return createDirective<OMPCompoundRootDirective>(
+      C, Clauses, AssociatedStmt, /*NumChildren=*/0, DKind, StartLoc, EndLoc);
+}
+
+OMPCompoundRootDirective *OMPCompoundRootDirective::CreateEmpty(
+    const ASTContext &C, OpenMPDirectiveKind DKind, unsigned NumClauses,
+    EmptyShell) {
+  return createEmptyDirective<OMPCompoundRootDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/0, DKind);
+}
+
 OMPOpaqueBlockDirective *OMPOpaqueBlockDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index cd4f91337ef42c..b756b1c47c6060 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -735,6 +735,12 @@ void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S,
     PrintStmt(S->getRawStmt());
 }
 
+void StmtPrinter::VisitOMPCompoundRootDirective(
+    OMPCompoundRootDirective *Node) {
+  OS << "OMPCompoundRootDirective\n";
+  PrintStmt(Node, /*ForceNoStmt=*/false);
+}
+
 void StmtPrinter::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *Node) {
   OpenMPDirectiveKind DKind = Node->getDirectiveKind();
   bool ForceNoStmt = false;
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index a4f032c3f78c5f..8880a9e277db6a 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -967,6 +967,11 @@ StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) {
       P.Visit(*I);
 }
 
+void StmtProfiler::VisitOMPCompoundRootDirective(
+    const OMPCompoundRootDirective *S) {
+  VisitOMPExecutableDirective(S);
+}
+
 void StmtProfiler::VisitOMPOpaqueBlockDirective(
     const OMPOpaqueBlockDirective *S) {
   VisitOMPExecutableDirective(S);
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 09d3802c5e06dc..d1378cba65d761 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -2372,6 +2372,13 @@ void TextNodeDumper::VisitOMPExecutableDirective(
     OS << " openmp_standalone_directive";
 }
 
+void TextNodeDumper::VisitOMPCompoundRootDirective(
+    const OMPCompoundRootDirective *D) {
+  VisitOMPExecutableDirective(D);
+  OpenMPDirectiveKind DKind = D->getDirectiveKind();
+  OS << " '" << llvm::omp::getOpenMPDirectiveName(DKind) << '\'';
+}
+
 void TextNodeDumper::VisitOMPOpaqueBlockDirective(
     const OMPOpaqueBlockDirective *D) {
   VisitOMPExecutableDirective(D);
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 42e9490ebbaa6e..82781eae764ab1 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -204,6 +204,10 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::SEHTryStmtClass:
     EmitSEHTryStmt(cast<SEHTryStmt>(*S));
     break;
+  case Stmt::OMPCompoundRootDirectiveClass:
+    // Skip this node, go straight through to the associated statement.
+    EmitStmt(cast<OMPCompoundRootDirective>(*S).getAssociatedStmt(), Attrs);
+    break;
   case Stmt::OMPOpaqueBlockDirectiveClass:
   case Stmt::OMPOpaqueLoopDirectiveClass:
     // These are catch-all nodes for executable OpenMP directives in templates.
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 467eb89afba52f..b3743b8fd1c74c 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1452,6 +1452,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPMaskedTaskLoopDirectiveClass:
   case Stmt::OMPMasterTaskLoopSimdDirectiveClass:
   case Stmt::OMPMaskedTaskLoopSimdDirectiveClass:
+  case Stmt::OMPCompoundRootDirectiveClass:
   case Stmt::OMPOpaqueBlockDirectiveClass:
   case Stmt::OMPOpaqueLoopDirectiveClass:
   case Stmt::OMPOrderedDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index fbed2952de3fb6..3a59839a0b1828 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -23309,6 +23309,13 @@ StmtResult SemaOpenMP::ActOnOpenMPScopeDirective(ArrayRef<OMPClause *> Clauses,
 
 static bool checkScanScope(Sema &S, Scope *CurrentS, SourceLocation Loc);
 
+StmtResult SemaOpenMP::ActOnOpenMPCompoundRootDirective(
+    OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
+    SourceLocation StartLoc, SourceLocation EndLoc) {
+  return OMPCompoundRootDirective::Create(
+      getASTContext(), StartLoc, EndLoc, DKind, Clauses, AStmt);
+}
+
 StmtResult SemaOpenMP::ActOnOpenMPOpaqueBlockDirective(
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
     OpenMPDirectiveKind CancelRegion, const DeclarationNameInfo &DirName,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 3a496811676787..cefaa9bd716e98 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -9406,6 +9406,14 @@ StmtResult TreeTransform<Derived>::TransformOMPInformationalDirective(
       D->getBeginLoc(), D->getEndLoc());
 }
 
+template <typename Derived>
+StmtResult TreeTransform<Derived>::TransformOMPCompoundRootDirective(
+    OMPCompoundRootDirective *D) {
+  // This function should never be found in a template. Directive splitting
+  // only happens in non-template functions.
+  llvm_unreachable("TransformOMPCompoundRootDirective in a template");
+}
+
 template <typename Derived>
 StmtResult TreeTransform<Derived>::TransformOMPOpaqueBlockDirective(
     OMPOpaqueBlockDirective *D) {
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 4585693953be4d..1754c2cd89cdd4 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2409,6 +2409,13 @@ void ASTStmtReader::VisitOMPLoopDirective(OMPLoopDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtReader::VisitOMPCompoundRootDirective(OMPCompoundRootDirective *D) {
+  VisitStmt(D);
+  // The DKind was read in ReadStmtFromStream.
+  Record.skipInts(1);
+  VisitOMPExecutableDirective(D);
+}
+
 void ASTStmtReader::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *D) {
   VisitStmt(D);
   // The DKind was read in ReadStmtFromStream.
@@ -3513,6 +3520,14 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) {
       S = OMPCanonicalLoop::createEmpty(Context);
       break;
 
+    case STMT_OMP_COMPOUND_ROOT_DIRECTIVE: {
+      unsigned DKind = Record[ASTStmtReader::NumStmtFields];
+      unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1];
+      S = OMPCompoundRootDirective::CreateEmpty(
+          Context, static_cast<OpenMPDirectiveKind>(DKind), NumClauses, Empty);
+      break;
+    }
+
     case STMT_OMP_OPAQUE_BLOCK_DIRECTIVE: {
       unsigned DKind = Record[ASTStmtReader::NumStmtFields];
       unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1];
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 47968dd78454a4..ed31af10d7417d 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2409,6 +2409,13 @@ void ASTStmtWriter::VisitOMPLoopDirective(OMPLoopDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtWriter::VisitOMPCompoundRootDirective(OMPCompoundRootDirective *D) {
+  VisitStmt(D);
+  Record.writeUInt32(static_cast<unsigned>(D->getDirectiveKind()));
+  VisitOMPExecutableDirective(D);
+  Code = serialization::STMT_OMP_COMPOUND_ROOT_DIRECTIVE;
+}
+
 void ASTStmtWriter::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *D) {
   VisitStmt(D);
   Record.writeUInt32(static_cast<unsigned>(D->getDirectiveKind()));
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index b980d6d25f9a24..cd3b49749c4ce8 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1748,6 +1748,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::SEHLeaveStmtClass:
     case Stmt::SEHFinallyStmtClass:
     case Stmt::OMPCanonicalLoopClass:
+    case Stmt::OMPCompoundRootDirectiveClass:
     case Stmt::OMPOpaqueBlockDirectiveClass:
     case Stmt::OMPOpaqueLoopDirectiveClass:
     case Stmt::OMPParallelDirectiveClass:



More information about the llvm-branch-commits mailing list