[Mlir-commits] [flang] [mlir] Add ZA directives for Flang. (PR #76505)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 28 06:28:46 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mats Petersson (Leporacanthicus)
<details>
<summary>Changes</summary>
Add direcives to support functions and subroutines that use ZA state.
---
Patch is 23.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76505.diff
10 Files Affected:
- (modified) flang/include/flang/Lower/PFTBuilder.h (+8-2)
- (modified) flang/include/flang/Parser/parse-tree.h (+2-1)
- (modified) flang/lib/Lower/Bridge.cpp (+109-14)
- (modified) flang/lib/Lower/PFTBuilder.cpp (+42-21)
- (modified) flang/lib/Parser/program-parsers.cpp (+2-1)
- (modified) flang/lib/Semantics/program-tree.cpp (+5-1)
- (modified) flang/lib/Semantics/resolve-names.cpp (+28-1)
- (added) flang/test/Lower/arm-ssve-directives.f90 (+83)
- (modified) flang/test/Parser/compiler-directives.f90 (+8)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+4-2)
``````````diff
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
ModuleStatement beginStmt;
ModuleStatement endStmt;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
- common::Indirection<SeparateModuleSubprogram>>
+ common::Indirection<SeparateModuleSubprogram>,
+ common::Indirection<CompilerDirective>>
u;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..043c5104191813 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
@@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
- for (Fortran::lower::pft::FunctionLikeUnit &f :
- m.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &unit :
+ m.nestedUnits) {
+ if (auto *f = std::get_if<
+ Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ declareFunction(*f);
+ }
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
@@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });
// Primary translation pass.
- for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+ std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+ for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
- [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, units.end(), d);
+ },
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
@@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
- u);
+ *it);
}
// Once all the code has been translated, create global runtime type info
@@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- collectHostAssociatedVariables(f, escapeHost);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ collectHostAssociatedVariables(*f, escapeHost);
+ }
funit.setHostAssociatedSymbols(escapeHost);
// Declare internal procedures
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ declareFunction(*f);
+ }
}
/// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- lowerFunc(f); // internal procedure
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ lowerFunc(*f); // internal procedure
+ }
}
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
- for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
- lowerFunc(f);
+ for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+ it = std::next(it)) {
+ std::visit(
+ Fortran::common::visitors{
+ [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, mod.nestedUnits.end(), d);
+ }},
+ *it);
+ }
}
void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -5001,6 +5022,80 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}
+ /// Process compiler directives that apply to subprograms
+ template <typename ITERATOR>
+ void
+ processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+ Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+ if (!parserDirective)
+ return;
+ auto *nvList =
+ std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+ &parserDirective->u);
+ if (!nvList)
+ return;
+
+ // get the function the directive applies to (hopefully the next unit)
+ mlir::func::FuncOp mlirFunc;
+ it = std::next(it);
+ if (it != endIt) {
+ auto *pftFunction =
+ std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+ if (pftFunction) {
+ Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+ mlirFunc = callee.getFuncOp();
+ }
+ }
+
+ for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+ // arm streaming sve directives
+ auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+ if (name == "arm_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+ else if (name == "arm_locally_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+ else if (name == "arm_streaming_compatible")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+ if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+ if (!mlirFunc) {
+ // TODO: share diagnostic code with warnings elsewhere
+ // TODO: source location is printed as loc<"file.f90":line:col>
+ mlir::Location loc = genLocation(parserDirective->source);
+ llvm::errs() << loc << ": warning: ignoring directive '" << name
+ << "' because it has no associated subprogram\n";
+ continue;
+ }
+ llvm::StringRef attrName =
+ mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+ mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+ mlirFunc->setAttr(attrName, unitAttr);
+ }
+ auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
+ if (name == "arm_new_za")
+ zaMode = mlir::arm_sme::ArmZaMode::NewZA;
+ else if (name == "arm_shared_za")
+ zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
+ else if (name == "arm_preserves_za")
+ zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
+ if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
+ if (!mlirFunc) {
+ // TODO: share diagnostic code with warnings elsewhere
+ // TODO: source location is printed as loc<"file.f90":line:col>
+ mlir::Location loc = genLocation(parserDirective->source);
+ llvm::errs() << loc << ": warning: ignoring directive '" << name
+ << "' because it has no associated subprogram\n";
+ continue;
+ }
+ llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
+ mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+ mlirFunc->setAttr(attrName, unitAttr);
+ }
+ }
+ }
+
//===--------------------------------------------------------------------===//
Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
return false;
}
+ if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+ assert(nestedUnitList && "Modules have a nested units list");
+ lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+ addNestedUnit(std::move(unit));
+ return false;
+ }
return enterConstructOrDirective(directive);
}
@@ -279,7 +285,7 @@ class PFTBuilder {
bool enterModule(const A &mod) {
Fortran::lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
[&](lower::pft::FunctionLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
[&](lower::pft::ModuleLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
},
- [&](auto &) { functionList = nullptr; },
+ [&](auto &) { nestedUnitList = nullptr; },
});
}
}
@@ -432,11 +438,16 @@ class PFTBuilder {
return std::get<A>(pgm->getUnits().back());
}
+ template <typename A>
+ void addNestedUnit(A &&source) {
+ nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+ }
+
template <typename A>
A &addFunction(A &&func) {
- if (functionList) {
- functionList->emplace_back(std::move(func));
- return functionList->back();
+ if (nestedUnitList) {
+ addNestedUnit(func);
+ return std::get<A>(nestedUnitList->back());
}
return addUnit(std::move(func));
}
@@ -459,7 +470,7 @@ class PFTBuilder {
/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
"evaluation list isn't correct");
evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
evaluationListStack.pop_back();
}
@@ -1088,9 +1099,9 @@ class PFTBuilder {
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;
- /// functionList points to the internal or module procedure function list
- /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
- std::list<lower::pft::FunctionLikeUnit> *functionList{};
+ /// nestedUnitList points to the internal or module procedure unit list
+ /// of nested units (e.g. functions). It may be null.
+ std::list<lower::pft::NestedUnit> *nestedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
- if (!functionLikeUnit.nestedFunctions.empty()) {
+ if (!functionLikeUnit.nestedUnits.empty()) {
outputStream << "\nContains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- functionLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested :
+ functionLikeUnit.nestedUnits) {
+ if (const auto *func =
+ std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\n";
}
outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
outputStream << "Contains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- moduleLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+ if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
}
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
// separate-module-subprogram
TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
- construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+ construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+ construct<ModuleSubprogram>(indirect(compilerDirective)))
// R1410 module-nature -> INTRINSIC | NON_INTRINSIC
constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..50487ea58dff72 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -112,7 +112,11 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
for (const auto &subp :
std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
common::visit(
- [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
+ common::visitors{
+ [&](const common::Indirection<parser::CompilerDirective> &) {},
+ [&](const auto &y) {
+ node.AddChild(ProgramTree::Build(y.value()));
+ }},
subp.u);
}
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..ef8accd4636c3a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,34 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
}
}
} else {
- Say(x.source, "Compiler directive was ignored"_warn_en_US);
+ bool handled{false};
+ if (const auto *nvList{
+ std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+ &x.u)}) {
+ for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name{std::get<parser::Name>(nv.t).ToString()};
+ const std::initializer_list<const char *> handledAttrs{
+ "arm_streaming",
+ "arm_locally_streaming",
+ "arm_streaming_compatible",
+ "arm_shared_za",
+ "arm_new_za",
+ "arm_preserves_za",
+ };
+ if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+ handledAttrs.end()) {
+ // exit early so that subsequent recognised attributes can't change
+ // the result
+ handled = false;
+ break;
+ }
+ // this attribute was handled
+ handled = true;
+ }
+ }
+ if (!handled) {
+ S...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/76505
More information about the Mlir-commits
mailing list