[flang-commits] [flang] [mlir] [Flang][OpenMP] Implement workdistribute construct lowering (PR #140523)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 8 05:04:40 PDT 2025


================
@@ -0,0 +1,1817 @@
+//===- LowerWorkdistribute.cpp
+//-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+// Fortran array statements are lowered to fir as fir.do_loop unordered.
+// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
+// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
+// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}.
+// It hoists all the other ops outside target region.
+// Relaces heap allocation on target with omp.target_allocmem and
+// deallocation with omp.target_freemem from host. Also replaces
+// runtime function "Assign" with omp_target_memcpy.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+// The isRuntimeCall function is a utility designed to determine
+// if a given operation is a call to a Fortran-specific runtime function.
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+// This is the single source of truth about whether we should parallelize an
+// operation nested in an omp.workdistribute region.
+static bool shouldParallelize(Operation *op) {
+  // True if the op is a runtime call to Assign
+  if (isRuntimeCall(op)) {
+    fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+    if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+        "_FortranAAssign") {
+      return true;
+    }
+  }
+  // We cannot parallelize ops with side effects.
+  // Parallelizable operations should not produce
+  // values that other operations depend on
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  // We cannot parallise anything else.
+  return false;
+}
+
+// The getPerfectlyNested function is a generic utility for finding
+// a single, "perfectly nested" operation within a parent operation.
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+// VerifyTargetTeamsWorkdistribute method verifies that
+// omp.target { teams { workdistribute { ... } } } is well formed
+// and fails for function calls that don't have lowering implemented yet.
+static bool
+VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    workdistribute.emitError() << "workdistribute not nested in teams\n";
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    workdistribute.emitError() << "workdistribute with multiple blocks\n";
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    workdistribute.emitError() << "teams with multiple blocks\n";
+    return false;
+  }
+  omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+  // return if not omp.target
+  if (!targetOp)
+    return true;
+
+  for (auto &op : workdistribute.getOps()) {
+    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+      if (isRuntimeCall(&op)) {
+        auto funcName = (*callOp.getCallee()).getRootReference().getValue();
+        // _FortranAAssign is handled. Other runtime calls are not supported
+        // in omp.workdistribute yet.
+        if (funcName == "_FortranAAssign")
+          continue;
+        else
+          workdistribute.emitError()
+              << "Runtime call " << funcName
+              << " lowering not supported for workdistribute yet.";
+        return false;
+      } else {
+        workdistribute.emitError() << "Non-runtime fir.call lowering not "
+                                      "supported in workdistribute yet.";
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+// FissionWorkdistribute method finds the parallelizable ops
+// within teams {workdistribute} region and moves them to their
+// own teams{workdistribute} region.
+//
+// If B() and D() are parallelizable,
+//
+// omp.teams {
+//   omp.workdistribute {
+//     A()
+//     B()
+//     C()
+//     D()
+//     E()
+//   }
+// }
+//
+// becomes
+//
+// A()
+// omp.teams {
+//   omp.workdistribute {
+//     B()
+//   }
+// }
+// C()
+// omp.teams {
+//   omp.workdistribute {
+//     D()
+//   }
+// }
+// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+
+  omp::TargetOp targetOp;
+  // Get the target op parent of teams
+  targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return false;
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+// Generate omp.parallel operation with an empty region.
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+// Generate omp.distribute operation with an empty region.
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+// Generate loop nest clause operands from fir.do_loop operation.
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+// Generate omp.wsloop operation with an empty region and
+// clone the body of fir.do_loop operation inside the loop nest region.
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    terminatorOp->erase();
+  }
+  return;
+}
+
+// WorkdistributeDoLower method finds the fir.do_loop unoredered
+// nested in teams {workdistribute{fir.do_loop unoredered}} and
+// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
+//
+// If fir.do_loop is present inside teams workdistribute
+//
+// omp.teams {
+//   omp.workdistribute {
+//     fir.do_loop unoredered {
+//       ...
+//     }
+//   }
+// }
+//
+// Then, its lowered to
+//
+// omp.teams {
+//    omp.parallel {
+//      omp.distribute {
+//        omp.wsloop {
+//          omp.loop_nest
+//            ...
+//          }
+//        }
+//      }
+//   }
+// }
+
+static bool
+WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
+                      SetVector<omp::TargetOp> &targetOpsToProcess) {
+  OpBuilder rewriter(workdistribute);
+  auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+  auto wdLoc = workdistribute->getLoc();
+  if (doLoop && shouldParallelize(doLoop)) {
+    assert(doLoop.getReduceOperands().empty());
+
+    // Record the target ops to process later
+    if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
+      auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+      if (targetOp) {
+        targetOpsToProcess.insert(targetOp);
+      }
+    }
+    // Generate the nested parallel, distribute, wsloop and loop_nest ops.
+    genParallelOp(wdLoc, rewriter, true);
+    genDistributeOp(wdLoc, rewriter, true);
+    mlir::omp::LoopNestOperands loopNestClauseOps;
+    genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+    genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+    workdistribute.erase();
+    return true;
+  }
+  return false;
+}
+
+// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array
+static bool isEnclosedTypeRefToBoxArray(Type type) {
+  // Check if it's a reference type
+  if (auto refType = dyn_cast<fir::ReferenceType>(type)) {
+    // Get the referenced type (should be fir.box)
+    auto referencedType = refType.getEleTy();
+    // Check if referenced type is a box
+    if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) {
+      // Get the boxed type and check if it's an array
+      auto boxedType = boxType.getEleTy();
+      // Check if boxed type is a sequence (array)
+      return isa<fir::SequenceType>(boxedType);
+    }
+  }
+  return false;
+}
+
+// Check if the enclosed type in fir.box is scalar (not array)
+static bool isEnclosedTypeBoxScalar(Type type) {
+  // Check if it's a box type
+  if (auto boxType = dyn_cast<fir::BoxType>(type)) {
+    // Get the boxed type
+    auto boxedType = boxType.getEleTy();
+    // Check if boxed type is NOT a sequence (array)
+    return !isa<fir::SequenceType>(boxedType);
+  }
+  return false;
+}
+
+// Check if the FortranAAssign call has src as scalar and dest as array
+static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) {
+  if (callOp.getNumOperands() < 2)
+    return false;
+  auto srcArg = callOp.getOperand(1);
+  auto destArg = callOp.getOperand(0);
+  // Both operands should be fir.convert ops
+  auto srcConvert = srcArg.getDefiningOp<fir::ConvertOp>();
+  auto destConvert = destArg.getDefiningOp<fir::ConvertOp>();
+  if (!srcConvert || !destConvert) {
+    emitError(callOp->getLoc(),
+              "Unimplemented: FortranAssign to OpenMP lowering\n");
+    return false;
+  }
+  // Get the original types before conversion
+  auto srcOrigType = srcConvert.getValue().getType();
+  auto destOrigType = destConvert.getValue().getType();
+
+  // Check if src is scalar and dest is array
+  bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType);
+  bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType);
+  return srcIsScalar && destIsArray;
+}
+
+// Convert a flat index to multi-dimensional indices for an array box
+// Example: 2D array with shape (2,4)
+//         Col 1  Col 2  Col 3  Col 4
+// Row 1:  (1,1)  (1,2)  (1,3)  (1,4)
+// Row 2:  (2,1)  (2,2)  (2,3)  (2,4)
+//
+// extents: (2,4)
+//
+// flatIdx:  0     1     2     3     4     5     6     7
+// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4)
+static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder,
+                                                Location loc, Value flatIdx,
+                                                Value arrayBox) {
+  // Get array type and rank
+  auto boxType = cast<fir::BoxType>(arrayBox.getType());
+  auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+  int rank = seqType.getDimension();
+
+  // Get all extents
+  SmallVector<Value> extents;
+  // Get extents for each dimension
+  for (int i = 0; i < rank; ++i) {
+    auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+    auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+    extents.push_back(boxDims.getResult(1));
+  }
+
+  // Convert flat index to multi-dimensional indices
+  SmallVector<Value> indices(rank);
+  Value temp = flatIdx;
+  auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+  // Work backwards through dimensions (row-major order)
+  for (int i = rank - 1; i >= 0; --i) {
+    Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]);
+    // Convert to one-based index
+    indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1);
+    if (i > 0) {
+      temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]);
+    }
+  }
+
+  return indices;
+}
+
+// Calculate the total number of elements in the array box
+// (totalElems = extent(1) * extent(2) * ... * extent(n))
+static Value CalculateTotalElements(OpBuilder &builder, Location loc,
+                                    Value arrayBox) {
+  auto boxType = cast<fir::BoxType>(arrayBox.getType());
+  auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+  int rank = seqType.getDimension();
+
+  Value totalElems = nullptr;
+  for (int i = 0; i < rank; ++i) {
+    auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+    auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+    Value extent = boxDims.getResult(1);
+    if (i == 0) {
+      totalElems = extent;
+    } else {
+      totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent);
+    }
+  }
+  return totalElems;
+}
+
+// Replace the FortranAAssign runtime call with an unordered do loop
+static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
+                                       omp::TeamsOp teamsOp,
+                                       omp::WorkdistributeOp workdistribute,
+                                       fir::CallOp callOp) {
+  auto destConvert = callOp.getOperand(0).getDefiningOp<fir::ConvertOp>();
+  auto srcConvert = callOp.getOperand(1).getDefiningOp<fir::ConvertOp>();
+
+  Value destBox = destConvert.getValue();
+  Value srcBox = srcConvert.getValue();
+
+  // get defining alloca op of destBox and srcBox
+  auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>();
+
+  if (!destAlloca) {
+    emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n");
+    return;
+  }
+
+  // get the store op that stores to the alloca
+  for (auto user : destAlloca->getUsers()) {
+    if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
+      destBox = storeOp.getValue();
+      break;
+    }
+  }
+
+  builder.setInsertionPoint(teamsOp);
+  // Load destination array box (if it's a reference)
+  Value arrayBox = destBox;
+  if (isa<fir::ReferenceType>(destBox.getType()))
+    arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+
+  auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
+  Value scalar = builder.create<fir::LoadOp>(loc, scalarValue);
+
+  // Calculate total number of elements (flattened)
+  auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
+  auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value totalElems = CalculateTotalElements(builder, loc, arrayBox);
+
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  builder.setInsertionPointToStart(workdistributeBlock);
+  // Create single unordered loop for flattened array
+  auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true);
+  Block *loopBlock = &doLoop.getRegion().front();
+  builder.setInsertionPointToStart(doLoop.getBody());
+
+  auto flatIdx = loopBlock->getArgument(0);
+  SmallVector<Value> indices =
+      convertFlatToMultiDim(builder, loc, flatIdx, arrayBox);
+  // Use fir.array_coor for linear addressing
+  auto elemPtr = fir::ArrayCoorOp::create(
+      builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox,
+      nullptr, nullptr, ValueRange{indices}, ValueRange{});
+
+  builder.create<fir::StoreOp>(loc, scalar, elemPtr);
+}
+
+// WorkdistributeRuntimeCallLower method finds the runtime calls
+// nested in teams {workdistribute{}} and
+// lowers FortranAAssign to unordered do loop if src is scalar and dest is
+// array. Other runtime calls are not handled currently.
+static bool
+WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
+                               SetVector<omp::TargetOp> &targetOpsToProcess) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  bool changed = false;
+  omp::TargetOp targetOp;
+  // Get the target op parent of teams
+  targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+  SmallVector<Operation *> opsToErase;
+  for (auto &op : workdistribute.getOps()) {
+    if (&op == terminator) {
+      break;
+    }
+    if (isRuntimeCall(&op)) {
+      rewriter.setInsertionPoint(&op);
+      fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+      if ((*runtimeCall.getCallee()).getRootReference().getValue() ==
+          "_FortranAAssign") {
+        if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) {
+          // Record the target ops to process later
+          targetOpsToProcess.insert(targetOp);
+          replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute,
+                                     runtimeCall);
+          opsToErase.push_back(&op);
+          changed = true;
+        }
+      }
+    }
+  }
+  // Erase the runtime calls that have been replaced.
+  for (auto *op : opsToErase) {
+    op->erase();
+  }
+  return changed;
+}
+
+// TeamsWorkdistributeToSingleOp method hoists all the ops inside
+// teams {workdistribute{}} before teams op.
+//
+// If A() and B () are present inside teams workdistribute
+//
+// omp.teams {
+//   omp.workdistribute {
+//     A()
+//     B()
+//   }
+// }
+//
+// Then, its lowered to
+//
+// A()
+// B()
+//
+// If only the terminator remains in teams after hoisting, we erase teams op.
+static bool
+TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp,
+                              SetVector<omp::TargetOp> &targetOpsToProcess) {
+  auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+  if (!workdistributeOp)
+    return false;
+  // Get the block containing teamsOp (the parent block).
+  Block *parentBlock = teamsOp->getBlock();
+  Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  // Record the target ops to process later
+  for (auto &op : workdistributeBlock.getOperations()) {
+    if (shouldParallelize(&op)) {
+      auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+      if (targetOp) {
+        targetOpsToProcess.insert(targetOp);
+      }
+    }
+  }
+  auto insertPoint = Block::iterator(teamsOp);
+  // Get the range of operations to move (excluding the terminator).
+  auto workdistributeBegin = workdistributeBlock.begin();
+  auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+  // Move the operations from workdistribute block to before teamsOp.
+  parentBlock->getOperations().splice(insertPoint,
+                                      workdistributeBlock.getOperations(),
+                                      workdistributeBegin, workdistributeEnd);
+  // Erase the now-empty workdistributeOp.
+  workdistributeOp.erase();
+  Block &teamsBlock = *teamsOp.getRegion().begin();
+  // Check if only the terminator remains and erase teams op.
+  if (teamsBlock.getOperations().size() == 1 &&
+      teamsBlock.getTerminator() != nullptr) {
+    teamsOp.erase();
+  }
+  return true;
+}
+
+// If multiple workdistribute are nested in a target regions, we will need to
+// split the target region, but we want to preserve the data semantics of the
+// original data region and avoid unnecessary data movement at each of the
+// subkernels - we split the target region into a target_data{target}
+// nest where only the outer one moves the data
+std::optional<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
+                                             RewriterBase &rewriter) {
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+  // Collect all the mapinfo ops
+  SmallVector<omp::MapInfoOp> mapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    mapInfos.push_back(mapInfo);
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Value> innerMapInfos;
+  SmallVector<Value> outerMapInfos;
+  // Create new mapinfo ops for the inner target region
+  for (auto mapInfo : mapInfos) {
+    auto originalMapType =
+        (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+    auto originalCaptureType = mapInfo.getMapCaptureType();
+    llvm::omp::OpenMPOffloadMappingFlags newMapType;
+    mlir::omp::VariableCaptureKind newCaptureType;
+    // For bycopy, we keep the same map type and capture type
+    // For byref, we change the map type to none and keep the capture type
+    if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+      newMapType = originalMapType;
+      newCaptureType = originalCaptureType;
+    } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+      newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+      newCaptureType = originalCaptureType;
+      outerMapInfos.push_back(mapInfo);
+    } else {
+      llvm_unreachable("Unhandled case");
+    }
+    auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+    innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+        rewriter.getIntegerType(64, false),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            newMapType)));
+    innerMapInfo.setMapCaptureType(newCaptureType);
+    innerMapInfos.push_back(innerMapInfo.getResult());
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  // Create the target data op
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+  // Create the inner target op
+  auto newTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+                              newTargetOp.getRegion().begin());
+  rewriter.replaceOp(targetOp, targetDataOp);
+  return newTargetOp;
+}
+
+// getNestedOpToIsolate function is designed to identify a specific teams
+// parallel op within the body of an omp::TargetOp that should be "isolated."
+// This returns a tuple of op, if its first op in targetBlock, or if the op is
+// last op in the tragte block.
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+// Temporary structure to hold the two mapinfo ops
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+// isPtr checks if the type is a pointer or reference type.
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+// getPtrTypeForOmp returns an LLVM pointer type for the given type.
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::ReferenceType::get(ty);
+}
+
+// allocateTempOmpVar allocates a temporary variable for OpenMP mapping
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  // Get the appropriate type for allocation
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  } else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  // Lambda to create mapinfo ops
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+  };
+  // Create mapinfo ops.
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+}
+
+// usedOutsideSplit checks if a value is used outside the split operation.
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+}
+
+// isRecomputableAfterFission checks if an operation can be recomputed
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  // If the op has side effects, it cannot be recomputed.
+  // We consider fir.declare as having no side effects.
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
----------------
skc7 wrote:

Updated.

https://github.com/llvm/llvm-project/pull/140523


More information about the flang-commits mailing list