[Mlir-commits] [flang] [mlir] [Flang][OpenMP] Implement workdistribute construct lowering (PR #140523)

Krzysztof Parzyszek llvmlistbot at llvm.org
Fri Oct 17 09:23:07 PDT 2025


================
@@ -0,0 +1,1859 @@
+//===- LowerWorkdistribute.cpp
+//-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+// Fortran array statements are lowered to fir as fir.do_loop unordered.
+// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
+// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
+// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}.
+// It hoists all the other ops outside target region.
+// Relaces heap allocation on target with omp.target_allocmem and
+// deallocation with omp.target_freemem from host. Also replaces
+// runtime function "Assign" with omp_target_memcpy.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+/// This string is used to identify the Fortran-specific runtime FortranAAssign.
+static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign";
+
+/// The isRuntimeCall function is a utility designed to determine
+/// if a given operation is a call to a Fortran-specific runtime function.
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.workdistribute region.
+/// Parallelize here refers to dividing into units of work.
+static bool shouldParallelize(Operation *op) {
+  // True if the op is a runtime call to Assign
+  if (isRuntimeCall(op)) {
+    fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+    auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue();
+    if (funcName == FortranAssignStr) {
+      return true;
+    }
+  }
+  // We cannot parallelize ops with side effects.
+  // Parallelizable operations should not produce
+  // values that other operations depend on
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  // We cannot parallelize anything else.
+  return false;
+}
+
+/// The getPerfectlyNested function is a generic utility for finding
+/// a single, "perfectly nested" operation within a parent operation.
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// verifyTargetTeamsWorkdistribute method verifies that
+/// omp.target { teams { workdistribute { ... } } } is well formed
+/// and fails for function calls that don't have lowering implemented yet.
+static LogicalResult
+verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return failure();
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return failure();
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return failure();
+  }
+
+  bool foundWorkdistribute = false;
+  for (auto &op : teams.getOps()) {
+    if (isa<omp::WorkdistributeOp>(op)) {
+      if (foundWorkdistribute) {
+        emitError(loc, "teams has multiple workdistribute ops.\n");
+        return failure();
+      }
+      foundWorkdistribute = true;
+      continue;
+    }
+    // Identify any omp dialect ops present before/after workdistribute.
+    if (op.getDialect() && isa<omp::OpenMPDialect>(op.getDialect()) &&
+        !isa<omp::TerminatorOp>(op)) {
+      emitError(loc, "teams has omp ops other than workdistribute. Lowering "
+                     "not implemented yet.\n");
+      return failure();
+    }
+  }
+
+  omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+  // return if not omp.target
+  if (!targetOp)
+    return success();
+
+  for (auto &op : workdistribute.getOps()) {
+    if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+      if (isRuntimeCall(&op)) {
+        auto funcName = (*callOp.getCallee()).getRootReference().getValue();
+        // _FortranAAssign is handled. Other runtime calls are not supported
+        // in omp.workdistribute yet.
+        if (funcName == FortranAssignStr)
+          continue;
+        else {
+          emitError(loc, "Runtime call " + funcName +
+                             " lowering not supported for workdistribute yet.");
+          return failure();
+        }
+      }
+    }
+  }
+  return success();
+}
+
+/// fissionWorkdistribute method finds the parallelizable ops
+/// within teams {workdistribute} region and moves them to their
+/// own teams{workdistribute} region.
+///
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+static FailureOr<bool>
+fissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return failure();
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+/// Generate omp.parallel operation with an empty region.
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+/// Generate omp.distribute operation with an empty region.
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+/// Generate loop nest clause operands from fir.do_loop operation.
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+/// Generate omp.wsloop operation with an empty region and
+/// clone the body of fir.do_loop operation inside the loop nest region.
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    terminatorOp->erase();
+  }
+  return;
----------------
kparzysz wrote:

Unnecessary

https://github.com/llvm/llvm-project/pull/140523


More information about the Mlir-commits mailing list