[Mlir-commits] [flang] [mlir] [flang] Implement workdistribute construct lowering (PR #140523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 19 06:21:45 PDT 2025


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/140523

>From 16a024c84bb31b20f07569a2598eda275c6bbbc1 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 4 Dec 2023 12:57:36 -0800
Subject: [PATCH 01/13] [Flang] Introduce lower-workdistribute pass for
 workdistribute lowering.

Co-authors: ivanradanov, skc7
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   4 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 901 ++++++++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   4 +-
 flang/test/Fir/basic-program.fir              |   1 +
 .../OpenMP/lower-workdistribute-doloop.mlir   |  33 +
 .../lower-workdistribute-fission-target.mlir  | 112 +++
 .../OpenMP/lower-workdistribute-fission.mlir  |  71 ++
 .../OpenMP/lower-workdistribute-target.mlir   |  32 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |   3 +
 10 files changed, 1161 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index e2f092024c250..bfbaa5f838e90 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e0aebd0714c8f..5454a71ff54a6 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
   SimdOnly.cpp
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..0885efc716db4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,901 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.execute region.
+static bool shouldParallelize(Operation *op) {
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  if (isRuntimeCall(op)) {
+    return true;
+  }
+  // We cannot parallise anything else
+  return false;
+}
+
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return false;
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+///   omp.parallel {
+///     omp.distribute {
+///     omp.wsloop {
+///       omp.loop_nest
+///         ...
+///       }
+///     }
+///   }
+/// }
+
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    // rewriter.erase(terminatorOp);
+    terminatorOp->erase();
+  }
+  return;
+}
+
+static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+  auto wdLoc = workdistribute->getLoc();
+  if (doLoop && shouldParallelize(doLoop)) {
+    assert(doLoop.getReduceOperands().empty());
+    genParallelOp(wdLoc, rewriter, true);
+    genDistributeOp(wdLoc, rewriter, true);
+    mlir::omp::LoopNestOperands loopNestClauseOps;
+    genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+    genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+    workdistribute.erase();
+    return true;
+  }
+  return false;
+}
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+  auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+  if (!workdistributeOp)
+    return false;
+  // Get the block containing teamsOp (the parent block).
+  Block *parentBlock = teamsOp->getBlock();
+  Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  auto insertPoint = Block::iterator(teamsOp);
+  // Get the range of operations to move (excluding the terminator).
+  auto workdistributeBegin = workdistributeBlock.begin();
+  auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+  // Move the operations from workdistribute block to before teamsOp.
+  parentBlock->getOperations().splice(insertPoint,
+                                      workdistributeBlock.getOperations(),
+                                      workdistributeBegin, workdistributeEnd);
+  // Erase the now-empty workdistributeOp.
+  workdistributeOp.erase();
+  Block &teamsBlock = *teamsOp.getRegion().begin();
+  // Check if only the terminator remains and erase teams op.
+  if (teamsBlock.getOperations().size() == 1 &&
+      teamsBlock.getTerminator() != nullptr) {
+    teamsOp.erase();
+  }
+  return true;
+}
+
+struct SplitTargetResult {
+  omp::TargetOp targetOp;
+  omp::TargetDataOp dataOp;
+};
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+
+  SmallVector<omp::MapInfoOp> mapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    mapInfos.push_back(mapInfo);
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Value> innerMapInfos;
+  SmallVector<Value> outerMapInfos;
+
+  for (auto mapInfo : mapInfos) {
+    auto originalMapType =
+        (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+    auto originalCaptureType = mapInfo.getMapCaptureType();
+    llvm::omp::OpenMPOffloadMappingFlags newMapType;
+    mlir::omp::VariableCaptureKind newCaptureType;
+
+    if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+      newMapType = originalMapType;
+      newCaptureType = originalCaptureType;
+    } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+      newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+      newCaptureType = originalCaptureType;
+      outerMapInfos.push_back(mapInfo);
+    } else {
+      llvm_unreachable("Unhandled case");
+    }
+    auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+    innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+        rewriter.getIntegerType(64, false),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            newMapType)));
+    innerMapInfo.setMapCaptureType(newCaptureType);
+    innerMapInfos.push_back(innerMapInfo.getResult());
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+
+  auto newTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+                              newTargetOp.getRegion().begin());
+
+  rewriter.replaceOp(targetOp, newTargetOp);
+  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  } else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+  };
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+};
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+  return false;
+}
+
+struct SplitResult {
+  omp::TargetOp preTargetOp;
+  omp::TargetOp isolatedTargetOp;
+  omp::TargetOp postTargetOp;
+};
+
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+                                       SetVector<Operation *> &nonRecomputable,
+                                       SetVector<Operation *> &toCache,
+                                       SetVector<Operation *> &toRecompute) {
+  Operation *op = v.getDefiningOp();
+  if (!op) {
+    assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
+    return;
+  }
+  if (nonRecomputable.contains(op)) {
+    toCache.insert(op);
+    return;
+  }
+  toRecompute.insert(op);
+  for (auto opr : op->getOperands())
+    collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+}
+
+static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
+                                    MLIRContext &ctx, IRMapping &mapping,
+                                    Operation *splitBefore, Block *targetBlock,
+                                    Block *newTargetBlock,
+                                    SmallVector<Value> &allocs,
+                                    SetVector<Operation *> &toRecompute) {
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = newTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    mapping.map(originalArg, newArg);
+  }
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  for (auto original : allocs) {
+    Value newArg = newTargetBlock->addArgument(
+        getPtrTypeForOmp(original.getType()), original.getLoc());
+    Value restored;
+    if (isPtr(original.getType())) {
+      restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+      if (!isa<LLVM::LLVMPointerType>(original.getType()))
+        restored =
+            rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+    } else {
+      restored = rewriter.create<fir::LoadOp>(loc, newArg);
+    }
+    mapping.map(original, restored);
+  }
+  for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
+    if (toRecompute.contains(&*it))
+      rewriter.clone(*it, mapping);
+  }
+}
+
+static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
+                             RewriterBase &rewriter) {
+  auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+  MLIRContext &ctx = *targetOp.getContext();
+  assert(targetOp);
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  rewriter.setInsertionPoint(targetOp);
+
+  auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+  auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+  SmallVector<Value> requiredVals;
+  SetVector<Operation *> toCache;
+  SetVector<Operation *> toRecompute;
+  SetVector<Operation *> nonRecomputable;
+  SmallVector<Value> allocs;
+
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
+    for (auto res : it->getResults()) {
+      if (usedOutsideSplit(res, splitBeforeOp))
+        requiredVals.push_back(res);
+    }
+    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+      nonRecomputable.insert(&*it);
+  }
+
+  for (auto requiredVal : requiredVals)
+    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+
+  for (Operation *op : toCache) {
+    for (auto res : op->getResults()) {
+      auto alloc =
+          allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+      allocs.push_back(res);
+      preMapOperands.push_back(alloc.from);
+      postMapOperands.push_back(alloc.to);
+    }
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto preTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  auto *preTargetBlock = rewriter.createBlock(
+      &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+  IRMapping preMapping;
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = preTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    preMapping.map(originalArg, newArg);
+  }
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
+    rewriter.clone(*it, preMapping);
+
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+  for (auto original : allocs) {
+    Value toStore = preMapping.lookup(original);
+    auto newArg = preTargetBlock->addArgument(
+        getPtrTypeForOmp(original.getType()), original.getLoc());
+    if (isPtr(original.getType())) {
+      if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+        toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
+      rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+    } else {
+      rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+    }
+  }
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+
+  auto *isolatedTargetBlock =
+      rewriter.createBlock(&isolatedTargetOp.getRegion(),
+                           isolatedTargetOp.getRegion().begin(), {}, {});
+
+  IRMapping isolatedMapping;
+  reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
+                          targetBlock, isolatedTargetBlock, allocs,
+                          toRecompute);
+  rewriter.clone(*splitBeforeOp, isolatedMapping);
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  omp::TargetOp postTargetOp = nullptr;
+
+  if (splitAfter) {
+    rewriter.setInsertionPoint(targetOp);
+    postTargetOp = rewriter.create<omp::TargetOp>(
+        targetOp.getLoc(), targetOp.getAllocateVars(),
+        targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+        targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+        targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+        targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+        targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+        targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+        postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+        targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+        targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+    auto *postTargetBlock = rewriter.createBlock(
+        &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+    IRMapping postMapping;
+    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
+                            targetBlock, postTargetBlock, allocs, toRecompute);
+
+    assert(splitBeforeOp->getNumResults() == 0 ||
+           llvm::all_of(splitBeforeOp->getResults(),
+                        [](Value result) { return result.use_empty(); }));
+
+    for (auto it = std::next(splitBeforeOp->getIterator());
+         it != targetBlock->end(); it++)
+      rewriter.clone(*it, postMapping);
+  }
+
+  rewriter.eraseOp(targetOp);
+  return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+}
+
+static mlir::LLVM::ConstantOp
+genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
+  mlir::Type i32Ty = rewriter.getI32Type();
+  mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
+}
+
+static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
+
+static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *targetBlock = &targetOp.getRegion().front();
+  assert(targetBlock == &targetOp.getRegion().back());
+  IRMapping mapping;
+  for (auto map :
+       zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
+    Value mapInfo = std::get<0>(map);
+    BlockArgument arg = std::get<1>(map);
+    Operation *op = mapInfo.getDefiningOp();
+    assert(op);
+    auto mapInfoOp = cast<omp::MapInfoOp>(op);
+    mapping.map(arg, mapInfoOp.getVarPtr());
+  }
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Operation *> opsToReplace;
+  Value device = targetOp.getDevice();
+  if (!device) {
+    device = genI32Constant(targetOp.getLoc(), rewriter, 0);
+  }
+  for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
+       it != end; ++it) {
+    auto *op = &*it;
+    if (isRuntimeCall(op)) {
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      auto module = runtimeCall->getParentOfType<ModuleOp>();
+      auto callee =
+          cast<func::FuncOp>(module.lookupSymbol(runtimeCall.getCalleeAttr()));
+      std::string newCalleeName = (callee.getName() + "_omp").str();
+      mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+      func::FuncOp newCallee =
+          cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
+      if (!newCallee) {
+        SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
+        argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
+        newCallee = moduleBuilder.create<func::FuncOp>(
+            callee->getLoc(), newCalleeName,
+            FunctionType::get(rewriter.getContext(), argTypes,
+                              callee.getFunctionType().getResults()));
+        if (callee.getArgAttrs())
+          newCallee.setArgAttrsAttr(*callee.getArgAttrs());
+        if (callee.getResAttrs())
+          newCallee.setResAttrsAttr(*callee.getResAttrs());
+        newCallee.setSymVisibility(callee.getSymVisibility());
+        newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary());
+      }
+      SmallVector<Value> operands = runtimeCall.getOperands();
+      operands.push_back(device);
+      auto tmpCall = rewriter.create<fir::CallOp>(
+          runtimeCall.getLoc(), runtimeCall.getResultTypes(),
+          SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
+          runtimeCall.getFastmathAttr());
+      Operation *newCall = rewriter.clone(*tmpCall, mapping);
+      mapping.map(&*it, newCall);
+      rewriter.eraseOp(tmpCall);
+    } else {
+      Operation *clonedOp = rewriter.clone(*op, mapping);
+      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+        opsToReplace.push_back(clonedOp);
+      for (unsigned i = 0; i < op->getNumResults(); ++i) {
+        mapping.map(op->getResult(i), clonedOp->getResult(i));
+      }
+    }
+  }
+  for (Operation *op : opsToReplace) {
+    if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
+      rewriter.setInsertionPoint(allocOp);
+      auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
+          allocOp.getLoc(), rewriter.getI64Type(), device,
+          allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
+          allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
+          allocOp.getShape());
+      auto firConvertOp = rewriter.create<fir::ConvertOp>(
+          allocOp.getLoc(), allocOp.getResult().getType(),
+          ompAllocmemOp.getResult());
+      rewriter.replaceOp(allocOp, firConvertOp.getResult());
+    } else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
+      rewriter.setInsertionPoint(freeOp);
+      auto firConvertOp = rewriter.create<fir::ConvertOp>(
+          freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
+      rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
+                                            firConvertOp.getResult());
+      rewriter.eraseOp(freeOp);
+    }
+  }
+  rewriter.eraseOp(targetOp);
+}
+
+void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  auto tuple = getNestedOpToIsolate(targetOp);
+  if (!tuple) {
+    LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
+    moveToHost(targetOp, rewriter);
+    return;
+  }
+
+  Operation *toIsolate = std::get<0>(*tuple);
+  bool splitBefore = !std::get<1>(*tuple);
+  bool splitAfter = !std::get<2>(*tuple);
+
+  if (splitBefore && splitAfter) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    moveToHost(res.preTargetOp, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+  if (splitBefore) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    moveToHost(res.preTargetOp, rewriter);
+    return;
+  }
+  if (splitAfter) {
+    auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+}
+
+class LowerWorkdistributePass
+    : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    auto moduleOp = getOperation();
+    bool changed = false;
+    moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+      changed |= FissionWorkdistribute(workdistribute);
+    });
+    moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+      changed |= WorkdistributeDoLower(workdistribute);
+    });
+    moduleOp->walk([&](mlir::omp::TeamsOp teams) {
+      changed |= TeamsWorkdistributeToSingleOp(teams);
+    });
+
+    if (changed) {
+      SmallVector<omp::TargetOp> targetOps;
+      moduleOp->walk(
+          [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
+      IRRewriter rewriter(&context);
+      for (auto targetOp : targetOps) {
+        auto res = splitTargetData(targetOp, rewriter);
+        if (res)
+          fissionTarget(res->targetOp, rewriter);
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 7c2777baebef1..f17cf31d8930a 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -292,8 +292,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP != EnableOpenMP::None)
+  if (enableOpenMP != EnableOpenMP::None) {
     pm.addPass(flangomp::createLowerWorkshare());
+    pm.addPass(flangomp::createLowerWorkdistribute());
+  }
   if (enableOpenMP == EnableOpenMP::Simd)
     pm.addPass(flangomp::createSimdOnlyPass());
 }
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a31397efb332..9aa3ccadf6385 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -69,6 +69,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:     InlineHLFIRAssign
 // PASSES-NEXT:   ConvertHLFIRtoFIR
 // PASSES-NEXT:   LowerWorkshare
+// PASSES-NEXT:   LowerWorkdistribute
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
new file mode 100644
index 0000000000000..00d10d6264ec9
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -0,0 +1,33 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x({{.*}})
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK:                     %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:                     fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+  omp.teams {
+    omp.workdistribute { 
+      fir.do_loop %iv = %lb to %ub step %step unordered {
+        %zero = arith.constant 0 : index
+        fir.store %zero to %addr : !fir.ref<index>
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
new file mode 100644
index 0000000000000..19bdb9ce10fbd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -0,0 +1,112 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x(
+// CHECK:           %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
+// CHECK:           fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref<index>
+// CHECK:           %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"}
+// CHECK:           fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref<index>
+// CHECK:           %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"}
+// CHECK:           fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref<index>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+// CHECK:             %[[VAL_11:.*]] = fir.alloca index
+// CHECK:             %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_14:.*]] = fir.alloca index
+// CHECK:             %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_17:.*]] = fir.alloca index
+// CHECK:             %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_20:.*]] = fir.alloca !fir.heap<index>
+// CHECK:             %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_23:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref<index>
+// CHECK:             %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref<index>
+// CHECK:             %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
+// CHECK:             %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array<i32: 1, 0, 1>, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap<index>
+// CHECK:             fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index
+// CHECK:               omp.teams {
+// CHECK:                 omp.parallel {
+// CHECK:                   omp.distribute {
+// CHECK:                     omp.wsloop {
+// CHECK:                       omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) {
+// CHECK:                         fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap<index>
+// CHECK:                         omp.yield
+// CHECK:                       }
+// CHECK:                     } {omp.composite}
+// CHECK:                   } {omp.composite}
+// CHECK:                   omp.terminator
+// CHECK:                 } {omp.composite}
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
+// CHECK:             %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK:             %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
+// CHECK:             %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index
+// CHECK:             fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap<index>
+// CHECK:             %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap<index>) -> ()
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
+  %lb_ref = fir.alloca index {bindc_name = "lb"}
+  fir.store %lb to %lb_ref : !fir.ref<index>
+  %ub_ref = fir.alloca index {bindc_name = "ub"}
+  fir.store %ub to %ub_ref : !fir.ref<index>
+  %step_ref = fir.alloca index {bindc_name = "step"}
+  fir.store %step to %step_ref : !fir.ref<index>
+
+  %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+  %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+  %step_map = omp.map.info var_ptr(%step_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+  %addr_map = omp.map.info var_ptr(%addr : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+
+  omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+    %lb_val = fir.load %ARG0 : !fir.ref<index>
+    %ub_val = fir.load %ARG1 : !fir.ref<index>
+    %step_val = fir.load %ARG2 : !fir.ref<index>
+    %one = arith.constant 1 : index
+
+    %20 = arith.addi %ub_val, %ub_val : index
+    omp.teams {
+      omp.workdistribute {
+        %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"}
+        fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered {
+          fir.store %20 to %dev_mem : !fir.heap<index>
+        }
+        fir.store %lb_val to %dev_mem : !fir.heap<index>
+        fir.freemem %dev_mem : !fir.heap<index>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
new file mode 100644
index 0000000000000..c562b7009664d
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -0,0 +1,71 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_fission_workdistribute(
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                     %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:                     %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           return
+// CHECK:         }
+module {
+func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
+  return
+}
+func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
+  return
+}
+func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
+  %c0_idx = arith.constant 0 : index
+  %c1_idx = arith.constant 1 : index
+  %c9_idx = arith.constant 9 : index
+  %float_val = arith.constant 5.0 : f32
+  omp.teams   {
+    omp.workdistribute   {
+      fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
+      fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
+        %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
+        %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
+      }
+      fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
+      fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
+      fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
+        %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
+      }
+      %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
+      fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
+      omp.terminator  
+    }
+    omp.terminator
+  }
+  return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
new file mode 100644
index 0000000000000..d96068b26ca2f
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
@@ -0,0 +1,32 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_nested_derived_type_map_operand_and_block_addition(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:           %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+// CHECK:           %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"}
+// CHECK:           %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+// CHECK:           omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {        
+  %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+  %2 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %4 = fir.coordinate_of %2, i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+  %5 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%n%i"}
+  %7 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %9 = fir.coordinate_of %7, r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+  %10 = omp.map.info var_ptr(%9 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%n%r"}
+  %11 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+  omp.target map_entries(%11 -> %arg1 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4e26e65cf9718..b8b9a0338e9ac 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5400,6 +5400,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
+  if (targetOp.getHostEvalVars().empty())
+    numLoops = 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);

>From 0f80e2847d879b85195b25d32bd30154a9906ae5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 25 Jul 2025 22:20:12 +0530
Subject: [PATCH 02/13] Fix hoisting declare ops out of omp.target

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 34 +++++++++++++++++--
 1 file changed, 31 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 0885efc716db4..3f78727450f31 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
 
-  rewriter.replaceOp(targetOp, newTargetOp);
+  rewriter.replaceOp(targetOp, targetDataOp);
   return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
 }
 
@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       rewriter.eraseOp(tmpCall);
     } else {
       Operation *clonedOp = rewriter.clone(*op, mapping);
-      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
-        opsToReplace.push_back(clonedOp);
       for (unsigned i = 0; i < op->getNumResults(); ++i) {
         mapping.map(op->getResult(i), clonedOp->getResult(i));
       }
+      // fir.declare changes its type when hoisting it out of omp.target to
+      // omp.target_data Introduce a load, if original declareOp input is not of
+      // reference type, but cloned delcareOp input is reference type.
+      if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
+        auto originalDeclareOp = cast<fir::DeclareOp>(op);
+        Type originalInType = originalDeclareOp.getMemref().getType();
+        Type clonedInType = clonedDeclareOp.getMemref().getType();
+
+        fir::ReferenceType originalRefType =
+            dyn_cast<fir::ReferenceType>(originalInType);
+        fir::ReferenceType clonedRefType =
+            dyn_cast<fir::ReferenceType>(clonedInType);
+        if (!originalRefType && clonedRefType) {
+          Type clonedEleTy = clonedRefType.getElementType();
+          if (clonedEleTy == originalDeclareOp.getType()) {
+            opsToReplace.push_back(clonedOp);
+          }
+        }
+      }
+      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+        opsToReplace.push_back(clonedOp);
     }
   }
   for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
                                             firConvertOp.getResult());
       rewriter.eraseOp(freeOp);
+    } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
+      Type clonedInType = clonedDeclareOp.getMemref().getType();
+      fir::ReferenceType clonedRefType =
+          dyn_cast<fir::ReferenceType>(clonedInType);
+      Type clonedEleTy = clonedRefType.getElementType();
+      rewriter.setInsertionPoint(op);
+      Value loadedValue = rewriter.create<fir::LoadOp>(
+          clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
+      clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
     }
   }
   rewriter.eraseOp(targetOp);

>From 9af51bc49d7ed9e6e1368d892a456265b908f976 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sun, 3 Aug 2025 15:26:54 +0530
Subject: [PATCH 03/13] Handle case when private maps are present in omp.target

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 133 +++++++++++++-----
 1 file changed, 95 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 3f78727450f31..e61240e8aa443 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -566,22 +566,60 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
                                toRecompute);
 }
 
-static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
-                                    MLIRContext &ctx, IRMapping &mapping,
-                                    Operation *splitBefore, Block *targetBlock,
-                                    Block *newTargetBlock,
-                                    SmallVector<Value> &allocs,
-                                    SetVector<Operation *> &toRecompute) {
-  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
-    auto originalArg = targetBlock->getArgument(i);
+static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
+                                  omp::TargetOp &targetOp, Block *targetBlock,
+                                  Block *newTargetBlock,
+                                  SmallVector<Value> &mapOperands,
+                                  SmallVector<Value> &allocs,
+                                  IRMapping &irMapping) {
+  // Map `map_operands` to block arguments.
+  unsigned originalMapVarsSize = targetOp.getMapVars().size();
+  for (unsigned i = 0; i < mapOperands.size(); ++i) {
+    Value originalValue;
+    BlockArgument newArg;
+    // Map the new arguments from the original block.
+    if (i < originalMapVarsSize) {
+      originalValue = targetBlock->getArgument(i);
+      newArg = newTargetBlock->addArgument(originalValue.getType(),
+                                           originalValue.getLoc());
+    }
+    // Map the new arguments from the `allocs`.
+    else {
+      originalValue = allocs[i - originalMapVarsSize];
+      newArg = newTargetBlock->addArgument(
+          getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc());
+    }
+    irMapping.map(originalValue, newArg);
+  }
+  // Map `private_vars` to block arguments.
+  unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size();
+  for (unsigned i = 0; i < originalPrivateVarsSize; ++i) {
+    auto originalArg = targetBlock->getArgument(originalMapVarsSize + i);
     auto newArg = newTargetBlock->addArgument(originalArg.getType(),
                                               originalArg.getLoc());
-    mapping.map(originalArg, newArg);
+    irMapping.map(originalArg, newArg);
   }
-  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
-  for (auto original : allocs) {
-    Value newArg = newTargetBlock->addArgument(
-        getPtrTypeForOmp(original.getType()), original.getLoc());
+  return;
+}
+
+static void reloadCacheAndRecompute(
+    Location loc, RewriterBase &rewriter, Operation *splitBefore,
+    omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
+    SmallVector<Value> &mapOperands, SmallVector<Value> &allocs,
+    SetVector<Operation *> &toRecompute, IRMapping &irMapping) {
+  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, newTargetBlock,
+                        mapOperands, allocs, irMapping);
+  // Handle the load operations for the allocs.
+  rewriter.setInsertionPointToStart(newTargetBlock);
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+  unsigned originalMapVarsSize = targetOp.getMapVars().size();
+  // Create Stores for allocs.
+  for (unsigned i = 0; i < allocs.size(); ++i) {
+    Value original = allocs[i];
+    // Get the new block argument for this specific allocated value.
+    Value newArg = newTargetBlock->getArgument(originalMapVarsSize + i);
+
     Value restored;
     if (isPtr(original.getType())) {
       restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
@@ -591,18 +629,18 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
     } else {
       restored = rewriter.create<fir::LoadOp>(loc, newArg);
     }
-    mapping.map(original, restored);
+    irMapping.map(original, restored);
   }
+
   for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
     if (toRecompute.contains(&*it))
-      rewriter.clone(*it, mapping);
+      rewriter.clone(*it, irMapping);
   }
 }
 
 static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
                              RewriterBase &rewriter) {
   auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
-  MLIRContext &ctx = *targetOp.getContext();
   assert(targetOp);
   auto loc = targetOp.getLoc();
   auto *targetBlock = &targetOp.getRegion().front();
@@ -657,22 +695,29 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
   auto *preTargetBlock = rewriter.createBlock(
       &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
   IRMapping preMapping;
-  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
-    auto originalArg = targetBlock->getArgument(i);
-    auto newArg = preTargetBlock->addArgument(originalArg.getType(),
-                                              originalArg.getLoc());
-    preMapping.map(originalArg, newArg);
-  }
-  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
-    rewriter.clone(*it, preMapping);
 
+  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock,
+                        preMapOperands, allocs, preMapping);
+
+  // Handle the store operations for the allocs.
+  rewriter.setInsertionPointToStart(preTargetBlock);
   auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
 
-  for (auto original : allocs) {
-    Value toStore = preMapping.lookup(original);
-    auto newArg = preTargetBlock->addArgument(
-        getPtrTypeForOmp(original.getType()), original.getLoc());
-    if (isPtr(original.getType())) {
+  // Clone the original operations.
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
+    rewriter.clone(*it, preMapping);
+  }
+
+  unsigned originalMapVarsSize = targetOp.getMapVars().size();
+  // Create Stores for allocs.
+  for (unsigned i = 0; i < allocs.size(); ++i) {
+    Value originalResult = allocs[i];
+    Value toStore = preMapping.lookup(originalResult);
+    // Get the new block argument for this specific allocated value.
+    Value newArg = preTargetBlock->getArgument(originalMapVarsSize + i);
+
+    if (isPtr(originalResult.getType())) {
       if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
         toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
       rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
@@ -701,9 +746,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
                            isolatedTargetOp.getRegion().begin(), {}, {});
 
   IRMapping isolatedMapping;
-  reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
-                          targetBlock, isolatedTargetBlock, allocs,
-                          toRecompute);
+  reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+                          isolatedTargetBlock, postMapOperands, allocs,
+                          toRecompute, isolatedMapping);
   rewriter.clone(*splitBeforeOp, isolatedMapping);
   rewriter.create<omp::TerminatorOp>(loc);
 
@@ -725,8 +770,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
     auto *postTargetBlock = rewriter.createBlock(
         &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
     IRMapping postMapping;
-    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
-                            targetBlock, postTargetBlock, allocs, toRecompute);
+    reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+                            postTargetBlock, postMapOperands, allocs,
+                            toRecompute, postMapping);
 
     assert(splitBeforeOp->getNumResults() == 0 ||
            llvm::all_of(splitBeforeOp->getResults(),
@@ -755,15 +801,24 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
   Block *targetBlock = &targetOp.getRegion().front();
   assert(targetBlock == &targetOp.getRegion().back());
   IRMapping mapping;
-  for (auto map :
-       zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
-    Value mapInfo = std::get<0>(map);
-    BlockArgument arg = std::get<1>(map);
+  for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) {
+    Value mapInfo = targetOp.getMapVars()[i];
+    BlockArgument arg = targetBlock->getArguments()[i];
     Operation *op = mapInfo.getDefiningOp();
     assert(op);
     auto mapInfoOp = cast<omp::MapInfoOp>(op);
+    // map the block argument to the host-side variable pointer
     mapping.map(arg, mapInfoOp.getVarPtr());
   }
+  unsigned mapSize = targetOp.getMapVars().size();
+  for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) {
+    Value privateVar = targetOp.getPrivateVars()[i];
+    // The mapping should link the device-side variable to the host-side one.
+    BlockArgument arg = targetBlock->getArguments()[mapSize + i];
+    // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
+    mapping.map(arg, privateVar);
+  }
+
   rewriter.setInsertionPoint(targetOp);
   SmallVector<Operation *> opsToReplace;
   Value device = targetOp.getDevice();
@@ -813,6 +868,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       // fir.declare changes its type when hoisting it out of omp.target to
       // omp.target_data Introduce a load, if original declareOp input is not of
       // reference type, but cloned delcareOp input is reference type.
+
       if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
         auto originalDeclareOp = cast<fir::DeclareOp>(op);
         Type originalInType = originalDeclareOp.getMemref().getType();
@@ -833,6 +889,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
         opsToReplace.push_back(clonedOp);
     }
   }
+
   for (Operation *op : opsToReplace) {
     if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
       rewriter.setInsertionPoint(allocOp);

>From 919268b12666b7c066ae49c92b6570f7b9d17809 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 4 Aug 2025 11:52:27 +0530
Subject: [PATCH 04/13] Add comments/description for functions.

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 197 +++++++++++-------
 1 file changed, 121 insertions(+), 76 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index e61240e8aa443..ece64a1ba1d4d 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -1,4 +1,5 @@
-//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//===- LowerWorkdistribute.cpp
+//-------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,6 +9,16 @@
 //
 // This file implements the lowering and optimisations of omp.workdistribute.
 //
+// Fortran array statements are lowered to fir as fir.do_loop unordered.
+// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
+// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
+// lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
+// It hoists all the other ops outside target region.
+// Relaces heap allocation on target with omp.target_allocmem and
+// deallocation with omp.target_freemem from host. Also replaces
+// runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign
+// on target, once hoisted outside target is replaced with @_FortranAAssign_omp.
+//
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -49,6 +60,8 @@ using namespace mlir;
 
 namespace {
 
+// The isRuntimeCall function is a utility designed to determine
+// if a given operation is a call to a Fortran-specific runtime function.
 static bool isRuntimeCall(Operation *op) {
   if (auto callOp = dyn_cast<fir::CallOp>(op)) {
     auto callee = callOp.getCallee();
@@ -61,8 +74,8 @@ static bool isRuntimeCall(Operation *op) {
   return false;
 }
 
-/// This is the single source of truth about whether we should parallelize an
-/// operation nested in an omp.execute region.
+// This is the single source of truth about whether we should parallelize an
+// operation nested in an omp.execute region.
 static bool shouldParallelize(Operation *op) {
   if (llvm::any_of(op->getResults(),
                    [](OpResult v) -> bool { return !v.use_empty(); }))
@@ -74,13 +87,16 @@ static bool shouldParallelize(Operation *op) {
       return false;
     return *unordered;
   }
-  if (isRuntimeCall(op)) {
+  if (isRuntimeCall(op) &&
+      (op->getName().getStringRef() == "_FortranAAssign")) {
     return true;
   }
-  // We cannot parallise anything else
+  // We cannot parallise anything else.
   return false;
 }
 
+// The getPerfectlyNested function is a generic utility for finding
+// a single, "perfectly nested" operation within a parent operation.
 template <typename T>
 static T getPerfectlyNested(Operation *op) {
   if (op->getNumRegions() != 1)
@@ -96,33 +112,37 @@ static T getPerfectlyNested(Operation *op) {
   return nullptr;
 }
 
-/// If B() and D() are parallelizable,
-///
-/// omp.teams {
-///   omp.workdistribute {
-///     A()
-///     B()
-///     C()
-///     D()
-///     E()
-///   }
-/// }
-///
-/// becomes
-///
-/// A()
-/// omp.teams {
-///   omp.workdistribute {
-///     B()
-///   }
-/// }
-/// C()
-/// omp.teams {
-///   omp.workdistribute {
-///     D()
-///   }
-/// }
-/// E()
+// FissionWorkdistribute method finds the parallelizable ops
+// within teams {workdistribute} region and moves them to their
+// own teams{workdistribute} region.
+//
+// If B() and D() are parallelizable,
+//
+// omp.teams {
+//   omp.workdistribute {
+//     A()
+//     B()
+//     C()
+//     D()
+//     E()
+//   }
+// }
+//
+// becomes
+//
+// A()
+// omp.teams {
+//   omp.workdistribute {
+//     B()
+//   }
+// }
+// C()
+// omp.teams {
+//   omp.workdistribute {
+//     D()
+//   }
+// }
+// E()
 
 static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
   OpBuilder rewriter(workdistribute);
@@ -215,29 +235,6 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
   return changed;
 }
 
-/// If fir.do_loop is present inside teams workdistribute
-///
-/// omp.teams {
-///   omp.workdistribute {
-///     fir.do_loop unoredered {
-///       ...
-///     }
-///   }
-/// }
-///
-/// Then, its lowered to
-///
-/// omp.teams {
-///   omp.parallel {
-///     omp.distribute {
-///     omp.wsloop {
-///       omp.loop_nest
-///         ...
-///       }
-///     }
-///   }
-/// }
-
 static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
   auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
   parallelOp.setComposite(composite);
@@ -295,6 +292,33 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
   return;
 }
 
+// WorkdistributeDoLower method finds the fir.do_loop unoredered
+// nested in teams {workdistribute{fir.do_loop unoredered}} and
+// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
+//
+// If fir.do_loop is present inside teams workdistribute
+//
+// omp.teams {
+//   omp.workdistribute {
+//     fir.do_loop unoredered {
+//       ...
+//     }
+//   }
+// }
+//
+// Then, its lowered to
+//
+// omp.teams {
+//   omp.parallel {
+//     omp.distribute {
+//     omp.wsloop {
+//       omp.loop_nest
+//         ...
+//       }
+//     }
+//   }
+// }
+
 static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
   OpBuilder rewriter(workdistribute);
   auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
@@ -312,20 +336,23 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
   return false;
 }
 
-/// If A() and B () are present inside teams workdistribute
-///
-/// omp.teams {
-///   omp.workdistribute {
-///     A()
-///     B()
-///   }
-/// }
-///
-/// Then, its lowered to
-///
-/// A()
-/// B()
-///
+// TeamsWorkdistributeToSingleOp method hoists all the ops inside
+// teams {workdistribute{}} before teams op.
+//
+// If A() and B () are present inside teams workdistribute
+//
+// omp.teams {
+//   omp.workdistribute {
+//     A()
+//     B()
+//   }
+// }
+//
+// Then, its lowered to
+//
+// A()
+// B()
+//
 
 static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
   auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
@@ -358,11 +385,11 @@ struct SplitTargetResult {
   omp::TargetDataOp dataOp;
 };
 
-/// If multiple workdistribute are nested in a target regions, we will need to
-/// split the target region, but we want to preserve the data semantics of the
-/// original data region and avoid unnecessary data movement at each of the
-/// subkernels - we split the target region into a target_data{target}
-/// nest where only the outer one moves the data
+// If multiple workdistribute are nested in a target regions, we will need to
+// split the target region, but we want to preserve the data semantics of the
+// original data region and avoid unnecessary data movement at each of the
+// subkernels - we split the target region into a target_data{target}
+// nest where only the outer one moves the data
 std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
                                                  RewriterBase &rewriter) {
   auto loc = targetOp->getLoc();
@@ -438,6 +465,10 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
 }
 
+// getNestedOpToIsolate function is designed to identify a specific teams
+// parallel op within the body of an omp::TargetOp that should be "isolated."
+// This returns a tuple of op, if its first op in targetBlock, or if the op is
+// last op in the tragte block.
 static std::optional<std::tuple<Operation *, bool, bool>>
 getNestedOpToIsolate(omp::TargetOp targetOp) {
   if (targetOp.getRegion().empty())
@@ -638,6 +669,15 @@ static void reloadCacheAndRecompute(
   }
 }
 
+// isolateOp method rewrites a omp.target_data { omp.target } in to
+// omp.target_data {
+//      // preTargetOp region contains ops before splitBeforeOp.
+//      omp.target {}
+//      // isolatedTargetOp region contains splitBeforeOp,
+//      omp.target {}
+//      // postTargetOp region contains ops after splitBeforeOp.
+//      omp.target {}
+// }
 static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
                              RewriterBase &rewriter) {
   auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
@@ -796,6 +836,10 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
 
 static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
 
+// moveToHost method clones all the ops from target region outside of it.
+// It hoists runtime functions and replaces them with omp vesions.
+// Also hoists and replaces fir.allocmem with omp.target_allocmem and
+// fir.freemem with omp.target_freemem
 static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
   OpBuilder::InsertionGuard guard(rewriter);
   Block *targetBlock = &targetOp.getRegion().front();
@@ -815,7 +859,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
     Value privateVar = targetOp.getPrivateVars()[i];
     // The mapping should link the device-side variable to the host-side one.
     BlockArgument arg = targetBlock->getArguments()[mapSize + i];
-    // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
+    // Map the device-side copy (arg) to the host-side value (privateVar).
     mapping.map(arg, privateVar);
   }
 
@@ -868,7 +912,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       // fir.declare changes its type when hoisting it out of omp.target to
       // omp.target_data Introduce a load, if original declareOp input is not of
       // reference type, but cloned delcareOp input is reference type.
-
       if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
         auto originalDeclareOp = cast<fir::DeclareOp>(op);
         Type originalInType = originalDeclareOp.getMemref().getType();
@@ -890,6 +933,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
     }
   }
 
+  // Replace fir.allocmem with omp.target_allocmem,
+  // fir.freemem with omp.target_freemem.
   for (Operation *op : opsToReplace) {
     if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
       rewriter.setInsertionPoint(allocOp);

>From 5b4832ea546fc6bc2f3716f4ab36ca4789de852e Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 28 Aug 2025 16:53:33 +0530
Subject: [PATCH 05/13] update moveToHost implementation

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 259 +++++++++++++-----
 .../lower-workdistribute-fission-target.mlir  |  42 +--
 2 files changed, 214 insertions(+), 87 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index ece64a1ba1d4d..8ead3d57eca98 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -834,13 +834,140 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
 }
 
-static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
+static mlir::LLVM::ConstantOp
+genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
+  mlir::Type i64Ty = rewriter.getI64Type();
+  mlir::IntegerAttr attr = rewriter.getI64IntegerAttr(value);
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
+}
+
+static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
+                                         Location loc, Value boxDesc) {
+  Value box = boxDesc;
+  if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+    box = fir::LoadOp::create(builder, loc, boxDesc);
+  }
+  assert(isa<fir::BoxType>(box.getType()) &&
+         "Unknown type passed to genDescriptorGetBaseAddress");
+  auto i8Type = builder.getI8Type();
+  auto unknownArrayType =
+      fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type);
+  auto i8BoxType = fir::BoxType::get(unknownArrayType);
+  auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box);
+  auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox);
+  return rawAddr;
+}
+
+static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder,
+                                           Location loc, Value boxDesc) {
+  Value box = boxDesc;
+  if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+    box = fir::LoadOp::create(builder, loc, boxDesc);
+  }
+  assert(isa<fir::BoxType>(box.getType()) &&
+         "Unknown type passed to genDescriptorGetTotalElements");
+  auto i64Type = builder.getI64Type();
+  return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box);
+}
+
+static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc,
+                                     Value boxDesc) {
+  Value box = boxDesc;
+  if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+    box = fir::LoadOp::create(builder, loc, boxDesc);
+  }
+  assert(isa<fir::BoxType>(box.getType()) &&
+         "Unknown type passed to genDescriptorGetElementSize");
+  auto i64Type = builder.getI64Type();
+  return fir::BoxEleSizeOp::create(builder, loc, i64Type, box);
+}
+
+static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
+                                             Location loc, Value boxDesc) {
+  Value box = boxDesc;
+  if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+    box = fir::LoadOp::create(builder, loc, boxDesc);
+  }
+  assert(isa<fir::BoxType>(box.getType()) &&
+         "Unknown type passed to genDescriptorGetElementSize");
+  Value eleSize = genDescriptorGetEleSize(builder, loc, box);
+  Value totalElements = genDescriptorGetTotalElements(builder, loc, box);
+  return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize);
+}
+
+static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder,
+                                               mlir::Location loc,
+                                               mlir::Value hostPtr,
+                                               mlir::Value deviceNum,
+                                               mlir::ModuleOp module) {
+  auto *context = builder.getContext();
+  auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
+  auto i32Type = builder.getI32Type();
+  auto funcName = "omp_get_mapped_ptr";
+  auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
+
+  if (!funcOp) {
+    auto funcType =
+        mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType});
+
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(module.getBody());
+
+    funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
+    funcOp.setPrivate();
+  }
+
+  llvm::SmallVector<mlir::Value> args;
+  args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr));
+  args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum));
+  auto callOp = fir::CallOp::create(builder, loc, funcOp, args);
+  auto mappedPtr = callOp.getResult(0);
+  auto isNull = builder.genIsNullAddr(loc, mappedPtr);
+  auto convertedHostPtr =
+      fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr);
+  auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr,
+                                        mappedPtr);
+  return result;
+}
+
+static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
+                                   mlir::Location loc, mlir::Value dst,
+                                   mlir::Value src, mlir::Value length,
+                                   mlir::Value dstOffset, mlir::Value srcOffset,
+                                   mlir::Value device, mlir::ModuleOp module) {
+  auto *context = builder.getContext();
+  // int omp_target_memcpy(void *dst, const void *src, size_t length,
+  //                       size_t dst_offset, size_t src_offset,
+  //                       int dst_device, int src_device)
+  auto funcName = "omp_target_memcpy";
+  auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
+  auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit
+  auto i32Type = builder.getI32Type();
+  auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
+
+  if (!funcOp) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(module.getBody());
+    llvm::SmallVector<mlir::Type> argTypes = {
+        voidPtrType, voidPtrType, sizeTType, sizeTType,
+        sizeTType,   i32Type,     i32Type};
+    auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type});
+    funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
+    funcOp.setPrivate();
+  }
+
+  llvm::SmallVector<mlir::Value> args{dst,       src,    length, dstOffset,
+                                      srcOffset, device, device};
+  fir::CallOp::create(builder, loc, funcOp, args);
+  return;
+}
 
 // moveToHost method clones all the ops from target region outside of it.
 // It hoists runtime functions and replaces them with omp vesions.
 // Also hoists and replaces fir.allocmem with omp.target_allocmem and
 // fir.freemem with omp.target_freemem
-static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
+static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
+                       mlir::ModuleOp module) {
   OpBuilder::InsertionGuard guard(rewriter);
   Block *targetBlock = &targetOp.getRegion().front();
   assert(targetBlock == &targetOp.getRegion().back());
@@ -859,7 +986,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
     Value privateVar = targetOp.getPrivateVars()[i];
     // The mapping should link the device-side variable to the host-side one.
     BlockArgument arg = targetBlock->getArguments()[mapSize + i];
-    // Map the device-side copy (arg) to the host-side value (privateVar).
+    // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
     mapping.map(arg, privateVar);
   }
 
@@ -872,69 +999,43 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
   for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
        it != end; ++it) {
     auto *op = &*it;
-    if (isRuntimeCall(op)) {
-      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
-      auto module = runtimeCall->getParentOfType<ModuleOp>();
-      auto callee =
-          cast<func::FuncOp>(module.lookupSymbol(runtimeCall.getCalleeAttr()));
-      std::string newCalleeName = (callee.getName() + "_omp").str();
-      mlir::OpBuilder moduleBuilder(module.getBodyRegion());
-      func::FuncOp newCallee =
-          cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
-      if (!newCallee) {
-        SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
-        argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
-        newCallee = moduleBuilder.create<func::FuncOp>(
-            callee->getLoc(), newCalleeName,
-            FunctionType::get(rewriter.getContext(), argTypes,
-                              callee.getFunctionType().getResults()));
-        if (callee.getArgAttrs())
-          newCallee.setArgAttrsAttr(*callee.getArgAttrs());
-        if (callee.getResAttrs())
-          newCallee.setResAttrsAttr(*callee.getResAttrs());
-        newCallee.setSymVisibility(callee.getSymVisibility());
-        newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary());
-      }
-      SmallVector<Value> operands = runtimeCall.getOperands();
-      operands.push_back(device);
-      auto tmpCall = rewriter.create<fir::CallOp>(
-          runtimeCall.getLoc(), runtimeCall.getResultTypes(),
-          SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
-          runtimeCall.getFastmathAttr());
-      Operation *newCall = rewriter.clone(*tmpCall, mapping);
-      mapping.map(&*it, newCall);
-      rewriter.eraseOp(tmpCall);
-    } else {
-      Operation *clonedOp = rewriter.clone(*op, mapping);
-      for (unsigned i = 0; i < op->getNumResults(); ++i) {
-        mapping.map(op->getResult(i), clonedOp->getResult(i));
-      }
-      // fir.declare changes its type when hoisting it out of omp.target to
-      // omp.target_data Introduce a load, if original declareOp input is not of
-      // reference type, but cloned delcareOp input is reference type.
-      if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
-        auto originalDeclareOp = cast<fir::DeclareOp>(op);
-        Type originalInType = originalDeclareOp.getMemref().getType();
-        Type clonedInType = clonedDeclareOp.getMemref().getType();
-
-        fir::ReferenceType originalRefType =
-            dyn_cast<fir::ReferenceType>(originalInType);
-        fir::ReferenceType clonedRefType =
-            dyn_cast<fir::ReferenceType>(clonedInType);
-        if (!originalRefType && clonedRefType) {
-          Type clonedEleTy = clonedRefType.getElementType();
-          if (clonedEleTy == originalDeclareOp.getType()) {
-            opsToReplace.push_back(clonedOp);
-          }
+    Operation *clonedOp = rewriter.clone(*op, mapping);
+    for (unsigned i = 0; i < op->getNumResults(); ++i) {
+      mapping.map(op->getResult(i), clonedOp->getResult(i));
+    }
+    // fir.declare changes its type when hoisting it out of omp.target to
+    // omp.target_data Introduce a load, if original declareOp input is not of
+    // reference type, but cloned delcareOp input is reference type.
+
+    if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
+      auto originalDeclareOp = cast<fir::DeclareOp>(op);
+      Type originalInType = originalDeclareOp.getMemref().getType();
+      Type clonedInType = clonedDeclareOp.getMemref().getType();
+
+      fir::ReferenceType originalRefType =
+          dyn_cast<fir::ReferenceType>(originalInType);
+      fir::ReferenceType clonedRefType =
+          dyn_cast<fir::ReferenceType>(clonedInType);
+      if (!originalRefType && clonedRefType) {
+        Type clonedEleTy = clonedRefType.getElementType();
+        if (clonedEleTy == originalDeclareOp.getType()) {
+          opsToReplace.push_back(clonedOp);
         }
       }
+    }
       if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
         opsToReplace.push_back(clonedOp);
-    }
+      if (isRuntimeCall(clonedOp)) {
+        fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+        if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+            "_FortranAAssign") {
+          opsToReplace.push_back(clonedOp);
+        } else {
+          llvm_unreachable("Unhandled runtime call hoisting.");
+        }
+      }
   }
 
-  // Replace fir.allocmem with omp.target_allocmem,
-  // fir.freemem with omp.target_freemem.
   for (Operation *op : opsToReplace) {
     if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
       rewriter.setInsertionPoint(allocOp);
@@ -963,16 +1064,40 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       Value loadedValue = rewriter.create<fir::LoadOp>(
           clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
       clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
+    } else if (isRuntimeCall(op)) {
+      rewriter.setInsertionPoint(op);
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      SmallVector<Value> operands = runtimeCall.getOperands();
+      mlir::Location loc = runtimeCall.getLoc();
+      fir::FirOpBuilder builder{rewriter, op};
+      assert(operands.size() == 4);
+      Value sourceFile{operands[2]}, sourceLine{operands[3]};
+
+      auto fromBaseAddr =
+          genDescriptorGetBaseAddress(builder, loc, operands[1]);
+      auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]);
+      auto dataSizeInBytes =
+          genDescriptorGetDataSizeInBytes(builder, loc, operands[1]);
+
+      Value toPtr =
+          genOmpGetMappedPtrIfPresent(builder, loc, toBaseAddr, device, module);
+      Value fromPtr = genOmpGetMappedPtrIfPresent(builder, loc, fromBaseAddr,
+                                                  device, module);
+      Value zero = genI64Constant(loc, rewriter, 0);
+      genOmpTargetMemcpyCall(builder, loc, toPtr, fromPtr, dataSizeInBytes,
+                             zero, zero, device, module);
+      rewriter.eraseOp(op);
     }
   }
   rewriter.eraseOp(targetOp);
 }
 
-void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
+void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
+                   mlir::ModuleOp module) {
   auto tuple = getNestedOpToIsolate(targetOp);
   if (!tuple) {
     LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
-    moveToHost(targetOp, rewriter);
+    moveToHost(targetOp, rewriter, module);
     return;
   }
 
@@ -982,18 +1107,18 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
 
   if (splitBefore && splitAfter) {
     auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    moveToHost(res.preTargetOp, rewriter);
-    fissionTarget(res.postTargetOp, rewriter);
+    moveToHost(res.preTargetOp, rewriter, module);
+    fissionTarget(res.postTargetOp, rewriter, module);
     return;
   }
   if (splitBefore) {
     auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    moveToHost(res.preTargetOp, rewriter);
+    moveToHost(res.preTargetOp, rewriter, module);
     return;
   }
   if (splitAfter) {
     auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
-    fissionTarget(res.postTargetOp, rewriter);
+    fissionTarget(res.postTargetOp, rewriter, module);
     return;
   }
 }
@@ -1023,7 +1148,7 @@ class LowerWorkdistributePass
       for (auto targetOp : targetOps) {
         auto res = splitTargetData(targetOp, rewriter);
         if (res)
-          fissionTarget(res->targetOp, rewriter);
+          fissionTarget(res->targetOp, rewriter, moduleOp);
       }
     }
   }
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
index 19bdb9ce10fbd..25ef34f81b492 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -14,7 +14,7 @@
 // CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "lb"}
 // CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "ub"}
 // CHECK:           %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "step"}
-// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
 // CHECK:           omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
 // CHECK:             %[[VAL_11:.*]] = fir.alloca index
 // CHECK:             %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
@@ -28,29 +28,30 @@
 // CHECK:             %[[VAL_20:.*]] = fir.alloca !fir.heap<index>
 // CHECK:             %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
 // CHECK:             %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
-// CHECK:             %[[VAL_23:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref<index>
 // CHECK:             %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref<index>
 // CHECK:             %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
-// CHECK:             %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
-// CHECK:             %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:             %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array<i32: 1, 0, 1>, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap<index>
+// CHECK:             %[[VAL_27:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"}
+// CHECK:             %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap<index>
 // CHECK:             fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref<index>
 // CHECK:             fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref<index>
 // CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
-// CHECK:             fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
-// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
-// CHECK:               %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr<index>
+// CHECK:             fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
 // CHECK:               %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
 // CHECK:               %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<!fir.heap<index>>
-// CHECK:               %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index
+// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index
 // CHECK:               omp.teams {
 // CHECK:                 omp.parallel {
 // CHECK:                   omp.distribute {
 // CHECK:                     omp.wsloop {
-// CHECK:                       omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) {
-// CHECK:                         fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap<index>
+// CHECK:                       omp.loop_nest (%[[VAL_44:.*]]) : index = (%[[VAL_39]]) to (%[[VAL_40]]) inclusive step (%[[VAL_41]]) {
+// CHECK:                         fir.store %[[VAL_43]] to %[[VAL_42]] : !fir.heap<index>
 // CHECK:                         omp.yield
 // CHECK:                       }
 // CHECK:                     } {omp.composite}
@@ -61,14 +62,15 @@
 // CHECK:               }
 // CHECK:               omp.terminator
 // CHECK:             }
-// CHECK:             %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
-// CHECK:             %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
-// CHECK:             %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
-// CHECK:             %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
-// CHECK:             %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index
-// CHECK:             fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap<index>
-// CHECK:             %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:             "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap<index>) -> ()
+// CHECK:             %[[VAL_45:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             %[[VAL_46:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
+// CHECK:             %[[VAL_47:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK:             %[[VAL_48:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
+// CHECK:             %[[VAL_49:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_47]] : index
+// CHECK:             fir.store %[[VAL_46]] to %[[VAL_49]] : !fir.heap<index>
+// CHECK:             %[[VAL_51:.*]] = fir.convert %[[VAL_49]] : (!fir.heap<index>) -> i64
+// CHECK:             omp.target_freemem %[[VAL_45]], %[[VAL_51]] : i32, i64
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           return

>From 4efe1390a520d3cff909b1801121a778bbde90a6 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 2 Sep 2025 21:37:39 +0530
Subject: [PATCH 06/13] clang-format code

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp   | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 8ead3d57eca98..fe07070f008c9 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -1023,17 +1023,17 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
         }
       }
     }
-      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+    if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+      opsToReplace.push_back(clonedOp);
+    if (isRuntimeCall(clonedOp)) {
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+          "_FortranAAssign") {
         opsToReplace.push_back(clonedOp);
-      if (isRuntimeCall(clonedOp)) {
-        fir::CallOp runtimeCall = cast<fir::CallOp>(op);
-        if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
-            "_FortranAAssign") {
-          opsToReplace.push_back(clonedOp);
-        } else {
-          llvm_unreachable("Unhandled runtime call hoisting.");
-        }
+      } else {
+        llvm_unreachable("Unhandled runtime call hoisting.");
       }
+    }
   }
 
   for (Operation *op : opsToReplace) {

>From fbf5dada26f79f3125d39d8966516e3d9ac30584 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Sep 2025 11:41:51 +0530
Subject: [PATCH 07/13] Remove openp-to-llvm ir changes. Created new PR #157717

---
 .../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b8b9a0338e9ac..4e26e65cf9718 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5400,9 +5400,6 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  if (targetOp.getHostEvalVars().empty())
-    numLoops = 0;
-
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);

>From 0436298f3fff36ced2c730e7c3c17088f6d7ddda Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Sep 2025 15:39:22 +0530
Subject: [PATCH 08/13] Fix CI errors

---
 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index fe07070f008c9..7fc59eee2ca2d 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -540,7 +540,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
   auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
   auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
   return TempOmpVar{mapInfoFrom, mapInfoTo};
-};
+}
 
 static bool usedOutsideSplit(Value v, Operation *split) {
   if (!split)
@@ -1071,8 +1071,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
       mlir::Location loc = runtimeCall.getLoc();
       fir::FirOpBuilder builder{rewriter, op};
       assert(operands.size() == 4);
-      Value sourceFile{operands[2]}, sourceLine{operands[3]};
-
       auto fromBaseAddr =
           genDescriptorGetBaseAddress(builder, loc, operands[1]);
       auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]);

>From 8a465251ed4593f134f12fc728d81c840f4152e5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 15 Sep 2025 18:48:46 +0530
Subject: [PATCH 09/13] Use host_eval on target in host execution.

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 743 +++++++++++++-----
 .../lower-workdistribute-fission-host.mlir    | 117 +++
 .../lower-workdistribute-fission-target.mlir  |   3 +
 3 files changed, 647 insertions(+), 216 deletions(-)
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 7fc59eee2ca2d..aa0e1f3416114 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -16,8 +16,7 @@
 // It hoists all the other ops outside target region.
 // Relaces heap allocation on target with omp.target_allocmem and
 // deallocation with omp.target_freemem from host. Also replaces
-// runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign
-// on target, once hoisted outside target is replaced with @_FortranAAssign_omp.
+// runtime function "Assign" with omp.target_memcpy.
 //
 //===----------------------------------------------------------------------===//
 
@@ -75,7 +74,7 @@ static bool isRuntimeCall(Operation *op) {
 }
 
 // This is the single source of truth about whether we should parallelize an
-// operation nested in an omp.execute region.
+// operation nested in an omp.workdistribute region.
 static bool shouldParallelize(Operation *op) {
   if (llvm::any_of(op->getResults(),
                    [](OpResult v) -> bool { return !v.use_empty(); }))
@@ -87,6 +86,7 @@ static bool shouldParallelize(Operation *op) {
       return false;
     return *unordered;
   }
+  // True if the op is a runtime call to Assign
   if (isRuntimeCall(op) &&
       (op->getName().getStringRef() == "_FortranAAssign")) {
     return true;
@@ -235,6 +235,7 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
   return changed;
 }
 
+// Generate omp.parallel operation with an empty region.
 static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
   auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
   parallelOp.setComposite(composite);
@@ -243,6 +244,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
   return;
 }
 
+// Generate omp.distribute operation with an empty region.
 static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
   mlir::omp::DistributeOperands distributeClauseOps;
   auto distributeOp =
@@ -253,6 +255,7 @@ static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
   return;
 }
 
+// Generate loop nest clause operands from fir.do_loop operation.
 static void
 genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
                      mlir::omp::LoopNestOperands &loopNestClauseOps) {
@@ -264,6 +267,7 @@ genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
   loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
 }
 
+// Generate omp.wsloop operation with an empty region and
 static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
                         const mlir::omp::LoopNestOperands &clauseOps,
                         bool composite) {
@@ -286,7 +290,6 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
   if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
     rewriter.setInsertionPoint(terminatorOp);
     rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
-    // rewriter.erase(terminatorOp);
     terminatorOp->erase();
   }
   return;
@@ -319,12 +322,22 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
 //   }
 // }
 
-static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+static bool
+WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
+                      SetVector<omp::TargetOp> &targetOpsToProcess) {
   OpBuilder rewriter(workdistribute);
   auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
   auto wdLoc = workdistribute->getLoc();
   if (doLoop && shouldParallelize(doLoop)) {
     assert(doLoop.getReduceOperands().empty());
+
+    // Record the target ops to process later
+    if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
+      auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+      if (targetOp) {
+        targetOpsToProcess.insert(targetOp);
+      }
+    }
     genParallelOp(wdLoc, rewriter, true);
     genDistributeOp(wdLoc, rewriter, true);
     mlir::omp::LoopNestOperands loopNestClauseOps;
@@ -353,7 +366,7 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
 // A()
 // B()
 //
-
+// If only the terminator remains in teams after hoisting, we erase teams op.
 static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
   auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
   if (!workdistributeOp)
@@ -380,25 +393,20 @@ static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
   return true;
 }
 
-struct SplitTargetResult {
-  omp::TargetOp targetOp;
-  omp::TargetDataOp dataOp;
-};
-
 // If multiple workdistribute are nested in a target regions, we will need to
 // split the target region, but we want to preserve the data semantics of the
 // original data region and avoid unnecessary data movement at each of the
 // subkernels - we split the target region into a target_data{target}
 // nest where only the outer one moves the data
-std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
-                                                 RewriterBase &rewriter) {
+std::optional<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
+                                             RewriterBase &rewriter) {
   auto loc = targetOp->getLoc();
   if (targetOp.getMapVars().empty()) {
     LLVM_DEBUG(llvm::dbgs()
                << DEBUG_TYPE << " target region has no data maps\n");
     return std::nullopt;
   }
-
+  // Collect all the mapinfo ops
   SmallVector<omp::MapInfoOp> mapInfos;
   for (auto opr : targetOp.getMapVars()) {
     auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
@@ -408,14 +416,15 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   rewriter.setInsertionPoint(targetOp);
   SmallVector<Value> innerMapInfos;
   SmallVector<Value> outerMapInfos;
-
+  // Create new mapinfo ops for the inner target region
   for (auto mapInfo : mapInfos) {
     auto originalMapType =
         (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
     auto originalCaptureType = mapInfo.getMapCaptureType();
     llvm::omp::OpenMPOffloadMappingFlags newMapType;
     mlir::omp::VariableCaptureKind newCaptureType;
-
+    // For bycopy, we keep the same map type and capture type
+    // For byref, we change the map type to none and keep the capture type
     if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
       newMapType = originalMapType;
       newCaptureType = originalCaptureType;
@@ -441,12 +450,13 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   auto ifExpr = targetOp.getIfExpr();
   auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
   auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  // Create the target data op
   auto targetDataOp = rewriter.create<omp::TargetDataOp>(
       loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
   auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
   rewriter.create<mlir::omp::TerminatorOp>(loc);
   rewriter.setInsertionPointToStart(taregtDataBlock);
-
+  // Create the inner target op
   auto newTargetOp = rewriter.create<omp::TargetOp>(
       targetOp.getLoc(), targetOp.getAllocateVars(),
       targetOp.getAllocatorVars(), targetOp.getBareAttr(),
@@ -460,9 +470,8 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
       targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
-
   rewriter.replaceOp(targetOp, targetDataOp);
-  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+  return newTargetOp;
 }
 
 // getNestedOpToIsolate function is designed to identify a specific teams
@@ -480,20 +489,23 @@ getNestedOpToIsolate(omp::TargetOp targetOp) {
     if (first && last)
       return std::nullopt;
 
-    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+    if (isa<omp::TeamsOp>(&op))
       return {{&op, first, last}};
   }
   return std::nullopt;
 }
 
+// Temporary structure to hold the two mapinfo ops
 struct TempOmpVar {
   omp::MapInfoOp from, to;
 };
 
+// isPtr checks if the type is a pointer or reference type.
 static bool isPtr(Type ty) {
   return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
 }
 
+// getPtrTypeForOmp returns an LLVM pointer type for the given type.
 static Type getPtrTypeForOmp(Type ty) {
   if (isPtr(ty))
     return LLVM::LLVMPointerType::get(ty.getContext());
@@ -501,6 +513,7 @@ static Type getPtrTypeForOmp(Type ty) {
     return fir::LLVMPointerType::get(ty);
 }
 
+// allocateTempOmpVar allocates a temporary variable for OpenMP mapping
 static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
                                      RewriterBase &rewriter) {
   MLIRContext &ctx = *ty.getContext();
@@ -542,6 +555,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
   return TempOmpVar{mapInfoFrom, mapInfoTo};
 }
 
+// usedOutsideSplit checks if a value is used outside the split operation.
 static bool usedOutsideSplit(Value v, Operation *split) {
   if (!split)
     return false;
@@ -557,6 +571,7 @@ static bool usedOutsideSplit(Value v, Operation *split) {
   return false;
 };
 
+// isRecomputableAfterFission checks if an operation can be recomputed
 static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
   if (isa<fir::DeclareOp>(op))
     return true;
@@ -572,12 +587,7 @@ static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
   return false;
 }
 
-struct SplitResult {
-  omp::TargetOp preTargetOp;
-  omp::TargetOp isolatedTargetOp;
-  omp::TargetOp postTargetOp;
-};
-
+// collectNonRecomputableDeps collects dependencies that cannot be recomputed
 static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
                                        SetVector<Operation *> &nonRecomputable,
                                        SetVector<Operation *> &toCache,
@@ -597,20 +607,40 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
                                toRecompute);
 }
 
+// createBlockArgsAndMap creates block arguments and maps them
 static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
                                   omp::TargetOp &targetOp, Block *targetBlock,
                                   Block *newTargetBlock,
+                                  SmallVector<Value> &hostEvalVars,
                                   SmallVector<Value> &mapOperands,
                                   SmallVector<Value> &allocs,
                                   IRMapping &irMapping) {
-  // Map `map_operands` to block arguments.
+  // FIRST: Map `host_eval_vars` to block arguments
+  unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size();
+  for (unsigned i = 0; i < hostEvalVars.size(); ++i) {
+    Value originalValue;
+    BlockArgument newArg;
+    if (i < originalHostEvalVarsSize) {
+      originalValue = targetBlock->getArgument(i); // Host_eval args come first
+      newArg = newTargetBlock->addArgument(originalValue.getType(),
+                                           originalValue.getLoc());
+    } else {
+      originalValue = hostEvalVars[i];
+      newArg = newTargetBlock->addArgument(originalValue.getType(),
+                                           originalValue.getLoc());
+    }
+    irMapping.map(originalValue, newArg);
+  }
+
+  // SECOND: Map `map_operands` to block arguments
   unsigned originalMapVarsSize = targetOp.getMapVars().size();
   for (unsigned i = 0; i < mapOperands.size(); ++i) {
     Value originalValue;
     BlockArgument newArg;
     // Map the new arguments from the original block.
     if (i < originalMapVarsSize) {
-      originalValue = targetBlock->getArgument(i);
+      originalValue = targetBlock->getArgument(originalHostEvalVarsSize +
+                                               i); // Offset by host_eval count
       newArg = newTargetBlock->addArgument(originalValue.getType(),
                                            originalValue.getLoc());
     }
@@ -622,10 +652,12 @@ static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
     }
     irMapping.map(originalValue, newArg);
   }
-  // Map `private_vars` to block arguments.
+
+  // THIRD: Map `private_vars` to block arguments (if any)
   unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size();
   for (unsigned i = 0; i < originalPrivateVarsSize; ++i) {
-    auto originalArg = targetBlock->getArgument(originalMapVarsSize + i);
+    auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize +
+                                                originalMapVarsSize + i);
     auto newArg = newTargetBlock->addArgument(originalArg.getType(),
                                               originalArg.getLoc());
     irMapping.map(originalArg, newArg);
@@ -633,24 +665,25 @@ static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
   return;
 }
 
+// reloadCacheAndRecompute reloads cached values and recomputes operations
 static void reloadCacheAndRecompute(
     Location loc, RewriterBase &rewriter, Operation *splitBefore,
     omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
-    SmallVector<Value> &mapOperands, SmallVector<Value> &allocs,
-    SetVector<Operation *> &toRecompute, IRMapping &irMapping) {
-  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, newTargetBlock,
-                        mapOperands, allocs, irMapping);
+    SmallVector<Value> &hostEvalVars, SmallVector<Value> &mapOperands,
+    SmallVector<Value> &allocs, SetVector<Operation *> &toRecompute,
+    IRMapping &irMapping) {
   // Handle the load operations for the allocs.
   rewriter.setInsertionPointToStart(newTargetBlock);
   auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
 
   unsigned originalMapVarsSize = targetOp.getMapVars().size();
+  unsigned hostEvalVarsSize = hostEvalVars.size();
   // Create Stores for allocs.
   for (unsigned i = 0; i < allocs.size(); ++i) {
     Value original = allocs[i];
     // Get the new block argument for this specific allocated value.
-    Value newArg = newTargetBlock->getArgument(originalMapVarsSize + i);
-
+    Value newArg =
+        newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i);
     Value restored;
     if (isPtr(original.getType())) {
       restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
@@ -662,171 +695,66 @@ static void reloadCacheAndRecompute(
     }
     irMapping.map(original, restored);
   }
-
+  // Clone the operations if they are in the toRecompute set.
   for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
     if (toRecompute.contains(&*it))
       rewriter.clone(*it, irMapping);
   }
 }
 
-// isolateOp method rewrites a omp.target_data { omp.target } in to
-// omp.target_data {
-//      // preTargetOp region contains ops before splitBeforeOp.
-//      omp.target {}
-//      // isolatedTargetOp region contains splitBeforeOp,
-//      omp.target {}
-//      // postTargetOp region contains ops after splitBeforeOp.
-//      omp.target {}
-// }
-static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
-                             RewriterBase &rewriter) {
-  auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
-  assert(targetOp);
-  auto loc = targetOp.getLoc();
-  auto *targetBlock = &targetOp.getRegion().front();
-  rewriter.setInsertionPoint(targetOp);
-
-  auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
-  auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
-
-  SmallVector<Value> requiredVals;
-  SetVector<Operation *> toCache;
-  SetVector<Operation *> toRecompute;
-  SetVector<Operation *> nonRecomputable;
-  SmallVector<Value> allocs;
-
-  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
-       it++) {
-    for (auto res : it->getResults()) {
-      if (usedOutsideSplit(res, splitBeforeOp))
-        requiredVals.push_back(res);
+// Given a teamsOp, navigate down the nested structure to find the
+// innermost LoopNestOp. The expected nesting is:
+// teams -> parallel -> distribute -> wsloop -> loop_nest
+static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) {
+  if (teamsOp.getRegion().empty())
+    return nullptr;
+  // Ensure the teams region has a single block.
+  if (teamsOp.getRegion().getBlocks().size() != 1)
+    return nullptr;
+  // Find parallel op inside teams
+  mlir::omp::ParallelOp parallelOp = nullptr;
+  for (auto &op : teamsOp.getRegion().front()) {
+    if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) {
+      parallelOp = parallel;
+      break;
     }
-    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
-      nonRecomputable.insert(&*it);
   }
+  if (!parallelOp)
+    return nullptr;
 
-  for (auto requiredVal : requiredVals)
-    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
-                               toRecompute);
-
-  for (Operation *op : toCache) {
-    for (auto res : op->getResults()) {
-      auto alloc =
-          allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
-      allocs.push_back(res);
-      preMapOperands.push_back(alloc.from);
-      postMapOperands.push_back(alloc.to);
+  // Find distribute op inside parallel
+  mlir::omp::DistributeOp distributeOp = nullptr;
+  for (auto &op : parallelOp.getRegion().front()) {
+    if (auto distribute = dyn_cast<mlir::omp::DistributeOp>(op)) {
+      distributeOp = distribute;
+      break;
     }
   }
+  if (!distributeOp)
+    return nullptr;
 
-  rewriter.setInsertionPoint(targetOp);
-
-  auto preTargetOp = rewriter.create<omp::TargetOp>(
-      targetOp.getLoc(), targetOp.getAllocateVars(),
-      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
-      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
-      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
-      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
-      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
-      preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
-      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
-  auto *preTargetBlock = rewriter.createBlock(
-      &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
-  IRMapping preMapping;
-
-  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock,
-                        preMapOperands, allocs, preMapping);
-
-  // Handle the store operations for the allocs.
-  rewriter.setInsertionPointToStart(preTargetBlock);
-  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
-
-  // Clone the original operations.
-  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
-       it++) {
-    rewriter.clone(*it, preMapping);
-  }
-
-  unsigned originalMapVarsSize = targetOp.getMapVars().size();
-  // Create Stores for allocs.
-  for (unsigned i = 0; i < allocs.size(); ++i) {
-    Value originalResult = allocs[i];
-    Value toStore = preMapping.lookup(originalResult);
-    // Get the new block argument for this specific allocated value.
-    Value newArg = preTargetBlock->getArgument(originalMapVarsSize + i);
-
-    if (isPtr(originalResult.getType())) {
-      if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
-        toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
-      rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
-    } else {
-      rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+  // Find wsloop op inside distribute
+  mlir::omp::WsloopOp wsloopOp = nullptr;
+  for (auto &op : distributeOp.getRegion().front()) {
+    if (auto wsloop = dyn_cast<mlir::omp::WsloopOp>(op)) {
+      wsloopOp = wsloop;
+      break;
     }
   }
-  rewriter.create<omp::TerminatorOp>(loc);
-
-  rewriter.setInsertionPoint(targetOp);
-
-  auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
-      targetOp.getLoc(), targetOp.getAllocateVars(),
-      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
-      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
-      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
-      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
-      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
-      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
-      postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
-      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
-
-  auto *isolatedTargetBlock =
-      rewriter.createBlock(&isolatedTargetOp.getRegion(),
-                           isolatedTargetOp.getRegion().begin(), {}, {});
-
-  IRMapping isolatedMapping;
-  reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
-                          isolatedTargetBlock, postMapOperands, allocs,
-                          toRecompute, isolatedMapping);
-  rewriter.clone(*splitBeforeOp, isolatedMapping);
-  rewriter.create<omp::TerminatorOp>(loc);
-
-  omp::TargetOp postTargetOp = nullptr;
+  if (!wsloopOp)
+    return nullptr;
 
-  if (splitAfter) {
-    rewriter.setInsertionPoint(targetOp);
-    postTargetOp = rewriter.create<omp::TargetOp>(
-        targetOp.getLoc(), targetOp.getAllocateVars(),
-        targetOp.getAllocatorVars(), targetOp.getBareAttr(),
-        targetOp.getDependKindsAttr(), targetOp.getDependVars(),
-        targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
-        targetOp.getHostEvalVars(), targetOp.getIfExpr(),
-        targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
-        targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
-        postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
-        targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
-        targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
-    auto *postTargetBlock = rewriter.createBlock(
-        &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
-    IRMapping postMapping;
-    reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
-                            postTargetBlock, postMapOperands, allocs,
-                            toRecompute, postMapping);
-
-    assert(splitBeforeOp->getNumResults() == 0 ||
-           llvm::all_of(splitBeforeOp->getResults(),
-                        [](Value result) { return result.use_empty(); }));
-
-    for (auto it = std::next(splitBeforeOp->getIterator());
-         it != targetBlock->end(); it++)
-      rewriter.clone(*it, postMapping);
+  // Find loop_nest op inside wsloop
+  for (auto &op : wsloopOp.getRegion().front()) {
+    if (auto loopNest = dyn_cast<mlir::omp::LoopNestOp>(op)) {
+      return loopNest;
+    }
   }
 
-  rewriter.eraseOp(targetOp);
-  return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+  return nullptr;
 }
 
+// Generate LLVM constant operations for i32 and i64 types.
 static mlir::LLVM::ConstantOp
 genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   mlir::Type i32Ty = rewriter.getI32Type();
@@ -834,6 +762,7 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
 }
 
+// Generate LLVM constant operations for i64 type.
 static mlir::LLVM::ConstantOp
 genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   mlir::Type i64Ty = rewriter.getI64Type();
@@ -841,6 +770,9 @@ genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
 }
 
+// Given a box descriptor, extract the base address of the data it describes.
+// If the box descriptor is a reference, load it first.
+// The base address is returned as an i8* pointer.
 static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
                                          Location loc, Value boxDesc) {
   Value box = boxDesc;
@@ -858,6 +790,9 @@ static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
   return rawAddr;
 }
 
+// Given a box descriptor, extract the total number of elements in the array it
+// describes. If the box descriptor is a reference, load it first.
+// The total number of elements is returned as an i64 value.
 static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder,
                                            Location loc, Value boxDesc) {
   Value box = boxDesc;
@@ -870,6 +805,9 @@ static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder,
   return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box);
 }
 
+// Given a box descriptor, extract the size of each element in the array it
+// describes. If the box descriptor is a reference, load it first.
+// The element size is returned as an i64 value.
 static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc,
                                      Value boxDesc) {
   Value box = boxDesc;
@@ -882,6 +820,10 @@ static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc,
   return fir::BoxEleSizeOp::create(builder, loc, i64Type, box);
 }
 
+// Given a box descriptor, compute the total size in bytes of the data it
+// describes. This is done by multiplying the total number of elements by the
+// size of each element. If the box descriptor is a reference, load it first.
+// The total size in bytes is returned as an i64 value.
 static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
                                              Location loc, Value boxDesc) {
   Value box = boxDesc;
@@ -895,6 +837,11 @@ static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
   return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize);
 }
 
+// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to
+// retrieve the device pointer corresponding to a given host pointer and device
+// number. If no mapping exists, the original host pointer is returned.
+// Signature:
+//   void *omp_get_mapped_ptr(void *host_ptr, int device_num);
 static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder,
                                                mlir::Location loc,
                                                mlir::Value hostPtr,
@@ -930,15 +877,18 @@ static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder,
   return result;
 }
 
+// Generate a call to the OpenMP runtime function `omp_target_memcpy` to
+// perform memory copy between host and device or between devices.
+// Signature:
+//   int omp_target_memcpy(void *dst, const void *src, size_t length,
+//                         size_t dst_offset, size_t src_offset,
+//                         int dst_device, int src_device);
 static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
                                    mlir::Location loc, mlir::Value dst,
                                    mlir::Value src, mlir::Value length,
                                    mlir::Value dstOffset, mlir::Value srcOffset,
                                    mlir::Value device, mlir::ModuleOp module) {
   auto *context = builder.getContext();
-  // int omp_target_memcpy(void *dst, const void *src, size_t length,
-  //                       size_t dst_offset, size_t src_offset,
-  //                       int dst_device, int src_device)
   auto funcName = "omp_target_memcpy";
   auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
   auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit
@@ -962,30 +912,48 @@ static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
   return;
 }
 
+// Struct to hold the host eval vars corresponding to loop bounds and steps
+struct HostEvalVars {
+  SmallVector<Value> lbs;
+  SmallVector<Value> ubs;
+  SmallVector<Value> steps;
+};
+
 // moveToHost method clones all the ops from target region outside of it.
 // It hoists runtime functions and replaces them with omp vesions.
 // Also hoists and replaces fir.allocmem with omp.target_allocmem and
 // fir.freemem with omp.target_freemem
 static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
-                       mlir::ModuleOp module) {
+                       mlir::ModuleOp module,
+                       struct HostEvalVars &hostEvalVars) {
   OpBuilder::InsertionGuard guard(rewriter);
   Block *targetBlock = &targetOp.getRegion().front();
   assert(targetBlock == &targetOp.getRegion().back());
   IRMapping mapping;
+  // create mapping for host_eval_vars
+  unsigned hostEvalVarCount = targetOp.getHostEvalVars().size();
+  for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) {
+    Value hostEvalVar = targetOp.getHostEvalVars()[i];
+    BlockArgument arg = targetBlock->getArguments()[i];
+    mapping.map(arg, hostEvalVar);
+  }
+  // create mapping for map_vars
   for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) {
     Value mapInfo = targetOp.getMapVars()[i];
-    BlockArgument arg = targetBlock->getArguments()[i];
+    BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i];
     Operation *op = mapInfo.getDefiningOp();
     assert(op);
     auto mapInfoOp = cast<omp::MapInfoOp>(op);
     // map the block argument to the host-side variable pointer
     mapping.map(arg, mapInfoOp.getVarPtr());
   }
+  // create mapping for private_vars
   unsigned mapSize = targetOp.getMapVars().size();
   for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) {
     Value privateVar = targetOp.getPrivateVars()[i];
     // The mapping should link the device-side variable to the host-side one.
-    BlockArgument arg = targetBlock->getArguments()[mapSize + i];
+    BlockArgument arg =
+        targetBlock->getArguments()[hostEvalVarCount + mapSize + i];
     // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
     mapping.map(arg, privateVar);
   }
@@ -993,20 +961,22 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
   rewriter.setInsertionPoint(targetOp);
   SmallVector<Operation *> opsToReplace;
   Value device = targetOp.getDevice();
+
   if (!device) {
     device = genI32Constant(targetOp.getLoc(), rewriter, 0);
   }
+  // Clone all operations.
   for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
        it != end; ++it) {
     auto *op = &*it;
     Operation *clonedOp = rewriter.clone(*op, mapping);
+    // Map the results of the original op to the cloned op.
     for (unsigned i = 0; i < op->getNumResults(); ++i) {
       mapping.map(op->getResult(i), clonedOp->getResult(i));
     }
     // fir.declare changes its type when hoisting it out of omp.target to
     // omp.target_data Introduce a load, if original declareOp input is not of
     // reference type, but cloned delcareOp input is reference type.
-
     if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
       auto originalDeclareOp = cast<fir::DeclareOp>(op);
       Type originalInType = originalDeclareOp.getMemref().getType();
@@ -1023,8 +993,10 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
         }
       }
     }
+    // Collect the ops to be replaced.
     if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
       opsToReplace.push_back(clonedOp);
+    // Check for runtime calls to be replaced.
     if (isRuntimeCall(clonedOp)) {
       fir::CallOp runtimeCall = cast<fir::CallOp>(op);
       if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
@@ -1035,7 +1007,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
       }
     }
   }
-
+  // Replace fir.allocmem with omp.target_allocmem.
   for (Operation *op : opsToReplace) {
     if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
       rewriter.setInsertionPoint(allocOp);
@@ -1048,14 +1020,20 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
           allocOp.getLoc(), allocOp.getResult().getType(),
           ompAllocmemOp.getResult());
       rewriter.replaceOp(allocOp, firConvertOp.getResult());
-    } else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
+    }
+    // Replace fir.freemem with omp.target_freemem.
+    else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
       rewriter.setInsertionPoint(freeOp);
       auto firConvertOp = rewriter.create<fir::ConvertOp>(
           freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
       rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
                                             firConvertOp.getResult());
       rewriter.eraseOp(freeOp);
-    } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
+    }
+    // fir.declare changes its type when hoisting it out of omp.target to
+    // omp.target_data Introduce a load, if original declareOp input is not of
+    // reference type, but cloned delcareOp input is reference type.
+    else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
       Type clonedInType = clonedDeclareOp.getMemref().getType();
       fir::ReferenceType clonedRefType =
           dyn_cast<fir::ReferenceType>(clonedInType);
@@ -1064,7 +1042,9 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
       Value loadedValue = rewriter.create<fir::LoadOp>(
           clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
       clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
-    } else if (isRuntimeCall(op)) {
+    }
+    // Replace runtime calls with omp versions.
+    else if (isRuntimeCall(op)) {
       rewriter.setInsertionPoint(op);
       fir::CallOp runtimeCall = cast<fir::CallOp>(op);
       SmallVector<Value> operands = runtimeCall.getOperands();
@@ -1087,40 +1067,370 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
       rewriter.eraseOp(op);
     }
   }
+
+  // Update the host_eval_vars to use the mapped values.
+  for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) {
+    hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]);
+    hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]);
+    hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]);
+  }
+  // Finally erase the original targetOp.
+  rewriter.eraseOp(targetOp);
+}
+
+// Result of isolateOp method
+struct SplitResult {
+  omp::TargetOp preTargetOp;
+  omp::TargetOp isolatedTargetOp;
+  omp::TargetOp postTargetOp;
+};
+
+// computeAllocsCacheRecomputable method computes the allocs needed to cache
+// the values that are used outside the split point. It also computes the ops
+// that need to be cached and the ops that can be recomputed after the split.
+static void computeAllocsCacheRecomputable(
+    omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter,
+    SmallVector<Value> &preMapOperands, SmallVector<Value> &postMapOperands,
+    SmallVector<Value> &allocs, SmallVector<Value> &requiredVals,
+    SetVector<Operation *> &nonRecomputable, SetVector<Operation *> &toCache,
+    SetVector<Operation *> &toRecompute) {
+  auto *targetBlock = &targetOp.getRegion().front();
+  // Find all values that are used outside the split point.
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
+    // Check if any of the results are used outside the split point.
+    for (auto res : it->getResults()) {
+      if (usedOutsideSplit(res, splitBeforeOp))
+        requiredVals.push_back(res);
+    }
+    // If the op is not recomputable, add it to the nonRecomputable set.
+    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+      nonRecomputable.insert(&*it);
+  }
+  // For each required value, collect its dependencies.
+  for (auto requiredVal : requiredVals)
+    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+  // For each op in toCache, create an alloc and update the pre and post map
+  // operands.
+  for (Operation *op : toCache) {
+    for (auto res : op->getResults()) {
+      auto alloc =
+          allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+      allocs.push_back(res);
+      preMapOperands.push_back(alloc.from);
+      postMapOperands.push_back(alloc.to);
+    }
+  }
+}
+
+// genPreTargetOp method generates the preTargetOp that contains all the ops
+// before the split point. It also creates the block arguments and maps the
+// values accordingly. It also creates the store operations for the allocs.
+static omp::TargetOp
+genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
+               SmallVector<Value> &allocs, Operation *splitBeforeOp,
+               RewriterBase &rewriter, struct HostEvalVars &hostEvalVars,
+               bool isTargetDevice) {
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()};
+  // update the hostEvalVars of preTargetOp
+  omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
+      targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+      targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
+      targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+      targetOp.getPrivateMapsAttr());
+  auto *preTargetBlock = rewriter.createBlock(
+      &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+  IRMapping preMapping;
+  // Create block arguments and map the values.
+  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock,
+                        preHostEvalVars, preMapOperands, allocs, preMapping);
+
+  // Handle the store operations for the allocs.
+  rewriter.setInsertionPointToStart(preTargetBlock);
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+  // Clone the original operations.
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
+    rewriter.clone(*it, preMapping);
+  }
+
+  unsigned originalHostEvalVarsSize = preHostEvalVars.size();
+  unsigned originalMapVarsSize = targetOp.getMapVars().size();
+  // Create Stores for allocs.
+  for (unsigned i = 0; i < allocs.size(); ++i) {
+    Value originalResult = allocs[i];
+    Value toStore = preMapping.lookup(originalResult);
+    // Get the new block argument for this specific allocated value.
+    Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize +
+                                               originalMapVarsSize + i);
+    // Create the store operation.
+    if (isPtr(originalResult.getType())) {
+      if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+        toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
+      rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+    } else {
+      rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+    }
+  }
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  // Update hostEvalVars with the mapped values for the loop bounds if we have
+  // a loopNestOp and we are not generating code for the target device.
+  omp::LoopNestOp loopNestOp =
+      getLoopNestFromTeams(cast<omp::TeamsOp>(splitBeforeOp));
+  if (loopNestOp && !isTargetDevice) {
+    for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) {
+      Value lb = loopNestOp.getLoopLowerBounds()[i];
+      Value ub = loopNestOp.getLoopUpperBounds()[i];
+      Value step = loopNestOp.getLoopSteps()[i];
+
+      hostEvalVars.lbs.push_back(preMapping.lookup(lb));
+      hostEvalVars.ubs.push_back(preMapping.lookup(ub));
+      hostEvalVars.steps.push_back(preMapping.lookup(step));
+    }
+  }
+
+  return preTargetOp;
+}
+
+// genIsolatedTargetOp method generates the isolatedTargetOp that contains the
+// ops between the split point. It also creates the block arguments and maps
+// the values accordingly. It also creates the load operations for the allocs
+// and recomputes the necessary ops.
+static omp::TargetOp
+genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
+                    Operation *splitBeforeOp, RewriterBase &rewriter,
+                    SmallVector<Value> &allocs,
+                    SetVector<Operation *> &toRecompute,
+                    struct HostEvalVars &hostEvalVars, bool isTargetDevice) {
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars()};
+  // update the hostEvalVars of isolatedTargetOp
+  if (!hostEvalVars.lbs.empty() && !isTargetDevice) {
+    for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) {
+      isolatedHostEvalVars.push_back(hostEvalVars.lbs[i]);
+    }
+    for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) {
+      isolatedHostEvalVars.push_back(hostEvalVars.ubs[i]);
+    }
+    for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) {
+      isolatedHostEvalVars.push_back(hostEvalVars.steps[i]);
+    }
+  }
+  // Create the isolated target op
+  omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+      targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
+      targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+      targetOp.getPrivateMapsAttr());
+  auto *isolatedTargetBlock =
+      rewriter.createBlock(&isolatedTargetOp.getRegion(),
+                           isolatedTargetOp.getRegion().begin(), {}, {});
+  IRMapping isolatedMapping;
+  // Create block arguments and map the values.
+  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock,
+                        isolatedTargetBlock, isolatedHostEvalVars,
+                        postMapOperands, allocs, isolatedMapping);
+  // Handle the load operations for the allocs and recompute ops.
+  reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+                          isolatedTargetBlock, isolatedHostEvalVars,
+                          postMapOperands, allocs, toRecompute,
+                          isolatedMapping);
+
+  // Clone the original operations.
+  rewriter.clone(*splitBeforeOp, isolatedMapping);
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  // update the loop bounds in the isolatedTargetOp if we have host_eval vars
+  // and we are not generating code for the target device.
+  if (!hostEvalVars.lbs.empty() && !isTargetDevice) {
+    omp::TeamsOp teamsOp;
+    for (auto &op : *isolatedTargetBlock) {
+      if (isa<omp::TeamsOp>(&op))
+        teamsOp = cast<omp::TeamsOp>(&op);
+    }
+    assert(teamsOp && "No teamsOp found in isolated target region");
+    // Get the loopNestOp inside the teamsOp
+    auto loopNestOp = getLoopNestFromTeams(teamsOp);
+    // Get the BlockArgs related to host_eval vars and update loop_nest bounds
+    // to them
+    unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size();
+    unsigned index = originalHostEvalVarsSize;
+    // Replace loop bounds with the block arguments passed down via host_eval
+    SmallVector<Value> lbs, ubs, steps;
+
+    // Collect new lb/ub/step values from target block args
+    for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i)
+      lbs.push_back(isolatedTargetBlock->getArgument(index++));
+
+    for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i)
+      ubs.push_back(isolatedTargetBlock->getArgument(index++));
+
+    for (size_t i = 0; i < hostEvalVars.steps.size(); ++i)
+      steps.push_back(isolatedTargetBlock->getArgument(index++));
+
+    // Reset the loop bounds
+    loopNestOp.getLoopLowerBoundsMutable().assign(lbs);
+    loopNestOp.getLoopUpperBoundsMutable().assign(ubs);
+    loopNestOp.getLoopStepsMutable().assign(steps);
+  }
+
+  return isolatedTargetOp;
+}
+
+// genPostTargetOp method generates the postTargetOp that contains all the ops
+// after the split point. It also creates the block arguments and maps the
+// values accordingly. It also creates the load operations for the allocs
+// and recomputes the necessary ops.
+static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
+                                     Operation *splitBeforeOp,
+                                     SmallVector<Value> &postMapOperands,
+                                     RewriterBase &rewriter,
+                                     SmallVector<Value> &allocs,
+                                     SetVector<Operation *> &toRecompute) {
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()};
+  // Create the post target op
+  omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
+      targetOp.getIfExpr(), targetOp.getInReductionVars(),
+      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+      targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
+      targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+      targetOp.getPrivateMapsAttr());
+  // Create the block for postTargetOp
+  auto *postTargetBlock = rewriter.createBlock(
+      &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+  IRMapping postMapping;
+  // Create block arguments and map the values.
+  createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock,
+                        postHostEvalVars, postMapOperands, allocs, postMapping);
+  // Handle the load operations for the allocs and recompute ops.
+  reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+                          postTargetBlock, postHostEvalVars, postMapOperands,
+                          allocs, toRecompute, postMapping);
+  assert(splitBeforeOp->getNumResults() == 0 ||
+         llvm::all_of(splitBeforeOp->getResults(),
+                      [](Value result) { return result.use_empty(); }));
+  // Clone the original operations after the split point.
+  for (auto it = std::next(splitBeforeOp->getIterator());
+       it != targetBlock->end(); it++)
+    rewriter.clone(*it, postMapping);
+  return postTargetOp;
+}
+
+// isolateOp method rewrites a omp.target_data { omp.target } in to
+// omp.target_data {
+//      // preTargetOp region contains ops before splitBeforeOp.
+//      omp.target {}
+//      // isolatedTargetOp region contains splitBeforeOp,
+//      omp.target {}
+//      // postTargetOp region contains ops after splitBeforeOp.
+//      omp.target {}
+// }
+// It also handles the mapping of variables and the caching/recomputing
+// of values as needed.
+static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
+                             RewriterBase &rewriter, mlir::ModuleOp module,
+                             bool isTargetDevice) {
+  auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+  assert(targetOp);
+  rewriter.setInsertionPoint(targetOp);
+
+  // Prepare the map operands for preTargetOp and postTargetOp
+  auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+  auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+  // Vectors to hold analysis results
+  SmallVector<Value> requiredVals;
+  SetVector<Operation *> toCache;
+  SetVector<Operation *> toRecompute;
+  SetVector<Operation *> nonRecomputable;
+  SmallVector<Value> allocs;
+  struct HostEvalVars hostEvalVars;
+
+  // Analyze the ops in target region to determine which ops need to be
+  // cached and which ops need to be recomputed
+  computeAllocsCacheRecomputable(
+      targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands,
+      allocs, requiredVals, nonRecomputable, toCache, toRecompute);
+
+  rewriter.setInsertionPoint(targetOp);
+
+  // Generate the preTargetOp that contains all the ops before splitBeforeOp.
+  auto preTargetOp =
+      genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter,
+                     hostEvalVars, isTargetDevice);
+
+  // Move the ops of preTarget to host.
+  moveToHost(preTargetOp, rewriter, module, hostEvalVars);
+  rewriter.setInsertionPoint(targetOp);
+
+  // Generate the isolatedTargetOp
+  omp::TargetOp isolatedTargetOp =
+      genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter,
+                          allocs, toRecompute, hostEvalVars, isTargetDevice);
+
+  omp::TargetOp postTargetOp = nullptr;
+  // Generate the postTargetOp that contains all the ops after splitBeforeOp.
+  if (splitAfter) {
+    rewriter.setInsertionPoint(targetOp);
+    postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands,
+                                   rewriter, allocs, toRecompute);
+  }
+  // Finally erase the original targetOp.
   rewriter.eraseOp(targetOp);
+  return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
 }
 
-void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
-                   mlir::ModuleOp module) {
+// Recursively fission target ops until no more nested ops can be isolated.
+static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
+                          mlir::ModuleOp module, bool isTargetDevice) {
   auto tuple = getNestedOpToIsolate(targetOp);
   if (!tuple) {
     LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
-    moveToHost(targetOp, rewriter, module);
+    struct HostEvalVars hostEvalVars;
+    moveToHost(targetOp, rewriter, module, hostEvalVars);
     return;
   }
-
   Operation *toIsolate = std::get<0>(*tuple);
   bool splitBefore = !std::get<1>(*tuple);
   bool splitAfter = !std::get<2>(*tuple);
 
   if (splitBefore && splitAfter) {
-    auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    moveToHost(res.preTargetOp, rewriter, module);
-    fissionTarget(res.postTargetOp, rewriter, module);
-    return;
-  }
-  if (splitBefore) {
-    auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    moveToHost(res.preTargetOp, rewriter, module);
-    return;
-  }
-  if (splitAfter) {
-    auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
-    fissionTarget(res.postTargetOp, rewriter, module);
+    auto res =
+        isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
+    fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice);
     return;
+  } else {
+    llvm::errs() << "Unhandled case in fissionTarget\n";
+    llvm::report_fatal_error("Unhandled case in fissionTarget");
   }
 }
 
+// Pass to lower omp.workdistribute ops.
 class LowerWorkdistributePass
     : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
 public:
@@ -1128,25 +1438,26 @@ class LowerWorkdistributePass
     MLIRContext &context = getContext();
     auto moduleOp = getOperation();
     bool changed = false;
+    SetVector<omp::TargetOp> targetOpsToProcess;
     moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
       changed |= FissionWorkdistribute(workdistribute);
     });
     moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
-      changed |= WorkdistributeDoLower(workdistribute);
+      changed |= WorkdistributeDoLower(workdistribute, targetOpsToProcess);
     });
     moduleOp->walk([&](mlir::omp::TeamsOp teams) {
       changed |= TeamsWorkdistributeToSingleOp(teams);
     });
 
     if (changed) {
-      SmallVector<omp::TargetOp> targetOps;
-      moduleOp->walk(
-          [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
+      bool isTargetDevice =
+          llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp)
+              .getIsTargetDevice();
       IRRewriter rewriter(&context);
-      for (auto targetOp : targetOps) {
+      for (auto targetOp : targetOpsToProcess) {
         auto res = splitTargetData(targetOp, rewriter);
         if (res)
-          fissionTarget(res->targetOp, rewriter, moduleOp);
+          fissionTarget(*res, rewriter, moduleOp, isTargetDevice);
       }
     }
   }
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir
new file mode 100644
index 0000000000000..b4c9598a78f0e
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir
@@ -0,0 +1,117 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+// Test lowering of workdistribute after fission on host device.
+
+// CHECK-LABEL:   func.func @x(
+// CHECK:           %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
+// CHECK:           fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref<index>
+// CHECK:           %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"}
+// CHECK:           fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref<index>
+// CHECK:           %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"}
+// CHECK:           fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref<index>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+// CHECK:             %[[VAL_11:.*]] = fir.alloca index
+// CHECK:             %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_14:.*]] = fir.alloca index
+// CHECK:             %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_17:.*]] = fir.alloca index
+// CHECK:             %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_20:.*]] = fir.alloca !fir.heap<index>
+// CHECK:             %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref<index>
+// CHECK:             %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref<index>
+// CHECK:             %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
+// CHECK:             %[[VAL_27:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"}
+// CHECK:             %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap<index>
+// CHECK:             fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index
+// CHECK:               omp.teams {
+// CHECK:                 omp.parallel {
+// CHECK:                   omp.distribute {
+// CHECK:                     omp.wsloop {
+// CHECK:                       omp.loop_nest (%[[VAL_47:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) {
+// CHECK:                         fir.store %[[VAL_46]] to %[[VAL_45]] : !fir.heap<index>
+// CHECK:                         omp.yield
+// CHECK:                       }
+// CHECK:                     } {omp.composite}
+// CHECK:                   } {omp.composite}
+// CHECK:                   omp.terminator
+// CHECK:                 } {omp.composite}
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_48:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             %[[VAL_49:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
+// CHECK:             %[[VAL_50:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK:             %[[VAL_51:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
+// CHECK:             %[[VAL_52:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             %[[VAL_53:.*]] = arith.addi %[[VAL_50]], %[[VAL_50]] : index
+// CHECK:             fir.store %[[VAL_49]] to %[[VAL_52]] : !fir.heap<index>
+// CHECK:             %[[VAL_54:.*]] = fir.convert %[[VAL_52]] : (!fir.heap<index>) -> i64
+// CHECK:             omp.target_freemem %[[VAL_48]], %[[VAL_54]] : i32, i64
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false} {
+func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
+  %lb_ref = fir.alloca index {bindc_name = "lb"}
+  fir.store %lb to %lb_ref : !fir.ref<index>
+  %ub_ref = fir.alloca index {bindc_name = "ub"}
+  fir.store %ub to %ub_ref : !fir.ref<index>
+  %step_ref = fir.alloca index {bindc_name = "step"}
+  fir.store %step to %step_ref : !fir.ref<index>
+
+  %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+  %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+  %step_map = omp.map.info var_ptr(%step_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+  %addr_map = omp.map.info var_ptr(%addr : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+
+  omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+    %lb_val = fir.load %ARG0 : !fir.ref<index>
+    %ub_val = fir.load %ARG1 : !fir.ref<index>
+    %step_val = fir.load %ARG2 : !fir.ref<index>
+    %one = arith.constant 1 : index
+
+    %20 = arith.addi %ub_val, %ub_val : index
+    omp.teams {
+      omp.workdistribute {
+        %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"}
+        fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered {
+          fir.store %20 to %dev_mem : !fir.heap<index>
+        }
+        fir.store %lb_val to %dev_mem : !fir.heap<index>
+        fir.freemem %dev_mem : !fir.heap<index>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
index 25ef34f81b492..6e82efb308328 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -1,4 +1,5 @@
 // RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+// Test lowering of workdistribute after fission on host device.
 
 // CHECK-LABEL:   func.func @x(
 // CHECK:           %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
@@ -76,6 +77,7 @@
 // CHECK:           return
 // CHECK:         }
 
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
 func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
   %lb_ref = fir.alloca index {bindc_name = "lb"}
   fir.store %lb to %lb_ref : !fir.ref<index>
@@ -112,3 +114,4 @@ func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
   }
   return
 }
+}

>From 421fa7d0828edd4214846f2a009b811090eed14f Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 15 Sep 2025 18:57:33 +0530
Subject: [PATCH 10/13] Fix CI errors

---
 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index aa0e1f3416114..e839db17150f9 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -569,7 +569,7 @@ static bool usedOutsideSplit(Value v, Operation *split) {
       return true;
   }
   return false;
-};
+}
 
 // isRecomputableAfterFission checks if an operation can be recomputed
 static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {

>From da0b4e2d32b7c711ff98e540aefa4409e1dcfc01 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 16 Sep 2025 23:10:32 +0530
Subject: [PATCH 11/13] Handle lowering of scalar assignments to arrays

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 338 ++++++++++++++++--
 1 file changed, 302 insertions(+), 36 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index e839db17150f9..4a91b074b7cd1 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -76,6 +76,17 @@ static bool isRuntimeCall(Operation *op) {
 // This is the single source of truth about whether we should parallelize an
 // operation nested in an omp.workdistribute region.
 static bool shouldParallelize(Operation *op) {
+  // True if the op is a runtime call to Assign
+  if (isRuntimeCall(op)) {
+    fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+    if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+        "_FortranAAssign") {
+      return true;
+    }
+  }
+  // We cannot parallelize ops with side effects.
+  // Parallelizable operations should not produce
+  // values that other operations depend on
   if (llvm::any_of(op->getResults(),
                    [](OpResult v) -> bool { return !v.use_empty(); }))
     return false;
@@ -86,11 +97,6 @@ static bool shouldParallelize(Operation *op) {
       return false;
     return *unordered;
   }
-  // True if the op is a runtime call to Assign
-  if (isRuntimeCall(op) &&
-      (op->getName().getStringRef() == "_FortranAAssign")) {
-    return true;
-  }
   // We cannot parallise anything else.
   return false;
 }
@@ -268,6 +274,7 @@ genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
 }
 
 // Generate omp.wsloop operation with an empty region and
+// clone the body of fir.do_loop operation inside the loop nest region.
 static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
                         const mlir::omp::LoopNestOperands &clauseOps,
                         bool composite) {
@@ -349,6 +356,221 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
   return false;
 }
 
+// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array
+static bool isEnclosedTypeRefToBoxArray(Type type) {
+  // Step 1: Check if it's a reference type
+  if (auto refType = dyn_cast<fir::ReferenceType>(type)) {
+    // Step 2: Get the referenced type (should be fir.box)
+    auto referencedType = refType.getEleTy();
+
+    // Step 3: Check if referenced type is a box
+    if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) {
+      // Step 4: Get the boxed type and check if it's an array
+      auto boxedType = boxType.getEleTy();
+
+      // Step 5: Check if boxed type is a sequence (array)
+      return isa<fir::SequenceType>(boxedType);
+    }
+  }
+  return false;
+}
+
+// Check if the enclosed type in fir.box is scalar (not array)
+static bool isEnclosedTypeBoxScalar(Type type) {
+  // Step 1: Check if it's a box type
+  if (auto boxType = dyn_cast<fir::BoxType>(type)) {
+    // Step 2: Get the boxed type
+    auto boxedType = boxType.getEleTy();
+    // Step 3: Check if boxed type is NOT a sequence (array)
+    return !isa<fir::SequenceType>(boxedType);
+  }
+  return false;
+}
+
+// Check if the FortranAAssign call has src as scalar and dest as array
+static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) {
+  if (callOp.getNumOperands() < 2)
+    return false;
+  auto srcArg = callOp.getOperand(1);
+  auto destArg = callOp.getOperand(0);
+  // Both operands should be fir.convert ops
+  auto srcConvert = srcArg.getDefiningOp<fir::ConvertOp>();
+  auto destConvert = destArg.getDefiningOp<fir::ConvertOp>();
+  if (!srcConvert || !destConvert) {
+    emitError(callOp->getLoc(),
+              "Unimplemented: FortranAssign to OpenMP lowering\n");
+    return false;
+  }
+  // Get the original types before conversion
+  auto srcOrigType = srcConvert.getValue().getType();
+  auto destOrigType = destConvert.getValue().getType();
+
+  // Check if src is scalar and dest is array
+  bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType);
+  bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType);
+  return srcIsScalar && destIsArray;
+}
+
+// Convert a flat index to multi-dimensional indices for an array box
+// Example: 2D array with shape (2,4)
+//         Col 1  Col 2  Col 3  Col 4
+// Row 1:  (1,1)  (1,2)  (1,3)  (1,4)
+// Row 2:  (2,1)  (2,2)  (2,3)  (2,4)
+//
+// extents: (2,4)
+//
+// flatIdx:  0     1     2     3     4     5     6     7
+// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4)
+static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder,
+                                                Location loc, Value flatIdx,
+                                                Value arrayBox) {
+  // Get array type and rank
+  auto boxType = cast<fir::BoxType>(arrayBox.getType());
+  auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+  int rank = seqType.getDimension();
+
+  // Get all extents
+  SmallVector<Value> extents;
+  // Get extents for each dimension
+  for (int i = 0; i < rank; ++i) {
+    auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+    auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+    extents.push_back(boxDims.getResult(1));
+  }
+
+  // Convert flat index to multi-dimensional indices
+  SmallVector<Value> indices(rank);
+  Value temp = flatIdx;
+  auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+  // Work backwards through dimensions (row-major order)
+  for (int i = rank - 1; i >= 0; --i) {
+    Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]);
+    // Convert to one-based index
+    indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1);
+    if (i > 0) {
+      temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]);
+    }
+  }
+
+  return indices;
+}
+
+// Calculate the total number of elements in the array box
+// (totalElems = extent(1) * extent(2) * ... * extent(n))
+static Value CalculateTotalElements(OpBuilder &builder, Location loc,
+                                    Value arrayBox) {
+  auto boxType = cast<fir::BoxType>(arrayBox.getType());
+  auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+  int rank = seqType.getDimension();
+
+  Value totalElems = nullptr;
+  for (int i = 0; i < rank; ++i) {
+    auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+    auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+    Value extent = boxDims.getResult(1);
+    if (i == 0) {
+      totalElems = extent;
+    } else {
+      totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent);
+    }
+  }
+  return totalElems;
+}
+
+// Replace the FortranAAssign runtime call with an unordered do loop
+static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
+                                       omp::TeamsOp teamsOp,
+                                       omp::WorkdistributeOp workdistribute,
+                                       fir::CallOp callOp) {
+  auto destConvert = callOp.getOperand(0).getDefiningOp<fir::ConvertOp>();
+  auto srcConvert = callOp.getOperand(1).getDefiningOp<fir::ConvertOp>();
+
+  Value destBox = destConvert.getValue();
+  Value srcBox = srcConvert.getValue();
+
+  builder.setInsertionPoint(teamsOp);
+  // Load destination array box and source scalar
+  auto arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+  auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
+  auto scalar = builder.create<fir::LoadOp>(loc, scalarValue);
+
+  // Calculate total number of elements (flattened)
+  auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
+  auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value totalElems = CalculateTotalElements(builder, loc, arrayBox);
+
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  builder.setInsertionPointToStart(workdistributeBlock);
+  // Create single unordered loop for flattened array
+  auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true);
+  Block *loopBlock = &doLoop.getRegion().front();
+  builder.setInsertionPointToStart(doLoop.getBody());
+
+  auto flatIdx = loopBlock->getArgument(0);
+  SmallVector<Value> indices =
+      convertFlatToMultiDim(builder, loc, flatIdx, arrayBox);
+  // Use fir.array_coor for linear addressing
+  auto elemPtr = fir::ArrayCoorOp::create(
+      builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox,
+      nullptr, nullptr, ValueRange{indices}, ValueRange{});
+
+  builder.create<fir::StoreOp>(loc, scalar, elemPtr);
+}
+
+// WorkdistributeRuntimeCallLower method finds the runtime calls
+// nested in teams {workdistribute{}} and
+// lowers FortranAAssign to unordered do loop if src is scalar and dest is
+// array. Other runtime calls are not handled currently.
+static bool
+WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
+                               SetVector<omp::TargetOp> &targetOpsToProcess) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  bool changed = false;
+  omp::TargetOp targetOp;
+  // Get the target op parent of teams
+  if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
+    targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+  }
+  for (auto &op : workdistribute.getOps()) {
+    if (&op == terminator) {
+      break;
+    }
+    if (isRuntimeCall(&op)) {
+      rewriter.setInsertionPoint(&op);
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+          "_FortranAAssign") {
+        if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) {
+          // Record the target ops to process later
+          targetOpsToProcess.insert(targetOp);
+          replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute,
+                                     runtimeCall);
+          op.erase();
+          return true;
+        }
+      }
+    }
+  }
+  return changed;
+}
+
 // TeamsWorkdistributeToSingleOp method hoists all the ops inside
 // teams {workdistribute{}} before teams op.
 //
@@ -367,13 +589,24 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
 // B()
 //
 // If only the terminator remains in teams after hoisting, we erase teams op.
-static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+static bool
+TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp,
+                              SetVector<omp::TargetOp> &targetOpsToProcess) {
   auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
   if (!workdistributeOp)
     return false;
   // Get the block containing teamsOp (the parent block).
   Block *parentBlock = teamsOp->getBlock();
   Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  // Record the target ops to process later
+  for (auto &op : workdistributeBlock.getOperations()) {
+    if (shouldParallelize(&op)) {
+      auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+      if (targetOp) {
+        targetOpsToProcess.insert(targetOp);
+      }
+    }
+  }
   auto insertPoint = Block::iterator(teamsOp);
   // Get the range of operations to move (excluding the terminator).
   auto workdistributeBegin = workdistributeBlock.begin();
@@ -762,14 +995,6 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
   return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
 }
 
-// Generate LLVM constant operations for i64 type.
-static mlir::LLVM::ConstantOp
-genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
-  mlir::Type i64Ty = rewriter.getI64Type();
-  mlir::IntegerAttr attr = rewriter.getI64IntegerAttr(value);
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, attr);
-}
-
 // Given a box descriptor, extract the base address of the data it describes.
 // If the box descriptor is a reference, load it first.
 // The base address is returned as an i8* pointer.
@@ -912,6 +1137,46 @@ static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
   return;
 }
 
+// Generate code to replace a Fortran array assignment call with OpenMP
+// runtime calls to perform the equivalent operation on the device.
+// This involves extracting the source and destination pointers from the
+// Fortran array descriptors, retrieving their mapped device pointers (if any),
+// and invoking `omp_target_memcpy` to copy the data on the device.
+static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder,
+                                           mlir::Location loc,
+                                           fir::CallOp callOp,
+                                           mlir::Value device,
+                                           mlir::ModuleOp module) {
+  assert(callOp.getNumResults() == 0 &&
+         "Expected _FortranAAssign to have no results");
+  assert(callOp.getNumOperands() >= 2 &&
+         "Expected _FortranAAssign to have at least two operands");
+
+  // Extract the source and destination pointers from the call operands.
+  mlir::Value dest = callOp.getOperand(0);
+  mlir::Value src = callOp.getOperand(1);
+
+  // Get the base addresses of the source and destination arrays.
+  mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src);
+  mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest);
+
+  // Get the total size in bytes of the data to be copied.
+  mlir::Value dataSize = genDescriptorGetDataSizeInBytes(builder, loc, src);
+
+  // Retrieve the mapped device pointers for source and destination.
+  // If no mapping exists, the original host pointer is used.
+  Value destPtr =
+      genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module);
+  Value srcPtr =
+      genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module);
+  Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
+                                                builder.getI64IntegerAttr(0));
+  // Generate the call to omp_target_memcpy to perform the data copy on the
+  // device.
+  genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, dataSize, zero, zero,
+                         device, module);
+}
+
 // Struct to hold the host eval vars corresponding to loop bounds and steps
 struct HostEvalVars {
   SmallVector<Value> lbs;
@@ -1045,26 +1310,21 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
     }
     // Replace runtime calls with omp versions.
     else if (isRuntimeCall(op)) {
-      rewriter.setInsertionPoint(op);
       fir::CallOp runtimeCall = cast<fir::CallOp>(op);
-      SmallVector<Value> operands = runtimeCall.getOperands();
-      mlir::Location loc = runtimeCall.getLoc();
-      fir::FirOpBuilder builder{rewriter, op};
-      assert(operands.size() == 4);
-      auto fromBaseAddr =
-          genDescriptorGetBaseAddress(builder, loc, operands[1]);
-      auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]);
-      auto dataSizeInBytes =
-          genDescriptorGetDataSizeInBytes(builder, loc, operands[1]);
-
-      Value toPtr =
-          genOmpGetMappedPtrIfPresent(builder, loc, toBaseAddr, device, module);
-      Value fromPtr = genOmpGetMappedPtrIfPresent(builder, loc, fromBaseAddr,
-                                                  device, module);
-      Value zero = genI64Constant(loc, rewriter, 0);
-      genOmpTargetMemcpyCall(builder, loc, toPtr, fromPtr, dataSizeInBytes,
-                             zero, zero, device, module);
-      rewriter.eraseOp(op);
+      if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+          "_FortranAAssign") {
+        rewriter.setInsertionPoint(op);
+        fir::FirOpBuilder builder{rewriter, op};
+
+        mlir::Location loc = runtimeCall.getLoc();
+        genFortranAssignOmpReplacement(builder, loc, runtimeCall, device,
+                                       module);
+        rewriter.eraseOp(op);
+      } else {
+        llvm_unreachable("Unhandled runtime call hoisting.");
+      }
+    } else {
+      llvm_unreachable("Unhandled op hoisting.");
     }
   }
 
@@ -1424,8 +1684,11 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
         isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
     fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice);
     return;
+  }
+  if (splitBefore) {
+    isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
+    return;
   } else {
-    llvm::errs() << "Unhandled case in fissionTarget\n";
     llvm::report_fatal_error("Unhandled case in fissionTarget");
   }
 }
@@ -1442,13 +1705,16 @@ class LowerWorkdistributePass
     moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
       changed |= FissionWorkdistribute(workdistribute);
     });
+    moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+      changed |=
+          WorkdistributeRuntimeCallLower(workdistribute, targetOpsToProcess);
+    });
     moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
       changed |= WorkdistributeDoLower(workdistribute, targetOpsToProcess);
     });
     moduleOp->walk([&](mlir::omp::TeamsOp teams) {
-      changed |= TeamsWorkdistributeToSingleOp(teams);
+      changed |= TeamsWorkdistributeToSingleOp(teams, targetOpsToProcess);
     });
-
     if (changed) {
       bool isTargetDevice =
           llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp)

>From 0a81be8b1b3ee7b310ba8ed807f7eff12b7066fc Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 18 Sep 2025 11:58:26 +0530
Subject: [PATCH 12/13] Add tests for scalar assign

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  |  27 ++---
 .../Lower/OpenMP/workdistribute-saxpy-1d.f90  |  23 ++++
 .../Lower/OpenMP/workdistribute-saxpy-2d.f90  |  26 +++++
 .../Lower/OpenMP/workdistribute-saxpy-3d.f90  |  27 +++++
 ...workdistribute-saxpy-and-scalar-assign.f90 |  33 ++++++
 .../OpenMP/workdistribute-saxpy-two-2d.f90    |  38 ++++++
 .../OpenMP/workdistribute-scalar-assign.f90   |  20 ++++
 ...-workdistribute-runtime-assign-scalar.mlir | 108 ++++++++++++++++++
 8 files changed, 289 insertions(+), 13 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90
 create mode 100644 flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 4a91b074b7cd1..88836c9323cef 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -358,17 +358,15 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
 
 // Check if the enclosed type in fir.ref is fir.box and fir.box encloses array
 static bool isEnclosedTypeRefToBoxArray(Type type) {
-  // Step 1: Check if it's a reference type
+  // Check if it's a reference type
   if (auto refType = dyn_cast<fir::ReferenceType>(type)) {
-    // Step 2: Get the referenced type (should be fir.box)
+    // Get the referenced type (should be fir.box)
     auto referencedType = refType.getEleTy();
-
-    // Step 3: Check if referenced type is a box
+    // Check if referenced type is a box
     if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) {
-      // Step 4: Get the boxed type and check if it's an array
+      // Get the boxed type and check if it's an array
       auto boxedType = boxType.getEleTy();
-
-      // Step 5: Check if boxed type is a sequence (array)
+      // Check if boxed type is a sequence (array)
       return isa<fir::SequenceType>(boxedType);
     }
   }
@@ -377,11 +375,11 @@ static bool isEnclosedTypeRefToBoxArray(Type type) {
 
 // Check if the enclosed type in fir.box is scalar (not array)
 static bool isEnclosedTypeBoxScalar(Type type) {
-  // Step 1: Check if it's a box type
+  // Check if it's a box type
   if (auto boxType = dyn_cast<fir::BoxType>(type)) {
-    // Step 2: Get the boxed type
+    // Get the boxed type
     auto boxedType = boxType.getEleTy();
-    // Step 3: Check if boxed type is NOT a sequence (array)
+    // Check if boxed type is NOT a sequence (array)
     return !isa<fir::SequenceType>(boxedType);
   }
   return false;
@@ -743,7 +741,7 @@ static Type getPtrTypeForOmp(Type ty) {
   if (isPtr(ty))
     return LLVM::LLVMPointerType::get(ty.getContext());
   else
-    return fir::LLVMPointerType::get(ty);
+    return fir::ReferenceType::get(ty);
 }
 
 // allocateTempOmpVar allocates a temporary variable for OpenMP mapping
@@ -806,6 +804,8 @@ static bool usedOutsideSplit(Value v, Operation *split) {
 
 // isRecomputableAfterFission checks if an operation can be recomputed
 static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  // If the op has side effects, it cannot be recomputed.
+  // We consider fir.declare as having no side effects.
   if (isa<fir::DeclareOp>(op))
     return true;
 
@@ -1161,7 +1161,7 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder,
   mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest);
 
   // Get the total size in bytes of the data to be copied.
-  mlir::Value dataSize = genDescriptorGetDataSizeInBytes(builder, loc, src);
+  mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src);
 
   // Retrieve the mapped device pointers for source and destination.
   // If no mapping exists, the original host pointer is used.
@@ -1171,9 +1171,10 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder,
       genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module);
   Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                 builder.getI64IntegerAttr(0));
+
   // Generate the call to omp_target_memcpy to perform the data copy on the
   // device.
-  genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, dataSize, zero, zero,
+  genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero,
                          device, module);
 }
 
diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90
new file mode 100644
index 0000000000000..95c3f37f4720e
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90
@@ -0,0 +1,23 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  use iso_fortran_env
+  real(kind=real32) :: a
+  real(kind=real32), dimension(10) :: x
+  real(kind=real32), dimension(10) :: y
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+
+  !$omp target teams workdistribute
+  y = a * x + y
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+
diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90
new file mode 100644
index 0000000000000..70e82780edb1a
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90
@@ -0,0 +1,26 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute(a, x, y, rows, cols)
+  use iso_fortran_env
+  implicit none
+
+  integer, intent(in) :: rows, cols
+  real(kind=real32) :: a
+  real(kind=real32), dimension(rows, cols) :: x, y
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+  ! CHECK: fir.do_loop
+
+  !$omp target teams workdistribute
+  y = a * x + y
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+
diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90
new file mode 100644
index 0000000000000..d6fa300eaff99
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90
@@ -0,0 +1,27 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute(a, x, y, rows, cols, depth)
+  use iso_fortran_env
+  implicit none
+
+  integer, intent(in) :: rows, cols, depth
+  real(kind=real32) :: a
+  real(kind=real32), dimension(rows, cols, depth) :: x, y
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+  ! CHECK: fir.do_loop
+  ! CHECK: fir.do_loop
+
+  !$omp target teams workdistribute
+  y = a * x + y
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+
diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90
new file mode 100644
index 0000000000000..6b6dc3e3a184f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90
@@ -0,0 +1,33 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute()
+  use iso_fortran_env
+  real(kind=real32) :: a
+  real(kind=real32), dimension(10) :: x
+  real(kind=real32), dimension(10) :: y
+  !$omp target teams workdistribute
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+
+  y = a * x + y
+
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+
+  y = 2.0_real32
+
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+
diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90
new file mode 100644
index 0000000000000..2229ccf34e920
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90
@@ -0,0 +1,38 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute
+subroutine target_teams_workdistribute(a, x, y, rows, cols)
+  use iso_fortran_env
+  implicit none
+
+  integer, intent(in) :: rows, cols
+  real(kind=real32) :: a
+  real(kind=real32), dimension(rows, cols) :: x, y
+
+  !$omp target teams workdistribute
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+  ! CHECK: fir.do_loop
+
+  y = a * x + y
+  
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+  ! CHECK: fir.do_loop
+  
+  y = a * y + x
+
+  !$omp end target teams workdistribute
+end subroutine target_teams_workdistribute
+
+
diff --git a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90
new file mode 100644
index 0000000000000..af94559dfa8cf
--- /dev/null
+++ b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90
@@ -0,0 +1,20 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPtarget_teams_workdistribute_scalar_assign
+subroutine target_teams_workdistribute_scalar_assign()
+  integer :: aa(10)
+
+  ! CHECK: omp.target_data
+  ! CHECK: omp.target
+  ! CHECK: omp.teams
+  ! CHECK: omp.parallel
+  ! CHECK: omp.distribute
+  ! CHECK: omp.wsloop
+  ! CHECK: omp.loop_nest
+  
+  !$omp target teams workdistribute
+  aa = 20
+  !$omp end target teams workdistribute
+
+end subroutine target_teams_workdistribute_scalar_assign
+
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir
new file mode 100644
index 0000000000000..03d5d71df0a82
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir
@@ -0,0 +1,108 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// Test lowering of workdistribute for a scalar assignment within a target teams workdistribute region.
+// The test checks that the scalar assignment is correctly lowered to wsloop and loop_nest operations.
+
+// Example Fortran code:
+// !$omp target teams workdistribute
+// y = 3.0_real32
+// !$omp end target teams workdistribute
+
+
+// CHECK-LABEL:   func.func @x(
+// CHECK:             omp.target {{.*}} {
+// CHECK:               omp.teams {
+// CHECK:                 omp.parallel {
+// CHECK:                   omp.distribute {
+// CHECK:                     omp.wsloop {
+// CHECK:                       omp.loop_nest (%[[VAL_73:.*]]) : index = (%[[VAL_66:.*]]) to (%[[VAL_72:.*]]) inclusive step (%[[VAL_67:.*]]) {
+// CHECK:                         %[[VAL_74:.*]] = arith.constant 0 : index
+// CHECK:                         %[[VAL_75:.*]]:3 = fir.box_dims %[[VAL_64:.*]], %[[VAL_74]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// CHECK:                         %[[VAL_76:.*]] = arith.constant 1 : index
+// CHECK:                         %[[VAL_77:.*]]:3 = fir.box_dims %[[VAL_64]], %[[VAL_76]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// CHECK:                         %[[VAL_78:.*]] = arith.constant 1 : index
+// CHECK:                         %[[VAL_79:.*]] = arith.remsi %[[VAL_73]], %[[VAL_77]]#1 : index
+// CHECK:                         %[[VAL_80:.*]] = arith.addi %[[VAL_79]], %[[VAL_78]] : index
+// CHECK:                         %[[VAL_81:.*]] = arith.divsi %[[VAL_73]], %[[VAL_77]]#1 : index
+// CHECK:                         %[[VAL_82:.*]] = arith.remsi %[[VAL_81]], %[[VAL_75]]#1 : index
+// CHECK:                         %[[VAL_83:.*]] = arith.addi %[[VAL_82]], %[[VAL_78]] : index
+// CHECK:                         %[[VAL_84:.*]] = fir.array_coor %[[VAL_64]] %[[VAL_83]], %[[VAL_80]] : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
+// CHECK:                         fir.store %[[VAL_65:.*]] to %[[VAL_84]] : !fir.ref<f32>
+// CHECK:                         omp.yield
+// CHECK:                       }
+// CHECK:                     } {omp.composite}
+// CHECK:                   } {omp.composite}
+// CHECK:                   omp.terminator
+// CHECK:                 } {omp.composite}
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+// CHECK:         func.func private @_FortranAAssign(!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.ref<i8>, i32) attributes {fir.runtime}
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+func.func @x(%arr : !fir.ref<!fir.array<?x?xf32>>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c78 = arith.constant 78 : index
+    %cst = arith.constant 3.000000e+00 : f32
+    %0 = fir.alloca i32
+    %1 = fir.alloca i32
+    %c10 = arith.constant 10 : index
+    %c20 = arith.constant 20 : index
+    %194 = arith.subi %c10, %c1 : index
+    %195 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%194 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index)
+    %196 = arith.subi %c20, %c1 : index
+    %197 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%196 : index) extent(%c20 : index) stride(%c1 : index) start_idx(%c1 : index)
+    %198 = omp.map.info var_ptr(%arr : !fir.ref<!fir.array<?x?xf32>>, f32) map_clauses(implicit, tofrom) capture(ByRef) bounds(%195, %197) -> !fir.ref<!fir.array<?x?xf32>> {name = "y"}
+    %199 = omp.map.info var_ptr(%1 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = ""}
+    %200 = omp.map.info var_ptr(%0 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = ""}
+    omp.target map_entries(%198 -> %arg5, %199 -> %arg6, %200 -> %arg7 : !fir.ref<!fir.array<?x?xf32>>, !fir.ref<i32>, !fir.ref<i32>) {
+      %c0_0 = arith.constant 0 : index
+      %201 = fir.load %arg7 : !fir.ref<i32>
+      %202 = fir.load %arg6 : !fir.ref<i32>
+      %203 = fir.convert %202 : (i32) -> i64
+      %204 = fir.convert %201 : (i32) -> i64
+      %205 = fir.convert %204 : (i64) -> index
+      %206 = arith.cmpi sgt, %205, %c0_0 : index
+      %207 = fir.convert %203 : (i64) -> index
+      %208 = arith.cmpi sgt, %207, %c0_0 : index
+      %209 = arith.select %208, %207, %c0_0 : index
+      %210 = arith.select %206, %205, %c0_0 : index
+      %211 = fir.shape %210, %209 : (index, index) -> !fir.shape<2>
+      %212 = fir.declare %arg5(%211) {uniq_name = "_QFFaxpy_array_workdistributeEy"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<?x?xf32>>
+      %213 = fir.embox %212(%211) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
+      omp.teams {
+        %214 = fir.alloca !fir.box<!fir.array<?x?xf32>> {pinned}
+        omp.workdistribute {
+          %215 = fir.alloca f32
+          %216 = fir.embox %215 : (!fir.ref<f32>) -> !fir.box<f32>
+          %217 = fir.shape %210, %209 : (index, index) -> !fir.shape<2>
+          %218 = fir.embox %212(%217) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
+          fir.store %218 to %214 : !fir.ref<!fir.box<!fir.array<?x?xf32>>>
+          %219 = fir.address_of(@_QQclXf9c642d28e5bba1f07fa9a090b72f4fc) : !fir.ref<!fir.char<1,78>>
+          %c39_i32 = arith.constant 39 : i32
+          %220 = fir.convert %214 : (!fir.ref<!fir.box<!fir.array<?x?xf32>>>) -> !fir.ref<!fir.box<none>>
+          %221 = fir.convert %216 : (!fir.box<f32>) -> !fir.box<none>
+          %222 = fir.convert %219 : (!fir.ref<!fir.char<1,78>>) -> !fir.ref<i8>
+          fir.call @_FortranAAssign(%220, %221, %222, %c39_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.ref<i8>, i32) -> ()
+          omp.terminator
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    return
+}
+
+func.func private @_FortranAAssign(!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.ref<i8>, i32) attributes {fir.runtime}
+
+fir.global linkonce @_QQclXf9c642d28e5bba1f07fa9a090b72f4fc constant : !fir.char<1,78> {
+  %0 = fir.string_lit "File: /work/github/skc7/llvm-project/build_fomp_reldebinfo/saxpy_tests/\00"(78) : !fir.char<1,78>
+  fir.has_value %0 : !fir.char<1,78>
+}
+}

>From 95ebe8cc1420a14e402c5a3f175593e407deff67 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Sep 2025 18:50:33 +0530
Subject: [PATCH 13/13] Fix Scalar assign bug. And Fix CI tests

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 51 +++++++++++++++----
 .../lower-workdistribute-fission-host.mlir    |  8 +--
 .../lower-workdistribute-fission-target.mlir  |  9 ++--
 3 files changed, 49 insertions(+), 19 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 88836c9323cef..ff62457e0a7da 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -487,11 +487,30 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
   Value destBox = destConvert.getValue();
   Value srcBox = srcConvert.getValue();
 
+  // get defining alloca op of destBox and srcBox
+  auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>();
+
+  if (!destAlloca) {
+    emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n");
+    return;
+  }
+
+  // get the store op that stores to the alloca
+  for (auto user : destAlloca->getUsers()) {
+    if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
+      destBox = storeOp.getValue();
+      break;
+    }
+  }
+
   builder.setInsertionPoint(teamsOp);
-  // Load destination array box and source scalar
-  auto arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+  // Load destination array box (if it's a reference)
+  Value arrayBox = destBox;
+  if (isa<fir::ReferenceType>(destBox.getType()))
+    arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+
   auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
-  auto scalar = builder.create<fir::LoadOp>(loc, scalarValue);
+  Value scalar = builder.create<fir::LoadOp>(loc, scalarValue);
 
   // Calculate total number of elements (flattened)
   auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -543,9 +562,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
   bool changed = false;
   omp::TargetOp targetOp;
   // Get the target op parent of teams
-  if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
-    targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
-  }
+  targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+  SmallVector<Operation *> opsToErase;
   for (auto &op : workdistribute.getOps()) {
     if (&op == terminator) {
       break;
@@ -560,12 +578,15 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
           targetOpsToProcess.insert(targetOp);
           replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute,
                                      runtimeCall);
-          op.erase();
-          return true;
+          opsToErase.push_back(&op);
+          changed = true;
         }
       }
     }
   }
+  for (auto *op : opsToErase) {
+    op->erase();
+  }
   return changed;
 }
 
@@ -911,7 +932,7 @@ static void reloadCacheAndRecompute(
 
   unsigned originalMapVarsSize = targetOp.getMapVars().size();
   unsigned hostEvalVarsSize = hostEvalVars.size();
-  // Create Stores for allocs.
+  // Create load operations for each allocated variable.
   for (unsigned i = 0; i < allocs.size(); ++i) {
     Value original = allocs[i];
     // Get the new block argument for this specific allocated value.
@@ -1196,6 +1217,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
   Block *targetBlock = &targetOp.getRegion().front();
   assert(targetBlock == &targetOp.getRegion().back());
   IRMapping mapping;
+
+  auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp());
+  if (!targetDataOp) {
+    llvm_unreachable("Expected target op to be inside target_data op");
+    return;
+  }
   // create mapping for host_eval_vars
   unsigned hostEvalVarCount = targetOp.getHostEvalVars().size();
   for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) {
@@ -1361,12 +1388,14 @@ static void computeAllocsCacheRecomputable(
        it++) {
     // Check if any of the results are used outside the split point.
     for (auto res : it->getResults()) {
-      if (usedOutsideSplit(res, splitBeforeOp))
+      if (usedOutsideSplit(res, splitBeforeOp)) {
         requiredVals.push_back(res);
+      }
     }
     // If the op is not recomputable, add it to the nonRecomputable set.
-    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+    if (!isRecomputableAfterFission(&*it, splitBeforeOp)) {
       nonRecomputable.insert(&*it);
+    }
   }
   // For each required value, collect its dependencies.
   for (auto requiredVal : requiredVals)
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir
index b4c9598a78f0e..04e60ca8bbf37 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir
@@ -42,10 +42,10 @@
 // CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
 // CHECK:             fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
 // CHECK:             omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
-// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref<index>
+// CHECK:               %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.ref<index>
+// CHECK:               %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.ref<index>
+// CHECK:               %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.ref<!fir.heap<index>>
 // CHECK:               %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index
 // CHECK:               omp.teams {
 // CHECK:                 omp.parallel {
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
index 6e82efb308328..062eb701b52ef 100644
--- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -42,10 +42,10 @@
 // CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
 // CHECK:             fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
 // CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
-// CHECK:               %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<index>
-// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.ref<index>
+// CHECK:               %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.ref<index>
+// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.ref<index>
+// CHECK:               %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref<!fir.heap<index>>
 // CHECK:               %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index
 // CHECK:               omp.teams {
 // CHECK:                 omp.parallel {
@@ -77,6 +77,7 @@
 // CHECK:           return
 // CHECK:         }
 
+
 module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
 func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
   %lb_ref = fir.alloca index {bindc_name = "lb"}



More information about the Mlir-commits mailing list