[clang] [llvm] Assume (PR #97535)

via cfe-commits cfe-commits at lists.llvm.org
Wed Jul 3 11:58:32 PDT 2024


https://github.com/SunilKuravinakop updated https://github.com/llvm/llvm-project/pull/97535

>From a331d2343a43a9ca2030ef07a25ca93ff26a5778 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe28vega.us.cray.com>
Date: Mon, 1 Jul 2024 03:29:45 -0500
Subject: [PATCH 1/4] Support for "#pragma omp assume" Directive in Parse and
 AST.

  Changes to be committed:
 	modified:   clang/include/clang/Parse/Parser.h
 	modified:   clang/include/clang/Sema/Sema.h
 	modified:   clang/lib/Basic/OpenMPKinds.cpp
 	modified:   clang/lib/Parse/ParseOpenMP.cpp
 	modified:   clang/lib/Sema/SemaOpenMP.cpp
 	modified:   llvm/include/llvm/Frontend/OpenMP/OMP.td
---
 clang/include/clang/Parse/Parser.h       |  4 +-
 clang/include/clang/Sema/Sema.h          |  3 ++
 clang/lib/Basic/OpenMPKinds.cpp          |  3 ++
 clang/lib/Parse/ParseOpenMP.cpp          | 33 +++++++++++-
 clang/lib/Sema/SemaOpenMP.cpp            | 67 +++++++++++++++++++++++-
 llvm/include/llvm/Frontend/OpenMP/OMP.td |  3 ++
 6 files changed, 108 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 5950dd74cfe83..7568c053eb77f 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3380,8 +3380,8 @@ class Parser : public CodeCompletionHandler {
                                      SourceLocation Loc);
 
   /// Parse 'omp [begin] assume[s]' directive.
-  void ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
-                                   SourceLocation Loc);
+  StmtResult ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
+                                         SourceLocation Loc);
 
   /// Parse 'omp end assumes' directive.
   void ParseOpenMPEndAssumesDirective(SourceLocation Loc);
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index c6035445e9019..652df4a16e62d 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -12966,6 +12966,9 @@ class Sema final : public SemaBase {
   /// Act on \p D, a function definition inside of an `omp [begin/end] assumes`.
   void ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D);
 
+  /// Act on \p D, Associated statements of `omp assume`.
+  StmtResult ActOnFinishedStatementInOpenMPAssumeScope(Stmt *);
+
   /// Can we exit an OpenMP declare variant scope at the moment.
   bool isInOpenMPDeclareVariantScope() const {
     return !OMPDeclareVariantScopes.empty();
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index b3e9affbb3e58..97a0dd508c0a5 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -721,6 +721,9 @@ void clang::getOpenMPCaptureRegions(
   case OMPD_parallel_loop:
     CaptureRegions.push_back(OMPD_parallel);
     break;
+  case OMPD_assume:
+    CaptureRegions.push_back(OMPD_assume);
+    break;
   case OMPD_target_teams:
   case OMPD_target_teams_distribute:
   case OMPD_target_teams_distribute_simd:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 814126e321d3b..3d1a9564897af 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -1673,8 +1673,8 @@ void Parser::ParseOpenMPClauses(OpenMPDirectiveKind DKind,
 ///     'no_openmp_routines'
 ///     'no_parallelism'
 ///
-void Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
-                                         SourceLocation Loc) {
+StmtResult Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
+                                               SourceLocation Loc) {
   SmallVector<std::string, 4> Assumptions;
   bool SkippedClauses = false;
 
@@ -1750,7 +1750,33 @@ void Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
     Assumptions.push_back(Assumption);
   }
 
+  StmtResult AssociatedStmt;
+
+  // Fix the scope for assume.
+  if (DKind == llvm::omp::Directive::OMPD_assume) {
+
+    if (Tok.getKind() == clang::tok::annot_pragma_openmp_end)
+      ConsumeAnyToken();
+
+    DeclarationNameInfo DirName;
+    Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc);
+  }
+
   Actions.ActOnOpenMPAssumesDirective(Loc, DKind, Assumptions, SkippedClauses);
+
+  if (DKind == llvm::omp::Directive::OMPD_assume) {
+
+    AssociatedStmt = ParseStatement();
+    AssociatedStmt =
+        Actions.ActOnFinishedStatementInOpenMPAssumeScope(AssociatedStmt.get());
+
+    // End the scope for assume.
+    ParseOpenMPEndAssumesDirective(Loc);
+    Actions.EndOpenMPDSABlock(nullptr);
+    if (Tok.getKind() == clang::tok::annot_pragma_openmp_end)
+      ConsumeAnyToken();
+  }
+  return AssociatedStmt;
 }
 
 void Parser::ParseOpenMPEndAssumesDirective(SourceLocation Loc) {
@@ -2759,6 +2785,9 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
     }
     break;
   }
+  case OMPD_assume:
+    Directive = ParseOpenMPAssumesDirective(DKind, ConsumeToken());
+    break;
   case OMPD_flush:
   case OMPD_depobj:
   case OMPD_scan:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index e9efb4721133f..91d84d57589dc 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -3497,7 +3497,8 @@ void Sema::ActOnOpenMPAssumesDirective(SourceLocation Loc,
         << llvm::omp::getOpenMPDirectiveName(DKind);
 
   auto *AA = OMPAssumeAttr::Create(Context, llvm::join(Assumptions, ","), Loc);
-  if (DKind == llvm::omp::Directive::OMPD_begin_assumes) {
+  if (DKind == llvm::omp::Directive::OMPD_begin_assumes ||
+      DKind == llvm::omp::Directive::OMPD_assume) {
     OMPAssumeScoped.push_back(AA);
     return;
   }
@@ -4195,6 +4196,7 @@ static void handleDeclareVariantConstructTrait(DSAStackTy *Stack,
 
 void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
   switch (DKind) {
+  case OMPD_assume:
   case OMPD_parallel:
   case OMPD_parallel_for:
   case OMPD_parallel_for_simd:
@@ -7357,6 +7359,69 @@ void Sema::ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D) {
     FD->addAttr(AA);
 }
 
+class OMPAssumeStmtVisitor : public StmtVisitor<OMPAssumeStmtVisitor> {
+  SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped;
+
+public:
+  OMPAssumeStmtVisitor(SmallVector<OMPAssumeAttr *, 4> *OMPAssumeScoped) {
+    this->OMPAssumeScoped = OMPAssumeScoped;
+  }
+
+  void VisitCapturedStmt(CapturedStmt *CS) {
+    // To find the CaptureDecl for the CaptureStmt
+    CapturedDecl *CD = CS->getCapturedDecl();
+    if (CD) {
+      for (OMPAssumeAttr *AA : *OMPAssumeScoped)
+        CD->addAttr(AA);
+    }
+  }
+
+  void VisitCompoundStmt(CompoundStmt *CS) {
+    // Handle CompoundStmt
+    // Visit each statement in the CompoundStmt
+    for (Stmt *SubStmt : CS->body()) {
+      if (Expr *CE = dyn_cast<Expr>(SubStmt)) {
+        // If the statement is a Expr, process it
+        VisitExpr(CE);
+      }
+    }
+  }
+
+  void VisitExpr(Expr *CE) {
+    // Handle all Expr
+    for (auto *Child : CE->children()) {
+      Visit(Child);
+    }
+  }
+
+  void Visit(Stmt *S) {
+    const char *CName = S->getStmtClassName();
+    if ((strstr(CName, "OMP") != NULL) &&
+        (strstr(CName, "Directive") != NULL)) {
+      for (Stmt *Child : S->children()) {
+        auto *CS = dyn_cast<CapturedStmt>(Child);
+        if (CS)
+          VisitCapturedStmt(CS);
+        else
+          StmtVisitor<OMPAssumeStmtVisitor>::Visit(Child);
+      }
+    } else {
+      StmtVisitor<OMPAssumeStmtVisitor>::Visit(S);
+    }
+  }
+};
+
+StmtResult
+Sema::ActOnFinishedStatementInOpenMPAssumeScope(Stmt *AssociatedStmt) {
+
+  if (AssociatedStmt) {
+    // Add the AssumeAttr to the Directive associated with the Assume Directive.
+    OMPAssumeStmtVisitor Visitor(&OMPAssumeScoped);
+    Visitor.Visit(AssociatedStmt);
+  }
+  return AssociatedStmt;
+}
+
 Sema::OMPDeclareVariantScope::OMPDeclareVariantScope(OMPTraitInfo &TI)
     : TI(&TI), NameSuffix(TI.getMangledName()) {}
 
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index e91169e8da1aa..3660917e3e176 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -2089,6 +2089,9 @@ def OMP_Scan : Directive<"scan"> {
   ];
   let association = AS_Separating;
 }
+def OMP_Assume : Directive<"assume"> {
+  let association = AS_Block;
+}
 def OMP_Assumes : Directive<"assumes"> {
   let association = AS_None;
 }

>From 4115f97fc0cf99fa8bc6369635f54ad185a6ffe3 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe28vega.us.cray.com>
Date: Tue, 2 Jul 2024 05:08:09 -0500
Subject: [PATCH 2/4] Adding test for checking -ast-dump.   Changes to be
 committed:  	new file:   clang/test/OpenMP/assume_ast.cpp

---
 clang/test/OpenMP/assume_ast.cpp | 43 ++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)
 create mode 100644 clang/test/OpenMP/assume_ast.cpp

diff --git a/clang/test/OpenMP/assume_ast.cpp b/clang/test/OpenMP/assume_ast.cpp
new file mode 100644
index 0000000000000..9f63ac87391a5
--- /dev/null
+++ b/clang/test/OpenMP/assume_ast.cpp
@@ -0,0 +1,43 @@
+// Check no warnings/errors
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fsyntax-only -verify %s
+// expected-no-diagnostics
+
+// Check AST and unparsing
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -ast-dump  %s | FileCheck %s --check-prefix=DUMP
+
+// Check same results after serialization round-trip
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -emit-pch -o %t %s
+// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix=DUMP
+
+#ifndef HEADER
+#define HEADER
+
+#define N 12
+int A[N];
+int B[N];
+
+
+// DUMP-LABEL:  FunctionDecl {{.*}} main
+int
+main() {
+
+  for (int i = 0; i < N; ++i) {
+    A[i] = 0;
+  }
+
+
+  // assume is for the "simd" region
+  // DUMP:      OMPSimdDirective
+  // DUMP-NEXT:   CapturedStmt
+  // DUMP-NEXT:   CapturedDecl
+  #pragma omp assume no_openmp
+  #pragma omp simd
+  for (int i = 0; i < N; ++i){
+    A[i] += B[i];
+  }
+  // DUMP:  OMPAssumeAttr {{.*}} "omp_no_openmp"
+
+  return 0;
+}
+
+#endif

>From 311d5134cab92f18568892bb0b9428e29aa57421 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe28vega.us.cray.com>
Date: Wed, 3 Jul 2024 02:47:36 -0500
Subject: [PATCH 3/4] Correction for git-clang-format errors.   Changes to be
 committed:  	modified:   clang/lib/Parse/ParseOpenMP.cpp

---
 clang/lib/Parse/ParseOpenMP.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 9267e8792f0bf..78c0211e2058a 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -1768,7 +1768,8 @@ StmtResult Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
       ConsumeAnyToken();
 
     DeclarationNameInfo DirName;
-    Actions.OpenMP().StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc);
+    Actions.OpenMP().StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(),
+                                         Loc);
   }
 
   Actions.OpenMP().ActOnOpenMPAssumesDirective(Loc, DKind, Assumptions,
@@ -1777,8 +1778,8 @@ StmtResult Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
   if (DKind == llvm::omp::Directive::OMPD_assume) {
 
     AssociatedStmt = ParseStatement();
-    AssociatedStmt =
-        Actions.OpenMP().ActOnFinishedStatementInOpenMPAssumeScope(AssociatedStmt.get());
+    AssociatedStmt = Actions.OpenMP().ActOnFinishedStatementInOpenMPAssumeScope(
+        AssociatedStmt.get());
 
     // End the scope for assume.
     ParseOpenMPEndAssumesDirective(Loc);

>From b7d3afffdde20601b498ebad37c6e46a52d660e1 Mon Sep 17 00:00:00 2001
From: Sunil Kuravinakop <kuravina at pe28vega.us.cray.com>
Date: Wed, 3 Jul 2024 13:55:27 -0500
Subject: [PATCH 4/4] Changes to take care of review comments: 1) assume ends
 up with a simple assume intrinsic, right? Most probably non need to emit
 captured function, so it should be OMPD_unknown 2) "Fix the scope for assume"
 reworded to "Begin marking the scope for assume."  	modified:  
 clang/lib/Basic/OpenMPKinds.cpp  	modified:  
 clang/lib/Parse/ParseOpenMP.cpp  	modified:  
 clang/test/OpenMP/assume_ast.cpp

---
 clang/lib/Basic/OpenMPKinds.cpp  | 2 +-
 clang/lib/Parse/ParseOpenMP.cpp  | 4 ++--
 clang/test/OpenMP/assume_ast.cpp | 1 -
 3 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 696a819086af1..a44d19d0eda57 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -761,7 +761,7 @@ void clang::getOpenMPCaptureRegions(
       CaptureRegions.push_back(OMPD_parallel);
       break;
     case OMPD_assume:
-      CaptureRegions.push_back(OMPD_assume);
+      CaptureRegions.push_back(OMPD_unknown);
       break;
     case OMPD_target:
       CaptureRegions.push_back(OMPD_task);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index 78c0211e2058a..e4b8a958d2366 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -1761,7 +1761,7 @@ StmtResult Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
 
   StmtResult AssociatedStmt;
 
-  // Fix the scope for assume.
+  // Begin marking the scope for assume.
   if (DKind == llvm::omp::Directive::OMPD_assume) {
 
     if (Tok.getKind() == clang::tok::annot_pragma_openmp_end)
@@ -1781,7 +1781,7 @@ StmtResult Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
     AssociatedStmt = Actions.OpenMP().ActOnFinishedStatementInOpenMPAssumeScope(
         AssociatedStmt.get());
 
-    // End the scope for assume.
+    // Ending the scope for assume.
     ParseOpenMPEndAssumesDirective(Loc);
     Actions.OpenMP().EndOpenMPDSABlock(nullptr);
     if (Tok.getKind() == clang::tok::annot_pragma_openmp_end)
diff --git a/clang/test/OpenMP/assume_ast.cpp b/clang/test/OpenMP/assume_ast.cpp
index 9f63ac87391a5..078be70c83817 100644
--- a/clang/test/OpenMP/assume_ast.cpp
+++ b/clang/test/OpenMP/assume_ast.cpp
@@ -25,7 +25,6 @@ main() {
     A[i] = 0;
   }
 
-
   // assume is for the "simd" region
   // DUMP:      OMPSimdDirective
   // DUMP-NEXT:   CapturedStmt



More information about the cfe-commits mailing list