[llvm-branch-commits] [llvm] [mlir] [OpenMP][OMPIRBuilder] Use device shared memory for arg structures (PR #150925)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 23 03:54:26 PST 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150925

>From a275b82b4799eae65941f8accfc97ae5a84fe90f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 3 Jul 2025 16:47:51 +0100
Subject: [PATCH 1/8] [OpenMP][OMPIRBuilder] Use device shared memory for arg
 structures

Argument structures are created when sections of the LLVM IR corresponding to
an OpenMP construct are outlined into their own function. For this, stack
allocations are used.

This patch modifies this behavior when compiling for a target device and
outlining `parallel`-related IR, so that it uses device shared memory instead
of private stack space. This is needed in order for threads to have access to
these arguments.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  6 ++
 .../llvm/Transforms/Utils/CodeExtractor.h     | 35 ++++++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 98 +++++++++++++++++--
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 73 ++++++++++----
 .../LLVMIR/omptarget-parallel-llvm.mlir       | 10 +-
 5 files changed, 188 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 681a144a09519..3b2d1377b8905 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2512,7 +2512,13 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
+    CustomArgAllocatorCBTy CustomArgAllocatorCB;
+    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
     SetVector<Value *> Inputs, Outputs;
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index e117a33d3435f..d710eb7669807 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/Support/Compiler.h"
 #include <limits>
 
@@ -24,7 +25,6 @@ namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
 class AllocaInst;
-class BasicBlock;
 class BlockFrequency;
 class BlockFrequencyInfo;
 class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
 /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas as
 ///    arguments, and inserting stores to the arguments for any scalars.
 class LLVM_ABI CodeExtractor {
+  using CustomArgAllocatorCBTy = std::function<Instruction *(
+      BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+  using CustomArgDeallocatorCBTy = std::function<Instruction *(
+      BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
   using ValueSet = SetVector<Value *>;
 
   // Various bits of state computed on construction.
@@ -132,6 +136,25 @@ class LLVM_ABI CodeExtractor {
   // space.
   bool ArgsInZeroAddressSpace;
 
+  // If set, this callback will be used to allocate the arguments in the caller
+  // before passing it to the outlined function holding the extracted piece of
+  // code.
+  CustomArgAllocatorCBTy *CustomArgAllocatorCB;
+
+  // A block outside of the extraction set where previously introduced
+  // intermediate allocations can be deallocated. This is only used when a
+  // custom deallocator is specified.
+  BasicBlock *DeallocationBlock;
+
+  // If set, this callback will be used to deallocate the arguments in the
+  // caller after running the outlined function holding the extracted piece of
+  // code. It will not be called if a custom allocator isn't also present.
+  //
+  // By default, this will be done at the end of the basic block containing the
+  // call to the outlined function, except if a deallocation block is specified.
+  // In that case, that will take precedence.
+  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
+
 public:
   /// Create a code extractor for a sequence of blocks.
   ///
@@ -147,13 +170,19 @@ class LLVM_ABI CodeExtractor {
   /// which case it will be placed in the entry block of the function from which
   /// the code is being extracted. If ArgsInZeroAddressSpace param is set to
   /// true, then the aggregate param pointer of the outlined function is
-  /// declared in zero address space.
+  /// declared in zero address space. If a CustomArgAllocatorCB callback is
+  /// specified, it will be used to allocate any structures or variable copies
+  /// needed to pass arguments to the outlined function, rather than using
+  /// regular allocas.
   CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
                 BranchProbabilityInfo *BPI = nullptr,
                 AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                 bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr,
-                std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
+                std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
+                CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
+                BasicBlock *DeallocationBlock = nullptr,
+                CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
 
   /// Perform the extraction, returning the new function.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index afe86030f00f8..db3cb5aea134f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -289,6 +289,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
   return Result;
 }
 
+/// Given a function, if it represents the entry point of a target kernel, this
+/// returns the execution mode flags associated to that kernel.
+static std::optional<omp::OMPTgtExecModeFlags>
+getTargetKernelExecMode(Function &Kernel) {
+  CallInst *TargetInitCall = nullptr;
+  for (Instruction &Inst : Kernel.getEntryBlock()) {
+    if (auto *Call = dyn_cast<CallInst>(&Inst)) {
+      if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
+        TargetInitCall = Call;
+        break;
+      }
+    }
+  }
+
+  if (!TargetInitCall)
+    return std::nullopt;
+
+  // Get the kernel mode information from the global variable associated to the
+  // first argument to the call to __kmpc_target_init. Refer to
+  // createTargetInit() to see how this is initialized.
+  Value *InitOperand = TargetInitCall->getArgOperand(0);
+  GlobalVariable *KernelEnv = nullptr;
+  if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
+    KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
+  else
+    KernelEnv = cast<GlobalVariable>(InitOperand);
+  auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
+  auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
+  auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
+  return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -819,15 +851,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
-                            /* AggregateArgs */ true,
-                            /* BlockFrequencyInfo */ nullptr,
-                            /* BranchProbabilityInfo */ nullptr,
-                            /* AssumptionCache */ nullptr,
-                            /* AllowVarArgs */ true,
-                            /* AllowAlloca */ true,
-                            /* AllocaBlock*/ OI.OuterAllocaBB,
-                            /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+    CodeExtractor Extractor(
+        Blocks, /* DominatorTree */ nullptr,
+        /* AggregateArgs */ true,
+        /* BlockFrequencyInfo */ nullptr,
+        /* BranchProbabilityInfo */ nullptr,
+        /* AssumptionCache */ nullptr,
+        /* AllowVarArgs */ true,
+        /* AllowAlloca */ true,
+        /* AllocaBlock*/ OI.OuterAllocaBB,
+        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
+        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
+        /* DeallocationBlock */ OI.ExitBB,
+        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1729,6 +1765,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
                              ThreadID, ToBeDeletedVec);
     };
+
+    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+        getTargetKernelExecMode(*OuterFn);
+
+    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+    // the CodeExtractor to follow its default behavior. Otherwise, we need to
+    // use device shared memory to allocate argument structures.
+    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+      OI.CustomArgAllocatorCB = [this,
+                                 EntryBB](BasicBlock *, BasicBlock::iterator,
+                                          Type *ArgTy, const Twine &Name) {
+        // Instead of using the insertion point provided by the CodeExtractor,
+        // here we need to use the block that eventually calls the outlined
+        // function for the `parallel` construct.
+        //
+        // The reason is that the explicit deallocation call will be inserted
+        // within the outlined function, whereas the alloca insertion point
+        // might actually be located somewhere else in the caller. This becomes
+        // a problem when e.g. `parallel` is inside of a `distribute` construct,
+        // because the deallocation would be executed multiple times and the
+        // allocation just once (outside of the loop).
+        //
+        // TODO: Ideally, we'd want to do the allocation and deallocation
+        // outside of the `parallel` outlined function, hence using here the
+        // insertion point provided by the CodeExtractor. We can't do this at
+        // the moment because there is currently no way of passing an eligible
+        // insertion point for the explicit deallocation to the CodeExtractor,
+        // as that block is created (at least when nested inside of
+        // `distribute`) sometime after createParallel() completed, so it can't
+        // be stored in the OutlineInfo structure here.
+        //
+        // The current approach results in an explicit allocation and
+        // deallocation pair for each `distribute` loop iteration in that case,
+        // which is suboptimal.
+        return createOMPAllocShared(
+            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
+            Name);
+      };
+      OI.CustomArgDeallocatorCB =
+          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
+                 Type *ArgTy) -> Instruction * {
+        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
+      };
+    }
     OI.FixUpNonEntryAllocas = true;
   } else {
     // Generate OpenMP host runtime call
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index b298a8ae144d8..ebfa416943263 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -25,7 +25,6 @@
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -264,12 +263,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
                              BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace)
+                             bool ArgsInZeroAddressSpace,
+                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
+                             BasicBlock *DeallocationBlock,
+                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
+      CustomArgAllocatorCB(CustomArgAllocatorCB),
+      DeallocationBlock(DeallocationBlock),
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -1853,24 +1858,38 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    AllocaInst *alloca = new AllocaInst(
-        output->getType(), DL.getAllocaAddrSpace(), nullptr,
-        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
-    params.push_back(alloca);
-    ReloadOutputs.push_back(alloca);
+    Value *OutAlloc;
+    if (CustomArgAllocatorCB)
+      OutAlloc = (*CustomArgAllocatorCB)(
+          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
+          output->getName() + ".loc");
+    else
+      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+                                nullptr, output->getName() + ".loc",
+                                AllocaBlock->getFirstInsertionPt());
+
+    params.push_back(OutAlloc);
+    ReloadOutputs.push_back(OutAlloc);
   }
 
-  AllocaInst *Struct = nullptr;
+  Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                            "structArg", AllocaBlock->getFirstInsertionPt());
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct->getIterator());
-      params.push_back(StructSpaceCast);
-    } else {
+    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
+    if (CustomArgAllocatorCB) {
+      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
+                                       "structArg");
       params.push_back(Struct);
+    } else {
+      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+                              "structArg", StructArgIP);
+      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+        auto *StructSpaceCast = new AddrSpaceCastInst(
+            Struct, PointerType ::get(Context, 0), "structArg.ascast");
+        StructSpaceCast->insertAfter(Struct->getIterator());
+        params.push_back(StructSpaceCast);
+      } else {
+        params.push_back(Struct);
+      }
     }
 
     unsigned AggIdx = 0;
@@ -2014,6 +2033,26 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
                                        {}, call);
 
+  // Deallocate variables that used a custom allocator.
+  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
+    BasicBlock *DeallocBlock = codeReplacer;
+    BasicBlock::iterator DeallocIP = codeReplacer->end();
+    if (DeallocationBlock) {
+      DeallocBlock = DeallocationBlock;
+      DeallocIP = DeallocationBlock->getFirstInsertionPt();
+    }
+
+    int Index = 0;
+    for (Value *Output : outputs) {
+      if (!StructValues.contains(Output))
+        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
+                                  ReloadOutputs[Index++], Output->getType());
+    }
+
+    if (Struct)
+      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  }
+
   return call;
 }
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index cdb8dbbbc946c..6476632c58587 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
 // CHECK:         %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
-// CHECK:         %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
-// CHECK:         %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr
 // CHECK:         %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
 // CHECK:         %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr
 // CHECK:         store ptr %[[TMP0]], ptr %[[TMP4]], align 8
@@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
+// CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
-// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
-// CHECK:         store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
+// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
+// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
-// CHECK:         store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
+// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](

>From 1861979272cca68cca2feb4cf0643d0f80e16add Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 5 Aug 2025 15:34:28 +0100
Subject: [PATCH 2/8] Address intermittent ICE triggered from the
 `OpenMPIRBuilder::finalize` method due to an invalid builder insertion point

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp    | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 58572c8ae3d7c..8db6fc4f38566 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6906,6 +6906,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
 static LogicalResult
 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
+                         llvm::OpenMPIRBuilder *ompBuilder,
                          LLVM::ModuleTranslation &moduleTranslation) {
   // Amend omp.declare_target by deleting the IR of the outlined functions
   // created for target regions. They cannot be filtered out from MLIR earlier
@@ -6928,6 +6929,11 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
             moduleTranslation.lookupFunction(funcOp.getName());
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
+
+        // Invalidate the builder's current insertion point, as it now points to
+        // a deleted block.
+        ompBuilder->Builder.ClearInsertionPoint();
+        ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
       }
     }
     return success();
@@ -7079,9 +7085,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
       .Case("omp.declare_target",
             [&](Attribute attr) {
               if (auto declareTargetAttr =
-                      dyn_cast<omp::DeclareTargetAttr>(attr))
+                      dyn_cast<omp::DeclareTargetAttr>(attr)) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
                 return convertDeclareTargetAttr(op, declareTargetAttr,
-                                                moduleTranslation);
+                                                ompBuilder, moduleTranslation);
+              }
               return failure();
             })
       .Case("omp.requires",

>From f7f36c5af91425b4a84cd9e3a3f317b6ead0224e Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Aug 2025 12:40:09 +0100
Subject: [PATCH 3/8] Address nits

---
 .../llvm/Transforms/Utils/CodeExtractor.h     | 46 +++++++++----------
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  2 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  5 +-
 3 files changed, 28 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index d710eb7669807..d401ec4368b51 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -98,15 +98,15 @@ class LLVM_ABI CodeExtractor {
   BranchProbabilityInfo *BPI;
   AssumptionCache *AC;
 
-  // A block outside of the extraction set where any intermediate allocations
-  // will be placed inside. If this is null, allocations will be placed in the
-  // entry block of the function.
+  /// A block outside of the extraction set where any intermediate allocations
+  /// will be placed inside. If this is null, allocations will be placed in the
+  /// entry block of the function.
   BasicBlock *AllocationBlock;
 
-  // If true, varargs functions can be extracted.
+  /// If true, varargs functions can be extracted.
   bool AllowVarArgs;
 
-  // Bits of intermediate state computed at various phases of extraction.
+  /// Bits of intermediate state computed at various phases of extraction.
   SetVector<BasicBlock *> Blocks;
 
   /// Lists of blocks that are branched from the code region to be extracted,
@@ -127,32 +127,32 @@ class LLVM_ABI CodeExtractor {
   /// 1, etc.
   SmallVector<BasicBlock *> ExtractedFuncRetVals;
 
-  // Suffix to use when creating extracted function (appended to the original
-  // function name + "."). If empty, the default is to use the entry block
-  // label, if non-empty, otherwise "extracted".
+  /// Suffix to use when creating extracted function (appended to the original
+  /// function name + "."). If empty, the default is to use the entry block
+  /// label, if non-empty, otherwise "extracted".
   std::string Suffix;
 
-  // If true, the outlined function has aggregate argument in zero address
-  // space.
+  /// If true, the outlined function has aggregate argument in zero address
+  /// space.
   bool ArgsInZeroAddressSpace;
 
-  // If set, this callback will be used to allocate the arguments in the caller
-  // before passing it to the outlined function holding the extracted piece of
-  // code.
+  /// If set, this callback will be used to allocate the arguments in the caller
+  /// before passing it to the outlined function holding the extracted piece of
+  /// code.
   CustomArgAllocatorCBTy *CustomArgAllocatorCB;
 
-  // A block outside of the extraction set where previously introduced
-  // intermediate allocations can be deallocated. This is only used when a
-  // custom deallocator is specified.
+  /// A block outside of the extraction set where previously introduced
+  /// intermediate allocations can be deallocated. This is only used when a
+  /// custom deallocator is specified.
   BasicBlock *DeallocationBlock;
 
-  // If set, this callback will be used to deallocate the arguments in the
-  // caller after running the outlined function holding the extracted piece of
-  // code. It will not be called if a custom allocator isn't also present.
-  //
-  // By default, this will be done at the end of the basic block containing the
-  // call to the outlined function, except if a deallocation block is specified.
-  // In that case, that will take precedence.
+  /// If set, this callback will be used to deallocate the arguments in the
+  /// caller after running the outlined function holding the extracted piece of
+  /// code. It will not be called if a custom allocator isn't also present.
+  ///
+  /// By default, this will be done at the end of the basic block containing the
+  /// call to the outlined function, except if a deallocation block is specified.
+  /// In that case, that will take precedence.
   CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
 
 public:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index db3cb5aea134f..3193b3be3dadd 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -290,7 +290,7 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
 }
 
 /// Given a function, if it represents the entry point of a target kernel, this
-/// returns the execution mode flags associated to that kernel.
+/// returns the execution mode flags associated with that kernel.
 static std::optional<omp::OMPTgtExecModeFlags>
 getTargetKernelExecMode(Function &Kernel) {
   CallInst *TargetInitCall = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ebfa416943263..01a2069721d55 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -274,7 +274,10 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
       Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
       CustomArgAllocatorCB(CustomArgAllocatorCB),
       DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
+  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
+         "custom deallocator only allowed if a custom allocator is provided");
+}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.

>From 970c95d0cd62a8a0b7986c5cf9df7fcd4b4eb1e1 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 15 Aug 2025 15:45:14 +0100
Subject: [PATCH 4/8] Replace CodeExtractor callbacks with subclasses and
 simplify their creation based on OutlineInfo structures

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  23 +-
 .../llvm/Transforms/Utils/CodeExtractor.h     |  59 ++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 325 ++++++++++++------
 llvm/lib/Transforms/IPO/HotColdSplitting.cpp  |   1 +
 llvm/lib/Transforms/IPO/IROutliner.cpp        |   4 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 107 +++---
 .../Transforms/Utils/CodeExtractorTest.cpp    |   3 +-
 7 files changed, 311 insertions(+), 211 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3b2d1377b8905..425b54021eb36 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -33,6 +33,7 @@
 
 namespace llvm {
 class CanonicalLoopInfo;
+class CodeExtractor;
 class ScanInfo;
 struct TargetRegionEntryInfo;
 class OffloadEntriesInfoManager;
@@ -2512,30 +2513,34 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
-    using CustomArgAllocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-    using CustomArgDeallocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
-    CustomArgAllocatorCBTy CustomArgAllocatorCB;
-    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
     SetVector<Value *> Inputs, Outputs;
     // TODO: this should be safe to enable by default
     bool FixUpNonEntryAllocas = false;
 
+    LLVM_ABI virtual ~OutlineInfo() = default;
+
     /// Collect all blocks in between EntryBB and ExitBB in both the given
     /// vector and set.
     LLVM_ABI void collectBlocks(SmallPtrSetImpl<BasicBlock *> &BlockSet,
                                 SmallVectorImpl<BasicBlock *> &BlockVector);
 
+    /// Create a CodeExtractor instance based on the information stored in this
+    /// structure, the list of collected blocks from a previous call to
+    /// \c collectBlocks and a flag stating whether arguments must be passed in
+    /// address space 0.
+    LLVM_ABI virtual std::unique_ptr<CodeExtractor>
+    createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                        bool ArgsInZeroAddressSpace, Twine Suffix = Twine(""));
+
     /// Return the function that contains the region to be outlined.
     Function *getFunction() const { return EntryBB->getParent(); }
   };
 
   /// Collection of regions that need to be outlined during finalization.
-  SmallVector<OutlineInfo, 16> OutlineInfos;
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> OutlineInfos;
 
   /// A collection of candidate target functions that's constant allocas will
   /// attempt to be raised on a call of finalize after all currently enqueued
@@ -2550,7 +2555,9 @@ class OpenMPIRBuilder {
   std::forward_list<ScanInfo> ScanInfos;
 
   /// Add a new region that will be outlined later.
-  void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
+  void addOutlineInfo(std::unique_ptr<OutlineInfo> &&OI) {
+    OutlineInfos.emplace_back(std::move(OI));
+  }
 
   /// An ordered map of auto-generated variables to their unique names.
   /// It stores variables with the following names: 1) ".gomp_critical_user_" +
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index d401ec4368b51..064d15d32d6e7 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -24,6 +24,7 @@
 namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
+class AddrSpaceCastInst;
 class AllocaInst;
 class BlockFrequency;
 class BlockFrequencyInfo;
@@ -85,10 +86,6 @@ class CodeExtractorAnalysisCache {
 /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas as
 ///    arguments, and inserting stores to the arguments for any scalars.
 class LLVM_ABI CodeExtractor {
-  using CustomArgAllocatorCBTy = std::function<Instruction *(
-      BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-  using CustomArgDeallocatorCBTy = std::function<Instruction *(
-      BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
   using ValueSet = SetVector<Value *>;
 
   // Various bits of state computed on construction.
@@ -103,6 +100,14 @@ class LLVM_ABI CodeExtractor {
   /// entry block of the function.
   BasicBlock *AllocationBlock;
 
+  /// A block outside of the extraction set where deallocations for intermediate
+  /// allocations can be placed inside. Not used for automatically deallocated
+  /// memory (e.g. `alloca`), which is the default.
+  ///
+  /// If it is null and needed, the end of the replacement basic block will be
+  /// used to place deallocations.
+  BasicBlock *DeallocationBlock;
+
   /// If true, varargs functions can be extracted.
   bool AllowVarArgs;
 
@@ -136,25 +141,6 @@ class LLVM_ABI CodeExtractor {
   /// space.
   bool ArgsInZeroAddressSpace;
 
-  /// If set, this callback will be used to allocate the arguments in the caller
-  /// before passing it to the outlined function holding the extracted piece of
-  /// code.
-  CustomArgAllocatorCBTy *CustomArgAllocatorCB;
-
-  /// A block outside of the extraction set where previously introduced
-  /// intermediate allocations can be deallocated. This is only used when a
-  /// custom deallocator is specified.
-  BasicBlock *DeallocationBlock;
-
-  /// If set, this callback will be used to deallocate the arguments in the
-  /// caller after running the outlined function holding the extracted piece of
-  /// code. It will not be called if a custom allocator isn't also present.
-  ///
-  /// By default, this will be done at the end of the basic block containing the
-  /// call to the outlined function, except if a deallocation block is specified.
-  /// In that case, that will take precedence.
-  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
-
 public:
   /// Create a code extractor for a sequence of blocks.
   ///
@@ -168,21 +154,20 @@ class LLVM_ABI CodeExtractor {
   /// however code extractor won't validate whether extraction is legal. Any new
   /// allocations will be placed in the AllocationBlock, unless it is null, in
   /// which case it will be placed in the entry block of the function from which
-  /// the code is being extracted. If ArgsInZeroAddressSpace param is set to
+  /// the code is being extracted. Explicit deallocations for the aforementioned
+  /// allocations will be placed in the DeallocationBlock or the end of the
+  /// replacement block, if needed. If ArgsInZeroAddressSpace param is set to
   /// true, then the aggregate param pointer of the outlined function is
-  /// declared in zero address space. If a CustomArgAllocatorCB callback is
-  /// specified, it will be used to allocate any structures or variable copies
-  /// needed to pass arguments to the outlined function, rather than using
-  /// regular allocas.
+  /// declared in zero address space.
   CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
                 BranchProbabilityInfo *BPI = nullptr,
                 AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                 bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr,
-                std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
-                CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
                 BasicBlock *DeallocationBlock = nullptr,
-                CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
+                std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
+
+  virtual ~CodeExtractor() = default;
 
   /// Perform the extraction, returning the new function.
   ///
@@ -263,6 +248,18 @@ class LLVM_ABI CodeExtractor {
   /// region, passing it instead as a scalar.
   void excludeArgFromAggregate(Value *Arg);
 
+protected:
+  /// Allocate an intermediate variable at the specified point.
+  virtual Instruction *allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP,
+                                   Type *VarType, const Twine &Name = Twine(""),
+                                   AddrSpaceCastInst **CastedAlloc = nullptr);
+
+  /// Deallocate a previously-allocated intermediate variable at the specified
+  /// point.
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value *Var,
+                                     Type *VarType);
+
 private:
   struct LifetimeMarkerInfo {
     bool SinkLifeStart = false;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 3193b3be3dadd..5d9deba137b50 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -488,6 +488,88 @@ enum OpenMPOffloadingRequiresDirFlags {
   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
 };
 
+class OMPCodeExtractor : public CodeExtractor {
+public:
+  OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef<BasicBlock *> BBs,
+                   DominatorTree *DT = nullptr, bool AggregateArgs = false,
+                   BlockFrequencyInfo *BFI = nullptr,
+                   BranchProbabilityInfo *BPI = nullptr,
+                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
+                   bool AllowAlloca = false,
+                   BasicBlock *AllocationBlock = nullptr,
+                   BasicBlock *DeallocationBlock = nullptr,
+                   std::string Suffix = "", bool ArgsInZeroAddressSpace = false)
+      : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs,
+                      AllowAlloca, AllocationBlock, DeallocationBlock, Suffix,
+                      ArgsInZeroAddressSpace),
+        OMPBuilder(OMPBuilder) {}
+
+  virtual ~OMPCodeExtractor() = default;
+
+protected:
+  OpenMPIRBuilder &OMPBuilder;
+};
+
+class DeviceSharedMemCodeExtractor : public OMPCodeExtractor {
+public:
+  DeviceSharedMemCodeExtractor(
+      OpenMPIRBuilder &OMPBuilder, BasicBlock *AllocBlockOverride,
+      ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
+      bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+      BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr,
+      bool AllowVarArgs = false, bool AllowAlloca = false,
+      BasicBlock *AllocationBlock = nullptr,
+      BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "",
+      bool ArgsInZeroAddressSpace = false)
+      : OMPCodeExtractor(OMPBuilder, BBs, DT, AggregateArgs, BFI, BPI, AC,
+                         AllowVarArgs, AllowAlloca, AllocationBlock,
+                         DeallocationBlock, Suffix, ArgsInZeroAddressSpace),
+        AllocBlockOverride(AllocBlockOverride) {}
+  virtual ~DeviceSharedMemCodeExtractor() = default;
+
+protected:
+  virtual Instruction *
+  allocateVar(BasicBlock *, BasicBlock::iterator, Type *VarType,
+              const Twine &Name = Twine(""),
+              AddrSpaceCastInst **CastedAlloc = nullptr) override {
+    // Ignore the CastedAlloc pointer, if requested, because shared memory
+    // should not be casted to address space 0 to be passed around.
+    return OMPBuilder.createOMPAllocShared(
+        OpenMPIRBuilder::InsertPointTy(
+            AllocBlockOverride, AllocBlockOverride->getFirstInsertionPt()),
+        VarType, Name);
+  }
+
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value *Var,
+                                     Type *VarType) override {
+    return OMPBuilder.createOMPFreeShared(
+        OpenMPIRBuilder::InsertPointTy(BB, DeallocIP), Var, VarType);
+  }
+
+private:
+  // TODO: Remove the need for this override and instead get the CodeExtractor
+  // to provide a valid insert point for explicit deallocations by correctly
+  // populating its DeallocationBlock.
+  BasicBlock *AllocBlockOverride;
+};
+
+/// Helper storing information about regions to outline using device shared
+/// memory for intermediate allocations.
+struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo {
+  OpenMPIRBuilder &OMPBuilder;
+  BasicBlock *AllocBlockOverride = nullptr;
+
+  DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder)
+      : OMPBuilder(OMPBuilder) {}
+  virtual ~DeviceSharedMemOutlineInfo() = default;
+
+  virtual std::unique_ptr<CodeExtractor>
+  createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                      bool ArgsInZeroAddressSpace,
+                      Twine Suffix = Twine("")) override;
+};
+
 } // anonymous namespace
 
 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
@@ -829,20 +911,20 @@ static void hoistNonEntryAllocasToEntryBlock(llvm::Function *Func) {
 void OpenMPIRBuilder::finalize(Function *Fn) {
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  SmallVector<OutlineInfo, 16> DeferredOutlines;
-  for (OutlineInfo &OI : OutlineInfos) {
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> DeferredOutlines;
+  for (std::unique_ptr<OutlineInfo> &OI : OutlineInfos) {
     // Skip functions that have not finalized yet; may happen with nested
     // function generation.
-    if (Fn && OI.getFunction() != Fn) {
-      DeferredOutlines.push_back(OI);
+    if (Fn && OI->getFunction() != Fn) {
+      DeferredOutlines.push_back(std::move(OI));
       continue;
     }
 
     ParallelRegionBlockSet.clear();
     Blocks.clear();
-    OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+    OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
-    Function *OuterFn = OI.getFunction();
+    Function *OuterFn = OI->getFunction();
     CodeExtractorAnalysisCache CEAC(*OuterFn);
     // If we generate code for the target device, we need to allocate
     // struct for aggregate params in the device default alloca address space.
@@ -851,31 +933,20 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(
-        Blocks, /* DominatorTree */ nullptr,
-        /* AggregateArgs */ true,
-        /* BlockFrequencyInfo */ nullptr,
-        /* BranchProbabilityInfo */ nullptr,
-        /* AssumptionCache */ nullptr,
-        /* AllowVarArgs */ true,
-        /* AllowAlloca */ true,
-        /* AllocaBlock*/ OI.OuterAllocaBB,
-        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
-        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
-        /* DeallocationBlock */ OI.ExitBB,
-        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
+    std::unique_ptr<CodeExtractor> Extractor =
+        OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, ".omp_par");
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
-    LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
-                      << " Exit: " << OI.ExitBB->getName() << "\n");
-    assert(Extractor.isEligible() &&
+    LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName()
+                      << " Exit: " << OI->ExitBB->getName() << "\n");
+    assert(Extractor->isEligible() &&
            "Expected OpenMP outlining to be possible!");
 
-    for (auto *V : OI.ExcludeArgsFromAggregate)
-      Extractor.excludeArgFromAggregate(V);
+    for (auto *V : OI->ExcludeArgsFromAggregate)
+      Extractor->excludeArgFromAggregate(V);
 
     Function *OutlinedFn =
-        Extractor.extractCodeRegion(CEAC, OI.Inputs, OI.Outputs);
+        Extractor->extractCodeRegion(CEAC, OI->Inputs, OI->Outputs);
 
     // Forward target-cpu, target-features attributes to the outlined function.
     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
@@ -900,8 +971,8 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // made our own entry block after all.
     {
       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
-      assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
-      assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
+      assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB);
+      assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry);
       // Move instructions from the to-be-deleted ArtificialEntry to the entry
       // basic block of the parallel region. CodeExtractor generates
       // instructions to unwrap the aggregate argument and may sink
@@ -917,26 +988,27 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
 
         if (I.isTerminator()) {
           // Absorb any debug value that terminator may have
-          if (OI.EntryBB->getTerminator())
-            OI.EntryBB->getTerminator()->adoptDbgRecords(
+          if (OI->EntryBB->getTerminator())
+            OI->EntryBB->getTerminator()->adoptDbgRecords(
                 &ArtificialEntry, I.getIterator(), false);
           continue;
         }
 
-        I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
+        I.moveBeforePreserving(*OI->EntryBB,
+                               OI->EntryBB->getFirstInsertionPt());
       }
 
-      OI.EntryBB->moveBefore(&ArtificialEntry);
+      OI->EntryBB->moveBefore(&ArtificialEntry);
       ArtificialEntry.eraseFromParent();
     }
-    assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
+    assert(&OutlinedFn->getEntryBlock() == OI->EntryBB);
     assert(OutlinedFn && OutlinedFn->hasNUses(1));
 
     // Run a user callback, e.g. to add attributes.
-    if (OI.PostOutlineCB)
-      OI.PostOutlineCB(*OutlinedFn);
+    if (OI->PostOutlineCB)
+      OI->PostOutlineCB(*OutlinedFn);
 
-    if (OI.FixUpNonEntryAllocas)
+    if (OI->FixUpNonEntryAllocas)
       hoistNonEntryAllocasToEntryBlock(OutlinedFn);
   }
 
@@ -1756,26 +1828,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
 
-  OutlineInfo OI;
-  if (Config.isTargetDevice()) {
-    // Generate OpenMP target specific runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
-      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
-                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
-                             ThreadID, ToBeDeletedVec);
-    };
+  auto OI = [&]() -> std::unique_ptr<OutlineInfo> {
+    if (Config.isTargetDevice()) {
+      std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+          getTargetKernelExecMode(*OuterFn);
 
-    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
-        getTargetKernelExecMode(*OuterFn);
+      // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+      // the CodeExtractor to follow its default behavior. Otherwise, we need to
+      // use device shared memory to allocate argument structures.
+      if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+        auto Info = std::make_unique<DeviceSharedMemOutlineInfo>(*this);
 
-    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
-    // the CodeExtractor to follow its default behavior. Otherwise, we need to
-    // use device shared memory to allocate argument structures.
-    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
-      OI.CustomArgAllocatorCB = [this,
-                                 EntryBB](BasicBlock *, BasicBlock::iterator,
-                                          Type *ArgTy, const Twine &Name) {
         // Instead of using the insertion point provided by the CodeExtractor,
         // here we need to use the block that eventually calls the outlined
         // function for the `parallel` construct.
@@ -1799,34 +1862,38 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
         // The current approach results in an explicit allocation and
         // deallocation pair for each `distribute` loop iteration in that case,
         // which is suboptimal.
-        return createOMPAllocShared(
-            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
-            Name);
-      };
-      OI.CustomArgDeallocatorCB =
-          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
-                 Type *ArgTy) -> Instruction * {
-        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
-      };
+        Info->AllocBlockOverride = EntryBB;
+        return Info;
+      }
     }
-    OI.FixUpNonEntryAllocas = true;
+    return std::make_unique<OutlineInfo>();
+  }();
+
+  if (Config.isTargetDevice()) {
+    // Generate OpenMP target specific runtime call
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
+      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
+                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
+                             ThreadID, ToBeDeletedVec);
+    };
   } else {
     // Generate OpenMP host runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
-    OI.FixUpNonEntryAllocas = true;
   }
-
-  OI.OuterAllocaBB = OuterAllocaBlock;
-  OI.EntryBB = PRegEntryBB;
-  OI.ExitBB = PRegExitBB;
+  
+  OI->FixUpNonEntryAllocas = true;
+  OI->OuterAllocaBB = OuterAllocaBlock;
+  OI->EntryBB = PRegEntryBB;
+  OI->ExitBB = PRegExitBB;
 
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
@@ -1837,6 +1904,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ OuterAllocaBlock,
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
 
   // Find inputs to, outputs from the code region.
@@ -1861,7 +1929,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   auto PrivHelper = [&](Value &V) -> Error {
     if (&V == TIDAddr || &V == ZeroAddr) {
-      OI.ExcludeArgsFromAggregate.push_back(&V);
+      OI->ExcludeArgsFromAggregate.push_back(&V);
       return Error::success();
     }
 
@@ -2549,19 +2617,19 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
   if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = TaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
 
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      Mergeable, Priority, EventHandle, TaskAllocaBB,
-                      ToBeDeleted](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                       Mergeable, Priority, EventHandle, TaskAllocaBB,
+                       ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
@@ -5981,19 +6049,19 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
 
-  OutlineInfo OI;
-  OI.OuterAllocaBB = CLI->getPreheader();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->OuterAllocaBB = CLI->getPreheader();
   Function *OuterFn = CLI->getPreheader()->getParent();
 
   // Instructions which need to be deleted at the end of code generation
   SmallVector<Instruction *, 4> ToBeDeleted;
 
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Mark the body loop as region which needs to be extracted
-  OI.EntryBB = CLI->getBody();
-  OI.ExitBB = CLI->getLatch()->splitBasicBlockBefore(CLI->getLatch()->begin(),
-                                                     "omp.prelatch");
+  OI->EntryBB = CLI->getBody();
+  OI->ExitBB = CLI->getLatch()->splitBasicBlockBefore(CLI->getLatch()->begin(),
+                                                      "omp.prelatch");
 
   // Prepare loop body for extraction
   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
@@ -6013,7 +6081,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   // loop body region.
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks,
@@ -6025,6 +6093,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ CLI->getPreheader(),
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_wsloop",
                           /* AggrArgsIn0AddrSpace */ true);
 
@@ -6049,15 +6118,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   }
   // Make sure that loop counter variable is not merged into loop body
   // function argument structure and it is passed as separate variable
-  OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
+  OI->ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
 
   // PostOutline CB is invoked when loop body function is outlined and
   // loop body is replaced by call to outlined function. We need to add
   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
   // function will handle loop control logic.
   //
-  OI.PostOutlineCB = [=, ToBeDeletedVec =
-                             std::move(ToBeDeleted)](Function &OutlinedFn) {
+  OI->PostOutlineCB = [=, ToBeDeletedVec =
+                              std::move(ToBeDeleted)](Function &OutlinedFn) {
     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
                                 LoopType, NoLoop);
   };
@@ -8981,13 +9050,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
                                    TargetTaskAllocaBB->begin());
   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
 
-  OutlineInfo OI;
-  OI.EntryBB = TargetTaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TargetTaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
 
   // Generate the task body which will subsequently be outlined.
@@ -9005,8 +9074,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
   // OI.ExitBlock is set to the single task body block and will get left out of
   // the outlining process. So, simply create a new empty block to which we
   // uncoditionally branch from where TaskBodyCB left off
-  OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
-  emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
+  OI->ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
+  emitBlock(OI->ExitBB, Builder.GetInsertBlock()->getParent(),
             /*IsFinished=*/true);
 
   SmallVector<Value *, 2> OffloadingArraysToPrivatize;
@@ -9018,13 +9087,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
           RTArgs.SizesArray}) {
       if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
         OffloadingArraysToPrivatize.push_back(V);
-        OI.ExcludeArgsFromAggregate.push_back(V);
+        OI->ExcludeArgsFromAggregate.push_back(V);
       }
     }
   }
-  OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
-                      DeviceID, OffloadingArraysToPrivatize](
-                         Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
+                       DeviceID, OffloadingArraysToPrivatize](
+                          Function &OutlinedFn) mutable {
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
 
@@ -10996,17 +11065,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = AllocaBB;
-  OI.ExitBB = ExitBB;
-  OI.OuterAllocaBB = &OuterAllocaBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = AllocaBB;
+  OI->ExitBB = ExitBB;
+  OI->OuterAllocaBB = &OuterAllocaBB;
 
   // Insert fake values for global tid and bound tid.
   SmallVector<Instruction *, 8> ToBeDeleted;
   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
 
   auto HostPostOutlineCB = [this, Ident,
@@ -11047,7 +11116,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   };
 
   if (!Config.isTargetDevice())
-    OI.PostOutlineCB = HostPostOutlineCB;
+    OI->PostOutlineCB = HostPostOutlineCB;
 
   addOutlineInfo(std::move(OI));
 
@@ -11086,11 +11155,10 @@ OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
   // When using target we use different runtime functions which require a
   // callback.
   if (Config.isTargetDevice()) {
-    OutlineInfo OI;
-    OI.OuterAllocaBB = OuterAllocaIP.getBlock();
-    OI.EntryBB = AllocaBB;
-    OI.ExitBB = ExitBB;
-
+    auto OI = std::make_unique<OutlineInfo>();
+    OI->OuterAllocaBB = OuterAllocaIP.getBlock();
+    OI->EntryBB = AllocaBB;
+    OI->ExitBB = ExitBB;
     addOutlineInfo(std::move(OI));
   }
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
@@ -11152,6 +11220,39 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks(
   }
 }
 
+std::unique_ptr<CodeExtractor>
+OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                                                  bool ArgsInZeroAddressSpace,
+                                                  Twine Suffix) {
+  return std::make_unique<CodeExtractor>(Blocks, /* DominatorTree */ nullptr,
+                                         /* AggregateArgs */ true,
+                                         /* BlockFrequencyInfo */ nullptr,
+                                         /* BranchProbabilityInfo */ nullptr,
+                                         /* AssumptionCache */ nullptr,
+                                         /* AllowVarArgs */ true,
+                                         /* AllowAlloca */ true,
+                                         /* AllocationBlock*/ OuterAllocaBB,
+                                         /* DeallocationBlock */ nullptr,
+                                         /* Suffix */ Suffix.str(),
+                                         ArgsInZeroAddressSpace);
+}
+
+std::unique_ptr<CodeExtractor> DeviceSharedMemOutlineInfo::createCodeExtractor(
+    ArrayRef<BasicBlock *> Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) {
+  // TODO: Initialize the DeallocationBlock with a proper pair to OuterAllocaBB.
+  return std::make_unique<DeviceSharedMemCodeExtractor>(
+      OMPBuilder, AllocBlockOverride, Blocks, /* DominatorTree */ nullptr,
+      /* AggregateArgs */ true,
+      /* BlockFrequencyInfo */ nullptr,
+      /* BranchProbabilityInfo */ nullptr,
+      /* AssumptionCache */ nullptr,
+      /* AllowVarArgs */ true,
+      /* AllowAlloca */ true,
+      /* AllocationBlock*/ OuterAllocaBB,
+      /* DeallocationBlock */ ExitBB,
+      /* Suffix */ Suffix.str(), ArgsInZeroAddressSpace);
+}
+
 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
                                          uint64_t Size, int32_t Flags,
                                          GlobalValue::LinkageTypes,
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 3d8b7cbb59630..57809017a75a4 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -721,6 +721,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
             SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr,
             /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
             /* AllowAlloca */ false, /* AllocaBlock */ nullptr,
+            /* DeallocationBlock */ nullptr,
             /* Suffix */ "cold." + std::to_string(OutlinedFunctionID));
 
         if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) &&
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp
index 574e4639b5436..7494e2b7f6b89 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -2818,7 +2818,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       findAddInputsOutputs(M, *OS, NotSame);
       if (!OS->IgnoreRegion)
         OutlinedRegions.push_back(OS);
@@ -2929,7 +2929,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       bool FunctionOutlined = extractSection(*OS);
       if (FunctionOutlined) {
         unsigned StartIdx = OS->Candidate->getStartIdx();
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 01a2069721d55..b6d8ac1dec2d6 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -262,22 +262,14 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace,
-                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
-                             BasicBlock *DeallocationBlock,
-                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
+                             BasicBlock *AllocationBlock,
+                             BasicBlock *DeallocationBlock, std::string Suffix,
+                             bool ArgsInZeroAddressSpace)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
-      AllowVarArgs(AllowVarArgs),
+      DeallocationBlock(DeallocationBlock), AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
-      CustomArgAllocatorCB(CustomArgAllocatorCB),
-      DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
-  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
-         "custom deallocator only allowed if a custom allocator is provided");
-}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -451,6 +443,27 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
   return CommonExitBlock;
 }
 
+Instruction *CodeExtractor::allocateVar(BasicBlock *BB,
+                                        BasicBlock::iterator AllocIP,
+                                        Type *VarType, const Twine &Name,
+                                        AddrSpaceCastInst **CastedAlloc) {
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  Instruction *Alloca =
+      new AllocaInst(VarType, DL.getAllocaAddrSpace(), nullptr, Name, AllocIP);
+
+  if (CastedAlloc && ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+    *CastedAlloc = new AddrSpaceCastInst(
+        Alloca, PointerType::get(BB->getContext(), 0), Name + ".ascast");
+    (*CastedAlloc)->insertAfter(Alloca->getIterator());
+  }
+  return Alloca;
+}
+
+Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator,
+                                          Value *, Type *) {
+  return nullptr;
+}
+
 // Find the pair of life time markers for address 'Addr' that are either
 // defined inside the outline region or can legally be shrinkwrapped into the
 // outline region. If there are not other untracked uses of the address, return
@@ -1830,7 +1843,6 @@ CallInst *CodeExtractor::emitReplacerCall(
     std::vector<Value *> &Reloads) {
   LLVMContext &Context = oldFunction->getContext();
   Module *M = oldFunction->getParent();
-  const DataLayout &DL = M->getDataLayout();
 
   // This takes place of the original loop
   BasicBlock *codeReplacer =
@@ -1861,39 +1873,22 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    Value *OutAlloc;
-    if (CustomArgAllocatorCB)
-      OutAlloc = (*CustomArgAllocatorCB)(
-          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
-          output->getName() + ".loc");
-    else
-      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
-                                nullptr, output->getName() + ".loc",
-                                AllocaBlock->getFirstInsertionPt());
-
+    Value *OutAlloc =
+        allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                    output->getType(), output->getName() + ".loc");
     params.push_back(OutAlloc);
     ReloadOutputs.push_back(OutAlloc);
   }
 
   Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
-    if (CustomArgAllocatorCB) {
-      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
-                                       "structArg");
+    AddrSpaceCastInst *StructSpaceCast = nullptr;
+    Struct = allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                         StructArgTy, "structArg", &StructSpaceCast);
+    if (StructSpaceCast)
+      params.push_back(StructSpaceCast);
+    else
       params.push_back(Struct);
-    } else {
-      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                              "structArg", StructArgIP);
-      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-        auto *StructSpaceCast = new AddrSpaceCastInst(
-            Struct, PointerType ::get(Context, 0), "structArg.ascast");
-        StructSpaceCast->insertAfter(Struct->getIterator());
-        params.push_back(StructSpaceCast);
-      } else {
-        params.push_back(Struct);
-      }
-    }
 
     unsigned AggIdx = 0;
     for (Value *input : inputs) {
@@ -2036,26 +2031,24 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
                                        {}, call);
 
-  // Deallocate variables that used a custom allocator.
-  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
-    BasicBlock *DeallocBlock = codeReplacer;
-    BasicBlock::iterator DeallocIP = codeReplacer->end();
-    if (DeallocationBlock) {
-      DeallocBlock = DeallocationBlock;
-      DeallocIP = DeallocationBlock->getFirstInsertionPt();
-    }
-
-    int Index = 0;
-    for (Value *Output : outputs) {
-      if (!StructValues.contains(Output))
-        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
-                                  ReloadOutputs[Index++], Output->getType());
-    }
+  // Deallocate intermediate variables if they need explicit deallocation.
+  BasicBlock *DeallocBlock = codeReplacer;
+  BasicBlock::iterator DeallocIP = codeReplacer->end();
+  if (DeallocationBlock) {
+    DeallocBlock = DeallocationBlock;
+    DeallocIP = DeallocationBlock->getFirstInsertionPt();
+  }
 
-    if (Struct)
-      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  int Index = 0;
+  for (Value *Output : outputs) {
+    if (!StructValues.contains(Output))
+      deallocateVar(DeallocBlock, DeallocIP, ReloadOutputs[Index++],
+                    Output->getType());
   }
 
+  if (Struct)
+    deallocateVar(DeallocBlock, DeallocIP, Struct, StructArgTy);
+
   return call;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 90f06204ec9b3..8da41318dabf0 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -711,7 +711,8 @@ TEST(CodeExtractor, OpenMPAggregateArgs) {
                    /* AssumptionCache */ nullptr,
                    /* AllowVarArgs */ true,
                    /* AllowAlloca */ true,
-                   /* AllocaBlock*/ &Func->getEntryBlock(),
+                   /* AllocationBlock*/ &Func->getEntryBlock(),
+                   /* DeallocationBlock */ nullptr,
                    /* Suffix */ ".outlined",
                    /* ArgsInZeroAddressSpace */ true);
 

>From 621d7abfa783fdb10f8d211c72353b8286890b7d Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 22 Jan 2026 16:17:28 +0000
Subject: [PATCH 5/8] Update after rebase

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 26 +++++++++++------------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 5d9deba137b50..e66af8f5b58b2 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2280,15 +2280,15 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
   }
 
   llvm::CanonicalLoopInfo *CLI = result.get();
-  OutlineInfo OI;
-  OI.EntryBB = TaskloopAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskloopExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskloopAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskloopExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *> ToBeDeleted;
   // dummy instruction to be used as a fake argument
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskloopAllocaIP, "global.tid", false));
   Value *FakeLB = createFakeIntVal(Builder, AllocaIP, ToBeDeleted,
                                    TaskloopAllocaIP, "lb", false, true);
@@ -2298,11 +2298,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
                                      TaskloopAllocaIP, "step", false, true);
   // For Taskloop, we want to force the bounds being the first 3 inputs in the
   // aggregate struct
-  OI.Inputs.insert(FakeLB);
-  OI.Inputs.insert(FakeUB);
-  OI.Inputs.insert(FakeStep);
+  OI->Inputs.insert(FakeLB);
+  OI->Inputs.insert(FakeUB);
+  OI->Inputs.insert(FakeStep);
   if (TaskContextStructPtrVal)
-    OI.Inputs.insert(TaskContextStructPtrVal);
+    OI->Inputs.insert(TaskContextStructPtrVal);
   assert(((TaskContextStructPtrVal && DupCB) ||
           (!TaskContextStructPtrVal && !DupCB)) &&
          "Task context struct ptr and duplication callback must be both set "
@@ -2324,10 +2324,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
   }
   Value *TaskDupFn = *TaskDupFnOrErr;
 
-  OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
-                      TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
-                      IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
-                      FakeStep, Final, Mergeable, Priority,
+  OI->PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
+                       TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
+                       IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
+                       FakeStep, Final, Mergeable,Priority,
                       NumOfCollapseLoops](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&

>From 157c506e857851f36f8185b0e35c551b63d3c1f9 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 23 Jan 2026 14:23:54 +0000
Subject: [PATCH 6/8] Address formatting and ABI issues

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index e66af8f5b58b2..d2bab924535dc 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1885,7 +1885,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
   }
-  
+
   OI->FixUpNonEntryAllocas = true;
   OI->OuterAllocaBB = OuterAllocaBlock;
   OI->EntryBB = PRegEntryBB;

>From 965853589b72832c251fd0d4d5f4c1b702cb1969 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Wed, 11 Feb 2026 16:00:43 +0000
Subject: [PATCH 7/8] address nit comment

---
 llvm/lib/Transforms/Utils/CodeExtractor.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index b6d8ac1dec2d6..dca6884db40c0 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -461,6 +461,7 @@ Instruction *CodeExtractor::allocateVar(BasicBlock *BB,
 
 Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator,
                                           Value *, Type *) {
+  // Default alloca instructions created by allocateVar are released implicitly.
   return nullptr;
 }
 

>From 51da22d8c2fe1042b7f561ae13ca8cfb88b49240 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Mon, 23 Feb 2026 11:54:05 +0000
Subject: [PATCH 8/8] formatting

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2bab924535dc..fecfea4e7836a 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2327,8 +2327,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
   OI->PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Untied,
                        TaskloopAllocaBB, CLI, Loc, TaskDupFn, ToBeDeleted,
                        IfCond, GrainSize, NoGroup, Sched, FakeLB, FakeUB,
-                       FakeStep, Final, Mergeable,Priority,
-                      NumOfCollapseLoops](Function &OutlinedFn) mutable {
+                       FakeStep, Final, Mergeable, Priority,
+                       NumOfCollapseLoops](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");



More information about the llvm-branch-commits mailing list