[flang-commits] [flang] [mlir] Add ZA directives for Flang. (PR #76505)

via flang-commits flang-commits at lists.llvm.org
Thu Dec 28 06:28:46 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mats Petersson (Leporacanthicus)

<details>
<summary>Changes</summary>

Add direcives to support functions and subroutines that use ZA state.

---

Patch is 23.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76505.diff


10 Files Affected:

- (modified) flang/include/flang/Lower/PFTBuilder.h (+8-2) 
- (modified) flang/include/flang/Parser/parse-tree.h (+2-1) 
- (modified) flang/lib/Lower/Bridge.cpp (+109-14) 
- (modified) flang/lib/Lower/PFTBuilder.cpp (+42-21) 
- (modified) flang/lib/Parser/program-parsers.cpp (+2-1) 
- (modified) flang/lib/Semantics/program-tree.cpp (+5-1) 
- (modified) flang/lib/Semantics/resolve-names.cpp (+28-1) 
- (added) flang/test/Lower/arm-ssve-directives.f90 (+83) 
- (modified) flang/test/Parser/compiler-directives.f90 (+8) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+4-2) 


``````````diff
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
 
 void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
 
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
 /// Function-like units may contain evaluations (executable statements) and
 /// nested function-like units (internal procedures and function statements).
 struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
   EvaluationList evaluationList;
   LabelEvalMap labelEvaluationMap;
   SymbolLabelMap assignSymbolLabelMap;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   /// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
   /// is the primary entry point; remaining pairs are alternate entry points.
   /// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
 
   ModuleStatement beginStmt;
   ModuleStatement endStmt;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   EvaluationList evaluationList;
 };
 
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
   UNION_CLASS_BOILERPLATE(ModuleSubprogram);
   std::variant<common::Indirection<FunctionSubprogram>,
       common::Indirection<SubroutineSubprogram>,
-      common::Indirection<SeparateModuleSubprogram>>
+      common::Indirection<SeparateModuleSubprogram>,
+      common::Indirection<CompilerDirective>>
       u;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..043c5104191813 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
 #include "flang/Semantics/runtime-type-info.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
@@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
-                       for (Fortran::lower::pft::FunctionLikeUnit &f :
-                            m.nestedFunctions)
-                         declareFunction(f);
+                       for (Fortran::lower::pft::NestedUnit &unit :
+                            m.nestedUnits) {
+                         if (auto *f = std::get_if<
+                                 Fortran::lower::pft::FunctionLikeUnit>(&unit))
+                           declareFunction(*f);
+                       }
                      },
                      [&](Fortran::lower::pft::BlockDataUnit &b) {
                        if (!globalOmpRequiresSymbol)
@@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         [&]() { createIntrinsicModuleDefinitions(pft); });
 
     // Primary translation pass.
-    for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+    std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+    for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
       std::visit(
           Fortran::common::visitors{
               [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
               [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
-              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, units.end(), d);
+              },
               [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
                 builder = new fir::FirOpBuilder(bridge.getModule(),
                                                 bridge.getKindMap());
@@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                 builder = nullptr;
               },
           },
-          u);
+          *it);
     }
 
     // Once all the code has been translated, create global runtime type info
@@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Compute the set of host associated entities from the nested functions.
     llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      collectHostAssociatedVariables(f, escapeHost);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        collectHostAssociatedVariables(*f, escapeHost);
+    }
     funit.setHostAssociatedSymbols(escapeHost);
 
     // Declare internal procedures
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      declareFunction(f);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        declareFunction(*f);
+    }
   }
 
   /// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       endNewFunction(funit);
     }
     funit.setActiveEntry(0);
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      lowerFunc(f); // internal procedure
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        lowerFunc(*f); // internal procedure
+    }
   }
 
   /// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
-      lowerFunc(f);
+    for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+         it = std::next(it)) {
+      std::visit(
+          Fortran::common::visitors{
+              [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, mod.nestedUnits.end(), d);
+              }},
+          *it);
+    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -5001,6 +5022,80 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                         globalOmpRequiresSymbol);
   }
 
+  /// Process compiler directives that apply to subprograms
+  template <typename ITERATOR>
+  void
+  processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+                             Fortran::lower::pft::CompilerDirectiveUnit &d) {
+    auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+    if (!parserDirective)
+      return;
+    auto *nvList =
+        std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+            &parserDirective->u);
+    if (!nvList)
+      return;
+
+    // get the function the directive applies to (hopefully the next unit)
+    mlir::func::FuncOp mlirFunc;
+    it = std::next(it);
+    if (it != endIt) {
+      auto *pftFunction =
+          std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+      if (pftFunction) {
+        Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+        mlirFunc = callee.getFuncOp();
+      }
+    }
+
+    for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+      std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+      // arm streaming sve directives
+      auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+      if (name == "arm_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+      else if (name == "arm_locally_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+      else if (name == "arm_streaming_compatible")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+      if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+        if (!mlirFunc) {
+          // TODO: share diagnostic code with warnings elsewhere
+          // TODO: source location is printed as loc<"file.f90":line:col>
+          mlir::Location loc = genLocation(parserDirective->source);
+          llvm::errs() << loc << ": warning: ignoring directive '" << name
+                       << "' because it has no associated subprogram\n";
+          continue;
+        }
+        llvm::StringRef attrName =
+            mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+        mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+        mlirFunc->setAttr(attrName, unitAttr);
+      }
+      auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
+      if (name == "arm_new_za")
+        zaMode = mlir::arm_sme::ArmZaMode::NewZA;
+      else if (name == "arm_shared_za")
+        zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
+      else if (name == "arm_preserves_za")
+        zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
+      if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
+        if (!mlirFunc) {
+          // TODO: share diagnostic code with warnings elsewhere
+          // TODO: source location is printed as loc<"file.f90":line:col>
+          mlir::Location loc = genLocation(parserDirective->source);
+          llvm::errs() << loc << ": warning: ignoring directive '" << name
+                       << "' because it has no associated subprogram\n";
+          continue;
+        }
+        llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
+        mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+        mlirFunc->setAttr(attrName, unitAttr);
+      }
+    }
+  }
+
   //===--------------------------------------------------------------------===//
 
   Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
           lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
       return false;
     }
+    if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+      assert(nestedUnitList && "Modules have a nested units list");
+      lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+      addNestedUnit(std::move(unit));
+      return false;
+    }
     return enterConstructOrDirective(directive);
   }
 
@@ -279,7 +285,7 @@ class PFTBuilder {
   bool enterModule(const A &mod) {
     Fortran::lower::pft::ModuleLikeUnit &unit =
         addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
                                                  semanticsContext});
     labelEvaluationMap = &unit.labelEvaluationMap;
     assignSymbolLabelMap = &unit.assignSymbolLabelMap;
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
     if (!pftParentStack.empty()) {
       pftParentStack.back().visit(common::visitors{
           [&](lower::pft::FunctionLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
             labelEvaluationMap = &p.labelEvaluationMap;
             assignSymbolLabelMap = &p.assignSymbolLabelMap;
           },
           [&](lower::pft::ModuleLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
           },
-          [&](auto &) { functionList = nullptr; },
+          [&](auto &) { nestedUnitList = nullptr; },
       });
     }
   }
@@ -432,11 +438,16 @@ class PFTBuilder {
     return std::get<A>(pgm->getUnits().back());
   }
 
+  template <typename A>
+  void addNestedUnit(A &&source) {
+    nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+  }
+
   template <typename A>
   A &addFunction(A &&func) {
-    if (functionList) {
-      functionList->emplace_back(std::move(func));
-      return functionList->back();
+    if (nestedUnitList) {
+      addNestedUnit(func);
+      return std::get<A>(nestedUnitList->back());
     }
     return addUnit(std::move(func));
   }
@@ -459,7 +470,7 @@ class PFTBuilder {
 
   /// Append an Evaluation to the end of the current list.
   lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(!evaluationListStack.empty() && "empty evaluation list stack");
     if (!constructAndDirectiveStack.empty())
       eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
 
   /// push a new list on the stack of Evaluation lists
   void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(evaluationList && evaluationList->empty() &&
            "evaluation list isn't correct");
     evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
 
   /// pop the current list and return to the last Evaluation list
   void popEvaluationList() {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     evaluationListStack.pop_back();
   }
 
@@ -1088,9 +1099,9 @@ class PFTBuilder {
   std::vector<lower::pft::PftNode> pftParentStack;
   const semantics::SemanticsContext &semanticsContext;
 
-  /// functionList points to the internal or module procedure function list
-  /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
-  std::list<lower::pft::FunctionLikeUnit> *functionList{};
+  /// nestedUnitList points to the internal or module procedure unit list
+  /// of nested units (e.g. functions). It may be null.
+  std::list<lower::pft::NestedUnit> *nestedUnitList{};
   std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
   std::vector<lower::pft::Evaluation *> doConstructStack{};
   /// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
       outputStream << ": " << header;
     outputStream << '\n';
     dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
-    if (!functionLikeUnit.nestedFunctions.empty()) {
+    if (!functionLikeUnit.nestedUnits.empty()) {
       outputStream << "\nContains\n";
-      for (const lower::pft::FunctionLikeUnit &func :
-           functionLikeUnit.nestedFunctions)
-        dumpFunctionLikeUnit(outputStream, func);
+      for (const lower::pft::NestedUnit &nested :
+           functionLikeUnit.nestedUnits) {
+        if (const auto *func =
+                std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+          dumpFunctionLikeUnit(outputStream, *func);
+        if (const auto *directive =
+                std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+          dumpCompilerDirectiveUnit(outputStream, *directive);
+      }
       outputStream << "End Contains\n";
     }
     outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
     outputStream << unitKind << ' ' << name << ": " << header << '\n';
     dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
     outputStream << "Contains\n";
-    for (const lower::pft::FunctionLikeUnit &func :
-         moduleLikeUnit.nestedFunctions)
-      dumpFunctionLikeUnit(outputStream, func);
+    for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+      if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+        dumpFunctionLikeUnit(outputStream, *func);
+      if (const auto *directive =
+              std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+        dumpCompilerDirectiveUnit(outputStream, *directive);
+    }
     outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
   }
 
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
 //         separate-module-subprogram
 TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
     construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
-    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+    construct<ModuleSubprogram>(indirect(compilerDirective)))
 
 // R1410 module-nature -> INTRINSIC | NON_INTRINSIC
 constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..50487ea58dff72 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -112,7 +112,11 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
     for (const auto &subp :
         std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
       common::visit(
-          [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
+          common::visitors{
+              [&](const common::Indirection<parser::CompilerDirective> &) {},
+              [&](const auto &y) {
+                node.AddChild(ProgramTree::Build(y.value()));
+              }},
           subp.u);
     }
   }
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..ef8accd4636c3a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,34 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
       }
     }
   } else {
-    Say(x.source, "Compiler directive was ignored"_warn_en_US);
+    bool handled{false};
+    if (const auto *nvList{
+            std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+                &x.u)}) {
+      for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+        std::string name{std::get<parser::Name>(nv.t).ToString()};
+        const std::initializer_list<const char *> handledAttrs{
+            "arm_streaming",
+            "arm_locally_streaming",
+            "arm_streaming_compatible",
+            "arm_shared_za",
+            "arm_new_za",
+            "arm_preserves_za",
+        };
+        if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+            handledAttrs.end()) {
+          // exit early so that subsequent recognised attributes can't change
+          // the result
+          handled = false;
+          break;
+        }
+        // this attribute was handled
+        handled = true;
+      }
+    }
+    if (!handled) {
+      S...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76505


More information about the flang-commits mailing list