[flang] [llvm] [mlir] [flang][OpenMP] Enable tiling (PR #143715)
Krzysztof Parzyszek via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 18 11:50:01 PDT 2025
================
@@ -1890,39 +1910,124 @@ bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) {
return true;
}
+static bool isSizesClause(const parser::OmpClause *clause) {
+ return std::holds_alternative<parser::OmpClause::Sizes>(clause->u);
+}
+
+std::int64_t OmpAttributeVisitor::SetAssociatedMaxClause(
+ llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+
+ // Find the tile level to know how much to reduce the level for collapse
+ std::int64_t tileLevel = 0;
+ for (auto [level, clause] : llvm::zip_equal(levels, clauses)) {
+ if (isSizesClause(clause)) {
+ tileLevel = level;
+ }
+ }
+
+ std::int64_t maxLevel = 1;
+ const parser::OmpClause *maxClause = nullptr;
+ for (auto [level, clause] : llvm::zip_equal(levels, clauses)) {
+ if (tileLevel > 0 && tileLevel < level) {
+ context_.Say(clause->source,
+ "The value of the parameter in the COLLAPSE clause must"
+ " not be larger than the number of the number of tiled loops"
+ " because collapse relies on independent loop iterations."_err_en_US);
+ return 1;
+ }
+
+ if (!isSizesClause(clause)) {
+ level = level - tileLevel;
+ }
+
+ if (level > maxLevel) {
+ maxLevel = level;
+ maxClause = clause;
+ }
+ }
+ if (maxClause)
+ SetAssociatedClause(*maxClause);
+ return maxLevel;
+}
+
+std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromLoopConstruct(
+ const parser::OpenMPLoopConstruct &x) {
+ llvm::SmallVector<std::int64_t> levels;
+ llvm::SmallVector<const parser::OmpClause *> clauses;
+
+ CollectAssociatedLoopLevelsFromLoopConstruct(x, levels, clauses);
+ return SetAssociatedMaxClause(levels, clauses);
+}
+
std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromClauses(
const parser::OmpClauseList &x) {
- std::int64_t orderedLevel{0};
- std::int64_t collapseLevel{0};
+ llvm::SmallVector<std::int64_t> levels;
+ llvm::SmallVector<const parser::OmpClause *> clauses;
- const parser::OmpClause *ordClause{nullptr};
- const parser::OmpClause *collClause{nullptr};
+ CollectAssociatedLoopLevelsFromClauses(x, levels, clauses);
+ return SetAssociatedMaxClause(levels, clauses);
+}
- for (const auto &clause : x.v) {
- if (const auto *orderedClause{
- std::get_if<parser::OmpClause::Ordered>(&clause.u)}) {
- if (const auto v{EvaluateInt64(context_, orderedClause->v)}) {
- orderedLevel = *v;
- }
- ordClause = &clause;
- }
- if (const auto *collapseClause{
- std::get_if<parser::OmpClause::Collapse>(&clause.u)}) {
- if (const auto v{EvaluateInt64(context_, collapseClause->v)}) {
- collapseLevel = *v;
- }
- collClause = &clause;
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromLoopConstruct(
+ const parser::OpenMPLoopConstruct &x,
+ llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+ const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
+ const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
+
+ CollectAssociatedLoopLevelsFromClauses(clauseList, levels, clauses);
+ CollectAssociatedLoopLevelsFromInnerLoopContruct(x, levels, clauses);
+}
+
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromInnerLoopContruct(
+ const parser::OpenMPLoopConstruct &x,
+ llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+ const auto &innerOptional =
+ std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
+ x.t);
+ if (innerOptional.has_value()) {
+ CollectAssociatedLoopLevelsFromLoopConstruct(
+ innerOptional.value().value(), levels, clauses);
+ }
+}
+
+template <typename T>
+void OmpAttributeVisitor::CollectAssociatedLoopLevelFromClauseValue(
+ const parser::OmpClause &clause, llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+ if (const auto tclause{std::get_if<T>(&clause.u)}) {
+ std::int64_t level = 0;
+ if (const auto v{EvaluateInt64(context_, tclause->v)}) {
+ level = *v;
}
+ levels.push_back(level);
+ clauses.push_back(&clause);
}
+}
- if (orderedLevel && (!collapseLevel || orderedLevel >= collapseLevel)) {
- SetAssociatedClause(*ordClause);
- return orderedLevel;
- } else if (!orderedLevel && collapseLevel) {
- SetAssociatedClause(*collClause);
- return collapseLevel;
- } // orderedLevel < collapseLevel is an error handled in structural checks
- return 1; // default is outermost loop
+template <typename T>
+void OmpAttributeVisitor::CollectAssociatedLoopLevelFromClauseSize(
+ const parser::OmpClause &clause, llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+ if (const auto tclause{std::get_if<T>(&clause.u)}) {
+ levels.push_back(tclause->v.size());
+ clauses.push_back(&clause);
+ }
+}
+
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromClauses(
+ const parser::OmpClauseList &x, llvm::SmallVector<std::int64_t> &levels,
+ llvm::SmallVector<const parser::OmpClause *> &clauses) {
+ for (const auto &clause : x.v) {
+ CollectAssociatedLoopLevelFromClauseValue<parser::OmpClause::Ordered>(
+ clause, levels, clauses);
+ CollectAssociatedLoopLevelFromClauseValue<parser::OmpClause::Collapse>(
+ clause, levels, clauses);
+ CollectAssociatedLoopLevelFromClauseSize<parser::OmpClause::Sizes>(
+ clause, levels, clauses);
----------------
kparzysz wrote:
It would probably be easier to read if these functions were inlined here.
https://github.com/llvm/llvm-project/pull/143715
More information about the llvm-commits
mailing list