[flang-commits] [mlir] [flang] Add ZA directives for Flang. (PR #76505)

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Fri Dec 29 06:29:12 PST 2023


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/76505

>From 1522c61697d307ec0271a3def5f012174f02292f Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Tue, 12 Dec 2023 15:47:22 +0000
Subject: [PATCH 1/7] [Flang] WIP: Allow compiler directives for module
 procedures

---
 flang/include/flang/Parser/parse-tree.h   | 3 ++-
 flang/lib/Parser/program-parsers.cpp      | 3 ++-
 flang/lib/Semantics/program-tree.cpp      | 3 +++
 flang/test/Parser/compiler-directives.f90 | 8 ++++++++
 4 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
   UNION_CLASS_BOILERPLATE(ModuleSubprogram);
   std::variant<common::Indirection<FunctionSubprogram>,
       common::Indirection<SubroutineSubprogram>,
-      common::Indirection<SeparateModuleSubprogram>>
+      common::Indirection<SeparateModuleSubprogram>,
+      common::Indirection<CompilerDirective>>
       u;
 };
 
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
 //         separate-module-subprogram
 TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
     construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
-    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+    construct<ModuleSubprogram>(indirect(compilerDirective)))
 
 // R1410 module-nature -> INTRINSIC | NON_INTRINSIC
 constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..fcb6392620b67f 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,6 +111,9 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
   if (subps) {
     for (const auto &subp :
         std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
+      if (std::holds_alternative<
+              common::Indirection<parser::CompilerDirective>>(subp.u))
+        continue;
       common::visit(
           [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
           subp.u);
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index 88cfd0944faf0a..526f379326909e 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -22,4 +22,12 @@ module m
   !dir$ optimize : 1
   !dir$ loop count (10000)
   !dir$ loop count (1, 500, 5000, 10000)
+contains
+  !dir$ noinline
+  subroutine sb1()
+  end subroutine
+
+  !dir$ noinline
+  subroutine sb2()
+  end subroutine
 end

>From 8ce3b251d582a96fd2a668a7c29980faf5101ca0 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 11:39:39 +0000
Subject: [PATCH 2/7] WIP: [flang][Lower] support compiler directives inside
 modules in PFTBuilder

Adding a list of all units inside a module in order makes it easier to
match directives with functions that immediately follow them.

The change to nested units inside of functions was incidental to make
this compile, but I expect it should be similarly useful for matching
attributes with the unit that follows them inside of a function body.
---
 flang/include/flang/Lower/PFTBuilder.h | 10 +++-
 flang/lib/Lower/Bridge.cpp             | 33 +++++++++-----
 flang/lib/Lower/PFTBuilder.cpp         | 63 +++++++++++++++++---------
 3 files changed, 72 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
 
 void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
 
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
 /// Function-like units may contain evaluations (executable statements) and
 /// nested function-like units (internal procedures and function statements).
 struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
   EvaluationList evaluationList;
   LabelEvalMap labelEvaluationMap;
   SymbolLabelMap assignSymbolLabelMap;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   /// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
   /// is the primary entry point; remaining pairs are alternate entry points.
   /// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
 
   ModuleStatement beginStmt;
   ModuleStatement endStmt;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   EvaluationList evaluationList;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..a08506bf6ebfa0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -303,9 +303,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
-                       for (Fortran::lower::pft::FunctionLikeUnit &f :
-                            m.nestedFunctions)
-                         declareFunction(f);
+                       for (Fortran::lower::pft::NestedUnit &unit :
+                            m.nestedUnits) {
+                         if (auto *f = std::get_if<
+                                 Fortran::lower::pft::FunctionLikeUnit>(&unit))
+                           declareFunction(*f);
+                       }
                      },
                      [&](Fortran::lower::pft::BlockDataUnit &b) {
                        if (!globalOmpRequiresSymbol)
@@ -387,13 +390,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Compute the set of host associated entities from the nested functions.
     llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      collectHostAssociatedVariables(f, escapeHost);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        collectHostAssociatedVariables(*f, escapeHost);
+    }
     funit.setHostAssociatedSymbols(escapeHost);
 
     // Declare internal procedures
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      declareFunction(f);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        declareFunction(*f);
+    }
   }
 
   /// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4674,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       endNewFunction(funit);
     }
     funit.setActiveEntry(0);
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      lowerFunc(f); // internal procedure
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        lowerFunc(*f); // internal procedure
+    }
   }
 
   /// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4701,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
-      lowerFunc(f);
+    for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        lowerFunc(*f);
+    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
           lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
       return false;
     }
+    if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+      assert(nestedUnitList && "Modules have a nested units list");
+      lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+      addNestedUnit(std::move(unit));
+      return false;
+    }
     return enterConstructOrDirective(directive);
   }
 
@@ -279,7 +285,7 @@ class PFTBuilder {
   bool enterModule(const A &mod) {
     Fortran::lower::pft::ModuleLikeUnit &unit =
         addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
                                                  semanticsContext});
     labelEvaluationMap = &unit.labelEvaluationMap;
     assignSymbolLabelMap = &unit.assignSymbolLabelMap;
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
     if (!pftParentStack.empty()) {
       pftParentStack.back().visit(common::visitors{
           [&](lower::pft::FunctionLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
             labelEvaluationMap = &p.labelEvaluationMap;
             assignSymbolLabelMap = &p.assignSymbolLabelMap;
           },
           [&](lower::pft::ModuleLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
           },
-          [&](auto &) { functionList = nullptr; },
+          [&](auto &) { nestedUnitList = nullptr; },
       });
     }
   }
@@ -432,11 +438,16 @@ class PFTBuilder {
     return std::get<A>(pgm->getUnits().back());
   }
 
+  template <typename A>
+  void addNestedUnit(A &&source) {
+    nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+  }
+
   template <typename A>
   A &addFunction(A &&func) {
-    if (functionList) {
-      functionList->emplace_back(std::move(func));
-      return functionList->back();
+    if (nestedUnitList) {
+      addNestedUnit(func);
+      return std::get<A>(nestedUnitList->back());
     }
     return addUnit(std::move(func));
   }
@@ -459,7 +470,7 @@ class PFTBuilder {
 
   /// Append an Evaluation to the end of the current list.
   lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(!evaluationListStack.empty() && "empty evaluation list stack");
     if (!constructAndDirectiveStack.empty())
       eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
 
   /// push a new list on the stack of Evaluation lists
   void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(evaluationList && evaluationList->empty() &&
            "evaluation list isn't correct");
     evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
 
   /// pop the current list and return to the last Evaluation list
   void popEvaluationList() {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     evaluationListStack.pop_back();
   }
 
@@ -1088,9 +1099,9 @@ class PFTBuilder {
   std::vector<lower::pft::PftNode> pftParentStack;
   const semantics::SemanticsContext &semanticsContext;
 
-  /// functionList points to the internal or module procedure function list
-  /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
-  std::list<lower::pft::FunctionLikeUnit> *functionList{};
+  /// nestedUnitList points to the internal or module procedure unit list
+  /// of nested units (e.g. functions). It may be null.
+  std::list<lower::pft::NestedUnit> *nestedUnitList{};
   std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
   std::vector<lower::pft::Evaluation *> doConstructStack{};
   /// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
       outputStream << ": " << header;
     outputStream << '\n';
     dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
-    if (!functionLikeUnit.nestedFunctions.empty()) {
+    if (!functionLikeUnit.nestedUnits.empty()) {
       outputStream << "\nContains\n";
-      for (const lower::pft::FunctionLikeUnit &func :
-           functionLikeUnit.nestedFunctions)
-        dumpFunctionLikeUnit(outputStream, func);
+      for (const lower::pft::NestedUnit &nested :
+           functionLikeUnit.nestedUnits) {
+        if (const auto *func =
+                std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+          dumpFunctionLikeUnit(outputStream, *func);
+        if (const auto *directive =
+                std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+          dumpCompilerDirectiveUnit(outputStream, *directive);
+      }
       outputStream << "End Contains\n";
     }
     outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
     outputStream << unitKind << ' ' << name << ": " << header << '\n';
     dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
     outputStream << "Contains\n";
-    for (const lower::pft::FunctionLikeUnit &func :
-         moduleLikeUnit.nestedFunctions)
-      dumpFunctionLikeUnit(outputStream, func);
+    for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+      if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+        dumpFunctionLikeUnit(outputStream, *func);
+      if (const auto *directive =
+              std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+        dumpCompilerDirectiveUnit(outputStream, *directive);
+    }
     outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
   }
 

>From 0453b9c7b742fded974e7f3cf38c1d4c61d50904 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 11 Dec 2023 16:29:51 +0000
Subject: [PATCH 3/7] WIP: [flang][Lower] add attributes for arm streaming sve
 directives

---
 flang/lib/Lower/Bridge.cpp               | 76 ++++++++++++++++++++++--
 flang/lib/Semantics/resolve-names.cpp    | 26 +++++++-
 flang/test/Lower/arm-ssve-directives.f90 | 46 ++++++++++++++
 3 files changed, 141 insertions(+), 7 deletions(-)
 create mode 100644 flang/test/Lower/arm-ssve-directives.f90

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a08506bf6ebfa0..5e0b6a801a98ec 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
 #include "flang/Semantics/runtime-type-info.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
@@ -325,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         [&]() { createIntrinsicModuleDefinitions(pft); });
 
     // Primary translation pass.
-    for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+    std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+    for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
       std::visit(
           Fortran::common::visitors{
               [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
               [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
-              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, units.end(), d);
+              },
               [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
                 builder = new fir::FirOpBuilder(bridge.getModule(),
                                                 bridge.getKindMap());
@@ -341,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                 builder = nullptr;
               },
           },
-          u);
+          *it);
     }
 
     // Once all the code has been translated, create global runtime type info
@@ -4701,9 +4705,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
-      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
-        lowerFunc(*f);
+    for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+         it = std::next(it)) {
+      std::visit(
+          Fortran::common::visitors{
+              [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, mod.nestedUnits.end(), d);
+              }},
+          *it);
     }
   }
 
@@ -5012,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                         globalOmpRequiresSymbol);
   }
 
+  /// Process compiler directives that apply to subprograms
+  template <typename ITERATOR>
+  void
+  processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+                             Fortran::lower::pft::CompilerDirectiveUnit &d) {
+    auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+    if (!parserDirective)
+      return;
+    auto *nvList =
+        std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+            &parserDirective->u);
+    if (!nvList)
+      return;
+
+    // get the function the directive applies to (hopefully the next unit)
+    mlir::func::FuncOp mlirFunc;
+    it = std::next(it);
+    if (it != endIt) {
+      auto *pftFunction =
+          std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+      if (pftFunction) {
+        Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+        mlirFunc = callee.getFuncOp();
+      }
+    }
+
+    for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+      std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+      // arm streaming sve directives
+      auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+      if (name == "arm_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+      else if (name == "arm_locally_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+      else if (name == "arm_streaming_compatible")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+      if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+        if (!mlirFunc) {
+          // TODO: share diagnostic code with warnings elsewhere
+          // TODO: source location is printed as loc<"file.f90":line:col>
+          mlir::Location loc = genLocation(parserDirective->source);
+          llvm::errs() << loc << ": warning: ignoring directive '" << name
+                       << "' because it has no associated subprogram\n";
+          continue;
+        }
+        llvm::StringRef attrName =
+            mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+        mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+        mlirFunc->setAttr(attrName, unitAttr);
+      }
+    }
+  }
+
   //===--------------------------------------------------------------------===//
 
   Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..40fb641e085a78 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,31 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
       }
     }
   } else {
-    Say(x.source, "Compiler directive was ignored"_warn_en_US);
+    bool handled = false;
+    if (const auto *nvList{
+            std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+                &x.u)}) {
+      for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+        std::string name = std::get<parser::Name>(nv.t).ToString();
+        const std::initializer_list<const char *> handledAttrs{
+            "arm_streaming",
+            "arm_locally_streaming",
+            "arm_streaming_compatible",
+        };
+        if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+            handledAttrs.end()) {
+          // exit early so that subsequent recognised attributes can't change
+          // the result
+          handled = false;
+          break;
+        }
+        // this attribute was handled
+        handled = true;
+      }
+    }
+    if (!handled) {
+      Say(x.source, "Compiler directive was ignored"_warn_en_US);
+    }
   }
 }
 
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
new file mode 100644
index 00000000000000..86fbe89920b274
--- /dev/null
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -0,0 +1,46 @@
+! RUN: bbc -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+! check we don't warn about these attributes
+! CHECK-NOT: warning: Compiler directive was ignored
+
+! check we create the right fuction attributes
+
+!dir$ arm_streaming
+subroutine sub
+end subroutine sub
+! CHECK-LABEL: func.func @_QPsub()
+! CHECK-SAME:      attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine sub2
+end subroutine sub2
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK-SAME:      attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine sub3
+end subroutine sub3
+! CHECK-LABEL: func.func @_QPsub3()
+! CHECK-SAME:      attributes {arm_streaming_compatible}
+
+module m
+contains
+
+!dir$ arm_streaming
+subroutine msub
+end subroutine msub
+! CHECK-LABEL: func.func @_QMmPmsub()
+! CHECK-SAME:      attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine msub2
+end subroutine msub2
+! CHECK-LABEL: func.func @_QMmPmsub2()
+! CHECK-SAME:      attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine msub3
+end subroutine msub3
+! CHECK-LABEL: func.func @_QMmPmsub3()
+! CHECK-SAME:      attributes {arm_streaming_compatible}
+end module

>From 69384d37424e3008bf46ce283bda8561374af44b Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 17:23:41 +0000
Subject: [PATCH 4/7] fixup! [Flang] WIP: Allow compiler directives for module
 procedures

Use the call to visit instead of a separate std::holds_alternative
---
 flang/lib/Semantics/program-tree.cpp | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index fcb6392620b67f..50487ea58dff72 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,11 +111,12 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
   if (subps) {
     for (const auto &subp :
         std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
-      if (std::holds_alternative<
-              common::Indirection<parser::CompilerDirective>>(subp.u))
-        continue;
       common::visit(
-          [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
+          common::visitors{
+              [&](const common::Indirection<parser::CompilerDirective> &) {},
+              [&](const auto &y) {
+                node.AddChild(ProgramTree::Build(y.value()));
+              }},
           subp.u);
     }
   }

>From 7e9b552a3bae17d1c5eb19fc20a57e68c0c1480e Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 17:26:10 +0000
Subject: [PATCH 5/7] fixup! WIP: [flang][Lower] add attributes for arm
 streaming sve directives

Use braced initialization
---
 flang/lib/Semantics/resolve-names.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 40fb641e085a78..e619e3e0961ec2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,12 +8389,12 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
       }
     }
   } else {
-    bool handled = false;
+    bool handled{false};
     if (const auto *nvList{
             std::get_if<std::list<parser::CompilerDirective::NameValue>>(
                 &x.u)}) {
       for (const parser::CompilerDirective::NameValue &nv : *nvList) {
-        std::string name = std::get<parser::Name>(nv.t).ToString();
+        std::string name{std::get<parser::Name>(nv.t).ToString()};
         const std::initializer_list<const char *> handledAttrs{
             "arm_streaming",
             "arm_locally_streaming",

>From 7eb4a3d0a1cb2a33d65375deac4f3177cb85d7ee Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 28 Dec 2023 14:07:49 +0000
Subject: [PATCH 6/7] [flang][Lower] Add the ZA mode directives to support amr
 streaming sve

This adds the arm_new_za, arm_shared_za and arm_preserves_za
directives.

Also adds two new enum values in the MLIR defintions for ArmZaMode.
---
 flang/lib/Lower/Bridge.cpp                    | 20 ++++++++++
 flang/lib/Semantics/resolve-names.cpp         |  3 ++
 flang/test/Lower/arm-ssve-directives.f90      | 37 +++++++++++++++++++
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  6 ++-
 4 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5e0b6a801a98ec..043c5104191813 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -5073,6 +5073,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
         mlirFunc->setAttr(attrName, unitAttr);
       }
+      auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
+      if (name == "arm_new_za")
+        zaMode = mlir::arm_sme::ArmZaMode::NewZA;
+      else if (name == "arm_shared_za")
+        zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
+      else if (name == "arm_preserves_za")
+        zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
+      if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
+        if (!mlirFunc) {
+          // TODO: share diagnostic code with warnings elsewhere
+          // TODO: source location is printed as loc<"file.f90":line:col>
+          mlir::Location loc = genLocation(parserDirective->source);
+          llvm::errs() << loc << ": warning: ignoring directive '" << name
+                       << "' because it has no associated subprogram\n";
+          continue;
+        }
+        llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
+        mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+        mlirFunc->setAttr(attrName, unitAttr);
+      }
     }
   }
 
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e619e3e0961ec2..ef8accd4636c3a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8399,6 +8399,9 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
             "arm_streaming",
             "arm_locally_streaming",
             "arm_streaming_compatible",
+            "arm_shared_za",
+            "arm_new_za",
+            "arm_preserves_za",
         };
         if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
             handledAttrs.end()) {
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
index 86fbe89920b274..dd4644f336b6b3 100644
--- a/flang/test/Lower/arm-ssve-directives.f90
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -23,6 +23,24 @@ end subroutine sub3
 ! CHECK-LABEL: func.func @_QPsub3()
 ! CHECK-SAME:      attributes {arm_streaming_compatible}
 
+!dir$ arm_new_za
+subroutine sub4
+end subroutine sub4
+! CHECK-LABEL: func.func @_QPsub4()
+! CHECK-SAME:      attributes {arm_new_za}
+
+!dir$ arm_shared_za
+subroutine sub5
+end subroutine sub5
+! CHECK-LABEL: func.func @_QPsub5()
+! CHECK-SAME:      attributes {arm_shared_za}
+
+!dir$ arm_preserves_za
+subroutine sub6
+end subroutine sub6
+! CHECK-LABEL: func.func @_QPsub6()
+! CHECK-SAME:      attributes {arm_preserves_za}
+
 module m
 contains
 
@@ -43,4 +61,23 @@ subroutine msub3
 end subroutine msub3
 ! CHECK-LABEL: func.func @_QMmPmsub3()
 ! CHECK-SAME:      attributes {arm_streaming_compatible}
+
+!dir$ arm_new_za
+subroutine msub4
+end subroutine msub4
+! CHECK-LABEL: func.func @_QMmPmsub4()
+! CHECK-SAME:      attributes {arm_new_za}
+
+!dir$ arm_shared_za
+subroutine msub5
+end subroutine msub5
+! CHECK-LABEL: func.func @_QMmPmsub5()
+! CHECK-SAME:      attributes {arm_shared_za}
+
+!dir$ arm_preserves_za
+subroutine msub6
+end subroutine msub6
+! CHECK-LABEL: func.func @_QMmPmsub6()
+! CHECK-SAME:      attributes {arm_preserves_za}
+
 end module
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 4266ac5b0c8cf6..57f9ac007bae94 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -28,13 +28,15 @@ def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode
   let genSpecializedAttr = 0;
 }
 
-// TODO: Add other ZA modes.
-// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
 def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
     [
       I32EnumAttrCase<"Disabled", 0, "disabled">,
       // A function's ZA state is created on entry and destroyed on exit.
       I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
+      // A function that preserves ZA state.
+      I32EnumAttrCase<"PreservesZA", 2, "arm_preserves_za">,
+      // A function that uses ZA state as input and/or output
+      I32EnumAttrCase<"SharedZA", 3, "arm_shared_za">,
     ]>{
   let cppNamespace = "mlir::arm_sme";
   let genSpecializedAttr = 0;

>From 6ab4020dbae116f29f95a308251e6cc56a97de2e Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 29 Dec 2023 14:16:27 +0000
Subject: [PATCH 7/7] Add docs for directives

---
 flang/docs/Directives.md | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/flang/docs/Directives.md b/flang/docs/Directives.md
index c8a2c087dfad14..9bd52f11a4fb20 100644
--- a/flang/docs/Directives.md
+++ b/flang/docs/Directives.md
@@ -29,3 +29,28 @@ A list of non-standard directives supported by Flang
     end
   end interface
 ```
+
+## ARM Streaming SVE directives
+
+These directives are added to support ARM specific instructions. All of
+these attributes apply to a specific subroutine or function. These directives
+are identical to the attributes provided in C and C++ for the same purpose.
+See https://arm-software.github.io/acle/main/acle.html#controlling-the-use-of-streaming-mode for more in depth details. (For the following, function is used
+to mean both subroutine and function).
+
+### Directives relating to ARM Streaming mode
+
+* `!dir$ arm_streaming` - The function is intended to be used in streaming
+  mode.
+* `!dir$ arm_streaming_compatible` - The function can work both in streaming
+  mode and non-streaming mode.
+* `!dir$ arm_streaming` - The function will enter streaming mode, and return to
+  non-streaming mode when reaturning.
+
+### Directives relating to ZA
+
+* `!dir$ arm_shared_za` - A function that uses ZA for input or output.
+* `!dir$ arm_new_za` - A function that has ZA state created and destroyed within
+  the function.
+* `!dir$ arm_preserves_za` - Optimisation hint for the compiler that the
+  function either doesn't alter, or saves and restores the ZA state.



More information about the flang-commits mailing list