[llvm-branch-commits] [clang] [clang][OpenMP] Add AST node for root of compound directive (PR #118878)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Mar 25 14:44:39 PDT 2025


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/118878

>From 1447ec21597f752b29e367a46f06eecdf9c81dd7 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 30 Oct 2024 13:34:21 -0500
Subject: [PATCH 1/2] [clang][OpenMP] Add AST node for root of compound
 directive

This will be used to print the original directive source from the AST
after splitting compound directives.
---
 clang/bindings/python/clang/cindex.py         |  3 +
 clang/include/clang-c/Index.h                 |  4 ++
 clang/include/clang/AST/RecursiveASTVisitor.h |  3 +
 clang/include/clang/AST/StmtOpenMP.h          | 60 +++++++++++++++++++
 clang/include/clang/AST/TextNodeDumper.h      |  1 +
 clang/include/clang/Basic/StmtNodes.td        |  1 +
 clang/include/clang/Sema/SemaOpenMP.h         |  6 ++
 .../include/clang/Serialization/ASTBitCodes.h |  1 +
 clang/lib/AST/StmtOpenMP.cpp                  | 15 +++++
 clang/lib/AST/StmtPrinter.cpp                 |  6 ++
 clang/lib/AST/StmtProfile.cpp                 |  5 ++
 clang/lib/AST/TextNodeDumper.cpp              |  7 +++
 clang/lib/CodeGen/CGStmt.cpp                  |  4 ++
 clang/lib/Sema/SemaExceptionSpec.cpp          |  1 +
 clang/lib/Sema/SemaOpenMP.cpp                 |  7 +++
 clang/lib/Sema/TreeTransform.h                |  8 +++
 clang/lib/Serialization/ASTReaderStmt.cpp     | 15 +++++
 clang/lib/Serialization/ASTWriterStmt.cpp     |  7 +++
 clang/lib/StaticAnalyzer/Core/ExprEngine.cpp  |  1 +
 19 files changed, 155 insertions(+)

diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py
index 3ae7c47915369..5174e16f28f06 100644
--- a/clang/bindings/python/clang/cindex.py
+++ b/clang/bindings/python/clang/cindex.py
@@ -1416,6 +1416,9 @@ def is_unexposed(self):
     # OpenMP opaque loop-associated directive.
     OMP_OPAQUE_LOOP_DIRECTIVE = 311
 
+    # OpenMP compound root directive.
+    OMP_COMPOUND_ROOT_DIRECTIVE = 312
+
     # OpenACC Compute Construct.
     OPEN_ACC_COMPUTE_DIRECTIVE = 320
 
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 5d1db153aaafe..02ce2b7690ef0 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2166,6 +2166,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPOpaqueLoopDirective = 311,
 
+  /** OpenMP compound root directive.
+   */
+  CXCursor_OMPCompoundRootDirective = 312,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e6fe46acb5fbc..2881604ec781a 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3026,6 +3026,9 @@ RecursiveASTVisitor<Derived>::TraverseOMPLoopDirective(OMPLoopDirective *S) {
   return TraverseOMPExecutableDirective(S);
 }
 
+DEF_TRAVERSE_STMT(OMPCompoundRootDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPOpaqueBlockDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 65434967142c8..4a3c2a53377d6 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1560,6 +1560,66 @@ class OMPLoopDirective : public OMPLoopBasedDirective {
   }
 };
 
+/// This represents the root of the tree of broken-up compound directive.
+/// It is used to implement pretty-printing consistent with the original
+/// source. This is a pass-through directive for the purposes of semantic
+/// analysis and code generation.
+/// The getDirectiveKind() will return the id of the original, compound
+/// directive. The associated statement will be the outermost one of the
+/// constituent directives. The associated statement is always present.
+class OMPCompoundRootDirective final : public OMPExecutableDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Build directive with the given start and end location.
+  ///
+  /// \param DKind The OpenMP directive kind.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending location of the directive.
+  ///
+  OMPCompoundRootDirective(OpenMPDirectiveKind DKind, SourceLocation StartLoc,
+                           SourceLocation EndLoc)
+      : OMPExecutableDirective(OMPCompoundRootDirectiveClass, DKind, StartLoc,
+                               EndLoc) {}
+
+  /// Build an empty directive.
+  ///
+  /// \param Kind The OpenMP directive kind.
+  ///
+  explicit OMPCompoundRootDirective(OpenMPDirectiveKind DKind)
+      : OMPExecutableDirective(OMPCompoundRootDirectiveClass, DKind,
+                               SourceLocation(), SourceLocation()) {}
+
+public:
+  /// Creates directive with a list of \a Clauses.
+  ///
+  /// \param C AST context.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending Location of the directive.
+  /// \param DKind The OpenMP directive kind.
+  /// \param Clauses List of clauses.
+  /// \param AssociatedStmt Statement, associated with the directive.
+  static OMPCompoundRootDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
+         Stmt *AssociatedStmt);
+
+  /// Creates an empty directive with the place for \a NumClauses
+  /// clauses.
+  ///
+  /// \param C AST context.
+  /// \param DKind The OpenMP directive kind.
+  /// \param NumClauses Number of clauses.
+  ///
+  static OMPCompoundRootDirective *
+  CreateEmpty(const ASTContext &C, OpenMPDirectiveKind DKind,
+              unsigned NumClauses, EmptyShell);
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPCompoundRootDirectiveClass;
+  }
+};
+
 /// This represents any executable OpenMP directive that is not loop-
 /// associated (usually block-associated).
 class OMPOpaqueBlockDirective final : public OMPExecutableDirective {
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 6606cce3ff085..ea0fafb884cb4 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -366,6 +366,7 @@ class TextNodeDumper
   void VisitPragmaCommentDecl(const PragmaCommentDecl *D);
   void VisitPragmaDetectMismatchDecl(const PragmaDetectMismatchDecl *D);
   void VisitOMPExecutableDirective(const OMPExecutableDirective *D);
+  void VisitOMPCompoundRootDirective(const OMPCompoundRootDirective *D);
   void VisitOMPOpaqueBlockDirective(const OMPOpaqueBlockDirective *D);
   void VisitOMPOpaqueLoopDirective(const OMPOpaqueLoopDirective *D);
   void VisitOMPDeclareReductionDecl(const OMPDeclareReductionDecl *D);
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index da8c91043c814..e365dea0dc6f1 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -225,6 +225,7 @@ def OMPExecutableDirective : StmtNode<Stmt, 1>;
 def OMPMetaDirective : StmtNode<OMPExecutableDirective>;
 def OMPLoopBasedDirective : StmtNode<OMPExecutableDirective, 1>;
 def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>;
+def OMPCompoundRootDirective : StmtNode<OMPExecutableDirective>;
 def OMPOpaqueBlockDirective : StmtNode<OMPExecutableDirective>;
 def OMPOpaqueLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 18103a72c9cd3..45ddde3e1ade9 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -849,6 +849,12 @@ class SemaOpenMP : public SemaBase {
       ArrayRef<OMPInteropInfo> AppendArgs, SourceLocation AdjustArgsLoc,
       SourceLocation AppendArgsLoc, SourceRange SR);
 
+  StmtResult ActOnOpenMPCompoundRootDirective(OpenMPDirectiveKind DKind,
+                                              ArrayRef<OMPClause *> Clauses,
+                                              Stmt *AStmt,
+                                              SourceLocation StartLoc,
+                                              SourceLocation EndLoc);
+
   StmtResult ActOnOpenMPOpaqueBlockDirective(
       OpenMPDirectiveKind Kind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
       OpenMPDirectiveKind CancelRegion, const DeclarationNameInfo &DirName,
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index ac332f6982635..fabe9b64cfa75 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1904,6 +1904,7 @@ enum StmtCode {
   STMT_SEH_TRY,                     // SEHTryStmt
 
   // OpenMP directives
+  STMT_OMP_COMPOUND_ROOT_DIRECTIVE,
   STMT_OMP_OPAQUE_BLOCK_DIRECTIVE,
   STMT_OMP_OPAQUE_LOOP_DIRECTIVE,
   STMT_OMP_META_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 3bd88acb9a3c6..7550aca792912 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -259,6 +259,21 @@ void OMPLoopDirective::setFinalsConditions(ArrayRef<Expr *> A) {
   llvm::copy(A, getFinalsConditions().begin());
 }
 
+OMPCompoundRootDirective *OMPCompoundRootDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
+    Stmt *AssociatedStmt) {
+  return createDirective<OMPCompoundRootDirective>(
+      C, Clauses, AssociatedStmt, /*NumChildren=*/0, DKind, StartLoc, EndLoc);
+}
+
+OMPCompoundRootDirective *OMPCompoundRootDirective::CreateEmpty(
+    const ASTContext &C, OpenMPDirectiveKind DKind, unsigned NumClauses,
+    EmptyShell) {
+  return createEmptyDirective<OMPCompoundRootDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/0, DKind);
+}
+
 OMPOpaqueBlockDirective *OMPOpaqueBlockDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index cd4f91337ef42..b756b1c47c606 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -735,6 +735,12 @@ void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S,
     PrintStmt(S->getRawStmt());
 }
 
+void StmtPrinter::VisitOMPCompoundRootDirective(
+    OMPCompoundRootDirective *Node) {
+  OS << "OMPCompoundRootDirective\n";
+  PrintStmt(Node, /*ForceNoStmt=*/false);
+}
+
 void StmtPrinter::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *Node) {
   OpenMPDirectiveKind DKind = Node->getDirectiveKind();
   bool ForceNoStmt = false;
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index a4f032c3f78c5..8880a9e277db6 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -967,6 +967,11 @@ StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) {
       P.Visit(*I);
 }
 
+void StmtProfiler::VisitOMPCompoundRootDirective(
+    const OMPCompoundRootDirective *S) {
+  VisitOMPExecutableDirective(S);
+}
+
 void StmtProfiler::VisitOMPOpaqueBlockDirective(
     const OMPOpaqueBlockDirective *S) {
   VisitOMPExecutableDirective(S);
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 09d3802c5e06d..d1378cba65d76 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -2372,6 +2372,13 @@ void TextNodeDumper::VisitOMPExecutableDirective(
     OS << " openmp_standalone_directive";
 }
 
+void TextNodeDumper::VisitOMPCompoundRootDirective(
+    const OMPCompoundRootDirective *D) {
+  VisitOMPExecutableDirective(D);
+  OpenMPDirectiveKind DKind = D->getDirectiveKind();
+  OS << " '" << llvm::omp::getOpenMPDirectiveName(DKind) << '\'';
+}
+
 void TextNodeDumper::VisitOMPOpaqueBlockDirective(
     const OMPOpaqueBlockDirective *D) {
   VisitOMPExecutableDirective(D);
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 42e9490ebbaa6..82781eae764ab 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -204,6 +204,10 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::SEHTryStmtClass:
     EmitSEHTryStmt(cast<SEHTryStmt>(*S));
     break;
+  case Stmt::OMPCompoundRootDirectiveClass:
+    // Skip this node, go straight through to the associated statement.
+    EmitStmt(cast<OMPCompoundRootDirective>(*S).getAssociatedStmt(), Attrs);
+    break;
   case Stmt::OMPOpaqueBlockDirectiveClass:
   case Stmt::OMPOpaqueLoopDirectiveClass:
     // These are catch-all nodes for executable OpenMP directives in templates.
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 467eb89afba52..b3743b8fd1c74 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1452,6 +1452,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPMaskedTaskLoopDirectiveClass:
   case Stmt::OMPMasterTaskLoopSimdDirectiveClass:
   case Stmt::OMPMaskedTaskLoopSimdDirectiveClass:
+  case Stmt::OMPCompoundRootDirectiveClass:
   case Stmt::OMPOpaqueBlockDirectiveClass:
   case Stmt::OMPOpaqueLoopDirectiveClass:
   case Stmt::OMPOrderedDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index fbed2952de3fb..3a59839a0b182 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -23309,6 +23309,13 @@ StmtResult SemaOpenMP::ActOnOpenMPScopeDirective(ArrayRef<OMPClause *> Clauses,
 
 static bool checkScanScope(Sema &S, Scope *CurrentS, SourceLocation Loc);
 
+StmtResult SemaOpenMP::ActOnOpenMPCompoundRootDirective(
+    OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
+    SourceLocation StartLoc, SourceLocation EndLoc) {
+  return OMPCompoundRootDirective::Create(
+      getASTContext(), StartLoc, EndLoc, DKind, Clauses, AStmt);
+}
+
 StmtResult SemaOpenMP::ActOnOpenMPOpaqueBlockDirective(
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
     OpenMPDirectiveKind CancelRegion, const DeclarationNameInfo &DirName,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 3a49681167678..cefaa9bd716e9 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -9406,6 +9406,14 @@ StmtResult TreeTransform<Derived>::TransformOMPInformationalDirective(
       D->getBeginLoc(), D->getEndLoc());
 }
 
+template <typename Derived>
+StmtResult TreeTransform<Derived>::TransformOMPCompoundRootDirective(
+    OMPCompoundRootDirective *D) {
+  // This function should never be found in a template. Directive splitting
+  // only happens in non-template functions.
+  llvm_unreachable("TransformOMPCompoundRootDirective in a template");
+}
+
 template <typename Derived>
 StmtResult TreeTransform<Derived>::TransformOMPOpaqueBlockDirective(
     OMPOpaqueBlockDirective *D) {
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 4585693953be4..1754c2cd89cdd 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2409,6 +2409,13 @@ void ASTStmtReader::VisitOMPLoopDirective(OMPLoopDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtReader::VisitOMPCompoundRootDirective(OMPCompoundRootDirective *D) {
+  VisitStmt(D);
+  // The DKind was read in ReadStmtFromStream.
+  Record.skipInts(1);
+  VisitOMPExecutableDirective(D);
+}
+
 void ASTStmtReader::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *D) {
   VisitStmt(D);
   // The DKind was read in ReadStmtFromStream.
@@ -3513,6 +3520,14 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) {
       S = OMPCanonicalLoop::createEmpty(Context);
       break;
 
+    case STMT_OMP_COMPOUND_ROOT_DIRECTIVE: {
+      unsigned DKind = Record[ASTStmtReader::NumStmtFields];
+      unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1];
+      S = OMPCompoundRootDirective::CreateEmpty(
+          Context, static_cast<OpenMPDirectiveKind>(DKind), NumClauses, Empty);
+      break;
+    }
+
     case STMT_OMP_OPAQUE_BLOCK_DIRECTIVE: {
       unsigned DKind = Record[ASTStmtReader::NumStmtFields];
       unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1];
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 47968dd78454a..ed31af10d7417 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2409,6 +2409,13 @@ void ASTStmtWriter::VisitOMPLoopDirective(OMPLoopDirective *D) {
   VisitOMPLoopBasedDirective(D);
 }
 
+void ASTStmtWriter::VisitOMPCompoundRootDirective(OMPCompoundRootDirective *D) {
+  VisitStmt(D);
+  Record.writeUInt32(static_cast<unsigned>(D->getDirectiveKind()));
+  VisitOMPExecutableDirective(D);
+  Code = serialization::STMT_OMP_COMPOUND_ROOT_DIRECTIVE;
+}
+
 void ASTStmtWriter::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *D) {
   VisitStmt(D);
   Record.writeUInt32(static_cast<unsigned>(D->getDirectiveKind()));
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index b980d6d25f9a2..cd3b49749c4ce 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1748,6 +1748,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::SEHLeaveStmtClass:
     case Stmt::SEHFinallyStmtClass:
     case Stmt::OMPCanonicalLoopClass:
+    case Stmt::OMPCompoundRootDirectiveClass:
     case Stmt::OMPOpaqueBlockDirectiveClass:
     case Stmt::OMPOpaqueLoopDirectiveClass:
     case Stmt::OMPParallelDirectiveClass:

>From 283b0b527836bee29cf633964fa0474f931a9bfc Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 25 Mar 2025 15:38:33 -0500
Subject: [PATCH 2/2] Store original Stmt in OMPCompoundRootDirective for
 unparsing

That statement will be ignored for all purposes, except for regenerating
the original source code.
---
 clang/include/clang/AST/StmtOpenMP.h  |  8 +++++++-
 clang/include/clang/Sema/SemaOpenMP.h |  2 +-
 clang/lib/AST/StmtOpenMP.cpp          | 10 ++++++----
 clang/lib/AST/StmtPrinter.cpp         | 15 +++++++++++++--
 clang/lib/AST/StmtProfile.cpp         |  1 +
 clang/lib/CodeGen/CGStmt.cpp          |  1 +
 clang/lib/Sema/SemaOpenMP.cpp         |  4 ++--
 7 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 4a3c2a53377d6..0d05dfd7956d2 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1571,6 +1571,8 @@ class OMPCompoundRootDirective final : public OMPExecutableDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
 
+  void setUnparseStmt(Stmt *S) { Data->getChildren()[0] = S; }
+
   /// Build directive with the given start and end location.
   ///
   /// \param DKind The OpenMP directive kind.
@@ -1602,7 +1604,7 @@ class OMPCompoundRootDirective final : public OMPExecutableDirective {
   static OMPCompoundRootDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
-         Stmt *AssociatedStmt);
+         Stmt *AssociatedStmt, Stmt *UnparseStmt);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -1618,6 +1620,10 @@ class OMPCompoundRootDirective final : public OMPExecutableDirective {
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPCompoundRootDirectiveClass;
   }
+
+  Stmt *getUnparseStmt() const {
+    return cast_or_null<Stmt>(Data->getChildren()[0]);
+  }
 };
 
 /// This represents any executable OpenMP directive that is not loop-
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 45ddde3e1ade9..f4a9d8a3bdac4 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -851,7 +851,7 @@ class SemaOpenMP : public SemaBase {
 
   StmtResult ActOnOpenMPCompoundRootDirective(OpenMPDirectiveKind DKind,
                                               ArrayRef<OMPClause *> Clauses,
-                                              Stmt *AStmt,
+                                              Stmt *AStmt, Stmt *UStmt,
                                               SourceLocation StartLoc,
                                               SourceLocation EndLoc);
 
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 7550aca792912..f86bbff57d896 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -262,16 +262,18 @@ void OMPLoopDirective::setFinalsConditions(ArrayRef<Expr *> A) {
 OMPCompoundRootDirective *OMPCompoundRootDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses,
-    Stmt *AssociatedStmt) {
-  return createDirective<OMPCompoundRootDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/0, DKind, StartLoc, EndLoc);
+    Stmt *AssociatedStmt, Stmt *UnparseStmt) {
+  auto *Dir = createDirective<OMPCompoundRootDirective>(
+      C, Clauses, AssociatedStmt, /*NumChildren=*/1, DKind, StartLoc, EndLoc);
+  Dir->setUnparseStmt(UnparseStmt);
+  return Dir;
 }
 
 OMPCompoundRootDirective *OMPCompoundRootDirective::CreateEmpty(
     const ASTContext &C, OpenMPDirectiveKind DKind, unsigned NumClauses,
     EmptyShell) {
   return createEmptyDirective<OMPCompoundRootDirective>(
-      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/0, DKind);
+      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/1, DKind);
 }
 
 OMPOpaqueBlockDirective *OMPOpaqueBlockDirective::Create(
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index b756b1c47c606..5188d9990e1af 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -737,8 +737,19 @@ void StmtPrinter::PrintOMPExecutableDirective(OMPExecutableDirective *S,
 
 void StmtPrinter::VisitOMPCompoundRootDirective(
     OMPCompoundRootDirective *Node) {
-  OS << "OMPCompoundRootDirective\n";
-  PrintStmt(Node, /*ForceNoStmt=*/false);
+  llvm::omp::Directive CompKind = Node->getDirectiveKind();
+
+  Indent() << "#pragma omp " << llvm::omp::getOpenMPDirectiveName(CompKind);
+  OMPClausePrinter Printer(OS, Policy);
+  for (auto *Clause : Node->clauses()) {
+    if (Clause && !Clause->isImplicit()) {
+      OS << ' ';
+      Printer.Visit(Clause);
+    }
+  }
+  OS << NL;
+
+  PrintStmt(Node->getUnparseStmt(), /*ForceNoStmt=*/false);
 }
 
 void StmtPrinter::VisitOMPOpaqueBlockDirective(OMPOpaqueBlockDirective *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 8880a9e277db6..15b4dda81c9fc 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -969,6 +969,7 @@ StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) {
 
 void StmtProfiler::VisitOMPCompoundRootDirective(
     const OMPCompoundRootDirective *S) {
+  // Ignore UnparseStmt.
   VisitOMPExecutableDirective(S);
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 82781eae764ab..a9019a725c6e5 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -206,6 +206,7 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
     break;
   case Stmt::OMPCompoundRootDirectiveClass:
     // Skip this node, go straight through to the associated statement.
+    // Ignore the unparse statement.
     EmitStmt(cast<OMPCompoundRootDirective>(*S).getAssociatedStmt(), Attrs);
     break;
   case Stmt::OMPOpaqueBlockDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 3a59839a0b182..a19a692edb0bb 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -23311,9 +23311,9 @@ static bool checkScanScope(Sema &S, Scope *CurrentS, SourceLocation Loc);
 
 StmtResult SemaOpenMP::ActOnOpenMPCompoundRootDirective(
     OpenMPDirectiveKind DKind, ArrayRef<OMPClause *> Clauses, Stmt *AStmt,
-    SourceLocation StartLoc, SourceLocation EndLoc) {
+    Stmt *UStmt, SourceLocation StartLoc, SourceLocation EndLoc) {
   return OMPCompoundRootDirective::Create(
-      getASTContext(), StartLoc, EndLoc, DKind, Clauses, AStmt);
+      getASTContext(), StartLoc, EndLoc, DKind, Clauses, AStmt, UStmt);
 }
 
 StmtResult SemaOpenMP::ActOnOpenMPOpaqueBlockDirective(



More information about the llvm-branch-commits mailing list