[llvm-branch-commits] [llvm] [mlir] [OpenMP][OMPIRBuilder] Use device shared memory for arg structures (PR #150925)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 27 04:45:52 PDT 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150925

>From 810d980c62c69f11977dc935b2af012d938021d8 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 3 Jul 2025 16:47:51 +0100
Subject: [PATCH] [OpenMP][OMPIRBuilder] Use device shared memory for arg
 structures

Argument structures are created when sections of the LLVM IR corresponding to
an OpenMP construct are outlined into their own function. For this, stack
allocations are used.

This patch modifies this behavior when compiling for a target device and
outlining `parallel`-related IR, so that it uses device shared memory instead
of private stack space. This is needed in order for threads to have access to
these arguments.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  17 +-
 .../llvm/Transforms/Utils/CodeExtractor.h     |  50 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 374 +++++++++++++-----
 llvm/lib/Transforms/IPO/HotColdSplitting.cpp  |   1 +
 llvm/lib/Transforms/IPO/IROutliner.cpp        |   4 +-
 llvm/lib/Transforms/IPO/PartialInlining.cpp   |   8 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  72 +++-
 .../Transforms/Utils/CodeExtractorTest.cpp    |   7 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  13 +-
 .../LLVMIR/omptarget-parallel-llvm.mlir       |  10 +-
 10 files changed, 412 insertions(+), 144 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 0ae77f823bc88..15223915a74d2 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -34,6 +34,7 @@
 
 namespace llvm {
 class CanonicalLoopInfo;
+class CodeExtractor;
 class ScanInfo;
 struct TargetRegionEntryInfo;
 class OffloadEntriesInfoManager;
@@ -2556,17 +2557,27 @@ class OpenMPIRBuilder {
     // TODO: this should be safe to enable by default
     bool FixUpNonEntryAllocas = false;
 
+    LLVM_ABI virtual ~OutlineInfo() = default;
+
     /// Collect all blocks in between EntryBB and ExitBB in both the given
     /// vector and set.
     LLVM_ABI void collectBlocks(SmallPtrSetImpl<BasicBlock *> &BlockSet,
                                 SmallVectorImpl<BasicBlock *> &BlockVector);
 
+    /// Create a CodeExtractor instance based on the information stored in this
+    /// structure, the list of collected blocks from a previous call to
+    /// \c collectBlocks and a flag stating whether arguments must be passed in
+    /// address space 0.
+    LLVM_ABI virtual std::unique_ptr<CodeExtractor>
+    createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                        bool ArgsInZeroAddressSpace, Twine Suffix = Twine(""));
+
     /// Return the function that contains the region to be outlined.
     Function *getFunction() const { return EntryBB->getParent(); }
   };
 
   /// Collection of regions that need to be outlined during finalization.
-  SmallVector<OutlineInfo, 16> OutlineInfos;
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> OutlineInfos;
 
   /// A collection of candidate target functions that's constant allocas will
   /// attempt to be raised on a call of finalize after all currently enqueued
@@ -2581,7 +2592,9 @@ class OpenMPIRBuilder {
   std::forward_list<ScanInfo> ScanInfos;
 
   /// Add a new region that will be outlined later.
-  void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
+  void addOutlineInfo(std::unique_ptr<OutlineInfo> &&OI) {
+    OutlineInfos.emplace_back(std::move(OI));
+  }
 
   /// An ordered map of auto-generated variables to their unique names.
   /// It stores variables with the following names: 1) ".gomp_critical_user_" +
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 66c9491182b16..fdc9f9c82b0ad 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -17,14 +17,15 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/Support/Compiler.h"
 #include <limits>
 
 namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
+class AddrSpaceCastInst;
 class AllocaInst;
-class BasicBlock;
 class BlockFrequency;
 class BlockFrequencyInfo;
 class BranchProbabilityInfo;
@@ -94,15 +95,23 @@ class LLVM_ABI CodeExtractor {
   BranchProbabilityInfo *BPI;
   AssumptionCache *AC;
 
-  // A block outside of the extraction set where any intermediate allocations
-  // will be placed inside. If this is null, allocations will be placed in the
-  // entry block of the function.
+  /// A block outside of the extraction set where any intermediate allocations
+  /// will be placed inside. If this is null, allocations will be placed in the
+  /// entry block of the function.
   BasicBlock *AllocationBlock;
 
-  // If true, varargs functions can be extracted.
+  /// A block outside of the extraction set where deallocations for intermediate
+  /// allocations can be placed inside. Not used for automatically deallocated
+  /// memory (e.g. `alloca`), which is the default.
+  ///
+  /// If it is null and needed, the end of the replacement basic block will be
+  /// used to place deallocations.
+  BasicBlock *DeallocationBlock;
+
+  /// If true, varargs functions can be extracted.
   bool AllowVarArgs;
 
-  // Bits of intermediate state computed at various phases of extraction.
+  /// Bits of intermediate state computed at various phases of extraction.
   SetVector<BasicBlock *> Blocks;
 
   /// Lists of blocks that are branched from the code region to be extracted,
@@ -123,13 +132,13 @@ class LLVM_ABI CodeExtractor {
   /// 1, etc.
   SmallVector<BasicBlock *> ExtractedFuncRetVals;
 
-  // Suffix to use when creating extracted function (appended to the original
-  // function name + "."). If empty, the default is to use the entry block
-  // label, if non-empty, otherwise "extracted".
+  /// Suffix to use when creating extracted function (appended to the original
+  /// function name + "."). If empty, the default is to use the entry block
+  /// label, if non-empty, otherwise "extracted".
   std::string Suffix;
 
-  // If true, the outlined function has aggregate argument in zero address
-  // space.
+  /// If true, the outlined function has aggregate argument in zero address
+  /// space.
   bool ArgsInZeroAddressSpace;
 
   // If true, the outlined function always return void even when there is only
@@ -152,7 +161,9 @@ class LLVM_ABI CodeExtractor {
   /// however code extractor won't validate whether extraction is legal. Any new
   /// allocations will be placed in the AllocationBlock, unless it is null, in
   /// which case it will be placed in the entry block of the function from which
-  /// the code is being extracted. If ArgsInZeroAddressSpace param is set to
+  /// the code is being extracted. Explicit deallocations for the aforementioned
+  /// allocations will be placed in the DeallocationBlock or the end of the
+  /// replacement block, if needed. If ArgsInZeroAddressSpace param is set to
   /// true, then the aggregate param pointer of the outlined function is
   /// declared in zero address space. If VoidReturnWithSingleOutput is set to
   /// true, then the return type of the outlined function is set void even if
@@ -162,9 +173,12 @@ class LLVM_ABI CodeExtractor {
                 BranchProbabilityInfo *BPI = nullptr,
                 AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                 bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr,
+                BasicBlock *DeallocationBlock = nullptr,
                 std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
                 bool VoidReturnWithSingleOutput = true);
 
+  virtual ~CodeExtractor() = default;
+
   /// Perform the extraction, returning the new function.
   ///
   /// Returns zero when called on a CodeExtractor instance where isEligible
@@ -244,6 +258,18 @@ class LLVM_ABI CodeExtractor {
   /// region, passing it instead as a scalar.
   void excludeArgFromAggregate(Value *Arg);
 
+protected:
+  /// Allocate an intermediate variable at the specified point.
+  virtual Instruction *allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP,
+                                   Type *VarType, const Twine &Name = Twine(""),
+                                   AddrSpaceCastInst **CastedAlloc = nullptr);
+
+  /// Deallocate a previously-allocated intermediate variable at the specified
+  /// point.
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value *Var,
+                                     Type *VarType);
+
 private:
   struct LifetimeMarkerInfo {
     bool SinkLifeStart = false;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0cd5355fdf74a..ade1e2c8c3dc6 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -294,6 +294,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
   return Result;
 }
 
+/// Given a function, if it represents the entry point of a target kernel, this
+/// returns the execution mode flags associated with that kernel.
+static std::optional<omp::OMPTgtExecModeFlags>
+getTargetKernelExecMode(Function &Kernel) {
+  CallInst *TargetInitCall = nullptr;
+  for (Instruction &Inst : Kernel.getEntryBlock()) {
+    if (auto *Call = dyn_cast<CallInst>(&Inst)) {
+      if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
+        TargetInitCall = Call;
+        break;
+      }
+    }
+  }
+
+  if (!TargetInitCall)
+    return std::nullopt;
+
+  // Get the kernel mode information from the global variable associated to the
+  // first argument to the call to __kmpc_target_init. Refer to
+  // createTargetInit() to see how this is initialized.
+  Value *InitOperand = TargetInitCall->getArgOperand(0);
+  GlobalVariable *KernelEnv = nullptr;
+  if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
+    KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
+  else
+    KernelEnv = cast<GlobalVariable>(InitOperand);
+  auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
+  auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
+  auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
+  return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -459,6 +491,88 @@ enum OpenMPOffloadingRequiresDirFlags {
   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
 };
 
+class OMPCodeExtractor : public CodeExtractor {
+public:
+  OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef<BasicBlock *> BBs,
+                   DominatorTree *DT = nullptr, bool AggregateArgs = false,
+                   BlockFrequencyInfo *BFI = nullptr,
+                   BranchProbabilityInfo *BPI = nullptr,
+                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
+                   bool AllowAlloca = false,
+                   BasicBlock *AllocationBlock = nullptr,
+                   BasicBlock *DeallocationBlock = nullptr,
+                   std::string Suffix = "", bool ArgsInZeroAddressSpace = false)
+      : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs,
+                      AllowAlloca, AllocationBlock, DeallocationBlock, Suffix,
+                      ArgsInZeroAddressSpace),
+        OMPBuilder(OMPBuilder) {}
+
+  virtual ~OMPCodeExtractor() = default;
+
+protected:
+  OpenMPIRBuilder &OMPBuilder;
+};
+
+class DeviceSharedMemCodeExtractor : public OMPCodeExtractor {
+public:
+  DeviceSharedMemCodeExtractor(
+      OpenMPIRBuilder &OMPBuilder, BasicBlock *AllocBlockOverride,
+      ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
+      bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+      BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr,
+      bool AllowVarArgs = false, bool AllowAlloca = false,
+      BasicBlock *AllocationBlock = nullptr,
+      BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "",
+      bool ArgsInZeroAddressSpace = false)
+      : OMPCodeExtractor(OMPBuilder, BBs, DT, AggregateArgs, BFI, BPI, AC,
+                         AllowVarArgs, AllowAlloca, AllocationBlock,
+                         DeallocationBlock, Suffix, ArgsInZeroAddressSpace),
+        AllocBlockOverride(AllocBlockOverride) {}
+  virtual ~DeviceSharedMemCodeExtractor() = default;
+
+protected:
+  virtual Instruction *
+  allocateVar(BasicBlock *, BasicBlock::iterator, Type *VarType,
+              const Twine &Name = Twine(""),
+              AddrSpaceCastInst **CastedAlloc = nullptr) override {
+    // Ignore the CastedAlloc pointer, if requested, because shared memory
+    // should not be casted to address space 0 to be passed around.
+    return OMPBuilder.createOMPAllocShared(
+        OpenMPIRBuilder::InsertPointTy(
+            AllocBlockOverride, AllocBlockOverride->getFirstInsertionPt()),
+        VarType, Name);
+  }
+
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value *Var,
+                                     Type *VarType) override {
+    return OMPBuilder.createOMPFreeShared(
+        OpenMPIRBuilder::InsertPointTy(BB, DeallocIP), Var, VarType);
+  }
+
+private:
+  // TODO: Remove the need for this override and instead get the CodeExtractor
+  // to provide a valid insert point for explicit deallocations by correctly
+  // populating its DeallocationBlock.
+  BasicBlock *AllocBlockOverride;
+};
+
+/// Helper storing information about regions to outline using device shared
+/// memory for intermediate allocations.
+struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo {
+  OpenMPIRBuilder &OMPBuilder;
+  BasicBlock *AllocBlockOverride = nullptr;
+
+  DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder)
+      : OMPBuilder(OMPBuilder) {}
+  virtual ~DeviceSharedMemOutlineInfo() = default;
+
+  virtual std::unique_ptr<CodeExtractor>
+  createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                      bool ArgsInZeroAddressSpace,
+                      Twine Suffix = Twine("")) override;
+};
+
 } // anonymous namespace
 
 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
@@ -800,20 +914,20 @@ static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
 void OpenMPIRBuilder::finalize(Function *Fn) {
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  SmallVector<OutlineInfo, 16> DeferredOutlines;
-  for (OutlineInfo &OI : OutlineInfos) {
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> DeferredOutlines;
+  for (std::unique_ptr<OutlineInfo> &OI : OutlineInfos) {
     // Skip functions that have not finalized yet; may happen with nested
     // function generation.
-    if (Fn && OI.getFunction() != Fn) {
-      DeferredOutlines.push_back(OI);
+    if (Fn && OI->getFunction() != Fn) {
+      DeferredOutlines.push_back(std::move(OI));
       continue;
     }
 
     ParallelRegionBlockSet.clear();
     Blocks.clear();
-    OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+    OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
-    Function *OuterFn = OI.getFunction();
+    Function *OuterFn = OI->getFunction();
     CodeExtractorAnalysisCache CEAC(*OuterFn);
     // If we generate code for the target device, we need to allocate
     // struct for aggregate params in the device default alloca address space.
@@ -822,27 +936,20 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
-                            /* AggregateArgs */ true,
-                            /* BlockFrequencyInfo */ nullptr,
-                            /* BranchProbabilityInfo */ nullptr,
-                            /* AssumptionCache */ nullptr,
-                            /* AllowVarArgs */ true,
-                            /* AllowAlloca */ true,
-                            /* AllocaBlock*/ OI.OuterAllocaBB,
-                            /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+    std::unique_ptr<CodeExtractor> Extractor =
+        OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, ".omp_par");
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
-    LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
-                      << " Exit: " << OI.ExitBB->getName() << "\n");
-    assert(Extractor.isEligible() &&
+    LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName()
+                      << " Exit: " << OI->ExitBB->getName() << "\n");
+    assert(Extractor->isEligible() &&
            "Expected OpenMP outlining to be possible!");
 
-    for (auto *V : OI.ExcludeArgsFromAggregate)
-      Extractor.excludeArgFromAggregate(V);
+    for (auto *V : OI->ExcludeArgsFromAggregate)
+      Extractor->excludeArgFromAggregate(V);
 
     Function *OutlinedFn =
-        Extractor.extractCodeRegion(CEAC, OI.Inputs, OI.Outputs);
+        Extractor->extractCodeRegion(CEAC, OI->Inputs, OI->Outputs);
 
     // Forward target-cpu, target-features attributes to the outlined function.
     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
@@ -867,8 +974,8 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // made our own entry block after all.
     {
       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
-      assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
-      assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
+      assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB);
+      assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry);
       // Move instructions from the to-be-deleted ArtificialEntry to the entry
       // basic block of the parallel region. CodeExtractor generates
       // instructions to unwrap the aggregate argument and may sink
@@ -884,25 +991,26 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
 
         if (I.isTerminator()) {
           // Absorb any debug value that terminator may have
-          if (Instruction *TI = OI.EntryBB->getTerminatorOrNull())
+          if (Instruction *TI = OI->EntryBB->getTerminatorOrNull())
             TI->adoptDbgRecords(&ArtificialEntry, I.getIterator(), false);
           continue;
         }
 
-        I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
+        I.moveBeforePreserving(*OI->EntryBB,
+                               OI->EntryBB->getFirstInsertionPt());
       }
 
-      OI.EntryBB->moveBefore(&ArtificialEntry);
+      OI->EntryBB->moveBefore(&ArtificialEntry);
       ArtificialEntry.eraseFromParent();
     }
-    assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
+    assert(&OutlinedFn->getEntryBlock() == OI->EntryBB);
     assert(OutlinedFn && OutlinedFn->hasNUses(1));
 
     // Run a user callback, e.g. to add attributes.
-    if (OI.PostOutlineCB)
-      OI.PostOutlineCB(*OutlinedFn);
+    if (OI->PostOutlineCB)
+      OI->PostOutlineCB(*OutlinedFn);
 
-    if (OI.FixUpNonEntryAllocas)
+    if (OI->FixUpNonEntryAllocas)
       hoistNonEntryAllocasToEntryBlock(OutlinedFn);
   }
 
@@ -1724,33 +1832,72 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
 
-  OutlineInfo OI;
+  auto OI = [&]() -> std::unique_ptr<OutlineInfo> {
+    if (Config.isTargetDevice()) {
+      std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+          getTargetKernelExecMode(*OuterFn);
+
+      // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+      // the CodeExtractor to follow its default behavior. Otherwise, we need to
+      // use device shared memory to allocate argument structures.
+      if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+        auto Info = std::make_unique<DeviceSharedMemOutlineInfo>(*this);
+
+        // Instead of using the insertion point provided by the CodeExtractor,
+        // here we need to use the block that eventually calls the outlined
+        // function for the `parallel` construct.
+        //
+        // The reason is that the explicit deallocation call will be inserted
+        // within the outlined function, whereas the alloca insertion point
+        // might actually be located somewhere else in the caller. This becomes
+        // a problem when e.g. `parallel` is inside of a `distribute` construct,
+        // because the deallocation would be executed multiple times and the
+        // allocation just once (outside of the loop).
+        //
+        // TODO: Ideally, we'd want to do the allocation and deallocation
+        // outside of the `parallel` outlined function, hence using here the
+        // insertion point provided by the CodeExtractor. We can't do this at
+        // the moment because there is currently no way of passing an eligible
+        // insertion point for the explicit deallocation to the CodeExtractor,
+        // as that block is created (at least when nested inside of
+        // `distribute`) sometime after createParallel() completed, so it can't
+        // be stored in the OutlineInfo structure here.
+        //
+        // The current approach results in an explicit allocation and
+        // deallocation pair for each `distribute` loop iteration in that case,
+        // which is suboptimal.
+        Info->AllocBlockOverride = EntryBB;
+        return Info;
+      }
+    }
+    return std::make_unique<OutlineInfo>();
+  }();
+
   if (Config.isTargetDevice()) {
     // Generate OpenMP target specific runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
                              ThreadID, ToBeDeletedVec);
     };
-    OI.FixUpNonEntryAllocas = true;
   } else {
     // Generate OpenMP host runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
-    OI.FixUpNonEntryAllocas = true;
   }
 
-  OI.OuterAllocaBB = OuterAllocaBlock;
-  OI.EntryBB = PRegEntryBB;
-  OI.ExitBB = PRegExitBB;
+  OI->FixUpNonEntryAllocas = true;
+  OI->OuterAllocaBB = OuterAllocaBlock;
+  OI->EntryBB = PRegEntryBB;
+  OI->ExitBB = PRegExitBB;
 
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
@@ -1761,6 +1908,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ OuterAllocaBlock,
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
 
   // Find inputs to, outputs from the code region.
@@ -1785,7 +1933,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   auto PrivHelper = [&](Value &V) -> Error {
     if (&V == TIDAddr || &V == ZeroAddr) {
-      OI.ExcludeArgsFromAggregate.push_back(&V);
+      OI->ExcludeArgsFromAggregate.push_back(&V);
       return Error::success();
     }
 
@@ -2139,15 +2287,15 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
   }
 
   llvm::CanonicalLoopInfo *CLI = result.get();
-  OutlineInfo OI;
-  OI.EntryBB = TaskloopAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskloopExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskloopAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskloopExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *> ToBeDeleted;
   // dummy instruction to be used as a fake argument
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskloopAllocaIP, "global.tid", false));
   Value *FakeLB = createFakeIntVal(Builder, AllocaIP, ToBeDeleted,
                                    TaskloopAllocaIP, "lb", false, true);
@@ -2157,11 +2305,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
                                      TaskloopAllocaIP, "step", false, true);
   // For Taskloop, we want to force the bounds being the first 3 inputs in the
   // aggregate struct
-  OI.Inputs.insert(FakeLB);
-  OI.Inputs.insert(FakeUB);
-  OI.Inputs.insert(FakeStep);
+  OI->Inputs.insert(FakeLB);
+  OI->Inputs.insert(FakeUB);
+  OI->Inputs.insert(FakeStep);
   if (TaskContextStructPtrVal)
-    OI.Inputs.insert(TaskContextStructPtrVal);
+    OI->Inputs.insert(TaskContextStructPtrVal);
   assert(((TaskContextStructPtrVal && DupCB) ||
           (!TaskContextStructPtrVal && !DupCB)) &&
          "Task context struct ptr and duplication callback must be both set "
@@ -2183,11 +2331,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
   }
   Value *TaskDupFn = *TaskDupFnOrErr;
 
-  OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
-                      TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
-                      IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
-                      FakeStep, FakeSharedsTy, Final, Mergeable, Priority,
-                      NumOfCollapseLoops](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
+                       TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
+                       IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
+                       FakeStep, FakeSharedsTy, Final, Mergeable, Priority,
+                       NumOfCollapseLoops](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
@@ -2489,19 +2637,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
   if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = TaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
 
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      Affinities, Mergeable, Priority, EventHandle,
-                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                       Affinities, Mergeable, Priority, EventHandle,
+                       TaskAllocaBB,
+                       ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
@@ -5993,19 +6142,19 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   }
   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize, Flag);
 
-  OutlineInfo OI;
-  OI.OuterAllocaBB = CLI->getPreheader();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->OuterAllocaBB = CLI->getPreheader();
   Function *OuterFn = CLI->getPreheader()->getParent();
 
   // Instructions which need to be deleted at the end of code generation
   SmallVector<Instruction *, 4> ToBeDeleted;
 
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Mark the body loop as region which needs to be extracted
-  OI.EntryBB = CLI->getBody();
-  OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(CLI->getLatch()->begin(),
-                                                     "omp.prelatch");
+  OI->EntryBB = CLI->getBody();
+  OI->ExitBB = CLI->getLatch()->splitBasicBlockBefore(CLI->getLatch()->begin(),
+                                                      "omp.prelatch");
 
   // Prepare loop body for extraction
   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
@@ -6025,7 +6174,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   // loop body region.
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks,
@@ -6037,6 +6186,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ CLI->getPreheader(),
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_wsloop",
                           /* AggrArgsIn0AddrSpace */ true);
 
@@ -6061,15 +6211,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   }
   // Make sure that loop counter variable is not merged into loop body
   // function argument structure and it is passed as separate variable
-  OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
+  OI->ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
 
   // PostOutline CB is invoked when loop body function is outlined and
   // loop body is replaced by call to outlined function. We need to add
   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
   // function will handle loop control logic.
   //
-  OI.PostOutlineCB = [=, ToBeDeletedVec =
-                             std::move(ToBeDeleted)](Function &OutlinedFn) {
+  OI->PostOutlineCB = [=, ToBeDeletedVec =
+                              std::move(ToBeDeleted)](Function &OutlinedFn) {
     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
                                 LoopType, NoLoop);
   };
@@ -9071,13 +9221,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
                                    TargetTaskAllocaBB->begin());
   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
 
-  OutlineInfo OI;
-  OI.EntryBB = TargetTaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TargetTaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
 
   // Generate the task body which will subsequently be outlined.
@@ -9095,8 +9245,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
   // OI.ExitBlock is set to the single task body block and will get left out of
   // the outlining process. So, simply create a new empty block to which we
   // uncoditionally branch from where TaskBodyCB left off
-  OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
-  emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
+  OI->ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
+  emitBlock(OI->ExitBB, Builder.GetInsertBlock()->getParent(),
             /*IsFinished=*/true);
 
   SmallVector<Value *, 2> OffloadingArraysToPrivatize;
@@ -9108,13 +9258,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
           RTArgs.SizesArray}) {
       if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
         OffloadingArraysToPrivatize.push_back(V);
-        OI.ExcludeArgsFromAggregate.push_back(V);
+        OI->ExcludeArgsFromAggregate.push_back(V);
       }
     }
   }
-  OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
-                      DeviceID, OffloadingArraysToPrivatize](
-                         Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
+                       DeviceID, OffloadingArraysToPrivatize](
+                          Function &OutlinedFn) mutable {
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
 
@@ -11104,17 +11254,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = AllocaBB;
-  OI.ExitBB = ExitBB;
-  OI.OuterAllocaBB = &OuterAllocaBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = AllocaBB;
+  OI->ExitBB = ExitBB;
+  OI->OuterAllocaBB = &OuterAllocaBB;
 
   // Insert fake values for global tid and bound tid.
   SmallVector<Instruction *, 8> ToBeDeleted;
   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
 
   auto HostPostOutlineCB = [this, Ident,
@@ -11155,7 +11305,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   };
 
   if (!Config.isTargetDevice())
-    OI.PostOutlineCB = HostPostOutlineCB;
+    OI->PostOutlineCB = HostPostOutlineCB;
 
   addOutlineInfo(std::move(OI));
 
@@ -11194,11 +11344,10 @@ OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
   // When using target we use different runtime functions which require a
   // callback.
   if (Config.isTargetDevice()) {
-    OutlineInfo OI;
-    OI.OuterAllocaBB = OuterAllocaIP.getBlock();
-    OI.EntryBB = AllocaBB;
-    OI.ExitBB = ExitBB;
-
+    auto OI = std::make_unique<OutlineInfo>();
+    OI->OuterAllocaBB = OuterAllocaIP.getBlock();
+    OI->EntryBB = AllocaBB;
+    OI->ExitBB = ExitBB;
     addOutlineInfo(std::move(OI));
   }
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
@@ -11260,6 +11409,39 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks(
   }
 }
 
+std::unique_ptr<CodeExtractor>
+OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                                                  bool ArgsInZeroAddressSpace,
+                                                  Twine Suffix) {
+  return std::make_unique<CodeExtractor>(Blocks, /* DominatorTree */ nullptr,
+                                         /* AggregateArgs */ true,
+                                         /* BlockFrequencyInfo */ nullptr,
+                                         /* BranchProbabilityInfo */ nullptr,
+                                         /* AssumptionCache */ nullptr,
+                                         /* AllowVarArgs */ true,
+                                         /* AllowAlloca */ true,
+                                         /* AllocationBlock*/ OuterAllocaBB,
+                                         /* DeallocationBlock */ nullptr,
+                                         /* Suffix */ Suffix.str(),
+                                         ArgsInZeroAddressSpace);
+}
+
+std::unique_ptr<CodeExtractor> DeviceSharedMemOutlineInfo::createCodeExtractor(
+    ArrayRef<BasicBlock *> Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) {
+  // TODO: Initialize the DeallocationBlock with a proper pair to OuterAllocaBB.
+  return std::make_unique<DeviceSharedMemCodeExtractor>(
+      OMPBuilder, AllocBlockOverride, Blocks, /* DominatorTree */ nullptr,
+      /* AggregateArgs */ true,
+      /* BlockFrequencyInfo */ nullptr,
+      /* BranchProbabilityInfo */ nullptr,
+      /* AssumptionCache */ nullptr,
+      /* AllowVarArgs */ true,
+      /* AllowAlloca */ true,
+      /* AllocationBlock*/ OuterAllocaBB,
+      /* DeallocationBlock */ ExitBB,
+      /* Suffix */ Suffix.str(), ArgsInZeroAddressSpace);
+}
+
 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
                                          uint64_t Size, int32_t Flags,
                                          GlobalValue::LinkageTypes,
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index b3e9796ca6789..6b2ecf2277cdf 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -721,6 +721,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
             SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr,
             /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
             /* AllowAlloca */ false, /* AllocaBlock */ nullptr,
+            /* DeallocationBlock */ nullptr,
             /* Suffix */ "cold." + std::to_string(OutlinedFunctionID),
             /* ArgsInZeroAddressSpace */ false,
             /* VoidReturnWithSingleOutput */ false);
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp
index 41c036496487c..2bdfe55592c70 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -2789,7 +2789,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       findAddInputsOutputs(M, *OS, NotSame);
       if (!OS->IgnoreRegion)
         OutlinedRegions.push_back(OS);
@@ -2900,7 +2900,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       bool FunctionOutlined = extractSection(*OS);
       if (FunctionOutlined) {
         unsigned StartIdx = OS->Candidate->getStartIdx();
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 8b4b2b5c70844..5ad98d04a5725 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -1105,8 +1105,8 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
                      ClonedFuncBFI.get(), &BPI,
                      LookupAC(*RegionInfo.EntryBlock->getParent()),
                      /* AllowVarargs */ false, /* AllowAlloca */ false,
-                     /* AllocaBlock */ nullptr, /* Suffix */ "",
-                     /* ArgsInZeroAddressSpace */ false,
+                     /* AllocaBlock */ nullptr, /* DeallocationBlock */ nullptr,
+                     /* Suffix */ "", /* ArgsInZeroAddressSpace */ false,
                      /* VoidReturnWithSingleOutput */ false);
 
     CE.findInputsOutputs(Inputs, Outputs, Sinks);
@@ -1189,8 +1189,8 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
       CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
                     ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
                     /* AllowVarargs */ true, /* AllowAlloca */ false,
-                    /* AllocaBlock */ nullptr, /* Suffix */ "",
-                    /* ArgsInZeroAddressSpace */ false,
+                    /* AllocaBlock */ nullptr, /* DeallocationBlock */ nullptr,
+                    /* Suffix */ "", /* ArgsInZeroAddressSpace */ false,
                     /* VoidReturnWithSingleOutput */ false)
           .extractCodeRegion(CEAC);
 
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 28db49daf2efa..ed3d4039f2e4a 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -25,7 +25,6 @@
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -263,12 +262,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix,
+                             BasicBlock *AllocationBlock,
+                             BasicBlock *DeallocationBlock, std::string Suffix,
                              bool ArgsInZeroAddressSpace,
                              bool VoidReturnWithSingleOutput)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
-      AllowVarArgs(AllowVarArgs),
+      DeallocationBlock(DeallocationBlock), AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
       Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
       VoidReturnWithSingleOutput(VoidReturnWithSingleOutput) {}
@@ -445,6 +445,28 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
   return CommonExitBlock;
 }
 
+Instruction *CodeExtractor::allocateVar(BasicBlock *BB,
+                                        BasicBlock::iterator AllocIP,
+                                        Type *VarType, const Twine &Name,
+                                        AddrSpaceCastInst **CastedAlloc) {
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  Instruction *Alloca =
+      new AllocaInst(VarType, DL.getAllocaAddrSpace(), nullptr, Name, AllocIP);
+
+  if (CastedAlloc && ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+    *CastedAlloc = new AddrSpaceCastInst(
+        Alloca, PointerType::get(BB->getContext(), 0), Name + ".ascast");
+    (*CastedAlloc)->insertAfter(Alloca->getIterator());
+  }
+  return Alloca;
+}
+
+Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator,
+                                          Value *, Type *) {
+  // Default alloca instructions created by allocateVar are released implicitly.
+  return nullptr;
+}
+
 // Find the pair of life time markers for address 'Addr' that are either
 // defined inside the outline region or can legally be shrinkwrapped into the
 // outline region. If there are not other untracked uses of the address, return
@@ -1834,7 +1856,6 @@ CallInst *CodeExtractor::emitReplacerCall(
     std::vector<Value *> &Reloads) {
   LLVMContext &Context = oldFunction->getContext();
   Module *M = oldFunction->getParent();
-  const DataLayout &DL = M->getDataLayout();
 
   // This takes place of the original loop
   BasicBlock *codeReplacer =
@@ -1865,25 +1886,22 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    AllocaInst *alloca = new AllocaInst(
-        output->getType(), DL.getAllocaAddrSpace(), nullptr,
-        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
-    params.push_back(alloca);
-    ReloadOutputs.push_back(alloca);
+    Value *OutAlloc =
+        allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                    output->getType(), output->getName() + ".loc");
+    params.push_back(OutAlloc);
+    ReloadOutputs.push_back(OutAlloc);
   }
 
-  AllocaInst *Struct = nullptr;
+  Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                            "structArg", AllocaBlock->getFirstInsertionPt());
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct->getIterator());
+    AddrSpaceCastInst *StructSpaceCast = nullptr;
+    Struct = allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                         StructArgTy, "structArg", &StructSpaceCast);
+    if (StructSpaceCast)
       params.push_back(StructSpaceCast);
-    } else {
+    else
       params.push_back(Struct);
-    }
 
     unsigned AggIdx = 0;
     for (Value *input : inputs) {
@@ -2026,6 +2044,24 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
                                        {}, call);
 
+  // Deallocate intermediate variables if they need explicit deallocation.
+  BasicBlock *DeallocBlock = codeReplacer;
+  BasicBlock::iterator DeallocIP = codeReplacer->end();
+  if (DeallocationBlock) {
+    DeallocBlock = DeallocationBlock;
+    DeallocIP = DeallocationBlock->getFirstInsertionPt();
+  }
+
+  int Index = 0;
+  for (Value *Output : outputs) {
+    if (!StructValues.contains(Output))
+      deallocateVar(DeallocBlock, DeallocIP, ReloadOutputs[Index++],
+                    Output->getType());
+  }
+
+  if (Struct)
+    deallocateVar(DeallocBlock, DeallocIP, Struct, StructArgTy);
+
   return call;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 269d427d382b6..93b1a070a880a 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -178,8 +178,8 @@ TEST(CodeExtractor, InputOutputReturnMonitoring) {
   CodeExtractor CE(Candidates, /* DT */ nullptr, /* AggregateArgs */ false,
                    /* BFI */ nullptr, /* BPI */ nullptr, /* AC */ nullptr,
                    /* AllowVarargs */ false, /* AllowAlloca */ false,
-                   /* AllocaBlock */ nullptr, /* Suffix */ "",
-                   /* ArgsInZeroAddressSpace */ false,
+                   /* AllocaBlock */ nullptr, /* DeallocationBlock */ nullptr,
+                   /* Suffix */ "", /* ArgsInZeroAddressSpace */ false,
                    /* VoidReturnWithSingleOutput */ false);
   EXPECT_TRUE(CE.isEligible());
 
@@ -779,7 +779,8 @@ TEST(CodeExtractor, OpenMPAggregateArgs) {
                    /* AssumptionCache */ nullptr,
                    /* AllowVarArgs */ true,
                    /* AllowAlloca */ true,
-                   /* AllocaBlock*/ &Func->getEntryBlock(),
+                   /* AllocationBlock*/ &Func->getEntryBlock(),
+                   /* DeallocationBlock */ nullptr,
                    /* Suffix */ ".outlined",
                    /* ArgsInZeroAddressSpace */ true);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 896e0b62d9821..88c370252d2bb 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7497,6 +7497,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
 static LogicalResult
 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
+                         llvm::OpenMPIRBuilder *ompBuilder,
                          LLVM::ModuleTranslation &moduleTranslation) {
   // Amend omp.declare_target by deleting the IR of the outlined functions
   // created for target regions. They cannot be filtered out from MLIR earlier
@@ -7519,6 +7520,11 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
             moduleTranslation.lookupFunction(funcOp.getName());
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
+
+        // Invalidate the builder's current insertion point, as it now points to
+        // a deleted block.
+        ompBuilder->Builder.ClearInsertionPoint();
+        ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
       }
     }
     return success();
@@ -7690,9 +7696,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
       .Case("omp.declare_target",
             [&](Attribute attr) {
               if (auto declareTargetAttr =
-                      dyn_cast<omp::DeclareTargetAttr>(attr))
+                      dyn_cast<omp::DeclareTargetAttr>(attr)) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
                 return convertDeclareTargetAttr(op, declareTargetAttr,
-                                                moduleTranslation);
+                                                ompBuilder, moduleTranslation);
+              }
               return failure();
             })
       .Case("omp.requires",
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 403e9117cc345..72652bc14cde4 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-SAME: ptr %[[TMP0:.*]], ptr %[[TMP:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
 // CHECK:         %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
-// CHECK:         %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
-// CHECK:         %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr
 // CHECK:         %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
 // CHECK:         %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr
 // CHECK:         store ptr %[[TMP0]], ptr %[[TMP4]], align 8
@@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
+// CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
-// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
-// CHECK:         store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
+// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
+// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
-// CHECK:         store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
+// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](



More information about the llvm-branch-commits mailing list