[flang] [llvm] [mlir] [flang][OpenMP] Enable tiling (PR #143715)

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 18 11:50:01 PDT 2025


================
@@ -1890,39 +1910,124 @@ bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) {
   return true;
 }
 
+static bool isSizesClause(const parser::OmpClause *clause) {
+  return std::holds_alternative<parser::OmpClause::Sizes>(clause->u);
+}
+
+std::int64_t OmpAttributeVisitor::SetAssociatedMaxClause(
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+
+  // Find the tile level to know how much to reduce the level for collapse
+  std::int64_t tileLevel = 0;
+  for (auto [level, clause] : llvm::zip_equal(levels, clauses)) {
+    if (isSizesClause(clause)) {
+      tileLevel = level;
+    }
+  }
+
+  std::int64_t maxLevel = 1;
+  const parser::OmpClause *maxClause = nullptr;
+  for (auto [level, clause] : llvm::zip_equal(levels, clauses)) {
+    if (tileLevel > 0 && tileLevel < level) {
+      context_.Say(clause->source,
+          "The value of the parameter in the COLLAPSE clause must"
+          " not be larger than the number of the number of tiled loops"
+          " because collapse relies on independent loop iterations."_err_en_US);
+      return 1;
+    }
+
+    if (!isSizesClause(clause)) {
+      level = level - tileLevel;
+    }
+
+    if (level > maxLevel) {
+      maxLevel = level;
+      maxClause = clause;
+    }
+  }
+  if (maxClause)
+    SetAssociatedClause(*maxClause);
+  return maxLevel;
+}
+
+std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromLoopConstruct(
+    const parser::OpenMPLoopConstruct &x) {
+  llvm::SmallVector<std::int64_t> levels;
+  llvm::SmallVector<const parser::OmpClause *> clauses;
+
+  CollectAssociatedLoopLevelsFromLoopConstruct(x, levels, clauses);
+  return SetAssociatedMaxClause(levels, clauses);
+}
+
 std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromClauses(
     const parser::OmpClauseList &x) {
-  std::int64_t orderedLevel{0};
-  std::int64_t collapseLevel{0};
+  llvm::SmallVector<std::int64_t> levels;
+  llvm::SmallVector<const parser::OmpClause *> clauses;
 
-  const parser::OmpClause *ordClause{nullptr};
-  const parser::OmpClause *collClause{nullptr};
+  CollectAssociatedLoopLevelsFromClauses(x, levels, clauses);
+  return SetAssociatedMaxClause(levels, clauses);
+}
 
-  for (const auto &clause : x.v) {
-    if (const auto *orderedClause{
-            std::get_if<parser::OmpClause::Ordered>(&clause.u)}) {
-      if (const auto v{EvaluateInt64(context_, orderedClause->v)}) {
-        orderedLevel = *v;
-      }
-      ordClause = &clause;
-    }
-    if (const auto *collapseClause{
-            std::get_if<parser::OmpClause::Collapse>(&clause.u)}) {
-      if (const auto v{EvaluateInt64(context_, collapseClause->v)}) {
-        collapseLevel = *v;
-      }
-      collClause = &clause;
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromLoopConstruct(
+    const parser::OpenMPLoopConstruct &x,
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
+  const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
+
+  CollectAssociatedLoopLevelsFromClauses(clauseList, levels, clauses);
+  CollectAssociatedLoopLevelsFromInnerLoopContruct(x, levels, clauses);
+}
+
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromInnerLoopContruct(
+    const parser::OpenMPLoopConstruct &x,
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  const auto &innerOptional =
+      std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
+          x.t);
+  if (innerOptional.has_value()) {
+    CollectAssociatedLoopLevelsFromLoopConstruct(
+        innerOptional.value().value(), levels, clauses);
+  }
+}
+
+template <typename T>
+void OmpAttributeVisitor::CollectAssociatedLoopLevelFromClauseValue(
+    const parser::OmpClause &clause, llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  if (const auto tclause{std::get_if<T>(&clause.u)}) {
+    std::int64_t level = 0;
+    if (const auto v{EvaluateInt64(context_, tclause->v)}) {
+      level = *v;
     }
+    levels.push_back(level);
+    clauses.push_back(&clause);
   }
+}
 
-  if (orderedLevel && (!collapseLevel || orderedLevel >= collapseLevel)) {
-    SetAssociatedClause(*ordClause);
-    return orderedLevel;
-  } else if (!orderedLevel && collapseLevel) {
-    SetAssociatedClause(*collClause);
-    return collapseLevel;
-  } // orderedLevel < collapseLevel is an error handled in structural checks
-  return 1; // default is outermost loop
+template <typename T>
+void OmpAttributeVisitor::CollectAssociatedLoopLevelFromClauseSize(
+    const parser::OmpClause &clause, llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  if (const auto tclause{std::get_if<T>(&clause.u)}) {
+    levels.push_back(tclause->v.size());
+    clauses.push_back(&clause);
+  }
+}
+
+void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromClauses(
+    const parser::OmpClauseList &x, llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  for (const auto &clause : x.v) {
+    CollectAssociatedLoopLevelFromClauseValue<parser::OmpClause::Ordered>(
+        clause, levels, clauses);
+    CollectAssociatedLoopLevelFromClauseValue<parser::OmpClause::Collapse>(
+        clause, levels, clauses);
+    CollectAssociatedLoopLevelFromClauseSize<parser::OmpClause::Sizes>(
+        clause, levels, clauses);
----------------
kparzysz wrote:

It would probably be easier to read if these functions were inlined here.

https://github.com/llvm/llvm-project/pull/143715


More information about the llvm-commits mailing list