[flang-commits] [flang] [mlir] [Flang][mlir] - Translation of delayed privatization for deferred target-tasks (PR #155348)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Oct 2 05:46:16 PDT 2025


================
@@ -0,0 +1,401 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+    : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+          PrepareForOMPOffloadPrivatizationPass> {
+
+  void runOnOperation() override {
+    ModuleOp mod = getOperation()->getParentOfType<ModuleOp>();
+
+    // FunctionFilteringPass removes bounds arguments from omp.map.info
+    // operations. We require bounds else our pass asserts. But, that's only for
+    // maps in functions that are on the host. So, skip functions being compiled
+    // for the target.
+    auto offloadModuleInterface =
+        dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+    if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+      return;
+
+    getOperation()->walk([&](omp::TargetOp targetOp) {
+      if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+        return;
+      IRRewriter rewriter(&getContext());
+      ModuleOp mod = targetOp->getParentOfType<ModuleOp>();
+      OperandRange privateVars = targetOp.getPrivateVars();
+      SmallVector<mlir::Value> newPrivVars;
+
+      newPrivVars.reserve(privateVars.size());
+      std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+      for (auto [privVarIdx, privVarSymPair] :
+           llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+        Value privVar = std::get<0>(privVarSymPair);
+        Attribute privSym = std::get<1>(privVarSymPair);
+
+        omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+        if (!privatizer.needsMap()) {
+          newPrivVars.push_back(privVar);
+          continue;
+        }
+        bool isFirstPrivate = privatizer.getDataSharingType() ==
+                              omp::DataSharingClauseType::FirstPrivate;
+
+        Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+        Operation *mapInfoOperation = mappedValue.getDefiningOp();
+        auto mapInfoOp = cast<omp::MapInfoOp>(mapInfoOperation);
+
+        if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+          newPrivVars.push_back(privVar);
+          continue;
+        }
+
+        // Allocate heap memory that corresponds to the type of memory
+        // pointed to by varPtr
+        // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+        // should have mapped the pointer to the boxchar so use that as varPtr.
+        Value varPtr = privVar;
+        Type varType = mapInfoOp.getVarType();
+        bool isPrivatizedByValue =
+            !isa<LLVM::LLVMPointerType>(privVar.getType());
+        if (isPrivatizedByValue)
+          varPtr = mapInfoOp.getVarPtr();
+
+        assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+        Value heapMem =
+            allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+        if (!heapMem)
+          targetOp.emitError(
+              "Unable to allocate heap memory when trying to move "
+              "a private variable out of the stack and into the "
+              "heap for use by a deferred target task");
+
+        // The types of private vars should match before and after the
+        // transformation. In particular, if the type is a pointer,
+        // simply record the newly allocated malloc location as the
+        // new private variable. If, however, the type is not a pointer
+        // then, we need to load the value from the newly allocated
+        // location. We'll insert that load later after we have updated
+        // the malloc'd location with the contents of the original
+        // variable.
+        if (!isPrivatizedByValue)
+          newPrivVars.push_back(heapMem);
+
+        // We now need to copy the original private variable into the newly
+        // allocated location in the heap.
+        // Find the earliest insertion point for the copy. This will be before
+        // the first in the list of omp::MapInfoOp instances that use varPtr.
+        // After the copy these omp::MapInfoOp instances will refer to heapMem
+        // instead.
+        Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+        DenseSet<Operation *> users;
+        if (varPtrDefiningOp) {
+          users.insert(varPtrDefiningOp->user_begin(),
+                       varPtrDefiningOp->user_end());
+        } else {
+          auto blockArg = cast<BlockArgument>(varPtr);
+          users.insert(blockArg.user_begin(), blockArg.user_end());
+        }
+        auto usesVarPtr = [&users](Operation *op) -> bool {
+          return users.count(op);
+        };
+
+        SmallVector<Operation *> chainOfOps;
+        chainOfOps.push_back(mapInfoOperation);
+        if (!mapInfoOp.getMembers().empty()) {
+          for (auto member : mapInfoOp.getMembers()) {
+            if (usesVarPtr(member.getDefiningOp()))
+              chainOfOps.push_back(member.getDefiningOp());
+
+            omp::MapInfoOp memberMap =
+                cast<omp::MapInfoOp>(member.getDefiningOp());
+            if (memberMap.getVarPtrPtr() &&
+                usesVarPtr(memberMap.getVarPtrPtr().getDefiningOp()))
+              chainOfOps.push_back(memberMap.getVarPtrPtr().getDefiningOp());
+          }
+        }
+
+        DominanceInfo dom;
+        llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+          return dom.dominates(l, r);
+        });
+
+        rewriter.setInsertionPoint(chainOfOps.front());
+
+        Operation *firstOp = chainOfOps.front();
+        Location loc = firstOp->getLoc();
+
+        // Create a llvm.func for 'region' that is marked always_inline and call
+        // it.
+        auto createAlwaysInlineFuncAndCallIt =
+            [&](Region &region, llvm::StringRef funcName,
+                llvm::ArrayRef<Value> args) -> Value {
+          assert(!region.empty() && "region cannot be empty");
+          LLVM::LLVMFuncOp func =
+              createFuncOpForRegion(loc, mod, region, funcName, rewriter);
+          auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+          return call.getResult();
+        };
+
+        Value moldArg, newArg;
+        if (isPrivatizedByValue) {
+          moldArg = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
+          newArg = rewriter.create<LLVM::LoadOp>(loc, varType, heapMem);
+        } else {
+          moldArg = varPtr;
+          newArg = heapMem;
+        }
+
+        Value initializedVal;
+        if (!privatizer.getInitRegion().empty())
+          initializedVal = createAlwaysInlineFuncAndCallIt(
+              privatizer.getInitRegion(),
+              llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+              {moldArg, newArg});
+        else
+          initializedVal = newArg;
+
+        if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+          initializedVal = createAlwaysInlineFuncAndCallIt(
+              privatizer.getCopyRegion(),
+              llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+              {moldArg, initializedVal});
+
+        if (isPrivatizedByValue)
+          (void)rewriter.create<LLVM::StoreOp>(loc, initializedVal, heapMem);
+
+        // clone origOp, replace all uses of varPtr with heapMem and
+        // erase origOp.
+        auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+          Operation *clonedOp = rewriter.clone(*origOp);
+          rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+          rewriter.modifyOpInPlace(clonedOp, [&]() {
+            clonedOp->replaceUsesOfWith(varPtr, heapMem);
+          });
+          rewriter.eraseOp(origOp);
+          return clonedOp;
+        };
+
+        // Now that we have set up the heap-allocated copy of the private
+        // variable, rewrite all the uses of the original variable with
+        // the heap-allocated variable.
+        rewriter.setInsertionPoint(targetOp);
+        rewriter.setInsertionPoint(cloneModifyAndErase(mapInfoOperation));
+
+        // Fix any members that may use varPtr to now use heapMem
+        if (!mapInfoOp.getMembers().empty()) {
+          for (auto member : mapInfoOp.getMembers()) {
+            Operation *memberOperation = member.getDefiningOp();
+            if (!usesVarPtr(memberOperation))
+              continue;
+            rewriter.setInsertionPoint(cloneModifyAndErase(memberOperation));
+
+            auto memberMapInfoOp = cast<omp::MapInfoOp>(memberOperation);
+            if (memberMapInfoOp.getVarPtrPtr()) {
+              Operation *varPtrPtrdefOp =
+                  memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+              rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+            }
+          }
+        }
+
+        // If the type of the private variable is not a pointer,
+        // which is typically the case with !fir.boxchar types, then
+        // we need to ensure that the new private variable is also
+        // not a pointer. Insert a load from heapMem right before
+        // targetOp.
+        if (isPrivatizedByValue) {
+          rewriter.setInsertionPoint(targetOp);
+          auto newPrivVar = rewriter.create<LLVM::LoadOp>(mapInfoOp.getLoc(),
+                                                          varType, heapMem);
+          newPrivVars.push_back(newPrivVar);
+        }
+      }
+      assert(newPrivVars.size() == privateVars.size() &&
+             "The number of private variables must match before and after "
+             "transformation");
+
+      rewriter.setInsertionPoint(targetOp);
+      Operation *newOp = rewriter.clone(*targetOp.getOperation());
+      omp::TargetOp newTargetOp = cast<omp::TargetOp>(newOp);
+      rewriter.modifyOpInPlace(newTargetOp, [&]() {
+        newTargetOp.getPrivateVarsMutable().assign(newPrivVars);
+      });
+      rewriter.replaceOp(targetOp, newTargetOp);
+    });
+  }
----------------
tblah wrote:

I think this is still missing an inline of the privatizer's cleanup region. This will undo work in done by the init region (e.g. deallocating ALLOCATABLEs, running destructors for derived types)

https://github.com/llvm/llvm-project/pull/155348


More information about the flang-commits mailing list