[flang] [llvm] [mlir] [WIP] Implement workdistribute construct (PR #140523)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 19 07:19:07 PDT 2025


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/140523

>From e0dff6afb7aa31330aa0516effb7a0f65df5315f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 12:57:36 -0800
Subject: [PATCH 01/11] Add coexecute directives

---
 llvm/include/llvm/Frontend/OpenMP/OMP.td | 45 ++++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 0af4b436649a3..752486a8105b6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -682,6 +682,8 @@ def OMP_CancellationPoint : Directive<"cancellation point"> {
   let association = AS_None;
   let category = CA_Executable;
 }
+def OMP_Coexecute : Directive<"coexecute"> {}
+def OMP_EndCoexecute : Directive<"end coexecute"> {}
 def OMP_Critical : Directive<"critical"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_Hint>,
@@ -2198,6 +2200,33 @@ def OMP_TargetTeams : Directive<"target teams"> {
   let leafConstructs = [OMP_Target, OMP_Teams];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+  ];
+
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMX_Bare>,
+  ];
+}
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2484,6 +2513,22 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
   let leafConstructs = [OMP_TaskLoop, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>
+  ];
+}
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,

>From 8b1b36f5e716b8186d98b0d5c47c0fdf649ae67b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 May 2025 11:01:45 +0530
Subject: [PATCH 02/11] [OpenMP] Fix Coexecute definitions

---
 llvm/include/llvm/Frontend/OpenMP/OMP.td | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 752486a8105b6..7f450b43c2e36 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -682,8 +682,15 @@ def OMP_CancellationPoint : Directive<"cancellation point"> {
   let association = AS_None;
   let category = CA_Executable;
 }
-def OMP_Coexecute : Directive<"coexecute"> {}
-def OMP_EndCoexecute : Directive<"end coexecute"> {}
+def OMP_Coexecute : Directive<"coexecute"> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndCoexecute : Directive<"end coexecute"> {
+  let leafConstructs = OMP_Coexecute.leafConstructs;
+  let association = OMP_Coexecute.association;
+  let category = OMP_Coexecute.category;
+}
 def OMP_Critical : Directive<"critical"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_Hint>,
@@ -2224,8 +2231,10 @@ def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
     VersionedClause<OMPC_NumTeams>,
     VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
-    VersionedClause<OMPC_OMX_Bare>,
+    VersionedClause<OMPC_OMPX_Bare>,
   ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
+  let category = CA_Executable;
 }
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
@@ -2528,6 +2537,8 @@ def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
     VersionedClause<OMPC_NumTeams>,
     VersionedClause<OMPC_ThreadLimit>
   ];
+  let leafConstructs = [OMP_Target, OMP_Teams];
+  let category = CA_Executable;
 }
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [

>From 9b8d66a45e602375ec779e6c5bdd43232644f9a2 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 12:58:10 -0800
Subject: [PATCH 03/11] Add omp.coexecute op

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..8061aa0209cc9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -325,6 +325,41 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
   let hasRegionVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Coexecute Construct
+//===----------------------------------------------------------------------===//
+
+def CoexecuteOp : OpenMP_Op<"coexecute"> {
+  let summary = "coexecute directive";
+  let description = [{
+    The coexecute construct specifies that the teams from the teams directive
+    this is nested in shall cooperate to execute the computation in this region.
+    There is no implicit barrier at the end as specified in the standard.
+
+    TODO
+    We should probably change the defaut behaviour to have a barrier unless
+    nowait is specified, see below snippet.
+
+    ```
+    !$omp target teams
+        !$omp coexecute
+                tmp = matmul(x, y)
+        !$omp end coexecute
+        a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed!
+    !$omp end target teams coexecute
+    ```
+
+  }];
+
+  let arguments = (ins UnitAttr:$nowait);
+
+  let regions = (region AnyRegion:$region);
+
+  let assemblyFormat = [{
+    oilist(`nowait` $nowait) $region attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // 2.8.2 Single Construct
 //===----------------------------------------------------------------------===//

>From 7ecec06e00230649446c77c970160d4814a90e07 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 17:50:41 -0800
Subject: [PATCH 04/11] Initial frontend support for coexecute

---
 .../include/flang/Semantics/openmp-directive-sets.h | 13 +++++++++++++
 flang/lib/Lower/OpenMP/OpenMP.cpp                   | 12 ++++++++++++
 flang/lib/Parser/openmp-parsers.cpp                 |  5 ++++-
 flang/lib/Semantics/resolve-directives.cpp          |  6 ++++++
 4 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..5c316e030c63f 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_coexecute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -187,9 +188,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_coexecute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allCoexecuteSet{
+    Directive::OMPD_coexecute,
+    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_target_teams_coexecute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +238,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_coexecute,
+    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_coexecute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +305,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_coexecute,
     } | allDoSet,
 };
 
@@ -376,6 +388,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_coexecute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 61bbc709872fd..b0c65c8e37988 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,6 +2670,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::CoexecuteOp
+genCoexecuteOp(Fortran::lower::AbstractConverter &converter,
+               Fortran::lower::pft::Evaluation &eval,
+               mlir::Location currentLocation,
+               const Fortran::parser::OmpClauseList &clauseList) {
+  return genOpWithBody<mlir::omp::CoexecuteOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation for atomic operations
 //===----------------------------------------------------------------------===//
@@ -3929,6 +3938,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
                        item);
     break;
+  case llvm::omp::Directive::OMPD_coexecute:
+    newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList);
+    break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
     unsigned version = semaCtx.langOptions().OpenMPVersion;
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 52d3a5844c969..591b1642baed3 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,12 +1344,15 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 9fa7bc8964854..ae297f204356a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1617,6 +1617,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_coexecute:
+  case llvm::omp::Directive::OMPD_teams_coexecute:
+  case llvm::omp::Directive::OMPD_target_teams_coexecute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
@@ -1650,6 +1653,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_coexecute:
+  case llvm::omp::Directive::OMPD_teams_coexecute:
+  case llvm::omp::Directive::OMPD_target_teams_coexecute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
   case llvm::omp::Directive::OMPD_target_parallel: {

>From ca0cc44c621fde89f1889fb328e66755ca3f5e3a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 May 2025 15:09:45 +0530
Subject: [PATCH 05/11] [OpenMP] Fixes for coexecute definitions

---
 .../flang/Semantics/openmp-directive-sets.h   |  1 +
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 13 ++--
 flang/test/Lower/OpenMP/coexecute.f90         | 59 +++++++++++++++++++
 llvm/include/llvm/Frontend/OpenMP/OMP.td      | 33 +++++------
 4 files changed, 83 insertions(+), 23 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/coexecute.f90

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 5c316e030c63f..43f4e642b3d86 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -173,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_coexecute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index b0c65c8e37988..80612bd05ad97 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2671,12 +2671,13 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 }
 
 static mlir::omp::CoexecuteOp
-genCoexecuteOp(Fortran::lower::AbstractConverter &converter,
-               Fortran::lower::pft::Evaluation &eval,
-               mlir::Location currentLocation,
-               const Fortran::parser::OmpClauseList &clauseList) {
+genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+            mlir::Location loc, const ConstructQueue &queue,
+            ConstructQueue::const_iterator item) {
   return genOpWithBody<mlir::omp::CoexecuteOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList);
+    OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                      llvm::omp::Directive::OMPD_coexecute), queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3939,7 +3940,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                        item);
     break;
   case llvm::omp::Directive::OMPD_coexecute:
-    newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList);
+    newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90
new file mode 100644
index 0000000000000..b14f71f9bbbfa
--- /dev/null
+++ b/flang/test/Lower/OpenMP/coexecute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_coexecute
+subroutine target_teams_coexecute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp target teams coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams coexecute
+end subroutine target_teams_coexecute
+
+! CHECK-LABEL: func @_QPteams_coexecute
+subroutine teams_coexecute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp teams coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams coexecute
+end subroutine teams_coexecute
+
+! CHECK-LABEL: func @_QPtarget_teams_coexecute_m
+subroutine target_teams_coexecute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp target
+  !$omp teams
+  !$omp coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end coexecute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_coexecute_m
+
+! CHECK-LABEL: func @_QPteams_coexecute_m
+subroutine teams_coexecute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp teams
+  !$omp coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end coexecute
+  !$omp end teams
+end subroutine teams_coexecute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 7f450b43c2e36..3f02b6534816f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -2209,29 +2209,28 @@ def OMP_TargetTeams : Directive<"target teams"> {
 }
 def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
   let allowedClauses = [
-    VersionedClause<OMPC_If>,
-    VersionedClause<OMPC_Map>,
-    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Allocate>,
     VersionedClause<OMPC_Depend>,
     VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_IsDevicePtr>,
     VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
     VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_UsesAllocators, 50>,
     VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
   ];
-
   let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
     VersionedClause<OMPC_Device>,
     VersionedClause<OMPC_NoWait>,
-    VersionedClause<OMPC_DefaultMap>,
-    VersionedClause<OMPC_Default>,
     VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
   ];
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
   let category = CA_Executable;
@@ -2524,20 +2523,20 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
 }
 def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
   let allowedClauses = [
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_Reduction>,
     VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
     VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
   ];
   let allowedOnceClauses = [
     VersionedClause<OMPC_Default>,
     VersionedClause<OMPC_If, 52>,
     VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>
+    VersionedClause<OMPC_ThreadLimit>,
   ];
-  let leafConstructs = [OMP_Target, OMP_Teams];
+  let leafConstructs = [OMP_Teams, OMP_Coexecute];
   let category = CA_Executable;
 }
 def OMP_TeamsDistribute : Directive<"teams distribute"> {

>From 8077858a88a2ffac2b7d726c1ae5d1f1edb64b67 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 14:48:52 +0530
Subject: [PATCH 06/11] [OpenMP] Use workdistribute instead of coexecute

---
 .../flang/Semantics/openmp-directive-sets.h   |  24 ++---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  15 ++-
 flang/lib/Parser/openmp-parsers.cpp           |   6 +-
 flang/lib/Semantics/resolve-directives.cpp    |  12 +--
 flang/test/Lower/OpenMP/coexecute.f90         |  59 ----------
 flang/test/Lower/OpenMP/workdistribute.f90    |  59 ++++++++++
 llvm/include/llvm/Frontend/OpenMP/OMP.td      | 101 ++++++++++--------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  28 ++---
 8 files changed, 152 insertions(+), 152 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/coexecute.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute.f90

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 43f4e642b3d86..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,7 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
-    Directive::OMPD_target_teams_coexecute,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -173,7 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
-    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -189,14 +189,14 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
-        Directive::OMPD_target_teams_coexecute,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
-static const OmpDirectiveSet allCoexecuteSet{
-    Directive::OMPD_coexecute,
-    Directive::OMPD_teams_coexecute,
-    Directive::OMPD_target_teams_coexecute,
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 //===----------------------------------------------------------------------===//
@@ -239,9 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
-    Directive::OMPD_target_teams_coexecute,
-    Directive::OMPD_teams_coexecute,
-    Directive::OMPD_coexecute,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -306,7 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
-        Directive::OMPD_coexecute,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -389,7 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
-    Directive::OMPD_coexecute,
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 80612bd05ad97..42d04bceddb12 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,14 +2670,14 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
-static mlir::omp::CoexecuteOp
-genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+static mlir::omp::WorkdistributeOp
+genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
             ConstructQueue::const_iterator item) {
-  return genOpWithBody<mlir::omp::CoexecuteOp>(
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
     OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                      llvm::omp::Directive::OMPD_coexecute), queue, item);
+                      llvm::omp::Directive::OMPD_workdistribute), queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3939,16 +3939,15 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
                        item);
     break;
-  case llvm::omp::Directive::OMPD_coexecute:
-    newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item);
-    break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
     unsigned version = semaCtx.langOptions().OpenMPVersion;
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 591b1642baed3..5b5ee257edd1f 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,15 +1344,15 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
-        "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute),
+        "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
-        "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute),
+        "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
         "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
-        "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute))))
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index ae297f204356a..4636508ac144d 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1617,9 +1617,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
-  case llvm::omp::Directive::OMPD_coexecute:
-  case llvm::omp::Directive::OMPD_teams_coexecute:
-  case llvm::omp::Directive::OMPD_target_teams_coexecute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
@@ -1653,9 +1653,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
-  case llvm::omp::Directive::OMPD_coexecute:
-  case llvm::omp::Directive::OMPD_teams_coexecute:
-  case llvm::omp::Directive::OMPD_target_teams_coexecute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
   case llvm::omp::Directive::OMPD_target_parallel: {
diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90
deleted file mode 100644
index b14f71f9bbbfa..0000000000000
--- a/flang/test/Lower/OpenMP/coexecute.f90
+++ /dev/null
@@ -1,59 +0,0 @@
-! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
-
-! CHECK-LABEL: func @_QPtarget_teams_coexecute
-subroutine target_teams_coexecute()
-  ! CHECK: omp.target
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp target teams coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end target teams coexecute
-end subroutine target_teams_coexecute
-
-! CHECK-LABEL: func @_QPteams_coexecute
-subroutine teams_coexecute()
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp teams coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end teams coexecute
-end subroutine teams_coexecute
-
-! CHECK-LABEL: func @_QPtarget_teams_coexecute_m
-subroutine target_teams_coexecute_m()
-  ! CHECK: omp.target
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp target
-  !$omp teams
-  !$omp coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end coexecute
-  !$omp end teams
-  !$omp end target
-end subroutine target_teams_coexecute_m
-
-! CHECK-LABEL: func @_QPteams_coexecute_m
-subroutine teams_coexecute_m()
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp teams
-  !$omp coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end coexecute
-  !$omp end teams
-end subroutine teams_coexecute_m
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 3f02b6534816f..c88a3049450de 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1292,6 +1292,15 @@ def OMP_EndWorkshare : Directive<"end workshare"> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<"workdistribute"> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<"end workdistribute"> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2207,34 +2216,6 @@ def OMP_TargetTeams : Directive<"target teams"> {
   let leafConstructs = [OMP_Target, OMP_Teams];
   let category = CA_Executable;
 }
-def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
-  let allowedClauses = [
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_Depend>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_HasDeviceAddr, 51>,
-    VersionedClause<OMPC_If>,
-    VersionedClause<OMPC_IsDevicePtr>,
-    VersionedClause<OMPC_Map>,
-    VersionedClause<OMPC_OMPX_Attribute>,
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_UsesAllocators, 50>,
-  ];
-  let allowedOnceClauses = [
-    VersionedClause<OMPC_Default>,
-    VersionedClause<OMPC_DefaultMap>,
-    VersionedClause<OMPC_Device>,
-    VersionedClause<OMPC_NoWait>,
-    VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_OMPX_DynCGroupMem>,
-    VersionedClause<OMPC_OMPX_Bare>,
-    VersionedClause<OMPC_ThreadLimit>,
-  ];
-  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
-  let category = CA_Executable;
-}
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2457,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd :
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<"target teams workdistribute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<"target teams loop"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2521,24 +2530,6 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
   let leafConstructs = [OMP_TaskLoop, OMP_Simd];
   let category = CA_Executable;
 }
-def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
-  let allowedClauses = [
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_OMPX_Attribute>,
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Shared>,
-  ];
-  let allowedOnceClauses = [
-    VersionedClause<OMPC_Default>,
-    VersionedClause<OMPC_If, 52>,
-    VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>,
-  ];
-  let leafConstructs = [OMP_Teams, OMP_Coexecute];
-  let category = CA_Executable;
-}
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2726,3 +2717,21 @@ def OMP_teams_loop : Directive<"teams loop"> {
   let leafConstructs = [OMP_Teams, OMP_loop];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<"teams workdistribute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8061aa0209cc9..5e3ab0e908d21 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -326,38 +326,30 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 }
 
 //===----------------------------------------------------------------------===//
-// Coexecute Construct
+// workdistribute Construct
 //===----------------------------------------------------------------------===//
 
-def CoexecuteOp : OpenMP_Op<"coexecute"> {
-  let summary = "coexecute directive";
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
   let description = [{
-    The coexecute construct specifies that the teams from the teams directive
-    this is nested in shall cooperate to execute the computation in this region.
-    There is no implicit barrier at the end as specified in the standard.
-
-    TODO
-    We should probably change the defaut behaviour to have a barrier unless
-    nowait is specified, see below snippet.
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
 
     ```
     !$omp target teams
-        !$omp coexecute
+        !$omp workdistribute
                 tmp = matmul(x, y)
-        !$omp end coexecute
+        !$omp end workdistribute
         a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed!
-    !$omp end target teams coexecute
+    !$omp end target teams workdistribute
     ```
 
   }];
 
-  let arguments = (ins UnitAttr:$nowait);
-
   let regions = (region AnyRegion:$region);
 
-  let assemblyFormat = [{
-    oilist(`nowait` $nowait) $region attr-dict
-  }];
+  let assemblyFormat = "$region attr-dict";
 }
 
 //===----------------------------------------------------------------------===//

>From 085062f9ebac1079a720f614498c0b124eda8a51 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 16:17:14 +0530
Subject: [PATCH 07/11] [OpenMP] workdistribute trivial lowering

Lowering logic inspired from ivanradanov coexeute lowering
f56da1a207df4a40776a8570122a33f047074a3c
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   4 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 101 ++++++++++++++++++
 .../OpenMP/lower-workdistribute.mlir          |  52 +++++++++
 4 files changed, 158 insertions(+)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..75c9d2b0d494e
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,101 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Value.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
+using OpRewritePattern::OpRewritePattern;
+mlir::LogicalResult
+    matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
+                       mlir::PatternRewriter &rewriter) const override {
+        auto loc = workdistribute->getLoc();
+        auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
+        if (!teams) {
+            mlir::emitError(loc, "workdistribute not nested in teams\n");
+            return mlir::failure();
+        }
+        if (workdistribute.getRegion().getBlocks().size() != 1) {
+            mlir::emitError(loc, "workdistribute with multiple blocks\n");
+            return mlir::failure();
+        }
+        if (teams.getRegion().getBlocks().size() != 1) {
+            mlir::emitError(loc, "teams with multiple blocks\n");
+           return mlir::failure();
+        }
+        if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
+            mlir::emitError(loc, "teams with multiple nested ops\n");
+            return mlir::failure();
+        }
+        mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
+        rewriter.eraseOp(workdistributeBlock->getTerminator());
+        rewriter.inlineBlockBefore(workdistributeBlock, teams);
+        rewriter.eraseOp(teams);
+        return mlir::success();
+    }
+};
+
+class LowerWorkdistributePass
+    : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+  void runOnOperation() override {
+    mlir::MLIRContext &context = getContext();
+    mlir::RewritePatternSet patterns(&context);
+    mlir::GreedyRewriteConfig config;
+    // prevent the pattern driver form merging blocks
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
+  
+    patterns.insert<WorkdistributeToSingle>(&context);
+    mlir::Operation *op = getOperation();
+    if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
+      mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+      signalPassFailure();
+    }
+  }
+};
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
new file mode 100644
index 0000000000000..34c8c3f01976d
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
@@ -0,0 +1,52 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @_QPtarget_simple() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+// CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK:           %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+// CHECK:           %[[VAL_4:.*]] = fir.zero_bits !fir.heap<i32>
+// CHECK:           %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+// CHECK:           fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+// CHECK:           hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref<i32>
+// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+// CHECK:           omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref<!fir.box<!fir.heap<i32>>>) {
+// CHECK:             %[[VAL_10:.*]] = arith.constant 10 : i32
+// CHECK:             %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i32>
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
+// CHECK:             hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @_QPtarget_simple() {
+    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+    %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+    %3 = fir.zero_bits !fir.heap<i32>
+    %4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+    fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+    %c2_i32 = arith.constant 2 : i32
+    hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
+    %6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+    omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>){
+        omp.teams {
+            omp.workdistribute {
+                %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+                %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+                %c10_i32 = arith.constant 10 : i32
+                %13 = fir.load %11#0 : !fir.ref<i32>
+                %14 = arith.addi %c10_i32, %13 : i32
+                hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+                omp.terminator
+            }
+            omp.terminator
+        }
+        omp.terminator
+    }
+    return
+}
\ No newline at end of file

>From c9b63efe85f7aed781a4a0fd7d0888b595f2a520 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 19:29:33 +0530
Subject: [PATCH 08/11] [Flang][OpenMP] Add workdistribute lower pass to
 pipeline

---
 flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 77751908e35be..15983f80c1e4b 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -278,8 +278,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP)
+  if (enableOpenMP) {
     pm.addPass(flangomp::createLowerWorkshare());
+    pm.addPass(flangomp::createLowerWorkdistribute());
+  }
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 048c3f22d55248a21e53ee3f4be2c0b07b500039 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 15 May 2025 16:39:21 +0530
Subject: [PATCH 09/11] [Flang][OpenMP] Add FissionWorkdistribute lowering.

Fission logic inspired from ivanradanov implementation :
c97eca4010e460aac5a3d795614ca0980bce4565
---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 233 ++++++++++++++----
 .../OpenMP/lower-workdistribute-fission.mlir  |  60 +++++
 ...ir => lower-workdistribute-to-single.mlir} |   2 +-
 3 files changed, 243 insertions(+), 52 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
 rename flang/test/Transforms/OpenMP/{lower-workdistribute.mlir => lower-workdistribute-to-single.mlir} (99%)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 75c9d2b0d494e..f799202be2645 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -10,31 +10,26 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <flang/Optimizer/Builder/FIRBuilder.h>
-#include <flang/Optimizer/Dialect/FIROps.h>
-#include <flang/Optimizer/Dialect/FIRType.h>
-#include <flang/Optimizer/HLFIR/HLFIROps.h>
-#include <flang/Optimizer/OpenMP/Passes.h>
-#include <llvm/ADT/BreadthFirstIterator.h>
-#include <llvm/ADT/STLExtras.h>
-#include <llvm/ADT/SmallVectorExtras.h>
-#include <llvm/ADT/iterator_range.h>
-#include <llvm/Support/ErrorHandling.h>
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
-#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
-#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
 #include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
 #include <mlir/IR/IRMapping.h>
-#include <mlir/IR/OpDefinition.h>
 #include <mlir/IR/PatternMatch.h>
-#include <mlir/IR/Value.h>
-#include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
+#include <optional>
 #include <variant>
 
 namespace flangomp {
@@ -48,52 +43,188 @@ using namespace mlir;
 
 namespace {
 
-struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
-using OpRewritePattern::OpRewritePattern;
-mlir::LogicalResult
-    matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
-                       mlir::PatternRewriter &rewriter) const override {
-        auto loc = workdistribute->getLoc();
-        auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
-        if (!teams) {
-            mlir::emitError(loc, "workdistribute not nested in teams\n");
-            return mlir::failure();
-        }
-        if (workdistribute.getRegion().getBlocks().size() != 1) {
-            mlir::emitError(loc, "workdistribute with multiple blocks\n");
-            return mlir::failure();
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.workdistribute region.
+static bool shouldParallelize(Operation *op) {
+    // Currently we cannot parallelize operations with results that have uses
+    if (llvm::any_of(op->getResults(),
+                     [](OpResult v) -> bool { return !v.use_empty(); }))
+      return false;
+    // We will parallelize unordered loops - these come from array syntax
+    if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+      auto unordered = loop.getUnordered();
+      if (!unordered)
+        return false;
+      return *unordered;
+    }
+    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+      auto callee = callOp.getCallee();
+      if (!callee)
+        return false;
+      auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+      // TODO need to insert a check here whether it is a call we can actually
+      // parallelize currently
+      if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+        return true;
+      return false;
+    }
+    // We cannot parallise anything else
+    return false;
+}
+
+struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
+    using OpRewritePattern::OpRewritePattern;
+    LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                    PatternRewriter &rewriter) const override {
+        auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+        if (!workdistributeOp) {
+            LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+            return failure();
         }
-        if (teams.getRegion().getBlocks().size() != 1) {
-            mlir::emitError(loc, "teams with multiple blocks\n");
-           return mlir::failure();
+      
+        Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+        rewriter.eraseOp(workdistributeBlock->getTerminator());
+        rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+        rewriter.eraseOp(teamsOp);
+        workdistributeOp.emitWarning("unable to parallelize coexecute");
+        return success();
+    }
+};
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+struct FissionWorkdistribute
+    : public OpRewritePattern<omp::WorkdistributeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult
+  matchAndRewrite(omp::WorkdistributeOp workdistribute,
+                  PatternRewriter &rewriter) const override {
+    auto loc = workdistribute->getLoc();
+    auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+    if (!teams) {
+      emitError(loc, "workdistribute not nested in teams\n");
+      return failure();
+    }
+    if (workdistribute.getRegion().getBlocks().size() != 1) {
+      emitError(loc, "workdistribute with multiple blocks\n");
+      return failure();
+    }
+    if (teams.getRegion().getBlocks().size() != 1) {
+      emitError(loc, "teams with multiple blocks\n");
+      return failure();
+    }
+    if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
+      emitError(loc, "teams with multiple nested ops\n");
+      return failure();
+    }
+
+    auto *teamsBlock = &teams.getRegion().front();
+
+    // While we have unhandled operations in the original workdistribute
+    auto *workdistributeBlock = &workdistribute.getRegion().front();
+    auto *terminator = workdistributeBlock->getTerminator();
+    bool changed = false;
+    while (&workdistributeBlock->front() != terminator) {
+      rewriter.setInsertionPoint(teams);
+      IRMapping mapping;
+      llvm::SmallVector<Operation *> hoisted;
+      Operation *parallelize = nullptr;
+      for (auto &op : workdistribute.getOps()) {
+        if (&op == terminator) {
+          break;
         }
-        if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
-            mlir::emitError(loc, "teams with multiple nested ops\n");
-            return mlir::failure();
+        if (shouldParallelize(&op)) {
+          parallelize = &op;
+          break;
+        } else {
+          rewriter.clone(op, mapping);
+          hoisted.push_back(&op);
+          changed = true;
         }
-        mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
-        rewriter.eraseOp(workdistributeBlock->getTerminator());
-        rewriter.inlineBlockBefore(workdistributeBlock, teams);
-        rewriter.eraseOp(teams);
-        return mlir::success();
+      }
+
+      for (auto *op : hoisted)
+        rewriter.replaceOp(op, mapping.lookup(op));
+
+      if (parallelize && hoisted.empty() &&
+          parallelize->getNextNode() == terminator)
+        break;
+      if (parallelize) {
+        auto newTeams = rewriter.cloneWithoutRegions(teams);
+        auto *newTeamsBlock = rewriter.createBlock(
+            &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+        for (auto arg : teamsBlock->getArguments())
+          newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+        auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+        rewriter.create<omp::TerminatorOp>(loc);
+        rewriter.createBlock(&newWorkdistribute.getRegion(),
+                            newWorkdistribute.getRegion().begin(), {}, {});
+        auto *cloned = rewriter.clone(*parallelize);
+        rewriter.replaceOp(parallelize, cloned);
+        rewriter.create<omp::TerminatorOp>(loc);
+        changed = true;
+      }
     }
+    return success(changed);
+  }
 };
 
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
   void runOnOperation() override {
-    mlir::MLIRContext &context = getContext();
-    mlir::RewritePatternSet patterns(&context);
-    mlir::GreedyRewriteConfig config;
+    MLIRContext &context = getContext();
+    RewritePatternSet patterns(&context);
+    GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
     config.setRegionSimplificationLevel(
-        mlir::GreedySimplifyRegionLevel::Disabled);
+        GreedySimplifyRegionLevel::Disabled);
   
-    patterns.insert<WorkdistributeToSingle>(&context);
-    mlir::Operation *op = getOperation();
-    if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
-      mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+    patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
+    Operation *op = getOperation();
+    if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+      emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
       signalPassFailure();
     }
   }
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
new file mode 100644
index 0000000000000..ea03a10dd3d44
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -0,0 +1,60 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_fission_workdistribute({{.*}}) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:             %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           return
+// CHECK:         }
+module {
+func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
+  return
+}
+func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
+  return
+}
+func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
+  %c0_idx = arith.constant 0 : index
+  %c1_idx = arith.constant 1 : index
+  %c9_idx = arith.constant 9 : index
+  %float_val = arith.constant 5.0 : f32
+  omp.teams   {
+    omp.workdistribute   {
+      fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
+      fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
+        %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
+        %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
+      }
+      fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
+      fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
+      fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
+        %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
+      }
+      %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
+      fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
+      omp.terminator  
+    }
+    omp.terminator
+  }
+  return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
similarity index 99%
rename from flang/test/Transforms/OpenMP/lower-workdistribute.mlir
rename to flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
index 34c8c3f01976d..0cc2aeded2532 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
@@ -49,4 +49,4 @@ func.func @_QPtarget_simple() {
         omp.terminator
     }
     return
-}
\ No newline at end of file
+}

>From 5b30d3dcb80cb4cef546f5bfdf3aa389f527d07d Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sun, 18 May 2025 12:37:53 +0530
Subject: [PATCH 10/11] [OpenMP][Flang] Lower teams workdistribute do_loop to
 wsloop.

Logic inspired from ivanradanov commit
5682e9ea7fcba64693f7cfdc0f1970fab2d7d4ae
---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 177 +++++++++++++++---
 .../OpenMP/lower-workdistribute-doloop.mlir   |  28 +++
 .../OpenMP/lower-workdistribute-fission.mlir  |  22 ++-
 3 files changed, 193 insertions(+), 34 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index f799202be2645..de208a8190650 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -6,18 +6,22 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements the lowering of omp.workdistribute.
+// This file implements the lowering and optimisations of omp.workdistribute.
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
@@ -29,6 +33,7 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
+#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 #include <variant>
 
@@ -87,25 +92,6 @@ static bool shouldParallelize(Operation *op) {
     return false;
 }
 
-struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
-    using OpRewritePattern::OpRewritePattern;
-    LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
-                                    PatternRewriter &rewriter) const override {
-        auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
-        if (!workdistributeOp) {
-            LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
-            return failure();
-        }
-      
-        Block *workdistributeBlock = &workdistributeOp.getRegion().front();
-        rewriter.eraseOp(workdistributeBlock->getTerminator());
-        rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
-        rewriter.eraseOp(teamsOp);
-        workdistributeOp.emitWarning("unable to parallelize coexecute");
-        return success();
-    }
-};
-
 /// If B() and D() are parallelizable,
 ///
 /// omp.teams {
@@ -210,22 +196,161 @@ struct FissionWorkdistribute
   }
 };
 
+static void
+genLoopNestClauseOps(mlir::Location loc,
+                     mlir::PatternRewriter &rewriter,
+                     fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void
+genWsLoopOp(mlir::PatternRewriter &rewriter,
+            fir::DoLoopOp doLoop,
+            const mlir::omp::LoopNestOperands &clauseOps) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    rewriter.eraseOp(terminatorOp);
+  }
+  return;
+}
+
+/// If fir.do_loop id present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to 
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     omp.parallel {
+///       omp.wsloop {
+///         omp.loop_nest
+///           ...
+///         }
+///       }
+///     }
+///   }
+/// }
+
+struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                PatternRewriter &rewriter) const override {
+    auto teamsLoc = teamsOp->getLoc();
+    auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+    if (!workdistributeOp) {
+      LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+      return failure();
+    }
+    assert(teamsOp.getReductionVars().empty());
+
+    auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
+    if (doLoop && shouldParallelize(doLoop)) {
+
+      auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
+      rewriter.createBlock(&parallelOp.getRegion());
+      rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
+
+      mlir::omp::LoopNestOperands loopNestClauseOps;
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
+                           loopNestClauseOps);
+
+      genWsLoopOp(rewriter, doLoop, loopNestClauseOps);
+      rewriter.setInsertionPoint(doLoop);
+      rewriter.eraseOp(doLoop);
+      return success();
+    }
+    return failure();
+  }
+};
+
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                  PatternRewriter &rewriter) const override {
+      auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+      if (!workdistributeOp) {
+          LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+          return failure();
+      }      
+      Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+      rewriter.eraseOp(workdistributeBlock->getTerminator());
+      rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+      rewriter.eraseOp(teamsOp);
+      return success();
+  }
+};
+
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
   void runOnOperation() override {
     MLIRContext &context = getContext();
-    RewritePatternSet patterns(&context);
     GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
     config.setRegionSimplificationLevel(
         GreedySimplifyRegionLevel::Disabled);
-  
-    patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
+    
     Operation *op = getOperation();
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
-      emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
-      signalPassFailure();
+    {
+      RewritePatternSet patterns(&context);
+      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
+      if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+        emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+        signalPassFailure();
+      }
+    }
+    {
+      RewritePatternSet patterns(&context);
+      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
+      if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+        emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+        signalPassFailure();
+      }
     }
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
new file mode 100644
index 0000000000000..666bdb3ced647
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -0,0 +1,28 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x({{.*}})
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           omp.parallel {
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK:                 fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+  omp.teams {
+    omp.workdistribute { 
+      fir.do_loop %iv = %lb to %ub step %step unordered {
+        %zero = arith.constant 0 : index
+        fir.store %zero to %addr : !fir.ref<index>
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
\ No newline at end of file
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
index ea03a10dd3d44..cf50d135d01ec 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -6,20 +6,26 @@
 // CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
 // CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
-// CHECK:           fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
-// CHECK:             %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:             %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
-// CHECK:             %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:             fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:           omp.parallel {
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                 %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                 fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
 // CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
 // CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
-// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
 // CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
 // CHECK:           }
-// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
-// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref<f32>
 // CHECK:           return
 // CHECK:         }
 module {

>From df65bd53111948abf6f9c2e1e0b8e27aa5e01946 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 19 May 2025 15:33:53 +0530
Subject: [PATCH 11/11] clang format

---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  18 +--
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 108 +++++++++---------
 flang/lib/Parser/openmp-parsers.cpp           |   6 +-
 .../OpenMP/lower-workdistribute-doloop.mlir   |   2 +-
 4 files changed, 67 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 42d04bceddb12..ebf0710ab4feb 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,14 +2670,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
-static mlir::omp::WorkdistributeOp
-genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-            mlir::Location loc, const ConstructQueue &queue,
-            ConstructQueue::const_iterator item) {
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
   return genOpWithBody<mlir::omp::WorkdistributeOp>(
-    OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                      llvm::omp::Directive::OMPD_workdistribute), queue, item);
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3946,7 +3947,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
   case llvm::omp::Directive::OMPD_workdistribute:
-    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
     break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index de208a8190650..f75d4d1988fd2 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -14,15 +14,16 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/Utils/IndexingUtils.h>
@@ -33,7 +34,6 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
-#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 #include <variant>
 
@@ -66,30 +66,30 @@ static T getPerfectlyNested(Operation *op) {
 /// This is the single source of truth about whether we should parallelize an
 /// operation nested in an omp.workdistribute region.
 static bool shouldParallelize(Operation *op) {
-    // Currently we cannot parallelize operations with results that have uses
-    if (llvm::any_of(op->getResults(),
-                     [](OpResult v) -> bool { return !v.use_empty(); }))
+  // Currently we cannot parallelize operations with results that have uses
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
       return false;
-    // We will parallelize unordered loops - these come from array syntax
-    if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
-      auto unordered = loop.getUnordered();
-      if (!unordered)
-        return false;
-      return *unordered;
-    }
-    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
-      auto callee = callOp.getCallee();
-      if (!callee)
-        return false;
-      auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
-      // TODO need to insert a check here whether it is a call we can actually
-      // parallelize currently
-      if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
-        return true;
+    return *unordered;
+  }
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
       return false;
-    }
-    // We cannot parallise anything else
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    // TODO need to insert a check here whether it is a call we can actually
+    // parallelize currently
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
     return false;
+  }
+  // We cannot parallise anything else
+  return false;
 }
 
 /// If B() and D() are parallelizable,
@@ -120,12 +120,10 @@ static bool shouldParallelize(Operation *op) {
 /// }
 /// E()
 
-struct FissionWorkdistribute
-    : public OpRewritePattern<omp::WorkdistributeOp> {
+struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
   using OpRewritePattern::OpRewritePattern;
-  LogicalResult
-  matchAndRewrite(omp::WorkdistributeOp workdistribute,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
+                                PatternRewriter &rewriter) const override {
     auto loc = workdistribute->getLoc();
     auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
     if (!teams) {
@@ -185,7 +183,7 @@ struct FissionWorkdistribute
         auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
         rewriter.create<omp::TerminatorOp>(loc);
         rewriter.createBlock(&newWorkdistribute.getRegion(),
-                            newWorkdistribute.getRegion().begin(), {}, {});
+                             newWorkdistribute.getRegion().begin(), {}, {});
         auto *cloned = rewriter.clone(*parallelize);
         rewriter.replaceOp(parallelize, cloned);
         rewriter.create<omp::TerminatorOp>(loc);
@@ -197,8 +195,7 @@ struct FissionWorkdistribute
 };
 
 static void
-genLoopNestClauseOps(mlir::Location loc,
-                     mlir::PatternRewriter &rewriter,
+genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
                      fir::DoLoopOp loop,
                      mlir::omp::LoopNestOperands &loopNestClauseOps) {
   assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -209,10 +206,8 @@ genLoopNestClauseOps(mlir::Location loc,
   loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
 }
 
-static void
-genWsLoopOp(mlir::PatternRewriter &rewriter,
-            fir::DoLoopOp doLoop,
-            const mlir::omp::LoopNestOperands &clauseOps) {
+static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps) {
 
   auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
   rewriter.createBlock(&wsloopOp.getRegion());
@@ -236,7 +231,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
   return;
 }
 
-/// If fir.do_loop id present inside teams workdistribute
+/// If fir.do_loop is present inside teams workdistribute
 ///
 /// omp.teams {
 ///   omp.workdistribute {
@@ -246,7 +241,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
 ///   }
 /// }
 ///
-/// Then, its lowered to 
+/// Then, its lowered to
 ///
 /// omp.teams {
 ///   omp.workdistribute {
@@ -277,7 +272,8 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
 
       auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
       rewriter.createBlock(&parallelOp.getRegion());
-      rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
+      rewriter.setInsertionPoint(
+          rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
 
       mlir::omp::LoopNestOperands loopNestClauseOps;
       genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
@@ -292,7 +288,6 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
   }
 };
 
-
 /// If A() and B () are present inside teams workdistribute
 ///
 /// omp.teams {
@@ -311,17 +306,17 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
 struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
-                                  PatternRewriter &rewriter) const override {
-      auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
-      if (!workdistributeOp) {
-          LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
-          return failure();
-      }      
-      Block *workdistributeBlock = &workdistributeOp.getRegion().front();
-      rewriter.eraseOp(workdistributeBlock->getTerminator());
-      rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
-      rewriter.eraseOp(teamsOp);
-      return success();
+                                PatternRewriter &rewriter) const override {
+    auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+    if (!workdistributeOp) {
+      LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+      return failure();
+    }
+    Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+    rewriter.eraseOp(workdistributeBlock->getTerminator());
+    rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+    rewriter.eraseOp(teamsOp);
+    return success();
   }
 };
 
@@ -332,13 +327,13 @@ class LowerWorkdistributePass
     MLIRContext &context = getContext();
     GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
-    config.setRegionSimplificationLevel(
-        GreedySimplifyRegionLevel::Disabled);
-    
+    config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
+
     Operation *op = getOperation();
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
+      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(
+          &context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
@@ -346,7 +341,8 @@ class LowerWorkdistributePass
     }
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
+      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
+          &context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
@@ -354,4 +350,4 @@ class LowerWorkdistributePass
     }
   }
 };
-}
+} // namespace
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 5b5ee257edd1f..dc25adfe28c1d 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,12 +1344,14 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
-        "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
-        "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
         "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
         "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
index 666bdb3ced647..9fb970246b90c 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -25,4 +25,4 @@ func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<
     omp.terminator
   }
   return
-}
\ No newline at end of file
+}



More information about the llvm-commits mailing list