[flang-commits] [flang] [llvm] [mlir] [WIP] Implement workdistribute construct (PR #140523)

via flang-commits flang-commits at lists.llvm.org
Tue Jun 10 06:18:13 PDT 2025


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/140523

>From e0dff6afb7aa31330aa0516effb7a0f65df5315f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 12:57:36 -0800
Subject: [PATCH 01/17] Add coexecute directives

---
 llvm/include/llvm/Frontend/OpenMP/OMP.td | 45 ++++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 0af4b436649a3..752486a8105b6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -682,6 +682,8 @@ def OMP_CancellationPoint : Directive<"cancellation point"> {
   let association = AS_None;
   let category = CA_Executable;
 }
+def OMP_Coexecute : Directive<"coexecute"> {}
+def OMP_EndCoexecute : Directive<"end coexecute"> {}
 def OMP_Critical : Directive<"critical"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_Hint>,
@@ -2198,6 +2200,33 @@ def OMP_TargetTeams : Directive<"target teams"> {
   let leafConstructs = [OMP_Target, OMP_Teams];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+  ];
+
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMX_Bare>,
+  ];
+}
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2484,6 +2513,22 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
   let leafConstructs = [OMP_TaskLoop, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>
+  ];
+}
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,

>From 8b1b36f5e716b8186d98b0d5c47c0fdf649ae67b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 May 2025 11:01:45 +0530
Subject: [PATCH 02/17] [OpenMP] Fix Coexecute definitions

---
 llvm/include/llvm/Frontend/OpenMP/OMP.td | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 752486a8105b6..7f450b43c2e36 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -682,8 +682,15 @@ def OMP_CancellationPoint : Directive<"cancellation point"> {
   let association = AS_None;
   let category = CA_Executable;
 }
-def OMP_Coexecute : Directive<"coexecute"> {}
-def OMP_EndCoexecute : Directive<"end coexecute"> {}
+def OMP_Coexecute : Directive<"coexecute"> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndCoexecute : Directive<"end coexecute"> {
+  let leafConstructs = OMP_Coexecute.leafConstructs;
+  let association = OMP_Coexecute.association;
+  let category = OMP_Coexecute.category;
+}
 def OMP_Critical : Directive<"critical"> {
   let allowedOnceClauses = [
     VersionedClause<OMPC_Hint>,
@@ -2224,8 +2231,10 @@ def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
     VersionedClause<OMPC_NumTeams>,
     VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
-    VersionedClause<OMPC_OMX_Bare>,
+    VersionedClause<OMPC_OMPX_Bare>,
   ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
+  let category = CA_Executable;
 }
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
@@ -2528,6 +2537,8 @@ def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
     VersionedClause<OMPC_NumTeams>,
     VersionedClause<OMPC_ThreadLimit>
   ];
+  let leafConstructs = [OMP_Target, OMP_Teams];
+  let category = CA_Executable;
 }
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [

>From 9b8d66a45e602375ec779e6c5bdd43232644f9a2 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 12:58:10 -0800
Subject: [PATCH 03/17] Add omp.coexecute op

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a79fbf77a268..8061aa0209cc9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -325,6 +325,41 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
   let hasRegionVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Coexecute Construct
+//===----------------------------------------------------------------------===//
+
+def CoexecuteOp : OpenMP_Op<"coexecute"> {
+  let summary = "coexecute directive";
+  let description = [{
+    The coexecute construct specifies that the teams from the teams directive
+    this is nested in shall cooperate to execute the computation in this region.
+    There is no implicit barrier at the end as specified in the standard.
+
+    TODO
+    We should probably change the defaut behaviour to have a barrier unless
+    nowait is specified, see below snippet.
+
+    ```
+    !$omp target teams
+        !$omp coexecute
+                tmp = matmul(x, y)
+        !$omp end coexecute
+        a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed!
+    !$omp end target teams coexecute
+    ```
+
+  }];
+
+  let arguments = (ins UnitAttr:$nowait);
+
+  let regions = (region AnyRegion:$region);
+
+  let assemblyFormat = [{
+    oilist(`nowait` $nowait) $region attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // 2.8.2 Single Construct
 //===----------------------------------------------------------------------===//

>From 7ecec06e00230649446c77c970160d4814a90e07 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov2 at llnl.gov>
Date: Mon, 4 Dec 2023 17:50:41 -0800
Subject: [PATCH 04/17] Initial frontend support for coexecute

---
 .../include/flang/Semantics/openmp-directive-sets.h | 13 +++++++++++++
 flang/lib/Lower/OpenMP/OpenMP.cpp                   | 12 ++++++++++++
 flang/lib/Parser/openmp-parsers.cpp                 |  5 ++++-
 flang/lib/Semantics/resolve-directives.cpp          |  6 ++++++
 4 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index dd610c9702c28..5c316e030c63f 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
+    Directive::OMPD_target_teams_coexecute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -187,9 +188,16 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
+        Directive::OMPD_target_teams_coexecute,
     } | topTeamsSet,
 };
 
+static const OmpDirectiveSet allCoexecuteSet{
+    Directive::OMPD_coexecute,
+    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_target_teams_coexecute,
+};
+
 //===----------------------------------------------------------------------===//
 // Directive sets for groups of multiple directives
 //===----------------------------------------------------------------------===//
@@ -230,6 +238,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
+    Directive::OMPD_target_teams_coexecute,
+    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_coexecute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -294,6 +305,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
+        Directive::OMPD_coexecute,
     } | allDoSet,
 };
 
@@ -376,6 +388,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
+    Directive::OMPD_coexecute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 61bbc709872fd..b0c65c8e37988 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,6 +2670,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::omp::CoexecuteOp
+genCoexecuteOp(Fortran::lower::AbstractConverter &converter,
+               Fortran::lower::pft::Evaluation &eval,
+               mlir::Location currentLocation,
+               const Fortran::parser::OmpClauseList &clauseList) {
+  return genOpWithBody<mlir::omp::CoexecuteOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList);
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation for atomic operations
 //===----------------------------------------------------------------------===//
@@ -3929,6 +3938,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
                        item);
     break;
+  case llvm::omp::Directive::OMPD_coexecute:
+    newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList);
+    break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
     unsigned version = semaCtx.langOptions().OpenMPVersion;
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 52d3a5844c969..591b1642baed3 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,12 +1344,15 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
+        "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
+        "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
-        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
+        "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
+        "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 9fa7bc8964854..ae297f204356a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1617,6 +1617,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_coexecute:
+  case llvm::omp::Directive::OMPD_teams_coexecute:
+  case llvm::omp::Directive::OMPD_target_teams_coexecute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
@@ -1650,6 +1653,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
+  case llvm::omp::Directive::OMPD_coexecute:
+  case llvm::omp::Directive::OMPD_teams_coexecute:
+  case llvm::omp::Directive::OMPD_target_teams_coexecute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
   case llvm::omp::Directive::OMPD_target_parallel: {

>From ca0cc44c621fde89f1889fb328e66755ca3f5e3a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 May 2025 15:09:45 +0530
Subject: [PATCH 05/17] [OpenMP] Fixes for coexecute definitions

---
 .../flang/Semantics/openmp-directive-sets.h   |  1 +
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 13 ++--
 flang/test/Lower/OpenMP/coexecute.f90         | 59 +++++++++++++++++++
 llvm/include/llvm/Frontend/OpenMP/OMP.td      | 33 +++++------
 4 files changed, 83 insertions(+), 23 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/coexecute.f90

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 5c316e030c63f..43f4e642b3d86 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -173,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_teams_coexecute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index b0c65c8e37988..80612bd05ad97 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2671,12 +2671,13 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 }
 
 static mlir::omp::CoexecuteOp
-genCoexecuteOp(Fortran::lower::AbstractConverter &converter,
-               Fortran::lower::pft::Evaluation &eval,
-               mlir::Location currentLocation,
-               const Fortran::parser::OmpClauseList &clauseList) {
+genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+            mlir::Location loc, const ConstructQueue &queue,
+            ConstructQueue::const_iterator item) {
   return genOpWithBody<mlir::omp::CoexecuteOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList);
+    OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                      llvm::omp::Directive::OMPD_coexecute), queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3939,7 +3940,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                        item);
     break;
   case llvm::omp::Directive::OMPD_coexecute:
-    newOp = genCoexecuteOp(converter, eval, currentLocation, beginClauseList);
+    newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90
new file mode 100644
index 0000000000000..b14f71f9bbbfa
--- /dev/null
+++ b/flang/test/Lower/OpenMP/coexecute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_coexecute
+subroutine target_teams_coexecute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp target teams coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams coexecute
+end subroutine target_teams_coexecute
+
+! CHECK-LABEL: func @_QPteams_coexecute
+subroutine teams_coexecute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp teams coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams coexecute
+end subroutine teams_coexecute
+
+! CHECK-LABEL: func @_QPtarget_teams_coexecute_m
+subroutine target_teams_coexecute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp target
+  !$omp teams
+  !$omp coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end coexecute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_coexecute_m
+
+! CHECK-LABEL: func @_QPteams_coexecute_m
+subroutine teams_coexecute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.coexecute
+  !$omp teams
+  !$omp coexecute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end coexecute
+  !$omp end teams
+end subroutine teams_coexecute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 7f450b43c2e36..3f02b6534816f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -2209,29 +2209,28 @@ def OMP_TargetTeams : Directive<"target teams"> {
 }
 def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
   let allowedClauses = [
-    VersionedClause<OMPC_If>,
-    VersionedClause<OMPC_Map>,
-    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Allocate>,
     VersionedClause<OMPC_Depend>,
     VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_IsDevicePtr>,
     VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
     VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_UsesAllocators, 50>,
     VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
   ];
-
   let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
     VersionedClause<OMPC_Device>,
     VersionedClause<OMPC_NoWait>,
-    VersionedClause<OMPC_DefaultMap>,
-    VersionedClause<OMPC_Default>,
     VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>,
     VersionedClause<OMPC_OMPX_DynCGroupMem>,
     VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
   ];
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
   let category = CA_Executable;
@@ -2524,20 +2523,20 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
 }
 def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
   let allowedClauses = [
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_Reduction>,
     VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
     VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
   ];
   let allowedOnceClauses = [
     VersionedClause<OMPC_Default>,
     VersionedClause<OMPC_If, 52>,
     VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>
+    VersionedClause<OMPC_ThreadLimit>,
   ];
-  let leafConstructs = [OMP_Target, OMP_Teams];
+  let leafConstructs = [OMP_Teams, OMP_Coexecute];
   let category = CA_Executable;
 }
 def OMP_TeamsDistribute : Directive<"teams distribute"> {

>From 8077858a88a2ffac2b7d726c1ae5d1f1edb64b67 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 14:48:52 +0530
Subject: [PATCH 06/17] [OpenMP] Use workdistribute instead of coexecute

---
 .../flang/Semantics/openmp-directive-sets.h   |  24 ++---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  15 ++-
 flang/lib/Parser/openmp-parsers.cpp           |   6 +-
 flang/lib/Semantics/resolve-directives.cpp    |  12 +--
 flang/test/Lower/OpenMP/coexecute.f90         |  59 ----------
 flang/test/Lower/OpenMP/workdistribute.f90    |  59 ++++++++++
 llvm/include/llvm/Frontend/OpenMP/OMP.td      | 101 ++++++++++--------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  28 ++---
 8 files changed, 152 insertions(+), 152 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/coexecute.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute.f90

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 43f4e642b3d86..7ced6ed9b44d6 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -143,7 +143,7 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
     Directive::OMPD_target_teams_loop,
-    Directive::OMPD_target_teams_coexecute,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -173,7 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
-    Directive::OMPD_teams_coexecute,
+    Directive::OMPD_teams_workdistribute,
 };
 
 static const OmpDirectiveSet bottomTeamsSet{
@@ -189,14 +189,14 @@ static const OmpDirectiveSet allTeamsSet{
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_target_teams_distribute_simd,
         Directive::OMPD_target_teams_loop,
-        Directive::OMPD_target_teams_coexecute,
+        Directive::OMPD_target_teams_workdistribute,
     } | topTeamsSet,
 };
 
-static const OmpDirectiveSet allCoexecuteSet{
-    Directive::OMPD_coexecute,
-    Directive::OMPD_teams_coexecute,
-    Directive::OMPD_target_teams_coexecute,
+static const OmpDirectiveSet allWorkdistributeSet{
+    Directive::OMPD_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_target_teams_workdistribute,
 };
 
 //===----------------------------------------------------------------------===//
@@ -239,9 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_taskgroup,
     Directive::OMPD_teams,
     Directive::OMPD_workshare,
-    Directive::OMPD_target_teams_coexecute,
-    Directive::OMPD_teams_coexecute,
-    Directive::OMPD_coexecute,
+    Directive::OMPD_target_teams_workdistribute,
+    Directive::OMPD_teams_workdistribute,
+    Directive::OMPD_workdistribute,
 };
 
 static const OmpDirectiveSet loopConstructSet{
@@ -306,7 +306,7 @@ static const OmpDirectiveSet workShareSet{
         Directive::OMPD_scope,
         Directive::OMPD_sections,
         Directive::OMPD_single,
-        Directive::OMPD_coexecute,
+        Directive::OMPD_workdistribute,
     } | allDoSet,
 };
 
@@ -389,7 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
 };
 
 static const OmpDirectiveSet nestedTeamsAllowedSet{
-    Directive::OMPD_coexecute,
+    Directive::OMPD_workdistribute,
     Directive::OMPD_distribute,
     Directive::OMPD_distribute_parallel_do,
     Directive::OMPD_distribute_parallel_do_simd,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 80612bd05ad97..42d04bceddb12 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,14 +2670,14 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
-static mlir::omp::CoexecuteOp
-genCoexecuteOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+static mlir::omp::WorkdistributeOp
+genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
             ConstructQueue::const_iterator item) {
-  return genOpWithBody<mlir::omp::CoexecuteOp>(
+  return genOpWithBody<mlir::omp::WorkdistributeOp>(
     OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                      llvm::omp::Directive::OMPD_coexecute), queue, item);
+                      llvm::omp::Directive::OMPD_workdistribute), queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3939,16 +3939,15 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
                        item);
     break;
-  case llvm::omp::Directive::OMPD_coexecute:
-    newOp = genCoexecuteOp(converter, symTable, semaCtx, eval, loc, queue, item);
-    break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll: {
     unsigned version = semaCtx.langOptions().OpenMPVersion;
     TODO(loc, "Unhandled loop directive (" +
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
-  // case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
                            queue, item);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 591b1642baed3..5b5ee257edd1f 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,15 +1344,15 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
-        "TARGET TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_coexecute),
+        "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
-        "TEAMS COEXECUTE" >> pure(llvm::omp::Directive::OMPD_teams_coexecute),
+        "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
         "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
-        "COEXECUTE" >> pure(llvm::omp::Directive::OMPD_coexecute))))
+        "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
 
 TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
     sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index ae297f204356a..4636508ac144d 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -1617,9 +1617,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_taskgroup:
   case llvm::omp::Directive::OMPD_teams:
-  case llvm::omp::Directive::OMPD_coexecute:
-  case llvm::omp::Directive::OMPD_teams_coexecute:
-  case llvm::omp::Directive::OMPD_target_teams_coexecute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
@@ -1653,9 +1653,9 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
   case llvm::omp::Directive::OMPD_target:
   case llvm::omp::Directive::OMPD_task:
   case llvm::omp::Directive::OMPD_teams:
-  case llvm::omp::Directive::OMPD_coexecute:
-  case llvm::omp::Directive::OMPD_teams_coexecute:
-  case llvm::omp::Directive::OMPD_target_teams_coexecute:
+  case llvm::omp::Directive::OMPD_workdistribute:
+  case llvm::omp::Directive::OMPD_teams_workdistribute:
+  case llvm::omp::Directive::OMPD_target_teams_workdistribute:
   case llvm::omp::Directive::OMPD_parallel_workshare:
   case llvm::omp::Directive::OMPD_target_teams:
   case llvm::omp::Directive::OMPD_target_parallel: {
diff --git a/flang/test/Lower/OpenMP/coexecute.f90 b/flang/test/Lower/OpenMP/coexecute.f90
deleted file mode 100644
index b14f71f9bbbfa..0000000000000
--- a/flang/test/Lower/OpenMP/coexecute.f90
+++ /dev/null
@@ -1,59 +0,0 @@
-! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
-
-! CHECK-LABEL: func @_QPtarget_teams_coexecute
-subroutine target_teams_coexecute()
-  ! CHECK: omp.target
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp target teams coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end target teams coexecute
-end subroutine target_teams_coexecute
-
-! CHECK-LABEL: func @_QPteams_coexecute
-subroutine teams_coexecute()
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp teams coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end teams coexecute
-end subroutine teams_coexecute
-
-! CHECK-LABEL: func @_QPtarget_teams_coexecute_m
-subroutine target_teams_coexecute_m()
-  ! CHECK: omp.target
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp target
-  !$omp teams
-  !$omp coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end coexecute
-  !$omp end teams
-  !$omp end target
-end subroutine target_teams_coexecute_m
-
-! CHECK-LABEL: func @_QPteams_coexecute_m
-subroutine teams_coexecute_m()
-  ! CHECK: omp.teams
-  ! CHECK: omp.coexecute
-  !$omp teams
-  !$omp coexecute
-  ! CHECK: fir.call
-  call f1()
-  ! CHECK: omp.terminator
-  ! CHECK: omp.terminator
-  !$omp end coexecute
-  !$omp end teams
-end subroutine teams_coexecute_m
diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90
new file mode 100644
index 0000000000000..924205bb72e5e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+! CHECK-LABEL: func @_QPteams_workdistribute
+subroutine teams_workdistribute()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end teams workdistribute
+end subroutine teams_workdistribute
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
+subroutine target_teams_workdistribute_m()
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp target
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+  !$omp end target
+end subroutine target_teams_workdistribute_m
+
+! CHECK-LABEL: func @_QPteams_workdistribute_m
+subroutine teams_workdistribute_m()
+  ! CHECK: omp.teams
+  ! CHECK: omp.workdistribute
+  !$omp teams
+  !$omp workdistribute
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  ! CHECK: omp.terminator
+  !$omp end workdistribute
+  !$omp end teams
+end subroutine teams_workdistribute_m
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 3f02b6534816f..c88a3049450de 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -1292,6 +1292,15 @@ def OMP_EndWorkshare : Directive<"end workshare"> {
   let category = OMP_Workshare.category;
   let languages = [L_Fortran];
 }
+def OMP_Workdistribute : Directive<"workdistribute"> {
+  let association = AS_Block;
+  let category = CA_Executable;
+}
+def OMP_EndWorkdistribute : Directive<"end workdistribute"> {
+  let leafConstructs = OMP_Workdistribute.leafConstructs;
+  let association = OMP_Workdistribute.association;
+  let category = OMP_Workdistribute.category;
+}
 
 //===----------------------------------------------------------------------===//
 // Definitions of OpenMP compound directives
@@ -2207,34 +2216,6 @@ def OMP_TargetTeams : Directive<"target teams"> {
   let leafConstructs = [OMP_Target, OMP_Teams];
   let category = CA_Executable;
 }
-def OMP_TargetTeamsCoexecute : Directive<"target teams coexecute"> {
-  let allowedClauses = [
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_Depend>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_HasDeviceAddr, 51>,
-    VersionedClause<OMPC_If>,
-    VersionedClause<OMPC_IsDevicePtr>,
-    VersionedClause<OMPC_Map>,
-    VersionedClause<OMPC_OMPX_Attribute>,
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Shared>,
-    VersionedClause<OMPC_UsesAllocators, 50>,
-  ];
-  let allowedOnceClauses = [
-    VersionedClause<OMPC_Default>,
-    VersionedClause<OMPC_DefaultMap>,
-    VersionedClause<OMPC_Device>,
-    VersionedClause<OMPC_NoWait>,
-    VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_OMPX_DynCGroupMem>,
-    VersionedClause<OMPC_OMPX_Bare>,
-    VersionedClause<OMPC_ThreadLimit>,
-  ];
-  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Coexecute];
-  let category = CA_Executable;
-}
 def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2457,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd :
   let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
   let category = CA_Executable;
 }
+def OMP_TargetTeamsWorkdistribute : Directive<"target teams workdistribute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_Depend>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_HasDeviceAddr, 51>,
+    VersionedClause<OMPC_If>,
+    VersionedClause<OMPC_IsDevicePtr>,
+    VersionedClause<OMPC_Map>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+    VersionedClause<OMPC_UsesAllocators, 50>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_DefaultMap>,
+    VersionedClause<OMPC_Device>,
+    VersionedClause<OMPC_NoWait>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_OMPX_DynCGroupMem>,
+    VersionedClause<OMPC_OMPX_Bare>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
 def OMP_target_teams_loop : Directive<"target teams loop"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2521,24 +2530,6 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> {
   let leafConstructs = [OMP_TaskLoop, OMP_Simd];
   let category = CA_Executable;
 }
-def OMP_TeamsCoexecute : Directive<"teams coexecute"> {
-  let allowedClauses = [
-    VersionedClause<OMPC_Allocate>,
-    VersionedClause<OMPC_FirstPrivate>,
-    VersionedClause<OMPC_OMPX_Attribute>,
-    VersionedClause<OMPC_Private>,
-    VersionedClause<OMPC_Reduction>,
-    VersionedClause<OMPC_Shared>,
-  ];
-  let allowedOnceClauses = [
-    VersionedClause<OMPC_Default>,
-    VersionedClause<OMPC_If, 52>,
-    VersionedClause<OMPC_NumTeams>,
-    VersionedClause<OMPC_ThreadLimit>,
-  ];
-  let leafConstructs = [OMP_Teams, OMP_Coexecute];
-  let category = CA_Executable;
-}
 def OMP_TeamsDistribute : Directive<"teams distribute"> {
   let allowedClauses = [
     VersionedClause<OMPC_Allocate>,
@@ -2726,3 +2717,21 @@ def OMP_teams_loop : Directive<"teams loop"> {
   let leafConstructs = [OMP_Teams, OMP_loop];
   let category = CA_Executable;
 }
+def OMP_TeamsWorkdistribute : Directive<"teams workdistribute"> {
+  let allowedClauses = [
+    VersionedClause<OMPC_Allocate>,
+    VersionedClause<OMPC_FirstPrivate>,
+    VersionedClause<OMPC_OMPX_Attribute>,
+    VersionedClause<OMPC_Private>,
+    VersionedClause<OMPC_Reduction>,
+    VersionedClause<OMPC_Shared>,
+  ];
+  let allowedOnceClauses = [
+    VersionedClause<OMPC_Default>,
+    VersionedClause<OMPC_If, 52>,
+    VersionedClause<OMPC_NumTeams>,
+    VersionedClause<OMPC_ThreadLimit>,
+  ];
+  let leafConstructs = [OMP_Teams, OMP_Workdistribute];
+  let category = CA_Executable;
+}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8061aa0209cc9..5e3ab0e908d21 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -326,38 +326,30 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 }
 
 //===----------------------------------------------------------------------===//
-// Coexecute Construct
+// workdistribute Construct
 //===----------------------------------------------------------------------===//
 
-def CoexecuteOp : OpenMP_Op<"coexecute"> {
-  let summary = "coexecute directive";
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+  let summary = "workdistribute directive";
   let description = [{
-    The coexecute construct specifies that the teams from the teams directive
-    this is nested in shall cooperate to execute the computation in this region.
-    There is no implicit barrier at the end as specified in the standard.
-
-    TODO
-    We should probably change the defaut behaviour to have a barrier unless
-    nowait is specified, see below snippet.
+    workdistribute divides execution of the enclosed structured block into
+    separate units of work, each executed only once by each
+    initial thread in the league.
 
     ```
     !$omp target teams
-        !$omp coexecute
+        !$omp workdistribute
                 tmp = matmul(x, y)
-        !$omp end coexecute
+        !$omp end workdistribute
         a = tmp(0, 0) ! there is no implicit barrier! the matmul hasnt completed!
-    !$omp end target teams coexecute
+    !$omp end target teams workdistribute
     ```
 
   }];
 
-  let arguments = (ins UnitAttr:$nowait);
-
   let regions = (region AnyRegion:$region);
 
-  let assemblyFormat = [{
-    oilist(`nowait` $nowait) $region attr-dict
-  }];
+  let assemblyFormat = "$region attr-dict";
 }
 
 //===----------------------------------------------------------------------===//

>From 085062f9ebac1079a720f614498c0b124eda8a51 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 16:17:14 +0530
Subject: [PATCH 07/17] [OpenMP] workdistribute trivial lowering

Lowering logic inspired from ivanradanov coexeute lowering
f56da1a207df4a40776a8570122a33f047074a3c
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   4 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 101 ++++++++++++++++++
 .../OpenMP/lower-workdistribute.mlir          |  52 +++++++++
 4 files changed, 158 insertions(+)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..75c9d2b0d494e
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,101 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Value.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
+using OpRewritePattern::OpRewritePattern;
+mlir::LogicalResult
+    matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
+                       mlir::PatternRewriter &rewriter) const override {
+        auto loc = workdistribute->getLoc();
+        auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
+        if (!teams) {
+            mlir::emitError(loc, "workdistribute not nested in teams\n");
+            return mlir::failure();
+        }
+        if (workdistribute.getRegion().getBlocks().size() != 1) {
+            mlir::emitError(loc, "workdistribute with multiple blocks\n");
+            return mlir::failure();
+        }
+        if (teams.getRegion().getBlocks().size() != 1) {
+            mlir::emitError(loc, "teams with multiple blocks\n");
+           return mlir::failure();
+        }
+        if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
+            mlir::emitError(loc, "teams with multiple nested ops\n");
+            return mlir::failure();
+        }
+        mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
+        rewriter.eraseOp(workdistributeBlock->getTerminator());
+        rewriter.inlineBlockBefore(workdistributeBlock, teams);
+        rewriter.eraseOp(teams);
+        return mlir::success();
+    }
+};
+
+class LowerWorkdistributePass
+    : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+  void runOnOperation() override {
+    mlir::MLIRContext &context = getContext();
+    mlir::RewritePatternSet patterns(&context);
+    mlir::GreedyRewriteConfig config;
+    // prevent the pattern driver form merging blocks
+    config.setRegionSimplificationLevel(
+        mlir::GreedySimplifyRegionLevel::Disabled);
+  
+    patterns.insert<WorkdistributeToSingle>(&context);
+    mlir::Operation *op = getOperation();
+    if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
+      mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+      signalPassFailure();
+    }
+  }
+};
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
new file mode 100644
index 0000000000000..34c8c3f01976d
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
@@ -0,0 +1,52 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @_QPtarget_simple() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+// CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK:           %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+// CHECK:           %[[VAL_4:.*]] = fir.zero_bits !fir.heap<i32>
+// CHECK:           %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+// CHECK:           fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+// CHECK:           hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref<i32>
+// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+// CHECK:           omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref<!fir.box<!fir.heap<i32>>>) {
+// CHECK:             %[[VAL_10:.*]] = arith.constant 10 : i32
+// CHECK:             %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i32>
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
+// CHECK:             hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @_QPtarget_simple() {
+    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
+    %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
+    %3 = fir.zero_bits !fir.heap<i32>
+    %4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+    fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
+    %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+    %c2_i32 = arith.constant 2 : i32
+    hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
+    %6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
+    omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>){
+        omp.teams {
+            omp.workdistribute {
+                %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+                %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
+                %c10_i32 = arith.constant 10 : i32
+                %13 = fir.load %11#0 : !fir.ref<i32>
+                %14 = arith.addi %c10_i32, %13 : i32
+                hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
+                omp.terminator
+            }
+            omp.terminator
+        }
+        omp.terminator
+    }
+    return
+}
\ No newline at end of file

>From c9b63efe85f7aed781a4a0fd7d0888b595f2a520 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 May 2025 19:29:33 +0530
Subject: [PATCH 08/17] [Flang][OpenMP] Add workdistribute lower pass to
 pipeline

---
 flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 77751908e35be..15983f80c1e4b 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -278,8 +278,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP)
+  if (enableOpenMP) {
     pm.addPass(flangomp::createLowerWorkshare());
+    pm.addPass(flangomp::createLowerWorkdistribute());
+  }
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed

>From 048c3f22d55248a21e53ee3f4be2c0b07b500039 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 15 May 2025 16:39:21 +0530
Subject: [PATCH 09/17] [Flang][OpenMP] Add FissionWorkdistribute lowering.

Fission logic inspired from ivanradanov implementation :
c97eca4010e460aac5a3d795614ca0980bce4565
---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 233 ++++++++++++++----
 .../OpenMP/lower-workdistribute-fission.mlir  |  60 +++++
 ...ir => lower-workdistribute-to-single.mlir} |   2 +-
 3 files changed, 243 insertions(+), 52 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
 rename flang/test/Transforms/OpenMP/{lower-workdistribute.mlir => lower-workdistribute-to-single.mlir} (99%)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 75c9d2b0d494e..f799202be2645 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -10,31 +10,26 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <flang/Optimizer/Builder/FIRBuilder.h>
-#include <flang/Optimizer/Dialect/FIROps.h>
-#include <flang/Optimizer/Dialect/FIRType.h>
-#include <flang/Optimizer/HLFIR/HLFIROps.h>
-#include <flang/Optimizer/OpenMP/Passes.h>
-#include <llvm/ADT/BreadthFirstIterator.h>
-#include <llvm/ADT/STLExtras.h>
-#include <llvm/ADT/SmallVectorExtras.h>
-#include <llvm/ADT/iterator_range.h>
-#include <llvm/Support/ErrorHandling.h>
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
-#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
-#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
-#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
 #include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
 #include <mlir/IR/IRMapping.h>
-#include <mlir/IR/OpDefinition.h>
 #include <mlir/IR/PatternMatch.h>
-#include <mlir/IR/Value.h>
-#include <mlir/IR/Visitors.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
+#include <optional>
 #include <variant>
 
 namespace flangomp {
@@ -48,52 +43,188 @@ using namespace mlir;
 
 namespace {
 
-struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
-using OpRewritePattern::OpRewritePattern;
-mlir::LogicalResult
-    matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
-                       mlir::PatternRewriter &rewriter) const override {
-        auto loc = workdistribute->getLoc();
-        auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
-        if (!teams) {
-            mlir::emitError(loc, "workdistribute not nested in teams\n");
-            return mlir::failure();
-        }
-        if (workdistribute.getRegion().getBlocks().size() != 1) {
-            mlir::emitError(loc, "workdistribute with multiple blocks\n");
-            return mlir::failure();
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.workdistribute region.
+static bool shouldParallelize(Operation *op) {
+    // Currently we cannot parallelize operations with results that have uses
+    if (llvm::any_of(op->getResults(),
+                     [](OpResult v) -> bool { return !v.use_empty(); }))
+      return false;
+    // We will parallelize unordered loops - these come from array syntax
+    if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+      auto unordered = loop.getUnordered();
+      if (!unordered)
+        return false;
+      return *unordered;
+    }
+    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+      auto callee = callOp.getCallee();
+      if (!callee)
+        return false;
+      auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+      // TODO need to insert a check here whether it is a call we can actually
+      // parallelize currently
+      if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+        return true;
+      return false;
+    }
+    // We cannot parallise anything else
+    return false;
+}
+
+struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
+    using OpRewritePattern::OpRewritePattern;
+    LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                    PatternRewriter &rewriter) const override {
+        auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+        if (!workdistributeOp) {
+            LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+            return failure();
         }
-        if (teams.getRegion().getBlocks().size() != 1) {
-            mlir::emitError(loc, "teams with multiple blocks\n");
-           return mlir::failure();
+      
+        Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+        rewriter.eraseOp(workdistributeBlock->getTerminator());
+        rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+        rewriter.eraseOp(teamsOp);
+        workdistributeOp.emitWarning("unable to parallelize coexecute");
+        return success();
+    }
+};
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+struct FissionWorkdistribute
+    : public OpRewritePattern<omp::WorkdistributeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult
+  matchAndRewrite(omp::WorkdistributeOp workdistribute,
+                  PatternRewriter &rewriter) const override {
+    auto loc = workdistribute->getLoc();
+    auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+    if (!teams) {
+      emitError(loc, "workdistribute not nested in teams\n");
+      return failure();
+    }
+    if (workdistribute.getRegion().getBlocks().size() != 1) {
+      emitError(loc, "workdistribute with multiple blocks\n");
+      return failure();
+    }
+    if (teams.getRegion().getBlocks().size() != 1) {
+      emitError(loc, "teams with multiple blocks\n");
+      return failure();
+    }
+    if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
+      emitError(loc, "teams with multiple nested ops\n");
+      return failure();
+    }
+
+    auto *teamsBlock = &teams.getRegion().front();
+
+    // While we have unhandled operations in the original workdistribute
+    auto *workdistributeBlock = &workdistribute.getRegion().front();
+    auto *terminator = workdistributeBlock->getTerminator();
+    bool changed = false;
+    while (&workdistributeBlock->front() != terminator) {
+      rewriter.setInsertionPoint(teams);
+      IRMapping mapping;
+      llvm::SmallVector<Operation *> hoisted;
+      Operation *parallelize = nullptr;
+      for (auto &op : workdistribute.getOps()) {
+        if (&op == terminator) {
+          break;
         }
-        if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
-            mlir::emitError(loc, "teams with multiple nested ops\n");
-            return mlir::failure();
+        if (shouldParallelize(&op)) {
+          parallelize = &op;
+          break;
+        } else {
+          rewriter.clone(op, mapping);
+          hoisted.push_back(&op);
+          changed = true;
         }
-        mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
-        rewriter.eraseOp(workdistributeBlock->getTerminator());
-        rewriter.inlineBlockBefore(workdistributeBlock, teams);
-        rewriter.eraseOp(teams);
-        return mlir::success();
+      }
+
+      for (auto *op : hoisted)
+        rewriter.replaceOp(op, mapping.lookup(op));
+
+      if (parallelize && hoisted.empty() &&
+          parallelize->getNextNode() == terminator)
+        break;
+      if (parallelize) {
+        auto newTeams = rewriter.cloneWithoutRegions(teams);
+        auto *newTeamsBlock = rewriter.createBlock(
+            &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+        for (auto arg : teamsBlock->getArguments())
+          newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+        auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+        rewriter.create<omp::TerminatorOp>(loc);
+        rewriter.createBlock(&newWorkdistribute.getRegion(),
+                            newWorkdistribute.getRegion().begin(), {}, {});
+        auto *cloned = rewriter.clone(*parallelize);
+        rewriter.replaceOp(parallelize, cloned);
+        rewriter.create<omp::TerminatorOp>(loc);
+        changed = true;
+      }
     }
+    return success(changed);
+  }
 };
 
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
   void runOnOperation() override {
-    mlir::MLIRContext &context = getContext();
-    mlir::RewritePatternSet patterns(&context);
-    mlir::GreedyRewriteConfig config;
+    MLIRContext &context = getContext();
+    RewritePatternSet patterns(&context);
+    GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
     config.setRegionSimplificationLevel(
-        mlir::GreedySimplifyRegionLevel::Disabled);
+        GreedySimplifyRegionLevel::Disabled);
   
-    patterns.insert<WorkdistributeToSingle>(&context);
-    mlir::Operation *op = getOperation();
-    if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
-      mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+    patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
+    Operation *op = getOperation();
+    if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+      emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
       signalPassFailure();
     }
   }
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
new file mode 100644
index 0000000000000..ea03a10dd3d44
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -0,0 +1,60 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_fission_workdistribute({{.*}}) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
+// CHECK:             %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:             %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           return
+// CHECK:         }
+module {
+func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
+  return
+}
+func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
+  return
+}
+func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
+  %c0_idx = arith.constant 0 : index
+  %c1_idx = arith.constant 1 : index
+  %c9_idx = arith.constant 9 : index
+  %float_val = arith.constant 5.0 : f32
+  omp.teams   {
+    omp.workdistribute   {
+      fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
+      fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
+        %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
+        %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
+      }
+      fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
+      fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
+      fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
+        %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
+      }
+      %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
+      fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
+      omp.terminator  
+    }
+    omp.terminator
+  }
+  return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
similarity index 99%
rename from flang/test/Transforms/OpenMP/lower-workdistribute.mlir
rename to flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
index 34c8c3f01976d..0cc2aeded2532 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
@@ -49,4 +49,4 @@ func.func @_QPtarget_simple() {
         omp.terminator
     }
     return
-}
\ No newline at end of file
+}

>From 5b30d3dcb80cb4cef546f5bfdf3aa389f527d07d Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sun, 18 May 2025 12:37:53 +0530
Subject: [PATCH 10/17] [OpenMP][Flang] Lower teams workdistribute do_loop to
 wsloop.

Logic inspired from ivanradanov commit
5682e9ea7fcba64693f7cfdc0f1970fab2d7d4ae
---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 177 +++++++++++++++---
 .../OpenMP/lower-workdistribute-doloop.mlir   |  28 +++
 .../OpenMP/lower-workdistribute-fission.mlir  |  22 ++-
 3 files changed, 193 insertions(+), 34 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index f799202be2645..de208a8190650 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -6,18 +6,22 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements the lowering of omp.workdistribute.
+// This file implements the lowering and optimisations of omp.workdistribute.
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
@@ -29,6 +33,7 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
+#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 #include <variant>
 
@@ -87,25 +92,6 @@ static bool shouldParallelize(Operation *op) {
     return false;
 }
 
-struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
-    using OpRewritePattern::OpRewritePattern;
-    LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
-                                    PatternRewriter &rewriter) const override {
-        auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
-        if (!workdistributeOp) {
-            LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
-            return failure();
-        }
-      
-        Block *workdistributeBlock = &workdistributeOp.getRegion().front();
-        rewriter.eraseOp(workdistributeBlock->getTerminator());
-        rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
-        rewriter.eraseOp(teamsOp);
-        workdistributeOp.emitWarning("unable to parallelize coexecute");
-        return success();
-    }
-};
-
 /// If B() and D() are parallelizable,
 ///
 /// omp.teams {
@@ -210,22 +196,161 @@ struct FissionWorkdistribute
   }
 };
 
+static void
+genLoopNestClauseOps(mlir::Location loc,
+                     mlir::PatternRewriter &rewriter,
+                     fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void
+genWsLoopOp(mlir::PatternRewriter &rewriter,
+            fir::DoLoopOp doLoop,
+            const mlir::omp::LoopNestOperands &clauseOps) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    rewriter.eraseOp(terminatorOp);
+  }
+  return;
+}
+
+/// If fir.do_loop id present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to 
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     omp.parallel {
+///       omp.wsloop {
+///         omp.loop_nest
+///           ...
+///         }
+///       }
+///     }
+///   }
+/// }
+
+struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                PatternRewriter &rewriter) const override {
+    auto teamsLoc = teamsOp->getLoc();
+    auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+    if (!workdistributeOp) {
+      LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+      return failure();
+    }
+    assert(teamsOp.getReductionVars().empty());
+
+    auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
+    if (doLoop && shouldParallelize(doLoop)) {
+
+      auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
+      rewriter.createBlock(&parallelOp.getRegion());
+      rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
+
+      mlir::omp::LoopNestOperands loopNestClauseOps;
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
+                           loopNestClauseOps);
+
+      genWsLoopOp(rewriter, doLoop, loopNestClauseOps);
+      rewriter.setInsertionPoint(doLoop);
+      rewriter.eraseOp(doLoop);
+      return success();
+    }
+    return failure();
+  }
+};
+
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+                                  PatternRewriter &rewriter) const override {
+      auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+      if (!workdistributeOp) {
+          LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+          return failure();
+      }      
+      Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+      rewriter.eraseOp(workdistributeBlock->getTerminator());
+      rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+      rewriter.eraseOp(teamsOp);
+      return success();
+  }
+};
+
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
   void runOnOperation() override {
     MLIRContext &context = getContext();
-    RewritePatternSet patterns(&context);
     GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
     config.setRegionSimplificationLevel(
         GreedySimplifyRegionLevel::Disabled);
-  
-    patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
+    
     Operation *op = getOperation();
-    if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
-      emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
-      signalPassFailure();
+    {
+      RewritePatternSet patterns(&context);
+      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
+      if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+        emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+        signalPassFailure();
+      }
+    }
+    {
+      RewritePatternSet patterns(&context);
+      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
+      if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
+        emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
+        signalPassFailure();
+      }
     }
   }
 };
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
new file mode 100644
index 0000000000000..666bdb3ced647
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -0,0 +1,28 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x({{.*}})
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           omp.parallel {
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK:                 fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+  omp.teams {
+    omp.workdistribute { 
+      fir.do_loop %iv = %lb to %ub step %step unordered {
+        %zero = arith.constant 0 : index
+        fir.store %zero to %addr : !fir.ref<index>
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
\ No newline at end of file
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
index ea03a10dd3d44..cf50d135d01ec 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -6,20 +6,26 @@
 // CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
 // CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
-// CHECK:           fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
-// CHECK:             %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:             %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
-// CHECK:             %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:             fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:           omp.parallel {
+// CHECK:             omp.wsloop {
+// CHECK:               omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                 %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                 %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:                 %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                 fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:                 omp.yield
+// CHECK:               }
+// CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
 // CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
 // CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
-// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
 // CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
 // CHECK:           }
-// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
-// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref<f32>
 // CHECK:           return
 // CHECK:         }
 module {

>From df65bd53111948abf6f9c2e1e0b8e27aa5e01946 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 19 May 2025 15:33:53 +0530
Subject: [PATCH 11/17] clang format

---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  18 +--
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 108 +++++++++---------
 flang/lib/Parser/openmp-parsers.cpp           |   6 +-
 .../OpenMP/lower-workdistribute-doloop.mlir   |   2 +-
 4 files changed, 67 insertions(+), 67 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 42d04bceddb12..ebf0710ab4feb 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2670,14 +2670,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
-static mlir::omp::WorkdistributeOp
-genWorkdistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-            mlir::Location loc, const ConstructQueue &queue,
-            ConstructQueue::const_iterator item) {
+static mlir::omp::WorkdistributeOp genWorkdistributeOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
   return genOpWithBody<mlir::omp::WorkdistributeOp>(
-    OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                      llvm::omp::Directive::OMPD_workdistribute), queue, item);
+      OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+                        llvm::omp::Directive::OMPD_workdistribute),
+      queue, item);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3946,7 +3947,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
   }
   case llvm::omp::Directive::OMPD_workdistribute:
-    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
+                                item);
     break;
   case llvm::omp::Directive::OMPD_workshare:
     newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index de208a8190650..f75d4d1988fd2 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -14,15 +14,16 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
-#include "flang/Optimizer/Transforms/Passes.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/Utils/IndexingUtils.h>
@@ -33,7 +34,6 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
-#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 #include <variant>
 
@@ -66,30 +66,30 @@ static T getPerfectlyNested(Operation *op) {
 /// This is the single source of truth about whether we should parallelize an
 /// operation nested in an omp.workdistribute region.
 static bool shouldParallelize(Operation *op) {
-    // Currently we cannot parallelize operations with results that have uses
-    if (llvm::any_of(op->getResults(),
-                     [](OpResult v) -> bool { return !v.use_empty(); }))
+  // Currently we cannot parallelize operations with results that have uses
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
       return false;
-    // We will parallelize unordered loops - these come from array syntax
-    if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
-      auto unordered = loop.getUnordered();
-      if (!unordered)
-        return false;
-      return *unordered;
-    }
-    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
-      auto callee = callOp.getCallee();
-      if (!callee)
-        return false;
-      auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
-      // TODO need to insert a check here whether it is a call we can actually
-      // parallelize currently
-      if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
-        return true;
+    return *unordered;
+  }
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
       return false;
-    }
-    // We cannot parallise anything else
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    // TODO need to insert a check here whether it is a call we can actually
+    // parallelize currently
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
     return false;
+  }
+  // We cannot parallise anything else
+  return false;
 }
 
 /// If B() and D() are parallelizable,
@@ -120,12 +120,10 @@ static bool shouldParallelize(Operation *op) {
 /// }
 /// E()
 
-struct FissionWorkdistribute
-    : public OpRewritePattern<omp::WorkdistributeOp> {
+struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
   using OpRewritePattern::OpRewritePattern;
-  LogicalResult
-  matchAndRewrite(omp::WorkdistributeOp workdistribute,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
+                                PatternRewriter &rewriter) const override {
     auto loc = workdistribute->getLoc();
     auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
     if (!teams) {
@@ -185,7 +183,7 @@ struct FissionWorkdistribute
         auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
         rewriter.create<omp::TerminatorOp>(loc);
         rewriter.createBlock(&newWorkdistribute.getRegion(),
-                            newWorkdistribute.getRegion().begin(), {}, {});
+                             newWorkdistribute.getRegion().begin(), {}, {});
         auto *cloned = rewriter.clone(*parallelize);
         rewriter.replaceOp(parallelize, cloned);
         rewriter.create<omp::TerminatorOp>(loc);
@@ -197,8 +195,7 @@ struct FissionWorkdistribute
 };
 
 static void
-genLoopNestClauseOps(mlir::Location loc,
-                     mlir::PatternRewriter &rewriter,
+genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
                      fir::DoLoopOp loop,
                      mlir::omp::LoopNestOperands &loopNestClauseOps) {
   assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -209,10 +206,8 @@ genLoopNestClauseOps(mlir::Location loc,
   loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
 }
 
-static void
-genWsLoopOp(mlir::PatternRewriter &rewriter,
-            fir::DoLoopOp doLoop,
-            const mlir::omp::LoopNestOperands &clauseOps) {
+static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps) {
 
   auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
   rewriter.createBlock(&wsloopOp.getRegion());
@@ -236,7 +231,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
   return;
 }
 
-/// If fir.do_loop id present inside teams workdistribute
+/// If fir.do_loop is present inside teams workdistribute
 ///
 /// omp.teams {
 ///   omp.workdistribute {
@@ -246,7 +241,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
 ///   }
 /// }
 ///
-/// Then, its lowered to 
+/// Then, its lowered to
 ///
 /// omp.teams {
 ///   omp.workdistribute {
@@ -277,7 +272,8 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
 
       auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
       rewriter.createBlock(&parallelOp.getRegion());
-      rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
+      rewriter.setInsertionPoint(
+          rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
 
       mlir::omp::LoopNestOperands loopNestClauseOps;
       genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
@@ -292,7 +288,6 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
   }
 };
 
-
 /// If A() and B () are present inside teams workdistribute
 ///
 /// omp.teams {
@@ -311,17 +306,17 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
 struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
-                                  PatternRewriter &rewriter) const override {
-      auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
-      if (!workdistributeOp) {
-          LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
-          return failure();
-      }      
-      Block *workdistributeBlock = &workdistributeOp.getRegion().front();
-      rewriter.eraseOp(workdistributeBlock->getTerminator());
-      rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
-      rewriter.eraseOp(teamsOp);
-      return success();
+                                PatternRewriter &rewriter) const override {
+    auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+    if (!workdistributeOp) {
+      LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
+      return failure();
+    }
+    Block *workdistributeBlock = &workdistributeOp.getRegion().front();
+    rewriter.eraseOp(workdistributeBlock->getTerminator());
+    rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
+    rewriter.eraseOp(teamsOp);
+    return success();
   }
 };
 
@@ -332,13 +327,13 @@ class LowerWorkdistributePass
     MLIRContext &context = getContext();
     GreedyRewriteConfig config;
     // prevent the pattern driver form merging blocks
-    config.setRegionSimplificationLevel(
-        GreedySimplifyRegionLevel::Disabled);
-    
+    config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
+
     Operation *op = getOperation();
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
+      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(
+          &context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
@@ -346,7 +341,8 @@ class LowerWorkdistributePass
     }
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
+      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
+          &context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
@@ -354,4 +350,4 @@ class LowerWorkdistributePass
     }
   }
 };
-}
+} // namespace
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 5b5ee257edd1f..dc25adfe28c1d 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1344,12 +1344,14 @@ TYPE_PARSER(
         "SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
         "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
         "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
-        "TARGET TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
+        "TARGET TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
         "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
         "TARGET" >> pure(llvm::omp::Directive::OMPD_target),
         "TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
         "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
-        "TEAMS WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_workdistribute),
+        "TEAMS WORKDISTRIBUTE" >>
+            pure(llvm::omp::Directive::OMPD_teams_workdistribute),
         "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
         "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
         "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
index 666bdb3ced647..9fb970246b90c 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -25,4 +25,4 @@ func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<
     omp.terminator
   }
   return
-}
\ No newline at end of file
+}

>From 60351b6a73ed19de8531ac63336e17be7536cf48 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 27 May 2025 16:24:26 +0530
Subject: [PATCH 12/17] update to workdistribute lowering

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 194 ++++++++++--------
 .../OpenMP/lower-workdistribute-doloop.mlir   |  19 +-
 .../OpenMP/lower-workdistribute-fission.mlir  |  31 +--
 3 files changed, 139 insertions(+), 105 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index f75d4d1988fd2..c9c7827ace217 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -48,25 +48,21 @@ using namespace mlir;
 
 namespace {
 
-template <typename T>
-static T getPerfectlyNested(Operation *op) {
-  if (op->getNumRegions() != 1)
-    return nullptr;
-  auto &region = op->getRegion(0);
-  if (region.getBlocks().size() != 1)
-    return nullptr;
-  auto *block = &region.front();
-  auto *firstOp = &block->front();
-  if (auto nested = dyn_cast<T>(firstOp))
-    if (firstOp->getNextNode() == block->getTerminator())
-      return nested;
-  return nullptr;
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
 }
 
 /// This is the single source of truth about whether we should parallelize an
-/// operation nested in an omp.workdistribute region.
+/// operation nested in an omp.execute region.
 static bool shouldParallelize(Operation *op) {
-  // Currently we cannot parallelize operations with results that have uses
   if (llvm::any_of(op->getResults(),
                    [](OpResult v) -> bool { return !v.use_empty(); }))
     return false;
@@ -77,21 +73,28 @@ static bool shouldParallelize(Operation *op) {
       return false;
     return *unordered;
   }
-  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
-    auto callee = callOp.getCallee();
-    if (!callee)
-      return false;
-    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
-    // TODO need to insert a check here whether it is a call we can actually
-    // parallelize currently
-    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
-      return true;
-    return false;
+  if (isRuntimeCall(op)) {
+    return true;
   }
   // We cannot parallise anything else
   return false;
 }
 
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
 /// If B() and D() are parallelizable,
 ///
 /// omp.teams {
@@ -138,17 +141,33 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
       emitError(loc, "teams with multiple blocks\n");
       return failure();
     }
-    if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
-      emitError(loc, "teams with multiple nested ops\n");
-      return failure();
-    }
 
     auto *teamsBlock = &teams.getRegion().front();
+    bool changed = false;
+    // Move the ops inside teams and before workdistribute outside.
+    IRMapping irMapping;
+    llvm::SmallVector<Operation *> teamsHoisted;
+    for (auto &op : teams.getOps()) {
+      if (&op == workdistribute) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        emitError(loc,
+                  "teams has parallelize ops before first workdistribute\n");
+        return failure();
+      } else {
+        rewriter.setInsertionPoint(teams);
+        rewriter.clone(op, irMapping);
+        teamsHoisted.push_back(&op);
+        changed = true;
+      }
+    }
+    for (auto *op : teamsHoisted)
+      rewriter.replaceOp(op, irMapping.lookup(op));
 
     // While we have unhandled operations in the original workdistribute
     auto *workdistributeBlock = &workdistribute.getRegion().front();
     auto *terminator = workdistributeBlock->getTerminator();
-    bool changed = false;
     while (&workdistributeBlock->front() != terminator) {
       rewriter.setInsertionPoint(teams);
       IRMapping mapping;
@@ -194,9 +213,51 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
   }
 };
 
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+///   omp.parallel {
+///     omp.distribute {
+///     omp.wsloop {
+///       omp.loop_nest
+///         ...
+///       }
+///     }
+///   }
+/// }
+
+static void genParallelOp(Location loc, PatternRewriter &rewriter,
+                          bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+static void genDistributeOp(Location loc, PatternRewriter &rewriter,
+                            bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
 static void
-genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
-                     fir::DoLoopOp loop,
+genLoopNestClauseOps(mlir::PatternRewriter &rewriter, fir::DoLoopOp loop,
                      mlir::omp::LoopNestOperands &loopNestClauseOps) {
   assert(loopNestClauseOps.loopLowerBounds.empty() &&
          "Loop nest bounds were already emitted!");
@@ -207,9 +268,11 @@ genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
 }
 
 static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
-                        const mlir::omp::LoopNestOperands &clauseOps) {
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
 
   auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
   rewriter.createBlock(&wsloopOp.getRegion());
 
   auto loopNestOp =
@@ -231,57 +294,20 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
   return;
 }
 
-/// If fir.do_loop is present inside teams workdistribute
-///
-/// omp.teams {
-///   omp.workdistribute {
-///     fir.do_loop unoredered {
-///       ...
-///     }
-///   }
-/// }
-///
-/// Then, its lowered to
-///
-/// omp.teams {
-///   omp.workdistribute {
-///     omp.parallel {
-///       omp.wsloop {
-///         omp.loop_nest
-///           ...
-///         }
-///       }
-///     }
-///   }
-/// }
-
-struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
+struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
   using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
+  LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
                                 PatternRewriter &rewriter) const override {
-    auto teamsLoc = teamsOp->getLoc();
-    auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
-    if (!workdistributeOp) {
-      LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
-      return failure();
-    }
-    assert(teamsOp.getReductionVars().empty());
-
-    auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
+    auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+    auto wdLoc = workdistribute->getLoc();
     if (doLoop && shouldParallelize(doLoop)) {
-
-      auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
-      rewriter.createBlock(&parallelOp.getRegion());
-      rewriter.setInsertionPoint(
-          rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
-
+      assert(doLoop.getReduceOperands().empty());
+      genParallelOp(wdLoc, rewriter, true);
+      genDistributeOp(wdLoc, rewriter, true);
       mlir::omp::LoopNestOperands loopNestClauseOps;
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
-                           loopNestClauseOps);
-
-      genWsLoopOp(rewriter, doLoop, loopNestClauseOps);
-      rewriter.setInsertionPoint(doLoop);
-      rewriter.eraseOp(doLoop);
+      genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+      genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+      rewriter.eraseOp(workdistribute);
       return success();
     }
     return failure();
@@ -315,7 +341,7 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
     Block *workdistributeBlock = &workdistributeOp.getRegion().front();
     rewriter.eraseOp(workdistributeBlock->getTerminator());
     rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
-    rewriter.eraseOp(teamsOp);
+    rewriter.eraseOp(workdistributeOp);
     return success();
   }
 };
@@ -332,8 +358,7 @@ class LowerWorkdistributePass
     Operation *op = getOperation();
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(
-          &context);
+      patterns.insert<FissionWorkdistribute, WorkdistributeDoLower>(&context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
@@ -341,8 +366,7 @@ class LowerWorkdistributePass
     }
     {
       RewritePatternSet patterns(&context);
-      patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
-          &context);
+      patterns.insert<TeamsWorkdistributeToSingle>(&context);
       if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
         emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
         signalPassFailure();
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
index 9fb970246b90c..f8351bb64e6e8 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -2,13 +2,18 @@
 
 // CHECK-LABEL:   func.func @x({{.*}})
 // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
-// CHECK:           omp.parallel {
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
-// CHECK:                 fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:             }
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK:                     fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
index cf50d135d01ec..c562b7009664d 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -1,21 +1,26 @@
 // RUN: fir-opt --lower-workdistribute %s | FileCheck %s
 
-// CHECK-LABEL:   func.func @test_fission_workdistribute({{.*}}) {
+// CHECK-LABEL:   func.func @test_fission_workdistribute(
 // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
 // CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
-// CHECK:           omp.parallel {
-// CHECK:             omp.wsloop {
-// CHECK:               omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
-// CHECK:                 %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:                 %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
-// CHECK:                 %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
-// CHECK:                 fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
-// CHECK:                 omp.yield
-// CHECK:               }
-// CHECK:             }
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                     %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:                     %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
@@ -24,8 +29,8 @@
 // CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
 // CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
 // CHECK:           }
-// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref<f32>
-// CHECK:           fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref<f32>
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
 // CHECK:           return
 // CHECK:         }
 module {

>From fdc6938dff8456cf5864cc40b999e9855943e70b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 28 May 2025 21:41:25 +0530
Subject: [PATCH 13/17] Fix basic-program.fir test.

---
 flang/test/Fir/basic-program.fir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 7ac8b92f48953..a611629eeb280 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -69,6 +69,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:     InlineHLFIRAssign
 // PASSES-NEXT:   ConvertHLFIRtoFIR
 // PASSES-NEXT:   LowerWorkshare
+// PASSES-NEXT:   LowerWorkdistribute
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd

>From 6ecc39ff1d9aa80bf5be8b7a5144bc672c2d074e Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 30 May 2025 15:12:46 +0530
Subject: [PATCH 14/17] Wrap omp.target with omp.target_data

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 88 +++++++++++++++++++
 .../OpenMP/lower-workdistribute-target.mlir   | 36 ++++++++
 .../lower-workdistribute-to-single.mlir       | 52 -----------
 3 files changed, 124 insertions(+), 52 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
 delete mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index c9c7827ace217..6509cc5014dd7 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -346,6 +346,85 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
   }
 };
 
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct SplitTargetResult {
+  omp::TargetOp targetOp;
+  omp::TargetDataOp dataOp;
+};
+
+/// If multiple coexecutes are nested in a target regions, we will need to split
+/// the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
+
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+
+  // Collect all map_entries with capture(ByRef)
+  SmallVector<mlir::Value> byRefMapInfos;
+  SmallVector<omp::MapInfoOp> MapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    MapInfos.push_back(mapInfo);
+    if (mapInfo.getMapCaptureType() == omp::VariableCaptureKind::ByRef)
+      byRefMapInfos.push_back(opr);
+  }
+
+  // Create the new omp.target_data op with these collected map_entries
+  auto targetLoc = targetOp.getLoc();
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(loc, device, ifExpr, 
+                                                          mlir::ValueRange{byRefMapInfos},
+                                                          deviceAddrVars,
+                                                          devicePtrVars);
+
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+
+  // Clone mapInfo ops inside omp.target_data region
+  IRMapping mapping;
+  for (auto mapInfo : MapInfos) {
+    rewriter.clone(*mapInfo, mapping);
+  }
+  // Clone omp.target from exisiting targetOp inside target_data region.
+  auto newTargetOp = rewriter.clone(*targetOp, mapping);
+
+  // Erase TargetOp and its MapInfoOps
+  rewriter.eraseOp(targetOp);
+  
+  for (auto mapInfo : MapInfos) {
+    auto mapInfoRes = mapInfo.getResult();
+    if (mapInfoRes.getUsers().empty()) 
+      rewriter.eraseOp(mapInfo);
+  }
+  return SplitTargetResult{targetOp, targetDataOp};
+}                                                  
+
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
@@ -372,6 +451,15 @@ class LowerWorkdistributePass
         signalPassFailure();
       }
     }
+    {
+      SmallVector<omp::TargetOp> targetOps;
+      op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
+      IRRewriter rewriter(&context);
+      for (auto targetOp : targetOps) {
+        auto res = splitTargetData(targetOp, rewriter);
+      }
+    }
+
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
new file mode 100644
index 0000000000000..e6ca98d3bf596
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
@@ -0,0 +1,36 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_nested_derived_type_map_operand_and_block_addition(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:           %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+// CHECK:           %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"}
+// CHECK:           %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+// CHECK:           omp.target_data map_entries(%[[VAL_10]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:             %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+// CHECK:             omp.target map_entries(%[[VAL_11]] -> %[[VAL_12:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {        
+  %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+  %2 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %4 = fir.coordinate_of %2, i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+  %5 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%n%i"}
+  %7 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %9 = fir.coordinate_of %7, r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+  %10 = omp.map.info var_ptr(%9 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%n%r"}
+  %11 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+  omp.target map_entries(%11 -> %arg1 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
deleted file mode 100644
index 0cc2aeded2532..0000000000000
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
-
-// CHECK-LABEL:   func.func @_QPtarget_simple() {
-// CHECK:           %[[VAL_0:.*]] = arith.constant 2 : i32
-// CHECK:           %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
-// CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:           %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
-// CHECK:           %[[VAL_4:.*]] = fir.zero_bits !fir.heap<i32>
-// CHECK:           %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
-// CHECK:           fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.heap<i32>>>
-// CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
-// CHECK:           hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref<i32>
-// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
-// CHECK:           omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref<!fir.box<!fir.heap<i32>>>) {
-// CHECK:             %[[VAL_10:.*]] = arith.constant 10 : i32
-// CHECK:             %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:             %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
-// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i32>
-// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
-// CHECK:             hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
-// CHECK:             omp.terminator
-// CHECK:           }
-// CHECK:           return
-// CHECK:         }
-func.func @_QPtarget_simple() {
-    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
-    %1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-    %2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
-    %3 = fir.zero_bits !fir.heap<i32>
-    %4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
-    fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
-    %5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
-    %c2_i32 = arith.constant 2 : i32
-    hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
-    %6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
-    omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>){
-        omp.teams {
-            omp.workdistribute {
-                %11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-                %12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
-                %c10_i32 = arith.constant 10 : i32
-                %13 = fir.load %11#0 : !fir.ref<i32>
-                %14 = arith.addi %c10_i32, %13 : i32
-                hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
-                omp.terminator
-            }
-            omp.terminator
-        }
-        omp.terminator
-    }
-    return
-}

>From 432a1308e6f6b52ee0fc9f47d312da186e26d38e Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Jun 2025 15:47:08 +0530
Subject: [PATCH 15/17] Add fission of target region Logic inspired from
 ivanradanov llvm branch: flang_workdistribute_iwomp_2024 commit:
 a77451505dbd728a7a339f6c7c4c1382c709c502

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 437 +++++++++++++++++-
 .../lower-workdistribute-fission-target.mlir  | 104 +++++
 2 files changed, 521 insertions(+), 20 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 6509cc5014dd7..8f2de92cfd186 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -34,6 +34,7 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include <optional>
 #include <variant>
 
@@ -346,21 +347,6 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
   }
 };
 
-static std::optional<std::tuple<Operation *, bool, bool>>
-getNestedOpToIsolate(omp::TargetOp targetOp) {
-  auto *targetBlock = &targetOp.getRegion().front();
-  for (auto &op : *targetBlock) {
-    bool first = &op == &*targetBlock->begin();
-    bool last = op.getNextNode() == targetBlock->getTerminator();
-    if (first && last)
-      return std::nullopt;
-
-    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
-      return {{&op, first, last}};
-  }
-  return std::nullopt;
-}
-
 struct SplitTargetResult {
   omp::TargetOp targetOp;
   omp::TargetDataOp dataOp;
@@ -371,8 +357,7 @@ struct SplitTargetResult {
 /// original data region and avoid unnecessary data movement at each of the
 /// subkernels - we split the target region into a target_data{target}
 /// nest where only the outer one moves the data
-std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
-                                                 RewriterBase &rewriter) {
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp, RewriterBase &rewriter) {
 
   auto loc = targetOp->getLoc();
   if (targetOp.getMapVars().empty()) {
@@ -391,7 +376,6 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   }
 
   // Create the new omp.target_data op with these collected map_entries
-  auto targetLoc = targetOp.getLoc();
   rewriter.setInsertionPoint(targetOp);
   auto device = targetOp.getDevice();
   auto ifExpr = targetOp.getIfExpr();
@@ -422,8 +406,420 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
     if (mapInfoRes.getUsers().empty()) 
       rewriter.eraseOp(mapInfo);
   }
-  return SplitTargetResult{targetOp, targetDataOp};
-}                                                  
+  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar 
+allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) {
+  MLIRContext& ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  }
+  else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+      loc, alloc.getType(), alloc,
+      TypeAttr::get(allocType),
+      rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), mappingFlags),
+      rewriter.getAttr<omp::VariableCaptureKindAttr>(
+          omp::VariableCaptureKind::ByRef),
+      /*varPtrPtr=*/Value{},
+      /*members=*/SmallVector<Value>{},
+      /*member_index=*/mlir::ArrayAttr{},
+      /*bounds=*/ValueRange(),
+      /*mapperId=*/mlir::FlatSymbolRefAttr(), 
+      /*name=*/rewriter.getStringAttr(name),
+      rewriter.getBoolAttr(false));
+  };
+  uint64_t mapFrom = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+};
+
+static bool isOpToBeCached(Operation *op) {
+  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {  
+    Value memref = loadOp.getMemref();  
+    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {  
+      // 'op' is an operation within the targetOp that 'splitBefore' is also in.
+      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();  
+      // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.  
+      // This implies the load is from a variable directly mapped into the target region.  
+      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&  
+          !parentOpOfLoadBlock->getRegions().empty()) {  
+        Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front();  
+        if (blockArg.getOwner() == targetOpEntryBlock) {  
+          // This load is from a direct argument of the target op.  
+          // It's safe to recompute.
+          return false;  
+        }  
+      }  
+    }  
+  }
+  return true;
+}
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {  
+    Value memref = loadOp.getMemref();  
+    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {  
+      // 'op' is an operation within the targetOp that 'splitBefore' is also in.
+      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();  
+      // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.  
+      // This implies the load is from a variable directly mapped into the target region.  
+      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&  
+          !parentOpOfLoadBlock->getRegions().empty()) {  
+        Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front();  
+        if (blockArg.getOwner() == targetOpEntryBlock) {  
+          // This load is from a direct argument of the target op.  
+          // It's safe to recompute.
+          return true;  
+        }  
+      }  
+    }  
+  } 
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+  return false;
+}
+
+struct SplitResult {
+  omp::TargetOp preTargetOp;
+  omp::TargetOp isolatedTargetOp;
+  omp::TargetOp postTargetOp;
+};
+
+static void collectNonRecomputableDeps(Value& v,
+                                omp::TargetOp targetOp,
+                                SetVector<Operation *>& nonRecomputable,
+                                SetVector<Operation *>& toCache,
+                                SetVector<Operation *>& toRecompute) {
+  Operation *op = v.getDefiningOp();
+  if (!op) {
+    assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
+    return;
+  }
+  if (nonRecomputable.contains(op)) {
+    toCache.insert(op);
+    return;
+  }
+  toRecompute.insert(op);
+  for (auto opr : op->getOperands())
+    collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, toRecompute);
+}
+
+
+static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
+                        MLIRContext& ctx,
+                        IRMapping &mapping, Operation *splitBefore,
+                        Block *targetBlock, Block *newTargetBlock,
+                        SmallVector<Value>& allocs,
+                        SetVector<Operation *>& toRecompute) {
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = newTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    mapping.map(originalArg, newArg);
+  }
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  for (auto original : allocs) {
+    Value newArg = newTargetBlock->addArgument(
+      getPtrTypeForOmp(original.getType()), original.getLoc());
+    Value restored;
+    if (isPtr(original.getType())) {
+      restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+      if (!isa<LLVM::LLVMPointerType>(original.getType()))
+        restored = rewriter.create<UnrealizedConversionCastOp>(loc, original.getType(), ValueRange(restored))
+                           .getResult(0);
+    } 
+    else {
+        restored = rewriter.create<fir::LoadOp>(loc, newArg);
+    }
+    mapping.map(original, restored);
+  }
+  for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
+    if (toRecompute.contains(&*it))
+      rewriter.clone(*it, mapping);
+  }
+}
+
+static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
+                              RewriterBase &rewriter) {
+  auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+  MLIRContext& ctx = *targetOp.getContext();
+  assert(targetOp);
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  rewriter.setInsertionPoint(targetOp);
+   
+  auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+  auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+  SmallVector<Value> requiredVals;
+  SetVector<Operation *> toCache;
+  SetVector<Operation *> toRecompute;
+  SetVector<Operation *> nonRecomputable;
+  SmallVector<Value> allocs;
+
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) {
+    for (auto res : it->getResults()) {
+      if (usedOutsideSplit(res, splitBeforeOp))
+        requiredVals.push_back(res);
+    }
+    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+        nonRecomputable.insert(&*it);
+  }
+
+  for (auto requiredVal : requiredVals)
+    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, toRecompute);
+  
+  for (Operation *op : toCache) {
+    for (auto res : op->getResults()) {
+      auto alloc = allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+      allocs.push_back(res);
+      preMapOperands.push_back(alloc.from);
+      postMapOperands.push_back(alloc.to);
+    }
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto preTargetOp = rewriter.create<omp::TargetOp>(
+        targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
+        targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
+        targetOp.getDependVars(), targetOp.getDevice(),
+        targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
+        targetOp.getIfExpr(), targetOp.getInReductionVars(),
+        targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+        targetOp.getIsDevicePtrVars(), preMapOperands,
+        targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+        targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
+        targetOp.getPrivateMapsAttr()); 
+  auto *preTargetBlock = rewriter.createBlock(
+      &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+  IRMapping preMapping;
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = preTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    preMapping.map(originalArg, newArg);
+  }
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
+    rewriter.clone(*it, preMapping);
+
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+
+  for (auto original : allocs) {
+    Value toStore = preMapping.lookup(original);
+    auto newArg = preTargetBlock->addArgument(
+        getPtrTypeForOmp(original.getType()), original.getLoc());
+    if (isPtr(original.getType())) {
+      if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+        toStore = rewriter.create<UnrealizedConversionCastOp>(loc, llvmPtrTy,
+                                                           ValueRange(toStore))
+                      .getResult(0);
+      rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+    } else {
+      rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+    }
+  }
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
+      targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
+      targetOp.getDependVars(), targetOp.getDevice(),
+      targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
+      targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+      targetOp.getIsDevicePtrVars(), postMapOperands,
+      targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
+      targetOp.getPrivateMapsAttr()); 
+
+  auto *isolatedTargetBlock =
+        rewriter.createBlock(&isolatedTargetOp.getRegion(),
+                             isolatedTargetOp.getRegion().begin(), {}, {});
+
+  IRMapping isolatedMapping;
+  reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
+                          targetBlock, isolatedTargetBlock,
+                          allocs, toRecompute);
+  rewriter.clone(*splitBeforeOp, isolatedMapping);
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  omp::TargetOp postTargetOp = nullptr;
+  
+  if (splitAfter) {
+      rewriter.setInsertionPoint(targetOp);
+    postTargetOp = rewriter.create<omp::TargetOp>(
+        targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
+        targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
+        targetOp.getDependVars(), targetOp.getDevice(),
+        targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
+        targetOp.getIfExpr(), targetOp.getInReductionVars(),
+        targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+        targetOp.getIsDevicePtrVars(), postMapOperands,
+        targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+        targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
+        targetOp.getPrivateMapsAttr()); 
+    auto *postTargetBlock = rewriter.createBlock(
+          &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+    IRMapping postMapping;
+    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, 
+                            targetBlock, postTargetBlock,
+                            allocs, toRecompute);
+
+    assert(splitBeforeOp->getNumResults() == 0 ||
+             llvm::all_of(splitBeforeOp->getResults(),
+                          [](Value result) { return result.use_empty(); }));
+
+    for (auto it = std::next(splitBeforeOp->getIterator());
+         it != targetBlock->end(); it++)
+      rewriter.clone(*it, postMapping);
+  }
+
+  rewriter.eraseOp(targetOp);
+  return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+}
+
+static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *targetBlock = &targetOp.getRegion().front();
+  assert(targetBlock == &targetOp.getRegion().back());
+  IRMapping mapping;
+  for (auto map :
+       zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
+    Value mapInfo = std::get<0>(map);
+    BlockArgument arg = std::get<1>(map);
+    Operation *op = mapInfo.getDefiningOp();
+    assert(op);
+    auto mapInfoOp = cast<omp::MapInfoOp>(op);
+    mapping.map(arg, mapInfoOp.getVarPtr());
+  }
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Operation *> opsToMove;
+  for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
+       it != end; ++it) {
+    auto *op = &*it;
+    auto allocOp = dyn_cast<fir::AllocMemOp>(op);
+    auto freeOp = dyn_cast<fir::FreeMemOp>(op);
+    fir::CallOp runtimeCall = nullptr;
+    if (isRuntimeCall(op))
+      runtimeCall = cast<fir::CallOp>(op);
+
+    if (allocOp || freeOp || runtimeCall)
+        continue;
+    opsToMove.push_back(op);
+  }
+  // Move ops before targetOp and erase from region
+  for (Operation *op : opsToMove)
+    rewriter.clone(*op, mapping);
+  
+  rewriter.eraseOp(targetOp);
+}
+
+void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  auto tuple = getNestedOpToIsolate(targetOp);
+  if (!tuple) {
+    LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
+    //moveToHost(targetOp, rewriter);
+    return;
+  }
+
+  Operation *toIsolate = std::get<0>(*tuple);
+  bool splitBefore = !std::get<1>(*tuple);
+  bool splitAfter = !std::get<2>(*tuple);
+
+  if (splitBefore && splitAfter) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    //moveToHost(res.preTargetOp, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+  if (splitBefore) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    //moveToHost(res.preTargetOp, rewriter);
+    return;
+  }
+  if (splitAfter) {
+    assert(false && "TODO");
+    auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+}
 
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
@@ -457,6 +853,7 @@ class LowerWorkdistributePass
       IRRewriter rewriter(&context);
       for (auto targetOp : targetOps) {
         auto res = splitTargetData(targetOp, rewriter);
+        if (res) fissionTarget(res->targetOp, rewriter);
       }
     }
 
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
new file mode 100644
index 0000000000000..ed6c641f2e934
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -0,0 +1,104 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x
+// CHECK:           %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
+// CHECK:           fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref<index>
+// CHECK:           %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"}
+// CHECK:           fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref<index>
+// CHECK:           %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"}
+// CHECK:           fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref<index>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+// CHECK:             %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:             %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:             %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:             %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:             %[[VAL_11:.*]] = fir.alloca !fir.heap<index>
+// CHECK:             %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_14:.*]], %[[VAL_8]] -> %[[VAL_15:.*]], %[[VAL_9]] -> %[[VAL_16:.*]], %[[VAL_10]] -> %[[VAL_17:.*]], %[[VAL_12]] -> %[[VAL_18:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_19:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_20:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK:               %[[VAL_21:.*]] = fir.load %[[VAL_15]] : !fir.ref<index>
+// CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_16]] : !fir.ref<index>
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_24:.*]] = fir.allocmem index, %[[VAL_19]] {uniq_name = "dev_buf"}
+// CHECK:               fir.store %[[VAL_24]] to %[[VAL_18]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_25:.*]], %[[VAL_8]] -> %[[VAL_26:.*]], %[[VAL_9]] -> %[[VAL_27:.*]], %[[VAL_10]] -> %[[VAL_28:.*]], %[[VAL_13]] -> %[[VAL_29:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_30:.*]] = fir.load %[[VAL_29]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_31:.*]] = fir.load %[[VAL_25]] : !fir.ref<index>
+// CHECK:               %[[VAL_32:.*]] = fir.load %[[VAL_26]] : !fir.ref<index>
+// CHECK:               %[[VAL_33:.*]] = fir.load %[[VAL_27]] : !fir.ref<index>
+// CHECK:               %[[VAL_34:.*]] = arith.addi %[[VAL_32]], %[[VAL_32]] : index
+// CHECK:               omp.teams {
+// CHECK:                 omp.parallel {
+// CHECK:                   omp.distribute {
+// CHECK:                     omp.wsloop {
+// CHECK:                       omp.loop_nest (%[[VAL_35:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) {
+// CHECK:                         fir.store %[[VAL_34]] to %[[VAL_30]] : !fir.heap<index>
+// CHECK:                         omp.yield
+// CHECK:                       }
+// CHECK:                     } {omp.composite}
+// CHECK:                   } {omp.composite}
+// CHECK:                   omp.terminator
+// CHECK:                 } {omp.composite}
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_36:.*]], %[[VAL_8]] -> %[[VAL_37:.*]], %[[VAL_9]] -> %[[VAL_38:.*]], %[[VAL_10]] -> %[[VAL_39:.*]], %[[VAL_13]] -> %[[VAL_40:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_36]] : !fir.ref<index>
+// CHECK:               %[[VAL_43:.*]] = fir.load %[[VAL_37]] : !fir.ref<index>
+// CHECK:               %[[VAL_44:.*]] = fir.load %[[VAL_38]] : !fir.ref<index>
+// CHECK:               %[[VAL_45:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index
+// CHECK:               fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap<index>
+// CHECK:               fir.freemem %[[VAL_41]] : !fir.heap<index>
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
+  %lb_ref = fir.alloca index {bindc_name = "lb"}
+  fir.store %lb to %lb_ref : !fir.ref<index>
+  %ub_ref = fir.alloca index {bindc_name = "ub"}
+  fir.store %ub to %ub_ref : !fir.ref<index>
+  %step_ref = fir.alloca index {bindc_name = "step"}
+  fir.store %step to %step_ref : !fir.ref<index>
+
+  %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+  %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+  %step_map = omp.map.info var_ptr(%step_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+  %addr_map = omp.map.info var_ptr(%addr : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+
+  omp.target map_entries(%lb_map -> %arg0, %ub_map -> %arg1, %step_map -> %arg2, %addr_map -> %arg3 : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+    %lb_val = fir.load %arg0 : !fir.ref<index>
+    %ub_val = fir.load %arg1 : !fir.ref<index>
+    %step_val = fir.load %arg2 : !fir.ref<index>
+    %one = arith.constant 1 : index
+
+    %20 = arith.addi %ub_val, %ub_val : index
+    omp.teams {
+      omp.workdistribute {
+        %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"}
+        fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered {
+          fir.store %20 to %dev_mem : !fir.heap<index>
+        }
+        fir.store %lb_val to %dev_mem : !fir.heap<index>
+        fir.freemem %dev_mem : !fir.heap<index>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}

>From f858541c715643642ca89bda134d3de9449c656d Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 6 Jun 2025 14:20:08 +0530
Subject: [PATCH 16/17] Use fir.convert instead of unrealised cast

---
 flang/lib/Lower/OpenMP/OpenMP.cpp                      | 10 ++++++++++
 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp     |  7 ++-----
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp       |  3 +++
 3 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ebf0710ab4feb..be14d10c5914b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -561,6 +561,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
       });
       break;
 
+    case OMPD_teams_workdistribute:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_workdistribute:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
+      });
+      break;
+
     case OMPD_teams_distribute:
     case OMPD_teams_distribute_simd:
       cp.processThreadLimit(stmtCtx, hostInfo.ops);
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 8f2de92cfd186..6d6de47f7741e 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -597,8 +597,7 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
     if (isPtr(original.getType())) {
       restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
       if (!isa<LLVM::LLVMPointerType>(original.getType()))
-        restored = rewriter.create<UnrealizedConversionCastOp>(loc, original.getType(), ValueRange(restored))
-                           .getResult(0);
+        restored = rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
     } 
     else {
         restored = rewriter.create<fir::LoadOp>(loc, newArg);
@@ -684,9 +683,7 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
         getPtrTypeForOmp(original.getType()), original.getLoc());
     if (isPtr(original.getType())) {
       if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
-        toStore = rewriter.create<UnrealizedConversionCastOp>(loc, llvmPtrTy,
-                                                           ValueRange(toStore))
-                      .getResult(0);
+        toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
       rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
     } else {
       rewriter.create<fir::StoreOp>(loc, toStore, newArg);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 010c46358f7df..40d178e7ea1ff 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5047,6 +5047,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
+  if (targetOp.getHostEvalVars().empty())
+    numLoops = 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);

>From 79dd250a44ec1f2b82781ded2cd206d9758fcf60 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 10 Jun 2025 18:46:57 +0530
Subject: [PATCH 17/17] [Flang] Add fir omp target alloc and free ops

This commit is CP from ivanradanov commit
be860ac8baf24b8405e6f396c75d7f0d26375de5
---
 .../include/flang/Optimizer/Dialect/FIROps.td | 61 +++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 458b780806144..102bfe80c318d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -517,6 +517,67 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> {
   let assemblyFormat = "type($intype) attr-dict";
 }
 
+def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
+    [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+  let summary = "allocate storage on an openmp device for an object of a given type";
+
+  let description = [{
+    Creates a heap memory reference suitable for storing a value of the
+    given type, T.  The heap refernce returned has type `!fir.heap<T>`.
+    The memory object is in an undefined state.  `allocmem` operations must
+    be paired with `freemem` operations to avoid memory leaks.
+
+    ```
+      %0 = fir.omp_target_allocmem !fir.array<10 x f32>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<AnyIntegerType>:$device,
+    TypeAttr:$in_type,
+    OptionalAttr<StrAttr>:$uniq_name,
+    OptionalAttr<StrAttr>:$bindc_name,
+    Variadic<AnyIntegerType>:$typeparams,
+    Variadic<AnyIntegerType>:$shape
+  );
+  let results = (outs fir_HeapType);
+
+  let extraClassDeclaration = [{
+    mlir::Type getAllocatedType();
+    bool hasLenParams() { return !getTypeparams().empty(); }
+    bool hasShapeOperands() { return !getShape().empty(); }
+    unsigned numLenParams() { return getTypeparams().size(); }
+    operand_range getLenParams() { return getTypeparams(); }
+    unsigned numShapeOperands() { return getShape().size(); }
+    operand_range getShapeOperands() { return getShape(); }
+    static mlir::Type getRefTy(mlir::Type ty);
+  }];
+}
+
+def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
+  [MemoryEffects<[MemFree]>]> {
+  let summary = "free a heap object";
+
+  let description = [{
+    Deallocates a heap memory reference that was allocated by an `allocmem`.
+    The memory object that is deallocated is placed in an undefined state
+    after `fir.freemem`.  Optimizations may treat the loading of an object
+    in the undefined state as undefined behavior.  This includes aliasing
+    references, such as the result of an `fir.embox`.Add commentMore actions
+
+    ```
+      %21 = fir.omp_target_allocmem !fir.type<ZT(p:i32){field:i32}>
+      ...
+      fir.omp_target_freemem %21 : !fir.heap<!fir.type<ZT>>
+    ```
+  }];
+
+  let arguments = (ins
+  Arg<AnyIntegerType, "", [MemFree]>:$device,
+  Arg<fir_HeapType, "", [MemFree]>:$heapref
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // Terminator operations
 //===----------------------------------------------------------------------===//



More information about the flang-commits mailing list