[flang-commits] [flang] RFC: WIP: add support for compiler directives which apply to functions (PR #75352)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Dec 13 08:10:38 PST 2023


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/75352

RFC: not for merge

Some directives (e.g. `noinline`) apply to the following function-like statement. This PR provides partial support for these kinds of directives, using some new directives for Arm streaming SVE (modeled after the clang attributes with the same name).

The directives are currently supported for function-like statements at the top-level or inside of modules.

I would like comments on the rough shape of the implementation (especially the changes to `PFTBuilder.cpp` and the decision to match directives to functions in `Bridge.cpp`).

## To Do
- [ ]  We currently cannot parse compiler directives inside of interface blocks
- [ ]  No support for nested function-like statements
- [ ]  No support for `.mod` files
- [ ]  The warning generated when a directive is not followed by a function-like statement is not using proper diagnostics and prints the source location inside of `loc<>`

>From 1522c61697d307ec0271a3def5f012174f02292f Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Tue, 12 Dec 2023 15:47:22 +0000
Subject: [PATCH 1/3] [Flang] WIP: Allow compiler directives for module
 procedures

---
 flang/include/flang/Parser/parse-tree.h   | 3 ++-
 flang/lib/Parser/program-parsers.cpp      | 3 ++-
 flang/lib/Semantics/program-tree.cpp      | 3 +++
 flang/test/Parser/compiler-directives.f90 | 8 ++++++++
 4 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
   UNION_CLASS_BOILERPLATE(ModuleSubprogram);
   std::variant<common::Indirection<FunctionSubprogram>,
       common::Indirection<SubroutineSubprogram>,
-      common::Indirection<SeparateModuleSubprogram>>
+      common::Indirection<SeparateModuleSubprogram>,
+      common::Indirection<CompilerDirective>>
       u;
 };
 
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
 //         separate-module-subprogram
 TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
     construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
-    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+    construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+    construct<ModuleSubprogram>(indirect(compilerDirective)))
 
 // R1410 module-nature -> INTRINSIC | NON_INTRINSIC
 constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..fcb6392620b67f 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,6 +111,9 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
   if (subps) {
     for (const auto &subp :
         std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
+      if (std::holds_alternative<
+              common::Indirection<parser::CompilerDirective>>(subp.u))
+        continue;
       common::visit(
           [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
           subp.u);
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index 88cfd0944faf0a..526f379326909e 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -22,4 +22,12 @@ module m
   !dir$ optimize : 1
   !dir$ loop count (10000)
   !dir$ loop count (1, 500, 5000, 10000)
+contains
+  !dir$ noinline
+  subroutine sb1()
+  end subroutine
+
+  !dir$ noinline
+  subroutine sb2()
+  end subroutine
 end

>From 8ce3b251d582a96fd2a668a7c29980faf5101ca0 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 11:39:39 +0000
Subject: [PATCH 2/3] WIP: [flang][Lower] support compiler directives inside
 modules in PFTBuilder

Adding a list of all units inside a module in order makes it easier to
match directives with functions that immediately follow them.

The change to nested units inside of functions was incidental to make
this compile, but I expect it should be similarly useful for matching
attributes with the unit that follows them inside of a function body.
---
 flang/include/flang/Lower/PFTBuilder.h | 10 +++-
 flang/lib/Lower/Bridge.cpp             | 33 +++++++++-----
 flang/lib/Lower/PFTBuilder.cpp         | 63 +++++++++++++++++---------
 3 files changed, 72 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
 
 void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
 
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
 /// Function-like units may contain evaluations (executable statements) and
 /// nested function-like units (internal procedures and function statements).
 struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
   EvaluationList evaluationList;
   LabelEvalMap labelEvaluationMap;
   SymbolLabelMap assignSymbolLabelMap;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   /// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
   /// is the primary entry point; remaining pairs are alternate entry points.
   /// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
 
   ModuleStatement beginStmt;
   ModuleStatement endStmt;
-  std::list<FunctionLikeUnit> nestedFunctions;
+  std::list<NestedUnit> nestedUnits;
   EvaluationList evaluationList;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..a08506bf6ebfa0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -303,9 +303,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
-                       for (Fortran::lower::pft::FunctionLikeUnit &f :
-                            m.nestedFunctions)
-                         declareFunction(f);
+                       for (Fortran::lower::pft::NestedUnit &unit :
+                            m.nestedUnits) {
+                         if (auto *f = std::get_if<
+                                 Fortran::lower::pft::FunctionLikeUnit>(&unit))
+                           declareFunction(*f);
+                       }
                      },
                      [&](Fortran::lower::pft::BlockDataUnit &b) {
                        if (!globalOmpRequiresSymbol)
@@ -387,13 +390,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Compute the set of host associated entities from the nested functions.
     llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      collectHostAssociatedVariables(f, escapeHost);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        collectHostAssociatedVariables(*f, escapeHost);
+    }
     funit.setHostAssociatedSymbols(escapeHost);
 
     // Declare internal procedures
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      declareFunction(f);
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        declareFunction(*f);
+    }
   }
 
   /// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4674,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       endNewFunction(funit);
     }
     funit.setActiveEntry(0);
-    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
-      lowerFunc(f); // internal procedure
+    for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+        lowerFunc(*f); // internal procedure
+    }
   }
 
   /// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4701,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
-      lowerFunc(f);
+    for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
+      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+        lowerFunc(*f);
+    }
   }
 
   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
           lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
       return false;
     }
+    if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+      assert(nestedUnitList && "Modules have a nested units list");
+      lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+      addNestedUnit(std::move(unit));
+      return false;
+    }
     return enterConstructOrDirective(directive);
   }
 
@@ -279,7 +285,7 @@ class PFTBuilder {
   bool enterModule(const A &mod) {
     Fortran::lower::pft::ModuleLikeUnit &unit =
         addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
                                                  semanticsContext});
     labelEvaluationMap = &unit.labelEvaluationMap;
     assignSymbolLabelMap = &unit.assignSymbolLabelMap;
-    functionList = &unit.nestedFunctions;
+    nestedUnitList = &unit.nestedUnits;
     pushEvaluationList(&unit.evaluationList);
     pftParentStack.emplace_back(unit);
     LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
     if (!pftParentStack.empty()) {
       pftParentStack.back().visit(common::visitors{
           [&](lower::pft::FunctionLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
             labelEvaluationMap = &p.labelEvaluationMap;
             assignSymbolLabelMap = &p.assignSymbolLabelMap;
           },
           [&](lower::pft::ModuleLikeUnit &p) {
-            functionList = &p.nestedFunctions;
+            nestedUnitList = &p.nestedUnits;
           },
-          [&](auto &) { functionList = nullptr; },
+          [&](auto &) { nestedUnitList = nullptr; },
       });
     }
   }
@@ -432,11 +438,16 @@ class PFTBuilder {
     return std::get<A>(pgm->getUnits().back());
   }
 
+  template <typename A>
+  void addNestedUnit(A &&source) {
+    nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+  }
+
   template <typename A>
   A &addFunction(A &&func) {
-    if (functionList) {
-      functionList->emplace_back(std::move(func));
-      return functionList->back();
+    if (nestedUnitList) {
+      addNestedUnit(func);
+      return std::get<A>(nestedUnitList->back());
     }
     return addUnit(std::move(func));
   }
@@ -459,7 +470,7 @@ class PFTBuilder {
 
   /// Append an Evaluation to the end of the current list.
   lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(!evaluationListStack.empty() && "empty evaluation list stack");
     if (!constructAndDirectiveStack.empty())
       eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
 
   /// push a new list on the stack of Evaluation lists
   void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     assert(evaluationList && evaluationList->empty() &&
            "evaluation list isn't correct");
     evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
 
   /// pop the current list and return to the last Evaluation list
   void popEvaluationList() {
-    assert(functionList && "not in a function");
+    assert(nestedUnitList && "not in a function");
     evaluationListStack.pop_back();
   }
 
@@ -1088,9 +1099,9 @@ class PFTBuilder {
   std::vector<lower::pft::PftNode> pftParentStack;
   const semantics::SemanticsContext &semanticsContext;
 
-  /// functionList points to the internal or module procedure function list
-  /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
-  std::list<lower::pft::FunctionLikeUnit> *functionList{};
+  /// nestedUnitList points to the internal or module procedure unit list
+  /// of nested units (e.g. functions). It may be null.
+  std::list<lower::pft::NestedUnit> *nestedUnitList{};
   std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
   std::vector<lower::pft::Evaluation *> doConstructStack{};
   /// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
       outputStream << ": " << header;
     outputStream << '\n';
     dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
-    if (!functionLikeUnit.nestedFunctions.empty()) {
+    if (!functionLikeUnit.nestedUnits.empty()) {
       outputStream << "\nContains\n";
-      for (const lower::pft::FunctionLikeUnit &func :
-           functionLikeUnit.nestedFunctions)
-        dumpFunctionLikeUnit(outputStream, func);
+      for (const lower::pft::NestedUnit &nested :
+           functionLikeUnit.nestedUnits) {
+        if (const auto *func =
+                std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+          dumpFunctionLikeUnit(outputStream, *func);
+        if (const auto *directive =
+                std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+          dumpCompilerDirectiveUnit(outputStream, *directive);
+      }
       outputStream << "End Contains\n";
     }
     outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
     outputStream << unitKind << ' ' << name << ": " << header << '\n';
     dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
     outputStream << "Contains\n";
-    for (const lower::pft::FunctionLikeUnit &func :
-         moduleLikeUnit.nestedFunctions)
-      dumpFunctionLikeUnit(outputStream, func);
+    for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+      if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+        dumpFunctionLikeUnit(outputStream, *func);
+      if (const auto *directive =
+              std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+        dumpCompilerDirectiveUnit(outputStream, *directive);
+    }
     outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
   }
 

>From 0453b9c7b742fded974e7f3cf38c1d4c61d50904 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 11 Dec 2023 16:29:51 +0000
Subject: [PATCH 3/3] WIP: [flang][Lower] add attributes for arm streaming sve
 directives

---
 flang/lib/Lower/Bridge.cpp               | 76 ++++++++++++++++++++++--
 flang/lib/Semantics/resolve-names.cpp    | 26 +++++++-
 flang/test/Lower/arm-ssve-directives.f90 | 46 ++++++++++++++
 3 files changed, 141 insertions(+), 7 deletions(-)
 create mode 100644 flang/test/Lower/arm-ssve-directives.f90

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a08506bf6ebfa0..5e0b6a801a98ec 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
 #include "flang/Semantics/runtime-type-info.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
@@ -325,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         [&]() { createIntrinsicModuleDefinitions(pft); });
 
     // Primary translation pass.
-    for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+    std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+    for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
       std::visit(
           Fortran::common::visitors{
               [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
               [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
-              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, units.end(), d);
+              },
               [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
                 builder = new fir::FirOpBuilder(bridge.getModule(),
                                                 bridge.getKindMap());
@@ -341,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                 builder = nullptr;
               },
           },
-          u);
+          *it);
     }
 
     // Once all the code has been translated, create global runtime type info
@@ -4701,9 +4705,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Lower functions contained in a module.
   void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
-    for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
-      if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
-        lowerFunc(*f);
+    for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+         it = std::next(it)) {
+      std::visit(
+          Fortran::common::visitors{
+              [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+              [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+                processSubprogramDirective(it, mod.nestedUnits.end(), d);
+              }},
+          *it);
     }
   }
 
@@ -5012,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                         globalOmpRequiresSymbol);
   }
 
+  /// Process compiler directives that apply to subprograms
+  template <typename ITERATOR>
+  void
+  processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+                             Fortran::lower::pft::CompilerDirectiveUnit &d) {
+    auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+    if (!parserDirective)
+      return;
+    auto *nvList =
+        std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+            &parserDirective->u);
+    if (!nvList)
+      return;
+
+    // get the function the directive applies to (hopefully the next unit)
+    mlir::func::FuncOp mlirFunc;
+    it = std::next(it);
+    if (it != endIt) {
+      auto *pftFunction =
+          std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+      if (pftFunction) {
+        Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+        mlirFunc = callee.getFuncOp();
+      }
+    }
+
+    for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+      std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+      // arm streaming sve directives
+      auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+      if (name == "arm_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+      else if (name == "arm_locally_streaming")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+      else if (name == "arm_streaming_compatible")
+        streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+      if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+        if (!mlirFunc) {
+          // TODO: share diagnostic code with warnings elsewhere
+          // TODO: source location is printed as loc<"file.f90":line:col>
+          mlir::Location loc = genLocation(parserDirective->source);
+          llvm::errs() << loc << ": warning: ignoring directive '" << name
+                       << "' because it has no associated subprogram\n";
+          continue;
+        }
+        llvm::StringRef attrName =
+            mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+        mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+        mlirFunc->setAttr(attrName, unitAttr);
+      }
+    }
+  }
+
   //===--------------------------------------------------------------------===//
 
   Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..40fb641e085a78 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,31 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
       }
     }
   } else {
-    Say(x.source, "Compiler directive was ignored"_warn_en_US);
+    bool handled = false;
+    if (const auto *nvList{
+            std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+                &x.u)}) {
+      for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+        std::string name = std::get<parser::Name>(nv.t).ToString();
+        const std::initializer_list<const char *> handledAttrs{
+            "arm_streaming",
+            "arm_locally_streaming",
+            "arm_streaming_compatible",
+        };
+        if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+            handledAttrs.end()) {
+          // exit early so that subsequent recognised attributes can't change
+          // the result
+          handled = false;
+          break;
+        }
+        // this attribute was handled
+        handled = true;
+      }
+    }
+    if (!handled) {
+      Say(x.source, "Compiler directive was ignored"_warn_en_US);
+    }
   }
 }
 
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
new file mode 100644
index 00000000000000..86fbe89920b274
--- /dev/null
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -0,0 +1,46 @@
+! RUN: bbc -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+! check we don't warn about these attributes
+! CHECK-NOT: warning: Compiler directive was ignored
+
+! check we create the right fuction attributes
+
+!dir$ arm_streaming
+subroutine sub
+end subroutine sub
+! CHECK-LABEL: func.func @_QPsub()
+! CHECK-SAME:      attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine sub2
+end subroutine sub2
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK-SAME:      attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine sub3
+end subroutine sub3
+! CHECK-LABEL: func.func @_QPsub3()
+! CHECK-SAME:      attributes {arm_streaming_compatible}
+
+module m
+contains
+
+!dir$ arm_streaming
+subroutine msub
+end subroutine msub
+! CHECK-LABEL: func.func @_QMmPmsub()
+! CHECK-SAME:      attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine msub2
+end subroutine msub2
+! CHECK-LABEL: func.func @_QMmPmsub2()
+! CHECK-SAME:      attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine msub3
+end subroutine msub3
+! CHECK-LABEL: func.func @_QMmPmsub3()
+! CHECK-SAME:      attributes {arm_streaming_compatible}
+end module



More information about the flang-commits mailing list