[flang-commits] [flang] 1d7960a - [Flang][OpenMP][OpenACC] Add function for mapping parser clause classes with the corresponding clause kind.

via flang-commits flang-commits at lists.llvm.org
Tue Mar 16 23:53:57 PDT 2021


Author: Praveen
Date: 2021-03-17T12:20:43+05:30
New Revision: 1d7960a601fc83b3847f83681573019271e7516f

URL: https://github.com/llvm/llvm-project/commit/1d7960a601fc83b3847f83681573019271e7516f
DIFF: https://github.com/llvm/llvm-project/commit/1d7960a601fc83b3847f83681573019271e7516f.diff

LOG: [Flang][OpenMP][OpenACC] Add function for mapping parser clause classes with the corresponding clause kind.

1. Generate the mapping for clauses between the parser class and the
   corresponding clause kind for OpenMP and OpenACC using tablegen.

2. Add a common function to get the OmpObjectList from the OpenMP
   clauses to avoid repetition of code.

Reviewed by: Kiranchandramohan @kiranchandramohan , Valentin Clement @clementval

Differential Revision: https://reviews.llvm.org/D98603

Added: 
    

Modified: 
    flang/lib/Semantics/check-omp-structure.cpp
    flang/lib/Semantics/check-omp-structure.h
    llvm/test/TableGen/directive1.td
    llvm/utils/TableGen/DirectiveEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index b23ae51b50948..baa31fd1ecb92 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -840,7 +840,6 @@ void OmpStructureChecker::CheckReductionArraySection(
 
 void OmpStructureChecker::CheckMultipleAppearanceAcrossContext(
     const parser::OmpObjectList &redObjectList) {
-  const parser::OmpObjectList *objList{nullptr};
   //  TODO: Verify the assumption here that the immediately enclosing region is
   //  the parallel region to which the worksharing construct having reduction
   //  binds to.
@@ -848,43 +847,29 @@ void OmpStructureChecker::CheckMultipleAppearanceAcrossContext(
     for (auto it : enclosingContext->clauseInfo) {
       llvmOmpClause type = it.first;
       const auto *clause = it.second;
-      if (type == llvm::omp::Clause::OMPC_private) {
-        const auto &pClause{std::get<parser::OmpClause::Private>(clause->u)};
-        objList = &pClause.v;
-      } else if (type == llvm::omp::Clause::OMPC_firstprivate) {
-        const auto &fpClause{
-            std::get<parser::OmpClause::Firstprivate>(clause->u)};
-        objList = &fpClause.v;
-      } else if (type == llvm::omp::Clause::OMPC_lastprivate) {
-        const auto &lpClause{
-            std::get<parser::OmpClause::Lastprivate>(clause->u)};
-        objList = &lpClause.v;
-      } else if (type == llvm::omp::Clause::OMPC_reduction) {
-        const auto &rClause{std::get<parser::OmpClause::Reduction>(clause->u)};
-        const auto &olist{std::get<1>(rClause.v.t)};
-        objList = &olist;
-      }
-      if (objList) {
-        for (const auto &ompObject : objList->v) {
-          if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
-            if (const auto *symbol{name->symbol}) {
-              for (const auto &redOmpObject : redObjectList.v) {
-                if (const auto *rname{
-                        parser::Unwrap<parser::Name>(redOmpObject)}) {
-                  if (const auto *rsymbol{rname->symbol}) {
-                    if (rsymbol->name() == symbol->name()) {
-                      context_.Say(GetContext().clauseSource,
-                          "%s variable '%s' is %s in outer context must"
-                          " be shared in the parallel regions to which any"
-                          " of the worksharing regions arising from the "
-                          "worksharing"
-                          " construct bind."_err_en_US,
-                          parser::ToUpperCaseLetters(
-                              getClauseName(llvm::omp::Clause::OMPC_reduction)
-                                  .str()),
-                          symbol->name(),
-                          parser::ToUpperCaseLetters(
-                              getClauseName(type).str()));
+      if (llvm::omp::privateReductionSet.test(type)) {
+        if (const auto *objList{GetOmpObjectList(*clause)}) {
+          for (const auto &ompObject : objList->v) {
+            if (const auto *name{parser::Unwrap<parser::Name>(ompObject)}) {
+              if (const auto *symbol{name->symbol}) {
+                for (const auto &redOmpObject : redObjectList.v) {
+                  if (const auto *rname{
+                          parser::Unwrap<parser::Name>(redOmpObject)}) {
+                    if (const auto *rsymbol{rname->symbol}) {
+                      if (rsymbol->name() == symbol->name()) {
+                        context_.Say(GetContext().clauseSource,
+                            "%s variable '%s' is %s in outer context must"
+                            " be shared in the parallel regions to which any"
+                            " of the worksharing regions arising from the "
+                            "worksharing"
+                            " construct bind."_err_en_US,
+                            parser::ToUpperCaseLetters(
+                                getClauseName(llvm::omp::Clause::OMPC_reduction)
+                                    .str()),
+                            symbol->name(),
+                            parser::ToUpperCaseLetters(
+                                getClauseName(type).str()));
+                      }
                     }
                   }
                 }
@@ -1213,7 +1198,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
   DirectivesClauseTriple dirClauseTriple;
   SymbolSourceMap currSymbols;
   GetSymbolsInObjectList(x.v, currSymbols);
-  CheckDefinableObjects(currSymbols, llvm::omp::Clause::OMPC_lastprivate);
+  CheckDefinableObjects(currSymbols, GetClauseKindForParserClass(x));
 
   // Check lastprivate variables in worksharing constructs
   dirClauseTriple.emplace(llvm::omp::Directive::OMPD_do,
@@ -1224,7 +1209,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
           llvm::omp::Directive::OMPD_parallel, llvm::omp::privateReductionSet));
 
   CheckPrivateSymbolsInOuterCxt(
-      currSymbols, dirClauseTriple, llvm::omp::Clause::OMPC_lastprivate);
+      currSymbols, dirClauseTriple, GetClauseKindForParserClass(x));
 }
 
 llvm::StringRef OmpStructureChecker::getClauseName(llvm::omp::Clause clause) {
@@ -1368,40 +1353,11 @@ void OmpStructureChecker::CheckPrivateSymbolsInOuterCxt(
     if (auto *enclosingContext{GetEnclosingContextWithDir(enclosingDir)}) {
       for (auto it{enclosingContext->clauseInfo.begin()};
            it != enclosingContext->clauseInfo.end(); ++it) {
-        // TODO: Replace the hard-coded clause names by using autogen checks or
-        // a function which maps parser::OmpClause::<name> to the corresponding
-        // llvm::omp::Clause::OMPC_<name>
-        std::visit(common::visitors{
-                       [&](const parser::OmpClause::Private &x) {
-                         if (enclosingClauseSet.test(
-                                 llvm::omp::Clause::OMPC_private)) {
-                           GetSymbolsInObjectList(x.v, enclosingSymbols);
-                         }
-                       },
-                       [&](const parser::OmpClause::Firstprivate &x) {
-                         if (enclosingClauseSet.test(
-                                 llvm::omp::Clause::OMPC_firstprivate)) {
-                           GetSymbolsInObjectList(x.v, enclosingSymbols);
-                         }
-                       },
-                       [&](const parser::OmpClause::Lastprivate &x) {
-                         if (enclosingClauseSet.test(
-                                 llvm::omp::Clause::OMPC_lastprivate)) {
-                           GetSymbolsInObjectList(x.v, enclosingSymbols);
-                         }
-                       },
-                       [&](const parser::OmpClause::Reduction &x) {
-                         if (enclosingClauseSet.test(
-                                 llvm::omp::Clause::OMPC_reduction)) {
-                           const auto &ompObjectList{
-                               std::get<parser::OmpObjectList>(x.v.t)};
-                           GetSymbolsInObjectList(
-                               ompObjectList, enclosingSymbols);
-                         }
-                       },
-                       [&](const auto &) {},
-                   },
-            it->second->u);
+        if (enclosingClauseSet.test(it->first)) {
+          if (const auto *ompObjectList{GetOmpObjectList(*it->second)}) {
+            GetSymbolsInObjectList(*ompObjectList, enclosingSymbols);
+          }
+        }
       }
 
       // Check if the symbols in current context are private in outer context
@@ -1497,4 +1453,37 @@ void OmpStructureChecker::CheckWorkshareBlockStmts(
   }
 }
 
+const parser::OmpObjectList *OmpStructureChecker::GetOmpObjectList(
+    const parser::OmpClause &clause) {
+
+  // Clauses with OmpObjectList as its data member
+  using MemberObjectListClauses = std::tuple<parser::OmpClause::Copyprivate,
+      parser::OmpClause::Copyin, parser::OmpClause::Firstprivate,
+      parser::OmpClause::From, parser::OmpClause::Lastprivate,
+      parser::OmpClause::Link, parser::OmpClause::Private,
+      parser::OmpClause::Shared, parser::OmpClause::To>;
+
+  // Clauses with OmpObjectList in the tuple
+  using TupleObjectListClauses = std::tuple<parser::OmpClause::Allocate,
+      parser::OmpClause::Map, parser::OmpClause::Reduction>;
+
+  // TODO:: Generate the tuples using TableGen.
+  // Handle other constructs with OmpObjectList such as OpenMPThreadprivate.
+  return std::visit(
+      common::visitors{
+          [&](const auto &x) -> const parser::OmpObjectList * {
+            using Ty = std::decay_t<decltype(x)>;
+            if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
+              return &x.v;
+            } else if constexpr (common::HasMember<Ty,
+                                     TupleObjectListClauses>) {
+              return &(std::get<parser::OmpObjectList>(x.v.t));
+            } else {
+              return nullptr;
+            }
+          },
+      },
+      clause.u);
+}
+
 } // namespace Fortran::semantics

diff  --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h
index cd560dd1cd796..f11ddc66b401e 100644
--- a/flang/lib/Semantics/check-omp-structure.h
+++ b/flang/lib/Semantics/check-omp-structure.h
@@ -153,6 +153,13 @@ class OmpStructureChecker
 #define GEN_FLANG_CLAUSE_CHECK_ENTER
 #include "llvm/Frontend/OpenMP/OMP.inc"
 
+  // Get the OpenMP Clause Kind for the corresponding Parser class
+  template <typename A>
+  llvm::omp::Clause GetClauseKindForParserClass(const A &) {
+#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP
+#include "llvm/Frontend/OpenMP/OMP.inc"
+  }
+
 private:
   bool HasInvalidWorksharingNesting(
       const parser::CharBlock &, const OmpDirectiveSet &);
@@ -197,6 +204,7 @@ class OmpStructureChecker
       const parser::Name &name, const llvm::omp::Clause clause);
   void CheckMultipleAppearanceAcrossContext(
       const parser::OmpObjectList &ompObjectList);
+  const parser::OmpObjectList *GetOmpObjectList(const parser::OmpClause &);
 };
 } // namespace Fortran::semantics
 #endif // FORTRAN_SEMANTICS_CHECK_OMP_STRUCTURE_H_

diff  --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td
index dbf9b6c03d3a2..a69958175267e 100644
--- a/llvm/test/TableGen/directive1.td
+++ b/llvm/test/TableGen/directive1.td
@@ -256,3 +256,23 @@ def TDL_DirA : Directive<"dira"> {
 // GEN-NEXT:  }
 // GEN-EMPTY:
 // GEN-NEXT:  #endif // GEN_FLANG_CLAUSE_UNPARSE
+// GEN-EMPTY:
+// GEN-NEXT:  #ifdef GEN_FLANG_CLAUSE_CHECK_ENTER
+// GEN-NEXT:  #undef GEN_FLANG_CLAUSE_CHECK_ENTER
+// GEN-EMPTY:
+// GEN-NEXT:  void Enter(const parser::TdlClause::Clausea &);
+// GEN-NEXT:  void Enter(const parser::TdlClause::Clauseb &);
+// GEN-EMPTY:
+// GEN-NEXT:  #endif // GEN_FLANG_CLAUSE_CHECK_ENTER
+// GEN-EMPTY:
+// GEN-NEXT:  #ifdef GEN_FLANG_CLAUSE_PARSER_KIND_MAP
+// GEN-NEXT:  #undef GEN_FLANG_CLAUSE_PARSER_KIND_MAP
+// GEN-EMPTY:
+// GEN-NEXT:  if constexpr (std::is_same_v<A, parser::TdlClause::Clausea>)
+// GEN-NEXT:    return llvm::tdl::Clause::TDLC_clausea;
+// GEN-NEXT:  if constexpr (std::is_same_v<A, parser::TdlClause::Clauseb>)
+// GEN-NEXT:    return llvm::tdl::Clause::TDLC_clauseb;
+// GEN-NEXT:  llvm_unreachable("Invalid Tdl Parser clause");
+// GEN-EMPTY:
+// GEN-NEXT:  #endif // GEN_FLANG_CLAUSE_PARSER_KIND_MAP
+// GEN-EMPTY:

diff  --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index deb51a082649f..b331fd9c0613e 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -647,6 +647,29 @@ void GenerateFlangClauseCheckPrototypes(const DirectiveLanguage &DirLang,
   }
 }
 
+// Generate the mapping for clauses between the parser class and the
+// corresponding clause Kind
+void GenerateFlangClauseParserKindMap(const DirectiveLanguage &DirLang,
+                                      raw_ostream &OS) {
+
+  IfDefScope Scope("GEN_FLANG_CLAUSE_PARSER_KIND_MAP", OS);
+
+  OS << "\n";
+  for (const auto &C : DirLang.getClauses()) {
+    Clause Clause{C};
+    OS << "if constexpr (std::is_same_v<A, parser::"
+       << DirLang.getFlangClauseBaseClass()
+       << "::" << Clause.getFormattedParserClassName();
+    OS << ">)\n";
+    OS << "  return llvm::" << DirLang.getCppNamespace()
+       << "::Clause::" << DirLang.getClausePrefix() << Clause.getFormattedName()
+       << ";\n";
+  }
+
+  OS << "llvm_unreachable(\"Invalid " << DirLang.getName()
+     << " Parser clause\");\n";
+}
+
 // Generate the implementation section for the enumeration in the directive
 // language
 void EmitDirectivesFlangImpl(const DirectiveLanguage &DirLang,
@@ -665,6 +688,8 @@ void EmitDirectivesFlangImpl(const DirectiveLanguage &DirLang,
   GenerateFlangClauseUnparse(DirLang, OS);
 
   GenerateFlangClauseCheckPrototypes(DirLang, OS);
+
+  GenerateFlangClauseParserKindMap(DirLang, OS);
 }
 
 void GenerateClauseClassMacro(const DirectiveLanguage &DirLang,


        


More information about the flang-commits mailing list