[llvm-branch-commits] [llvm] [mlir] [OpenMP][OMPIRBuilder] Use device shared memory for arg structures (PR #150925)
    Sergio Afonso via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Fri Aug 15 07:59:48 PDT 2025
    
    
  
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150925
>From 688b61435b38e8632ab81e9aa94fadb5aa5ad7f1 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 3 Jul 2025 16:47:51 +0100
Subject: [PATCH 1/4] [OpenMP][OMPIRBuilder] Use device shared memory for arg
 structures
Argument structures are created when sections of the LLVM IR corresponding to
an OpenMP construct are outlined into their own function. For this, stack
allocations are used.
This patch modifies this behavior when compiling for a target device and
outlining `parallel`-related IR, so that it uses device shared memory instead
of private stack space. This is needed in order for threads to have access to
these arguments.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  6 ++
 .../llvm/Transforms/Utils/CodeExtractor.h     | 34 ++++++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 98 +++++++++++++++++--
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 73 ++++++++++----
 .../LLVMIR/omptarget-parallel-llvm.mlir       | 10 +-
 5 files changed, 187 insertions(+), 34 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 0fb664aa5f888..90740e0f4fad0 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2244,7 +2244,13 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
+    CustomArgAllocatorCBTy CustomArgAllocatorCB;
+    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
 
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index 407eb50d2c7a3..d72f697cda992 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/Support/Compiler.h"
 #include <limits>
 
@@ -24,7 +25,6 @@ namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
 class AllocaInst;
-class BasicBlock;
 class BlockFrequency;
 class BlockFrequencyInfo;
 class BranchProbabilityInfo;
@@ -85,6 +85,10 @@ class CodeExtractorAnalysisCache {
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
   class CodeExtractor {
+    using CustomArgAllocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
+    using CustomArgDeallocatorCBTy = std::function<Instruction *(
+        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -133,6 +137,25 @@ class CodeExtractorAnalysisCache {
     // space.
     bool ArgsInZeroAddressSpace;
 
+    // If set, this callback will be used to allocate the arguments in the
+    // caller before passing it to the outlined function holding the extracted
+    // piece of code.
+    CustomArgAllocatorCBTy *CustomArgAllocatorCB;
+
+    // A block outside of the extraction set where previously introduced
+    // intermediate allocations can be deallocated. This is only used when an
+    // custom deallocator is specified.
+    BasicBlock *DeallocationBlock;
+
+    // If set, this callback will be used to deallocate the arguments in the
+    // caller after running the outlined function holding the extracted piece of
+    // code. It will not be called if a custom allocator isn't also present.
+    //
+    // By default, this will be done at the end of the basic block containing
+    // the call to the outlined function, except if a deallocation block is
+    // specified. In that case, that will take precedence.
+    CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
+
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -149,7 +172,9 @@ class CodeExtractorAnalysisCache {
     /// the function from which the code is being extracted.
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
-    /// space.
+    /// space. If a CustomArgAllocatorCB callback is specified, it will be used
+    /// to allocate any structures or variable copies needed to pass arguments
+    /// to the outlined function, rather than using regular allocas.
     LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
@@ -157,7 +182,10 @@ class CodeExtractorAnalysisCache {
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
+                  CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
+                  BasicBlock *DeallocationBlock = nullptr,
+                  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
 
     /// Perform the extraction, returning the new function.
     ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 75af18518ff84..7025c9d3fce6d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -280,6 +280,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
   return Result;
 }
 
+/// Given a function, if it represents the entry point of a target kernel, this
+/// returns the execution mode flags associated to that kernel.
+static std::optional<omp::OMPTgtExecModeFlags>
+getTargetKernelExecMode(Function &Kernel) {
+  CallInst *TargetInitCall = nullptr;
+  for (Instruction &Inst : Kernel.getEntryBlock()) {
+    if (auto *Call = dyn_cast<CallInst>(&Inst)) {
+      if (Call->getCalledFunction()->getName() == "__kmpc_target_init") {
+        TargetInitCall = Call;
+        break;
+      }
+    }
+  }
+
+  if (!TargetInitCall)
+    return std::nullopt;
+
+  // Get the kernel mode information from the global variable associated to the
+  // first argument to the call to __kmpc_target_init. Refer to
+  // createTargetInit() to see how this is initialized.
+  Value *InitOperand = TargetInitCall->getArgOperand(0);
+  GlobalVariable *KernelEnv = nullptr;
+  if (auto *Cast = dyn_cast<ConstantExpr>(InitOperand))
+    KernelEnv = cast<GlobalVariable>(Cast->getOperand(0));
+  else
+    KernelEnv = cast<GlobalVariable>(InitOperand);
+  auto *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer());
+  auto *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand(0));
+  auto *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand(2));
+  return static_cast<OMPTgtExecModeFlags>(KernelMode->getZExtValue());
+}
+
 /// Make \p Source branch to \p Target.
 ///
 /// Handles two situations:
@@ -714,15 +746,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
-                            /* AggregateArgs */ true,
-                            /* BlockFrequencyInfo */ nullptr,
-                            /* BranchProbabilityInfo */ nullptr,
-                            /* AssumptionCache */ nullptr,
-                            /* AllowVarArgs */ true,
-                            /* AllowAlloca */ true,
-                            /* AllocaBlock*/ OI.OuterAllocaBB,
-                            /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+    CodeExtractor Extractor(
+        Blocks, /* DominatorTree */ nullptr,
+        /* AggregateArgs */ true,
+        /* BlockFrequencyInfo */ nullptr,
+        /* BranchProbabilityInfo */ nullptr,
+        /* AssumptionCache */ nullptr,
+        /* AllowVarArgs */ true,
+        /* AllowAlloca */ true,
+        /* AllocaBlock*/ OI.OuterAllocaBB,
+        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
+        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
+        /* DeallocationBlock */ OI.ExitBB,
+        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
@@ -1625,6 +1661,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
                              ThreadID, ToBeDeletedVec);
     };
+
+    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+        getTargetKernelExecMode(*OuterFn);
+
+    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+    // the CodeExtractor to follow its default behavior. Otherwise, we need to
+    // use device shared memory to allocate argument structures.
+    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+      OI.CustomArgAllocatorCB = [this,
+                                 EntryBB](BasicBlock *, BasicBlock::iterator,
+                                          Type *ArgTy, const Twine &Name) {
+        // Instead of using the insertion point provided by the CodeExtractor,
+        // here we need to use the block that eventually calls the outlined
+        // function for the `parallel` construct.
+        //
+        // The reason is that the explicit deallocation call will be inserted
+        // within the outlined function, whereas the alloca insertion point
+        // might actually be located somewhere else in the caller. This becomes
+        // a problem when e.g. `parallel` is inside of a `distribute` construct,
+        // because the deallocation would be executed multiple times and the
+        // allocation just once (outside of the loop).
+        //
+        // TODO: Ideally, we'd want to do the allocation and deallocation
+        // outside of the `parallel` outlined function, hence using here the
+        // insertion point provided by the CodeExtractor. We can't do this at
+        // the moment because there is currently no way of passing an eligible
+        // insertion point for the explicit deallocation to the CodeExtractor,
+        // as that block is created (at least when nested inside of
+        // `distribute`) sometime after createParallel() completed, so it can't
+        // be stored in the OutlineInfo structure here.
+        //
+        // The current approach results in an explicit allocation and
+        // deallocation pair for each `distribute` loop iteration in that case,
+        // which is suboptimal.
+        return createOMPAllocShared(
+            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
+            Name);
+      };
+      OI.CustomArgDeallocatorCB =
+          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
+                 Type *ArgTy) -> Instruction * {
+        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
+      };
+    }
   } else {
     // Generate OpenMP host runtime call
     OI.PostOutlineCB = [=, ToBeDeletedVec =
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index bbd1ed6a3ab2d..70e0c3a4cad7d 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -25,7 +25,6 @@
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -265,12 +264,18 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
                              BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace)
+                             bool ArgsInZeroAddressSpace,
+                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
+                             BasicBlock *DeallocationBlock,
+                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
+      CustomArgAllocatorCB(CustomArgAllocatorCB),
+      DeallocationBlock(DeallocationBlock),
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -1850,24 +1855,38 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    AllocaInst *alloca = new AllocaInst(
-        output->getType(), DL.getAllocaAddrSpace(), nullptr,
-        output->getName() + ".loc", AllocaBlock->getFirstInsertionPt());
-    params.push_back(alloca);
-    ReloadOutputs.push_back(alloca);
+    Value *OutAlloc;
+    if (CustomArgAllocatorCB)
+      OutAlloc = (*CustomArgAllocatorCB)(
+          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
+          output->getName() + ".loc");
+    else
+      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+                                nullptr, output->getName() + ".loc",
+                                AllocaBlock->getFirstInsertionPt());
+
+    params.push_back(OutAlloc);
+    ReloadOutputs.push_back(OutAlloc);
   }
 
-  AllocaInst *Struct = nullptr;
+  Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                            "structArg", AllocaBlock->getFirstInsertionPt());
-    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-      auto *StructSpaceCast = new AddrSpaceCastInst(
-          Struct, PointerType ::get(Context, 0), "structArg.ascast");
-      StructSpaceCast->insertAfter(Struct->getIterator());
-      params.push_back(StructSpaceCast);
-    } else {
+    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
+    if (CustomArgAllocatorCB) {
+      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
+                                       "structArg");
       params.push_back(Struct);
+    } else {
+      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
+                              "structArg", StructArgIP);
+      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+        auto *StructSpaceCast = new AddrSpaceCastInst(
+            Struct, PointerType ::get(Context, 0), "structArg.ascast");
+        StructSpaceCast->insertAfter(Struct->getIterator());
+        params.push_back(StructSpaceCast);
+      } else {
+        params.push_back(Struct);
+      }
     }
 
     unsigned AggIdx = 0;
@@ -2011,6 +2030,26 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
                                        {}, call);
 
+  // Deallocate variables that used a custom allocator.
+  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
+    BasicBlock *DeallocBlock = codeReplacer;
+    BasicBlock::iterator DeallocIP = codeReplacer->end();
+    if (DeallocationBlock) {
+      DeallocBlock = DeallocationBlock;
+      DeallocIP = DeallocationBlock->getFirstInsertionPt();
+    }
+
+    int Index = 0;
+    for (Value *Output : outputs) {
+      if (!StructValues.contains(Output))
+        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
+                                  ReloadOutputs[Index++], Output->getType());
+    }
+
+    if (Struct)
+      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  }
+
   return call;
 }
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 60c6fa4dd8f1e..504e39c96f008 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} {
 // CHECK:         %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5)
 // CHECK:         %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr
-// CHECK:         %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5)
-// CHECK:         %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr
 // CHECK:         %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5)
 // CHECK:         %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr
 // CHECK:         store ptr %[[TMP0]], ptr %[[TMP4]], align 8
@@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1
 // CHECK:         br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]]
 // CHECK:         %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8
+// CHECK:         %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
 // CHECK:         %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr))
-// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
-// CHECK:         store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8
+// CHECK:         %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0
+// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
-// CHECK:         store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8
+// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
 // CHECK:         call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
 // CHECK: define internal void @[[FUNC1]](
>From 4da11cfd7dd40b2b8bef36669605024b5a532ed1 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 5 Aug 2025 15:34:28 +0100
Subject: [PATCH 2/4] Address intermittent ICE triggered from the
 `OpenMPIRBuilder::finalize` method due to an invalid builder insertion point
---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp    | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 84830e0e971ad..539e62a402add 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5872,6 +5872,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
 static LogicalResult
 convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
+                         llvm::OpenMPIRBuilder *ompBuilder,
                          LLVM::ModuleTranslation &moduleTranslation) {
   // Amend omp.declare_target by deleting the IR of the outlined functions
   // created for target regions. They cannot be filtered out from MLIR earlier
@@ -5894,6 +5895,11 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
             moduleTranslation.lookupFunction(funcOp.getName());
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
+
+        // Invalidate the builder's current insertion point, as it now points to
+        // a deleted block.
+        ompBuilder->Builder.ClearInsertionPoint();
+        ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc());
       }
     }
     return success();
@@ -6338,9 +6344,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
       .Case("omp.declare_target",
             [&](Attribute attr) {
               if (auto declareTargetAttr =
-                      dyn_cast<omp::DeclareTargetAttr>(attr))
+                      dyn_cast<omp::DeclareTargetAttr>(attr)) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
                 return convertDeclareTargetAttr(op, declareTargetAttr,
-                                                moduleTranslation);
+                                                ompBuilder, moduleTranslation);
+              }
               return failure();
             })
       .Case("omp.requires",
>From 06a2570f2b2438b611361d229de95d22a8997d44 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Aug 2025 12:40:09 +0100
Subject: [PATCH 3/4] Address nits
---
 .../llvm/Transforms/Utils/CodeExtractor.h     | 46 +++++++++----------
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  2 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  5 +-
 3 files changed, 28 insertions(+), 25 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index d72f697cda992..ffefa624f6976 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -98,15 +98,15 @@ class CodeExtractorAnalysisCache {
     BranchProbabilityInfo *BPI;
     AssumptionCache *AC;
 
-    // A block outside of the extraction set where any intermediate
-    // allocations will be placed inside. If this is null, allocations
-    // will be placed in the entry block of the function.
+    /// A block outside of the extraction set where any intermediate
+    /// allocations will be placed inside. If this is null, allocations
+    /// will be placed in the entry block of the function.
     BasicBlock *AllocationBlock;
 
-    // If true, varargs functions can be extracted.
+    /// If true, varargs functions can be extracted.
     bool AllowVarArgs;
 
-    // Bits of intermediate state computed at various phases of extraction.
+    /// Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
 
     /// Lists of blocks that are branched from the code region to be extracted,
@@ -128,32 +128,32 @@ class CodeExtractorAnalysisCache {
     /// returns 1, etc.
     SmallVector<BasicBlock *> ExtractedFuncRetVals;
 
-    // Suffix to use when creating extracted function (appended to the original
-    // function name + "."). If empty, the default is to use the entry block
-    // label, if non-empty, otherwise "extracted".
+    /// Suffix to use when creating extracted function (appended to the original
+    /// function name + "."). If empty, the default is to use the entry block
+    /// label, if non-empty, otherwise "extracted".
     std::string Suffix;
 
-    // If true, the outlined function has aggregate argument in zero address
-    // space.
+    /// If true, the outlined function has aggregate argument in zero address
+    /// space.
     bool ArgsInZeroAddressSpace;
 
-    // If set, this callback will be used to allocate the arguments in the
-    // caller before passing it to the outlined function holding the extracted
-    // piece of code.
+    /// If set, this callback will be used to allocate the arguments in the
+    /// caller before passing it to the outlined function holding the extracted
+    /// piece of code.
     CustomArgAllocatorCBTy *CustomArgAllocatorCB;
 
-    // A block outside of the extraction set where previously introduced
-    // intermediate allocations can be deallocated. This is only used when an
-    // custom deallocator is specified.
+    /// A block outside of the extraction set where previously introduced
+    /// intermediate allocations can be deallocated. This is only used when a
+    /// custom deallocator is specified.
     BasicBlock *DeallocationBlock;
 
-    // If set, this callback will be used to deallocate the arguments in the
-    // caller after running the outlined function holding the extracted piece of
-    // code. It will not be called if a custom allocator isn't also present.
-    //
-    // By default, this will be done at the end of the basic block containing
-    // the call to the outlined function, except if a deallocation block is
-    // specified. In that case, that will take precedence.
+    /// If set, this callback will be used to deallocate the arguments in the
+    /// caller after running the outlined function holding the extracted piece
+    /// of code. It will not be called if a custom allocator isn't also present.
+    ///
+    /// By default, this will be done at the end of the basic block containing
+    /// the call to the outlined function, except if a deallocation block is
+    /// specified. In that case, that will take precedence.
     CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
 
   public:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 7025c9d3fce6d..c573c782d3958 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -281,7 +281,7 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
 }
 
 /// Given a function, if it represents the entry point of a target kernel, this
-/// returns the execution mode flags associated to that kernel.
+/// returns the execution mode flags associated with that kernel.
 static std::optional<omp::OMPTgtExecModeFlags>
 getTargetKernelExecMode(Function &Kernel) {
   CallInst *TargetInitCall = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 70e0c3a4cad7d..fc81307a70722 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -275,7 +275,10 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
       Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
       CustomArgAllocatorCB(CustomArgAllocatorCB),
       DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {}
+      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
+  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
+         "custom deallocator only allowed if a custom allocator is provided");
+}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
>From be955670bb7a2b7320bcbf05b6edb8300f01d0c6 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 15 Aug 2025 15:45:14 +0100
Subject: [PATCH 4/4] Replace CodeExtractor callbacks with subclasses and
 simplify their creation based on OutlineInfo structures
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  23 +-
 .../llvm/Transforms/Utils/CodeExtractor.h     |  59 ++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 317 ++++++++++++------
 llvm/lib/Transforms/IPO/HotColdSplitting.cpp  |   1 +
 llvm/lib/Transforms/IPO/IROutliner.cpp        |   4 +-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   | 107 +++---
 .../Transforms/Utils/CodeExtractorTest.cpp    |   3 +-
 7 files changed, 309 insertions(+), 205 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 90740e0f4fad0..36be9bfd999a1 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -31,6 +31,7 @@
 
 namespace llvm {
 class CanonicalLoopInfo;
+class CodeExtractor;
 class ScanInfo;
 struct TargetRegionEntryInfo;
 class OffloadEntriesInfoManager;
@@ -2244,27 +2245,31 @@ class OpenMPIRBuilder {
   /// during finalization.
   struct OutlineInfo {
     using PostOutlineCBTy = std::function<void(Function &)>;
-    using CustomArgAllocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-    using CustomArgDeallocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     PostOutlineCBTy PostOutlineCB;
-    CustomArgAllocatorCBTy CustomArgAllocatorCB;
-    CustomArgDeallocatorCBTy CustomArgDeallocatorCB;
     BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
     SmallVector<Value *, 2> ExcludeArgsFromAggregate;
 
+    LLVM_ABI virtual ~OutlineInfo() = default;
+
     /// Collect all blocks in between EntryBB and ExitBB in both the given
     /// vector and set.
     LLVM_ABI void collectBlocks(SmallPtrSetImpl<BasicBlock *> &BlockSet,
                                 SmallVectorImpl<BasicBlock *> &BlockVector);
 
+    /// Create a CodeExtractor instance based on the information stored in this
+    /// structure, the list of collected blocks from a previous call to
+    /// \c collectBlocks and a flag stating whether arguments must be passed in
+    /// address space 0.
+    LLVM_ABI virtual std::unique_ptr<CodeExtractor>
+    createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                        bool ArgsInZeroAddressSpace, Twine Suffix = Twine(""));
+
     /// Return the function that contains the region to be outlined.
     Function *getFunction() const { return EntryBB->getParent(); }
   };
 
   /// Collection of regions that need to be outlined during finalization.
-  SmallVector<OutlineInfo, 16> OutlineInfos;
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> OutlineInfos;
 
   /// A collection of candidate target functions that's constant allocas will
   /// attempt to be raised on a call of finalize after all currently enqueued
@@ -2279,7 +2284,9 @@ class OpenMPIRBuilder {
   std::forward_list<ScanInfo> ScanInfos;
 
   /// Add a new region that will be outlined later.
-  void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
+  void addOutlineInfo(std::unique_ptr<OutlineInfo> &&OI) {
+    OutlineInfos.emplace_back(std::move(OI));
+  }
 
   /// An ordered map of auto-generated variables to their unique names.
   /// It stores variables with the following names: 1) ".gomp_critical_user_" +
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index ffefa624f6976..b3bea96039172 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -24,6 +24,7 @@
 namespace llvm {
 
 template <typename PtrType> class SmallPtrSetImpl;
+class AddrSpaceCastInst;
 class AllocaInst;
 class BlockFrequency;
 class BlockFrequencyInfo;
@@ -85,10 +86,6 @@ class CodeExtractorAnalysisCache {
   /// 3) Add allocas for any scalar outputs, adding all of the outputs' allocas
   ///    as arguments, and inserting stores to the arguments for any scalars.
   class CodeExtractor {
-    using CustomArgAllocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Type *, const Twine &)>;
-    using CustomArgDeallocatorCBTy = std::function<Instruction *(
-        BasicBlock *, BasicBlock::iterator, Value *, Type *)>;
     using ValueSet = SetVector<Value *>;
 
     // Various bits of state computed on construction.
@@ -103,6 +100,14 @@ class CodeExtractorAnalysisCache {
     /// will be placed in the entry block of the function.
     BasicBlock *AllocationBlock;
 
+    /// A block outside of the extraction set where deallocations for
+    /// intermediate allocations can be placed inside. Not used for
+    /// automatically deallocated memory (e.g. `alloca`), which is the default.
+    ///
+    /// If it is null and needed, the end of the replacement basic block will be
+    /// used to place deallocations.
+    BasicBlock *DeallocationBlock;
+
     /// If true, varargs functions can be extracted.
     bool AllowVarArgs;
 
@@ -137,25 +142,6 @@ class CodeExtractorAnalysisCache {
     /// space.
     bool ArgsInZeroAddressSpace;
 
-    /// If set, this callback will be used to allocate the arguments in the
-    /// caller before passing it to the outlined function holding the extracted
-    /// piece of code.
-    CustomArgAllocatorCBTy *CustomArgAllocatorCB;
-
-    /// A block outside of the extraction set where previously introduced
-    /// intermediate allocations can be deallocated. This is only used when a
-    /// custom deallocator is specified.
-    BasicBlock *DeallocationBlock;
-
-    /// If set, this callback will be used to deallocate the arguments in the
-    /// caller after running the outlined function holding the extracted piece
-    /// of code. It will not be called if a custom allocator isn't also present.
-    ///
-    /// By default, this will be done at the end of the basic block containing
-    /// the call to the outlined function, except if a deallocation block is
-    /// specified. In that case, that will take precedence.
-    CustomArgDeallocatorCBTy *CustomArgDeallocatorCB;
-
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -169,12 +155,12 @@ class CodeExtractorAnalysisCache {
     /// however code extractor won't validate whether extraction is legal.
     /// Any new allocations will be placed in the AllocationBlock, unless
     /// it is null, in which case it will be placed in the entry block of
-    /// the function from which the code is being extracted.
+    /// the function from which the code is being extracted. Explicit
+    /// deallocations for the aforementioned allocations will be placed in the
+    /// DeallocationBlock or the end of the replacement block, if needed.
     /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
     /// param pointer of the outlined function is declared in zero address
-    /// space. If a CustomArgAllocatorCB callback is specified, it will be used
-    /// to allocate any structures or variable copies needed to pass arguments
-    /// to the outlined function, rather than using regular allocas.
+    /// space.
     LLVM_ABI
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
@@ -182,10 +168,10 @@ class CodeExtractorAnalysisCache {
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false,
-                  CustomArgAllocatorCBTy *CustomArgAllocatorCB = nullptr,
                   BasicBlock *DeallocationBlock = nullptr,
-                  CustomArgDeallocatorCBTy *CustomArgDeallocatorCB = nullptr);
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
+
+    LLVM_ABI virtual ~CodeExtractor() = default;
 
     /// Perform the extraction, returning the new function.
     ///
@@ -271,6 +257,19 @@ class CodeExtractorAnalysisCache {
     /// region, passing it instead as a scalar.
     LLVM_ABI void excludeArgFromAggregate(Value *Arg);
 
+  protected:
+    /// Allocate an intermediate variable at the specified point.
+    LLVM_ABI virtual Instruction *
+    allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP, Type *VarType,
+                const Twine &Name = Twine(""),
+                AddrSpaceCastInst **CastedAlloc = nullptr);
+
+    /// Deallocate a previously-allocated intermediate variable at the specified
+    /// point.
+    LLVM_ABI virtual Instruction *deallocateVar(BasicBlock *BB,
+                                                BasicBlock::iterator DeallocIP,
+                                                Value *Var, Type *VarType);
+
   private:
     struct LifetimeMarkerInfo {
       bool SinkLifeStart = false;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c573c782d3958..83cb21b54394b 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -467,6 +467,88 @@ enum OpenMPOffloadingRequiresDirFlags {
   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
 };
 
+class OMPCodeExtractor : public CodeExtractor {
+public:
+  OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef<BasicBlock *> BBs,
+                   DominatorTree *DT = nullptr, bool AggregateArgs = false,
+                   BlockFrequencyInfo *BFI = nullptr,
+                   BranchProbabilityInfo *BPI = nullptr,
+                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
+                   bool AllowAlloca = false,
+                   BasicBlock *AllocationBlock = nullptr,
+                   BasicBlock *DeallocationBlock = nullptr,
+                   std::string Suffix = "", bool ArgsInZeroAddressSpace = false)
+      : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs,
+                      AllowAlloca, AllocationBlock, DeallocationBlock, Suffix,
+                      ArgsInZeroAddressSpace),
+        OMPBuilder(OMPBuilder) {}
+
+  virtual ~OMPCodeExtractor() = default;
+
+protected:
+  OpenMPIRBuilder &OMPBuilder;
+};
+
+class DeviceSharedMemCodeExtractor : public OMPCodeExtractor {
+public:
+  DeviceSharedMemCodeExtractor(
+      OpenMPIRBuilder &OMPBuilder, BasicBlock *AllocBlockOverride,
+      ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
+      bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+      BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr,
+      bool AllowVarArgs = false, bool AllowAlloca = false,
+      BasicBlock *AllocationBlock = nullptr,
+      BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "",
+      bool ArgsInZeroAddressSpace = false)
+      : OMPCodeExtractor(OMPBuilder, BBs, DT, AggregateArgs, BFI, BPI, AC,
+                         AllowVarArgs, AllowAlloca, AllocationBlock,
+                         DeallocationBlock, Suffix, ArgsInZeroAddressSpace),
+        AllocBlockOverride(AllocBlockOverride) {}
+  virtual ~DeviceSharedMemCodeExtractor() = default;
+
+protected:
+  virtual Instruction *
+  allocateVar(BasicBlock *, BasicBlock::iterator, Type *VarType,
+              const Twine &Name = Twine(""),
+              AddrSpaceCastInst **CastedAlloc = nullptr) override {
+    // Ignore the CastedAlloc pointer, if requested, because shared memory
+    // should not be casted to address space 0 to be passed around.
+    return OMPBuilder.createOMPAllocShared(
+        OpenMPIRBuilder::InsertPointTy(
+            AllocBlockOverride, AllocBlockOverride->getFirstInsertionPt()),
+        VarType, Name);
+  }
+
+  virtual Instruction *deallocateVar(BasicBlock *BB,
+                                     BasicBlock::iterator DeallocIP, Value *Var,
+                                     Type *VarType) override {
+    return OMPBuilder.createOMPFreeShared(
+        OpenMPIRBuilder::InsertPointTy(BB, DeallocIP), Var, VarType);
+  }
+
+private:
+  // TODO: Remove the need for this override and instead get the CodeExtractor
+  // to provide a valid insert point for explicit deallocations by correctly
+  // populating its DeallocationBlock.
+  BasicBlock *AllocBlockOverride;
+};
+
+/// Helper storing information about regions to outline using device shared
+/// memory for intermediate allocations.
+struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo {
+  OpenMPIRBuilder &OMPBuilder;
+  BasicBlock *AllocBlockOverride = nullptr;
+
+  DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder)
+      : OMPBuilder(OMPBuilder) {}
+  virtual ~DeviceSharedMemOutlineInfo() = default;
+
+  virtual std::unique_ptr<CodeExtractor>
+  createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                      bool ArgsInZeroAddressSpace,
+                      Twine Suffix = Twine("")) override;
+};
+
 } // anonymous namespace
 
 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
@@ -724,20 +806,20 @@ static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
 void OpenMPIRBuilder::finalize(Function *Fn) {
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  SmallVector<OutlineInfo, 16> DeferredOutlines;
-  for (OutlineInfo &OI : OutlineInfos) {
+  SmallVector<std::unique_ptr<OutlineInfo>, 16> DeferredOutlines;
+  for (std::unique_ptr<OutlineInfo> &OI : OutlineInfos) {
     // Skip functions that have not finalized yet; may happen with nested
     // function generation.
-    if (Fn && OI.getFunction() != Fn) {
-      DeferredOutlines.push_back(OI);
+    if (Fn && OI->getFunction() != Fn) {
+      DeferredOutlines.push_back(std::move(OI));
       continue;
     }
 
     ParallelRegionBlockSet.clear();
     Blocks.clear();
-    OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+    OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
-    Function *OuterFn = OI.getFunction();
+    Function *OuterFn = OI->getFunction();
     CodeExtractorAnalysisCache CEAC(*OuterFn);
     // If we generate code for the target device, we need to allocate
     // struct for aggregate params in the device default alloca address space.
@@ -746,30 +828,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // CodeExtractor generates correct code for extracted functions
     // which are used by OpenMP runtime.
     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
-    CodeExtractor Extractor(
-        Blocks, /* DominatorTree */ nullptr,
-        /* AggregateArgs */ true,
-        /* BlockFrequencyInfo */ nullptr,
-        /* BranchProbabilityInfo */ nullptr,
-        /* AssumptionCache */ nullptr,
-        /* AllowVarArgs */ true,
-        /* AllowAlloca */ true,
-        /* AllocaBlock*/ OI.OuterAllocaBB,
-        /* Suffix */ ".omp_par", ArgsInZeroAddressSpace,
-        OI.CustomArgAllocatorCB ? &OI.CustomArgAllocatorCB : nullptr,
-        /* DeallocationBlock */ OI.ExitBB,
-        OI.CustomArgDeallocatorCB ? &OI.CustomArgDeallocatorCB : nullptr);
+    std::unique_ptr<CodeExtractor> Extractor =
+        OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, ".omp_par");
 
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
-    LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
-                      << " Exit: " << OI.ExitBB->getName() << "\n");
-    assert(Extractor.isEligible() &&
+    LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName()
+                      << " Exit: " << OI->ExitBB->getName() << "\n");
+    assert(Extractor->isEligible() &&
            "Expected OpenMP outlining to be possible!");
 
-    for (auto *V : OI.ExcludeArgsFromAggregate)
-      Extractor.excludeArgFromAggregate(V);
+    for (auto *V : OI->ExcludeArgsFromAggregate)
+      Extractor->excludeArgFromAggregate(V);
 
-    Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
+    Function *OutlinedFn = Extractor->extractCodeRegion(CEAC);
 
     // Forward target-cpu, target-features attributes to the outlined function.
     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
@@ -794,8 +865,8 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     // made our own entry block after all.
     {
       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
-      assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
-      assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
+      assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB);
+      assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry);
       // Move instructions from the to-be-deleted ArtificialEntry to the entry
       // basic block of the parallel region. CodeExtractor generates
       // instructions to unwrap the aggregate argument and may sink
@@ -811,24 +882,25 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
 
         if (I.isTerminator()) {
           // Absorb any debug value that terminator may have
-          if (OI.EntryBB->getTerminator())
-            OI.EntryBB->getTerminator()->adoptDbgRecords(
+          if (OI->EntryBB->getTerminator())
+            OI->EntryBB->getTerminator()->adoptDbgRecords(
                 &ArtificialEntry, I.getIterator(), false);
           continue;
         }
 
-        I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
+        I.moveBeforePreserving(*OI->EntryBB,
+                               OI->EntryBB->getFirstInsertionPt());
       }
 
-      OI.EntryBB->moveBefore(&ArtificialEntry);
+      OI->EntryBB->moveBefore(&ArtificialEntry);
       ArtificialEntry.eraseFromParent();
     }
-    assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
+    assert(&OutlinedFn->getEntryBlock() == OI->EntryBB);
     assert(OutlinedFn && OutlinedFn->hasNUses(1));
 
     // Run a user callback, e.g. to add attributes.
-    if (OI.PostOutlineCB)
-      OI.PostOutlineCB(*OutlinedFn);
+    if (OI->PostOutlineCB)
+      OI->PostOutlineCB(*OutlinedFn);
   }
 
   // Remove work items that have been completed.
@@ -1652,26 +1724,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
 
-  OutlineInfo OI;
-  if (Config.isTargetDevice()) {
-    // Generate OpenMP target specific runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
-      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
-                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
-                             ThreadID, ToBeDeletedVec);
-    };
+  auto OI = [&]() -> std::unique_ptr<OutlineInfo> {
+    if (Config.isTargetDevice()) {
+      std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+          getTargetKernelExecMode(*OuterFn);
 
-    std::optional<omp::OMPTgtExecModeFlags> ExecMode =
-        getTargetKernelExecMode(*OuterFn);
+      // If OuterFn is not a Generic kernel, skip custom allocation. This causes
+      // the CodeExtractor to follow its default behavior. Otherwise, we need to
+      // use device shared memory to allocate argument structures.
+      if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
+        auto Info = std::make_unique<DeviceSharedMemOutlineInfo>(*this);
 
-    // If OuterFn is not a Generic kernel, skip custom allocation. This causes
-    // the CodeExtractor to follow its default behavior. Otherwise, we need to
-    // use device shared memory to allocate argument structures.
-    if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
-      OI.CustomArgAllocatorCB = [this,
-                                 EntryBB](BasicBlock *, BasicBlock::iterator,
-                                          Type *ArgTy, const Twine &Name) {
         // Instead of using the insertion point provided by the CodeExtractor,
         // here we need to use the block that eventually calls the outlined
         // function for the `parallel` construct.
@@ -1695,32 +1758,37 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
         // The current approach results in an explicit allocation and
         // deallocation pair for each `distribute` loop iteration in that case,
         // which is suboptimal.
-        return createOMPAllocShared(
-            InsertPointTy(EntryBB, EntryBB->getFirstInsertionPt()), ArgTy,
-            Name);
-      };
-      OI.CustomArgDeallocatorCB =
-          [this](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
-                 Type *ArgTy) -> Instruction * {
-        return createOMPFreeShared(InsertPointTy(BB, AllocIP), Arg, ArgTy);
-      };
+        Info->AllocBlockOverride = EntryBB;
+        return Info;
+      }
     }
+    return std::make_unique<OutlineInfo>();
+  }();
+
+  if (Config.isTargetDevice()) {
+    // Generate OpenMP target specific runtime call
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
+      targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
+                             IfCondition, NumThreads, PrivTID, PrivTIDAddr,
+                             ThreadID, ToBeDeletedVec);
+    };
   } else {
     // Generate OpenMP host runtime call
-    OI.PostOutlineCB = [=, ToBeDeletedVec =
-                               std::move(ToBeDeleted)](Function &OutlinedFn) {
+    OI->PostOutlineCB = [=, ToBeDeletedVec =
+                                std::move(ToBeDeleted)](Function &OutlinedFn) {
       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
     };
   }
 
-  OI.OuterAllocaBB = OuterAllocaBlock;
-  OI.EntryBB = PRegEntryBB;
-  OI.ExitBB = PRegExitBB;
+  OI->OuterAllocaBB = OuterAllocaBlock;
+  OI->EntryBB = PRegEntryBB;
+  OI->ExitBB = PRegExitBB;
 
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
@@ -1731,6 +1799,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ OuterAllocaBlock,
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
 
   // Find inputs to, outputs from the code region.
@@ -1755,7 +1824,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
 
   auto PrivHelper = [&](Value &V) -> Error {
     if (&V == TIDAddr || &V == ZeroAddr) {
-      OI.ExcludeArgsFromAggregate.push_back(&V);
+      OI->ExcludeArgsFromAggregate.push_back(&V);
       return Error::success();
     }
 
@@ -2032,19 +2101,19 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
   if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = TaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
-  OI.ExitBB = TaskExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
+  OI->ExitBB = TaskExitBB;
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
 
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      Mergeable, Priority, EventHandle, TaskAllocaBB,
-                      ToBeDeleted](Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                       Mergeable, Priority, EventHandle, TaskAllocaBB,
+                       ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
@@ -5133,19 +5202,19 @@ OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
 
-  OutlineInfo OI;
-  OI.OuterAllocaBB = CLI->getPreheader();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->OuterAllocaBB = CLI->getPreheader();
   Function *OuterFn = CLI->getPreheader()->getParent();
 
   // Instructions which need to be deleted at the end of code generation
   SmallVector<Instruction *, 4> ToBeDeleted;
 
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Mark the body loop as region which needs to be extracted
-  OI.EntryBB = CLI->getBody();
-  OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
-                                               "omp.prelatch", true);
+  OI->EntryBB = CLI->getBody();
+  OI->ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
+                                                "omp.prelatch", true);
 
   // Prepare loop body for extraction
   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
@@ -5165,7 +5234,7 @@ OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
   // loop body region.
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
   SmallVector<BasicBlock *, 32> Blocks;
-  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  OI->collectBlocks(ParallelRegionBlockSet, Blocks);
 
   CodeExtractorAnalysisCache CEAC(*OuterFn);
   CodeExtractor Extractor(Blocks,
@@ -5177,6 +5246,7 @@ OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
                           /* AllowVarArgs */ true,
                           /* AllowAlloca */ true,
                           /* AllocationBlock */ CLI->getPreheader(),
+                          /* DeallocationBlock */ nullptr,
                           /* Suffix */ ".omp_wsloop",
                           /* AggrArgsIn0AddrSpace */ true);
 
@@ -5201,15 +5271,15 @@ OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
   }
   // Make sure that loop counter variable is not merged into loop body
   // function argument structure and it is passed as separate variable
-  OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
+  OI->ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
 
   // PostOutline CB is invoked when loop body function is outlined and
   // loop body is replaced by call to outlined function. We need to add
   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
   // function will handle loop control logic.
   //
-  OI.PostOutlineCB = [=, ToBeDeletedVec =
-                             std::move(ToBeDeleted)](Function &OutlinedFn) {
+  OI->PostOutlineCB = [=, ToBeDeletedVec =
+                              std::move(ToBeDeleted)](Function &OutlinedFn) {
     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
                                 LoopType);
   };
@@ -8023,13 +8093,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
                                    TargetTaskAllocaBB->begin());
   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
 
-  OutlineInfo OI;
-  OI.EntryBB = TargetTaskAllocaBB;
-  OI.OuterAllocaBB = AllocaIP.getBlock();
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = TargetTaskAllocaBB;
+  OI->OuterAllocaBB = AllocaIP.getBlock();
 
   // Add the thread ID argument.
   SmallVector<Instruction *, 4> ToBeDeleted;
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
 
   // Generate the task body which will subsequently be outlined.
@@ -8047,8 +8117,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
   // OI.ExitBlock is set to the single task body block and will get left out of
   // the outlining process. So, simply create a new empty block to which we
   // uncoditionally branch from where TaskBodyCB left off
-  OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
-  emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
+  OI->ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
+  emitBlock(OI->ExitBB, Builder.GetInsertBlock()->getParent(),
             /*IsFinished=*/true);
 
   SmallVector<Value *, 2> OffloadingArraysToPrivatize;
@@ -8060,13 +8130,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
           RTArgs.SizesArray}) {
       if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
         OffloadingArraysToPrivatize.push_back(V);
-        OI.ExcludeArgsFromAggregate.push_back(V);
+        OI->ExcludeArgsFromAggregate.push_back(V);
       }
     }
   }
-  OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
-                      DeviceID, OffloadingArraysToPrivatize](
-                         Function &OutlinedFn) mutable {
+  OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
+                       DeviceID, OffloadingArraysToPrivatize](
+                          Function &OutlinedFn) mutable {
     assert(OutlinedFn.hasOneUse() &&
            "there must be a single user for the outlined function");
 
@@ -10026,17 +10096,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.EntryBB = AllocaBB;
-  OI.ExitBB = ExitBB;
-  OI.OuterAllocaBB = &OuterAllocaBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->EntryBB = AllocaBB;
+  OI->ExitBB = ExitBB;
+  OI->OuterAllocaBB = &OuterAllocaBB;
 
   // Insert fake values for global tid and bound tid.
   SmallVector<Instruction *, 8> ToBeDeleted;
   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
-  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+  OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal(
       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
 
   auto HostPostOutlineCB = [this, Ident,
@@ -10076,7 +10146,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   };
 
   if (!Config.isTargetDevice())
-    OI.PostOutlineCB = HostPostOutlineCB;
+    OI->PostOutlineCB = HostPostOutlineCB;
 
   addOutlineInfo(std::move(OI));
 
@@ -10112,10 +10182,10 @@ OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
     return Err;
 
-  OutlineInfo OI;
-  OI.OuterAllocaBB = OuterAllocaIP.getBlock();
-  OI.EntryBB = AllocaBB;
-  OI.ExitBB = ExitBB;
+  auto OI = std::make_unique<OutlineInfo>();
+  OI->OuterAllocaBB = OuterAllocaIP.getBlock();
+  OI->EntryBB = AllocaBB;
+  OI->ExitBB = ExitBB;
 
   addOutlineInfo(std::move(OI));
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
@@ -10175,6 +10245,39 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks(
   }
 }
 
+std::unique_ptr<CodeExtractor>
+OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
+                                                  bool ArgsInZeroAddressSpace,
+                                                  Twine Suffix) {
+  return std::make_unique<CodeExtractor>(Blocks, /* DominatorTree */ nullptr,
+                                         /* AggregateArgs */ true,
+                                         /* BlockFrequencyInfo */ nullptr,
+                                         /* BranchProbabilityInfo */ nullptr,
+                                         /* AssumptionCache */ nullptr,
+                                         /* AllowVarArgs */ true,
+                                         /* AllowAlloca */ true,
+                                         /* AllocationBlock*/ OuterAllocaBB,
+                                         /* DeallocationBlock */ nullptr,
+                                         /* Suffix */ Suffix.str(),
+                                         ArgsInZeroAddressSpace);
+}
+
+std::unique_ptr<CodeExtractor> DeviceSharedMemOutlineInfo::createCodeExtractor(
+    ArrayRef<BasicBlock *> Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) {
+  // TODO: Initialize the DeallocationBlock with a proper pair to OuterAllocaBB.
+  return std::make_unique<DeviceSharedMemCodeExtractor>(
+      OMPBuilder, AllocBlockOverride, Blocks, /* DominatorTree */ nullptr,
+      /* AggregateArgs */ true,
+      /* BlockFrequencyInfo */ nullptr,
+      /* BranchProbabilityInfo */ nullptr,
+      /* AssumptionCache */ nullptr,
+      /* AllowVarArgs */ true,
+      /* AllowAlloca */ true,
+      /* AllocationBlock*/ OuterAllocaBB,
+      /* DeallocationBlock */ ExitBB,
+      /* Suffix */ Suffix.str(), ArgsInZeroAddressSpace);
+}
+
 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
                                          uint64_t Size, int32_t Flags,
                                          GlobalValue::LinkageTypes,
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 3d8b7cbb59630..57809017a75a4 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -721,6 +721,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
             SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr,
             /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
             /* AllowAlloca */ false, /* AllocaBlock */ nullptr,
+            /* DeallocationBlock */ nullptr,
             /* Suffix */ "cold." + std::to_string(OutlinedFunctionID));
 
         if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) &&
diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp
index c57981ae4ca0d..77c1d98116747 100644
--- a/llvm/lib/Transforms/IPO/IROutliner.cpp
+++ b/llvm/lib/Transforms/IPO/IROutliner.cpp
@@ -2829,7 +2829,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       findAddInputsOutputs(M, *OS, NotSame);
       if (!OS->IgnoreRegion)
         OutlinedRegions.push_back(OS);
@@ -2940,7 +2940,7 @@ unsigned IROutliner::doOutline(Module &M) {
       OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
       OS->CE = new (ExtractorAllocator.Allocate())
           CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
-                        false, nullptr, "outlined");
+                        false, nullptr, nullptr, "outlined");
       bool FunctionOutlined = extractSection(*OS);
       if (FunctionOutlined) {
         unsigned StartIdx = OS->Candidate->getStartIdx();
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index fc81307a70722..3339f5e4fea7d 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -263,22 +263,14 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix,
-                             bool ArgsInZeroAddressSpace,
-                             CustomArgAllocatorCBTy *CustomArgAllocatorCB,
-                             BasicBlock *DeallocationBlock,
-                             CustomArgDeallocatorCBTy *CustomArgDeallocatorCB)
+                             BasicBlock *AllocationBlock,
+                             BasicBlock *DeallocationBlock, std::string Suffix,
+                             bool ArgsInZeroAddressSpace)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
-      AllowVarArgs(AllowVarArgs),
+      DeallocationBlock(DeallocationBlock), AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace),
-      CustomArgAllocatorCB(CustomArgAllocatorCB),
-      DeallocationBlock(DeallocationBlock),
-      CustomArgDeallocatorCB(CustomArgDeallocatorCB) {
-  assert((!CustomArgDeallocatorCB || CustomArgAllocatorCB) &&
-         "custom deallocator only allowed if a custom allocator is provided");
-}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -452,6 +444,27 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
   return CommonExitBlock;
 }
 
+Instruction *CodeExtractor::allocateVar(BasicBlock *BB,
+                                        BasicBlock::iterator AllocIP,
+                                        Type *VarType, const Twine &Name,
+                                        AddrSpaceCastInst **CastedAlloc) {
+  const DataLayout &DL = BB->getModule()->getDataLayout();
+  Instruction *Alloca =
+      new AllocaInst(VarType, DL.getAllocaAddrSpace(), nullptr, Name, AllocIP);
+
+  if (CastedAlloc && ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+    *CastedAlloc = new AddrSpaceCastInst(
+        Alloca, PointerType::get(BB->getContext(), 0), Name + ".ascast");
+    (*CastedAlloc)->insertAfter(Alloca->getIterator());
+  }
+  return Alloca;
+}
+
+Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator,
+                                          Value *, Type *) {
+  return nullptr;
+}
+
 // Find the pair of life time markers for address 'Addr' that are either
 // defined inside the outline region or can legally be shrinkwrapped into the
 // outline region. If there are not other untracked uses of the address, return
@@ -1827,7 +1840,6 @@ CallInst *CodeExtractor::emitReplacerCall(
     std::vector<Value *> &Reloads) {
   LLVMContext &Context = oldFunction->getContext();
   Module *M = oldFunction->getParent();
-  const DataLayout &DL = M->getDataLayout();
 
   // This takes place of the original loop
   BasicBlock *codeReplacer =
@@ -1858,39 +1870,22 @@ CallInst *CodeExtractor::emitReplacerCall(
     if (StructValues.contains(output))
       continue;
 
-    Value *OutAlloc;
-    if (CustomArgAllocatorCB)
-      OutAlloc = (*CustomArgAllocatorCB)(
-          AllocaBlock, AllocaBlock->getFirstInsertionPt(), output->getType(),
-          output->getName() + ".loc");
-    else
-      OutAlloc = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
-                                nullptr, output->getName() + ".loc",
-                                AllocaBlock->getFirstInsertionPt());
-
+    Value *OutAlloc =
+        allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                    output->getType(), output->getName() + ".loc");
     params.push_back(OutAlloc);
     ReloadOutputs.push_back(OutAlloc);
   }
 
   Instruction *Struct = nullptr;
   if (!StructValues.empty()) {
-    BasicBlock::iterator StructArgIP = AllocaBlock->getFirstInsertionPt();
-    if (CustomArgAllocatorCB) {
-      Struct = (*CustomArgAllocatorCB)(AllocaBlock, StructArgIP, StructArgTy,
-                                       "structArg");
+    AddrSpaceCastInst *StructSpaceCast = nullptr;
+    Struct = allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(),
+                         StructArgTy, "structArg", &StructSpaceCast);
+    if (StructSpaceCast)
+      params.push_back(StructSpaceCast);
+    else
       params.push_back(Struct);
-    } else {
-      Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
-                              "structArg", StructArgIP);
-      if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
-        auto *StructSpaceCast = new AddrSpaceCastInst(
-            Struct, PointerType ::get(Context, 0), "structArg.ascast");
-        StructSpaceCast->insertAfter(Struct->getIterator());
-        params.push_back(StructSpaceCast);
-      } else {
-        params.push_back(Struct);
-      }
-    }
 
     unsigned AggIdx = 0;
     for (Value *input : inputs) {
@@ -2033,26 +2028,24 @@ CallInst *CodeExtractor::emitReplacerCall(
   insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart,
                                        {}, call);
 
-  // Deallocate variables that used a custom allocator.
-  if (CustomArgAllocatorCB && CustomArgDeallocatorCB) {
-    BasicBlock *DeallocBlock = codeReplacer;
-    BasicBlock::iterator DeallocIP = codeReplacer->end();
-    if (DeallocationBlock) {
-      DeallocBlock = DeallocationBlock;
-      DeallocIP = DeallocationBlock->getFirstInsertionPt();
-    }
-
-    int Index = 0;
-    for (Value *Output : outputs) {
-      if (!StructValues.contains(Output))
-        (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP,
-                                  ReloadOutputs[Index++], Output->getType());
-    }
+  // Deallocate intermediate variables if they need explicit deallocation.
+  BasicBlock *DeallocBlock = codeReplacer;
+  BasicBlock::iterator DeallocIP = codeReplacer->end();
+  if (DeallocationBlock) {
+    DeallocBlock = DeallocationBlock;
+    DeallocIP = DeallocationBlock->getFirstInsertionPt();
+  }
 
-    if (Struct)
-      (*CustomArgDeallocatorCB)(DeallocBlock, DeallocIP, Struct, StructArgTy);
+  int Index = 0;
+  for (Value *Output : outputs) {
+    if (!StructValues.contains(Output))
+      deallocateVar(DeallocBlock, DeallocIP, ReloadOutputs[Index++],
+                    Output->getType());
   }
 
+  if (Struct)
+    deallocateVar(DeallocBlock, DeallocIP, Struct, StructArgTy);
+
   return call;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 9ea8de3da1e5b..6fd266a815dcf 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -711,7 +711,8 @@ TEST(CodeExtractor, OpenMPAggregateArgs) {
                    /* AssumptionCache */ nullptr,
                    /* AllowVarArgs */ true,
                    /* AllowAlloca */ true,
-                   /* AllocaBlock*/ &Func->getEntryBlock(),
+                   /* AllocationBlock*/ &Func->getEntryBlock(),
+                   /* DeallocationBlock */ nullptr,
                    /* Suffix */ ".outlined",
                    /* ArgsInZeroAddressSpace */ true);
 
    
    
More information about the llvm-branch-commits
mailing list