[flang-commits] [flang] [mlir] [Flang][mlir] - Translation of delayed privatization for deferred target-tasks (PR #155348)
Pranav Bhandarkar via flang-commits
flang-commits at lists.llvm.org
Wed Aug 27 09:11:10 PDT 2025
================
@@ -0,0 +1,423 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare for OpenMP Offload
+// Privatization ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/OpenMPOffloadPrivatizationPrepare.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <cstdint>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+#define PDBGS() (llvm::dbgs() << "[" << DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// OMPTargetPrepareDelayedPrivatizationPattern
+//===----------------------------------------------------------------------===//
+
+class OMPTargetPrepareDelayedPrivatizationPattern
+ : public OpRewritePattern<omp::TargetOp> {
+public:
+ using OpRewritePattern<omp::TargetOp>::OpRewritePattern;
+
+ // Match omp::TargetOp that have the following characteristics.
+ // 1. have private vars which refer to local (stack) memory
+ // 2. the target op has the nowait clause
+ // In this case, we allocate memory for the privatized variable on the heap
+ // and copy the original variable into this new heap allocation. We fix up
+ // any omp::MapInfoOp instances that may be mapping the private variable.
+ mlir::LogicalResult
+ matchAndRewrite(omp::TargetOp targetOp,
+ PatternRewriter &rewriter) const override {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return rewriter.notifyMatchFailure(
+ targetOp,
+ "targetOp does not have privateVars or does not need a target task");
+
+ ModuleOp mod = targetOp->getParentOfType<ModuleOp>();
+ LLVM::LLVMFuncOp llvmFunc = targetOp->getParentOfType<LLVM::LLVMFuncOp>();
+ OperandRange privateVars = targetOp.getPrivateVars();
+ mlir::SmallVector<mlir::Value> newPrivVars;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ auto privVar = std::get<0>(privVarSymPair);
+ auto privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ mlir::Value mappedValue =
+ targetOp.getMappedValueForPrivateVar(privVarIdx);
+ Operation *mapInfoOperation = mappedValue.getDefiningOp();
+ auto mapInfoOp = mlir::cast<omp::MapInfoOp>(mapInfoOperation);
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // TODO: For boxchars this likely wont be a pointer.
+ mlir::Value varPtr = privVar;
+ mlir::Value heapMem = allocateHeapMem(targetOp, privVar, mod, rewriter);
+ if (!heapMem)
+ return failure();
+
+ newPrivVars.push_back(heapMem);
+
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ std::set<Operation *> users;
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOperation);
+ if (!mapInfoOp.getMembers().empty()) {
+ for (auto member : mapInfoOp.getMembers()) {
+ if (usesVarPtr(member.getDefiningOp()))
+ chainOfOps.push_back(member.getDefiningOp());
+
+ omp::MapInfoOp memberMap =
+ mlir::cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (memberMap.getVarPtrPtr() &&
+ usesVarPtr(memberMap.getVarPtrPtr().getDefiningOp()))
+ chainOfOps.push_back(memberMap.getVarPtrPtr().getDefiningOp());
+ }
+ }
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ return dom.dominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+ // Copy the value of the local variable into the heap-allocated location.
+ mlir::Location loc = chainOfOps.front()->getLoc();
+ mlir::Type varType = getElemType(varPtr);
+ auto loadVal = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
+ LLVM_ATTRIBUTE_UNUSED auto storeInst =
+ rewriter.create<LLVM::StoreOp>(loc, loadVal.getResult(), heapMem);
+
+ using ReplacementEntry = std::pair<Operation *, Operation *>;
+ llvm::SmallVector<ReplacementEntry> replRecord;
+ auto cloneAndMarkForDeletion = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ replRecord.push_back(std::make_pair(origOp, clonedOp));
+ return clonedOp;
+ };
+
+ rewriter.setInsertionPoint(targetOp);
+ rewriter.setInsertionPoint(cloneAndMarkForDeletion(mapInfoOperation));
+
+ // Fix any members that may use varPtr to now use heapMem
+ if (!mapInfoOp.getMembers().empty()) {
+ for (auto member : mapInfoOp.getMembers()) {
+ Operation *memberOperation = member.getDefiningOp();
+ if (!usesVarPtr(memberOperation))
+ continue;
+ rewriter.setInsertionPoint(cloneAndMarkForDeletion(memberOperation));
+
+ auto memberMapInfoOp = mlir::cast<omp::MapInfoOp>(memberOperation);
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+
+ // In the case of firstprivate, we have to do the following
+ // 1. Allocate heap memory for the underlying data.
+ // 2. Copy the original underlying data to the new memory allocated
+ // on the heap.
+ // 3. Put this new (heap) address in the originating
+ // struct/descriptor
+
+ // Consider the following sequence of omp.map.info and omp.target
+ // operations.
+ // %0 = llvm.getelementptr %19[0, 0]
+ // %1 = omp.map.info var_ptr(%19 : !llvm.ptr, i32) ...
+ // var_ptr_ptr(%0 : !llvm.ptr) bounds(..)
+ // %2 = omp.map.info var_ptr(%19 : !llvm.ptr, !desc_type)>) ...
+ // members(%1 : [0] : !llvm.ptr) -> !llvm.ptr
+ // omp.target nowait map_entries(%2 -> %arg5, %1 -> %arg8 : ..)
+ // private(@privatizer %19 -> %arg9 [map_idx=1] :
+ // !llvm.ptr) {
+ // We need to allocate memory on the heap for the underlying pointer
+ // which is stored at the var_ptr_ptr operand of %1. Then we need to
+ // copy this pointer to the new heap allocated memory location.
+ // Then, we need to store the address of the new heap location in
+ // the originating struct/descriptor. So, we generate the following
+ // (pseudo) MLIR code (Using the same names of mlir::Value instances
+ // in the example as in the code below)
+ //
+ // %dataMalloc = malloc(totalSize)
+ // %loadDataPtr = load %0 : !llvm.ptr -> !llvm.ptr
+ // memcpy(%dataMalloc, %loadDataPtr, totalSize)
+ // %newVarPtrPtrOp = llvm.getelementptr %heapMem[0, 0]
+ // llvm.store %dataMalloc, %newVarPtrPtrOp
+ // %1.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr, i32) ...
+ // var_ptr_ptr(%newVarPtrPtrOp : !llvm.ptr)
+ // %2.cloned = omp.map.info var_ptr(%heapMem : !llvm.ptr,
+ // !desc_type)>) ...
+ // members(%1.cloned : [0] : !llvm.ptr)
+ // -> !llvm.ptr
+ // omp.target nowait map_entries(%2.cloned -> %arg5,
+ // %1.cloned -> %arg8 : ..)
+ // private(@privatizer %heapMem -> .. [map_idx=1] : ..) {
+
+ if (isFirstPrivate) {
+ assert(!memberMapInfoOp.getBounds().empty() &&
+ "empty bounds on member map of firstprivate variable");
+ mlir::Location loc = memberMapInfoOp.getLoc();
+ mlir::Value totalSize =
+ getSizeInBytes(memberMapInfoOp, mod, rewriter);
+ auto dataMalloc = allocateHeapMem(loc, totalSize, mod, rewriter);
+ auto loadDataPtr = rewriter.create<LLVM::LoadOp>(
+ loc, memberMapInfoOp.getVarPtrPtr().getType(),
+ memberMapInfoOp.getVarPtrPtr());
+ LLVM_ATTRIBUTE_UNUSED auto memcpy =
+ rewriter.create<mlir::LLVM::MemcpyOp>(
+ loc, dataMalloc.getResult(), loadDataPtr.getResult(),
+ totalSize, /*isVolatile=*/false);
+ Operation *newVarPtrPtrOp = rewriter.clone(*varPtrPtrdefOp);
+ rewriter.replaceAllUsesExcept(memberMapInfoOp.getVarPtrPtr(),
+ newVarPtrPtrOp->getOpResult(0),
+ loadDataPtr);
+ rewriter.modifyOpInPlace(newVarPtrPtrOp, [&]() {
+ newVarPtrPtrOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ LLVM_ATTRIBUTE_UNUSED auto storePtr =
+ rewriter.create<LLVM::StoreOp>(loc, dataMalloc.getResult(),
+ newVarPtrPtrOp->getResult(0));
+ } else
+ rewriter.setInsertionPoint(
+ cloneAndMarkForDeletion(varPtrPtrdefOp));
+ }
+ }
+ }
+
+ for (auto repl : replRecord) {
+ Operation *origOp = repl.first;
+ Operation *clonedOp = repl.second;
+ rewriter.modifyOpInPlace(
+ clonedOp, [&]() { clonedOp->replaceUsesOfWith(varPtr, heapMem); });
+ rewriter.eraseOp(origOp);
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+
+ rewriter.setInsertionPoint(targetOp);
+ Operation *newOp = rewriter.clone(*targetOp.getOperation());
+ omp::TargetOp newTargetOp = mlir::cast<omp::TargetOp>(newOp);
+ rewriter.modifyOpInPlace(newTargetOp, [&]() {
+ newTargetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ rewriter.replaceOp(targetOp, newTargetOp);
+ return mlir::success();
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, mlir::Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ template <typename OpType>
+ mlir::Type getElemType(OpType op) const {
+ return op.getElemType();
+ }
+
+ mlir::Type getElemType(mlir::Value varPtr) const {
+ Operation *definingOp = unwrapAddrSpaceCast(varPtr.getDefiningOp());
+ assert((mlir::isa<LLVM::AllocaOp, LLVM::GEPOp>(definingOp)) &&
+ "getElemType in PrepareForOMPOffloadPrivatizationPass can deal only "
+ "with Alloca or GEP for now");
+ if (auto allocaOp = mlir::dyn_cast<LLVM::AllocaOp>(definingOp))
+ return getElemType(allocaOp);
+ // TODO: get rid of this because GEPOp.getElemType() is not the right thing
+ // to use.
+ if (auto gepOp = mlir::dyn_cast<LLVM::GEPOp>(definingOp))
+ return getElemType(gepOp);
+ return mlir::Type{};
+ }
+
+ mlir::Operation *unwrapAddrSpaceCast(Operation *op) const {
+ if (!mlir::isa<LLVM::AddrSpaceCastOp>(op))
+ return op;
+ mlir::LLVM::AddrSpaceCastOp addrSpaceCastOp =
+ mlir::cast<LLVM::AddrSpaceCastOp>(op);
+ return unwrapAddrSpaceCast(addrSpaceCastOp.getArg().getDefiningOp());
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const mlir::DataLayout &dl,
+ mlir::Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ // Generate code to get the size of data being mapped from the bounds
+ // of mapInfoOp
+ mlir::Value getSizeInBytes(omp::MapInfoOp mapInfoOp, ModuleOp mod,
+ PatternRewriter &rewriter) const {
+ mlir::Location loc = mapInfoOp.getLoc();
+ mlir::Type llvmInt64Ty = rewriter.getI64Type();
+ mlir::Value constOne =
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Ty, 1);
+ mlir::Value elementCount = constOne;
+ // TODO: Consider using boundsOp.getExtent() if available.
+ for (auto bounds : mapInfoOp.getBounds()) {
+ auto boundsOp = mlir::cast<omp::MapBoundsOp>(bounds.getDefiningOp());
+ elementCount = rewriter.create<LLVM::MulOp>(
+ loc, llvmInt64Ty, elementCount,
+ rewriter.create<LLVM::AddOp>(
+ loc, llvmInt64Ty,
+ (rewriter.create<LLVM::SubOp>(loc, llvmInt64Ty,
+ boundsOp.getUpperBound(),
+ boundsOp.getLowerBound())),
+ constOne));
+ }
+ const mlir::DataLayout &dl = mlir::DataLayout(mod);
+ std::int64_t elemSize = getSizeInBytes(dl, mapInfoOp.getVarType());
+ mlir::Value elemSizeV =
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Ty, elemSize);
+ return rewriter.create<LLVM::MulOp>(loc, llvmInt64Ty, elementCount,
+ elemSizeV);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, PatternRewriter &rewriter) const {
+ llvm::FailureOr<mlir::LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ template <typename OpTy>
+ mlir::Value allocateHeapMem(OpTy targetOp, mlir::Value privVar, ModuleOp mod,
+ PatternRewriter &rewriter) const {
+ mlir::Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(definingOp);
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ mlir::Location loc = definingOp->getLoc();
+ mlir::Type varType = getElemType(varPtr);
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+ const mlir::DataLayout &dl = mlir::DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+ mlir::Value sizeBytes = rewriter.create<LLVM::ConstantOp>(
+ loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ LLVM::CallOp allocateHeapMem(mlir::Location loc, mlir::Value size,
+ ModuleOp mod, PatternRewriter &rewriter) const {
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+ return rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{size});
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+struct PrepareForOMPOffloadPrivatizationPass
+ : public LLVM::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ LLVM::LLVMFuncOp func = getOperation();
+ MLIRContext &context = getContext();
+ ModuleOp mod = func->getParentOfType<ModuleOp>();
+
+ // FunctionFilteringPass removes bounds arguments from omp.map.info
+ // operations. We require bounds else our pass asserts. But, that's only for
+ // maps in functions that are on the host. So, skip functions being compiled
+ // for the target.
+ auto offloadModuleInterface =
+ mlir::dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice()) {
+ return;
+ }
+
+ RewritePatternSet patterns(&context);
+ patterns.add<OMPTargetPrepareDelayedPrivatizationPattern>(&context);
+
+ if (mlir::failed(
+ applyPatternsGreedily(func, std::move(patterns),
----------------
bhandarkar-pranav wrote:
This is a good suggestion. TBH, i had started off with this initially, but then I second guessed myself owing to my belief that a formal/fancy pattern rewriter is what reviewers would prefer. Of course, that was before I even realized i'd have to deal with lit test related annoyances due to canonicalization.
https://github.com/llvm/llvm-project/pull/155348
More information about the flang-commits
mailing list