[flang-commits] [flang] a74e9ce - [OpenMP] atomic compare weak : Parser & AST support (#79475)

via flang-commits flang-commits at lists.llvm.org
Wed Jan 31 03:32:12 PST 2024


Author: SunilKuravinakop
Date: 2024-01-31T06:32:06-05:00
New Revision: a74e9ce5dc5cc746bd625b0ef20524a90a073375

URL: https://github.com/llvm/llvm-project/commit/a74e9ce5dc5cc746bd625b0ef20524a90a073375
DIFF: https://github.com/llvm/llvm-project/commit/a74e9ce5dc5cc746bd625b0ef20524a90a073375.diff

LOG: [OpenMP] atomic compare weak : Parser & AST support (#79475)

This is a support for " #pragma omp atomic compare weak". It has Parser
& AST support for now.

---------

Authored-by: Sunil Kuravinakop <kuravina at pe28vega.us.cray.com>

Added: 
    

Modified: 
    clang/include/clang/AST/OpenMPClause.h
    clang/include/clang/AST/RecursiveASTVisitor.h
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/include/clang/Sema/Sema.h
    clang/lib/AST/OpenMPClause.cpp
    clang/lib/AST/StmtProfile.cpp
    clang/lib/CodeGen/CGStmtOpenMP.cpp
    clang/lib/Parse/ParseOpenMP.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Sema/TreeTransform.h
    clang/lib/Serialization/ASTReader.cpp
    clang/lib/Serialization/ASTWriter.cpp
    clang/test/OpenMP/atomic_ast_print.cpp
    clang/test/OpenMP/atomic_messages.cpp
    clang/tools/libclang/CIndex.cpp
    flang/lib/Semantics/check-omp-structure.cpp
    llvm/include/llvm/Frontend/OpenMP/OMP.td

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 924ca189381ba..325a1baa44614 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -2513,6 +2513,46 @@ class OMPRelaxedClause final : public OMPClause {
   }
 };
 
+/// This represents 'weak' clause in the '#pragma omp atomic'
+/// directives.
+///
+/// \code
+/// #pragma omp atomic compare weak
+/// \endcode
+/// In this example directive '#pragma omp atomic' has 'weak' clause.
+class OMPWeakClause final : public OMPClause {
+public:
+  /// Build 'weak' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPWeakClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_weak, StartLoc, EndLoc) {}
+
+  /// Build an empty clause.
+  OMPWeakClause()
+      : OMPClause(llvm::omp::OMPC_weak, SourceLocation(), SourceLocation()) {}
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_weak;
+  }
+};
+
 /// This represents 'fail' clause in the '#pragma omp atomic'
 /// directive.
 ///

diff  --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 16b98b364eaca..9da5206a21c34 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3436,6 +3436,11 @@ bool RecursiveASTVisitor<Derived>::VisitOMPRelaxedClause(OMPRelaxedClause *) {
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPWeakClause(OMPWeakClause *) {
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPThreadsClause(OMPThreadsClause *) {
   return true;

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1c0ebbe4e6834..7638a7e84c3c0 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11042,7 +11042,8 @@ def note_omp_atomic_compare: Note<
   "expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|"
   "expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">;
 def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">;
-def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">;
+def err_omp_atomic_no_compare : Error<"expected 'compare' clause with the '%0' modifier">;
+def err_omp_atomic_weak_no_equality : Error<"expected '==' operator for 'weak' clause">;
 def err_omp_atomic_several_clauses : Error<
   "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 490df817ddfa2..1ac2a465a0d59 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -12377,6 +12377,9 @@ class Sema final {
   /// Called on well-formed 'relaxed' clause.
   OMPClause *ActOnOpenMPRelaxedClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
+  /// Called on well-formed 'weak' clause.
+  OMPClause *ActOnOpenMPWeakClause(SourceLocation StartLoc,
+                                   SourceLocation EndLoc);
 
   /// Called on well-formed 'init' clause.
   OMPClause *

diff  --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 04f680a8f5c92..042a5df5906ca 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1957,6 +1957,8 @@ void OMPClausePrinter::VisitOMPRelaxedClause(OMPRelaxedClause *) {
   OS << "relaxed";
 }
 
+void OMPClausePrinter::VisitOMPWeakClause(OMPWeakClause *) { OS << "weak"; }
+
 void OMPClausePrinter::VisitOMPThreadsClause(OMPThreadsClause *) {
   OS << "threads";
 }

diff  --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 7459f967da095..1b817cf58b999 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -594,6 +594,8 @@ void OMPClauseProfiler::VisitOMPReleaseClause(const OMPReleaseClause *) {}
 
 void OMPClauseProfiler::VisitOMPRelaxedClause(const OMPRelaxedClause *) {}
 
+void OMPClauseProfiler::VisitOMPWeakClause(const OMPWeakClause *) {}
+
 void OMPClauseProfiler::VisitOMPThreadsClause(const OMPThreadsClause *) {}
 
 void OMPClauseProfiler::VisitOMPSIMDClause(const OMPSIMDClause *) {}

diff  --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index e362c9da51fe3..8fd74697de3c0 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6546,6 +6546,9 @@ void CodeGenFunction::EmitOMPAtomicDirective(const OMPAtomicDirective &S) {
     // Find first clause (skip seq_cst|acq_rel|aqcuire|release|relaxed clause,
     // if it is first).
     OpenMPClauseKind K = C->getClauseKind();
+    // TBD
+    if (K == OMPC_weak)
+      return;
     if (K == OMPC_seq_cst || K == OMPC_acq_rel || K == OMPC_acquire ||
         K == OMPC_release || K == OMPC_relaxed || K == OMPC_hint)
       continue;

diff  --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index da5f6605c6ffa..bfc31f2653c23 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3314,6 +3314,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_acquire:
   case OMPC_release:
   case OMPC_relaxed:
+  case OMPC_weak:
   case OMPC_threads:
   case OMPC_simd:
   case OMPC_nogroup:

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index b51f25c791b44..1556f6e854613 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -12708,9 +12708,11 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
       }
       break;
     }
+    case OMPC_weak:
     case OMPC_fail: {
       if (!EncounteredAtomicKinds.contains(OMPC_compare)) {
-        Diag(C->getBeginLoc(), diag::err_omp_atomic_fail_no_compare)
+        Diag(C->getBeginLoc(), diag::err_omp_atomic_no_compare)
+            << getOpenMPClauseName(C->getClauseKind())
             << SourceRange(C->getBeginLoc(), C->getEndLoc());
         return StmtError();
       }
@@ -13202,6 +13204,27 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
       E = Checker.getE();
       D = Checker.getD();
       CE = Checker.getCond();
+      // The weak clause may only appear if the resulting atomic operation is
+      // an atomic conditional update for which the comparison tests for
+      // equality. It was not possible to do this check in
+      // OpenMPAtomicCompareChecker::checkStmt() as the check for OMPC_weak
+      // could not be performed (Clauses are not available).
+      auto *It = find_if(Clauses, [](OMPClause *C) {
+        return C->getClauseKind() == llvm::omp::Clause::OMPC_weak;
+      });
+      if (It != Clauses.end()) {
+        auto *Cond = dyn_cast<BinaryOperator>(CE);
+        if (Cond->getOpcode() != BO_EQ) {
+          ErrorInfo.Error = Checker.ErrorTy::NotAnAssignment;
+          ErrorInfo.ErrorLoc = Cond->getExprLoc();
+          ErrorInfo.NoteLoc = Cond->getOperatorLoc();
+          ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+
+          Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_weak_no_equality)
+              << ErrorInfo.ErrorRange;
+          return StmtError();
+        }
+      }
       // We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
       IsXLHSInRHSPart = Checker.isXBinopExpr();
     }
@@ -17593,6 +17616,9 @@ OMPClause *Sema::ActOnOpenMPClause(OpenMPClauseKind Kind,
   case OMPC_relaxed:
     Res = ActOnOpenMPRelaxedClause(StartLoc, EndLoc);
     break;
+  case OMPC_weak:
+    Res = ActOnOpenMPWeakClause(StartLoc, EndLoc);
+    break;
   case OMPC_threads:
     Res = ActOnOpenMPThreadsClause(StartLoc, EndLoc);
     break;
@@ -17781,6 +17807,11 @@ OMPClause *Sema::ActOnOpenMPRelaxedClause(SourceLocation StartLoc,
   return new (Context) OMPRelaxedClause(StartLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPWeakClause(SourceLocation StartLoc,
+                                       SourceLocation EndLoc) {
+  return new (Context) OMPWeakClause(StartLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPThreadsClause(SourceLocation StartLoc,
                                           SourceLocation EndLoc) {
   return new (Context) OMPThreadsClause(StartLoc, EndLoc);

diff  --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e252c3e2ffc0e..4dc36a8a5f4f0 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -10075,6 +10075,12 @@ TreeTransform<Derived>::TransformOMPRelaxedClause(OMPRelaxedClause *C) {
   return C;
 }
 
+template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPWeakClause(OMPWeakClause *C) {
+  // No need to rebuild this clause, no template-dependent parameters.
+  return C;
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPThreadsClause(OMPThreadsClause *C) {

diff  --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 04ab42fa18826..895037a727d4d 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -10349,6 +10349,9 @@ OMPClause *OMPClauseReader::readClause() {
   case llvm::omp::OMPC_relaxed:
     C = new (Context) OMPRelaxedClause();
     break;
+  case llvm::omp::OMPC_weak:
+    C = new (Context) OMPWeakClause();
+    break;
   case llvm::omp::OMPC_threads:
     C = new (Context) OMPThreadsClause();
     break;
@@ -10747,6 +10750,8 @@ void OMPClauseReader::VisitOMPReleaseClause(OMPReleaseClause *) {}
 
 void OMPClauseReader::VisitOMPRelaxedClause(OMPRelaxedClause *) {}
 
+void OMPClauseReader::VisitOMPWeakClause(OMPWeakClause *) {}
+
 void OMPClauseReader::VisitOMPThreadsClause(OMPThreadsClause *) {}
 
 void OMPClauseReader::VisitOMPSIMDClause(OMPSIMDClause *) {}

diff  --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 9eefa7d1f7015..f33a8f146e609 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -6698,6 +6698,8 @@ void OMPClauseWriter::VisitOMPReleaseClause(OMPReleaseClause *) {}
 
 void OMPClauseWriter::VisitOMPRelaxedClause(OMPRelaxedClause *) {}
 
+void OMPClauseWriter::VisitOMPWeakClause(OMPWeakClause *) {}
+
 void OMPClauseWriter::VisitOMPThreadsClause(OMPThreadsClause *) {}
 
 void OMPClauseWriter::VisitOMPSIMDClause(OMPSIMDClause *) {}

diff  --git a/clang/test/OpenMP/atomic_ast_print.cpp b/clang/test/OpenMP/atomic_ast_print.cpp
index 6886ef29a6752..b16a5fcdbdd3e 100644
--- a/clang/test/OpenMP/atomic_ast_print.cpp
+++ b/clang/test/OpenMP/atomic_ast_print.cpp
@@ -232,6 +232,8 @@ T foo(T argc) {
   { if (a < c) { a = c; } }
 #pragma omp atomic compare fail(seq_cst)
   { if (a < c) { a = c; } }
+#pragma omp atomic compare seq_cst weak
+  { if(a == b) { a = c; } }
 #endif
   return T();
 }
@@ -1111,6 +1113,8 @@ int main(int argc, char **argv) {
   if(a < b) { a = b; }
 #pragma omp atomic compare fail(seq_cst)
   if(a < b) { a = b; }
+#pragma omp atomic compare seq_cst weak
+  if(a == b) { a = c; }
 #endif
   // CHECK-NEXT: #pragma omp atomic
   // CHECK-NEXT: a++;
@@ -1453,6 +1457,10 @@ int main(int argc, char **argv) {
   // CHECK-51-NEXT: if (a < b) {
   // CHECK-51-NEXT: a = b;
   // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare seq_cst weak
+  // CHECK-51-NEXT: if (a == b) {
+  // CHECK-51-NEXT: a = c;
+  // CHECK-51-NEXT: }
   // expect-note at +1 {{in instantiation of function template specialization 'foo<int>' requested here}}
   return foo(a);
 }

diff  --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp
index b09e3318cd9ac..d492f6ee1e896 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -974,6 +974,17 @@ int mixed() {
 // expected-error at +1 {{directive '#pragma omp atomic' cannot contain more than one 'fail' clause}}
 #pragma omp atomic compare fail(relaxed) fail(seq_cst)
   if(v < a) { v = a; }
+#pragma omp atomic compare seq_cst weak
+  if(v == a) { v = a; }
+// expected-error at +1 {{expected 'compare' clause with the 'weak' modifier}}
+#pragma omp atomic weak
+  if(v < a) { v = a; }
+#pragma omp atomic compare release weak
+// expected-error at +1 {{expected '==' operator for 'weak' clause}}
+  if(v < a) { v = a; }
+// expected-error at +1 {{directive '#pragma omp atomic' cannot contain more than one 'weak' clause}}
+#pragma omp atomic compare release weak fail(seq_cst) weak
+  if(v == a) { v = a; }
 
 
 #endif

diff  --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 0f92d3395a3f5..e5c0971996e01 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -2425,6 +2425,8 @@ void OMPClauseEnqueue::VisitOMPReleaseClause(const OMPReleaseClause *) {}
 
 void OMPClauseEnqueue::VisitOMPRelaxedClause(const OMPRelaxedClause *) {}
 
+void OMPClauseEnqueue::VisitOMPWeakClause(const OMPWeakClause *) {}
+
 void OMPClauseEnqueue::VisitOMPThreadsClause(const OMPThreadsClause *) {}
 
 void OMPClauseEnqueue::VisitOMPSIMDClause(const OMPSIMDClause *) {}

diff  --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 22e9d5d4dd799..03423de0c6104 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2264,6 +2264,7 @@ CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
 CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
 CHECK_SIMPLE_CLAUSE(Enter, OMPC_enter)
 CHECK_SIMPLE_CLAUSE(Fail, OMPC_fail)
+CHECK_SIMPLE_CLAUSE(Weak, OMPC_weak)
 
 CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize)
 CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks)

diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 2388abac81ceb..1481328bf483b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -215,6 +215,7 @@ def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
 def OMPC_Release : Clause<"release"> { let clangClass = "OMPReleaseClause"; }
 def OMPC_Relaxed : Clause<"relaxed"> { let clangClass = "OMPRelaxedClause"; }
+def OMPC_Weak : Clause<"weak"> { let clangClass = "OMPWeakClause"; }
 def OMPC_Depend : Clause<"depend"> {
   let clangClass = "OMPDependClause";
   let flangClass = "OmpDependClause";
@@ -642,7 +643,8 @@ def OMP_Atomic : Directive<"atomic"> {
     VersionedClause<OMPC_Release, 50>,
     VersionedClause<OMPC_Relaxed, 50>,
     VersionedClause<OMPC_Hint, 50>,
-    VersionedClause<OMPC_Fail, 51>
+    VersionedClause<OMPC_Fail, 51>,
+    VersionedClause<OMPC_Weak, 51>
   ];
 }
 def OMP_Target : Directive<"target"> {


        


More information about the flang-commits mailing list