[flang-commits] [flang] RFC: WIP: add support for compiler directives which apply to functions (PR #75352)
via flang-commits
flang-commits at lists.llvm.org
Wed Dec 13 08:11:07 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-parser
Author: Tom Eccles (tblah)
<details>
<summary>Changes</summary>
RFC: not for merge
Some directives (e.g. `noinline`) apply to the following function-like statement. This PR provides partial support for these kinds of directives, using some new directives for Arm streaming SVE (modeled after the clang attributes with the same name).
The directives are currently supported for function-like statements at the top-level or inside of modules.
I would like comments on the rough shape of the implementation (especially the changes to `PFTBuilder.cpp` and the decision to match directives to functions in `Bridge.cpp`).
## To Do
- [ ] We currently cannot parse compiler directives inside of interface blocks
- [ ] No support for nested function-like statements
- [ ] No support for `.mod` files
- [ ] The warning generated when a directive is not followed by a function-like statement is not using proper diagnostics and prints the source location inside of `loc<>`
---
Patch is 20.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75352.diff
9 Files Affected:
- (modified) flang/include/flang/Lower/PFTBuilder.h (+8-2)
- (modified) flang/include/flang/Parser/parse-tree.h (+2-1)
- (modified) flang/lib/Lower/Bridge.cpp (+89-14)
- (modified) flang/lib/Lower/PFTBuilder.cpp (+42-21)
- (modified) flang/lib/Parser/program-parsers.cpp (+2-1)
- (modified) flang/lib/Semantics/program-tree.cpp (+3)
- (modified) flang/lib/Semantics/resolve-names.cpp (+25-1)
- (added) flang/test/Lower/arm-ssve-directives.f90 (+46)
- (modified) flang/test/Parser/compiler-directives.f90 (+8)
``````````diff
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9c6696ff79dae1..16076c11483f61 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
+/// Things that can be nested inside of a module or function
+/// TODO: add the rest
+struct FunctionLikeUnit;
+struct CompilerDirectiveUnit;
+using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;
+
/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {
ModuleStatement beginStmt;
ModuleStatement endStmt;
- std::list<FunctionLikeUnit> nestedFunctions;
+ std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 393e0e24ec5cbd..880f1e249d34a6 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
- common::Indirection<SeparateModuleSubprogram>>
+ common::Indirection<SeparateModuleSubprogram>,
+ common::Indirection<CompilerDirective>>
u;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7e64adc3c144c9..5e0b6a801a98ec 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
@@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
- for (Fortran::lower::pft::FunctionLikeUnit &f :
- m.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &unit :
+ m.nestedUnits) {
+ if (auto *f = std::get_if<
+ Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ declareFunction(*f);
+ }
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
@@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });
// Primary translation pass.
- for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
+ std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
+ for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
- [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, units.end(), d);
+ },
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
@@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
- u);
+ *it);
}
// Once all the code has been translated, create global runtime type info
@@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- collectHostAssociatedVariables(f, escapeHost);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ collectHostAssociatedVariables(*f, escapeHost);
+ }
funit.setHostAssociatedSymbols(escapeHost);
// Declare internal procedures
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ declareFunction(*f);
+ }
}
/// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- lowerFunc(f); // internal procedure
+ for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
+ lowerFunc(*f); // internal procedure
+ }
}
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
- for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
- lowerFunc(f);
+ for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
+ it = std::next(it)) {
+ std::visit(
+ Fortran::common::visitors{
+ [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ processSubprogramDirective(it, mod.nestedUnits.end(), d);
+ }},
+ *it);
+ }
}
void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -5001,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}
+ /// Process compiler directives that apply to subprograms
+ template <typename ITERATOR>
+ void
+ processSubprogramDirective(ITERATOR it, ITERATOR endIt,
+ Fortran::lower::pft::CompilerDirectiveUnit &d) {
+ auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
+ if (!parserDirective)
+ return;
+ auto *nvList =
+ std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
+ &parserDirective->u);
+ if (!nvList)
+ return;
+
+ // get the function the directive applies to (hopefully the next unit)
+ mlir::func::FuncOp mlirFunc;
+ it = std::next(it);
+ if (it != endIt) {
+ auto *pftFunction =
+ std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
+ if (pftFunction) {
+ Fortran::lower::CalleeInterface callee{*pftFunction, *this};
+ mlirFunc = callee.getFuncOp();
+ }
+ }
+
+ for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();
+
+ // arm streaming sve directives
+ auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
+ if (name == "arm_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
+ else if (name == "arm_locally_streaming")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
+ else if (name == "arm_streaming_compatible")
+ streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
+ if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
+ if (!mlirFunc) {
+ // TODO: share diagnostic code with warnings elsewhere
+ // TODO: source location is printed as loc<"file.f90":line:col>
+ mlir::Location loc = genLocation(parserDirective->source);
+ llvm::errs() << loc << ": warning: ignoring directive '" << name
+ << "' because it has no associated subprogram\n";
+ continue;
+ }
+ llvm::StringRef attrName =
+ mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
+ mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
+ mlirFunc->setAttr(attrName, unitAttr);
+ }
+ }
+ }
+
//===--------------------------------------------------------------------===//
Fortran::lower::LoweringBridge &bridge;
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 32ed539c775b82..0a08e1cf7ff47e 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -259,6 +259,12 @@ class PFTBuilder {
lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
return false;
}
+ if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
+ assert(nestedUnitList && "Modules have a nested units list");
+ lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
+ addNestedUnit(std::move(unit));
+ return false;
+ }
return enterConstructOrDirective(directive);
}
@@ -279,7 +285,7 @@ class PFTBuilder {
bool enterModule(const A &mod) {
Fortran::lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -349,7 +355,7 @@ class PFTBuilder {
semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
- functionList = &unit.nestedFunctions;
+ nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -414,14 +420,14 @@ class PFTBuilder {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
[&](lower::pft::FunctionLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
[&](lower::pft::ModuleLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ nestedUnitList = &p.nestedUnits;
},
- [&](auto &) { functionList = nullptr; },
+ [&](auto &) { nestedUnitList = nullptr; },
});
}
}
@@ -432,11 +438,16 @@ class PFTBuilder {
return std::get<A>(pgm->getUnits().back());
}
+ template <typename A>
+ void addNestedUnit(A &&source) {
+ nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
+ }
+
template <typename A>
A &addFunction(A &&func) {
- if (functionList) {
- functionList->emplace_back(std::move(func));
- return functionList->back();
+ if (nestedUnitList) {
+ addNestedUnit(func);
+ return std::get<A>(nestedUnitList->back());
}
return addUnit(std::move(func));
}
@@ -459,7 +470,7 @@ class PFTBuilder {
/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,7 +510,7 @@ class PFTBuilder {
/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
"evaluation list isn't correct");
evaluationListStack.emplace_back(evaluationList);
@@ -507,7 +518,7 @@ class PFTBuilder {
/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
- assert(functionList && "not in a function");
+ assert(nestedUnitList && "not in a function");
evaluationListStack.pop_back();
}
@@ -1088,9 +1099,9 @@ class PFTBuilder {
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;
- /// functionList points to the internal or module procedure function list
- /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
- std::list<lower::pft::FunctionLikeUnit> *functionList{};
+ /// nestedUnitList points to the internal or module procedure unit list
+ /// of nested units (e.g. functions). It may be null.
+ std::list<lower::pft::NestedUnit> *nestedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
@@ -1264,11 +1275,17 @@ class PFTDumper {
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
- if (!functionLikeUnit.nestedFunctions.empty()) {
+ if (!functionLikeUnit.nestedUnits.empty()) {
outputStream << "\nContains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- functionLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested :
+ functionLikeUnit.nestedUnits) {
+ if (const auto *func =
+ std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\n";
}
outputStream << "End " << unitKind << ' ' << name << "\n\n";
@@ -1298,9 +1315,13 @@ class PFTDumper {
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
outputStream << "Contains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- moduleLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
+ for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
+ if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
+ dumpFunctionLikeUnit(outputStream, *func);
+ if (const auto *directive =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
+ dumpCompilerDirectiveUnit(outputStream, *directive);
+ }
outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
}
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index e24559bf14f7c9..ff5e58ebc721c5 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
// separate-module-subprogram
TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
- construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
+ construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
+ construct<ModuleSubprogram>(indirect(compilerDirective)))
// R1410 module-nature -> INTRINSIC | NON_INTRINSIC
constexpr auto moduleNature{
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index bf773f3810c847..fcb6392620b67f 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -111,6 +111,9 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
if (subps) {
for (const auto &subp :
std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
+ if (std::holds_alternative<
+ common::Indirection<parser::CompilerDirective>>(subp.u))
+ continue;
common::visit(
[&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
subp.u);
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e1cd34ddf65b6b..40fb641e085a78 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8389,7 +8389,31 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
}
}
} else {
- Say(x.source, "Compiler directive was ignored"_warn_en_US);
+ bool handled = false;
+ if (const auto *nvList{
+ std::get_if<std::list<parser::CompilerDirective::NameValue>>(
+ &x.u)}) {
+ for (const parser::CompilerDirective::NameValue &nv : *nvList) {
+ std::string name = std::get<parser::Name>(nv.t).ToString();
+ const std::initializer_list<const char *> handledAttrs{
+ "arm_streaming",
+ "arm_locally_streaming",
+ "arm_streaming_compatible",
+ };
+ if (std::find(handledAttrs.begin(), handledAttrs.end(), name) ==
+ handledAttrs.end()) {
+ // exit early so that subsequent recognised attributes can't change
+ // the result
+ handled = false;
+ break;
+ }
+ // this attribute was handled
+ handled = true;
+ }
+ }
+ if (!handled) {
+ Say(x.source, "Compiler directive was ignored"_warn_en_US);
+ }
}
}
diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90
new file mode 100644
index 00000000000000..86fbe89920b274
--- /dev/null
+++ b/flang/test/Lower/arm-ssve-directives.f90
@@ -0,0 +1,46 @@
+! RUN: bbc -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+! check we don't warn about these attributes
+! CHECK-NOT: warning: Compiler directive was ignored
+
+! check we create the right fuction attributes
+
+!dir$ arm_streaming
+subroutine sub
+end subroutine sub
+! CHECK-LABEL: func.func @_QPsub()
+! CHECK-SAME: attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine sub2
+end subroutine sub2
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK-SAME: attributes {arm_locally_streaming}
+
+!dir$ arm_streaming_compatible
+subroutine sub3
+end subroutine sub3
+! CHECK-LABEL: func.func @_QPsub3()
+! CHECK-SAME: attributes {arm_streaming_compatible}
+
+module m
+contains
+
+!dir$ arm_streaming
+subroutine msub
+end subroutine msub
+! CHECK-LABEL: func.func @_QMmPmsub()
+! CHECK-SAME: attributes {arm_streaming}
+
+!dir$ arm_locally_streaming
+subroutine msub2
+end subroutin...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/75352
More information about the flang-commits
mailing list