[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Add pass to replace allocas with device shared memory (PR #161863)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Mar 3 03:27:40 PST 2026
================
@@ -0,0 +1,196 @@
+//===- StackToShared.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to swap stack allocations on the target
+// device with device shared memory where applicable.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace omp {
+#define GEN_PASS_DEF_STACKTOSHAREDPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+
+/// When a use takes place inside an omp.parallel region and it's not as a
+/// private clause argument, or when it is a reduction argument passed to
+/// omp.parallel or a function call argument, then the defining allocation is
+/// eligible for replacement with shared memory.
+static bool allocaUseRequiresDeviceSharedMem(const OpOperand &use) {
+ Operation *owner = use.getOwner();
+ if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
+ if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
+ return true;
+ } else if (auto callOp = dyn_cast<CallOpInterface>(owner)) {
+ if (llvm::is_contained(callOp.getArgOperands(), use.get()))
+ return true;
+ }
+
+ // If it is used directly inside of a parallel region, it has to be replaced
+ // unless the use is a private clause.
+ if (owner->getParentOfType<omp::ParallelOp>()) {
+ if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
+ if (auto privateSyms =
+ cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) {
+ for (auto [var, sym] :
+ llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
+ if (var != use.get())
+ continue;
----------------
skatrak wrote:
We can do a bit better, but we do have to go though the list of arguments. I added a commit to #182856 to improve this a bit (easier to do there).
https://github.com/llvm/llvm-project/pull/161863
More information about the llvm-branch-commits
mailing list