[llvm-branch-commits] [flang] [mlir] [Flang][OpenMP] Add pass to replace allocas with device shared memory (PR #161863)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 23 05:57:14 PST 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161863

>From d1d2114a45bf4c5a1edee0d05a7cf531ac715426 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 16 Sep 2025 13:45:43 +0100
Subject: [PATCH 1/3] [Flang][OpenMP] Add pass to replace allocas with device
 shared memory

This patch introduces a new Flang OpenMP MLIR pass, only ran for target device
modules, that identifies `fir.alloca` operations that should use device shared
memory and replaces them with pairs of `omp.alloc_shared_mem` and
`omp.free_shared_mem` operations.

This works in conjunction to the MLIR to LLVM IR translation pass' handling of
privatization, mapping and reductions in the OpenMP dialect to properly select
the right memory space for allocations based on where they are made and where
they are used.

This pass, in particular, handles explicit stack allocations in MLIR, whereas
the aforementioned translation pass takes care of implicit ones represented by
entry block arguments.
---
 .../include/flang/Optimizer/OpenMP/Passes.td  |  17 ++
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 +
 flang/lib/Optimizer/OpenMP/StackToShared.cpp  | 162 +++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   4 +-
 .../Transforms/OpenMP/stack-to-shared.mlir    | 215 ++++++++++++++++++
 5 files changed, 398 insertions(+), 1 deletion(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/StackToShared.cpp
 create mode 100644 flang/test/Transforms/OpenMP/stack-to-shared.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 1b7da0da3721b..d612bfdedecac 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -145,4 +145,21 @@ def AutomapToTargetDataPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
+def StackToSharedPass : Pass<"omp-stack-to-shared", "mlir::func::FuncOp"> {
+  let summary = "Replaces stack allocations with shared memory.";
+  let description = [{
+    `fir.alloca` operations defining values in a target region and then used
+    inside of an `omp.parallel` region are replaced by this pass with
+    `omp.alloc_shared_mem` and `omp.free_shared_mem`. This is also done for
+    top-level function `fir.alloca`s used in the same way when the parent
+    function is a target device function.
+
+    This ensures that explicit private allocations, intended to be shared across
+    threads, use the proper memory space on a target device while supporting the
+    case of parallel regions indirectly reached from within a target region via
+    function calls.
+  }];
+  let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index eb4930fb2f6a7..f68e47f2a01e7 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -13,6 +13,7 @@ add_flang_library(FlangOpenMPTransforms
   LowerWorkshare.cpp
   LowerNontemporal.cpp
   SimdOnly.cpp
+  StackToShared.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/StackToShared.cpp b/flang/lib/Optimizer/OpenMP/StackToShared.cpp
new file mode 100644
index 0000000000000..e666e2ed8f9b9
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/StackToShared.cpp
@@ -0,0 +1,162 @@
+//===- StackToShared.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to swap stack allocations on the target
+// device with device shared memory where applicable.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_STACKTOSHAREDPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+using namespace mlir;
+
+namespace {
+class StackToSharedPass
+    : public flangomp::impl::StackToSharedPassBase<StackToSharedPass> {
+public:
+  StackToSharedPass() = default;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder builder(context);
+
+    func::FuncOp funcOp = getOperation();
+    auto offloadIface = funcOp->getParentOfType<omp::OffloadModuleInterface>();
+    if (!offloadIface || !offloadIface.getIsTargetDevice())
+      return;
+
+    funcOp->walk([&](fir::AllocaOp allocaOp) {
+      if (!shouldReplaceAlloca(*allocaOp))
+        return;
+
+      // Replace fir.alloca with omp.alloc_shared_mem.
+      builder.setInsertionPoint(allocaOp);
+      auto sharedAllocOp = omp::AllocSharedMemOp::create(
+          builder, allocaOp->getLoc(), allocaOp.getResult().getType(),
+          allocaOp.getInType(), allocaOp.getUniqNameAttr(),
+          allocaOp.getBindcNameAttr(), allocaOp.getTypeparams(),
+          allocaOp.getShape());
+      allocaOp.replaceAllUsesWith(sharedAllocOp.getOperation());
+      allocaOp.erase();
+
+      // Create a new omp.free_shared_mem for the allocated buffer prior to
+      // exiting the region.
+      Block *allocaBlock = sharedAllocOp->getBlock();
+      DominanceInfo domInfo;
+      for (Block &block : sharedAllocOp->getParentRegion()->getBlocks()) {
+        Operation *terminator = block.getTerminator();
+        if (!terminator->hasSuccessors() &&
+            domInfo.dominates(allocaBlock, &block)) {
+          builder.setInsertionPoint(terminator);
+          omp::FreeSharedMemOp::create(builder, sharedAllocOp.getLoc(),
+                                       sharedAllocOp);
+        }
+      }
+    });
+  }
+
+private:
+  // TODO: Refactor the logic in `shouldReplaceAlloca` and `checkAllocaUses` to
+  // be reusable by the MLIR to LLVM IR translation stage, as something very
+  // similar is also implemented there to choose between allocas and device
+  // shared memory allocations when processing OpenMP reductions, mapping and
+  // privatization.
+
+  // Decide whether to replace a fir.alloca with a pair of device shared memory
+  // allocation/deallocation pair based on the location of the allocation and
+  // its uses.
+  //
+  // In summary, it should be done whenever the allocation is placed outside any
+  // parallel regions and inside either a target device function or a generic
+  // kernel, while being used inside of a parallel region.
+  bool shouldReplaceAlloca(Operation &op) {
+    auto targetOp = op.getParentOfType<omp::TargetOp>();
+
+    // It must be inside of a generic omp.target or in a target device function,
+    // and not inside of omp.parallel.
+    if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) {
+      if (!targetOp || !targetOp->isProperAncestor(parallelOp))
+        return false;
+    }
+
+    if (targetOp) {
+      if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) !=
+          mlir::omp::TargetExecMode::generic)
+        return false;
+    } else {
+      auto declTargetIface = dyn_cast<mlir::omp::DeclareTargetInterface>(
+          *op.getParentOfType<func::FuncOp>());
+      if (!declTargetIface || !declTargetIface.isDeclareTarget() ||
+          declTargetIface.getDeclareTargetDeviceType() ==
+              mlir::omp::DeclareTargetDeviceType::host)
+        return false;
+    }
+
+    return checkAllocaUses(op.getUses());
+  }
+
+  // When a use takes place inside an omp.parallel region and it's not as a
+  // private clause argument, or when it is a reduction argument passed to
+  // omp.parallel, then the defining allocation is eligible for replacement with
+  // shared memory.
+  //
+  // Only one of the uses needs to meet these conditions to return true.
+  bool checkAllocaUses(const Operation::use_range &uses) {
+    auto checkUse = [&](const OpOperand &use) {
+      Operation *owner = use.getOwner();
+      auto moduleOp = owner->getParentOfType<ModuleOp>();
+      if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
+        if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
+          return true;
+      } else if (owner->getParentOfType<omp::ParallelOp>()) {
+        // If it is used directly inside of a parallel region, it has to be
+        // replaced unless the use is a private clause.
+        if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
+          if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
+                  owner->getAttr("private_syms"))) {
+            for (auto [var, sym] :
+                 llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
+              if (var != use.get())
+                continue;
+
+              auto privateOp = cast<omp::PrivateClauseOp>(
+                  moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
+              return privateOp.getDataSharingType() !=
+                     omp::DataSharingClauseType::Private;
+            }
+          }
+        }
+        return true;
+      }
+      return false;
+    };
+
+    // Check direct uses and also follow hlfir.declare uses.
+    for (const OpOperand &use : uses) {
+      if (auto declareOp = dyn_cast<hlfir::DeclareOp>(use.getOwner())) {
+        if (checkAllocaUses(declareOp->getUses()))
+          return true;
+      } else if (checkUse(use)) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 9b73d587ee7bc..08840a30c21c6 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -354,8 +354,10 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
   pm.addPass(flangomp::createDeleteUnreachableTargetsPass());
 
   pm.addPass(flangomp::createGenericLoopConversionPass());
-  if (opts.isTargetDevice)
+  if (opts.isTargetDevice) {
+    pm.addPass(flangomp::createStackToSharedPass());
     pm.addPass(flangomp::createFunctionFilteringPass());
+  }
 }
 
 void createDebugPasses(mlir::PassManager &pm,
diff --git a/flang/test/Transforms/OpenMP/stack-to-shared.mlir b/flang/test/Transforms/OpenMP/stack-to-shared.mlir
new file mode 100644
index 0000000000000..a7842048a8411
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/stack-to-shared.mlir
@@ -0,0 +1,215 @@
+// RUN: fir-opt --split-input-file --omp-stack-to-shared %s | FileCheck %s
+
+module attributes {omp.is_target_device = true} {
+  omp.declare_reduction @add_reduction_i32 : i32 init {
+  ^bb0(%arg0: i32):
+    %c0_i32 = arith.constant 0 : i32
+    omp.yield(%c0_i32 : i32)
+  } combiner {
+  ^bb0(%arg0: i32, %arg1: i32):
+    %0 = arith.addi %arg0, %arg1 : i32
+    omp.yield(%0 : i32)
+  }
+
+  omp.private {type = private} @privatizer_i32 : i32
+  omp.private {type = firstprivate} @firstprivatizer_i32 : i32 copy {
+  ^bb0(%arg0: i32, %arg1: i32):
+    omp.yield(%arg0 : i32)
+  }
+
+  // Verify that target device functions are searched for allocas shared across
+  // threads of a parallel region.
+  //
+  // Also ensure that all fir.alloca information is adequately forwarded to the
+  // new allocation, that uses of the allocation through hlfir.declare are
+  // detected and that only the expected types of uses (parallel reduction and
+  // non-private uses inside of a parallel region) are replaced.
+  // CHECK-LABEL: func.func @standalone_func
+  func.func @standalone_func(%lb: i32, %ub: i32, %step: i32) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
+    // CHECK: %[[ALLOC_0:.*]] = omp.alloc_shared_mem i32 {uniq_name = "x"} : !fir.ref<i32>
+    %0 = fir.alloca i32 {uniq_name = "x"}
+    %c = arith.constant 1 : index
+    // CHECK: %[[ALLOC_1:.*]] = omp.alloc_shared_mem !fir.char<1,?>(%[[C:.*]] : index), %[[C]] {bindc_name = "y", uniq_name = "y"} : !fir.ref<!fir.char<1,?>>
+    %1 = fir.alloca !fir.char<1,?>(%c : index), %c {bindc_name = "y", uniq_name = "y"}
+    // CHECK: %{{.*}}:2 = hlfir.declare %[[ALLOC_1]] typeparams %[[C]] {uniq_name = "y"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
+    %decl:2 = hlfir.declare %1 typeparams %c {uniq_name = "y"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
+    // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "z"}
+    %2 = fir.alloca i32 {uniq_name = "z"}
+    // CHECK: %[[ALLOC_2:.*]] = omp.alloc_shared_mem i32 {uniq_name = "a"} : !fir.ref<i32>
+    %3 = fir.alloca i32 {uniq_name = "a"}
+    // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "b"}
+    %4 = fir.alloca i32 {uniq_name = "b"}
+    omp.parallel reduction(@add_reduction_i32 %0 -> %arg0 : !fir.ref<i32>) {
+      // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "c"}
+      %5 = fir.alloca i32 {uniq_name = "c"}
+      %6:2 = fir.unboxchar %decl#0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+      omp.wsloop private(@privatizer_i32 %2 -> %arg1, @firstprivatizer_i32 %3 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
+        omp.loop_nest (%arg3) : i32 = (%lb) to (%ub) inclusive step (%step) {
+          %7 = fir.load %5 : !fir.ref<i32>
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    %5 = fir.load %4 : !fir.ref<i32>
+    // CHECK: omp.free_shared_mem %[[ALLOC_0]] : !fir.ref<i32>
+    // CHECK-NEXT: omp.free_shared_mem %[[ALLOC_1]] : !fir.ref<!fir.char<1,?>>
+    // CHECK-NEXT: omp.free_shared_mem %[[ALLOC_2]] : !fir.ref<i32>
+    // CHECK-NEXT: return
+    return
+  }
+
+  // Verify that generic target regions are searched for allocas shared across
+  // threads of a parallel region.
+  // CHECK-LABEL: func.func @target_generic
+  func.func @target_generic() {
+    // CHECK: omp.target
+    omp.target {
+      %c = arith.constant 0 : i32
+      // CHECK: %[[ALLOC_0:.*]] = omp.alloc_shared_mem i32 {uniq_name = "x"} : !fir.ref<i32>
+      %0 = fir.alloca i32 {uniq_name = "x"}
+      // CHECK: omp.teams
+      omp.teams {
+        // CHECK: %[[ALLOC_1:.*]] = omp.alloc_shared_mem i32 {uniq_name = "y"} : !fir.ref<i32>
+        %1 = fir.alloca i32 {uniq_name = "y"}
+        // CHECK: omp.distribute
+        omp.distribute {
+          omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
+            // CHECK: %[[ALLOC_2:.*]] = omp.alloc_shared_mem i32 {uniq_name = "z"} : !fir.ref<i32>
+            %2 = fir.alloca i32 {uniq_name = "z"}
+            // CHECK: omp.parallel
+            omp.parallel {
+              %3 = fir.load %0 : !fir.ref<i32>
+              %4 = fir.load %1 : !fir.ref<i32>
+              %5 = fir.load %2 : !fir.ref<i32>
+              // CHECK: omp.terminator
+              omp.terminator
+            }
+            // CHECK: omp.free_shared_mem %[[ALLOC_2]] : !fir.ref<i32>
+            // CHECK: omp.yield
+            omp.yield
+          }
+        }
+        // CHECK: omp.free_shared_mem %[[ALLOC_1]] : !fir.ref<i32>
+        // CHECK: omp.terminator
+        omp.terminator
+      }
+      // CHECK: omp.free_shared_mem %[[ALLOC_0]] : !fir.ref<i32>
+      // CHECK: omp.terminator
+      omp.terminator
+    }
+    // CHECK: return
+    return
+  }
+
+  // Make sure that uses not shared across threads on a parallel region inside
+  // of target are not incorrectly detected as such if there's another parallel
+  // region in the host wrapping the whole target region.
+  // CHECK-LABEL: func.func @target_generic_in_parallel
+  func.func @target_generic_in_parallel() {
+    // CHECK-NOT: omp.alloc_shared_mem
+    // CHECK-NOT: omp.free_shared_mem
+    omp.parallel {
+      omp.target {
+        %c = arith.constant 0 : i32
+        %0 = fir.alloca i32 {uniq_name = "x"}
+        omp.teams {
+          %1 = fir.alloca i32 {uniq_name = "y"}
+          omp.distribute {
+            omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
+              %3 = fir.load %0 : !fir.ref<i32>
+              %4 = fir.load %1 : !fir.ref<i32>
+              omp.parallel {
+                omp.terminator
+              }
+              omp.yield
+            }
+          }
+          omp.terminator
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    // CHECK: return
+    return
+  }
+
+  // Ensure that allocations within SPMD target regions are not replaced with
+  // device shared memory regardless of use.
+  // CHECK-LABEL: func.func @target_spmd
+  func.func @target_spmd() {
+    // CHECK-NOT: omp.alloc_shared_mem
+    // CHECK-NOT: omp.free_shared_mem
+    omp.target {
+      %c = arith.constant 0 : i32
+      %0 = fir.alloca i32 {uniq_name = "x"}
+      omp.teams {
+        %1 = fir.alloca i32 {uniq_name = "y"}
+        omp.parallel {
+          %2 = fir.alloca i32 {uniq_name = "z"}
+          %3 = fir.load %0 : !fir.ref<i32>
+          %4 = fir.load %1 : !fir.ref<i32>
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
+                %5 = fir.load %2 : !fir.ref<i32>
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    // CHECK: return
+    return
+  }
+}
+
+// -----
+
+// No transformations must be done when targeting the host device.
+// CHECK-LABEL: func.func @host_standalone
+func.func @host_standalone() {
+  // CHECK-NOT: omp.alloc_shared_mem
+  // CHECK-NOT: omp.free_shared_mem
+  %0 = fir.alloca i32 {uniq_name = "x"}
+  omp.parallel {
+    %1 = fir.load %0 : !fir.ref<i32>
+    omp.terminator
+  }
+  // CHECK: return
+  return
+}
+
+// CHECK-LABEL: func.func @host_target
+func.func @host_target() {
+  // CHECK-NOT: omp.alloc_shared_mem
+  // CHECK-NOT: omp.free_shared_mem
+  omp.target {
+    %c = arith.constant 0 : i32
+    %0 = fir.alloca i32 {uniq_name = "x"}
+    omp.teams {
+      %1 = fir.alloca i32 {uniq_name = "y"}
+      omp.distribute {
+        omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
+          %2 = fir.alloca i32 {uniq_name = "z"}
+          omp.parallel {
+            %3 = fir.load %0 : !fir.ref<i32>
+            %4 = fir.load %1 : !fir.ref<i32>
+            %5 = fir.load %2 : !fir.ref<i32>
+            omp.terminator
+          }
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  // CHECK: return
+  return
+}

>From 2800483d067fcdc1f8285749a7b6c12196c5aec2 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Thu, 5 Feb 2026 11:52:25 +0000
Subject: [PATCH 2/3] move stack-to-shared pass to the omp dialect

---
 .../include/flang/Optimizer/OpenMP/Passes.td  |  17 --
 .../include/flang/Optimizer/Support/InitFIR.h |   2 +
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   1 -
 flang/lib/Optimizer/OpenMP/StackToShared.cpp  | 162 -------------
 flang/lib/Optimizer/Passes/Pipelines.cpp      |  11 +-
 flang/test/Fir/basic-program.fir              |   2 +
 .../Transforms/OpenMP/stack-to-shared.mlir    | 215 ------------------
 mlir/docs/Passes.md                           |   4 +
 .../mlir/Dialect/OpenMP/Transforms/Passes.h   |   6 +-
 .../mlir/Dialect/OpenMP/Transforms/Passes.td  |  18 ++
 mlir/lib/Dialect/OpenMP/CMakeLists.txt        |  20 +-
 mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt     |  18 ++
 .../Dialect/OpenMP/Transforms/CMakeLists.txt  |  10 +
 .../OpenMP/Transforms/StackToShared.cpp       | 188 +++++++++++++++
 mlir/test/Dialect/OpenMP/stack-to-shared.mlir | 149 ++++++++++++
 15 files changed, 405 insertions(+), 418 deletions(-)
 delete mode 100644 flang/lib/Optimizer/OpenMP/StackToShared.cpp
 delete mode 100644 flang/test/Transforms/OpenMP/stack-to-shared.mlir
 create mode 100644 mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
 create mode 100644 mlir/test/Dialect/OpenMP/stack-to-shared.mlir

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index d612bfdedecac..1b7da0da3721b 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -145,21 +145,4 @@ def AutomapToTargetDataPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
-def StackToSharedPass : Pass<"omp-stack-to-shared", "mlir::func::FuncOp"> {
-  let summary = "Replaces stack allocations with shared memory.";
-  let description = [{
-    `fir.alloca` operations defining values in a target region and then used
-    inside of an `omp.parallel` region are replaced by this pass with
-    `omp.alloc_shared_mem` and `omp.free_shared_mem`. This is also done for
-    top-level function `fir.alloca`s used in the same way when the parent
-    function is a target device function.
-
-    This ensures that explicit private allocations, intended to be shared across
-    threads, use the proper memory space on a target device while supporting the
-    case of parallel regions indirectly reached from within a target region via
-    function calls.
-  }];
-  let dependentDialects = ["mlir::omp::OpenMPDialect"];
-}
-
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index d77d82feddd84..6051dbb07fad7 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/InitAllDialects.h"
@@ -106,6 +107,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
 /// but is a smaller set since we aren't using many of the passes found there.
 inline void registerMLIRPassesForFortranTools() {
   mlir::acc::registerOpenACCPasses();
+  mlir::omp::registerOpenMPPasses();
   mlir::registerCanonicalizerPass();
   mlir::registerCSEPass();
   mlir::affine::registerAffineLoopFusionPass();
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index f68e47f2a01e7..eb4930fb2f6a7 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -13,7 +13,6 @@ add_flang_library(FlangOpenMPTransforms
   LowerWorkshare.cpp
   LowerNontemporal.cpp
   SimdOnly.cpp
-  StackToShared.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/StackToShared.cpp b/flang/lib/Optimizer/OpenMP/StackToShared.cpp
deleted file mode 100644
index e666e2ed8f9b9..0000000000000
--- a/flang/lib/Optimizer/OpenMP/StackToShared.cpp
+++ /dev/null
@@ -1,162 +0,0 @@
-//===- StackToShared.cpp -------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements transforms to swap stack allocations on the target
-// device with device shared memory where applicable.
-//
-//===----------------------------------------------------------------------===//
-
-#include "flang/Optimizer/Dialect/FIROps.h"
-#include "flang/Optimizer/HLFIR/HLFIROps.h"
-#include "flang/Optimizer/OpenMP/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
-
-namespace flangomp {
-#define GEN_PASS_DEF_STACKTOSHAREDPASS
-#include "flang/Optimizer/OpenMP/Passes.h.inc"
-} // namespace flangomp
-
-using namespace mlir;
-
-namespace {
-class StackToSharedPass
-    : public flangomp::impl::StackToSharedPassBase<StackToSharedPass> {
-public:
-  StackToSharedPass() = default;
-
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-    OpBuilder builder(context);
-
-    func::FuncOp funcOp = getOperation();
-    auto offloadIface = funcOp->getParentOfType<omp::OffloadModuleInterface>();
-    if (!offloadIface || !offloadIface.getIsTargetDevice())
-      return;
-
-    funcOp->walk([&](fir::AllocaOp allocaOp) {
-      if (!shouldReplaceAlloca(*allocaOp))
-        return;
-
-      // Replace fir.alloca with omp.alloc_shared_mem.
-      builder.setInsertionPoint(allocaOp);
-      auto sharedAllocOp = omp::AllocSharedMemOp::create(
-          builder, allocaOp->getLoc(), allocaOp.getResult().getType(),
-          allocaOp.getInType(), allocaOp.getUniqNameAttr(),
-          allocaOp.getBindcNameAttr(), allocaOp.getTypeparams(),
-          allocaOp.getShape());
-      allocaOp.replaceAllUsesWith(sharedAllocOp.getOperation());
-      allocaOp.erase();
-
-      // Create a new omp.free_shared_mem for the allocated buffer prior to
-      // exiting the region.
-      Block *allocaBlock = sharedAllocOp->getBlock();
-      DominanceInfo domInfo;
-      for (Block &block : sharedAllocOp->getParentRegion()->getBlocks()) {
-        Operation *terminator = block.getTerminator();
-        if (!terminator->hasSuccessors() &&
-            domInfo.dominates(allocaBlock, &block)) {
-          builder.setInsertionPoint(terminator);
-          omp::FreeSharedMemOp::create(builder, sharedAllocOp.getLoc(),
-                                       sharedAllocOp);
-        }
-      }
-    });
-  }
-
-private:
-  // TODO: Refactor the logic in `shouldReplaceAlloca` and `checkAllocaUses` to
-  // be reusable by the MLIR to LLVM IR translation stage, as something very
-  // similar is also implemented there to choose between allocas and device
-  // shared memory allocations when processing OpenMP reductions, mapping and
-  // privatization.
-
-  // Decide whether to replace a fir.alloca with a pair of device shared memory
-  // allocation/deallocation pair based on the location of the allocation and
-  // its uses.
-  //
-  // In summary, it should be done whenever the allocation is placed outside any
-  // parallel regions and inside either a target device function or a generic
-  // kernel, while being used inside of a parallel region.
-  bool shouldReplaceAlloca(Operation &op) {
-    auto targetOp = op.getParentOfType<omp::TargetOp>();
-
-    // It must be inside of a generic omp.target or in a target device function,
-    // and not inside of omp.parallel.
-    if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) {
-      if (!targetOp || !targetOp->isProperAncestor(parallelOp))
-        return false;
-    }
-
-    if (targetOp) {
-      if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) !=
-          mlir::omp::TargetExecMode::generic)
-        return false;
-    } else {
-      auto declTargetIface = dyn_cast<mlir::omp::DeclareTargetInterface>(
-          *op.getParentOfType<func::FuncOp>());
-      if (!declTargetIface || !declTargetIface.isDeclareTarget() ||
-          declTargetIface.getDeclareTargetDeviceType() ==
-              mlir::omp::DeclareTargetDeviceType::host)
-        return false;
-    }
-
-    return checkAllocaUses(op.getUses());
-  }
-
-  // When a use takes place inside an omp.parallel region and it's not as a
-  // private clause argument, or when it is a reduction argument passed to
-  // omp.parallel, then the defining allocation is eligible for replacement with
-  // shared memory.
-  //
-  // Only one of the uses needs to meet these conditions to return true.
-  bool checkAllocaUses(const Operation::use_range &uses) {
-    auto checkUse = [&](const OpOperand &use) {
-      Operation *owner = use.getOwner();
-      auto moduleOp = owner->getParentOfType<ModuleOp>();
-      if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
-        if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
-          return true;
-      } else if (owner->getParentOfType<omp::ParallelOp>()) {
-        // If it is used directly inside of a parallel region, it has to be
-        // replaced unless the use is a private clause.
-        if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
-          if (auto privateSyms = llvm::cast_or_null<ArrayAttr>(
-                  owner->getAttr("private_syms"))) {
-            for (auto [var, sym] :
-                 llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
-              if (var != use.get())
-                continue;
-
-              auto privateOp = cast<omp::PrivateClauseOp>(
-                  moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
-              return privateOp.getDataSharingType() !=
-                     omp::DataSharingClauseType::Private;
-            }
-          }
-        }
-        return true;
-      }
-      return false;
-    };
-
-    // Check direct uses and also follow hlfir.declare uses.
-    for (const OpOperand &use : uses) {
-      if (auto declareOp = dyn_cast<hlfir::DeclareOp>(use.getOwner())) {
-        if (checkAllocaUses(declareOp->getUses()))
-          return true;
-      } else if (checkUse(use)) {
-        return true;
-      }
-    }
-
-    return false;
-  }
-};
-} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 08840a30c21c6..a4e2ecc296107 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -10,6 +10,7 @@
 /// common to flang and the test tools.
 
 #include "flang/Optimizer/Passes/Pipelines.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
 #include "llvm/Support/CommandLine.h"
 
 /// Force setting the no-alias attribute on fuction arguments when possible.
@@ -354,10 +355,8 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
   pm.addPass(flangomp::createDeleteUnreachableTargetsPass());
 
   pm.addPass(flangomp::createGenericLoopConversionPass());
-  if (opts.isTargetDevice) {
-    pm.addPass(flangomp::createStackToSharedPass());
+  if (opts.isTargetDevice)
     pm.addPass(flangomp::createFunctionFilteringPass());
-  }
 }
 
 void createDebugPasses(mlir::PassManager &pm,
@@ -426,6 +425,12 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
   }
 
   fir::addFIRToLLVMPass(pm, config);
+
+  // Convert applicable OpenMP stack allocations to shared memory allocations
+  // for GPU targets. This pass must run after any alloca-generating passes to
+  // ensure all are adequately accounted for.
+  if (config.EnableOpenMP && !config.EnableOpenMPSimd)
+    pm.addPass(mlir::omp::createStackToSharedPass());
 }
 
 /// Create a pass pipeline for lowering from MLIR to LLVM IR
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 5f84395b36037..1e26b388267b6 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -178,5 +178,7 @@ func.func @_QQmain() {
 // PASSES-NEXT:  LowerNontemporalPass
 // PASSES-NEXT: FIRToLLVMLowering
 // PASSES-NEXT: ReconcileUnrealizedCasts
+// PASSES-NEXT: 'llvm.func' Pipeline
+// PASSES-NEXT:  StackToSharedPass
 // PASSES-NEXT: PrepareForOMPOffloadPrivatizationPass
 // PASSES-NEXT: LLVMIRLoweringPass
diff --git a/flang/test/Transforms/OpenMP/stack-to-shared.mlir b/flang/test/Transforms/OpenMP/stack-to-shared.mlir
deleted file mode 100644
index a7842048a8411..0000000000000
--- a/flang/test/Transforms/OpenMP/stack-to-shared.mlir
+++ /dev/null
@@ -1,215 +0,0 @@
-// RUN: fir-opt --split-input-file --omp-stack-to-shared %s | FileCheck %s
-
-module attributes {omp.is_target_device = true} {
-  omp.declare_reduction @add_reduction_i32 : i32 init {
-  ^bb0(%arg0: i32):
-    %c0_i32 = arith.constant 0 : i32
-    omp.yield(%c0_i32 : i32)
-  } combiner {
-  ^bb0(%arg0: i32, %arg1: i32):
-    %0 = arith.addi %arg0, %arg1 : i32
-    omp.yield(%0 : i32)
-  }
-
-  omp.private {type = private} @privatizer_i32 : i32
-  omp.private {type = firstprivate} @firstprivatizer_i32 : i32 copy {
-  ^bb0(%arg0: i32, %arg1: i32):
-    omp.yield(%arg0 : i32)
-  }
-
-  // Verify that target device functions are searched for allocas shared across
-  // threads of a parallel region.
-  //
-  // Also ensure that all fir.alloca information is adequately forwarded to the
-  // new allocation, that uses of the allocation through hlfir.declare are
-  // detected and that only the expected types of uses (parallel reduction and
-  // non-private uses inside of a parallel region) are replaced.
-  // CHECK-LABEL: func.func @standalone_func
-  func.func @standalone_func(%lb: i32, %ub: i32, %step: i32) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
-    // CHECK: %[[ALLOC_0:.*]] = omp.alloc_shared_mem i32 {uniq_name = "x"} : !fir.ref<i32>
-    %0 = fir.alloca i32 {uniq_name = "x"}
-    %c = arith.constant 1 : index
-    // CHECK: %[[ALLOC_1:.*]] = omp.alloc_shared_mem !fir.char<1,?>(%[[C:.*]] : index), %[[C]] {bindc_name = "y", uniq_name = "y"} : !fir.ref<!fir.char<1,?>>
-    %1 = fir.alloca !fir.char<1,?>(%c : index), %c {bindc_name = "y", uniq_name = "y"}
-    // CHECK: %{{.*}}:2 = hlfir.declare %[[ALLOC_1]] typeparams %[[C]] {uniq_name = "y"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
-    %decl:2 = hlfir.declare %1 typeparams %c {uniq_name = "y"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
-    // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "z"}
-    %2 = fir.alloca i32 {uniq_name = "z"}
-    // CHECK: %[[ALLOC_2:.*]] = omp.alloc_shared_mem i32 {uniq_name = "a"} : !fir.ref<i32>
-    %3 = fir.alloca i32 {uniq_name = "a"}
-    // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "b"}
-    %4 = fir.alloca i32 {uniq_name = "b"}
-    omp.parallel reduction(@add_reduction_i32 %0 -> %arg0 : !fir.ref<i32>) {
-      // CHECK: %{{.*}} = fir.alloca i32 {uniq_name = "c"}
-      %5 = fir.alloca i32 {uniq_name = "c"}
-      %6:2 = fir.unboxchar %decl#0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
-      omp.wsloop private(@privatizer_i32 %2 -> %arg1, @firstprivatizer_i32 %3 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
-        omp.loop_nest (%arg3) : i32 = (%lb) to (%ub) inclusive step (%step) {
-          %7 = fir.load %5 : !fir.ref<i32>
-          omp.yield
-        }
-      }
-      omp.terminator
-    }
-    %5 = fir.load %4 : !fir.ref<i32>
-    // CHECK: omp.free_shared_mem %[[ALLOC_0]] : !fir.ref<i32>
-    // CHECK-NEXT: omp.free_shared_mem %[[ALLOC_1]] : !fir.ref<!fir.char<1,?>>
-    // CHECK-NEXT: omp.free_shared_mem %[[ALLOC_2]] : !fir.ref<i32>
-    // CHECK-NEXT: return
-    return
-  }
-
-  // Verify that generic target regions are searched for allocas shared across
-  // threads of a parallel region.
-  // CHECK-LABEL: func.func @target_generic
-  func.func @target_generic() {
-    // CHECK: omp.target
-    omp.target {
-      %c = arith.constant 0 : i32
-      // CHECK: %[[ALLOC_0:.*]] = omp.alloc_shared_mem i32 {uniq_name = "x"} : !fir.ref<i32>
-      %0 = fir.alloca i32 {uniq_name = "x"}
-      // CHECK: omp.teams
-      omp.teams {
-        // CHECK: %[[ALLOC_1:.*]] = omp.alloc_shared_mem i32 {uniq_name = "y"} : !fir.ref<i32>
-        %1 = fir.alloca i32 {uniq_name = "y"}
-        // CHECK: omp.distribute
-        omp.distribute {
-          omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
-            // CHECK: %[[ALLOC_2:.*]] = omp.alloc_shared_mem i32 {uniq_name = "z"} : !fir.ref<i32>
-            %2 = fir.alloca i32 {uniq_name = "z"}
-            // CHECK: omp.parallel
-            omp.parallel {
-              %3 = fir.load %0 : !fir.ref<i32>
-              %4 = fir.load %1 : !fir.ref<i32>
-              %5 = fir.load %2 : !fir.ref<i32>
-              // CHECK: omp.terminator
-              omp.terminator
-            }
-            // CHECK: omp.free_shared_mem %[[ALLOC_2]] : !fir.ref<i32>
-            // CHECK: omp.yield
-            omp.yield
-          }
-        }
-        // CHECK: omp.free_shared_mem %[[ALLOC_1]] : !fir.ref<i32>
-        // CHECK: omp.terminator
-        omp.terminator
-      }
-      // CHECK: omp.free_shared_mem %[[ALLOC_0]] : !fir.ref<i32>
-      // CHECK: omp.terminator
-      omp.terminator
-    }
-    // CHECK: return
-    return
-  }
-
-  // Make sure that uses not shared across threads on a parallel region inside
-  // of target are not incorrectly detected as such if there's another parallel
-  // region in the host wrapping the whole target region.
-  // CHECK-LABEL: func.func @target_generic_in_parallel
-  func.func @target_generic_in_parallel() {
-    // CHECK-NOT: omp.alloc_shared_mem
-    // CHECK-NOT: omp.free_shared_mem
-    omp.parallel {
-      omp.target {
-        %c = arith.constant 0 : i32
-        %0 = fir.alloca i32 {uniq_name = "x"}
-        omp.teams {
-          %1 = fir.alloca i32 {uniq_name = "y"}
-          omp.distribute {
-            omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
-              %3 = fir.load %0 : !fir.ref<i32>
-              %4 = fir.load %1 : !fir.ref<i32>
-              omp.parallel {
-                omp.terminator
-              }
-              omp.yield
-            }
-          }
-          omp.terminator
-        }
-        omp.terminator
-      }
-      omp.terminator
-    }
-    // CHECK: return
-    return
-  }
-
-  // Ensure that allocations within SPMD target regions are not replaced with
-  // device shared memory regardless of use.
-  // CHECK-LABEL: func.func @target_spmd
-  func.func @target_spmd() {
-    // CHECK-NOT: omp.alloc_shared_mem
-    // CHECK-NOT: omp.free_shared_mem
-    omp.target {
-      %c = arith.constant 0 : i32
-      %0 = fir.alloca i32 {uniq_name = "x"}
-      omp.teams {
-        %1 = fir.alloca i32 {uniq_name = "y"}
-        omp.parallel {
-          %2 = fir.alloca i32 {uniq_name = "z"}
-          %3 = fir.load %0 : !fir.ref<i32>
-          %4 = fir.load %1 : !fir.ref<i32>
-          omp.distribute {
-            omp.wsloop {
-              omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
-                %5 = fir.load %2 : !fir.ref<i32>
-                omp.yield
-              }
-            } {omp.composite}
-          } {omp.composite}
-          omp.terminator
-        } {omp.composite}
-        omp.terminator
-      }
-      omp.terminator
-    }
-    // CHECK: return
-    return
-  }
-}
-
-// -----
-
-// No transformations must be done when targeting the host device.
-// CHECK-LABEL: func.func @host_standalone
-func.func @host_standalone() {
-  // CHECK-NOT: omp.alloc_shared_mem
-  // CHECK-NOT: omp.free_shared_mem
-  %0 = fir.alloca i32 {uniq_name = "x"}
-  omp.parallel {
-    %1 = fir.load %0 : !fir.ref<i32>
-    omp.terminator
-  }
-  // CHECK: return
-  return
-}
-
-// CHECK-LABEL: func.func @host_target
-func.func @host_target() {
-  // CHECK-NOT: omp.alloc_shared_mem
-  // CHECK-NOT: omp.free_shared_mem
-  omp.target {
-    %c = arith.constant 0 : i32
-    %0 = fir.alloca i32 {uniq_name = "x"}
-    omp.teams {
-      %1 = fir.alloca i32 {uniq_name = "y"}
-      omp.distribute {
-        omp.loop_nest (%arg0) : i32 = (%c) to (%c) inclusive step (%c) {
-          %2 = fir.alloca i32 {uniq_name = "z"}
-          omp.parallel {
-            %3 = fir.load %0 : !fir.ref<i32>
-            %4 = fir.load %1 : !fir.ref<i32>
-            %5 = fir.load %2 : !fir.ref<i32>
-            omp.terminator
-          }
-          omp.yield
-        }
-      }
-      omp.terminator
-    }
-    omp.terminator
-  }
-  // CHECK: return
-  return
-}
diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index 9df32666415bb..f3d8a75c65840 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -72,6 +72,10 @@ This document describes the available MLIR passes and their contracts.
 
 [include "MemRefPasses.md"]
 
+## 'omp' Dialect Passes
+
+[include "OpenMPPasses.md"]
+
 ## 'shard' Dialect Passes
 
 [include "ShardPasses.md"]
diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h
index 21b6d1f466558..ddbe662be69fc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h
@@ -13,6 +13,10 @@
 
 namespace mlir {
 
+namespace LLVM {
+class LLVMFuncOp;
+} // namespace LLVM
+
 namespace omp {
 
 /// Generate the code for registering conversion passes.
@@ -23,4 +27,4 @@ namespace omp {
 } // namespace omp
 } // namespace mlir
 
-#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
+#endif // MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
index 1fde7e08ab433..498b8a4812caa 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
@@ -23,4 +23,22 @@ def PrepareForOMPOffloadPrivatizationPass : Pass<"omp-offload-privatization-prep
     }];
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
+
+def StackToSharedPass : Pass<"omp-stack-to-shared", "mlir::LLVM::LLVMFuncOp"> {
+  let summary = "Replaces stack allocations target devices with shared memory.";
+  let description = [{
+    `llvm.alloca` operations defining values in a non-SPMD target region and
+    then potentially used inside of an `omp.parallel` region are replaced by
+    this pass with `omp.alloc_shared_mem` and `omp.free_shared_mem`. This is
+    also done for top-level function `llvm.alloca`s used in the same way when
+    the parent function is a target device function.
+
+    This ensures that explicit private allocations, intended to be shared across
+    threads, use the proper memory space on a target device while supporting the
+    case of parallel regions indirectly reached from within a target region via
+    function calls.
+  }];
+  let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
 #endif // MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index f3c02da458508..9f57627c321fb 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,20 +1,2 @@
+add_subdirectory(IR)
 add_subdirectory(Transforms)
-
-add_mlir_dialect_library(MLIROpenMPDialect
-  IR/OpenMPDialect.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
-
-  DEPENDS
-  omp_gen
-  MLIROpenMPOpsIncGen
-  MLIROpenMPOpsInterfacesIncGen
-  MLIROpenMPTypeInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMDialect
-  MLIRFuncDialect
-  MLIROpenACCMPCommon
-  )
diff --git a/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..05923032d9077
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIROpenMPDialect
+  OpenMPDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
+
+  DEPENDS
+  omp_gen
+  MLIROpenMPOpsIncGen
+  MLIROpenMPOpsInterfacesIncGen
+  MLIROpenMPTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRFuncDialect
+  MLIROpenACCMPCommon
+  )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
index b9b8eda9ed51b..b00ca178dd9df 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -1,14 +1,24 @@
 add_mlir_dialect_library(MLIROpenMPTransforms
   OpenMPOffloadPrivatizationPrepare.cpp
+  StackToShared.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
 
   DEPENDS
+  omp_gen
   MLIROpenMPPassIncGen
+  MLIROpenMPOpsIncGen
+  MLIROpenMPOpsInterfacesIncGen
+  MLIROpenMPTypeInterfacesIncGen
 
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRFuncDialect
   MLIRLLVMDialect
+  MLIROpenACCMPCommon
   MLIROpenMPDialect
   MLIRPass
+  MLIRSupport
   MLIRTransforms
   )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
new file mode 100644
index 0000000000000..bd2c8c0747b7d
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
@@ -0,0 +1,188 @@
+//===- StackToShared.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to swap stack allocations on the target
+// device with device shared memory where applicable.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace omp {
+#define GEN_PASS_DEF_STACKTOSHAREDPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+
+/// When a use takes place inside an omp.parallel region and it's not as a
+/// private clause argument, or when it is a reduction argument passed to
+/// omp.parallel or a function call argument, then the defining allocation is
+/// eligible for replacement with shared memory.
+static bool allocaUseRequiresDeviceSharedMem(const OpOperand &use) {
+  Operation *owner = use.getOwner();
+  if (auto parallelOp = dyn_cast<omp::ParallelOp>(owner)) {
+    if (llvm::is_contained(parallelOp.getReductionVars(), use.get()))
+      return true;
+  } else if (auto callOp = dyn_cast<CallOpInterface>(owner)) {
+    if (llvm::is_contained(callOp.getArgOperands(), use.get()))
+      return true;
+  }
+
+  // If it is used directly inside of a parallel region, it has to be replaced
+  // unless the use is a private clause.
+  if (owner->getParentOfType<omp::ParallelOp>()) {
+    if (auto argIface = dyn_cast<omp::BlockArgOpenMPOpInterface>(owner)) {
+      if (auto privateSyms =
+              cast_or_null<ArrayAttr>(owner->getAttr("private_syms"))) {
+        for (auto [var, sym] :
+             llvm::zip_equal(argIface.getPrivateVars(), privateSyms)) {
+          if (var != use.get())
+            continue;
+
+          auto moduleOp = owner->getParentOfType<ModuleOp>();
+          auto privateOp = cast<omp::PrivateClauseOp>(
+              moduleOp.lookupSymbol(cast<SymbolRefAttr>(sym)));
+          return privateOp.getDataSharingType() !=
+                 omp::DataSharingClauseType::Private;
+        }
+      }
+    }
+    return true;
+  }
+  return false;
+}
+
+static bool shouldReplaceAllocaWithUses(const Operation::use_range &uses) {
+  // Check direct uses and also follow hlfir.declare/fir.convert uses.
+  for (const OpOperand &use : uses) {
+    Operation *owner = use.getOwner();
+    if (llvm::isa<LLVM::AddrSpaceCastOp, LLVM::GEPOp>(owner)) {
+      if (shouldReplaceAllocaWithUses(owner->getUses()))
+        return true;
+    } else if (allocaUseRequiresDeviceSharedMem(use)) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+// TODO: Refactor the logic in `shouldReplaceAllocaWithDeviceSharedMem`,
+// `shouldReplaceAllocaWithUses` and `allocaUseRequiresDeviceSharedMem` to
+// be reusable by the MLIR to LLVM IR translation stage, as something very
+// similar is also implemented there to choose between allocas and device
+// shared memory allocations when processing OpenMP reductions, mapping and
+// privatization.
+bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
+  auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>();
+  if (!offloadIface || !offloadIface.getIsTargetDevice())
+    return false;
+
+  auto targetOp = op.getParentOfType<omp::TargetOp>();
+
+  // It must be inside of a generic omp.target or in a target device function,
+  // and not inside of omp.parallel.
+  if (auto parallelOp = op.getParentOfType<omp::ParallelOp>()) {
+    if (!targetOp || targetOp->isProperAncestor(parallelOp))
+      return false;
+  }
+
+  if (targetOp) {
+    if (targetOp.getKernelExecFlags(targetOp.getInnermostCapturedOmpOp()) !=
+        omp::TargetExecMode::generic)
+      return false;
+  } else {
+    auto declTargetIface = op.getParentOfType<omp::DeclareTargetInterface>();
+    if (!declTargetIface || !declTargetIface.isDeclareTarget() ||
+        declTargetIface.getDeclareTargetDeviceType() ==
+            omp::DeclareTargetDeviceType::host)
+      return false;
+  }
+
+  return shouldReplaceAllocaWithUses(op.getUses());
+}
+
+void insertDeviceSharedMemDeallocation(OpBuilder &builder, Value allocVal) {
+  Block *allocaBlock = allocVal.getParentBlock();
+  DominanceInfo domInfo;
+  for (Block &block : allocVal.getParentRegion()->getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!terminator->hasSuccessors() &&
+        domInfo.dominates(allocaBlock, &block)) {
+      builder.setInsertionPoint(terminator);
+      omp::FreeSharedMemOp::create(builder, allocVal.getLoc(), allocVal);
+    }
+  }
+}
+
+namespace {
+class StackToSharedPass
+    : public omp::impl::StackToSharedPassBase<StackToSharedPass> {
+public:
+  StackToSharedPass() = default;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder builder(context);
+
+    LLVM::LLVMFuncOp funcOp = getOperation();
+    auto offloadIface = funcOp->getParentOfType<omp::OffloadModuleInterface>();
+    if (!offloadIface || !offloadIface.getIsTargetDevice())
+      return;
+
+    llvm::SmallVector<Operation *> toBeDeleted;
+    funcOp->walk([&](LLVM::AllocaOp allocaOp) {
+      if (!shouldReplaceAllocaWithDeviceSharedMem(*allocaOp))
+        return;
+      // Replace llvm.alloca with omp.alloc_shared_mem.
+      Type resultType = allocaOp.getResult().getType();
+
+      // TODO: The handling of non-default address spaces might need to be
+      // improved. This currently only handles the case where an alloca to
+      // non-default address space must only be used by a single addrspacecast
+      // to default address space.
+      bool nonDefaultAddrSpace = false;
+      if (auto llvmPtrType = dyn_cast<LLVM::LLVMPointerType>(resultType))
+        nonDefaultAddrSpace = llvmPtrType.getAddressSpace() != 0;
+
+      builder.setInsertionPoint(allocaOp);
+      auto sharedAllocOp = omp::AllocSharedMemOp::create(
+          builder, allocaOp->getLoc(), LLVM::LLVMPointerType::get(context),
+          allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
+          allocaOp.getAlignmentAttr());
+      if (nonDefaultAddrSpace) {
+        assert(allocaOp->hasOneUse() && "alloca must have only one use");
+        auto asCastOp =
+            cast<LLVM::AddrSpaceCastOp>(*allocaOp->getUsers().begin());
+        asCastOp.replaceAllUsesWith(sharedAllocOp.getOperation());
+        // Delete later because we can't delete the cast op before the top-level
+        // iteration visits it. Also, the alloca can't be deleted before because
+        // it's used by it.
+        toBeDeleted.push_back(asCastOp);
+        toBeDeleted.push_back(allocaOp);
+      } else {
+        allocaOp.replaceAllUsesWith(sharedAllocOp.getOperation());
+        allocaOp.erase();
+      }
+
+      // Create a new omp.free_shared_mem for the allocated buffer prior to
+      // exiting the region.
+      insertDeviceSharedMemDeallocation(builder, sharedAllocOp.getResult());
+    });
+    for (Operation *op : toBeDeleted)
+      op->erase();
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/OpenMP/stack-to-shared.mlir b/mlir/test/Dialect/OpenMP/stack-to-shared.mlir
new file mode 100644
index 0000000000000..81b03acd4d368
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/stack-to-shared.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt --omp-stack-to-shared %s | FileCheck %s
+
+module attributes {omp.is_target_device = true} {
+
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> f32
+  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+  omp.yield
+}
+omp.private {type = private} @privatizer_i32 : i32
+omp.private {type = firstprivate} @firstprivatizer_f32 : f32 copy {
+^bb0(%arg0: f32, %arg1: f32):
+  omp.yield(%arg0 : f32)
+}
+
+llvm.func @foo(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>}
+
+// CHECK-LABEL: llvm.func @device_func(
+// CHECK-SAME:  %[[N:.*]]: i64, %[[COND:.*]]: i1)
+llvm.func @device_func(%arg0: i64, %cond: i1) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
+  // CHECK: %[[ALLOC0:.*]] = omp.alloc_shared_mem %[[N]] x i64 : (i64) -> !llvm.ptr
+  %0 = llvm.alloca %arg0 x i64 : (i64) -> !llvm.ptr
+  // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem %[[N]] x f32 {alignment = 128 : i64} : (i64) -> !llvm.ptr
+  %1 = llvm.alloca %arg0 x f32 {alignment = 128} : (i64) -> !llvm.ptr
+  // CHECK: %[[ALLOC2:.*]] = omp.alloc_shared_mem %[[N]] x vector<16xf32> : (i64) -> !llvm.ptr
+  %2 = llvm.alloca %arg0 x vector<16xf32> : (i64) -> !llvm.ptr
+  // CHECK: %[[ALLOC3:.*]] = omp.alloc_shared_mem %[[N]] x i32 : (i64) -> !llvm.ptr
+  %3 = llvm.alloca %arg0 x i32 : (i64) -> !llvm.ptr<5>
+  %4 = llvm.addrspacecast %3 : !llvm.ptr<5> to !llvm.ptr
+
+  // CHECK: %[[ALLOC4:.*]] = llvm.alloca %[[N]] x i32 : (i64) -> !llvm.ptr
+  %5 = llvm.alloca %arg0 x i32 : (i64) -> !llvm.ptr
+  // CHECK: %[[ALLOC5:.*]] = llvm.alloca %[[N]] x i32 : (i64) -> !llvm.ptr
+  %6 = llvm.alloca %arg0 x i32 : (i64) -> !llvm.ptr
+  // CHECK: llvm.cond_br %[[COND]], ^[[IF:.*]], ^[[ELSE:.*]]
+  llvm.cond_br %cond, ^if, ^else
+
+// CHECK: ^[[IF]]:
+^if:
+  // CHECK: omp.parallel reduction(@add_f32 %[[ALLOC0]] -> %{{.*}} : !llvm.ptr)
+  omp.parallel reduction(@add_f32 %0 -> %arg1 : !llvm.ptr) {
+    // CHECK: %{{.*}} = llvm.load %[[ALLOC2]]
+    %7 = llvm.load %2 : !llvm.ptr -> vector<16xf32>
+    // CHECK: %{{.*}} = llvm.alloca
+    %8 = llvm.alloca %arg0 x i32 : (i64) -> !llvm.ptr
+    // CHECK: omp.wsloop private(@privatizer_i32 %[[ALLOC4]] -> %{{.*}}, @firstprivatizer_f32 %[[ALLOC1]] -> %{{.*}} : !llvm.ptr, !llvm.ptr)
+    omp.wsloop private(@privatizer_i32 %5 -> %arg2, @firstprivatizer_f32 %1 -> %arg3 : !llvm.ptr, !llvm.ptr) {
+      omp.loop_nest (%arg4) : i64 = (%arg0) to (%arg0) inclusive step (%arg0) {
+        llvm.call @foo(%arg1) : (!llvm.ptr) -> ()
+        llvm.call @foo(%8) : (!llvm.ptr) -> ()
+        llvm.call @foo(%arg2) : (!llvm.ptr) -> ()
+        llvm.call @foo(%arg3) : (!llvm.ptr) -> ()
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  // CHECK: llvm.br ^[[EXIT:.*]]
+  llvm.br ^exit
+
+// CHECK: ^[[ELSE]]:
+^else:
+  // CHECK: llvm.call @foo(%[[ALLOC3]]) : (!llvm.ptr) -> ()
+  llvm.call @foo(%4) : (!llvm.ptr) -> ()
+  // CHECK: %{{.*}} = llvm.load %[[ALLOC5]]
+  %8 = llvm.load %6 : !llvm.ptr -> i32
+  // CHECK: llvm.br ^[[EXIT]]
+  llvm.br ^exit
+
+// CHECK: ^[[EXIT]]:
+^exit:
+  // CHECK: omp.free_shared_mem %[[ALLOC0]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem %[[ALLOC1]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem %[[ALLOC2]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem %[[ALLOC3]] : !llvm.ptr
+  // CHECK-NOT: omp.free_shared_mem
+  // CHECK: llvm.return
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @host_func(
+// CHECK-SAME:  %[[N:.*]]: i64)
+llvm.func @host_func(%arg0: i64) {
+  // CHECK: %[[ALLOC0:.*]] = llvm.alloca %[[N]] x i32 : (i64) -> !llvm.ptr
+  %0 = llvm.alloca %arg0 x i32 : (i64) -> !llvm.ptr
+  // CHECK: omp.parallel
+  omp.parallel {
+    // CHECK: llvm.call @foo(%[[ALLOC0]]) : (!llvm.ptr) -> ()
+    llvm.call @foo(%0) : (!llvm.ptr) -> ()
+    // CHECK: omp.target
+    omp.target {
+      %c0 = llvm.mlir.constant(1 : i64) : i64
+      // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem %{{.*}}
+      %1 = llvm.alloca %c0 x i32 : (i64) -> !llvm.ptr
+      // CHECK-NEXT: llvm.call @foo(%[[ALLOC1]]) : (!llvm.ptr) -> ()
+      llvm.call @foo(%1) : (!llvm.ptr) -> ()
+      // CHECK-NEXT: omp.free_shared_mem %[[ALLOC1]] : !llvm.ptr
+      // CHECK-NEXT: omp.terminator
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @target_spmd(
+llvm.func @target_spmd() {
+  // CHECK-NOT: omp.alloc_shared_mem
+  // CHECK-NOT: omp.free_shared_mem
+  omp.target {
+    %c = llvm.mlir.constant(1 : i64) : i64
+    %0 = llvm.alloca %c x i32 : (i64) -> !llvm.ptr
+    omp.teams {
+      %1 = llvm.alloca %c x i32 : (i64) -> !llvm.ptr
+      omp.parallel {
+        %2 = llvm.alloca %c x i32 : (i64) -> !llvm.ptr
+        %3 = llvm.load %0 : !llvm.ptr -> i32
+        %4 = llvm.load %1 : !llvm.ptr -> i32
+        omp.distribute {
+          omp.wsloop {
+            omp.loop_nest (%arg0) : i64 = (%c) to (%c) inclusive step (%c) {
+              %5 = llvm.load %2 : !llvm.ptr -> i32
+              omp.yield
+            }
+          } {omp.composite}
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    }
+    omp.terminator
+  }
+  // CHECK: return
+  llvm.return
+}
+
+}

>From bcb62dbe66b782b1818ad3ae3f420d735c14d389 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Tue, 17 Feb 2026 15:26:22 +0000
Subject: [PATCH 3/3] update after rebase and address review comments

---
 .../mlir/Dialect/OpenMP/Transforms/Passes.td  | 10 ++++-----
 .../OpenMP/Transforms/StackToShared.cpp       | 22 +++++++++++++------
 mlir/test/Dialect/OpenMP/stack-to-shared.mlir | 14 ++++++------
 3 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
index 498b8a4812caa..31e2040dd8a16 100644
--- a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
@@ -27,11 +27,11 @@ def PrepareForOMPOffloadPrivatizationPass : Pass<"omp-offload-privatization-prep
 def StackToSharedPass : Pass<"omp-stack-to-shared", "mlir::LLVM::LLVMFuncOp"> {
   let summary = "Replaces stack allocations target devices with shared memory.";
   let description = [{
-    `llvm.alloca` operations defining values in a non-SPMD target region and
-    then potentially used inside of an `omp.parallel` region are replaced by
-    this pass with `omp.alloc_shared_mem` and `omp.free_shared_mem`. This is
-    also done for top-level function `llvm.alloca`s used in the same way when
-    the parent function is a target device function.
+    This pass replaces `llvm.alloca` operations located in a non-SPMD target
+    region and then potentially used inside of an `omp.parallel` region with
+    `omp.alloc_shared_mem` and `omp.free_shared_mem`. This is also done for
+    top-level function `llvm.alloca`s used in the same way when the parent
+    function is a target device function.
 
     This ensures that explicit private allocations, intended to be shared across
     threads, use the proper memory space on a target device while supporting the
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
index bd2c8c0747b7d..0edccf53a2031 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/StackToShared.cpp
@@ -85,7 +85,7 @@ static bool shouldReplaceAllocaWithUses(const Operation::use_range &uses) {
 // similar is also implemented there to choose between allocas and device
 // shared memory allocations when processing OpenMP reductions, mapping and
 // privatization.
-bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
+static bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
   auto offloadIface = op.getParentOfType<omp::OffloadModuleInterface>();
   if (!offloadIface || !offloadIface.getIsTargetDevice())
     return false;
@@ -114,7 +114,11 @@ bool shouldReplaceAllocaWithDeviceSharedMem(Operation &op) {
   return shouldReplaceAllocaWithUses(op.getUses());
 }
 
-void insertDeviceSharedMemDeallocation(OpBuilder &builder, Value allocVal) {
+static void insertDeviceSharedMemDeallocation(OpBuilder &builder,
+                                              TypeAttr elemType,
+                                              Value arraySize,
+                                              IntegerAttr alignment,
+                                              Value allocVal) {
   Block *allocaBlock = allocVal.getParentBlock();
   DominanceInfo domInfo;
   for (Block &block : allocVal.getParentRegion()->getBlocks()) {
@@ -122,7 +126,8 @@ void insertDeviceSharedMemDeallocation(OpBuilder &builder, Value allocVal) {
     if (!terminator->hasSuccessors() &&
         domInfo.dominates(allocaBlock, &block)) {
       builder.setInsertionPoint(terminator);
-      omp::FreeSharedMemOp::create(builder, allocVal.getLoc(), allocVal);
+      omp::FreeSharedMemOp::create(builder, allocVal.getLoc(), elemType,
+                                   arraySize, alignment, allocVal);
     }
   }
 }
@@ -151,8 +156,8 @@ class StackToSharedPass
 
       // TODO: The handling of non-default address spaces might need to be
       // improved. This currently only handles the case where an alloca to
-      // non-default address space must only be used by a single addrspacecast
-      // to default address space.
+      // non-default address space is only used by a single addrspacecast to
+      // default address space.
       bool nonDefaultAddrSpace = false;
       if (auto llvmPtrType = dyn_cast<LLVM::LLVMPointerType>(resultType))
         nonDefaultAddrSpace = llvmPtrType.getAddressSpace() != 0;
@@ -163,7 +168,8 @@ class StackToSharedPass
           allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
           allocaOp.getAlignmentAttr());
       if (nonDefaultAddrSpace) {
-        assert(allocaOp->hasOneUse() && "alloca must have only one use");
+        assert(allocaOp->hasOneUse() && " unsupported non-default address "
+                                        "space alloca with multiple uses");
         auto asCastOp =
             cast<LLVM::AddrSpaceCastOp>(*allocaOp->getUsers().begin());
         asCastOp.replaceAllUsesWith(sharedAllocOp.getOperation());
@@ -179,7 +185,9 @@ class StackToSharedPass
 
       // Create a new omp.free_shared_mem for the allocated buffer prior to
       // exiting the region.
-      insertDeviceSharedMemDeallocation(builder, sharedAllocOp.getResult());
+      insertDeviceSharedMemDeallocation(
+          builder, allocaOp.getElemTypeAttr(), allocaOp.getArraySize(),
+          allocaOp.getAlignmentAttr(), sharedAllocOp.getResult());
     });
     for (Operation *op : toBeDeleted)
       op->erase();
diff --git a/mlir/test/Dialect/OpenMP/stack-to-shared.mlir b/mlir/test/Dialect/OpenMP/stack-to-shared.mlir
index 81b03acd4d368..d14528e4f396a 100644
--- a/mlir/test/Dialect/OpenMP/stack-to-shared.mlir
+++ b/mlir/test/Dialect/OpenMP/stack-to-shared.mlir
@@ -32,7 +32,7 @@ llvm.func @foo(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declareta
 llvm.func @device_func(%arg0: i64, %cond: i1) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
   // CHECK: %[[ALLOC0:.*]] = omp.alloc_shared_mem %[[N]] x i64 : (i64) -> !llvm.ptr
   %0 = llvm.alloca %arg0 x i64 : (i64) -> !llvm.ptr
-  // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem %[[N]] x f32 {alignment = 128 : i64} : (i64) -> !llvm.ptr
+  // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) align(128) -> !llvm.ptr
   %1 = llvm.alloca %arg0 x f32 {alignment = 128} : (i64) -> !llvm.ptr
   // CHECK: %[[ALLOC2:.*]] = omp.alloc_shared_mem %[[N]] x vector<16xf32> : (i64) -> !llvm.ptr
   %2 = llvm.alloca %arg0 x vector<16xf32> : (i64) -> !llvm.ptr
@@ -81,10 +81,10 @@ llvm.func @device_func(%arg0: i64, %cond: i1) attributes {omp.declare_target = #
 
 // CHECK: ^[[EXIT]]:
 ^exit:
-  // CHECK: omp.free_shared_mem %[[ALLOC0]] : !llvm.ptr
-  // CHECK: omp.free_shared_mem %[[ALLOC1]] : !llvm.ptr
-  // CHECK: omp.free_shared_mem %[[ALLOC2]] : !llvm.ptr
-  // CHECK: omp.free_shared_mem %[[ALLOC3]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x i64 : (i64)] %[[ALLOC0]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64) align(128)] %[[ALLOC1]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x vector<16xf32> : (i64)] %[[ALLOC2]] : !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x i32 : (i64)] %[[ALLOC3]] : !llvm.ptr
   // CHECK-NOT: omp.free_shared_mem
   // CHECK: llvm.return
   llvm.return
@@ -102,11 +102,11 @@ llvm.func @host_func(%arg0: i64) {
     // CHECK: omp.target
     omp.target {
       %c0 = llvm.mlir.constant(1 : i64) : i64
-      // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem %{{.*}}
+      // CHECK: %[[ALLOC1:.*]] = omp.alloc_shared_mem [[ALLOC1_SIZE:.*]] -> !llvm.ptr
       %1 = llvm.alloca %c0 x i32 : (i64) -> !llvm.ptr
       // CHECK-NEXT: llvm.call @foo(%[[ALLOC1]]) : (!llvm.ptr) -> ()
       llvm.call @foo(%1) : (!llvm.ptr) -> ()
-      // CHECK-NEXT: omp.free_shared_mem %[[ALLOC1]] : !llvm.ptr
+      // CHECK-NEXT: omp.free_shared_mem [[[ALLOC1_SIZE]]] %[[ALLOC1]] : !llvm.ptr
       // CHECK-NEXT: omp.terminator
       omp.terminator
     }



More information about the llvm-branch-commits mailing list