[Mlir-commits] [flang] [mlir] [Flang][mlir] - Translation of delayed privatization for deferred target-tasks (PR #155348)

Tom Eccles llvmlistbot at llvm.org
Thu Oct 2 10:22:36 PDT 2025


================
@@ -0,0 +1,401 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+    : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+          PrepareForOMPOffloadPrivatizationPass> {
+
+  void runOnOperation() override {
+    ModuleOp mod = getOperation()->getParentOfType<ModuleOp>();
+
+    // FunctionFilteringPass removes bounds arguments from omp.map.info
+    // operations. We require bounds else our pass asserts. But, that's only for
+    // maps in functions that are on the host. So, skip functions being compiled
+    // for the target.
+    auto offloadModuleInterface =
+        dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+    if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+      return;
+
+    getOperation()->walk([&](omp::TargetOp targetOp) {
+      if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+        return;
+      IRRewriter rewriter(&getContext());
+      ModuleOp mod = targetOp->getParentOfType<ModuleOp>();
+      OperandRange privateVars = targetOp.getPrivateVars();
+      SmallVector<mlir::Value> newPrivVars;
+
+      newPrivVars.reserve(privateVars.size());
+      std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+      for (auto [privVarIdx, privVarSymPair] :
+           llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+        Value privVar = std::get<0>(privVarSymPair);
+        Attribute privSym = std::get<1>(privVarSymPair);
+
+        omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+        if (!privatizer.needsMap()) {
+          newPrivVars.push_back(privVar);
+          continue;
+        }
+        bool isFirstPrivate = privatizer.getDataSharingType() ==
+                              omp::DataSharingClauseType::FirstPrivate;
+
+        Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+        Operation *mapInfoOperation = mappedValue.getDefiningOp();
+        auto mapInfoOp = cast<omp::MapInfoOp>(mapInfoOperation);
+
+        if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+          newPrivVars.push_back(privVar);
+          continue;
+        }
+
+        // Allocate heap memory that corresponds to the type of memory
+        // pointed to by varPtr
+        // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+        // should have mapped the pointer to the boxchar so use that as varPtr.
+        Value varPtr = privVar;
+        Type varType = mapInfoOp.getVarType();
+        bool isPrivatizedByValue =
+            !isa<LLVM::LLVMPointerType>(privVar.getType());
+        if (isPrivatizedByValue)
+          varPtr = mapInfoOp.getVarPtr();
+
+        assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+        Value heapMem =
+            allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+        if (!heapMem)
+          targetOp.emitError(
+              "Unable to allocate heap memory when trying to move "
+              "a private variable out of the stack and into the "
+              "heap for use by a deferred target task");
+
+        // The types of private vars should match before and after the
+        // transformation. In particular, if the type is a pointer,
+        // simply record the newly allocated malloc location as the
+        // new private variable. If, however, the type is not a pointer
+        // then, we need to load the value from the newly allocated
+        // location. We'll insert that load later after we have updated
+        // the malloc'd location with the contents of the original
+        // variable.
+        if (!isPrivatizedByValue)
+          newPrivVars.push_back(heapMem);
+
+        // We now need to copy the original private variable into the newly
+        // allocated location in the heap.
+        // Find the earliest insertion point for the copy. This will be before
+        // the first in the list of omp::MapInfoOp instances that use varPtr.
+        // After the copy these omp::MapInfoOp instances will refer to heapMem
+        // instead.
+        Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+        DenseSet<Operation *> users;
+        if (varPtrDefiningOp) {
+          users.insert(varPtrDefiningOp->user_begin(),
+                       varPtrDefiningOp->user_end());
+        } else {
+          auto blockArg = cast<BlockArgument>(varPtr);
+          users.insert(blockArg.user_begin(), blockArg.user_end());
+        }
+        auto usesVarPtr = [&users](Operation *op) -> bool {
+          return users.count(op);
+        };
+
+        SmallVector<Operation *> chainOfOps;
+        chainOfOps.push_back(mapInfoOperation);
+        if (!mapInfoOp.getMembers().empty()) {
+          for (auto member : mapInfoOp.getMembers()) {
+            if (usesVarPtr(member.getDefiningOp()))
+              chainOfOps.push_back(member.getDefiningOp());
+
+            omp::MapInfoOp memberMap =
+                cast<omp::MapInfoOp>(member.getDefiningOp());
+            if (memberMap.getVarPtrPtr() &&
+                usesVarPtr(memberMap.getVarPtrPtr().getDefiningOp()))
+              chainOfOps.push_back(memberMap.getVarPtrPtr().getDefiningOp());
+          }
+        }
+
+        DominanceInfo dom;
+        llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+          return dom.dominates(l, r);
+        });
+
+        rewriter.setInsertionPoint(chainOfOps.front());
+
+        Operation *firstOp = chainOfOps.front();
+        Location loc = firstOp->getLoc();
+
+        // Create a llvm.func for 'region' that is marked always_inline and call
+        // it.
+        auto createAlwaysInlineFuncAndCallIt =
+            [&](Region &region, llvm::StringRef funcName,
+                llvm::ArrayRef<Value> args) -> Value {
+          assert(!region.empty() && "region cannot be empty");
+          LLVM::LLVMFuncOp func =
+              createFuncOpForRegion(loc, mod, region, funcName, rewriter);
+          auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+          return call.getResult();
+        };
+
+        Value moldArg, newArg;
+        if (isPrivatizedByValue) {
+          moldArg = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
+          newArg = rewriter.create<LLVM::LoadOp>(loc, varType, heapMem);
+        } else {
+          moldArg = varPtr;
+          newArg = heapMem;
+        }
+
+        Value initializedVal;
+        if (!privatizer.getInitRegion().empty())
+          initializedVal = createAlwaysInlineFuncAndCallIt(
+              privatizer.getInitRegion(),
+              llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+              {moldArg, newArg});
+        else
+          initializedVal = newArg;
+
+        if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+          initializedVal = createAlwaysInlineFuncAndCallIt(
+              privatizer.getCopyRegion(),
+              llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+              {moldArg, initializedVal});
+
+        if (isPrivatizedByValue)
+          (void)rewriter.create<LLVM::StoreOp>(loc, initializedVal, heapMem);
+
+        // clone origOp, replace all uses of varPtr with heapMem and
+        // erase origOp.
+        auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+          Operation *clonedOp = rewriter.clone(*origOp);
+          rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+          rewriter.modifyOpInPlace(clonedOp, [&]() {
+            clonedOp->replaceUsesOfWith(varPtr, heapMem);
+          });
+          rewriter.eraseOp(origOp);
+          return clonedOp;
+        };
+
+        // Now that we have set up the heap-allocated copy of the private
+        // variable, rewrite all the uses of the original variable with
+        // the heap-allocated variable.
+        rewriter.setInsertionPoint(targetOp);
+        rewriter.setInsertionPoint(cloneModifyAndErase(mapInfoOperation));
+
+        // Fix any members that may use varPtr to now use heapMem
+        if (!mapInfoOp.getMembers().empty()) {
+          for (auto member : mapInfoOp.getMembers()) {
+            Operation *memberOperation = member.getDefiningOp();
+            if (!usesVarPtr(memberOperation))
+              continue;
+            rewriter.setInsertionPoint(cloneModifyAndErase(memberOperation));
+
+            auto memberMapInfoOp = cast<omp::MapInfoOp>(memberOperation);
+            if (memberMapInfoOp.getVarPtrPtr()) {
+              Operation *varPtrPtrdefOp =
+                  memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+              rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+            }
+          }
+        }
+
+        // If the type of the private variable is not a pointer,
+        // which is typically the case with !fir.boxchar types, then
+        // we need to ensure that the new private variable is also
+        // not a pointer. Insert a load from heapMem right before
+        // targetOp.
+        if (isPrivatizedByValue) {
+          rewriter.setInsertionPoint(targetOp);
+          auto newPrivVar = rewriter.create<LLVM::LoadOp>(mapInfoOp.getLoc(),
+                                                          varType, heapMem);
+          newPrivVars.push_back(newPrivVar);
+        }
+      }
+      assert(newPrivVars.size() == privateVars.size() &&
+             "The number of private variables must match before and after "
+             "transformation");
+
+      rewriter.setInsertionPoint(targetOp);
+      Operation *newOp = rewriter.clone(*targetOp.getOperation());
+      omp::TargetOp newTargetOp = cast<omp::TargetOp>(newOp);
+      rewriter.modifyOpInPlace(newTargetOp, [&]() {
+        newTargetOp.getPrivateVarsMutable().assign(newPrivVars);
+      });
+      rewriter.replaceOp(targetOp, newTargetOp);
+    });
+  }
----------------
tblah wrote:

For CPU I put it at the end of the outlined task function. This happens a bit before the inlining, but maybe you could put it at the end of the omp.task region.

https://github.com/llvm/llvm-project/pull/155348


More information about the Mlir-commits mailing list