[flang-commits] [flang] [Flang][OpenMP][Sema] Module support for REQUIRES directive (PR #77082)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Fri Jan 5 03:23:39 PST 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/77082

This patch adds support for passing REQUIRES clauses across Fortran modules via USE statements through changes to directive resolution, the directive rewrite pass and semantics checks. `.mod` files are also extended to include `!$omp requires` directives so this information can be parsed from external modules and added to the module's symbol.

This is patch 5/5 of a series splitting [D149337](https://reviews.llvm.org/D149337) to simplify review.

Re-created from Phabricator review: https://reviews.llvm.org/D158168

>From 2e66f9dbd35303049da7aaf05cad4fa1826082ae Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 17 Aug 2023 11:19:07 +0100
Subject: [PATCH] [Flang][OpenMP][Sema] Module support for REQUIRES directive

This patch adds support for passing REQUIRES clauses across Fortran modules via
USE statements through changes to directive resolution, the directive rewrite
pass and semantics checks. `.mod` files are also extended to include `!$omp
requires` directives so this information can be parsed from external modules
and added to the module's symbol.

This is patch 5/5 of a series splitting
[D149337](https://reviews.llvm.org/D149337) to simplify review.

Re-created from Phabricator review: https://reviews.llvm.org/D158168
---
 flang/lib/Semantics/check-omp-structure.cpp   | 16 +++++++
 flang/lib/Semantics/check-omp-structure.h     |  2 +
 flang/lib/Semantics/mod-file.cpp              | 43 +++++++++++++++++++
 flang/lib/Semantics/resolve-directives.cpp    | 38 +++++++++++++++-
 flang/lib/Semantics/rewrite-directives.cpp    | 24 +++++++++++
 .../test/Semantics/Inputs/requires_module.f90 |  3 ++
 flang/test/Semantics/OpenMP/requires10.f90    | 14 ++++++
 flang/test/Semantics/OpenMP/requires11.f90    | 17 ++++++++
 flang/test/Semantics/OpenMP/requires12.f90    | 19 ++++++++
 9 files changed, 175 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/Semantics/Inputs/requires_module.f90
 create mode 100644 flang/test/Semantics/OpenMP/requires10.f90
 create mode 100644 flang/test/Semantics/OpenMP/requires11.f90
 create mode 100644 flang/test/Semantics/OpenMP/requires12.f90

diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index c430375d5ed011..fb7afdcb5cff66 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2055,6 +2055,22 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) {
   dirContext_.pop_back();
 }
 
+void OmpStructureChecker::Enter(const parser::UseStmt &x) {
+  semantics::Symbol *symbol{x.moduleName.symbol};
+  if (!symbol) {
+    // Cannot check used module if it wasn't resolved.
+    return;
+  }
+
+  auto &details = std::get<ModuleDetails>(symbol->details());
+  if (details.has_ompRequires() && deviceConstructFound_) {
+    context_.Say(x.moduleName.source,
+        "'%s' module containing device-related REQUIRES directive imported "
+        "lexically after device construct"_err_en_US,
+        x.moduleName.ToString());
+  }
+}
+
 // Clauses
 // Mainly categorized as
 // 1. Checks on 'OmpClauseList' from 'parse-tree.h'.
diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index 33243d926cf167..6ac7930e1aa5fe 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -127,6 +127,8 @@ class OmpStructureChecker
   void Enter(const parser::OmpAtomicCapture &);
   void Leave(const parser::OmpAtomic &);
 
+  void Enter(const parser::UseStmt &);
+
 #define GEN_FLANG_CLAUSE_CHECK_ENTER
 #include "llvm/Frontend/OpenMP/OMP.inc"
 
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 70b6bbf8b557ac..75b8ec9091e3ae 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -58,6 +58,8 @@ static void PutShape(
 static llvm::raw_ostream &PutAttr(llvm::raw_ostream &, Attr);
 static llvm::raw_ostream &PutType(llvm::raw_ostream &, const DeclTypeSpec &);
 static llvm::raw_ostream &PutLower(llvm::raw_ostream &, std::string_view);
+static llvm::raw_ostream &PutOmpRequires(
+    llvm::raw_ostream &, const WithOmpDeclarative &);
 static std::error_code WriteFile(
     const std::string &, const std::string &, bool = true);
 static bool FileContentsMatch(
@@ -163,6 +165,7 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
   uses_.str().clear();
   all << useExtraAttrs_.str();
   useExtraAttrs_.str().clear();
+  PutOmpRequires(all, details);
   all << decls_.str();
   decls_.str().clear();
   auto str{contains_.str()};
@@ -604,6 +607,8 @@ void ModFileWriter::PutSubprogram(const Symbol &symbol) {
     }
   }
   os << '\n';
+  // print OpenMP requires
+  PutOmpRequires(os, details);
   // walk symbols, collect ones needed for interface
   const Scope &scope{
       details.entryScope() ? *details.entryScope() : DEREF(symbol.scope())};
@@ -995,6 +1000,44 @@ llvm::raw_ostream &PutLower(llvm::raw_ostream &os, std::string_view str) {
   return os;
 }
 
+llvm::raw_ostream &PutOmpRequires(
+    llvm::raw_ostream &os, const WithOmpDeclarative &details) {
+  if (details.has_ompRequires() || details.has_ompAtomicDefaultMemOrder()) {
+    os << "!$omp requires";
+    if (auto *flags{details.ompRequires()}) {
+      if (flags->test(WithOmpDeclarative::RequiresFlag::ReverseOffload)) {
+        os << " reverse_offload";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::UnifiedAddress)) {
+        os << " unified_address";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::UnifiedSharedMemory)) {
+        os << " unified_shared_memory";
+      }
+      if (flags->test(WithOmpDeclarative::RequiresFlag::DynamicAllocators)) {
+        os << " dynamic_allocators";
+      }
+    }
+    if (auto *memOrder{details.ompAtomicDefaultMemOrder()}) {
+      os << " atomic_default_mem_order(";
+      switch (*memOrder) {
+      case common::OmpAtomicDefaultMemOrderType::SeqCst:
+        os << "seq_cst";
+        break;
+      case common::OmpAtomicDefaultMemOrderType::AcqRel:
+        os << "acq_rel";
+        break;
+      case common::OmpAtomicDefaultMemOrderType::Relaxed:
+        os << "relaxed";
+        break;
+      }
+      os << ')';
+    }
+    os << '\n';
+  }
+  return os;
+}
+
 void PutOpenACCDirective(llvm::raw_ostream &os, const Symbol &symbol) {
   if (symbol.test(Symbol::Flag::AccDeclare)) {
     os << "!$acc declare ";
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index da6c865ad56a3b..bef8ea60faceff 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -549,6 +549,42 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
 
   void Post(const parser::Name &);
 
+  void Post(const parser::SpecificationPart &x) {
+    // Look at USE statements from here so that we also have access to a source
+    // object, needed to tie the statement to a scope.
+    for (auto &stmt :
+        std::get<
+            std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
+            x.t)) {
+      const parser::UseStmt &useStmt = stmt.statement.value();
+      Symbol *moduleSym{useStmt.moduleName.symbol};
+      if (!moduleSym) {
+        continue;
+      }
+
+      // Gather information from the imported module's symbol details.
+      WithOmpDeclarative::RequiresFlags flags;
+      std::optional<common::OmpAtomicDefaultMemOrderType> memOrder;
+      common::visit(
+          [&](auto &details) {
+            if constexpr (std::is_base_of_v<ModuleDetails,
+                              std::decay_t<decltype(details)>>) {
+              if (details.has_ompRequires()) {
+                flags = *details.ompRequires();
+              }
+              if (details.has_ompAtomicDefaultMemOrder()) {
+                memOrder = *details.ompAtomicDefaultMemOrder();
+              }
+            }
+          },
+          moduleSym->details());
+
+      // Merge requires clauses into USE statement's parents.
+      Scope &scope = context_.FindScope(stmt.source);
+      AddOmpRequiresToScope(scope, flags, memOrder);
+    }
+  }
+
   // Keep track of labels in the statements that causes jumps to target labels
   void Post(const parser::GotoStmt &gotoStmt) { CheckSourceLabel(gotoStmt.v); }
   void Post(const parser::ComputedGotoStmt &computedGotoStmt) {
@@ -2279,7 +2315,7 @@ void ResolveOmpTopLevelParts(
 
   // Gather REQUIRES clauses from all non-module top-level program unit symbols,
   // combine them together ensuring compatibility and apply them to all these
-  // program units. Modules are skipped because their REQUIRES clauses should be
+  // program units. Modules are skipped because their REQUIRES clauses are
   // propagated via USE statements instead.
   WithOmpDeclarative::RequiresFlags combinedFlags;
   std::optional<common::OmpAtomicDefaultMemOrderType> combinedMemOrder;
diff --git a/flang/lib/Semantics/rewrite-directives.cpp b/flang/lib/Semantics/rewrite-directives.cpp
index bab91d25308225..76fe17c059739b 100644
--- a/flang/lib/Semantics/rewrite-directives.cpp
+++ b/flang/lib/Semantics/rewrite-directives.cpp
@@ -44,6 +44,7 @@ class OmpRewriteMutator : public DirectiveRewriteMutator {
 
   bool Pre(parser::OpenMPAtomicConstruct &);
   bool Pre(parser::OpenMPRequiresConstruct &);
+  void Post(parser::UseStmt &);
 
 private:
   bool atomicDirectiveDefaultOrderFound_{false};
@@ -165,6 +166,29 @@ bool OmpRewriteMutator::Pre(parser::OpenMPRequiresConstruct &x) {
   return false;
 }
 
+// Check that a module containing a REQUIRES statement with the
+// `atomic_default_mem_order` clause is not USEd after an atomic operation
+// without memory order defined.
+void OmpRewriteMutator::Post(parser::UseStmt &x) {
+  semantics::Symbol *symbol{x.moduleName.symbol};
+  if (!symbol) {
+    // Cannot check used module if it wasn't resolved.
+    return;
+  }
+
+  auto *details = symbol->detailsIf<ModuleDetails>();
+  if (atomicDirectiveDefaultOrderFound_ && details &&
+      details->has_ompAtomicDefaultMemOrder()) {
+    context_.Say(x.moduleName.source,
+        "'%s' module containing '%s' REQUIRES clause imported lexically after "
+        "atomic operation without a memory order clause"_err_en_US,
+        x.moduleName.ToString(),
+        parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+            llvm::omp::OMPC_atomic_default_mem_order)
+                                       .str()));
+  }
+}
+
 bool RewriteOmpParts(SemanticsContext &context, parser::Program &program) {
   if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
     return true;
diff --git a/flang/test/Semantics/Inputs/requires_module.f90 b/flang/test/Semantics/Inputs/requires_module.f90
new file mode 100644
index 00000000000000..2ce0d03a17d61c
--- /dev/null
+++ b/flang/test/Semantics/Inputs/requires_module.f90
@@ -0,0 +1,3 @@
+module requires_module
+  !$omp requires atomic_default_mem_order(seq_cst), unified_shared_memory
+end module
diff --git a/flang/test/Semantics/OpenMP/requires10.f90 b/flang/test/Semantics/OpenMP/requires10.f90
new file mode 100644
index 00000000000000..1c5a60f1899360
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires10.f90
@@ -0,0 +1,14 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! All atomic_default_mem_order clauses in REQUIRES directives found within a
+! compilation unit must specify the same ordering. Test that this is propagated
+! from imported modules
+
+!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit
+use requires_module
+!$omp requires atomic_default_mem_order(relaxed)
+
+end program
diff --git a/flang/test/Semantics/OpenMP/requires11.f90 b/flang/test/Semantics/OpenMP/requires11.f90
new file mode 100644
index 00000000000000..56b4e676fe8d65
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires11.f90
@@ -0,0 +1,17 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! Target-related clauses in REQUIRES directives must come strictly before any
+! device constructs, such as declare target with extended list. Test that this
+! is propagated from imported modules.
+
+subroutine f
+  !$omp declare target (f)
+end subroutine f
+
+program requires
+  !ERROR: 'requires_module' module containing device-related REQUIRES directive imported lexically after device construct
+  use requires_module
+end program requires
diff --git a/flang/test/Semantics/OpenMP/requires12.f90 b/flang/test/Semantics/OpenMP/requires12.f90
new file mode 100644
index 00000000000000..9762823fd0c3f7
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires12.f90
@@ -0,0 +1,19 @@
+! RUN: rm -rf %t && mkdir %t
+! RUN: %flang_fc1 -fsyntax-only -fopenmp -module-dir %t '%S/../Inputs/requires_module.f90'
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp -module-dir %t
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! atomic_default_mem_order clauses in REQUIRES directives must come strictly
+! before any atomic constructs with no explicit memory order set. Test that this
+! is propagated from imported modules.
+
+subroutine f
+  integer :: a = 0
+  !$omp atomic
+  a = a + 1
+end subroutine f
+
+program requires
+  !ERROR: 'requires_module' module containing 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clause imported lexically after atomic operation without a memory order clause
+  use requires_module
+end program requires



More information about the flang-commits mailing list