[flang-commits] [flang] [flang][OpenMP] Implement GetGeneratedNestDepthWithReason (PR #191718)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Sun Apr 12 08:15:04 PDT 2026


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/191718

For a loop-nest-generating construct this function returns the number of loops in the generated loop nest.

A loop-nest-transformation construct can be thought as replacing N nested loops with K nested loops, where
  N = GetAffectedNestDepthWithReason
  K = GetGeneratedNestDepthWithReason

>From 124455574867b7c26c13c8c60abf5af62482a2b3 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Sun, 12 Apr 2026 09:02:28 -0500
Subject: [PATCH] [flang][OpenMP] Implement GetGeneratedNestDepthWithReason

For a loop-nest-generating construct this function returns the number of
loops in the generated loop nest.

A loop-nest-transformation construct can be thought as replacing N nested
loops with K nested loops, where
  N = GetAffectedNestDepthWithReason
  K = GetGeneratedNestDepthWithReason
---
 flang/include/flang/Semantics/openmp-utils.h |  5 ++
 flang/lib/Semantics/openmp-utils.cpp         | 91 ++++++++++++++------
 2 files changed, 72 insertions(+), 24 deletions(-)

diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index b449988abe9c8..252c88c967750 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -218,6 +218,11 @@ WithReason<int64_t> GetHeightWithReason(
 std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
     const parser::OmpDirectiveSpecification &spec, unsigned version,
     SemanticsContext *semaCtx = nullptr);
+/// Return the depth of the generated nest(s):
+///   {generated-depth, is-perfect-nest}
+std::pair<WithReason<int64_t>, bool> GetGeneratedNestDepthWithReason(
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx = nullptr);
 /// Return the range of the affected nests in the sequence:
 ///   {first, count}.
 /// If the range is "the whole sequence", the return value will be {1, -1}.
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index c901d00695093..7a5b31b1c805e 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -767,37 +767,27 @@ WithReason<int64_t> GetHeightWithReason(
         "This construct is not a DO-loop or a loop-transformation construct"_because_en_US);
     return {0, reason};
   }
-  switch (spec.DirName().v) {
-  case llvm::omp::Directive::OMPD_flatten:
-    if (auto &&value{GetArgumentValueWithReason(
-            spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
-      // FLATTEN DEPTH(n) replaces n loops with 1.
-      return {int64_t(1) - *value.value, std::move(value.reason)};
-    } else {
-      Reason reason;
-      reason.Say(spec.DirName().source, MsgClauseAbsentAssume,
-          GetUpperName(llvm::omp::Clause::OMPC_depth, version), "a depth of 2");
-      return {-1, std::move(reason)};
-    }
+
+  switch (spec.DirId()) {
+  // These generate loop sequences.
   case llvm::omp::Directive::OMPD_fuse:
   case llvm::omp::Directive::OMPD_split:
     return {0, Reason()};
+  case llvm::omp::Directive::OMPD_flatten:
   case llvm::omp::Directive::OMPD_interchange:
   case llvm::omp::Directive::OMPD_nothing:
   case llvm::omp::Directive::OMPD_reverse:
-    return {0, Reason()};
   case llvm::omp::Directive::OMPD_stripe:
   case llvm::omp::Directive::OMPD_tile:
-    return GetNumArgumentsWithReason(
-        spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx);
-  case llvm::omp::Directive::OMPD_unroll:
-    if (isFullUnroll) {
-      Reason reason;
-      reason.Say(spec.DirName().source, MsgConstructDoesNotResult,
-          "Fully unrolled loop", "a loop nest");
-      return {-1, std::move(reason)};
+  case llvm::omp::Directive::OMPD_unroll: {
+    auto [cons, _1]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
+    auto [prod, _2]{GetGeneratedNestDepthWithReason(spec, version, semaCtx)};
+    if (cons && prod) {
+      return WithReason<int64_t>{*prod.value - *cons.value,
+          Reason().Append(cons.reason).Append(prod.reason)};
     }
-    return {0, Reason()};
+    return {};
+  }
   default:
     llvm_unreachable("Expecting loop-transforming construct");
   }
@@ -1000,6 +990,18 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
 
   if (IsLoopTransforming(dir)) {
     switch (dir) {
+    case llvm::omp::Directive::OMPD_flatten:
+      if (auto &&value{GetArgumentValueWithReason(
+              spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
+        // FLATTEN DEPTH(n) replaces n loops with 1.
+        return {std::move(value), true};
+      } else {
+        Reason reason;
+        reason.Say(spec.DirName().source, MsgClauseAbsentAssume,
+            GetUpperName(llvm::omp::Clause::OMPC_depth, version), "a depth of 2");
+        return {{2, std::move(reason)}, true};
+      }
+      break;
     case llvm::omp::Directive::OMPD_interchange: {
       // Get the length of the argument list to PERMUTATION.
       if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
@@ -1015,6 +1017,8 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
           spec.source, MsgClauseAbsentAssume, name, "a permutation (2, 1)");
       return {{2, std::move(reason)}, true};
     }
+    case llvm::omp::Directive::OMPD_nothing:
+      return {WithReason<int64_t>(0), false};
     case llvm::omp::Directive::OMPD_stripe:
     case llvm::omp::Directive::OMPD_tile: {
       // Get the length of the argument list to SIZES.
@@ -1035,10 +1039,9 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
       return {{1, std::move(reason)}, true};
     }
     case llvm::omp::Directive::OMPD_reverse:
+    case llvm::omp::Directive::OMPD_split:
     case llvm::omp::Directive::OMPD_unroll:
       return {WithReason<int64_t>(1), false};
-    // TODO: case llvm::omp::Directive::OMPD_flatten:
-    // TODO: case llvm::omp::Directive::OMPD_split:
     default:
       break;
     }
@@ -1047,6 +1050,46 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
   return {{}, false};
 }
 
+/// Return the depth of the generated nest(s)
+///   {generated-depth, is-perfect-nest}
+std::pair<WithReason<int64_t>, bool> GetGeneratedNestDepthWithReason(
+    const parser::OmpDirectiveSpecification &spec, unsigned version,
+    SemanticsContext *semaCtx) {
+  llvm::omp::Directive dir{spec.DirId()};
+  if (!IsLoopTransforming(dir)) {
+    return {{}, false};
+  }
+
+  auto [depth, _]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
+
+  switch (dir) {
+  case llvm::omp::Directive::OMPD_flatten:
+    return {WithReason<int64_t>(1), true};
+  case llvm::omp::Directive::OMPD_fuse:
+  case llvm::omp::Directive::OMPD_split:
+    // These result in loop sequences.
+    return {{}, false};
+  case llvm::omp::Directive::OMPD_interchange:
+  case llvm::omp::Directive::OMPD_nothing:
+  case llvm::omp::Directive::OMPD_reverse:
+    return {depth, true};
+  case llvm::omp::Directive::OMPD_stripe:
+  case llvm::omp::Directive::OMPD_tile:
+    if (depth) {
+      return {
+          WithReason<int64_t>(2 * *depth.value, std::move(depth.reason)), true};
+    }
+    return {{}, true};
+  case llvm::omp::Directive::OMPD_unroll:
+    if (IsFullUnroll(spec)) {
+      return {WithReason<int64_t>(0), false};
+    }
+    return {WithReason<int64_t>(1), true};
+  default:
+    return {{}, false};
+  }
+}
+
 /// Return the range of the affected nests in the sequence:
 ///   {first, count}
 WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(



More information about the flang-commits mailing list