[flang-commits] [flang] 3787fd9 - [Flang][OpenMP][Sema] Support propagation of REQUIRES information across program units

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Mon Sep 11 03:48:26 PDT 2023


Author: Sergio Afonso
Date: 2023-09-11T11:48:07+01:00
New Revision: 3787fd942f3927345320cc97a479f13e44355805

URL: https://github.com/llvm/llvm-project/commit/3787fd942f3927345320cc97a479f13e44355805
DIFF: https://github.com/llvm/llvm-project/commit/3787fd942f3927345320cc97a479f13e44355805.diff

LOG: [Flang][OpenMP][Sema] Support propagation of REQUIRES information across program units

This patch adds support for storing OpenMP REQUIRES information in the
semantics symbols for programs/subprograms and modules/submodules, and
populates them during directive resolution. A pass is added to name resolution
that makes sure this information is also propagated across top-level programs,
functions and subprograms.

Storing REQUIRES information inside of semantics symbols will also allow
supporting the propagation of this information across Fortran modules. This
will come as a separate patch.

The `bool DirectiveAttributeVisitor::Pre(const parser::SpecificationPart &x)`
method is removed since it resulted in specification parts being visited twice.

This is patch 3/5 of a series splitting D149337 to simplify review.

Differential Revision: https://reviews.llvm.org/D157983

Added: 
    flang/test/Semantics/OpenMP/requires09.f90

Modified: 
    flang/examples/FeatureList/FeatureList.cpp
    flang/include/flang/Common/Fortran.h
    flang/include/flang/Parser/dump-parse-tree.h
    flang/include/flang/Parser/parse-tree.h
    flang/include/flang/Semantics/symbol.h
    flang/lib/Parser/openmp-parsers.cpp
    flang/lib/Parser/unparse.cpp
    flang/lib/Semantics/resolve-directives.cpp
    flang/lib/Semantics/resolve-directives.h
    flang/lib/Semantics/resolve-names.cpp

Removed: 
    


################################################################################
diff  --git a/flang/examples/FeatureList/FeatureList.cpp b/flang/examples/FeatureList/FeatureList.cpp
index 7ab294597ee0e0d..6f10553cdcb4c0d 100644
--- a/flang/examples/FeatureList/FeatureList.cpp
+++ b/flang/examples/FeatureList/FeatureList.cpp
@@ -23,6 +23,7 @@
 #include <utility>
 #include <vector>
 
+using namespace Fortran::common;
 using namespace Fortran::frontend;
 using namespace Fortran::parser;
 using namespace Fortran;
@@ -553,7 +554,7 @@ struct NodeVisitor {
   READ_FEATURE(OmpAtomicClause)
   READ_FEATURE(OmpAtomicClauseList)
   READ_FEATURE(OmpAtomicDefaultMemOrderClause)
-  READ_FEATURE(OmpAtomicDefaultMemOrderClause::Type)
+  READ_FEATURE(OmpAtomicDefaultMemOrderType)
   READ_FEATURE(OpenMPFlushConstruct)
   READ_FEATURE(OpenMPLoopConstruct)
   READ_FEATURE(OpenMPExecutableAllocate)

diff  --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h
index df47e98150ce6e1..15db21bf3473c05 100644
--- a/flang/include/flang/Common/Fortran.h
+++ b/flang/include/flang/Common/Fortran.h
@@ -87,6 +87,9 @@ ENUM_CLASS(CUDASubprogramAttrs, Host, Device, HostDevice, Global, Grid_Global)
 // CUDA data attributes; mutually exclusive
 ENUM_CLASS(CUDADataAttr, Constant, Device, Managed, Pinned, Shared, Texture)
 
+// OpenMP atomic_default_mem_order clause allowed values
+ENUM_CLASS(OmpAtomicDefaultMemOrderType, SeqCst, AcqRel, Relaxed)
+
 // Fortran names may have up to 63 characters (See Fortran 2018 C601).
 static constexpr int maxNameLen{63};
 

diff  --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 7a009e8cc708284..e7d74dda71a20cf 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -589,7 +589,7 @@ class ParseTreeDumper {
   NODE(parser, OmpAtomicClause)
   NODE(parser, OmpAtomicClauseList)
   NODE(parser, OmpAtomicDefaultMemOrderClause)
-  NODE_ENUM(OmpAtomicDefaultMemOrderClause, Type)
+  NODE_ENUM(common, OmpAtomicDefaultMemOrderType)
   NODE(parser, OpenMPFlushConstruct)
   NODE(parser, OpenMPLoopConstruct)
   NODE(parser, OpenMPExecutableAllocate)

diff  --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index d8449c8b812ae2f..5d92ecb05841767 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3593,8 +3593,8 @@ struct OmpDependClause {
 //                 ATOMIC_DEFAULT_MEM_ORDER (SEQ_CST | ACQ_REL |
 //                                           RELAXED)
 struct OmpAtomicDefaultMemOrderClause {
-  ENUM_CLASS(Type, SeqCst, AcqRel, Relaxed)
-  WRAPPER_CLASS_BOILERPLATE(OmpAtomicDefaultMemOrderClause, Type);
+  WRAPPER_CLASS_BOILERPLATE(
+      OmpAtomicDefaultMemOrderClause, common::OmpAtomicDefaultMemOrderType);
 };
 
 // OpenMP Clauses

diff  --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 7280a4eaa5fca57..aada3bf94cc1213 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -45,8 +45,38 @@ using SymbolVector = std::vector<SymbolRef>;
 using MutableSymbolRef = common::Reference<Symbol>;
 using MutableSymbolVector = std::vector<MutableSymbolRef>;
 
+// Mixin for details with OpenMP declarative constructs.
+class WithOmpDeclarative {
+  using OmpAtomicOrderType = common::OmpAtomicDefaultMemOrderType;
+
+public:
+  ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory,
+      DynamicAllocators);
+  using RequiresFlags = common::EnumSet<RequiresFlag, RequiresFlag_enumSize>;
+
+  bool has_ompRequires() const { return ompRequires_.has_value(); }
+  const RequiresFlags *ompRequires() const {
+    return ompRequires_ ? &*ompRequires_ : nullptr;
+  }
+  void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; }
+
+  bool has_ompAtomicDefaultMemOrder() const {
+    return ompAtomicDefaultMemOrder_.has_value();
+  }
+  const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const {
+    return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
+  }
+  void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) {
+    ompAtomicDefaultMemOrder_ = flags;
+  }
+
+private:
+  std::optional<RequiresFlags> ompRequires_;
+  std::optional<OmpAtomicOrderType> ompAtomicDefaultMemOrder_;
+};
+
 // A module or submodule.
-class ModuleDetails {
+class ModuleDetails : public WithOmpDeclarative {
 public:
   ModuleDetails(bool isSubmodule = false) : isSubmodule_{isSubmodule} {}
   bool isSubmodule() const { return isSubmodule_; }
@@ -63,7 +93,7 @@ class ModuleDetails {
   const Scope *scope_{nullptr};
 };
 
-class MainProgramDetails {
+class MainProgramDetails : public WithOmpDeclarative {
 public:
 private:
 };
@@ -114,7 +144,7 @@ class OpenACCRoutineInfo {
 // A subroutine or function definition, or a subprogram interface defined
 // in an INTERFACE block as part of the definition of a dummy procedure
 // or a procedure pointer (with just POINTER).
-class SubprogramDetails : public WithBindName {
+class SubprogramDetails : public WithBindName, public WithOmpDeclarative {
 public:
   bool isFunction() const { return result_ != nullptr; }
   bool isInterface() const { return isInterface_; }

diff  --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index b30a3a1eb2a151f..b220d578b3bbc7e 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -432,9 +432,9 @@ TYPE_PARSER(sourced(construct<OmpMemoryOrderClause>(
 //                               acq_rel
 //                               relaxed
 TYPE_PARSER(construct<OmpAtomicDefaultMemOrderClause>(
-    "SEQ_CST" >> pure(OmpAtomicDefaultMemOrderClause::Type::SeqCst) ||
-    "ACQ_REL" >> pure(OmpAtomicDefaultMemOrderClause::Type::AcqRel) ||
-    "RELAXED" >> pure(OmpAtomicDefaultMemOrderClause::Type::Relaxed)))
+    "SEQ_CST" >> pure(common::OmpAtomicDefaultMemOrderType::SeqCst) ||
+    "ACQ_REL" >> pure(common::OmpAtomicDefaultMemOrderType::AcqRel) ||
+    "RELAXED" >> pure(common::OmpAtomicDefaultMemOrderType::Relaxed)))
 
 // 2.17.7 Atomic construct
 //        atomic-clause -> memory-order-clause | HINT(hint-expression)

diff  --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index d7626c0ea762937..398545d315e5cd5 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2307,17 +2307,7 @@ class UnparseVisitor {
   }
 
   void Unparse(const OmpAtomicDefaultMemOrderClause &x) {
-    switch (x.v) {
-    case OmpAtomicDefaultMemOrderClause::Type::SeqCst:
-      Word("SEQ_CST");
-      break;
-    case OmpAtomicDefaultMemOrderClause::Type::AcqRel:
-      Word("ACQ_REL");
-      break;
-    case OmpAtomicDefaultMemOrderClause::Type::Relaxed:
-      Word("RELAXED");
-      break;
-    }
+    Word(ToUpperCaseLetters(common::EnumToString(x.v)));
   }
 
   void Unparse(const OmpAtomicClauseList &x) { Walk(" ", x.v, " "); }

diff  --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 38e195d20a34367..8375d095624ad05 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -22,6 +22,13 @@
 #include <map>
 #include <sstream>
 
+template <typename T>
+static Fortran::semantics::Scope *GetScope(
+    Fortran::semantics::SemanticsContext &context, const T &x) {
+  std::optional<Fortran::parser::CharBlock> source{GetSource(x)};
+  return source ? &context.FindScope(*source) : nullptr;
+}
+
 namespace Fortran::semantics {
 
 template <typename T> class DirectiveAttributeVisitor {
@@ -324,11 +331,6 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
     return true;
   }
 
-  bool Pre(const parser::SpecificationPart &x) {
-    Walk(std::get<std::list<parser::OpenMPDeclarativeConstruct>>(x.t));
-    return true;
-  }
-
   bool Pre(const parser::StmtFunctionStmt &x) {
     const auto &parsedExpr{std::get<parser::Scalar<parser::Expr>>(x.t)};
     if (const auto *expr{GetExpr(context_, parsedExpr)}) {
@@ -375,7 +377,38 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
   void Post(const parser::OpenMPDeclareSimdConstruct &) { PopContext(); }
 
   bool Pre(const parser::OpenMPRequiresConstruct &x) {
+    using Flags = WithOmpDeclarative::RequiresFlags;
+    using Requires = WithOmpDeclarative::RequiresFlag;
     PushContext(x.source, llvm::omp::Directive::OMPD_requires);
+
+    // Gather information from the clauses.
+    Flags flags;
+    std::optional<common::OmpAtomicDefaultMemOrderType> memOrder;
+    for (const auto &clause : std::get<parser::OmpClauseList>(x.t).v) {
+      flags |= common::visit(
+          common::visitors{
+              [&memOrder](
+                  const parser::OmpClause::AtomicDefaultMemOrder &atomic) {
+                memOrder = atomic.v.v;
+                return Flags{};
+              },
+              [](const parser::OmpClause::ReverseOffload &) {
+                return Flags{Requires::ReverseOffload};
+              },
+              [](const parser::OmpClause::UnifiedAddress &) {
+                return Flags{Requires::UnifiedAddress};
+              },
+              [](const parser::OmpClause::UnifiedSharedMemory &) {
+                return Flags{Requires::UnifiedSharedMemory};
+              },
+              [](const parser::OmpClause::DynamicAllocators &) {
+                return Flags{Requires::DynamicAllocators};
+              },
+              [](const auto &) { return Flags{}; }},
+          clause.u);
+    }
+    // Merge clauses into parents' symbols details.
+    AddOmpRequiresToScope(currScope(), flags, memOrder);
     return true;
   }
   void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); }
@@ -672,6 +705,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
 
   bool HasSymbolInEnclosingScope(const Symbol &, Scope &);
   std::int64_t ordCollapseLevel{0};
+
+  void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags,
+      std::optional<common::OmpAtomicDefaultMemOrderType>);
 };
 
 template <typename T>
@@ -2175,6 +2211,77 @@ void ResolveOmpParts(
   }
 }
 
+void ResolveOmpTopLevelParts(
+    SemanticsContext &context, const parser::Program &program) {
+  if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
+    return;
+  }
+
+  // Gather REQUIRES clauses from all non-module top-level program unit symbols,
+  // combine them together ensuring compatibility and apply them to all these
+  // program units. Modules are skipped because their REQUIRES clauses should be
+  // propagated via USE statements instead.
+  WithOmpDeclarative::RequiresFlags combinedFlags;
+  std::optional<common::OmpAtomicDefaultMemOrderType> combinedMemOrder;
+
+  // Function to go through non-module top level program units and extract
+  // REQUIRES information to be processed by a function-like argument.
+  auto processProgramUnits{[&](auto processFn) {
+    for (const parser::ProgramUnit &unit : program.v) {
+      if (!std::holds_alternative<common::Indirection<parser::Module>>(
+              unit.u) &&
+          !std::holds_alternative<common::Indirection<parser::Submodule>>(
+              unit.u)) {
+        Symbol *symbol{common::visit(
+            [&context](
+                auto &x) { return GetScope(context, x.value())->symbol(); },
+            unit.u)};
+
+        common::visit(
+            [&](auto &details) {
+              if constexpr (std::is_convertible_v<decltype(&details),
+                                WithOmpDeclarative *>) {
+                processFn(*symbol, details);
+              }
+            },
+            symbol->details());
+      }
+    }
+  }};
+
+  // Combine global REQUIRES information from all program units except modules
+  // and submodules.
+  processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) {
+    if (const WithOmpDeclarative::RequiresFlags *
+        flags{details.ompRequires()}) {
+      combinedFlags |= *flags;
+    }
+    if (const common::OmpAtomicDefaultMemOrderType *
+        memOrder{details.ompAtomicDefaultMemOrder()}) {
+      if (combinedMemOrder && *combinedMemOrder != *memOrder) {
+        context.Say(symbol.scope()->sourceRange(),
+            "Conflicting '%s' REQUIRES clauses found in compilation "
+            "unit"_err_en_US,
+            parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+                llvm::omp::Clause::OMPC_atomic_default_mem_order)
+                                           .str()));
+      }
+      combinedMemOrder = *memOrder;
+    }
+  });
+
+  // Update all program units except modules and submodules with the combined
+  // global REQUIRES information.
+  processProgramUnits([&](Symbol &, WithOmpDeclarative &details) {
+    if (combinedFlags.any()) {
+      details.set_ompRequires(combinedFlags);
+    }
+    if (combinedMemOrder) {
+      details.set_ompAtomicDefaultMemOrder(*combinedMemOrder);
+    }
+  });
+}
+
 void OmpAttributeVisitor::CheckDataCopyingClause(
     const parser::Name &name, const Symbol &symbol, Symbol::Flag ompFlag) {
   const auto *checkSymbol{&symbol};
@@ -2322,4 +2429,44 @@ void OmpAttributeVisitor::CheckNameInAllocateStmt(
       parser::ToUpperCaseLetters(
           llvm::omp::getOpenMPDirectiveName(GetContext().directive).str()));
 }
+
+void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
+    WithOmpDeclarative::RequiresFlags flags,
+    std::optional<common::OmpAtomicDefaultMemOrderType> memOrder) {
+  Scope *scopeIter = &scope;
+  do {
+    if (Symbol * symbol{scopeIter->symbol()}) {
+      common::visit(
+          [&](auto &details) {
+            // Store clauses information into the symbol for the parent and
+            // enclosing modules, programs, functions and subroutines.
+            if constexpr (std::is_convertible_v<decltype(&details),
+                              WithOmpDeclarative *>) {
+              if (flags.any()) {
+                if (const WithOmpDeclarative::RequiresFlags *
+                    otherFlags{details.ompRequires()}) {
+                  flags |= *otherFlags;
+                }
+                details.set_ompRequires(flags);
+              }
+              if (memOrder) {
+                if (details.has_ompAtomicDefaultMemOrder() &&
+                    *details.ompAtomicDefaultMemOrder() != *memOrder) {
+                  context_.Say(scopeIter->sourceRange(),
+                      "Conflicting '%s' REQUIRES clauses found in compilation "
+                      "unit"_err_en_US,
+                      parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+                          llvm::omp::Clause::OMPC_atomic_default_mem_order)
+                                                     .str()));
+                }
+                details.set_ompAtomicDefaultMemOrder(*memOrder);
+              }
+            }
+          },
+          symbol->details());
+    }
+    scopeIter = &scopeIter->parent();
+  } while (!scopeIter->IsGlobal());
+}
+
 } // namespace Fortran::semantics

diff  --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h
index 6ba7a062529421a..839165aaf30eb81 100644
--- a/flang/lib/Semantics/resolve-directives.h
+++ b/flang/lib/Semantics/resolve-directives.h
@@ -11,6 +11,7 @@
 
 namespace Fortran::parser {
 struct Name;
+struct Program;
 struct ProgramUnit;
 } // namespace Fortran::parser
 
@@ -21,6 +22,7 @@ class SemanticsContext;
 // Name resolution for OpenACC and OpenMP directives
 void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &);
 void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
+void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
 
 } // namespace Fortran::semantics
 #endif

diff  --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 0b4b940fa1d1c70..865c198424696a9 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8699,11 +8699,14 @@ void ResolveNamesVisitor::ResolveExecutionParts(const ProgramTree &node) {
   }
 }
 
-void ResolveNamesVisitor::Post(const parser::Program &) {
+void ResolveNamesVisitor::Post(const parser::Program &x) {
   // ensure that all temps were deallocated
   CHECK(!attrs_);
   CHECK(!cudaDataAttr_);
   CHECK(!GetDeclTypeSpec());
+  // Top-level resolution to propagate information across program units after
+  // each of them has been resolved separately.
+  ResolveOmpTopLevelParts(context(), x);
 }
 
 // A singleton instance of the scope -> IMPLICIT rules mapping is

diff  --git a/flang/test/Semantics/OpenMP/requires09.f90 b/flang/test/Semantics/OpenMP/requires09.f90
new file mode 100644
index 000000000000000..2fa5d950b9c2d8a
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires09.f90
@@ -0,0 +1,14 @@
+! RUN: %python %S/../test_errors.py %s %flang -fopenmp
+! OpenMP Version 5.0
+! 2.4 Requires directive
+! All atomic_default_mem_order clauses in 'requires' directives found within a
+! compilation unit must specify the same ordering.
+
+subroutine f
+  !$omp requires atomic_default_mem_order(seq_cst)
+end subroutine f
+
+!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit
+subroutine g
+  !$omp requires atomic_default_mem_order(relaxed)
+end subroutine g


        


More information about the flang-commits mailing list