[llvm] b7899ba - [OPENMP51]Initial support for the dispatch directive.

Mike Rice via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 30 14:22:26 PDT 2021


Author: Mike Rice
Date: 2021-03-30T14:12:53-07:00
New Revision: b7899ba0e8b7a7b8fbab06a8b3f2d70f16d7a250

URL: https://github.com/llvm/llvm-project/commit/b7899ba0e8b7a7b8fbab06a8b3f2d70f16d7a250
DIFF: https://github.com/llvm/llvm-project/commit/b7899ba0e8b7a7b8fbab06a8b3f2d70f16d7a250.diff

LOG: [OPENMP51]Initial support for the dispatch directive.

Added basic parsing/sema/serialization support for dispatch directive.

Differential Revision: https://reviews.llvm.org/D99537

Added: 
    clang/test/OpenMP/dispatch_ast_print.cpp
    clang/test/OpenMP/dispatch_messages.cpp

Modified: 
    clang/include/clang-c/Index.h
    clang/include/clang/AST/RecursiveASTVisitor.h
    clang/include/clang/AST/StmtOpenMP.h
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/include/clang/Basic/StmtNodes.td
    clang/include/clang/Sema/Sema.h
    clang/include/clang/Serialization/ASTBitCodes.h
    clang/lib/AST/StmtOpenMP.cpp
    clang/lib/AST/StmtPrinter.cpp
    clang/lib/AST/StmtProfile.cpp
    clang/lib/Basic/OpenMPKinds.cpp
    clang/lib/CodeGen/CGStmt.cpp
    clang/lib/Parse/ParseOpenMP.cpp
    clang/lib/Sema/SemaExceptionSpec.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Sema/TreeTransform.h
    clang/lib/Serialization/ASTReaderStmt.cpp
    clang/lib/Serialization/ASTWriterStmt.cpp
    clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
    clang/tools/libclang/CIndex.cpp
    clang/tools/libclang/CXCursor.cpp
    llvm/include/llvm/Frontend/OpenMP/OMP.td
    llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Removed: 
    


################################################################################
diff  --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index bcc063051b8c3..696e271a72313 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2580,7 +2580,11 @@ enum CXCursorKind {
    */
   CXCursor_OMPInteropDirective = 290,
 
-  CXCursor_LastStmt = CXCursor_OMPInteropDirective,
+  /** OpenMP dispatch directive.
+   */
+  CXCursor_OMPDispatchDirective = 291,
+
+  CXCursor_LastStmt = CXCursor_OMPDispatchDirective,
 
   /**
    * Cursor that represents the translation unit itself.

diff  --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 256f73338bd27..23864819bc07e 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2971,6 +2971,9 @@ DEF_TRAVERSE_STMT(OMPTargetTeamsDistributeSimdDirective,
 DEF_TRAVERSE_STMT(OMPInteropDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPDispatchDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 // OpenMP clauses.
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) {

diff  --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 6ddc29446a8c8..a59b36c302a3d 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -5139,6 +5139,72 @@ class OMPInteropDirective final : public OMPExecutableDirective {
   }
 };
 
+/// This represents '#pragma omp dispatch' directive.
+///
+/// \code
+/// #pragma omp dispatch device(dnum)
+/// \endcode
+/// This example shows a directive '#pragma omp dispatch' with a
+/// device clause with variable 'dnum'.
+///
+class OMPDispatchDirective final : public OMPExecutableDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// The location of the target-call.
+  SourceLocation TargetCallLoc;
+
+  /// Set the location of the target-call.
+  void setTargetCallLoc(SourceLocation Loc) { TargetCallLoc = Loc; }
+
+  /// Build directive with the given start and end location.
+  ///
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending location of the directive.
+  ///
+  OMPDispatchDirective(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPExecutableDirective(OMPDispatchDirectiveClass,
+                               llvm::omp::OMPD_dispatch, StartLoc, EndLoc) {}
+
+  /// Build an empty directive.
+  ///
+  explicit OMPDispatchDirective()
+      : OMPExecutableDirective(OMPDispatchDirectiveClass,
+                               llvm::omp::OMPD_dispatch, SourceLocation(),
+                               SourceLocation()) {}
+
+public:
+  /// Creates directive with a list of \a Clauses.
+  ///
+  /// \param C AST context.
+  /// \param StartLoc Starting location of the directive kind.
+  /// \param EndLoc Ending Location of the directive.
+  /// \param Clauses List of clauses.
+  /// \param AssociatedStmt Statement, associated with the directive.
+  /// \param TargetCallLoc Location of the target-call.
+  ///
+  static OMPDispatchDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+         SourceLocation TargetCallLoc);
+
+  /// Creates an empty directive with the place for \a NumClauses
+  /// clauses.
+  ///
+  /// \param C AST context.
+  /// \param NumClauses Number of clauses.
+  ///
+  static OMPDispatchDirective *CreateEmpty(const ASTContext &C,
+                                           unsigned NumClauses, EmptyShell);
+
+  /// Return location of target-call.
+  SourceLocation getTargetCallLoc() const { return TargetCallLoc; }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPDispatchDirectiveClass;
+  }
+};
+
 } // end namespace clang
 
 #endif

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index ad592d5520300..099e73ed013a0 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10616,6 +10616,9 @@ def err_omp_interop_bad_depend_clause : Error<
   "'depend' clause requires the 'targetsync' interop type">;
 def err_omp_interop_var_multiple_actions : Error<
   "interop variable %0 used in multiple action clauses">;
+def err_omp_dispatch_statement_call
+    : Error<"statement after '#pragma omp dispatch' must be a direct call"
+            " to a target function or an assignment to one">;
 } // end of OpenMP category
 
 let CategoryName = "Related Result Type Issue" in {

diff  --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 4aac8a32ef976..d2e8fa69a5840 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -276,3 +276,4 @@ def OMPTargetTeamsDistributeParallelForDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetTeamsDistributeParallelForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetTeamsDistributeSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPInteropDirective : StmtNode<OMPExecutableDirective>;
+def OMPDispatchDirective : StmtNode<OMPExecutableDirective>;

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 8e1bc3f2dbdab..b5c8d3d292fcf 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -10804,6 +10804,11 @@ class Sema final {
   StmtResult ActOnOpenMPInteropDirective(ArrayRef<OMPClause *> Clauses,
                                          SourceLocation StartLoc,
                                          SourceLocation EndLoc);
+  /// Called on well-formed '\#pragma omp dispatch' after parsing of the
+  // /associated statement.
+  StmtResult ActOnOpenMPDispatchDirective(ArrayRef<OMPClause *> Clauses,
+                                          Stmt *AStmt, SourceLocation StartLoc,
+                                          SourceLocation EndLoc);
 
   /// Checks correctness of linear modifiers.
   bool CheckOpenMPLinearModifier(OpenMPLinearClauseKind LinKind,

diff  --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index 64f15e75bc2c8..17690b901eacf 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1941,6 +1941,7 @@ enum StmtCode {
   STMT_OMP_TARGET_TEAMS_DISTRIBUTE_PARALLEL_FOR_SIMD_DIRECTIVE,
   STMT_OMP_TARGET_TEAMS_DISTRIBUTE_SIMD_DIRECTIVE,
   STMT_OMP_INTEROP_DIRECTIVE,
+  STMT_OMP_DISPATCH_DIRECTIVE,
   EXPR_OMP_ARRAY_SECTION,
   EXPR_OMP_ARRAY_SHAPING,
   EXPR_OMP_ITERATOR,

diff  --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index fa77b862f3d0e..7dc43ea924508 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -1959,3 +1959,21 @@ OMPInteropDirective *OMPInteropDirective::CreateEmpty(const ASTContext &C,
                                                       EmptyShell) {
   return createEmptyDirective<OMPInteropDirective>(C, NumClauses);
 }
+
+OMPDispatchDirective *OMPDispatchDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+    SourceLocation TargetCallLoc) {
+  auto *Dir = createDirective<OMPDispatchDirective>(
+      C, Clauses, AssociatedStmt, /*NumChildren=*/0, StartLoc, EndLoc);
+  Dir->setTargetCallLoc(TargetCallLoc);
+  return Dir;
+}
+
+OMPDispatchDirective *OMPDispatchDirective::CreateEmpty(const ASTContext &C,
+                                                        unsigned NumClauses,
+                                                        EmptyShell) {
+  return createEmptyDirective<OMPDispatchDirective>(C, NumClauses,
+                                                    /*HasAssociatedStmt=*/true,
+                                                    /*NumChildren=*/0);
+}

diff  --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index ca35c6dccbf8c..5993268971f8b 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -967,6 +967,11 @@ void StmtPrinter::VisitOMPInteropDirective(OMPInteropDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPDispatchDirective(OMPDispatchDirective *Node) {
+  Indent() << "#pragma omp dispatch";
+  PrintOMPExecutableDirective(Node);
+}
+
 //===----------------------------------------------------------------------===//
 //  Expr printing methods.
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index bf130ed4ff3d3..fa3ab3ce977a5 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -1144,6 +1144,10 @@ void StmtProfiler::VisitOMPInteropDirective(const OMPInteropDirective *S) {
   VisitOMPExecutableDirective(S);
 }
 
+void StmtProfiler::VisitOMPDispatchDirective(const OMPDispatchDirective *S) {
+  VisitOMPExecutableDirective(S);
+}
+
 void StmtProfiler::VisitExpr(const Expr *S) {
   VisitStmt(S);
 }

diff  --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index e289e953d47ff..00abb3062ca61 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -660,6 +660,7 @@ void clang::getOpenMPCaptureRegions(
   case OMPD_atomic:
   case OMPD_target_data:
   case OMPD_distribute_simd:
+  case OMPD_dispatch:
     CaptureRegions.push_back(OMPD_unknown);
     break;
   case OMPD_tile:

diff  --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index fb719efb1a35d..b4ddfa586d1aa 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -378,6 +378,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::OMPInteropDirectiveClass:
     llvm_unreachable("Interop directive not supported yet.");
     break;
+  case Stmt::OMPDispatchDirectiveClass:
+    llvm_unreachable("Dispatch directive not supported yet.");
+    break;
   }
 }
 

diff  --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 1ea01409d3d36..54e24e94e6611 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2207,6 +2207,7 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
   case OMPD_target_teams_distribute_parallel_for:
   case OMPD_target_teams_distribute_parallel_for_simd:
   case OMPD_target_teams_distribute_simd:
+  case OMPD_dispatch:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);
     break;
@@ -2430,7 +2431,8 @@ Parser::ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx) {
   case OMPD_target_teams_distribute:
   case OMPD_target_teams_distribute_parallel_for:
   case OMPD_target_teams_distribute_parallel_for_simd:
-  case OMPD_target_teams_distribute_simd: {
+  case OMPD_target_teams_distribute_simd:
+  case OMPD_dispatch: {
     // Special processing for flush and depobj clauses.
     Token ImplicitTok;
     bool ImplicitClauseAllowed = false;

diff  --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index a9a487bd47a38..12f47cb0630d0 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1487,6 +1487,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass:
   case Stmt::OMPTeamsDistributeSimdDirectiveClass:
   case Stmt::OMPInteropDirectiveClass:
+  case Stmt::OMPDispatchDirectiveClass:
   case Stmt::ReturnStmtClass:
   case Stmt::SEHExceptStmtClass:
   case Stmt::SEHFinallyStmtClass:

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index fcb95e3a8442c..48fe19e50ee71 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -3979,7 +3979,8 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
   case OMPD_distribute:
   case OMPD_distribute_simd:
   case OMPD_ordered:
-  case OMPD_target_data: {
+  case OMPD_target_data:
+  case OMPD_dispatch: {
     Sema::CapturedParamNameType Params[] = {
         std::make_pair(StringRef(), QualType()) // __context with shared vars
     };
@@ -6120,6 +6121,10 @@ StmtResult Sema::ActOnOpenMPExecutableDirective(
            "No associated statement allowed for 'omp interop' directive");
     Res = ActOnOpenMPInteropDirective(ClausesWithImplicit, StartLoc, EndLoc);
     break;
+  case OMPD_dispatch:
+    Res = ActOnOpenMPDispatchDirective(ClausesWithImplicit, AStmt, StartLoc,
+                                       EndLoc);
+    break;
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_threadprivate:
@@ -9758,6 +9763,64 @@ StmtResult Sema::ActOnOpenMPSectionDirective(Stmt *AStmt,
                                      DSAStack->isCancelRegion());
 }
 
+static Expr *getDirectCallExpr(Expr *E) {
+  E = E->IgnoreParenCasts()->IgnoreImplicit();
+  if (auto *CE = dyn_cast<CallExpr>(E))
+    if (CE->getDirectCallee())
+      return E;
+  return nullptr;
+}
+
+StmtResult Sema::ActOnOpenMPDispatchDirective(ArrayRef<OMPClause *> Clauses,
+                                              Stmt *AStmt,
+                                              SourceLocation StartLoc,
+                                              SourceLocation EndLoc) {
+  if (!AStmt)
+    return StmtError();
+
+  Stmt *S = cast<CapturedStmt>(AStmt)->getCapturedStmt();
+
+  // 5.1 OpenMP
+  // expression-stmt : an expression statement with one of the following forms:
+  //   expression = target-call ( [expression-list] );
+  //   target-call ( [expression-list] );
+
+  SourceLocation TargetCallLoc;
+
+  if (!CurContext->isDependentContext()) {
+    Expr *TargetCall = nullptr;
+
+    auto *E = dyn_cast<Expr>(S);
+    if (!E) {
+      Diag(S->getBeginLoc(), diag::err_omp_dispatch_statement_call);
+      return StmtError();
+    }
+
+    E = E->IgnoreParenCasts()->IgnoreImplicit();
+
+    if (auto *BO = dyn_cast<BinaryOperator>(E)) {
+      if (BO->getOpcode() == BO_Assign)
+        TargetCall = getDirectCallExpr(BO->getRHS());
+    } else {
+      if (auto *COCE = dyn_cast<CXXOperatorCallExpr>(E))
+        if (COCE->getOperator() == OO_Equal)
+          TargetCall = getDirectCallExpr(COCE->getArg(1));
+      if (!TargetCall)
+        TargetCall = getDirectCallExpr(E);
+    }
+    if (!TargetCall) {
+      Diag(E->getBeginLoc(), diag::err_omp_dispatch_statement_call);
+      return StmtError();
+    }
+    TargetCallLoc = TargetCall->getExprLoc();
+  }
+
+  setFunctionHasBranchProtectedScope();
+
+  return OMPDispatchDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
+                                      TargetCallLoc);
+}
+
 StmtResult Sema::ActOnOpenMPSingleDirective(ArrayRef<OMPClause *> Clauses,
                                             Stmt *AStmt,
                                             SourceLocation StartLoc,
@@ -13349,6 +13412,7 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause(
     case OMPD_target_parallel_for_simd:
     case OMPD_target_teams_distribute_parallel_for:
     case OMPD_target_teams_distribute_parallel_for_simd:
+    case OMPD_dispatch:
       CaptureRegion = OMPD_task;
       break;
     case OMPD_target_data:

diff  --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 6fa91035a0c48..86a25149a23c1 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -9069,6 +9069,17 @@ TreeTransform<Derived>::TransformOMPInteropDirective(OMPInteropDirective *D) {
   return Res;
 }
 
+template <typename Derived>
+StmtResult
+TreeTransform<Derived>::TransformOMPDispatchDirective(OMPDispatchDirective *D) {
+  DeclarationNameInfo DirName;
+  getDerived().getSema().StartOpenMPDSABlock(OMPD_dispatch, DirName, nullptr,
+                                             D->getBeginLoc());
+  StmtResult Res = getDerived().TransformOMPExecutableDirective(D);
+  getDerived().getSema().EndOpenMPDSABlock(Res.get());
+  return Res;
+}
+
 //===----------------------------------------------------------------------===//
 // OpenMP clause transformation
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 22e1f57e3313e..6e552e9a2441f 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2593,6 +2593,12 @@ void ASTStmtReader::VisitOMPInteropDirective(OMPInteropDirective *D) {
   VisitOMPExecutableDirective(D);
 }
 
+void ASTStmtReader::VisitOMPDispatchDirective(OMPDispatchDirective *D) {
+  VisitStmt(D);
+  VisitOMPExecutableDirective(D);
+  D->setTargetCallLoc(Record.readSourceLocation());
+}
+
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
@@ -3513,6 +3519,11 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) {
           Context, Record[ASTStmtReader::NumStmtFields], Empty);
       break;
 
+    case STMT_OMP_DISPATCH_DIRECTIVE:
+      S = OMPDispatchDirective::CreateEmpty(
+          Context, Record[ASTStmtReader::NumStmtFields], Empty);
+      break;
+
     case EXPR_CXX_OPERATOR_CALL:
       S = CXXOperatorCallExpr::CreateEmpty(
           Context, /*NumArgs=*/Record[ASTStmtReader::NumExprFields],

diff  --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index ed14058aeb0d6..97ecd5a773ea1 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2547,6 +2547,13 @@ void ASTStmtWriter::VisitOMPInteropDirective(OMPInteropDirective *D) {
   Code = serialization::STMT_OMP_INTEROP_DIRECTIVE;
 }
 
+void ASTStmtWriter::VisitOMPDispatchDirective(OMPDispatchDirective *D) {
+  VisitStmt(D);
+  VisitOMPExecutableDirective(D);
+  Record.AddSourceLocation(D->getTargetCallLoc());
+  Code = serialization::STMT_OMP_DISPATCH_DIRECTIVE;
+}
+
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index fbd5fd87fcf0e..1498efd135d2c 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1295,6 +1295,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass:
     case Stmt::OMPTileDirectiveClass:
     case Stmt::OMPInteropDirectiveClass:
+    case Stmt::OMPDispatchDirectiveClass:
     case Stmt::CapturedStmtClass: {
       const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
       Engine.addAbortedBlock(node, currBldrCtx->getBlock());

diff  --git a/clang/test/OpenMP/dispatch_ast_print.cpp b/clang/test/OpenMP/dispatch_ast_print.cpp
new file mode 100644
index 0000000000000..3dea4cbcd3f97
--- /dev/null
+++ b/clang/test/OpenMP/dispatch_ast_print.cpp
@@ -0,0 +1,215 @@
+// RUN: %clang_cc1 -triple=x86_64-pc-win32 -fopenmp -fopenmp-version=51     \
+// RUN:   -fsyntax-only -verify %s
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -fsyntax-only -verify %s
+
+// expected-no-diagnostics
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -ast-print %s | FileCheck %s --check-prefix=PRINT
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -ast-dump  %s | FileCheck %s --check-prefix=DUMP
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -emit-pch -o %t %s
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix=DUMP
+
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 \
+// RUN:   -include-pch %t -ast-print %s | FileCheck %s --check-prefix=PRINT
+
+#ifndef HEADER
+#define HEADER
+
+int foo_gpu(int A, int *B) { return 0;}
+//PRINT: #pragma omp declare variant(foo_gpu)
+//DUMP: FunctionDecl{{.*}} foo
+//DUMP: OMPDeclareVariantAttr {{.*}}Implicit construct{{.*}}
+#pragma omp declare variant(foo_gpu) \
+    match(construct={dispatch}, device={arch(arm)})
+int foo(int, int*);
+
+template <typename T, typename TP>
+void fooTemp() {
+  T a;
+  TP b;
+  //PRINT: #pragma omp dispatch nowait
+  //DUMP: OMPDispatchDirective
+  //DUMP: OMPNowaitClause
+  #pragma omp dispatch nowait
+  foo(a, b);
+}
+
+int *get_device_ptr();
+int get_device();
+int other();
+
+//DUMP: FunctionDecl{{.*}} test_one
+void test_one()
+{
+  int aaa, bbb, var;
+  //PRINT: #pragma omp dispatch depend(in : var) nowait
+  //DUMP: OMPDispatchDirective
+  //DUMP: OMPDependClause
+  //DUMP: OMPNowaitClause
+  #pragma omp dispatch depend(in:var) nowait
+  foo(aaa, &bbb);
+
+  int *dp = get_device_ptr();
+  int dev = get_device();
+  //PRINT: #pragma omp dispatch device(dev) is_device_ptr(dp)
+  //DUMP: OMPDispatchDirective
+  //DUMP: OMPDeviceClause
+  //DUMP: OMPIs_device_ptrClause
+  #pragma omp dispatch device(dev) is_device_ptr(dp)
+  foo(aaa, dp);
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: foo(other(), &bbb);
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  foo(other(), &bbb);
+
+  fooTemp<int, int*>();
+}
+
+struct Obj {
+  Obj();
+  ~Obj();
+  int disp_method_variant1();
+  #pragma omp declare variant(disp_method_variant1)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  int disp_method1();
+
+  static int disp_method_variant2() { return 1; }
+  #pragma omp declare variant(disp_method_variant2)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  static int disp_method2() { return 2; }
+};
+
+Obj foo_vari();
+#pragma omp declare variant(foo_vari) \
+  match(construct={dispatch}, device={arch(arm)})
+Obj foo_obj();
+
+//DUMP: FunctionDecl{{.*}} test_two
+void test_two(Obj o1, Obj &o2, Obj *o3)
+{
+  //PRINT: #pragma omp dispatch
+  //PRINT: o1.disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  o1.disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: o2.disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  o2.disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: o3->disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  o3->disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: Obj::disp_method2();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  Obj::disp_method2();
+
+  int ret;
+  //PRINT: #pragma omp dispatch
+  //PRINT: ret = o1.disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  ret = o1.disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: ret = o2.disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  ret = o2.disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: ret = o3->disp_method1();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  ret = o3->disp_method1();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: ret = Obj::disp_method2();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  ret = Obj::disp_method2();
+
+  //PRINT: #pragma omp dispatch
+  //PRINT: (void)Obj::disp_method2();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  (void)Obj::disp_method2();
+
+  // Full C++ operator= case with temps and EH.
+  Obj o;
+  //PRINT: #pragma omp dispatch
+  //PRINT: o = foo_obj();
+  //DUMP: OMPDispatchDirective
+  #pragma omp dispatch
+  o = foo_obj();
+}
+
+struct A {
+  A& disp_operator(A other);
+  #pragma omp declare variant(disp_operator)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  A& operator=(A other);
+};
+
+struct Obj2 {
+  A xx;
+  Obj2& disp_operator(Obj2 other);
+  #pragma omp declare variant(disp_operator)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  Obj2& operator=(Obj2 other);
+
+  void foo() {
+    Obj2 z;
+    //PRINT: #pragma omp dispatch
+    //PRINT: z = z;
+    //DUMP: OMPDispatchDirective
+    #pragma omp dispatch
+    z = z;
+    //PRINT: #pragma omp dispatch
+    //PRINT: z.operator=(z);
+    //DUMP: OMPDispatchDirective
+    #pragma omp dispatch
+    z.operator=(z);
+  }
+  void bar() {
+    Obj2 j;
+    //PRINT: #pragma omp dispatch
+    //PRINT: j = {this->xx};
+    //DUMP: OMPDispatchDirective
+    #pragma omp dispatch
+    j = {this->xx};
+    //PRINT: #pragma omp dispatch
+    //PRINT: j.operator=({this->xx});
+    //DUMP: OMPDispatchDirective
+    #pragma omp dispatch
+    j.operator=({this->xx});
+  }
+};
+
+void test_three()
+{
+  Obj2 z1, z;
+  #pragma omp dispatch
+  z1 = z;
+  #pragma omp dispatch
+  z1.operator=(z);
+}
+#endif // HEADER

diff  --git a/clang/test/OpenMP/dispatch_messages.cpp b/clang/test/OpenMP/dispatch_messages.cpp
new file mode 100644
index 0000000000000..221c7b784c00e
--- /dev/null
+++ b/clang/test/OpenMP/dispatch_messages.cpp
@@ -0,0 +1,82 @@
+// RUN: %clang_cc1 -triple=x86_64-pc-win32 -verify -fopenmp    \
+// RUN:   -x c++ -std=c++14 -fexceptions -fcxx-exceptions %s
+
+// RUN: %clang_cc1 -triple=x86_64-pc-linux-gnu -verify -fopenmp \
+// RUN:   -x c++ -std=c++14 -fexceptions -fcxx-exceptions %s
+
+int disp_variant();
+#pragma omp declare variant(disp_variant) \
+    match(construct = {dispatch}, device = {arch(arm)})
+int disp_call();
+
+struct Obj {
+  int disp_method_variant1();
+  #pragma omp declare variant(disp_method_variant1)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  int disp_method1();
+  int disp_method_variant2();
+  #pragma omp declare variant(disp_method_variant2)                            \
+    match(construct={dispatch}, device={arch(arm)})
+  int disp_method2();
+};
+
+void testit_one(int dnum) {
+  // expected-error at +1 {{cannot contain more than one 'device' clause}}
+  #pragma omp dispatch device(dnum) device(3)
+  disp_call();
+
+  // expected-error at +1 {{cannot contain more than one 'nowait' clause}}
+  #pragma omp dispatch nowait device(dnum) nowait
+  disp_call();
+}
+
+void testit_two() {
+  //expected-error at +2 {{cannot return from OpenMP region}}
+  #pragma omp dispatch
+  return disp_call();
+}
+
+void testit_three(int (*fptr)(void), Obj *obj, int (Obj::*mptr)(void)) {
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  fptr();
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  (obj->*mptr)();
+
+  int ret;
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  ret = fptr();
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  ret = (obj->*mptr)();
+}
+
+void testit_four(int *x, int y, Obj *obj)
+{
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  *x = y;
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  y = disp_call() + disp_call();
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  y = (y = disp_call());
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  y += disp_call();
+
+  //expected-error at +2 {{statement after '#pragma omp dispatch' must be a direct call to a target function or an assignment to one}}
+  #pragma omp dispatch
+  for (int I = 0; I < 8; ++I) {
+    disp_call();
+  }
+}

diff  --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 841b36a6036c9..88b31e78feec8 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -5668,6 +5668,8 @@ CXString clang_getCursorKindSpelling(enum CXCursorKind Kind) {
     return cxstring::createRef("OMPTargetTeamsDistributeSimdDirective");
   case CXCursor_OMPInteropDirective:
     return cxstring::createRef("OMPInteropDirective");
+  case CXCursor_OMPDispatchDirective:
+    return cxstring::createRef("OMPDispatchDirective");
   case CXCursor_OverloadCandidate:
     return cxstring::createRef("OverloadCandidate");
   case CXCursor_TypeAliasTemplateDecl:

diff  --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 2f8d1e35936ec..d715d761f6911 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -810,6 +810,9 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::OMPInteropDirectiveClass:
     K = CXCursor_OMPInteropDirective;
     break;
+  case Stmt::OMPDispatchDirectiveClass:
+    K = CXCursor_OMPDispatchDirective;
+    break;
   case Stmt::BuiltinBitCastExprClass:
     K = CXCursor_BuiltinBitCastExpr;
   }

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 7845e4bc98dbf..6e494da564384 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1655,6 +1655,14 @@ def OMP_interop : Directive<"interop"> {
     VersionedClause<OMPC_Use>,
   ];
 }
+def OMP_dispatch : Directive<"dispatch"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_Depend>
+  ];
+}
 def OMP_Unknown : Directive<"unknown"> {
   let isDefault = true;
 }

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 9a77dd7f098d6..31221b2f26c99 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -1078,6 +1078,7 @@ __OMP_TRAIT_PROPERTY(user, condition, true)
 __OMP_TRAIT_PROPERTY(user, condition, false)
 __OMP_TRAIT_PROPERTY(user, condition, unknown)
 
+__OMP_TRAIT_SELECTOR_AND_PROPERTY(construct, dispatch)
 
 // Note that we put isa last so that the other conditions are checked first.
 // This allows us to issue warnings wrt. isa only if we match otherwise.


        


More information about the llvm-commits mailing list