[flang-commits] [flang] [flang] Add support for lowering directives at the CONTAINS level (PR #95123)

via flang-commits flang-commits at lists.llvm.org
Tue Jun 11 07:08:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (vdonaldson)

<details>
<summary>Changes</summary>

There is currently support for lowering directives that appear outside of a module or procedure, or inside the body of a module or procedure. Extend this to support directives at the CONTAINS level of a module or procedure, such as directives 3, 5, 7 9, and 10 in:

    !dir$ some directive 1
    module m
      !dir$ some directive 2
    contains
      !dir$ some directive 3
      subroutine p
        !dir$ some directive 4
      contains
        !dir$ some directive 5
        subroutine s1
          !dir$ some directive 6
        end subroutine s1
        !dir$ some directive 7
        subroutine s2
          !dir$ some directive 8
        end subroutine s2
        !dir$ some directive 9
      end subroutine p
      !dir$ some directive 10
    end module m
    !dir$ some directive 11

This is done by looking for CONTAINS statements at the module or procedure level, while ignoring CONTAINS statements at the derived type level.

---

Patch is 22.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95123.diff


4 Files Affected:

- (modified) flang/include/flang/Lower/PFTBuilder.h (+10-7) 
- (modified) flang/lib/Lower/Bridge.cpp (+38-30) 
- (modified) flang/lib/Lower/PFTBuilder.cpp (+95-54) 
- (added) flang/test/Lower/pre-fir-tree09.f90 (+100) 


``````````diff
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9913f584133fa..83200eb6351a8 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -31,11 +31,14 @@
 
 namespace Fortran::lower::pft {
 
+struct CompilerDirectiveUnit;
 struct Evaluation;
-struct Program;
-struct ModuleLikeUnit;
 struct FunctionLikeUnit;
+struct ModuleLikeUnit;
+struct Program;
 
+using ContainedUnit = std::variant<CompilerDirectiveUnit, FunctionLikeUnit>;
+using ContainedUnitList = std::list<ContainedUnit>;
 using EvaluationList = std::list<Evaluation>;
 
 /// Provide a variant like container that can hold references. It can hold
@@ -594,8 +597,8 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
 
 void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
 
-/// Function-like units may contain evaluations (executable statements) and
-/// nested function-like units (internal procedures and function statements).
+/// Function-like units may contain evaluations (executable statements),
+/// directives, and internal (nested) function-like units.
 struct FunctionLikeUnit : public ProgramUnit {
   // wrapper statements for function-like syntactic structures
   using FunctionStatement =
@@ -697,10 +700,10 @@ struct FunctionLikeUnit : public ProgramUnit {
   std::optional<FunctionStatement> beginStmt;
   FunctionStatement endStmt;
   const semantics::Scope *scope;
-  EvaluationList evaluationList;
   LabelEvalMap labelEvaluationMap;
   SymbolLabelMap assignSymbolLabelMap;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  ContainedUnitList containedUnitList;
+  EvaluationList evaluationList;
   /// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
   /// is the primary entry point; remaining pairs are alternate entry points.
   /// The primary entry point symbol is Null for an anonymous program.
@@ -746,7 +749,7 @@ struct ModuleLikeUnit : public ProgramUnit {
 
   ModuleStatement beginStmt;
   ModuleStatement endStmt;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  ContainedUnitList containedUnitList;
   EvaluationList evaluationList;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 202efa57d4a36..9ecbbc73dce07 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -302,28 +302,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     bool hasMainProgram = false;
     const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
     for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
-      std::visit(Fortran::common::visitors{
-                     [&](Fortran::lower::pft::FunctionLikeUnit &f) {
-                       if (f.isMainProgram())
-                         hasMainProgram = true;
-                       declareFunction(f);
-                       if (!globalOmpRequiresSymbol)
-                         globalOmpRequiresSymbol = f.getScope().symbol();
-                     },
-                     [&](Fortran::lower::pft::ModuleLikeUnit &m) {
-                       lowerModuleDeclScope(m);
-                       for (Fortran::lower::pft::FunctionLikeUnit &f :
-                            m.nestedFunctions)
-                         declareFunction(f);
-                     },
-                     [&](Fortran::lower::pft::BlockDataUnit &b) {
-                       if (!globalOmpRequiresSymbol)
-                         globalOmpRequiresSymbol = b.symTab.symbol();
-                     },
-                     [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
-                     [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
-                 },
-                 u);
+      std::visit(
+          Fortran::common::visitors{
+              [&](Fortran::lower::pft::FunctionLikeUnit &f) {
+                if (f.isMainProgram())
+                  hasMainProgram = true;
+                declareFunction(f);
+                if (!globalOmpRequiresSymbol)
+                  globalOmpRequiresSymbol = f.getScope().symbol();
+              },
+              [&](Fortran::lower::pft::ModuleLikeUnit &m) {
+                lowerModuleDeclScope(m);
+                for (Fortran::lower::pft::ContainedUnit &unit :
+                     m.containedUnitList)
+                  if (auto *f =
+                          std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
+                              &unit))
+                    declareFunction(*f);
+              },
+              [&](Fortran::lower::pft::BlockDataUnit &b) {
+                if (!globalOmpRequiresSymbol)
+                  globalOmpRequiresSymbol = b.symTab.symbol();
+              },
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+              [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
+          },
+          u);
     }
 
     // Create definitions of intrinsic module constants.
@@ -387,13 +391,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Compute the set of host associated entities from the nested functions.
     llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      collectHostAssociatedVariables(f, escapeHost);
+    for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        collectHostAssociatedVariables(*f, escapeHost);
     funit.setHostAssociatedSymbols(escapeHost);
 
     // Declare internal procedures
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      declareFunction(f);
+    for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        declareFunction(*f);
   }
 
   /// Get the scope that is defining or using \p sym. The returned scope is not
@@ -5356,8 +5362,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       endNewFunction(funit);
     }
     funit.setActiveEntry(0);
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      lowerFunc(f); // internal procedure
+    for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        lowerFunc(*f); // internal procedure
   }
 
   /// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -5381,8 +5388,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
-      lowerFunc(f);
+    for (Fortran::lower::pft::ContainedUnit &unit : mod.containedUnitList)
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        lowerFunc(*f);
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index f196b9c5a0cbc..df2c31381a0e7 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -209,6 +209,20 @@ class PFTBuilder {
     }
   }
 
+  bool Pre(const parser::SpecificationPart &) {
+    ++specificationPartLevel;
+    return true;
+  }
+  void Post(const parser::SpecificationPart &) { --specificationPartLevel; }
+
+  bool Pre(const parser::ContainsStmt &) {
+    if (!specificationPartLevel) {
+      assert(containsStmtStack.size() && "empty contains stack");
+      containsStmtStack.back() = true;
+    }
+    return false;
+  }
+
   // Module like
   bool Pre(const parser::Module &node) { return enterModule(node); }
   bool Pre(const parser::Submodule &node) { return enterModule(node); }
@@ -249,15 +263,21 @@ class PFTBuilder {
         whereBody.u);
   }
 
-  // CompilerDirective have special handling in case they are top level
-  // directives (i.e. they do not belong to a ProgramUnit).
+  // A CompilerDirective may appear outside any program unit, after a module
+  // or function contains statement, or inside a module or function.
   bool Pre(const parser::CompilerDirective &directive) {
-    assert(pftParentStack.size() > 0 &&
-           "At least the Program must be a parent");
-    if (pftParentStack.back().isA<lower::pft::Program>()) {
-      addUnit(
-          lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
+    assert(pftParentStack.size() > 0 && "no program");
+    lower::pft::PftNode &node = pftParentStack.back();
+    if (node.isA<lower::pft::Program>()) {
+      addUnit(lower::pft::CompilerDirectiveUnit(directive, node));
       return false;
+    } else if ((node.isA<lower::pft::ModuleLikeUnit>() ||
+                node.isA<lower::pft::FunctionLikeUnit>())) {
+      assert(containsStmtStack.size() && "empty contains stack");
+      if (containsStmtStack.back()) {
+        addContainedUnit(lower::pft::CompilerDirectiveUnit{directive, node});
+        return false;
+      }
     }
     return enterConstructOrDirective(directive);
   }
@@ -277,9 +297,10 @@ class PFTBuilder {
   /// Initialize a new module-like unit and make it the builder's focus.
   template <typename A>
   bool enterModule(const A &mod) {
-    Fortran::lower::pft::ModuleLikeUnit &unit =
+    lower::pft::ModuleLikeUnit &unit =
         addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
-    functionList = &unit.nestedFunctions;
+    containsStmtStack.push_back(false);
+    containedUnitList = &unit.containedUnitList;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -287,6 +308,7 @@ class PFTBuilder {
   }
 
   void exitModule() {
+    containsStmtStack.pop_back();
     if (!evaluationListStack.empty())
       popEvaluationList();
     pftParentStack.pop_back();
@@ -344,12 +366,13 @@ class PFTBuilder {
                      const semantics::SemanticsContext &semanticsContext) {
     cleanModuleEvaluationList();
     endFunctionBody(); // enclosing host subprogram body, if any
-    Fortran::lower::pft::FunctionLikeUnit &unit =
-        addFunction(lower::pft::FunctionLikeUnit{func, pftParentStack.back(),
-                                                 semanticsContext});
+    lower::pft::FunctionLikeUnit &unit =
+        addContainedUnit(lower::pft::FunctionLikeUnit{
+            func, pftParentStack.back(), semanticsContext});
     labelEvaluationMap = &unit.labelEvaluationMap;
     assignSymbolLabelMap = &unit.assignSymbolLabelMap;
-    functionList = &unit.nestedFunctions;
+    containsStmtStack.push_back(false);
+    containedUnitList = &unit.containedUnitList;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -361,6 +384,7 @@ class PFTBuilder {
     endFunctionBody();
     analyzeBranches(nullptr, *evaluationListStack.back()); // add branch links
     processEntryPoints();
+    containsStmtStack.pop_back();
     popEvaluationList();
     labelEvaluationMap = nullptr;
     assignSymbolLabelMap = nullptr;
@@ -371,7 +395,7 @@ class PFTBuilder {
   /// Initialize a new construct or directive and make it the builder's focus.
   template <typename A>
   bool enterConstructOrDirective(const A &constructOrDirective) {
-    Fortran::lower::pft::Evaluation &eval = addEvaluation(
+    lower::pft::Evaluation &eval = addEvaluation(
         lower::pft::Evaluation{constructOrDirective, pftParentStack.back()});
     eval.evaluationList.reset(new lower::pft::EvaluationList);
     pushEvaluationList(eval.evaluationList.get());
@@ -381,7 +405,7 @@ class PFTBuilder {
   }
 
   void exitConstructOrDirective() {
-    auto isOpenMPLoopConstruct = [](Fortran::lower::pft::Evaluation *eval) {
+    auto isOpenMPLoopConstruct = [](lower::pft::Evaluation *eval) {
       if (const auto *ompConstruct = eval->getIf<parser::OpenMPConstruct>())
         if (std::holds_alternative<parser::OpenMPLoopConstruct>(
                 ompConstruct->u))
@@ -396,8 +420,7 @@ class PFTBuilder {
       // construct region must have an exit target inside the region.
       // This is not applicable to the OpenMP loop construct since the
       // end of the loop is an available target inside the region.
-      Fortran::lower::pft::EvaluationList &evaluationList =
-          *eval->evaluationList;
+      lower::pft::EvaluationList &evaluationList = *eval->evaluationList;
       if (!evaluationList.empty() && evaluationList.back().isConstruct()) {
         static const parser::ContinueStmt exitTarget{};
         addEvaluation(
@@ -413,15 +436,15 @@ class PFTBuilder {
   void resetFunctionState() {
     if (!pftParentStack.empty()) {
       pftParentStack.back().visit(common::visitors{
+          [&](lower::pft::ModuleLikeUnit &p) {
+            containedUnitList = &p.containedUnitList;
+          },
           [&](lower::pft::FunctionLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            containedUnitList = &p.containedUnitList;
             labelEvaluationMap = &p.labelEvaluationMap;
             assignSymbolLabelMap = &p.assignSymbolLabelMap;
           },
-          [&](lower::pft::ModuleLikeUnit &p) {
-            functionList = &p.nestedFunctions;
-          },
-          [&](auto &) { functionList = nullptr; },
+          [&](auto &) { containedUnitList = nullptr; },
       });
     }
   }
@@ -433,12 +456,11 @@ class PFTBuilder {
   }
 
   template <typename A>
-  A &addFunction(A &&func) {
-    if (functionList) {
-      functionList->emplace_back(std::move(func));
-      return functionList->back();
-    }
-    return addUnit(std::move(func));
+  A &addContainedUnit(A &&unit) {
+    if (!containedUnitList)
+      return addUnit(std::move(unit));
+    containedUnitList->emplace_back(std::move(unit));
+    return std::get<A>(containedUnitList->back());
   }
 
   // ActionStmt has a couple of non-conforming cases, explicitly handled here.
@@ -459,7 +481,6 @@ class PFTBuilder {
 
   /// Append an Evaluation to the end of the current list.
   lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
-    assert(functionList && "not in a function");
     assert(!evaluationListStack.empty() && "empty evaluation list stack");
     if (!constructAndDirectiveStack.empty())
       eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,15 +520,15 @@ class PFTBuilder {
 
   /// push a new list on the stack of Evaluation lists
   void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
-    assert(functionList && "not in a function");
     assert(evaluationList && evaluationList->empty() &&
-           "evaluation list isn't correct");
+           "invalid evaluation list");
     evaluationListStack.emplace_back(evaluationList);
   }
 
   /// pop the current list and return to the last Evaluation list
   void popEvaluationList() {
-    assert(functionList && "not in a function");
+    assert(!evaluationListStack.empty() &&
+           "trying to pop an empty evaluationListStack");
     evaluationListStack.pop_back();
   }
 
@@ -1089,9 +1110,8 @@ class PFTBuilder {
   std::vector<lower::pft::PftNode> pftParentStack;
   const semantics::SemanticsContext &semanticsContext;
 
-  /// functionList points to the internal or module procedure function list
-  /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
-  std::list<lower::pft::FunctionLikeUnit> *functionList{};
+  llvm::SmallVector<bool> containsStmtStack{};
+  lower::pft::ContainedUnitList *containedUnitList{};
   std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
   std::vector<lower::pft::Evaluation *> doConstructStack{};
   /// evaluationListStack is the current nested construct evaluationList state.
@@ -1099,6 +1119,7 @@ class PFTBuilder {
   llvm::DenseMap<parser::Label, lower::pft::Evaluation *> *labelEvaluationMap{};
   lower::pft::SymbolLabelMap *assignSymbolLabelMap{};
   std::map<std::string, lower::pft::Evaluation *> constructNameMap{};
+  int specificationPartLevel{};
   lower::pft::Evaluation *lastLexicalEvaluation{};
 };
 
@@ -1201,11 +1222,15 @@ class PFTDumper {
       outputStream << " -> " << eval.controlSuccessor->printIndex;
     else if (eval.isA<parser::EntryStmt>() && eval.lexicalSuccessor)
       outputStream << " -> " << eval.lexicalSuccessor->printIndex;
+    bool extraNewline = false;
     if (!eval.position.empty())
       outputStream << ": " << eval.position.ToString();
-    else if (auto *dir = eval.getIf<Fortran::parser::CompilerDirective>())
+    else if (auto *dir = eval.getIf<parser::CompilerDirective>()) {
+      extraNewline = dir->source.ToString().back() == '\n';
       outputStream << ": !" << dir->source.ToString();
-    outputStream << '\n';
+    }
+    if (!extraNewline)
+      outputStream << '\n';
     if (eval.hasNestedEvaluations()) {
       dumpEvaluationList(outputStream, *eval.evaluationList, indent + 1);
       outputStream << indentString << "<<End " << name << bang << ">>\n";
@@ -1265,13 +1290,7 @@ class PFTDumper {
       outputStream << ": " << header;
     outputStream << '\n';
     dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
-    if (!functionLikeUnit.nestedFunctions.empty()) {
-      outputStream << "\nContains\n";
-      for (const lower::pft::FunctionLikeUnit &func :
-           functionLikeUnit.nestedFunctions)
-        dumpFunctionLikeUnit(outputStream, func);
-      outputStream << "End Contains\n";
-    }
+    dumpContainedUnitList(outputStream, functionLikeUnit.containedUnitList);
     outputStream << "End " << unitKind << ' ' << name << "\n\n";
   }
 
@@ -1298,11 +1317,8 @@ class PFTDumper {
     });
     outputStream << unitKind << ' ' << name << ": " << header << '\n';
     dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
-    outputStream << "Contains\n";
-    for (const lower::pft::FunctionLikeUnit &func :
-         moduleLikeUnit.nestedFunctions)
-      dumpFunctionLikeUnit(outputStream, func);
-    outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
+    dumpContainedUnitList(outputStream, moduleLikeUnit.containedUnitList);
+    outputStream << "End " << unitKind << ' ' << name << "\n\n";
   }
 
   // Top level directives
@@ -1311,9 +1327,34 @@ class PFTDumper {
       const lower::pft::CompilerDirectiveUnit &directive) {
     outputStream << getNodeIndex(directive) << " ";
     outputStream << "CompilerDirective: !";
-    outputStream << directive.get<Fortran::parser::CompilerDirective>()
-                        .source.ToString();
-    outputStream << "\nEnd CompilerDirective\n\n";
+    bool extraNewline =
+        directive.get<parser::CompilerDirective>().source.ToString().back() ==
+        'n';
+    outputStream
+        << directive.get<parser::CompilerDirective>().source.ToString();
+    if (!extraNewline)
+      outputStream << "\n";
+    outputStream << "\n";
+  }
+
+  void dumpContainedUnitList(
+      llvm::raw_ostream &outputStream,
+      const lower::pft::ContainedUnitList &containedUnitList) {
+    if (containedUnitList.empty())
+      return;
+    outputStream << "\nContains\n";
+    for (const lower::pft::ContainedUnit &unit : containedUnitList)
+      if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&unit)) {
+        dumpFunctionLikeUnit(outputStream, *func);
+      } else if (const auto *dir =
+                     std::get_if<lower::pft::CompilerDirectiveUnit>(&unit)) {
+        outputStream << getNodeIndex(*dir) << " ";
+        dumpEvaluation(outputStream,
+                       lower::pft::Evaluation{
+                           dir->get<parser::CompilerDirective>(), dir->parent});
+        outputStream << "\n";
+      }
+    outputStream << "End Contains\n";
   }
 
   void
@@ -1321,8 +1362,8 @@ class PFTDumper {
                            const lower::pft::OpenACCDirectiveUnit &directive) {
     outputStream << getNodeIndex(directive) << " ";
     outputStream << "OpenACCDirective: !$acc ";
-    outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
-                        .source.ToString();
+    outputStream
+        << directive.get<parser::OpenACCRoutineConstruct>().source.ToString();
     outpu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95123


More information about the flang-commits mailing list