[clang] b737b81 - [OPENMP]Allow cancellation constructs in target parallel regions.

Alexey Bataev via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 30 12:11:10 PDT 2020


Author: Alexey Bataev
Date: 2020-04-30T15:10:52-04:00
New Revision: b737b814feea66379f026550d088fafcd46cb3d2

URL: https://github.com/llvm/llvm-project/commit/b737b814feea66379f026550d088fafcd46cb3d2
DIFF: https://github.com/llvm/llvm-project/commit/b737b814feea66379f026550d088fafcd46cb3d2.diff

LOG: [OPENMP]Allow cancellation constructs in target parallel regions.

Summary:
omp cancellation point parallel and omp cancel parallel directives are
allowed in target paralle regions.

Reviewers: jdoerfert

Subscribers: yaxunl, guansong, caomhin, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D78941

Added: 
    

Modified: 
    clang/include/clang/AST/StmtOpenMP.h
    clang/lib/AST/StmtOpenMP.cpp
    clang/lib/CodeGen/CGOpenMPRuntime.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Serialization/ASTReaderStmt.cpp
    clang/lib/Serialization/ASTWriterStmt.cpp
    clang/test/OpenMP/target_parallel_ast_print.cpp
    clang/test/OpenMP/target_parallel_codegen.cpp
    clang/test/OpenMP/target_parallel_messages.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index b390bf0042f9..7fb63cf9ae3d 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -2805,6 +2805,9 @@ class OMPTargetExitDataDirective : public OMPExecutableDirective {
 ///
 class OMPTargetParallelDirective : public OMPExecutableDirective {
   friend class ASTStmtReader;
+  /// true if the construct has inner cancel directive.
+  bool HasCancel = false;
+
   /// Build directive with the given start and end location.
   ///
   /// \param StartLoc Starting location of the directive kind.
@@ -2827,6 +2830,9 @@ class OMPTargetParallelDirective : public OMPExecutableDirective {
                                SourceLocation(), SourceLocation(), NumClauses,
                                /*NumChildren=*/1) {}
 
+  /// Set cancel state.
+  void setHasCancel(bool Has) { HasCancel = Has; }
+
 public:
   /// Creates directive with a list of \a Clauses.
   ///
@@ -2835,10 +2841,11 @@ class OMPTargetParallelDirective : public OMPExecutableDirective {
   /// \param EndLoc Ending Location of the directive.
   /// \param Clauses List of clauses.
   /// \param AssociatedStmt Statement, associated with the directive.
+  /// \param HasCancel true if this directive has inner cancel directive.
   ///
   static OMPTargetParallelDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
-         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt);
+         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, bool HasCancel);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -2849,6 +2856,9 @@ class OMPTargetParallelDirective : public OMPExecutableDirective {
   static OMPTargetParallelDirective *
   CreateEmpty(const ASTContext &C, unsigned NumClauses, EmptyShell);
 
+  /// Return true if current directive has inner cancel directive.
+  bool hasCancel() const { return HasCancel; }
+
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPTargetParallelDirectiveClass;
   }

diff  --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 995f710876af..57cfbc505d95 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -887,7 +887,7 @@ OMPTargetDirective *OMPTargetDirective::CreateEmpty(const ASTContext &C,
 
 OMPTargetParallelDirective *OMPTargetParallelDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
-    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt) {
+    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, bool HasCancel) {
   unsigned Size =
       llvm::alignTo(sizeof(OMPTargetParallelDirective), alignof(OMPClause *));
   void *Mem =
@@ -896,6 +896,7 @@ OMPTargetParallelDirective *OMPTargetParallelDirective::Create(
       new (Mem) OMPTargetParallelDirective(StartLoc, EndLoc, Clauses.size());
   Dir->setClauses(Clauses);
   Dir->setAssociatedStmt(AssociatedStmt);
+  Dir->setHasCancel(HasCancel);
   return Dir;
 }
 

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 4ad276e40c99..e7525937416e 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -1459,6 +1459,8 @@ static llvm::Function *emitParallelOrTeamsOutlinedFunction(
   bool HasCancel = false;
   if (const auto *OPD = dyn_cast<OMPParallelDirective>(&D))
     HasCancel = OPD->hasCancel();
+  else if (const auto *OPD = dyn_cast<OMPTargetParallelDirective>(&D))
+    HasCancel = OPD->hasCancel();
   else if (const auto *OPSD = dyn_cast<OMPParallelSectionsDirective>(&D))
     HasCancel = OPSD->hasCancel();
   else if (const auto *OPFD = dyn_cast<OMPParallelForDirective>(&D))

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 0ae7c5a0e713..41bfca9c72f8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -9903,7 +9903,7 @@ Sema::ActOnOpenMPTargetParallelDirective(ArrayRef<OMPClause *> Clauses,
   setFunctionHasBranchProtectedScope();
 
   return OMPTargetParallelDirective::Create(Context, StartLoc, EndLoc, Clauses,
-                                            AStmt);
+                                            AStmt, DSAStack->isCancelRegion());
 }
 
 StmtResult Sema::ActOnOpenMPTargetParallelForDirective(

diff  --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 71d862c82d7c..2c91af31ee14 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2489,6 +2489,7 @@ void ASTStmtReader::VisitOMPTargetParallelDirective(
   VisitStmt(D);
   Record.skipInts(1);
   VisitOMPExecutableDirective(D);
+  D->setHasCancel(Record.readBool());
 }
 
 void ASTStmtReader::VisitOMPTargetParallelForDirective(

diff  --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index b7e46bc40c77..fc96cffcb2ec 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2336,6 +2336,7 @@ void ASTStmtWriter::VisitOMPTargetParallelDirective(
   VisitStmt(D);
   Record.push_back(D->getNumClauses());
   VisitOMPExecutableDirective(D);
+  Record.writeBool(D->hasCancel());
   Code = serialization::STMT_OMP_TARGET_PARALLEL_DIRECTIVE;
 }
 

diff  --git a/clang/test/OpenMP/target_parallel_ast_print.cpp b/clang/test/OpenMP/target_parallel_ast_print.cpp
index 775d8da8a650..aa101386b22a 100644
--- a/clang/test/OpenMP/target_parallel_ast_print.cpp
+++ b/clang/test/OpenMP/target_parallel_ast_print.cpp
@@ -225,8 +225,16 @@ int main (int argc, char **argv) {
 
 #pragma omp target parallel defaultmap(tofrom: scalar) reduction(task, +:argc)
 // CHECK-NEXT: #pragma omp target parallel defaultmap(tofrom: scalar) reduction(task, +: argc)
+  {
   foo();
+#pragma omp cancellation point parallel
+#pragma omp cancel parallel
+  }
+// CHECK-NEXT: {
 // CHECK-NEXT: foo();
+// CHECK-NEXT: #pragma omp cancellation point parallel
+// CHECK-NEXT: #pragma omp cancel parallel
+// CHECK-NEXT: }
 
   return tmain<int, 5>(argc, &argc) + tmain<char, 1>(argv[0][0], argv[0]);
 }

diff  --git a/clang/test/OpenMP/target_parallel_codegen.cpp b/clang/test/OpenMP/target_parallel_codegen.cpp
index 12026696ba94..2e094c294dfa 100644
--- a/clang/test/OpenMP/target_parallel_codegen.cpp
+++ b/clang/test/OpenMP/target_parallel_codegen.cpp
@@ -134,6 +134,7 @@ int foo(int n) {
   #pragma omp target parallel if(target: 1)
   {
     aa += 1;
+#pragma omp cancel parallel
   }
 
   // CHECK:       [[IF:%.+]] = icmp sgt i32 {{[^,]+}}, 10
@@ -360,6 +361,12 @@ int foo(int n) {
 // CHECK:       store i[[SZ]] %{{.+}}, i[[SZ]]* [[AA_ADDR]], align
 // CHECK:       [[AA_CADDR:%.+]] = bitcast i[[SZ]]* [[AA_ADDR]] to i16*
 // CHECK:       [[AA:%.+]] = load i16, i16* [[AA_CADDR]], align
+// CHECK:       [[IS_CANCEL:%.+]] = call i32 @__kmpc_cancel(%struct.ident_t* @{{.+}}, i32 %{{.+}}, i32 1)
+// CHECK:       [[CMP:%.+]] = icmp ne i32 [[IS_CANCEL]], 0
+// CHECK:       br i1 [[CMP]], label %[[EXIT:.+]], label %[[CONTINUE:[^,]+]]
+// CHECK:       [[EXIT]]:
+// CHECK:       br label %[[CONTINUE]]
+// CHECK:       [[CONTINUE]]:
 // CHECK:       ret void
 // CHECK-NEXT:  }
 

diff  --git a/clang/test/OpenMP/target_parallel_messages.cpp b/clang/test/OpenMP/target_parallel_messages.cpp
index 25e63bb72b59..1ccdf647db9c 100644
--- a/clang/test/OpenMP/target_parallel_messages.cpp
+++ b/clang/test/OpenMP/target_parallel_messages.cpp
@@ -76,6 +76,20 @@ int main(int argc, char **argv) {
   #pragma omp target parallel copyin(pvt) // expected-error {{unexpected OpenMP clause 'copyin' in directive '#pragma omp target parallel'}}
   foo();
 
+  #pragma omp target parallel
+  {
+#pragma omp cancel // expected-error {{one of 'for', 'parallel', 'sections' or 'taskgroup' is expected}}
+#pragma omp cancellation point // expected-error {{one of 'for', 'parallel', 'sections' or 'taskgroup' is expected}}
+#pragma omp cancel for // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancellation point for // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancel sections // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancellation point sections // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancel taskgroup // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancellation point taskgroup // expected-error {{region cannot be closely nested inside 'target parallel' region}}
+#pragma omp cancel parallel
+#pragma omp cancellation point parallel
+  }
+
   return 0;
 }
 


        


More information about the cfe-commits mailing list