[flang-commits] [flang] [Flang][OpenMP][Sema] Module support for REQUIRES directive (PR #77082)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 5 03:24:06 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch adds support for passing REQUIRES clauses across Fortran modules via USE statements through changes to directive resolution, the directive rewrite pass and semantics checks. `.mod` files are also extended to include `!$omp requires` directives so this information can be parsed from external modules and added to the module's symbol.

This is patch 5/5 of a series splitting [D149337](https://reviews.llvm.org/D149337) to simplify review.

Re-created from Phabricator review: https://reviews.llvm.org/D158168

---
Full diff: https://github.com/llvm/llvm-project/pull/77082.diff


9 Files Affected:

- (modified) flang/lib/Semantics/check-omp-structure.cpp (+16) 
- (modified) flang/lib/Semantics/check-omp-structure.h (+2) 
- (modified) flang/lib/Semantics/mod-file.cpp (+43) 
- (modified) flang/lib/Semantics/resolve-directives.cpp (+37-1) 
- (modified) flang/lib/Semantics/rewrite-directives.cpp (+24) 
- (added) flang/test/Semantics/Inputs/requires_module.f90 (+3) 
- (added) flang/test/Semantics/OpenMP/requires10.f90 (+14) 
- (added) flang/test/Semantics/OpenMP/requires11.f90 (+17) 
- (added) flang/test/Semantics/OpenMP/requires12.f90 (+19) 


``````````diff
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index c430375d5ed011..fb7afdcb5cff66 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2055,6 +2055,22 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) {
   dirContext_.pop_back();
 }
 
+void OmpStructureChecker::Enter(const parser::UseStmt &x) {
+  semantics::Symbol *symbol{x.moduleName.symbol};
+  if (!symbol) {
+    // Cannot check used module if it wasn't resolved.
+    return;
+  }
+
+  auto &details = std::get<ModuleDetails>(symbol->details());
+  if (details.has_ompRequires() && deviceConstructFound_) {
+    context_.Say(x.moduleName.source,
+        "'%s' module containing device-related REQUIRES directive imported "
+        "lexically after device construct"_err_en_US,
+        x.moduleName.ToString());
+  }
+}
+
 // Clauses
 // Mainly categorized as
 // 1. Checks on 'OmpClauseList' from 'parse-tree.h'.
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 33243d926cf167..6ac7930e1aa5fe 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -127,6 +127,8 @@ class OmpStructureChecker
   void Enter(const parser::OmpAtomicCapture &);
   void Leave(const parser::OmpAtomic &);
 
+  void Enter(const parser::UseStmt &);
+
 #define GEN_FLANG_CLAUSE_CHECK_ENTER
 #include "llvm/Frontend/OpenMP/OMP.inc"
 
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 70b6bbf8b557ac..75b8ec9091e3ae 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -58,6 +58,8 @@ static void PutShape(
 static llvm::raw_ostream &PutAttr(llvm::raw_ostream &, Attr);
 static llvm::raw_ostream &PutType(llvm::raw_ostream &, const DeclTypeSpec &);
 static llvm::raw_ostream &PutLower(llvm::raw_ostream &, std::string_view);
+static llvm::raw_ostream &PutOmpRequires(
+    llvm::raw_ostream &, const WithOmpDeclarative &);
 static std::error_code WriteFile(
     const std::string &, const std::string &, bool = true);
 static bool FileContentsMatch(
@@ -163,6 +165,7 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
   uses_.str().clear();
   all << useExtraAttrs_.str();
   useExtraAttrs_.str().clear();
+  PutOmpRequires(all, details);
   all << decls_.str();
   decls_.str().clear();
   auto str{contains_.str()};
@@ -604,6 +607,8 @@ void ModFileWriter::PutSubprogram(const Symbol &symbol) {
     }
   }
   os << '\n';
+  // print OpenMP requires
+  PutOmpRequires(os, details);
   // walk symbols, collect ones needed for interface
   const Scope &scope{
       details.entryScope() ? *details.entryScope() : DEREF(symbol.scope())};
@@ -995,6 +1000,44 @@ llvm::raw_ostream &PutLower(llvm::raw_ostream &os, std::string_view str) {
   return os;
 }
 
+llvm::raw_ostream &PutOmpRequires(
+    llvm::raw_ostream &os, const WithOmpDeclarative &details) {
+  if (details.has_ompRequires() || details.has_ompAtomicDefaultMemOrder()) {
+    os << "!$omp requires";
+    if (auto *flags{details.ompRequires()}) {
+      if (flags->test(WithOmpDeclarative::RequiresFlag::ReverseOffload)) {
+        os << " reverse_offload";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::UnifiedAddress)) {
+        os << " unified_address";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::UnifiedSharedMemory)) {
+        os << " unified_shared_memory";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::DynamicAllocators)) {
+        os << " dynamic_allocators";
+      }
+    }
+    if (auto *memOrder{details.ompAtomicDefaultMemOrder()}) {
+      os << " atomic_default_mem_order(";
+      switch (*memOrder) {
+      case common::OmpAtomicDefaultMemOrderType::SeqCst:
+        os << "seq_cst";
+        break;
+      case common::OmpAtomicDefaultMemOrderType::AcqRel:
+        os << "acq_rel";
+        break;
+      case common::OmpAtomicDefaultMemOrderType::Relaxed:
+        os << "relaxed";
+        break;
+      }
+      os << ')';
+    }
+    os << '\n';
+  }
+  return os;
+}
+
 void PutOpenACCDirective(llvm::raw_ostream &os, const Symbol &symbol) {
   if (symbol.test(Symbol::Flag::AccDeclare)) {
     os << "!$acc declare ";
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index da6c865ad56a3b..bef8ea60faceff 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -549,6 +549,42 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
 
   void Post(const parser::Name &);
 
+  void Post(const parser::SpecificationPart &x) {
+    // Look at USE statements from here so that we also have access to a source
+    // object, needed to tie the statement to a scope.
+    for (auto &stmt :
+        std::get<
+            std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
+            x.t)) {
+      const parser::UseStmt &useStmt = stmt.statement.value();
+      Symbol *moduleSym{useStmt.moduleName.symbol};
+      if (!moduleSym) {
+        continue;
+      }
+
+      // Gather information from the imported module's symbol details.
+      WithOmpDeclarative::RequiresFlags flags;
+      std::optional<common::OmpAtomicDefaultMemOrderType> memOrder;
+      common::visit(
+          [&](auto &details) {
+            if constexpr (std::is_base_of_v<ModuleDetails,
+                              std::decay_t<decltype(details)>>) {
+              if (details.has_ompRequires()) {
+                flags = *details.ompRequires();
+              }
+              if (details.has_ompAtomicDefaultMemOrder()) {
+                memOrder = *details.ompAtomicDefaultMemOrder();
+              }
+            }
+          },
+          moduleSym->details());
+
+      // Merge requires clauses into USE statement's parents.
+      Scope &scope = context_.FindScope(stmt.source);
+      AddOmpRequiresToScope(scope, flags, memOrder);
+    }
+  }
+
   // Keep track of labels in the statements that causes jumps to target labels
   void Post(const parser::GotoStmt &gotoStmt) { CheckSourceLabel(gotoStmt.v); }
   void Post(const parser::ComputedGotoStmt &computedGotoStmt) {
@@ -2279,7 +2315,7 @@ void ResolveOmpTopLevelParts(
 
   // Gather REQUIRES clauses from all non-module top-level program unit symbols,
   // combine them together ensuring compatibility and apply them to all these
-  // program units. Modules are skipped because their REQUIRES clauses should be
+  // program units. Modules are skipped because their REQUIRES clauses are
   // propagated via USE statements instead.
   WithOmpDeclarative::RequiresFlags combinedFlags;
   std::optional<common::OmpAtomicDefaultMemOrderType> combinedMemOrder;
diff --git a/flang/lib/Semantics/rewrite-directives.cpp b/flang/lib/Semantics/rewrite-directives.cpp
index bab91d25308225..76fe17c059739b 100644
--- a/flang/lib/Semantics/rewrite-directives.cpp
+++ b/flang/lib/Semantics/rewrite-directives.cpp
@@ -44,6 +44,7 @@ class OmpRewriteMutator : public DirectiveRewriteMutator {
 
   bool Pre(parser::OpenMPAtomicConstruct &);
   bool Pre(parser::OpenMPRequiresConstruct &);
+  void Post(parser::UseStmt &);
 
 private:
   bool atomicDirectiveDefaultOrderFound_{false};
@@ -165,6 +166,29 @@ bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) {
   return false;
 }
 
+// Check that a module containing a REQUIRES statement with the
+// `atomic_default_mem_order` clause is not USEd after an atomic operation
+// without memory order defined.
+void OmpRewriteMutator::Post(parser::UseStmt &x) {
+  semantics::Symbol *symbol{x.moduleName.symbol};
+  if (!symbol) {
+    // Cannot check used module if it wasn't resolved.
+    return;
+  }
+
+  auto *details = symbol->detailsIf<ModuleDetails>();
+  if (atomicDirectiveDefaultOrderFound_ && details &&
+      details->has_ompAtomicDefaultMemOrder()) {
+    context_.Say(x.moduleName.source,
+        "'%s' module containing '%s' REQUIRES clause imported lexically after "
+        "atomic operation without a memory order clause"_err_en_US,
+        x.moduleName.ToString(),
+        parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+            llvm::omp::OMPC_atomic_default_mem_order)
+                                       .str()));
+  }
+}
+
 bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) {
   if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
     return true;
diff --git a/flang/test/Semantics/Inputs/requires_module.f90 b/flang/test/Semantics/Inputs/requires_module.f90
new file mode 100644
index 00000000000000..2ce0d03a17d61c
--- /dev/null
+++ b/flang/test/Semantics/Inputs/requires_module.f90
@@ -0,0 +1,3 @@
+module requires_module
+  !$omp requires atomic_default_mem_order(seq_cst), unified_shared_memory
+end module
diff --git a/flang/test/Semantics/OpenMP/requires10.f90 b/flang/test/Semantics/OpenMP/requires10.f90
new file mode 100644
index 00000000000000..1c5a60f1899360
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires10.f90
@@ -0,0 +1,14 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! All atomic_default_mem_order clauses in REQUIRES directives found within a
+! compilation unit must specify the same ordering. Test that this is propagated
+! from imported modules
+
+!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit
+use requires_module
+!$omp requires atomic_default_mem_order(relaxed)
+
+end program
diff --git a/flang/test/Semantics/OpenMP/requires11.f90 b/flang/test/Semantics/OpenMP/requires11.f90
new file mode 100644
index 00000000000000..56b4e676fe8d65
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires11.f90
@@ -0,0 +1,17 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! Target-related clauses in REQUIRES directives must come strictly before any
+! device constructs, such as declare target with extended list. Test that this
+! is propagated from imported modules.
+
+subroutine f
+  !$omp declare target (f)
+end subroutine f
+
+program requires
+  !ERROR: 'requires_module' module containing device-related REQUIRES directive imported lexically after device construct
+  use requires_module
+end program requires
diff --git a/flang/test/Semantics/OpenMP/requires12.f90 b/flang/test/Semantics/OpenMP/requires12.f90
new file mode 100644
index 00000000000000..9762823fd0c3f7
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires12.f90
@@ -0,0 +1,19 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! atomic_default_mem_order clauses in REQUIRES directives must come strictly
+! before any atomic constructs with no explicit memory order set. Test that this
+! is propagated from imported modules.
+
+subroutine f
+  integer :: a = 0
+  !$omp atomic
+  a = a + 1
+end subroutine f
+
+program requires
+  !ERROR: 'requires_module' module containing 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clause imported lexically after atomic operation without a memory order clause
+  use requires_module
+end program requires

``````````

</details>


https://github.com/llvm/llvm-project/pull/77082


More information about the flang-commits mailing list