[Mlir-commits] [mlir] d452e67 - [flang][OpenMP] Enable tiling (#143715)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 10 06:25:44 PDT 2025


Author: Jan Leyonberg
Date: 2025-09-10T09:25:40-04:00
New Revision: d452e67ee7b5d17aa040f71d8997abc1a47750e4

URL: https://github.com/llvm/llvm-project/commit/d452e67ee7b5d17aa040f71d8997abc1a47750e4
DIFF: https://github.com/llvm/llvm-project/commit/d452e67ee7b5d17aa040f71d8997abc1a47750e4.diff

LOG: [flang][OpenMP] Enable tiling (#143715)

This patch enables tiling in flang. In MLIR tiling is handled by
changing the the omp.loop_nest op to be able to represent both collapse
and tiling, so the flang front-end will combine the nested constructs into
a single MLIR op. The MLIR->LLVM-IR lowering of the LoopNestOp is
enhanced to first do the tiling if present, then collapse.

Added: 
    flang/test/Parser/OpenMP/do-tile-size.f90

Modified: 
    flang/include/flang/Lower/OpenMP.h
    flang/lib/Lower/OpenMP/ClauseProcessor.cpp
    flang/lib/Lower/OpenMP/ClauseProcessor.h
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/lib/Lower/OpenMP/Utils.cpp
    flang/lib/Lower/OpenMP/Utils.h
    flang/lib/Semantics/resolve-directives.cpp
    flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
    flang/test/Lower/OpenMP/simd.f90
    flang/test/Lower/OpenMP/wsloop-collapse.f90
    flang/test/Lower/OpenMP/wsloop-variable.f90
    flang/test/Semantics/OpenMP/do-collapse.f90
    flang/test/Semantics/OpenMP/do-concurrent-collapse.f90
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    flang/test/Lower/OpenMP/nested-loop-transformation-construct01.f90


################################################################################
diff  --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 581c93f76d627..df01a7b82c66c 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -80,7 +80,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
 void genOpenMPSymbolProperties(AbstractConverter &converter,
                                const pft::Variable &var);
 
-int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
 void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
 void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
 bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 23f0ca14e931d..a96884f5680ba 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -273,10 +273,15 @@ bool ClauseProcessor::processCancelDirectiveName(
 
 bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
-    mlir::omp::LoopRelatedClauseOps &result,
+    mlir::omp::LoopRelatedClauseOps &loopResult,
+    mlir::omp::CollapseClauseOps &collapseResult,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
-  return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
-                                result, iv);
+
+  int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+                                               clauses, loopResult, iv);
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
+  return numCollapse > 1;
 }
 
 bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
@@ -522,6 +527,13 @@ bool ClauseProcessor::processProcBind(
   return false;
 }
 
+bool ClauseProcessor::processTileSizes(
+    lower::pft::Evaluation &eval, mlir::omp::LoopNestOperands &result) const {
+  auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
+  collectTileSizesFromOpenMPConstruct(ompCons, result.tileSizes, semaCtx);
+  return !result.tileSizes.empty();
+}
+
 bool ClauseProcessor::processSafelen(
     mlir::omp::SafelenClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c46bdb348a3ef..324ea3c1047a5 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -63,7 +63,8 @@ class ClauseProcessor {
       mlir::omp::CancelDirectiveNameClauseOps &result) const;
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
-                  mlir::omp::LoopRelatedClauseOps &result,
+                  mlir::omp::LoopRelatedClauseOps &loopResult,
+                  mlir::omp::CollapseClauseOps &collapseResult,
                   llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
   bool processDevice(lower::StatementContext &stmtCtx,
                      mlir::omp::DeviceClauseOps &result) const;
@@ -98,6 +99,8 @@ class ClauseProcessor {
   bool processPriority(lower::StatementContext &stmtCtx,
                        mlir::omp::PriorityClauseOps &result) const;
   bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
+  bool processTileSizes(lower::pft::Evaluation &eval,
+                        mlir::omp::LoopNestOperands &result) const;
   bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
   bool processSchedule(lower::StatementContext &stmtCtx,
                        mlir::omp::ScheduleClauseOps &result) const;

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index def6cfff88231..0ec33e6b24dbf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -503,7 +503,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       [[fallthrough]];
     case OMPD_distribute:
     case OMPD_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
       break;
 
     case OMPD_teams:
@@ -522,7 +522,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       [[fallthrough]];
     case OMPD_target_teams_distribute:
     case OMPD_target_teams_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
       cp.processNumTeams(stmtCtx, hostInfo->ops);
       break;
 
@@ -533,7 +533,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_loop:
-      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
       break;
 
     case OMPD_teams_workdistribute:
@@ -1569,9 +1569,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
 
   HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
   if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
-    cp.processCollapse(loc, eval, clauseOps, iv);
+    cp.processCollapse(loc, eval, clauseOps, clauseOps, iv);
 
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
+  cp.processTileSizes(eval, clauseOps);
 }
 
 static void genLoopClauses(
@@ -1948,9 +1949,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
     return llvm::SmallVector<const semantics::Symbol *>(iv);
   };
 
-  auto *nestedEval =
-      getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
-
+  uint64_t nestValue = getCollapseValue(item->clauses);
+  nestValue = nestValue < iv.size() ? iv.size() : nestValue;
+  auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
   return genOpWithBody<mlir::omp::LoopNestOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
                         directive)
@@ -3843,8 +3844,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
           parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
       switch (nestedDirective) {
       case llvm::omp::Directive::OMPD_tile:
-        // Emit the omp.loop_nest with annotation for tiling
-        genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
+        // Skip OMPD_tile since the tile sizes will be retrieved when
+        // generating the omp.loop_nest op.
         break;
       default: {
         unsigned version = semaCtx.langOptions().OpenMPVersion;
@@ -3957,18 +3958,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
     lower::genDeclareTargetIntGlobal(converter, var);
 }
 
-int64_t
-Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
-  for (const parser::OmpClause &clause : clauseList.v) {
-    if (const auto &collapseClause =
-            std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
-      const auto *expr = semantics::GetExpr(collapseClause->v);
-      return evaluate::ToInt64(*expr).value();
-    }
-  }
-  return 1;
-}
-
 void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
                                         const lower::pft::Variable &var) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

diff  --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index cb6dd57667824..d1d1cd68a5b44 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,6 +13,7 @@
 #include "Utils.h"
 
 #include "ClauseFinder.h"
+#include "flang/Evaluate/fold.h"
 #include "flang/Lower/OpenMP/Clauses.h"
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
@@ -24,11 +25,32 @@
 #include <flang/Parser/parse-tree.h>
 #include <flang/Parser/tools.h>
 #include <flang/Semantics/tools.h>
+#include <flang/Semantics/type.h>
 #include <flang/Utils/OpenMP.h>
 #include <llvm/Support/CommandLine.h>
 
 #include <iterator>
 
+template <typename T>
+Fortran::semantics::MaybeIntExpr
+EvaluateIntExpr(Fortran::semantics::SemanticsContext &context, const T &expr) {
+  if (Fortran::semantics::MaybeExpr maybeExpr{
+          Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
+    if (auto *intExpr{
+            Fortran::evaluate::UnwrapExpr<Fortran::semantics::SomeIntExpr>(
+                *maybeExpr)}) {
+      return std::move(*intExpr);
+    }
+  }
+  return std::nullopt;
+}
+
+template <typename T>
+std::optional<std::int64_t>
+EvaluateInt64(Fortran::semantics::SemanticsContext &context, const T &expr) {
+  return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
+}
+
 llvm::cl::opt<bool> treatIndexAsSection(
     "openmp-treat-index-as-section",
     llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
@@ -577,12 +599,64 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
   }
 }
 
-bool collectLoopRelatedInfo(
+// Helper function that finds the sizes clause in a inner OMPD_tile directive
+// and passes the sizes clause to the callback function if found.
+static void processTileSizesFromOpenMPConstruct(
+    const parser::OpenMPConstruct *ompCons,
+    std::function<void(const parser::OmpClause::Sizes *)> processFun) {
+  if (!ompCons)
+    return;
+  if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+    const auto &nestedOptional =
+        std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
+    assert(nestedOptional.has_value() &&
+           "Expected a DoConstruct or OpenMPLoopConstruct");
+    const auto *innerConstruct =
+        std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
+            &(nestedOptional.value()));
+    if (innerConstruct) {
+      const auto &innerLoopDirective = innerConstruct->value();
+      const auto &innerBegin =
+          std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
+      const auto &innerDirective =
+          std::get<parser::OmpLoopDirective>(innerBegin.t).v;
+
+      if (innerDirective == llvm::omp::Directive::OMPD_tile) {
+        // Get the size values from parse tree and convert to a vector.
+        const auto &innerClauseList{
+            std::get<parser::OmpClauseList>(innerBegin.t)};
+        for (const auto &clause : innerClauseList.v) {
+          if (const auto tclause{
+                  std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
+            processFun(tclause);
+            break;
+          }
+        }
+      }
+    }
+  }
+}
+
+/// Populates the sizes vector with values if the given OpenMPConstruct
+/// contains a loop construct with an inner tiling construct.
+void collectTileSizesFromOpenMPConstruct(
+    const parser::OpenMPConstruct *ompCons,
+    llvm::SmallVectorImpl<int64_t> &tileSizes,
+    Fortran::semantics::SemanticsContext &semaCtx) {
+  processTileSizesFromOpenMPConstruct(
+      ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+        for (auto &tval : tclause->v)
+          if (const auto v{EvaluateInt64(semaCtx, tval)})
+            tileSizes.push_back(*v);
+      });
+}
+
+int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
     lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
-  bool found = false;
+  int64_t numCollapse = 1;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   // Collect the loops to collapse.
@@ -595,9 +669,19 @@ bool collectLoopRelatedInfo(
   if (auto *clause =
           ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
     collapseValue = evaluate::ToInt64(clause->v).value();
-    found = true;
+    numCollapse = collapseValue;
+  }
+
+  // Collect sizes from tile directive if present.
+  std::int64_t sizesLengthValue = 0l;
+  if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
+    processTileSizesFromOpenMPConstruct(
+        ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+          sizesLengthValue = tclause->v.size();
+        });
   }
 
+  collapseValue = std::max(collapseValue, sizesLengthValue);
   std::size_t loopVarTypeSize = 0;
   do {
     lower::pft::Evaluation *doLoop =
@@ -631,7 +715,7 @@ bool collectLoopRelatedInfo(
 
   convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
 
-  return found;
+  return numCollapse;
 }
 
 } // namespace omp

diff  --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 88371ab8bf969..5f191d89ae205 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -159,12 +159,17 @@ void genObjectList(const ObjectList &objects,
 void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
                                      mlir::Location loc);
 
-bool collectLoopRelatedInfo(
+int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
     lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 
+void collectTileSizesFromOpenMPConstruct(
+    const parser::OpenMPConstruct *ompCons,
+    llvm::SmallVectorImpl<int64_t> &tileSizes,
+    Fortran::semantics::SemanticsContext &semaCtx);
+
 } // namespace omp
 } // namespace lower
 } // namespace Fortran

diff  --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 43f12c2b14038..1b7718d1314d3 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -856,7 +856,23 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
   const parser::OmpClause *GetAssociatedClause() { return associatedClause; }
 
 private:
-  std::int64_t GetAssociatedLoopLevelFromClauses(const parser::OmpClauseList &);
+  /// Given a vector of loop levels and a vector of corresponding clauses find
+  /// the largest loop level and set the associated loop level to the found
+  /// maximum. This is used for error handling to ensure that the number of
+  /// affected loops is not larger that the number of available loops.
+  std::int64_t SetAssociatedMaxClause(llvm::SmallVector<std::int64_t> &,
+      llvm::SmallVector<const parser::OmpClause *> &);
+  std::int64_t GetNumAffectedLoopsFromLoopConstruct(
+      const parser::OpenMPLoopConstruct &);
+  void CollectNumAffectedLoopsFromLoopConstruct(
+      const parser::OpenMPLoopConstruct &, llvm::SmallVector<std::int64_t> &,
+      llvm::SmallVector<const parser::OmpClause *> &);
+  void CollectNumAffectedLoopsFromInnerLoopContruct(
+      const parser::OpenMPLoopConstruct &, llvm::SmallVector<std::int64_t> &,
+      llvm::SmallVector<const parser::OmpClause *> &);
+  void CollectNumAffectedLoopsFromClauses(const parser::OmpClauseList &,
+      llvm::SmallVector<std::int64_t> &,
+      llvm::SmallVector<const parser::OmpClause *> &);
 
   Symbol::Flags dataSharingAttributeFlags{Symbol::Flag::OmpShared,
       Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate,
@@ -1868,7 +1884,6 @@ bool OmpAttributeVisitor::Pre(
 bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
   const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
   const auto &beginDir{std::get<parser::OmpLoopDirective>(beginLoopDir.t)};
-  const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
   switch (beginDir.v) {
   case llvm::omp::Directive::OMPD_distribute:
   case llvm::omp::Directive::OMPD_distribute_parallel_do:
@@ -1919,7 +1934,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
       beginDir.v == llvm::omp::Directive::OMPD_target_loop)
     IssueNonConformanceWarning(beginDir.v, beginDir.source, 52);
   ClearDataSharingAttributeObjects();
-  SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
+  SetContextAssociatedLoopLevel(GetNumAffectedLoopsFromLoopConstruct(x));
 
   if (beginDir.v == llvm::omp::Directive::OMPD_do) {
     auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
@@ -1933,7 +1948,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
     }
   }
   PrivatizeAssociatedLoopIndexAndCheckLoopLevel(x);
-  ordCollapseLevel = GetAssociatedLoopLevelFromClauses(clauseList) + 1;
+  ordCollapseLevel = GetNumAffectedLoopsFromLoopConstruct(x) + 1;
   return true;
 }
 
@@ -2021,44 +2036,111 @@ bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) {
   return true;
 }
 
-std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromClauses(
-    const parser::OmpClauseList &x) {
-  std::int64_t orderedLevel{0};
-  std::int64_t collapseLevel{0};
+static bool isSizesClause(const parser::OmpClause *clause) {
+  return std::holds_alternative<parser::OmpClause::Sizes>(clause->u);
+}
 
-  const parser::OmpClause *ordClause{nullptr};
-  const parser::OmpClause *collClause{nullptr};
+std::int64_t OmpAttributeVisitor::SetAssociatedMaxClause(
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+
+  // Find the tile level to ensure that the COLLAPSE clause value
+  // does not exeed the number of tiled loops.
+  std::int64_t tileLevel = 0;
+  for (auto [level, clause] : llvm::zip_equal(levels, clauses))
+    if (isSizesClause(clause))
+      tileLevel = level;
+
+  std::int64_t maxLevel = 1;
+  const parser::OmpClause *maxClause = nullptr;
+  for (auto [level, clause] : llvm::zip_equal(levels, clauses)) {
+    if (tileLevel > 0 && tileLevel < level) {
+      context_.Say(clause->source,
+          "The value of the parameter in the COLLAPSE clause must"
+          " not be larger than the number of the number of tiled loops"
+          " because collapse currently is limited to independent loop"
+          " iterations."_err_en_US);
+      return 1;
+    }
+
+    if (level > maxLevel) {
+      maxLevel = level;
+      maxClause = clause;
+    }
+  }
+  if (maxClause)
+    SetAssociatedClause(maxClause);
+  return maxLevel;
+}
+
+std::int64_t OmpAttributeVisitor::GetNumAffectedLoopsFromLoopConstruct(
+    const parser::OpenMPLoopConstruct &x) {
+  llvm::SmallVector<std::int64_t> levels;
+  llvm::SmallVector<const parser::OmpClause *> clauses;
+
+  CollectNumAffectedLoopsFromLoopConstruct(x, levels, clauses);
+  return SetAssociatedMaxClause(levels, clauses);
+}
+
+void OmpAttributeVisitor::CollectNumAffectedLoopsFromLoopConstruct(
+    const parser::OpenMPLoopConstruct &x,
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+  const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
+  const auto &clauseList{std::get<parser::OmpClauseList>(beginLoopDir.t)};
 
+  CollectNumAffectedLoopsFromClauses(clauseList, levels, clauses);
+  CollectNumAffectedLoopsFromInnerLoopContruct(x, levels, clauses);
+}
+
+void OmpAttributeVisitor::CollectNumAffectedLoopsFromInnerLoopContruct(
+    const parser::OpenMPLoopConstruct &x,
+    llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
+
+  const auto &nestedOptional =
+      std::get<std::optional<parser::NestedConstruct>>(x.t);
+  assert(nestedOptional.has_value() &&
+      "Expected a DoConstruct or OpenMPLoopConstruct");
+  const auto *innerConstruct =
+      std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
+          &(nestedOptional.value()));
+
+  if (innerConstruct) {
+    CollectNumAffectedLoopsFromLoopConstruct(
+        innerConstruct->value(), levels, clauses);
+  }
+}
+
+void OmpAttributeVisitor::CollectNumAffectedLoopsFromClauses(
+    const parser::OmpClauseList &x, llvm::SmallVector<std::int64_t> &levels,
+    llvm::SmallVector<const parser::OmpClause *> &clauses) {
   for (const auto &clause : x.v) {
-    if (const auto *orderedClause{
+    if (const auto oclause{
             std::get_if<parser::OmpClause::Ordered>(&clause.u)}) {
-      if (const auto v{EvaluateInt64(context_, orderedClause->v)}) {
-        orderedLevel = *v;
+      std::int64_t level = 0;
+      if (const auto v{EvaluateInt64(context_, oclause->v)}) {
+        level = *v;
       }
-      ordClause = &clause;
+      levels.push_back(level);
+      clauses.push_back(&clause);
     }
-    if (const auto *collapseClause{
+
+    if (const auto cclause{
             std::get_if<parser::OmpClause::Collapse>(&clause.u)}) {
-      if (const auto v{EvaluateInt64(context_, collapseClause->v)}) {
-        collapseLevel = *v;
+      std::int64_t level = 0;
+      if (const auto v{EvaluateInt64(context_, cclause->v)}) {
+        level = *v;
       }
-      collClause = &clause;
+      levels.push_back(level);
+      clauses.push_back(&clause);
     }
-  }
 
-  if (orderedLevel && (!collapseLevel || orderedLevel >= collapseLevel)) {
-    SetAssociatedClause(ordClause);
-    return orderedLevel;
-  } else if (!orderedLevel && collapseLevel) {
-    SetAssociatedClause(collClause);
-    return collapseLevel;
-  } else {
-    SetAssociatedClause(nullptr);
+    if (const auto tclause{std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
+      levels.push_back(tclause->v.size());
+      clauses.push_back(&clause);
+    }
   }
-  // orderedLevel < collapseLevel is an error handled in structural
-  // checks
-
-  return 1; // default is outermost loop
 }
 
 // 2.15.1.1 Data-sharing Attribute Rules - Predetermined
@@ -2090,10 +2172,21 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
   const parser::OmpClause *clause{GetAssociatedClause()};
   bool hasCollapseClause{
       clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false};
+  const parser::OpenMPLoopConstruct *innerMostLoop = &x;
+  const parser::NestedConstruct *innerMostNest = nullptr;
+  while (auto &optLoopCons{
+      std::get<std::optional<parser::NestedConstruct>>(innerMostLoop->t)}) {
+    innerMostNest = &(optLoopCons.value());
+    if (const auto *innerLoop{
+            std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
+                innerMostNest)}) {
+      innerMostLoop = &(innerLoop->value());
+    } else
+      break;
+  }
 
-  auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
-  if (optLoopCons.has_value()) {
-    if (const auto &outer{std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
+  if (innerMostNest) {
+    if (const auto &outer{std::get_if<parser::DoConstruct>(innerMostNest)}) {
       for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;
           --level) {
         if (loop->IsDoConcurrent()) {
@@ -2129,7 +2222,7 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
       CheckAssocLoopLevel(level, GetAssociatedClause());
     } else if (const auto &loop{std::get_if<
                    common::Indirection<parser::OpenMPLoopConstruct>>(
-                   &*optLoopCons)}) {
+                   innerMostNest)}) {
       auto &beginDirective =
           std::get<parser::OmpBeginLoopDirective>(loop->value().t);
       auto &beginLoopDirective =

diff  --git a/flang/test/Lower/OpenMP/nested-loop-transformation-construct01.f90 b/flang/test/Lower/OpenMP/nested-loop-transformation-construct01.f90
deleted file mode 100644
index 17eba93a7405d..0000000000000
--- a/flang/test/Lower/OpenMP/nested-loop-transformation-construct01.f90
+++ /dev/null
@@ -1,20 +0,0 @@
-! Test to ensure TODO message is emitted for tile OpenMP 5.1 Directives when they are nested.
-
-!RUN: not %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s
-
-subroutine loop_transformation_construct
-  implicit none
-  integer :: I = 10
-  integer :: x
-  integer :: y(I)
-
-  !$omp do
-  !$omp tile
-  do i = 1, I
-    y(i) = y(i) * 5
-  end do
-  !$omp end tile
-  !$omp end do
-end subroutine
-
-!CHECK: not yet implemented: Unhandled loop directive (tile)

diff  --git a/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
index 2890e78e9d17f..faf8f717f6308 100644
--- a/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
@@ -108,7 +108,7 @@ subroutine omp_do_lastprivate_collapse2(a)
   ! CHECK-NEXT: %[[UB2:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<i32>
   ! CHECK-NEXT: %[[STEP2:.*]] = arith.constant 1 : i32
   ! CHECK-NEXT: omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[A_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[I_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[J_PVT_REF:.*]] : !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) {
-  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]]) : i32 = (%[[LB1]], %[[LB2]]) to (%[[UB1]], %[[UB2]]) inclusive step (%[[STEP1]], %[[STEP2]]) {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]]) : i32 = (%[[LB1]], %[[LB2]]) to (%[[UB1]], %[[UB2]]) inclusive step (%[[STEP1]], %[[STEP2]]) collapse(2) {
   ! CHECK:      %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
   ! CHECK:      %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
   ! CHECK:      %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -174,7 +174,7 @@ subroutine omp_do_lastprivate_collapse3(a)
   ! CHECK-NEXT: %[[UB3:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<i32>
   ! CHECK-NEXT: %[[STEP3:.*]] = arith.constant 1 : i32
   ! CHECK-NEXT: omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[A_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[I_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[J_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[K_PVT_REF:.*]] : !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) {
-  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) : i32 = (%[[LB1]], %[[LB2]], %[[LB3]]) to (%[[UB1]], %[[UB2]], %[[UB3]]) inclusive step (%[[STEP1]], %[[STEP2]], %[[STEP3]]) {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) : i32 = (%[[LB1]], %[[LB2]], %[[LB3]]) to (%[[UB1]], %[[UB2]], %[[UB3]]) inclusive step (%[[STEP1]], %[[STEP2]], %[[STEP3]]) collapse(3) {
   ! CHECK:      %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
   ! CHECK:      %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
   ! CHECK:      %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)

diff  --git a/flang/test/Lower/OpenMP/simd.f90 b/flang/test/Lower/OpenMP/simd.f90
index 7655c786573e3..369b5eb072af9 100644
--- a/flang/test/Lower/OpenMP/simd.f90
+++ b/flang/test/Lower/OpenMP/simd.f90
@@ -175,7 +175,7 @@ subroutine simd_with_collapse_clause(n)
   ! CHECK-NEXT: omp.loop_nest (%[[ARG_0:.*]], %[[ARG_1:.*]]) : i32 = (
   ! CHECK-SAME:                %[[LOWER_I]], %[[LOWER_J]]) to (
   ! CHECK-SAME:                %[[UPPER_I]], %[[UPPER_J]]) inclusive step (
-  ! CHECK-SAME:                %[[STEP_I]], %[[STEP_J]]) {
+  ! CHECK-SAME:                %[[STEP_I]], %[[STEP_J]]) collapse(2) {
   !$OMP SIMD COLLAPSE(2)
   do i = 1, n
     do j = 1, n

diff  --git a/flang/test/Lower/OpenMP/wsloop-collapse.f90 b/flang/test/Lower/OpenMP/wsloop-collapse.f90
index 7ec40ab4b2f43..677c7809c397f 100644
--- a/flang/test/Lower/OpenMP/wsloop-collapse.f90
+++ b/flang/test/Lower/OpenMP/wsloop-collapse.f90
@@ -57,7 +57,7 @@ program wsloop_collapse
 !CHECK:           %[[VAL_31:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i32>
 !CHECK:           %[[VAL_32:.*]] = arith.constant 1 : i32
 !CHECK:           omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[VAL_4:.*]], @{{.*}} %{{.*}}#0 -> %[[VAL_2:.*]], @{{.*}} %{{.*}}#0 -> %[[VAL_0:.*]] : !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) {
-!CHECK-NEXT:        omp.loop_nest (%[[VAL_33:.*]], %[[VAL_34:.*]], %[[VAL_35:.*]]) : i32 = (%[[VAL_24]], %[[VAL_27]], %[[VAL_30]]) to (%[[VAL_25]], %[[VAL_28]], %[[VAL_31]]) inclusive step (%[[VAL_26]], %[[VAL_29]], %[[VAL_32]]) {
+!CHECK-NEXT:        omp.loop_nest (%[[VAL_33:.*]], %[[VAL_34:.*]], %[[VAL_35:.*]]) : i32 = (%[[VAL_24]], %[[VAL_27]], %[[VAL_30]]) to (%[[VAL_25]], %[[VAL_28]], %[[VAL_31]]) inclusive step (%[[VAL_26]], %[[VAL_29]], %[[VAL_32]]) collapse(3) {
   !$omp do collapse(3)
   do i = 1, a
      do j= 1, b

diff  --git a/flang/test/Lower/OpenMP/wsloop-variable.f90 b/flang/test/Lower/OpenMP/wsloop-variable.f90
index f998c84331ce4..0f4aafb10ded3 100644
--- a/flang/test/Lower/OpenMP/wsloop-variable.f90
+++ b/flang/test/Lower/OpenMP/wsloop-variable.f90
@@ -22,7 +22,7 @@ program wsloop_variable
 !CHECK:      %[[TMP6:.*]] = fir.convert %[[TMP1]] : (i32) -> i64
 !CHECK:      %[[TMP7:.*]] = fir.convert %{{.*}} : (i32) -> i64
 !CHECK:      omp.wsloop private({{.*}}) {
-!CHECK-NEXT:   omp.loop_nest (%[[ARG0:.*]], %[[ARG1:.*]]) : i64 = (%[[TMP2]], %[[TMP5]]) to (%[[TMP3]], %[[TMP6]]) inclusive step (%[[TMP4]], %[[TMP7]]) {
+!CHECK-NEXT:   omp.loop_nest (%[[ARG0:.*]], %[[ARG1:.*]]) : i64 = (%[[TMP2]], %[[TMP5]]) to (%[[TMP3]], %[[TMP6]]) inclusive step (%[[TMP4]], %[[TMP7]]) collapse(2) {
 !CHECK:          %[[ARG0_I16:.*]] = fir.convert %[[ARG0]] : (i64) -> i16
 !CHECK:          hlfir.assign %[[ARG0_I16]] to %[[STORE_IV0:.*]]#0 : i16, !fir.ref<i16>
 !CHECK:          hlfir.assign %[[ARG1]] to %[[STORE_IV1:.*]]#0 : i64, !fir.ref<i64>

diff  --git a/flang/test/Parser/OpenMP/do-tile-size.f90 b/flang/test/Parser/OpenMP/do-tile-size.f90
new file mode 100644
index 0000000000000..886ee4a2a680c
--- /dev/null
+++ b/flang/test/Parser/OpenMP/do-tile-size.f90
@@ -0,0 +1,29 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp -fopenmp-version=51 %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp -fopenmp-version=51 %s | FileCheck --check-prefix="PARSE-TREE" %s
+
+subroutine openmp_do_tiles(x)
+
+  integer, intent(inout)::x
+
+
+!CHECK: !$omp do
+!CHECK: !$omp tile sizes
+!$omp do
+!$omp  tile sizes(2)
+!CHECK: do
+  do x = 1, 100
+     call F1()
+!CHECK: end do
+  end do
+!CHECK: !$omp end tile
+!$omp end tile
+!$omp end do
+
+!PARSE-TREE:| | ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE:| | | OmpBeginLoopDirective
+!PARSE-TREE:| | | OpenMPLoopConstruct
+!PARSE-TREE:| | | | OmpBeginLoopDirective
+!PARSE-TREE:| | | | | OmpLoopDirective -> llvm::omp::Directive = tile
+!PARSE-TREE:| | | | | OmpClauseList -> OmpClause -> Sizes -> Scalar -> Integer -> Expr = '2_4'
+!PARSE-TREE: | | | | DoConstruct
+END subroutine openmp_do_tiles

diff  --git a/flang/test/Semantics/OpenMP/do-collapse.f90 b/flang/test/Semantics/OpenMP/do-collapse.f90
index 480bd45b79b83..ec6a3bdad3686 100644
--- a/flang/test/Semantics/OpenMP/do-collapse.f90
+++ b/flang/test/Semantics/OpenMP/do-collapse.f90
@@ -31,6 +31,7 @@ program omp_doCollapse
       end do
     end do
 
+  !ERROR: The value of the parameter in the COLLAPSE or ORDERED clause must not be larger than the number of nested loops following the construct.
   !ERROR: At most one COLLAPSE clause can appear on the SIMD directive
   !$omp simd collapse(2) collapse(1)
   do i = 1, 4

diff  --git a/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90 b/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90
index bb1929249183b..355626f6e73b9 100644
--- a/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90
+++ b/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90
@@ -1,6 +1,7 @@
 !RUN: %python %S/../test_errors.py %s %flang -fopenmp
 
 integer :: i, j
+! ERROR: DO CONCURRENT loops cannot be used with the COLLAPSE clause.
 !$omp parallel do collapse(2)
 do i = 1, 1
   ! ERROR: DO CONCURRENT loops cannot form part of a loop nest.

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index faf820dcfdb29..6a92b136ef51c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -40,7 +40,7 @@ struct DeviceTypeClauseOps {
 /// Clauses that correspond to operations other than omp.target, but might have
 /// to be evaluated outside of a parent target region.
 using HostEvaluatedOperands =
-    detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps,
+    detail::Clauses<CollapseClauseOps, LoopRelatedClauseOps, NumTeamsClauseOps,
                     NumThreadsClauseOps, ThreadLimitClauseOps>;
 
 // TODO: Add `indirect` clause.

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57fb4446c..5f40abe62a0f6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -209,6 +209,23 @@ class OpenMP_BindClauseSkip<
 
 def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [4.4.3] `collapse` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_CollapseClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+      ConfinedAttr<DefaultValuedOptionalAttr<I64Attr, "1">, [IntMinValue<1>]>
+      :$collapse_num_loops
+  );
+}
+
+def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [5.7.2] `copyprivate` clause
 //===----------------------------------------------------------------------===//
@@ -1385,6 +1402,22 @@ class OpenMP_ThreadLimitClauseSkip<
 
 def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [9.1.1] `sizes` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_TileSizesClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+      OptionalAttr<DenseI64ArrayAttr>:$tile_sizes
+  );
+}
+
+def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [12.1] `untied` clause
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 2548a8ab4aac6..830b36f440098 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -614,13 +614,18 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
 def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
     RecursiveMemoryEffects, SameVariadicOperandSize
   ], clauses = [
-    OpenMP_LoopRelatedClause
+    OpenMP_CollapseClause,
+    OpenMP_LoopRelatedClause,
+    OpenMP_TileSizesClause
   ], singleRegion = true> {
   let summary = "rectangular loop nest";
   let description = [{
-    This operation represents a collapsed rectangular loop nest. For each
-    rectangular loop of the nest represented by an instance of this operation,
-    lower and upper bounds, as well as a step variable, must be defined.
+    This operation represents a rectangular loop nest which may be collapsed
+    and/or tiled. For each rectangular loop of the nest represented by an
+    instance of this operation, lower and upper bounds, as well as a step
+    variable, must be defined. The collapse clause specifies how many loops
+    that should be collapsed (1 if no collapse is done) after any tiling is
+    performed. The tiling sizes is represented by the tile sizes clause.
 
     The lower and upper bounds specify a half-open range: the range includes the
     lower bound but does not include the upper bound. If the `loop_inclusive`
@@ -633,7 +638,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
     `loop_steps` arguments.
 
     ```mlir
-    omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+    omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) collapse(2) tiles(5,5) {
       %a = load %arrA[%i1, %i2] : memref<?x?xf32>
       %b = load %arrB[%i1, %i2] : memref<?x?xf32>
       %sum = arith.addf %a, %b : f32

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index c4a9fc2e556f1..460595ba9f254 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -492,8 +492,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
         // Create loop nest and populate region with contents of scf.parallel.
         auto loopOp = omp::LoopNestOp::create(
-            rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(),
-            parallelOp.getUpperBound(), parallelOp.getStep());
+            rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
+            parallelOp.getLowerBound(), parallelOp.getUpperBound(),
+            parallelOp.getStep(), /*loop_inclusive=*/false,
+            /*tile_sizes=*/nullptr);
 
         rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
                                     loopOp.getRegion().begin());

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e43f28e8d93d..aa88b9e8eef5a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -56,6 +56,11 @@ makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
   return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
 }
 
+static DenseI64ArrayAttr
+makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef<int64_t> intArray) {
+  return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
+}
+
 namespace {
 struct MemRefPointerLikeModel
     : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -2956,10 +2961,10 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &iv : ivs)
     iv.type = loopVarType;
 
+  auto *ctx = parser.getBuilder().getContext();
   // Parse "inclusive" flag.
   if (succeeded(parser.parseOptionalKeyword("inclusive")))
-    result.addAttribute("loop_inclusive",
-                        UnitAttr::get(parser.getBuilder().getContext()));
+    result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
 
   // Parse step values.
   SmallVector<OpAsmParser::UnresolvedOperand> steps;
@@ -2967,6 +2972,35 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
     return failure();
 
+  // Parse collapse
+  int64_t value = 0;
+  if (!parser.parseOptionalKeyword("collapse") &&
+      (parser.parseLParen() || parser.parseInteger(value) ||
+       parser.parseRParen()))
+    return failure();
+  if (value > 1)
+    result.addAttribute(
+        "collapse_num_loops",
+        IntegerAttr::get(parser.getBuilder().getI64Type(), value));
+
+  // Parse tiles
+  SmallVector<int64_t> tiles;
+  auto parseTiles = [&]() -> ParseResult {
+    int64_t tile;
+    if (parser.parseInteger(tile))
+      return failure();
+    tiles.push_back(tile);
+    return success();
+  };
+
+  if (!parser.parseOptionalKeyword("tiles") &&
+      (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
+       parser.parseRParen()))
+    return failure();
+
+  if (tiles.size() > 0)
+    result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
+
   // Parse the body.
   Region *region = result.addRegion();
   if (parser.parseRegion(*region, ivs))
@@ -2990,14 +3024,23 @@ void LoopNestOp::print(OpAsmPrinter &p) {
   if (getLoopInclusive())
     p << "inclusive ";
   p << "step (" << getLoopSteps() << ") ";
+  if (int64_t numCollapse = getCollapseNumLoops())
+    if (numCollapse > 1)
+      p << "collapse(" << numCollapse << ") ";
+
+  if (const auto tiles = getTileSizes())
+    p << "tiles(" << tiles.value() << ") ";
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
 void LoopNestOp::build(OpBuilder &builder, OperationState &state,
                        const LoopNestOperands &clauses) {
-  LoopNestOp::build(builder, state, clauses.loopLowerBounds,
-                    clauses.loopUpperBounds, clauses.loopSteps,
-                    clauses.loopInclusive);
+  MLIRContext *ctx = builder.getContext();
+  LoopNestOp::build(builder, state, clauses.collapseNumLoops,
+                    clauses.loopLowerBounds, clauses.loopUpperBounds,
+                    clauses.loopSteps, clauses.loopInclusive,
+                    makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
 }
 
 LogicalResult LoopNestOp::verify() {
@@ -3013,6 +3056,17 @@ LogicalResult LoopNestOp::verify() {
              << "range argument type does not match corresponding IV type";
   }
 
+  uint64_t numIVs = getIVs().size();
+
+  if (const auto &numCollapse = getCollapseNumLoops())
+    if (numCollapse > numIVs)
+      return emitOpError()
+             << "collapse value is larger than the number of loops";
+
+  if (const auto &tiles = getTileSizes())
+    if (tiles.value().size() > numIVs)
+      return emitOpError() << "too few canonical loops for tile dimensions";
+
   if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
     return emitOpError() << "expects parent op to be a loop wrapper";
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4e26e65cf9718..2ab6bb0a73200 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3041,16 +3041,46 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
     loopInfos.push_back(*loopResult);
   }
 
-  // Collapse loops. Store the insertion point because LoopInfos may get
-  // invalidated.
   llvm::OpenMPIRBuilder::InsertPointTy afterIP =
       loopInfos.front()->getAfterIP();
 
-  // Update the stack frame created for this loop to point to the resulting loop
-  // after applying transformations.
+  // Do tiling.
+  if (const auto &tiles = loopOp.getTileSizes()) {
+    llvm::Type *ivType = loopInfos.front()->getIndVarType();
+    SmallVector<llvm::Value *> tileSizes;
+
+    for (auto tile : tiles.value()) {
+      llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
+      tileSizes.push_back(tileVal);
+    }
+
+    std::vector<llvm::CanonicalLoopInfo *> newLoops =
+        ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
+
+    // Update afterIP to get the correct insertion point after
+    // tiling.
+    llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
+    llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
+    afterIP = {afterAfterBB, afterAfterBB->begin()};
+
+    // Update the loop infos.
+    loopInfos.clear();
+    for (const auto &newLoop : newLoops)
+      loopInfos.push_back(newLoop);
+  } // Tiling done.
+
+  // Do collapse.
+  const auto &numCollapse = loopOp.getCollapseNumLoops();
+  SmallVector<llvm::CanonicalLoopInfo *> collapseLoopInfos(
+      loopInfos.begin(), loopInfos.begin() + (numCollapse));
+
+  auto newTopLoopInfo =
+      ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
+
+  assert(newTopLoopInfo && "New top loop information is missing");
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+        frame.loopInfo = newTopLoopInfo;
         return WalkResult::interrupt();
       });
 

diff  --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index a722acbf2c347..d362bb6092419 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -6,7 +6,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
   // CHECK: omp.wsloop {
-  // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+  // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) collapse(2) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
     // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 986c3844d0bb9..763f41c5420b8 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -157,6 +157,29 @@ func.func @no_loops(%lb : index, %ub : index, %step : index) {
   }
 }
 
+// -----
+
+func.func @collapse_size(%lb : index, %ub : index, %step : index) {
+  omp.wsloop {
+    // expected-error at +1 {{collapse value is larger than the number of loops}}
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) collapse(4) {
+      omp.yield
+    }
+  }
+}
+
+// -----
+
+func.func @tiles_length(%lb : index, %ub : index, %step : index) {
+  omp.wsloop {
+    // expected-error at +1 {{op too few canonical loops for tile dimensions}}
+    omp.loop_nest (%iv) : index =  (%lb) to (%ub) step (%step) tiles(2, 4) {
+      omp.yield
+    }
+  }
+}
+
+
 // -----
 
 func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3c2e0a3b7cc15..60b1f61135ac2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -376,6 +376,60 @@ func.func @omp_loop_nest_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32,
   return
 }
 
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse
+func.func @omp_loop_nest_pretty_multiple_collapse(%lb1 : i32, %ub1 : i32, %step1 : i32,
+    %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+  omp.wsloop {
+    // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2)
+    omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
+      %1 = "test.payload"(%iv1) : (i32) -> (index)
+      %2 = "test.payload"(%iv2) : (i32) -> (index)
+      memref.store %iv1, %data1[%1] : memref<?xi32>
+      memref.store %iv2, %data1[%2] : memref<?xi32>
+      omp.yield
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_tiles
+func.func @omp_loop_nest_pretty_multiple_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32,
+    %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+  omp.wsloop {
+    // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) tiles(5, 10)
+    omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) tiles(5, 10) {
+      %1 = "test.payload"(%iv1) : (i32) -> (index)
+      %2 = "test.payload"(%iv2) : (i32) -> (index)
+      memref.store %iv1, %data1[%1] : memref<?xi32>
+      memref.store %iv2, %data1[%2] : memref<?xi32>
+      omp.yield
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse_tiles
+func.func @omp_loop_nest_pretty_multiple_collapse_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32,
+    %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+  omp.wsloop {
+    // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) tiles(5, 10)
+    omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) tiles(5, 10) {
+      %1 = "test.payload"(%iv1) : (i32) -> (index)
+      %2 = "test.payload"(%iv2) : (i32) -> (index)
+      memref.store %iv1, %data1[%1] : memref<?xi32>
+      memref.store %iv2, %data1[%2] : memref<?xi32>
+      omp.yield
+    }
+  }
+
+  return
+}
+
 // CHECK-LABEL: omp_wsloop
 func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32>, %linear_var : i32, %chunk_var : i32) -> () {
 

diff  --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
index b42e387acbb11..d84641ff9c99b 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
@@ -9,7 +9,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
     %loop_lb = llvm.mlir.constant(0 : i32) : i32
     %loop_step = llvm.mlir.constant(1 : index) : i32
     omp.wsloop {
-      omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
+      omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) collapse(2) {
         %1 = llvm.add %arg1, %arg2  : i32
         %2 = llvm.mul %arg2, %loop_ub overflow<nsw>  : i32
         %3 = llvm.add %arg1, %2 :i32

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 3f4dcd5e24c56..27210bc0890ce 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -698,7 +698,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) {
 // CHECK-LABEL: @simd_simple_multiple
 llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   omp.simd {
-    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) {
+    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) collapse(2) {
       %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
       // The form of the emitted IR is controlled by OpenMPIRBuilder and
       // tested there. Just check that the right metadata is added and collapsed
@@ -736,7 +736,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
 // CHECK-LABEL: @simd_simple_multiple_simdlen
 llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   omp.simd simdlen(2) {
-    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
       %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
       // The form of the emitted IR is controlled by OpenMPIRBuilder and
       // tested there. Just check that the right metadata is added.
@@ -760,7 +760,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
 // CHECK-LABEL: @simd_simple_multiple_safelen
 llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   omp.simd safelen(2) {
-    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
       %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
       %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
       %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -779,7 +779,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
 // CHECK-LABEL: @simd_simple_multiple_simdlen_safelen
 llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   omp.simd simdlen(1) safelen(2) {
-    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
       %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
       %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
       %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -1177,7 +1177,7 @@ llvm.func @collapse_wsloop(
     // CHECK: store i32 %[[TOTAL_SUB_1]], ptr
     // CHECK: call void @__kmpc_for_static_init_4u
     omp.wsloop {
-      omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) {
+      omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
         %31 = llvm.load %20 : !llvm.ptr -> i32
         %32 = llvm.add %31, %arg0 : i32
         %33 = llvm.add %32, %arg1 : i32
@@ -1239,7 +1239,7 @@ llvm.func @collapse_wsloop_dynamic(
     // CHECK: store i32 %[[TOTAL]], ptr
     // CHECK: call void @__kmpc_dispatch_init_4u
     omp.wsloop schedule(dynamic) {
-      omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) {
+      omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
         %31 = llvm.load %20 : !llvm.ptr -> i32
         %32 = llvm.add %31, %arg0 : i32
         %33 = llvm.add %32, %arg1 : i32


        


More information about the Mlir-commits mailing list