[Mlir-commits] [flang] [mlir] [flang] Implement workdistribute construct lowering (PR #140523)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 25 09:51:34 PDT 2025
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/140523
>From d253390cf40feb805f326dd811f56b54183c8885 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 4 Dec 2023 12:57:36 -0800
Subject: [PATCH 1/2] [Flang] Introduce lower-workdistribute pass for
workdistribute lowering.
Co-authors: ivanradanov, skc7
---
.../include/flang/Optimizer/OpenMP/Passes.td | 4 +
flang/lib/Optimizer/OpenMP/CMakeLists.txt | 1 +
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 901 ++++++++++++++++++
flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +-
flang/test/Fir/basic-program.fir | 1 +
.../OpenMP/lower-workdistribute-doloop.mlir | 33 +
.../lower-workdistribute-fission-target.mlir | 112 +++
.../OpenMP/lower-workdistribute-fission.mlir | 71 ++
.../OpenMP/lower-workdistribute-target.mlir | 32 +
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 +
10 files changed, 1161 insertions(+), 1 deletion(-)
create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
let summary = "Lower workshare construct";
}
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+ let summary = "Lower workdistribute construct";
+}
+
def GenericLoopConversionPass
: Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ LowerWorkdistribute.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..0885efc716db4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,901 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+static bool isRuntimeCall(Operation *op) {
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+ auto callee = callOp.getCallee();
+ if (!callee)
+ return false;
+ auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+ if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+ return true;
+ }
+ return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.execute region.
+static bool shouldParallelize(Operation *op) {
+ if (llvm::any_of(op->getResults(),
+ [](OpResult v) -> bool { return !v.use_empty(); }))
+ return false;
+ // We will parallelize unordered loops - these come from array syntax
+ if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+ auto unordered = loop.getUnordered();
+ if (!unordered)
+ return false;
+ return *unordered;
+ }
+ if (isRuntimeCall(op)) {
+ return true;
+ }
+ // We cannot parallise anything else
+ return false;
+}
+
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+ if (op->getNumRegions() != 1)
+ return nullptr;
+ auto ®ion = op->getRegion(0);
+ if (region.getBlocks().size() != 1)
+ return nullptr;
+ auto *block = ®ion.front();
+ auto *firstOp = &block->front();
+ if (auto nested = dyn_cast<T>(firstOp))
+ if (firstOp->getNextNode() == block->getTerminator())
+ return nested;
+ return nullptr;
+}
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// A()
+/// B()
+/// C()
+/// D()
+/// E()
+/// }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+/// omp.workdistribute {
+/// B()
+/// }
+/// }
+/// C()
+/// omp.teams {
+/// omp.workdistribute {
+/// D()
+/// }
+/// }
+/// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+ OpBuilder rewriter(workdistribute);
+ auto loc = workdistribute->getLoc();
+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+ if (!teams) {
+ emitError(loc, "workdistribute not nested in teams\n");
+ return false;
+ }
+ if (workdistribute.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "workdistribute with multiple blocks\n");
+ return false;
+ }
+ if (teams.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "teams with multiple blocks\n");
+ return false;
+ }
+
+ auto *teamsBlock = &teams.getRegion().front();
+ bool changed = false;
+ // Move the ops inside teams and before workdistribute outside.
+ IRMapping irMapping;
+ llvm::SmallVector<Operation *> teamsHoisted;
+ for (auto &op : teams.getOps()) {
+ if (&op == workdistribute) {
+ break;
+ }
+ if (shouldParallelize(&op)) {
+ emitError(loc, "teams has parallelize ops before first workdistribute\n");
+ return false;
+ } else {
+ rewriter.setInsertionPoint(teams);
+ rewriter.clone(op, irMapping);
+ teamsHoisted.push_back(&op);
+ changed = true;
+ }
+ }
+ for (auto *op : llvm::reverse(teamsHoisted)) {
+ op->replaceAllUsesWith(irMapping.lookup(op));
+ op->erase();
+ }
+
+ // While we have unhandled operations in the original workdistribute
+ auto *workdistributeBlock = &workdistribute.getRegion().front();
+ auto *terminator = workdistributeBlock->getTerminator();
+ while (&workdistributeBlock->front() != terminator) {
+ rewriter.setInsertionPoint(teams);
+ IRMapping mapping;
+ llvm::SmallVector<Operation *> hoisted;
+ Operation *parallelize = nullptr;
+ for (auto &op : workdistribute.getOps()) {
+ if (&op == terminator) {
+ break;
+ }
+ if (shouldParallelize(&op)) {
+ parallelize = &op;
+ break;
+ } else {
+ rewriter.clone(op, mapping);
+ hoisted.push_back(&op);
+ changed = true;
+ }
+ }
+
+ for (auto *op : llvm::reverse(hoisted)) {
+ op->replaceAllUsesWith(mapping.lookup(op));
+ op->erase();
+ }
+
+ if (parallelize && hoisted.empty() &&
+ parallelize->getNextNode() == terminator)
+ break;
+ if (parallelize) {
+ auto newTeams = rewriter.cloneWithoutRegions(teams);
+ auto *newTeamsBlock = rewriter.createBlock(
+ &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+ for (auto arg : teamsBlock->getArguments())
+ newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+ auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+ rewriter.create<omp::TerminatorOp>(loc);
+ rewriter.createBlock(&newWorkdistribute.getRegion(),
+ newWorkdistribute.getRegion().begin(), {}, {});
+ auto *cloned = rewriter.clone(*parallelize);
+ parallelize->replaceAllUsesWith(cloned);
+ parallelize->erase();
+ rewriter.create<omp::TerminatorOp>(loc);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// fir.do_loop unoredered {
+/// ...
+/// }
+/// }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+/// omp.parallel {
+/// omp.distribute {
+/// omp.wsloop {
+/// omp.loop_nest
+/// ...
+/// }
+/// }
+/// }
+/// }
+
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+ auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+ parallelOp.setComposite(composite);
+ rewriter.createBlock(¶llelOp.getRegion());
+ rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ return;
+}
+
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+ mlir::omp::DistributeOperands distributeClauseOps;
+ auto distributeOp =
+ rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+ distributeOp.setComposite(composite);
+ auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+ rewriter.setInsertionPointToStart(distributeBlock);
+ return;
+}
+
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+ mlir::omp::LoopNestOperands &loopNestClauseOps) {
+ assert(loopNestClauseOps.loopLowerBounds.empty() &&
+ "Loop nest bounds were already emitted!");
+ loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+ loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+ loopNestClauseOps.loopSteps.push_back(loop.getStep());
+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+ const mlir::omp::LoopNestOperands &clauseOps,
+ bool composite) {
+
+ auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+ wsloopOp.setComposite(composite);
+ rewriter.createBlock(&wsloopOp.getRegion());
+
+ auto loopNestOp =
+ rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+ // Clone the loop's body inside the loop nest construct using the
+ // mapped values.
+ rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+ loopNestOp.getRegion().begin());
+ Block *clonedBlock = &loopNestOp.getRegion().back();
+ mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+ // Erase fir.result op of do loop and create yield op.
+ if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+ rewriter.setInsertionPoint(terminatorOp);
+ rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+ // rewriter.erase(terminatorOp);
+ terminatorOp->erase();
+ }
+ return;
+}
+
+static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+ OpBuilder rewriter(workdistribute);
+ auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+ auto wdLoc = workdistribute->getLoc();
+ if (doLoop && shouldParallelize(doLoop)) {
+ assert(doLoop.getReduceOperands().empty());
+ genParallelOp(wdLoc, rewriter, true);
+ genDistributeOp(wdLoc, rewriter, true);
+ mlir::omp::LoopNestOperands loopNestClauseOps;
+ genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+ genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+ workdistribute.erase();
+ return true;
+ }
+ return false;
+}
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// A()
+/// B()
+/// }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+ auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+ if (!workdistributeOp)
+ return false;
+ // Get the block containing teamsOp (the parent block).
+ Block *parentBlock = teamsOp->getBlock();
+ Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+ auto insertPoint = Block::iterator(teamsOp);
+ // Get the range of operations to move (excluding the terminator).
+ auto workdistributeBegin = workdistributeBlock.begin();
+ auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+ // Move the operations from workdistribute block to before teamsOp.
+ parentBlock->getOperations().splice(insertPoint,
+ workdistributeBlock.getOperations(),
+ workdistributeBegin, workdistributeEnd);
+ // Erase the now-empty workdistributeOp.
+ workdistributeOp.erase();
+ Block &teamsBlock = *teamsOp.getRegion().begin();
+ // Check if only the terminator remains and erase teams op.
+ if (teamsBlock.getOperations().size() == 1 &&
+ teamsBlock.getTerminator() != nullptr) {
+ teamsOp.erase();
+ }
+ return true;
+}
+
+struct SplitTargetResult {
+ omp::TargetOp targetOp;
+ omp::TargetDataOp dataOp;
+};
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+ RewriterBase &rewriter) {
+ auto loc = targetOp->getLoc();
+ if (targetOp.getMapVars().empty()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << DEBUG_TYPE << " target region has no data maps\n");
+ return std::nullopt;
+ }
+
+ SmallVector<omp::MapInfoOp> mapInfos;
+ for (auto opr : targetOp.getMapVars()) {
+ auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+ mapInfos.push_back(mapInfo);
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+ SmallVector<Value> innerMapInfos;
+ SmallVector<Value> outerMapInfos;
+
+ for (auto mapInfo : mapInfos) {
+ auto originalMapType =
+ (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+ auto originalCaptureType = mapInfo.getMapCaptureType();
+ llvm::omp::OpenMPOffloadMappingFlags newMapType;
+ mlir::omp::VariableCaptureKind newCaptureType;
+
+ if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+ newMapType = originalMapType;
+ newCaptureType = originalCaptureType;
+ } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+ newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+ newCaptureType = originalCaptureType;
+ outerMapInfos.push_back(mapInfo);
+ } else {
+ llvm_unreachable("Unhandled case");
+ }
+ auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+ innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+ rewriter.getIntegerType(64, false),
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ newMapType)));
+ innerMapInfo.setMapCaptureType(newCaptureType);
+ innerMapInfos.push_back(innerMapInfo.getResult());
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+ auto device = targetOp.getDevice();
+ auto ifExpr = targetOp.getIfExpr();
+ auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+ auto devicePtrVars = targetOp.getIsDevicePtrVars();
+ auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+ loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+ auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+ rewriter.create<mlir::omp::TerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(taregtDataBlock);
+
+ auto newTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+ newTargetOp.getRegion().begin());
+
+ rewriter.replaceOp(targetOp, newTargetOp);
+ return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+ if (targetOp.getRegion().empty())
+ return std::nullopt;
+ auto *targetBlock = &targetOp.getRegion().front();
+ for (auto &op : *targetBlock) {
+ bool first = &op == &*targetBlock->begin();
+ bool last = op.getNextNode() == targetBlock->getTerminator();
+ if (first && last)
+ return std::nullopt;
+
+ if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+ return {{&op, first, last}};
+ }
+ return std::nullopt;
+}
+
+struct TempOmpVar {
+ omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+ return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+ if (isPtr(ty))
+ return LLVM::LLVMPointerType::get(ty.getContext());
+ else
+ return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+ RewriterBase &rewriter) {
+ MLIRContext &ctx = *ty.getContext();
+ Value alloc;
+ Type allocType;
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+ if (isPtr(ty)) {
+ Type intTy = rewriter.getI32Type();
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+ allocType = llvmPtrTy;
+ alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+ allocType = intTy;
+ } else {
+ allocType = ty;
+ alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+ }
+ auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+ return rewriter.create<omp::MapInfoOp>(
+ loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+ mappingFlags),
+ rewriter.getAttr<omp::VariableCaptureKindAttr>(
+ omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/Value{},
+ /*members=*/SmallVector<Value>{},
+ /*member_index=*/mlir::ArrayAttr{},
+ /*bounds=*/ValueRange(),
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+ };
+ uint64_t mapFrom =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+ uint64_t mapTo =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+ auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+ auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+ return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+ if (!split)
+ return false;
+ auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+ auto *targetBlock = &targetOp.getRegion().front();
+ for (auto *user : v.getUsers()) {
+ while (user->getBlock() != targetBlock) {
+ user = user->getParentOp();
+ }
+ if (!user->isBeforeInBlock(split))
+ return true;
+ }
+ return false;
+};
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+ if (isa<fir::DeclareOp>(op))
+ return true;
+
+ llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+ MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!interface) {
+ return false;
+ }
+ interface.getEffects(effects);
+ if (effects.empty())
+ return true;
+ return false;
+}
+
+struct SplitResult {
+ omp::TargetOp preTargetOp;
+ omp::TargetOp isolatedTargetOp;
+ omp::TargetOp postTargetOp;
+};
+
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+ SetVector<Operation *> &nonRecomputable,
+ SetVector<Operation *> &toCache,
+ SetVector<Operation *> &toRecompute) {
+ Operation *op = v.getDefiningOp();
+ if (!op) {
+ assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
+ return;
+ }
+ if (nonRecomputable.contains(op)) {
+ toCache.insert(op);
+ return;
+ }
+ toRecompute.insert(op);
+ for (auto opr : op->getOperands())
+ collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+ toRecompute);
+}
+
+static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
+ MLIRContext &ctx, IRMapping &mapping,
+ Operation *splitBefore, Block *targetBlock,
+ Block *newTargetBlock,
+ SmallVector<Value> &allocs,
+ SetVector<Operation *> &toRecompute) {
+ for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+ auto originalArg = targetBlock->getArgument(i);
+ auto newArg = newTargetBlock->addArgument(originalArg.getType(),
+ originalArg.getLoc());
+ mapping.map(originalArg, newArg);
+ }
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+ for (auto original : allocs) {
+ Value newArg = newTargetBlock->addArgument(
+ getPtrTypeForOmp(original.getType()), original.getLoc());
+ Value restored;
+ if (isPtr(original.getType())) {
+ restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+ if (!isa<LLVM::LLVMPointerType>(original.getType()))
+ restored =
+ rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+ } else {
+ restored = rewriter.create<fir::LoadOp>(loc, newArg);
+ }
+ mapping.map(original, restored);
+ }
+ for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
+ if (toRecompute.contains(&*it))
+ rewriter.clone(*it, mapping);
+ }
+}
+
+static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
+ RewriterBase &rewriter) {
+ auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+ MLIRContext &ctx = *targetOp.getContext();
+ assert(targetOp);
+ auto loc = targetOp.getLoc();
+ auto *targetBlock = &targetOp.getRegion().front();
+ rewriter.setInsertionPoint(targetOp);
+
+ auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+ auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+ SmallVector<Value> requiredVals;
+ SetVector<Operation *> toCache;
+ SetVector<Operation *> toRecompute;
+ SetVector<Operation *> nonRecomputable;
+ SmallVector<Value> allocs;
+
+ for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+ it++) {
+ for (auto res : it->getResults()) {
+ if (usedOutsideSplit(res, splitBeforeOp))
+ requiredVals.push_back(res);
+ }
+ if (!isRecomputableAfterFission(&*it, splitBeforeOp))
+ nonRecomputable.insert(&*it);
+ }
+
+ for (auto requiredVal : requiredVals)
+ collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+ toRecompute);
+
+ for (Operation *op : toCache) {
+ for (auto res : op->getResults()) {
+ auto alloc =
+ allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+ allocs.push_back(res);
+ preMapOperands.push_back(alloc.from);
+ postMapOperands.push_back(alloc.to);
+ }
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+
+ auto preTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ auto *preTargetBlock = rewriter.createBlock(
+ &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+ IRMapping preMapping;
+ for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
+ auto originalArg = targetBlock->getArgument(i);
+ auto newArg = preTargetBlock->addArgument(originalArg.getType(),
+ originalArg.getLoc());
+ preMapping.map(originalArg, newArg);
+ }
+ for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++)
+ rewriter.clone(*it, preMapping);
+
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+ for (auto original : allocs) {
+ Value toStore = preMapping.lookup(original);
+ auto newArg = preTargetBlock->addArgument(
+ getPtrTypeForOmp(original.getType()), original.getLoc());
+ if (isPtr(original.getType())) {
+ if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+ toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
+ rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+ } else {
+ rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+ }
+ }
+ rewriter.create<omp::TerminatorOp>(loc);
+
+ rewriter.setInsertionPoint(targetOp);
+
+ auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+
+ auto *isolatedTargetBlock =
+ rewriter.createBlock(&isolatedTargetOp.getRegion(),
+ isolatedTargetOp.getRegion().begin(), {}, {});
+
+ IRMapping isolatedMapping;
+ reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
+ targetBlock, isolatedTargetBlock, allocs,
+ toRecompute);
+ rewriter.clone(*splitBeforeOp, isolatedMapping);
+ rewriter.create<omp::TerminatorOp>(loc);
+
+ omp::TargetOp postTargetOp = nullptr;
+
+ if (splitAfter) {
+ rewriter.setInsertionPoint(targetOp);
+ postTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ auto *postTargetBlock = rewriter.createBlock(
+ &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+ IRMapping postMapping;
+ reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
+ targetBlock, postTargetBlock, allocs, toRecompute);
+
+ assert(splitBeforeOp->getNumResults() == 0 ||
+ llvm::all_of(splitBeforeOp->getResults(),
+ [](Value result) { return result.use_empty(); }));
+
+ for (auto it = std::next(splitBeforeOp->getIterator());
+ it != targetBlock->end(); it++)
+ rewriter.clone(*it, postMapping);
+ }
+
+ rewriter.eraseOp(targetOp);
+ return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+}
+
+static mlir::LLVM::ConstantOp
+genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
+ mlir::Type i32Ty = rewriter.getI32Type();
+ mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
+}
+
+static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
+
+static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *targetBlock = &targetOp.getRegion().front();
+ assert(targetBlock == &targetOp.getRegion().back());
+ IRMapping mapping;
+ for (auto map :
+ zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) {
+ Value mapInfo = std::get<0>(map);
+ BlockArgument arg = std::get<1>(map);
+ Operation *op = mapInfo.getDefiningOp();
+ assert(op);
+ auto mapInfoOp = cast<omp::MapInfoOp>(op);
+ mapping.map(arg, mapInfoOp.getVarPtr());
+ }
+ rewriter.setInsertionPoint(targetOp);
+ SmallVector<Operation *> opsToReplace;
+ Value device = targetOp.getDevice();
+ if (!device) {
+ device = genI32Constant(targetOp.getLoc(), rewriter, 0);
+ }
+ for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
+ it != end; ++it) {
+ auto *op = &*it;
+ if (isRuntimeCall(op)) {
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+ auto module = runtimeCall->getParentOfType<ModuleOp>();
+ auto callee =
+ cast<func::FuncOp>(module.lookupSymbol(runtimeCall.getCalleeAttr()));
+ std::string newCalleeName = (callee.getName() + "_omp").str();
+ mlir::OpBuilder moduleBuilder(module.getBodyRegion());
+ func::FuncOp newCallee =
+ cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
+ if (!newCallee) {
+ SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
+ argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
+ newCallee = moduleBuilder.create<func::FuncOp>(
+ callee->getLoc(), newCalleeName,
+ FunctionType::get(rewriter.getContext(), argTypes,
+ callee.getFunctionType().getResults()));
+ if (callee.getArgAttrs())
+ newCallee.setArgAttrsAttr(*callee.getArgAttrs());
+ if (callee.getResAttrs())
+ newCallee.setResAttrsAttr(*callee.getResAttrs());
+ newCallee.setSymVisibility(callee.getSymVisibility());
+ newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary());
+ }
+ SmallVector<Value> operands = runtimeCall.getOperands();
+ operands.push_back(device);
+ auto tmpCall = rewriter.create<fir::CallOp>(
+ runtimeCall.getLoc(), runtimeCall.getResultTypes(),
+ SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
+ runtimeCall.getFastmathAttr());
+ Operation *newCall = rewriter.clone(*tmpCall, mapping);
+ mapping.map(&*it, newCall);
+ rewriter.eraseOp(tmpCall);
+ } else {
+ Operation *clonedOp = rewriter.clone(*op, mapping);
+ if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+ opsToReplace.push_back(clonedOp);
+ for (unsigned i = 0; i < op->getNumResults(); ++i) {
+ mapping.map(op->getResult(i), clonedOp->getResult(i));
+ }
+ }
+ }
+ for (Operation *op : opsToReplace) {
+ if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
+ rewriter.setInsertionPoint(allocOp);
+ auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
+ allocOp.getLoc(), rewriter.getI64Type(), device,
+ allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
+ allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
+ allocOp.getShape());
+ auto firConvertOp = rewriter.create<fir::ConvertOp>(
+ allocOp.getLoc(), allocOp.getResult().getType(),
+ ompAllocmemOp.getResult());
+ rewriter.replaceOp(allocOp, firConvertOp.getResult());
+ } else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
+ rewriter.setInsertionPoint(freeOp);
+ auto firConvertOp = rewriter.create<fir::ConvertOp>(
+ freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
+ rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
+ firConvertOp.getResult());
+ rewriter.eraseOp(freeOp);
+ }
+ }
+ rewriter.eraseOp(targetOp);
+}
+
+void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
+ auto tuple = getNestedOpToIsolate(targetOp);
+ if (!tuple) {
+ LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
+ moveToHost(targetOp, rewriter);
+ return;
+ }
+
+ Operation *toIsolate = std::get<0>(*tuple);
+ bool splitBefore = !std::get<1>(*tuple);
+ bool splitAfter = !std::get<2>(*tuple);
+
+ if (splitBefore && splitAfter) {
+ auto res = isolateOp(toIsolate, splitAfter, rewriter);
+ moveToHost(res.preTargetOp, rewriter);
+ fissionTarget(res.postTargetOp, rewriter);
+ return;
+ }
+ if (splitBefore) {
+ auto res = isolateOp(toIsolate, splitAfter, rewriter);
+ moveToHost(res.preTargetOp, rewriter);
+ return;
+ }
+ if (splitAfter) {
+ auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter);
+ fissionTarget(res.postTargetOp, rewriter);
+ return;
+ }
+}
+
+class LowerWorkdistributePass
+ : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ auto moduleOp = getOperation();
+ bool changed = false;
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ changed |= FissionWorkdistribute(workdistribute);
+ });
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ changed |= WorkdistributeDoLower(workdistribute);
+ });
+ moduleOp->walk([&](mlir::omp::TeamsOp teams) {
+ changed |= TeamsWorkdistributeToSingleOp(teams);
+ });
+
+ if (changed) {
+ SmallVector<omp::TargetOp> targetOps;
+ moduleOp->walk(
+ [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
+ IRRewriter rewriter(&context);
+ for (auto targetOp : targetOps) {
+ auto res = splitTargetData(targetOp, rewriter);
+ if (res)
+ fissionTarget(res->targetOp, rewriter);
+ }
+ }
+ }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index ca8e820608688..4b4e4f7ef19aa 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -294,8 +294,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, hlfir::createInlineHLFIRAssign);
pm.addPass(hlfir::createConvertHLFIRtoFIR());
- if (enableOpenMP)
+ if (enableOpenMP) {
pm.addPass(flangomp::createLowerWorkshare());
+ pm.addPass(flangomp::createLowerWorkdistribute());
+ }
}
/// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index c9fe53bf093a1..3a94ea8a476a1 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -69,6 +69,7 @@ func.func @_QQmain() {
// PASSES-NEXT: InlineHLFIRAssign
// PASSES-NEXT: ConvertHLFIRtoFIR
// PASSES-NEXT: LowerWorkshare
+// PASSES-NEXT: LowerWorkdistribute
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
new file mode 100644
index 0000000000000..00d10d6264ec9
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir
@@ -0,0 +1,33 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL: func.func @x({{.*}})
+// CHECK: omp.teams {
+// CHECK: omp.parallel {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
+// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: } {omp.composite}
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+ omp.teams {
+ omp.workdistribute {
+ fir.do_loop %iv = %lb to %ub step %step unordered {
+ %zero = arith.constant 0 : index
+ fir.store %zero to %addr : !fir.ref<index>
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
new file mode 100644
index 0000000000000..19bdb9ce10fbd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir
@@ -0,0 +1,112 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL: func.func @x(
+// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"}
+// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref<index>
+// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"}
+// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref<index>
+// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"}
+// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref<index>
+// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "step"}
+// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref<index>, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+// CHECK: %[[VAL_11:.*]] = fir.alloca index
+// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK: %[[VAL_14:.*]] = fir.alloca index
+// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK: %[[VAL_17:.*]] = fir.alloca index
+// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(from) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_from"}
+// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "__flang_workdistribute_to"}
+// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap<index>
+// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(from) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_from"}
+// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref<!fir.heap<index>>, !fir.heap<index>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.heap<index>> {name = "__flang_workdistribute_to"}
+// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref<index>
+// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref<index>
+// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index
+// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array<i32: 1, 0, 1>, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap<index>
+// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref<index>
+// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref<index>
+// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref<index>
+// CHECK: fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<!fir.heap<index>>) {
+// CHECK: %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr<index>
+// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr<index>
+// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<index>
+// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr<!fir.heap<index>>
+// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index
+// CHECK: omp.teams {
+// CHECK: omp.parallel {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) {
+// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap<index>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: } {omp.composite}
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref<index>
+// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref<index>
+// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref<index>
+// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref<!fir.heap<index>>
+// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index
+// CHECK: fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap<index>
+// CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap<index>) -> ()
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref<index>) {
+ %lb_ref = fir.alloca index {bindc_name = "lb"}
+ fir.store %lb to %lb_ref : !fir.ref<index>
+ %ub_ref = fir.alloca index {bindc_name = "ub"}
+ fir.store %ub to %ub_ref : !fir.ref<index>
+ %step_ref = fir.alloca index {bindc_name = "step"}
+ fir.store %step to %step_ref : !fir.ref<index>
+
+ %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "lb"}
+ %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "ub"}
+ %step_map = omp.map.info var_ptr(%step_ref : !fir.ref<index>, index) map_clauses(to) capture(ByRef) -> !fir.ref<index> {name = "step"}
+ %addr_map = omp.map.info var_ptr(%addr : !fir.ref<index>, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref<index> {name = "addr"}
+
+ omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref<index>, !fir.ref<index>, !fir.ref<index>, !fir.ref<index>) {
+ %lb_val = fir.load %ARG0 : !fir.ref<index>
+ %ub_val = fir.load %ARG1 : !fir.ref<index>
+ %step_val = fir.load %ARG2 : !fir.ref<index>
+ %one = arith.constant 1 : index
+
+ %20 = arith.addi %ub_val, %ub_val : index
+ omp.teams {
+ omp.workdistribute {
+ %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"}
+ fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered {
+ fir.store %20 to %dev_mem : !fir.heap<index>
+ }
+ fir.store %lb_val to %dev_mem : !fir.heap<index>
+ fir.freemem %dev_mem : !fir.heap<index>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
new file mode 100644
index 0000000000000..c562b7009664d
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir
@@ -0,0 +1,71 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL: func.func @test_fission_workdistribute(
+// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK: omp.teams {
+// CHECK: omp.parallel {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
+// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
+// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: } {omp.composite}
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: } {omp.composite}
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
+// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
+// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
+// CHECK: }
+// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
+// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
+// CHECK: return
+// CHECK: }
+module {
+func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
+ return
+}
+func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
+ return
+}
+func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
+ %c0_idx = arith.constant 0 : index
+ %c1_idx = arith.constant 1 : index
+ %c9_idx = arith.constant 9 : index
+ %float_val = arith.constant 5.0 : f32
+ omp.teams {
+ omp.workdistribute {
+ fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
+ fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
+ %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+ %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
+ %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+ fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
+ }
+ fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
+ fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
+ fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
+ %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+ fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
+ }
+ %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
+ fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
new file mode 100644
index 0000000000000..d96068b26ca2f
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir
@@ -0,0 +1,32 @@
+// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
+
+// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"}
+// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
+// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+ %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
+ %2 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+ %4 = fir.coordinate_of %2, i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
+ %5 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%n%i"}
+ %7 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
+ %9 = fir.coordinate_of %7, r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
+ %10 = omp.map.info var_ptr(%9 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%n%r"}
+ %11 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
+ omp.target map_entries(%11 -> %arg1 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da39b19388627..01ee398641ee5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5369,6 +5369,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
+ if (targetOp.getHostEvalVars().empty())
+ numLoops = 0;
+
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
>From 4907705c6319ae135185af58836ef9c6a0705e67 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 25 Jul 2025 22:20:12 +0530
Subject: [PATCH 2/2] Fix hoisting declare ops out of omp.target
---
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 34 +++++++++++++++++--
1 file changed, 31 insertions(+), 3 deletions(-)
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 0885efc716db4..3f78727450f31 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -434,7 +434,7 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
- rewriter.replaceOp(targetOp, newTargetOp);
+ rewriter.replaceOp(targetOp, targetDataOp);
return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
}
@@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
rewriter.eraseOp(tmpCall);
} else {
Operation *clonedOp = rewriter.clone(*op, mapping);
- if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
- opsToReplace.push_back(clonedOp);
for (unsigned i = 0; i < op->getNumResults(); ++i) {
mapping.map(op->getResult(i), clonedOp->getResult(i));
}
+ // fir.declare changes its type when hoisting it out of omp.target to
+ // omp.target_data Introduce a load, if original declareOp input is not of
+ // reference type, but cloned delcareOp input is reference type.
+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
+ Type originalInType = originalDeclareOp.getMemref().getType();
+ Type clonedInType = clonedDeclareOp.getMemref().getType();
+
+ fir::ReferenceType originalRefType =
+ dyn_cast<fir::ReferenceType>(originalInType);
+ fir::ReferenceType clonedRefType =
+ dyn_cast<fir::ReferenceType>(clonedInType);
+ if (!originalRefType && clonedRefType) {
+ Type clonedEleTy = clonedRefType.getElementType();
+ if (clonedEleTy == originalDeclareOp.getType()) {
+ opsToReplace.push_back(clonedOp);
+ }
+ }
+ }
+ if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+ opsToReplace.push_back(clonedOp);
}
}
for (Operation *op : opsToReplace) {
@@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
firConvertOp.getResult());
rewriter.eraseOp(freeOp);
+ } else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
+ Type clonedInType = clonedDeclareOp.getMemref().getType();
+ fir::ReferenceType clonedRefType =
+ dyn_cast<fir::ReferenceType>(clonedInType);
+ Type clonedEleTy = clonedRefType.getElementType();
+ rewriter.setInsertionPoint(op);
+ Value loadedValue = rewriter.create<fir::LoadOp>(
+ clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
+ clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
}
}
rewriter.eraseOp(targetOp);
More information about the Mlir-commits
mailing list