[flang-commits] [flang] [mlir] Add ZA directives for Flang. (PR #76505)
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Thu Dec 28 06:28:18 PST 2023
https://github.com/Leporacanthicus created https://github.com/llvm/llvm-project/pull/76505
Add direcives to support functions and subroutines that use ZA state.
>From 1522c61697d307ec0271a3def5f012174f02292f Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Tue, 12 Dec 2023 15:47:22 +0000
Subject: [PATCH 1/6] [Flang] WIP: Allow compiler directives for module
procedures
---
flang/include/flang/Parser/parse-tree.h | 3 ++-
flang/lib/Parser/program-parsers.cpp | 3 ++-
flang/lib/Semantics/program-tree.cpp | 3 +++
flang/test/Parser/compiler-directives.f90 | 8 ++++++++
4 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
- common::Indirection<SeparateModuleSubprogram>>
+ common::Indirection<SeparateModuleSubprogram>,
+ common::Indirection<CompilerDirective>>
u;
};
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
// separate-module-subprogram
TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
- construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+ construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+ construct<ModuleSubprogram>(indirect(compilerDirective)))
// R1410 module-nature -> INTRINSIC | NON_INTRINSIC
constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..fcb6392620b67f 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,6 +111,9 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
if (subps) {
for (const auto &subp :
std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
+ if (std::holds_alternative<
+ common::Indirection<parser::CompilerDirective>>(subp.u))
+ continue;
common::visit(
[&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
subp.u);
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index 88cfd0944faf0a..526f379326909e 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -22,4 +22,12 @@ module m
!dir$ optimize : 1
!dir$ loop count (10000)
!dir$ loop count (1, 500, 5000, 10000)
+contains
+ !dir$ noinline
+ subroutine sb1()
+ end subroutine
+
+ !dir$ noinline
+ subroutine sb2()
+ end subroutine
end
>From 8ce3b251d582a96fd2a668a7c29980faf5101ca0 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 11:39:39 +0000
Subject: [PATCH 2/6] WIP: [flang][Lower] support compiler directives inside
modules in PFTBuilder
Adding a list of all units inside a module in order makes it easier to
match directives with functions that immediately follow them.
The change to nested units inside of functions was incidental to make
this compile, but I expect it should be similarly useful for matching
attributes with the unit that follows them inside of a function body.
---
flang/include/flang/Lower/PFTBuilder.h | 10 +++-
flang/lib/Lower/Bridge.cpp | 33 +++++++++-----
flang/lib/Lower/PFTBuilder.cpp | 63 +++++++++++++++++---------
3 files changed, 72 insertions(+), 34 deletions(-)
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
ModuleStatement beginStmt;
ModuleStatement endStmt;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..a08506bf6ebfa0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -303,9 +303,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
- for (Fortran::lower::pft::FunctionLikeUnit &f :
- m.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &unit :
+ m.nestedUnits) {
+ if (auto *f = std::get_if<
+ Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ declareFunction(*f);
+ }
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
@@ -387,13 +390,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- collectHostAssociatedVariables(f, escapeHost);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ collectHostAssociatedVariables(*f, escapeHost);
+ }
funit.setHostAssociatedSymbols(escapeHost);
// Declare internal procedures
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ declareFunction(*f);
+ }
}
/// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4674,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- lowerFunc(f); // internal procedure
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ lowerFunc(*f); // internal procedure
+ }
}
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4701,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
- for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
- lowerFunc(f);
+ for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ lowerFunc(*f);
+ }
}
void setCurrentPosition(const Fortran::parser::CharBlock &position) {
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
return false;
}
+ if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+ assert(nestedUnitList && "Modules have a nested units list");
+ lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+ addNestedUnit(std::move(unit));
+ return false;
+ }
return enterConstructOrDirective(directive);
}
@@ -279,7 +285,7 @@ class PFTBuilder {
bool enterModule(const A &mod) {
Fortran::lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
[&](lower::pft::FunctionLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
[&](lower::pft::ModuleLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
},
- [&](auto &) { functionList = nullptr; },
+ [&](auto &) { nestedUnitList = nullptr; },
});
}
}
@@ -432,11 +438,16 @@ class PFTBuilder {
return std::get<A>(pgm->getUnits().back());
}
+ template <typename A>
+ void addNestedUnit(A &&source) {
+ nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+ }
+
template <typename A>
A &addFunction(A &&func) {
- if (functionList) {
- functionList->emplace_back(std::move(func));
- return functionList->back();
+ if (nestedUnitList) {
+ addNestedUnit(func);
+ return std::get<A>(nestedUnitList->back());
}
return addUnit(std::move(func));
}
@@ -459,7 +470,7 @@ class PFTBuilder {
/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
"evaluation list isn't correct");
evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
evaluationListStack.pop_back();
}
@@ -1088,9 +1099,9 @@ class PFTBuilder {
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;
- /// functionList points to the internal or module procedure function list
- /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
- std::list<lower::pft::FunctionLikeUnit> *functionList{};
+ /// nestedUnitList points to the internal or module procedure unit list
+ /// of nested units (e.g. functions). It may be null.
+ std::list<lower::pft::NestedUnit> *nestedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
- if (!functionLikeUnit.nestedFunctions.empty()) {
+ if (!functionLikeUnit.nestedUnits.empty()) {
outputStream << "\nContains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- functionLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested :
+ functionLikeUnit.nestedUnits) {
+ if (const auto *func =
+ std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\n";
}
outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
outputStream << "Contains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- moduleLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+ if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
}
>From 0453b9c7b742fded974e7f3cf38c1d4c61d50904 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 11 Dec 2023 16:29:51 +0000
Subject: [PATCH 3/6] WIP: [flang][Lower] add attributes for arm streaming sve
directives
---
flang/lib/Lower/Bridge.cpp | 76 ++++++++++++++++++++++--
flang/lib/Semantics/resolve-names.cpp | 26 +++++++-
flang/test/Lower/arm-ssve-directives.f90 | 46 ++++++++++++++
3 files changed, 141 insertions(+), 7 deletions(-)
create mode 100644 flang/test/Lower/arm-ssve-directives.f90
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a08506bf6ebfa0..5e0b6a801a98ec 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
@@ -325,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });
// Primary translation pass.
- for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+ std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+ for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
- [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, units.end(), d);
+ },
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
@@ -341,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
- u);
+ *it);
}
// Once all the code has been translated, create global runtime type info
@@ -4701,9 +4705,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
- for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) {
- if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
- lowerFunc(*f);
+ for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+ it = std::next(it)) {
+ std::visit(
+ Fortran::common::visitors{
+ [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, mod.nestedUnits.end(), d);
+ }},
+ *it);
}
}
@@ -5012,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}
+ /// Process compiler directives that apply to subprograms
+ template <typename ITERATOR>
+ void
+ processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+ Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+ if (!parserDirective)
+ return;
+ auto *nvList =
+ std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+ &parserDirective->u);
+ if (!nvList)
+ return;
+
+ // get the function the directive applies to (hopefully the next unit)
+ mlir::func::FuncOp mlirFunc;
+ it = std::next(it);
+ if (it != endIt) {
+ auto *pftFunction =
+ std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+ if (pftFunction) {
+ Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+ mlirFunc = callee.getFuncOp();
+ }
+ }
+
+ for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+ // arm streaming sve directives
+ auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+ if (name == "arm_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+ else if (name == "arm_locally_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+ else if (name == "arm_streaming_compatible")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+ if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+ if (!mlirFunc) {
+ // TODO: share diagnostic code with warnings elsewhere
+ // TODO: source location is printed as loc<"file.f90":line:col>
+ mlir::Location loc = genLocation(parserDirective->source);
+ llvm::errs() << loc << ": warning: ignoring directive '" << name
+ << "' because it has no associated subprogram\n";
+ continue;
+ }
+ llvm::StringRef attrName =
+ mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+ mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+ mlirFunc->setAttr(attrName, unitAttr);
+ }
+ }
+ }
+
//===--------------------------------------------------------------------===//
Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..40fb641e085a78 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,31 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
}
}
} else {
- Say(x.source, "Compiler directive was ignored"_warn_en_US);
+ bool handled = false;
+ if (const auto *nvList{
+ std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+ &x.u)}) {
+ for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name = std::get<parser::Name>(nv.t).ToString();
+ const std::initializer_list<const char *> handledAttrs{
+ "arm_streaming",
+ "arm_locally_streaming",
+ "arm_streaming_compatible",
+ };
+ if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+ handledAttrs.end()) {
+ // exit early so that subsequent recognised attributes can't change
+ // the result
+ handled = false;
+ break;
+ }
+ // this attribute was handled
+ handled = true;
+ }
+ }
+ if (!handled) {
+ Say(x.source, "Compiler directive was ignored"_warn_en_US);
+ }
}
}
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
new file mode 100644
index 00000000000000..86fbe89920b274
--- /dev/null
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -0,0 +1,46 @@
+! RUN: bbc -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+! check we don't warn about these attributes
+! CHECK-NOT: warning: Compiler directive was ignored
+
+! check we create the right fuction attributes
+
+!dir$ arm_streaming
+subroutine sub
+end subroutine sub
+! CHECK-LABEL: func.func @_QPsub()
+! CHECK-SAME: attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine sub2
+end subroutine sub2
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK-SAME: attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine sub3
+end subroutine sub3
+! CHECK-LABEL: func.func @_QPsub3()
+! CHECK-SAME: attributes {arm_streaming_compatible}
+
+module m
+contains
+
+!dir$ arm_streaming
+subroutine msub
+end subroutine msub
+! CHECK-LABEL: func.func @_QMmPmsub()
+! CHECK-SAME: attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine msub2
+end subroutine msub2
+! CHECK-LABEL: func.func @_QMmPmsub2()
+! CHECK-SAME: attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine msub3
+end subroutine msub3
+! CHECK-LABEL: func.func @_QMmPmsub3()
+! CHECK-SAME: attributes {arm_streaming_compatible}
+end module
>From 69384d37424e3008bf46ce283bda8561374af44b Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 17:23:41 +0000
Subject: [PATCH 4/6] fixup! [Flang] WIP: Allow compiler directives for module
procedures
Use the call to visit instead of a separate std::holds_alternative
---
flang/lib/Semantics/program-tree.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index fcb6392620b67f..50487ea58dff72 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,11 +111,12 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
if (subps) {
for (const auto &subp :
std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
- if (std::holds_alternative<
- common::Indirection<parser::CompilerDirective>>(subp.u))
- continue;
common::visit(
- [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
+ common::visitors{
+ [&](const common::Indirection<parser::CompilerDirective> &) {},
+ [&](const auto &y) {
+ node.AddChild(ProgramTree::Build(y.value()));
+ }},
subp.u);
}
}
>From 7e9b552a3bae17d1c5eb19fc20a57e68c0c1480e Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 13 Dec 2023 17:26:10 +0000
Subject: [PATCH 5/6] fixup! WIP: [flang][Lower] add attributes for arm
streaming sve directives
Use braced initialization
---
flang/lib/Semantics/resolve-names.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 40fb641e085a78..e619e3e0961ec2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,12 +8389,12 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
}
}
} else {
- bool handled = false;
+ bool handled{false};
if (const auto *nvList{
std::get_if<std::list<parser::CompilerDirective::NameValue>>(
&x.u)}) {
for (const parser::CompilerDirective::NameValue &nv : *nvList) {
- std::string name = std::get<parser::Name>(nv.t).ToString();
+ std::string name{std::get<parser::Name>(nv.t).ToString()};
const std::initializer_list<const char *> handledAttrs{
"arm_streaming",
"arm_locally_streaming",
>From 7eb4a3d0a1cb2a33d65375deac4f3177cb85d7ee Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 28 Dec 2023 14:07:49 +0000
Subject: [PATCH 6/6] [flang][Lower] Add the ZA mode directives to support amr
streaming sve
This adds the arm_new_za, arm_shared_za and arm_preserves_za
directives.
Also adds two new enum values in the MLIR defintions for ArmZaMode.
---
flang/lib/Lower/Bridge.cpp | 20 ++++++++++
flang/lib/Semantics/resolve-names.cpp | 3 ++
flang/test/Lower/arm-ssve-directives.f90 | 37 +++++++++++++++++++
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 6 ++-
4 files changed, 64 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5e0b6a801a98ec..043c5104191813 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -5073,6 +5073,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
+ auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
+ if (name == "arm_new_za")
+ zaMode = mlir::arm_sme::ArmZaMode::NewZA;
+ else if (name == "arm_shared_za")
+ zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
+ else if (name == "arm_preserves_za")
+ zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
+ if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
+ if (!mlirFunc) {
+ // TODO: share diagnostic code with warnings elsewhere
+ // TODO: source location is printed as loc<"file.f90":line:col>
+ mlir::Location loc = genLocation(parserDirective->source);
+ llvm::errs() << loc << ": warning: ignoring directive '" << name
+ << "' because it has no associated subprogram\n";
+ continue;
+ }
+ llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
+ mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+ mlirFunc->setAttr(attrName, unitAttr);
+ }
}
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e619e3e0961ec2..ef8accd4636c3a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8399,6 +8399,9 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
"arm_streaming",
"arm_locally_streaming",
"arm_streaming_compatible",
+ "arm_shared_za",
+ "arm_new_za",
+ "arm_preserves_za",
};
if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
handledAttrs.end()) {
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
index 86fbe89920b274..dd4644f336b6b3 100644
--- a/flang/test/Lower/arm-ssve-directives.f90
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -23,6 +23,24 @@ end subroutine sub3
! CHECK-LABEL: func.func @_QPsub3()
! CHECK-SAME: attributes {arm_streaming_compatible}
+!dir$ arm_new_za
+subroutine sub4
+end subroutine sub4
+! CHECK-LABEL: func.func @_QPsub4()
+! CHECK-SAME: attributes {arm_new_za}
+
+!dir$ arm_shared_za
+subroutine sub5
+end subroutine sub5
+! CHECK-LABEL: func.func @_QPsub5()
+! CHECK-SAME: attributes {arm_shared_za}
+
+!dir$ arm_preserves_za
+subroutine sub6
+end subroutine sub6
+! CHECK-LABEL: func.func @_QPsub6()
+! CHECK-SAME: attributes {arm_preserves_za}
+
module m
contains
@@ -43,4 +61,23 @@ subroutine msub3
end subroutine msub3
! CHECK-LABEL: func.func @_QMmPmsub3()
! CHECK-SAME: attributes {arm_streaming_compatible}
+
+!dir$ arm_new_za
+subroutine msub4
+end subroutine msub4
+! CHECK-LABEL: func.func @_QMmPmsub4()
+! CHECK-SAME: attributes {arm_new_za}
+
+!dir$ arm_shared_za
+subroutine msub5
+end subroutine msub5
+! CHECK-LABEL: func.func @_QMmPmsub5()
+! CHECK-SAME: attributes {arm_shared_za}
+
+!dir$ arm_preserves_za
+subroutine msub6
+end subroutine msub6
+! CHECK-LABEL: func.func @_QMmPmsub6()
+! CHECK-SAME: attributes {arm_preserves_za}
+
end module
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 4266ac5b0c8cf6..57f9ac007bae94 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -28,13 +28,15 @@ def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode
let genSpecializedAttr = 0;
}
-// TODO: Add other ZA modes.
-// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
[
I32EnumAttrCase<"Disabled", 0, "disabled">,
// A function's ZA state is created on entry and destroyed on exit.
I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
+ // A function that preserves ZA state.
+ I32EnumAttrCase<"PreservesZA", 2, "arm_preserves_za">,
+ // A function that uses ZA state as input and/or output
+ I32EnumAttrCase<"SharedZA", 3, "arm_shared_za">,
]>{
let cppNamespace = "mlir::arm_sme";
let genSpecializedAttr = 0;
More information about the flang-commits
mailing list