[flang-commits] [flang] [flang][OpenMP] Map `teams loop` to `teams distribute` when required. (PR #127489)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Mon Feb 17 21:05:10 PST 2025


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/127489

>From 1011020cc1e6dcd059baaa3eac56d93c78278a0b Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 17 Feb 2025 04:10:59 -0600
Subject: [PATCH] [flang][OpenMP] Map `teams loop` to `teams distribute` when
 required.

This extends support for generic `loop` rewriting by:
1. Preventing nesting multiple worksharing loops inside each other. This
   is checked by walking the `teams loop` region searching for any
   `loop` directive whose `bind` modifier is `parallel`.
2. Preventing convert to worksharing loop if calls to unknow functions
   are found in the `loop` directive's body.

We walk the `teams loop` body to identify either of the above 2
conditions, if either of them is found to be true, we map the `loop`
directive to `distribute`.
---
 .../OpenMP/GenericLoopConversion.cpp          | 57 +++++++++++++++++-
 flang/test/Lower/OpenMP/loop-directive.f90    | 59 +++++++++++++++++++
 2 files changed, 113 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
index d2581e3ad0a0a..15e5b0f2f019b 100644
--- a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -56,7 +56,10 @@ class GenericLoopConversionPattern
           "not yet implemented: Combined `parallel loop` directive");
       break;
     case GenericLoopCombinedInfo::TeamsLoop:
-      rewriteToDistributeParallelDo(loopOp, rewriter);
+      if (teamsLoopCanBeParallelFor(loopOp))
+        rewriteToDistributeParallelDo(loopOp, rewriter);
+      else
+        rewriteToDistrbute(loopOp, rewriter);
       break;
     }
 
@@ -97,8 +100,6 @@ class GenericLoopConversionPattern
     if (!loopOp.getReductionVars().empty())
       return todo("reduction");
 
-    // TODO For `teams loop`, check similar constrains to what is checked
-    // by `TeamsLoopChecker` in SemaOpenMP.cpp.
     return mlir::success();
   }
 
@@ -118,6 +119,56 @@ class GenericLoopConversionPattern
     return result;
   }
 
+  /// Checks whether a `teams loop` construct can be rewriten to `teams
+  /// distribute parallel do` or it has to be converted to `teams distribute`.
+  ///
+  /// This checks similar constrains to what is checked by `TeamsLoopChecker` in
+  /// SemaOpenMP.cpp in clang.
+  static bool teamsLoopCanBeParallelFor(mlir::omp::LoopOp loopOp) {
+    bool canBeParallelFor = true;
+    loopOp.walk([&](mlir::omp::LoopOp nestedLoopOp) {
+      if (nestedLoopOp == loopOp)
+        mlir::WalkResult::advance();
+
+      GenericLoopCombinedInfo combinedInfo =
+          findGenericLoopCombineInfo(nestedLoopOp);
+
+      // Worksharing loops cannot be nested inside each other. Therefore, if the
+      // current `loop` directive nests another `loop` whose `bind` modifier is
+      // `parallel`, this `loop` directive cannot be mapped to `distribute
+      // parallel for` but rather only to `distribute`.
+      if (combinedInfo == GenericLoopCombinedInfo::Standalone &&
+          nestedLoopOp.getBindKind() &&
+          *nestedLoopOp.getBindKind() == mlir::omp::ClauseBindKind::Parallel)
+        canBeParallelFor = false;
+
+      // TODO check for combined `parallel loop` when we support it.
+
+      return canBeParallelFor ? mlir::WalkResult::advance()
+                              : mlir::WalkResult::interrupt();
+    });
+
+    loopOp.walk([&](mlir::CallOpInterface callOp) {
+      // Calls to non-OpenMP API runtime functions inhibits transformation to
+      // `teams distribute parallel do` since the called functions might have
+      // nested parallelism themselves.
+      bool isOpenMPAPI = false;
+      mlir::CallInterfaceCallable callable = callOp.getCallableForCallee();
+
+      if (auto callableSymRef = mlir::dyn_cast<mlir::SymbolRefAttr>(callable))
+        isOpenMPAPI =
+            callableSymRef.getRootReference().strref().find("omp_") == 0;
+
+      if (!isOpenMPAPI)
+        canBeParallelFor = false;
+
+      return canBeParallelFor ? mlir::WalkResult::advance()
+                              : mlir::WalkResult::interrupt();
+    });
+
+    return canBeParallelFor;
+  }
+
   void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
                              mlir::ConversionPatternRewriter &rewriter) const {
     using namespace mlir::omp;
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
index 785f732e1b4f5..4cc136fac198a 100644
--- a/flang/test/Lower/OpenMP/loop-directive.f90
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -179,3 +179,62 @@ subroutine test_standalone_bind_parallel
     c(i) = a(i) * b(i)
   end do
 end subroutine
+
+! CHECK-LABEL: func.func @_QPteams_loop_cannot_be_parallel_for
+subroutine teams_loop_cannot_be_parallel_for
+  implicit none
+  integer :: iter, iter2, val(20)
+  val = 0
+  ! CHECK: omp.teams {
+
+  ! Verify the outer `loop` directive was mapped to only `distribute`.
+  ! CHECK-NOT: omp.parallel {{.*}}
+  ! CHECK:     omp.distribute {{.*}} {
+  ! CHECK-NOT:   omp.wsloop
+  ! CHECK:       omp.loop_nest {{.*}} {
+
+  ! Verify the inner `loop` directive was mapped to a worksharing loop.
+  ! CHECK:         omp.wsloop {{.*}} {
+  ! CHECK:           omp.loop_nest {{.*}} {
+  ! CHECK:           }
+  ! CHECK:         }
+
+  ! CHECK:       }
+  ! CHECK:     }
+
+  ! CHECK: }
+  !$omp target teams loop map(tofrom:val)
+  DO iter = 1, 5
+    !$omp loop bind(parallel)
+    DO iter2 = 1, 5
+      val(iter+iter2) = iter+iter2
+    END DO
+  END DO
+end subroutine
+
+subroutine foo()
+end subroutine
+
+! CHECK-LABEL: func.func @_QPteams_loop_cannot_be_parallel_for_2
+subroutine teams_loop_cannot_be_parallel_for_2
+  implicit none
+  integer :: iter, iter2, val(20)
+  val = 0
+
+  ! CHECK: omp.teams {
+
+  ! Verify the `loop` directive was mapped to only `distribute`.
+  ! CHECK-NOT: omp.parallel {{.*}}
+  ! CHECK:     omp.distribute {{.*}} {
+  ! CHECK-NOT:   omp.wsloop
+  ! CHECK:       omp.loop_nest {{.*}} {
+  ! CHECK:         fir.call @_QPfoo
+  ! CHECK:       }
+  ! CHECK:     }
+
+  ! CHECK: }
+  !$omp target teams loop map(tofrom:val)
+  DO iter = 1, 5
+    call foo()
+  END DO
+end subroutine



More information about the flang-commits mailing list