[Mlir-commits] [flang] [mlir] [flang] Implement workdistribute construct lowering (PR #140523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 25 09:51:34 PDT 2025


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/140523

>From d253390cf40feb805f326dd811f56b54183c8885 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 4 Dec 2023 12:57:36 -0800
Subject: [PATCH 1/2] [Flang] Introduce lower-workdistribute pass for
 workdistribute lowering.

Co-authors: ivanradanov, skc7
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |   4 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 901 ++++++++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   4 +-
 flang/test/Fir/basic-program.fir              |   1 +
 .../OpenMP/lower-workdistribute-doloop.mlir   |  33 +
 .../lower-workdistribute-fission-target.mlir  | 112 +++
 .../OpenMP/lower-workdistribute-fission.mlir  |  71 ++
 .../OpenMP/lower-workdistribute-target.mlir   |  32 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |   3 +
 10 files changed, 1161 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..0885efc716db4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,901 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.execute region.
+static bool shouldParallelize(Operation *op) {
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  if (isRuntimeCall(op)) {
+    return true;
+  }
+  // We cannot parallise anything else
+  return false;
+}
+
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return false;
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+///   omp.parallel {
+///     omp.distribute {
+///     omp.wsloop {
+///       omp.loop_nest
+///         ...
+///       }
+///     }
+///   }
+/// }
+
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    // rewriter.erase(terminatorOp);
+    terminatorOp->erase();
+  }
+  return;
+}
+
+static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+  auto wdLoc = workdistribute->getLoc();
+  if (doLoop && shouldParallelize(doLoop)) {
+    assert(doLoop.getReduceOperands().empty());
+    genParallelOp(wdLoc, rewriter, true);
+    genDistributeOp(wdLoc, rewriter, true);
+    mlir::omp::LoopNestOperands loopNestClauseOps;
+    genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+    genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+    workdistribute.erase();
+    return true;
+  }
+  return false;
+}
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+  auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+  if (!workdistributeOp)
+    return false;
+  // Get the block containing teamsOp (the parent block).
+  Block *parentBlock = teamsOp->getBlock();
+  Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  auto insertPoint = Block::iterator(teamsOp);
+  // Get the range of operations to move (excluding the terminator).
+  auto workdistributeBegin = workdistributeBlock.begin();
+  auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+  // Move the operations from workdistribute block to before teamsOp.
+  parentBlock->getOperations().splice(insertPoint,
+                                      workdistributeBlock.getOperations(),
+                                      workdistributeBegin, workdistributeEnd);
+  // Erase the now-empty workdistributeOp.
+  workdistributeOp.erase();
+  Block &teamsBlock = *teamsOp.getRegion().begin();
+  // Check if only the terminator remains and erase teams op.
+  if (teamsBlock.getOperations().size() == 1 &&
+      teamsBlock.getTerminator() != nullptr) {
+    teamsOp.erase();
+  }
+  return true;
+}
+
+struct SplitTargetResult {
+  omp::TargetOp targetOp;
+  omp::TargetDataOp dataOp;
+};
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+
+  SmallVector<omp::MapInfoOp> mapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    mapInfos.push_back(mapInfo);
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Value> innerMapInfos;
+  SmallVector<Value> outerMapInfos;
+
+  for (auto mapInfo : mapInfos) {
+    auto originalMapType =
+        (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+    auto originalCaptureType = mapInfo.getMapCaptureType();
+    llvm::omp::OpenMPOffloadMappingFlags newMapType;
+    mlir::omp::VariableCaptureKind newCaptureType;
+
+    if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+      newMapType = originalMapType;
+      newCaptureType = originalCaptureType;
+    } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+      newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+      newCaptureType = originalCaptureType;
+      outerMapInfos.push_back(mapInfo);
+    } else {
+      llvm_unreachable("Unhandled case");
+    }
+    auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+    innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+        rewriter.getIntegerType(64, false),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            newMapType)));
+    innerMapInfo.setMapCaptureType(newCaptureType);
+    innerMapInfos.push_back(innerMapInfo.getResult());
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+
+  auto newTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+                              newTargetOp.getRegion().begin());
+
+  rewriter.replaceOp(targetOp, newTargetOp);
+  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  } else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+  };
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+};
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+  return false;
+}
+
+struct SplitResult {
+  omp::TargetOp preTargetOp;
+  omp::TargetOp isolatedTargetOp;
+  omp::TargetOp postTargetOp;
+};
+
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+                                       SetVector<Operation *> &nonRecomputable,
+                                       SetVector<Operation *> &toCache,
+                                       SetVector<Operation *> &toRecompute) {
+  Operation *op = v.getDefiningOp();
+  if (!op) {
+    assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
+    return;
+  }
+  if (nonRecomputable.contains(op)) {
+    toCache.insert(op);
+    return;
+  }
+  toRecompute.insert(op);
+  for (auto opr : op->getOperands())
+    collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+}
+
+static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
+                                    MLIRContext &ctx, IRMapping &mapping,
+                                    Operation *splitBefore, Block *targetBlock,
+                                    Block *newTargetBlock,
+                                    SmallVector<Value> &allocs,
+                                    SetVector<Operation *> &toRecompute) {
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = newTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    mapping.map(originalArg, newArg);
+  }
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  for (auto original : allocs) {
+    Value newArg = newTargetBlock->addArgument(
+        getPtrTypeForOmp(original.getType()), original.getLoc());
+    Value restored;
+    if (isPtr(original.getType())) {
+      restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+      if (!isa<LLVM::LLVMPointerType>(original.getType()))
+        restored =
+            rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+    } else {
+      restored = rewriter.create<fir::LoadOp>(loc, newArg);
+    }
+    mapping.map(original, restored);
+  }
+  for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
+    if (toRecompute.contains(&*it))
+      rewriter.clone(*it, mapping);
+  }
+}
+
+static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
+                             RewriterBase &rewriter) {
+  auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+  MLIRContext &ctx = *targetOp.getContext();
+  assert(targetOp);
+  auto loc = targetOp.getLoc();
+  auto *targetBlock = &targetOp.getRegion().front();
+  rewriter.setInsertionPoint(targetOp);
+
+  auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+  auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+  SmallVector<Value> requiredVals;
+  SetVector<Operation *> toCache;
+  SetVector<Operation *> toRecompute;
+  SetVector<Operation *> nonRecomputable;
+  SmallVector<Value> allocs;
+
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
+    for (auto res : it->getResults()) {
+      if (usedOutsideSplit(res, splitBeforeOp))
+        requiredVals.push_back(res);
+    }
+    if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+      nonRecomputable.insert(&*it);
+  }
+
+  for (auto requiredVal : requiredVals)
+    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+
+  for (Operation *op : toCache) {
+    for (auto res : op->getResults()) {
+      auto alloc =
+          allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+      allocs.push_back(res);
+      preMapOperands.push_back(alloc.from);
+      postMapOperands.push_back(alloc.to);
+    }
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto preTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  auto *preTargetBlock = rewriter.createBlock(
+      &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+  IRMapping preMapping;
+  for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+    auto originalArg = targetBlock->getArgument(i);
+    auto newArg = preTargetBlock->addArgument(originalArg.getType(),
+                                              originalArg.getLoc());
+    preMapping.map(originalArg, newArg);
+  }
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
+    rewriter.clone(*it, preMapping);
+
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+  for (auto original : allocs) {
+    Value toStore = preMapping.lookup(original);
+    auto newArg = preTargetBlock->addArgument(
+        getPtrTypeForOmp(original.getType()), original.getLoc());
+    if (isPtr(original.getType())) {
+      if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+        toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
+      rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+    } else {
+      rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+    }
+  }
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  rewriter.setInsertionPoint(targetOp);
+
+  auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+
+  auto *isolatedTargetBlock =
+      rewriter.createBlock(&isolatedTargetOp.getRegion(),
+                           isolatedTargetOp.getRegion().begin(), {}, {});
+
+  IRMapping isolatedMapping;
+  reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
+                          targetBlock, isolatedTargetBlock, allocs,
+                          toRecompute);
+  rewriter.clone(*splitBeforeOp, isolatedMapping);
+  rewriter.create<omp::TerminatorOp>(loc);
+
+  omp::TargetOp postTargetOp = nullptr;
+
+  if (splitAfter) {
+    rewriter.setInsertionPoint(targetOp);
+    postTargetOp = rewriter.create<omp::TargetOp>(
+        targetOp.getLoc(), targetOp.getAllocateVars(),
+        targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+        targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+        targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+        targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+        targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+        targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+        postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+        targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+        targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+    auto *postTargetBlock = rewriter.createBlock(
+        &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+    IRMapping postMapping;
+    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
+                            targetBlock, postTargetBlock, allocs, toRecompute);
+
+    assert(splitBeforeOp->getNumResults() == 0 ||
+           llvm::all_of(splitBeforeOp->getResults(),
+                        [](Value result) { return result.use_empty(); }));
+
+    for (auto it = std::next(splitBeforeOp->getIterator());
+         it != targetBlock->end(); it++)
+      rewriter.clone(*it, postMapping);
+  }
+
+  rewriter.eraseOp(targetOp);
+  return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+}
+
+static mlir::LLVM::ConstantOp
+genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
+  mlir::Type i32Ty = rewriter.getI32Type();
+  mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
+}
+
+static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
+
+static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Block *targetBlock = &targetOp.getRegion().front();
+  assert(targetBlock == &targetOp.getRegion().back());
+  IRMapping mapping;
+  for (auto map :
+       zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
+    Value mapInfo = std::get<0>(map);
+    BlockArgument arg = std::get<1>(map);
+    Operation *op = mapInfo.getDefiningOp();
+    assert(op);
+    auto mapInfoOp = cast<omp::MapInfoOp>(op);
+    mapping.map(arg, mapInfoOp.getVarPtr());
+  }
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Operation *> opsToReplace;
+  Value device = targetOp.getDevice();
+  if (!device) {
+    device = genI32Constant(targetOp.getLoc(), rewriter, 0);
+  }
+  for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
+       it != end; ++it) {
+    auto *op = &*it;
+    if (isRuntimeCall(op)) {
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      auto module = runtimeCall->getParentOfType<ModuleOp>();
+      auto callee =
+          cast<func::FuncOp>(module.lookupSymbol(runtimeCall.getCalleeAttr()));
+      std::string newCalleeName = (callee.getName() + "_omp").str();
+      mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+      func::FuncOp newCallee =
+          cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
+      if (!newCallee) {
+        SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
+        argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
+        newCallee = moduleBuilder.create<func::FuncOp>(
+            callee->getLoc(), newCalleeName,
+            FunctionType::get(rewriter.getContext(), argTypes,
+                              callee.getFunctionType().getResults()));
+        if (callee.getArgAttrs())
+          newCallee.setArgAttrsAttr(*callee.getArgAttrs());
+        if (callee.getResAttrs())
+          newCallee.setResAttrsAttr(*callee.getResAttrs());
+        newCallee.setSymVisibility(callee.getSymVisibility());
+        newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary());
+      }
+      SmallVector<Value> operands = runtimeCall.getOperands();
+      operands.push_back(device);
+      auto tmpCall = rewriter.create<fir::CallOp>(
+          runtimeCall.getLoc(), runtimeCall.getResultTypes(),
+          SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
+          runtimeCall.getFastmathAttr());
+      Operation *newCall = rewriter.clone(*tmpCall, mapping);
+      mapping.map(&*it, newCall);
+      rewriter.eraseOp(tmpCall);
+    } else {
+      Operation *clonedOp = rewriter.clone(*op, mapping);
+      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+        opsToReplace.push_back(clonedOp);
+      for (unsigned i = 0; i < op->getNumResults(); ++i) {
+        mapping.map(op->getResult(i), clonedOp->getResult(i));
+      }
+    }
+  }
+  for (Operation *op : opsToReplace) {
+    if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
+      rewriter.setInsertionPoint(allocOp);
+      auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
+          allocOp.getLoc(), rewriter.getI64Type(), device,
+          allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
+          allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
+          allocOp.getShape());
+      auto firConvertOp = rewriter.create<fir::ConvertOp>(
+          allocOp.getLoc(), allocOp.getResult().getType(),
+          ompAllocmemOp.getResult());
+      rewriter.replaceOp(allocOp, firConvertOp.getResult());
+    } else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
+      rewriter.setInsertionPoint(freeOp);
+      auto firConvertOp = rewriter.create<fir::ConvertOp>(
+          freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
+      rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
+                                            firConvertOp.getResult());
+      rewriter.eraseOp(freeOp);
+    }
+  }
+  rewriter.eraseOp(targetOp);
+}
+
+void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
+  auto tuple = getNestedOpToIsolate(targetOp);
+  if (!tuple) {
+    LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
+    moveToHost(targetOp, rewriter);
+    return;
+  }
+
+  Operation *toIsolate = std::get<0>(*tuple);
+  bool splitBefore = !std::get<1>(*tuple);
+  bool splitAfter = !std::get<2>(*tuple);
+
+  if (splitBefore && splitAfter) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    moveToHost(res.preTargetOp, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+  if (splitBefore) {
+    auto res = isolateOp(toIsolate, splitAfter, rewriter);
+    moveToHost(res.preTargetOp, rewriter);
+    return;
+  }
+  if (splitAfter) {
+    auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
+    fissionTarget(res.postTargetOp, rewriter);
+    return;
+  }
+}
+
+class LowerWorkdistributePass
+    : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    auto moduleOp = getOperation();
+    bool changed = false;
+    moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+      changed |= FissionWorkdistribute(workdistribute);
+    });
+    moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+      changed |= WorkdistributeDoLower(workdistribute);
+    });
+    moduleOp->walk([&](mlir::omp::TeamsOp teams) {
+      changed |= TeamsWorkdistributeToSingleOp(teams);
+    });
+
+    if (changed) {
+      SmallVector<omp::TargetOp> targetOps;
+      moduleOp->walk(
+          [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
+      IRRewriter rewriter(&context);
+      for (auto targetOp : targetOps) {
+        auto res = splitTargetData(targetOp, rewriter);
+        if (res)
+          fissionTarget(res->targetOp, rewriter);
+      }
+    }
+  }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index ca8e820608688..4b4e4f7ef19aa 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -294,8 +294,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP)
+  if (enableOpenMP) {
     pm.addPass(flangomp::createLowerWorkshare());
+    pm.addPass(flangomp::createLowerWorkdistribute());
+  }
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index c9fe53bf093a1..3a94ea8a476a1 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -69,6 +69,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:     InlineHLFIRAssign
 // PASSES-NEXT:   ConvertHLFIRtoFIR
 // PASSES-NEXT:   LowerWorkshare
+// PASSES-NEXT:   LowerWorkdistribute
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
new file mode 100644
index 0000000000000..00d10d6264ec9
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -0,0 +1,33 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x({{.*}})
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK:                     %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:                     fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+  omp.teams {
+    omp.workdistribute { 
+      fir.do_loop %iv = %lb to %ub step %step unordered {
+        %zero = arith.constant 0 : index
+        fir.store %zero to %addr : !fir.ref<index>
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
new file mode 100644
index 0000000000000..19bdb9ce10fbd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -0,0 +1,112 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @x(
+// CHECK:           %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
+// CHECK:           fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref<index>
+// CHECK:           %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"}
+// CHECK:           fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref<index>
+// CHECK:           %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"}
+// CHECK:           fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref<index>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK:           %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK:           omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+// CHECK:             %[[VAL_11:.*]] = fir.alloca index
+// CHECK:             %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_14:.*]] = fir.alloca index
+// CHECK:             %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_17:.*]] = fir.alloca index
+// CHECK:             %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_20:.*]] = fir.alloca !fir.heap<index>
+// CHECK:             %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
+// CHECK:             %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
+// CHECK:             %[[VAL_23:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref<index>
+// CHECK:             %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref<index>
+// CHECK:             %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
+// CHECK:             %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array<i32: 1, 0, 1>, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap<index>
+// CHECK:             fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
+// CHECK:             fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK:               %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
+// CHECK:               %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK:               %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index
+// CHECK:               omp.teams {
+// CHECK:                 omp.parallel {
+// CHECK:                   omp.distribute {
+// CHECK:                     omp.wsloop {
+// CHECK:                       omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) {
+// CHECK:                         fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap<index>
+// CHECK:                         omp.yield
+// CHECK:                       }
+// CHECK:                     } {omp.composite}
+// CHECK:                   } {omp.composite}
+// CHECK:                   omp.terminator
+// CHECK:                 } {omp.composite}
+// CHECK:                 omp.terminator
+// CHECK:               }
+// CHECK:               omp.terminator
+// CHECK:             }
+// CHECK:             %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
+// CHECK:             %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK:             %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
+// CHECK:             %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK:             %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index
+// CHECK:             fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap<index>
+// CHECK:             %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:             "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap<index>) -> ()
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
+  %lb_ref = fir.alloca index {bindc_name = "lb"}
+  fir.store %lb to %lb_ref : !fir.ref<index>
+  %ub_ref = fir.alloca index {bindc_name = "ub"}
+  fir.store %ub to %ub_ref : !fir.ref<index>
+  %step_ref = fir.alloca index {bindc_name = "step"}
+  fir.store %step to %step_ref : !fir.ref<index>
+
+  %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+  %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+  %step_map = omp.map.info var_ptr(%step_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+  %addr_map = omp.map.info var_ptr(%addr : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+
+  omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+    %lb_val = fir.load %ARG0 : !fir.ref<index>
+    %ub_val = fir.load %ARG1 : !fir.ref<index>
+    %step_val = fir.load %ARG2 : !fir.ref<index>
+    %one = arith.constant 1 : index
+
+    %20 = arith.addi %ub_val, %ub_val : index
+    omp.teams {
+      omp.workdistribute {
+        %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"}
+        fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered {
+          fir.store %20 to %dev_mem : !fir.heap<index>
+        }
+        fir.store %lb_val to %dev_mem : !fir.heap<index>
+        fir.freemem %dev_mem : !fir.heap<index>
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
new file mode 100644
index 0000000000000..c562b7009664d
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -0,0 +1,71 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_fission_workdistribute(
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 9 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK:           fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           omp.teams {
+// CHECK:             omp.parallel {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// CHECK:                   omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK:                     %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK:                     %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:                     fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK:                     omp.yield
+// CHECK:                   }
+// CHECK:                 } {omp.composite}
+// CHECK:               } {omp.composite}
+// CHECK:               omp.terminator
+// CHECK:             } {omp.composite}
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK:           fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
+// CHECK:             %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:             fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK:           fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK:           return
+// CHECK:         }
+module {
+func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
+  return
+}
+func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
+  return
+}
+func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
+  %c0_idx = arith.constant 0 : index
+  %c1_idx = arith.constant 1 : index
+  %c9_idx = arith.constant 9 : index
+  %float_val = arith.constant 5.0 : f32
+  omp.teams   {
+    omp.workdistribute   {
+      fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
+      fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
+        %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
+        %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
+      }
+      fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
+      fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
+      fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
+        %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+        fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
+      }
+      %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
+      fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
+      omp.terminator  
+    }
+    omp.terminator
+  }
+  return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
new file mode 100644
index 0000000000000..d96068b26ca2f
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
@@ -0,0 +1,32 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_nested_derived_type_map_operand_and_block_addition(
+// CHECK-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:           %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+// CHECK:           %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+// CHECK:           %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"}
+// CHECK:           %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK:           %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+// CHECK:           %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
+// CHECK:           %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+// CHECK:           omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK:             omp.terminator
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {        
+  %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+  %2 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %4 = fir.coordinate_of %2, i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+  %5 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%n%i"}
+  %7 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+  %9 = fir.coordinate_of %7, r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+  %10 = omp.map.info var_ptr(%9 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%n%r"}
+  %11 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+  omp.target map_entries(%11 -> %arg1 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da39b19388627..01ee398641ee5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5369,6 +5369,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
+  if (targetOp.getHostEvalVars().empty())
+    numLoops = 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);

>From 4907705c6319ae135185af58836ef9c6a0705e67 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 25 Jul 2025 22:20:12 +0530
Subject: [PATCH 2/2] Fix hoisting declare ops out of omp.target

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  | 34 +++++++++++++++++--
 1 file changed, 31 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 0885efc716db4..3f78727450f31 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
 
-  rewriter.replaceOp(targetOp, newTargetOp);
+  rewriter.replaceOp(targetOp, targetDataOp);
   return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
 }
 
@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       rewriter.eraseOp(tmpCall);
     } else {
       Operation *clonedOp = rewriter.clone(*op, mapping);
-      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
-        opsToReplace.push_back(clonedOp);
       for (unsigned i = 0; i < op->getNumResults(); ++i) {
         mapping.map(op->getResult(i), clonedOp->getResult(i));
       }
+      // fir.declare changes its type when hoisting it out of omp.target to
+      // omp.target_data Introduce a load, if original declareOp input is not of
+      // reference type, but cloned delcareOp input is reference type.
+      if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
+        auto originalDeclareOp = cast<fir::DeclareOp>(op);
+        Type originalInType = originalDeclareOp.getMemref().getType();
+        Type clonedInType = clonedDeclareOp.getMemref().getType();
+
+        fir::ReferenceType originalRefType =
+            dyn_cast<fir::ReferenceType>(originalInType);
+        fir::ReferenceType clonedRefType =
+            dyn_cast<fir::ReferenceType>(clonedInType);
+        if (!originalRefType && clonedRefType) {
+          Type clonedEleTy = clonedRefType.getElementType();
+          if (clonedEleTy == originalDeclareOp.getType()) {
+            opsToReplace.push_back(clonedOp);
+          }
+        }
+      }
+      if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+        opsToReplace.push_back(clonedOp);
     }
   }
   for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
                                             firConvertOp.getResult());
       rewriter.eraseOp(freeOp);
+    } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
+      Type clonedInType = clonedDeclareOp.getMemref().getType();
+      fir::ReferenceType clonedRefType =
+          dyn_cast<fir::ReferenceType>(clonedInType);
+      Type clonedEleTy = clonedRefType.getElementType();
+      rewriter.setInsertionPoint(op);
+      Value loadedValue = rewriter.create<fir::LoadOp>(
+          clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
+      clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
     }
   }
   rewriter.eraseOp(targetOp);



More information about the Mlir-commits mailing list