[Mlir-commits] [flang] [llvm] [mlir] [Flang] [OpenMP] [MLIR] Add lowering support for OMP ALLOCATE directives and its clauses (PR #187167)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 18:14:14 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Raghu Maddhipatla (raghavendhra)

<details>
<summary>Changes</summary>

This patch implementation is primarily focused on

- Lowering to LLVM IR, by generating appropriate kmpc_alloc() and kmpc_alligned_alloc() calls and before the termination of block generating kmpc_free() for the same.
- Also handled, usage of array variables in the OMP ALLOCATE directive.
- Slight change to MLIR definition.
- Add test cases for variations of usage of OMP ALLOCATE directive and its clauses ALIGN and ALLOCATOR.

---

Patch is 25.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/187167.diff


14 Files Affected:

- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+19) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+2) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+38-2) 
- (removed) flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90 (-10) 
- (removed) flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90 (-10) 
- (added) flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90 (+47) 
- (added) flang/test/Lower/OpenMP/omp-declarative-allocate.f90 (+19) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+14-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+17) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+2-2) 
- (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+13) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+107) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+28) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+8-8) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0bd25e36b6468..7008df74161c6 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -325,6 +325,25 @@ static void collectIteratorIVs(
 // ClauseProcessor unique clauses
 //===----------------------------------------------------------------------===//
 
+bool ClauseProcessor::processAlign(
+    mlir::omp::AlignClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Align>()) {
+    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+    const std::optional<std::int64_t> align = evaluate::ToInt64(clause->v);
+    result.align = firOpBuilder.getI64IntegerAttr(*align);
+    return true;
+  }
+  return false;
+}
+
+bool ClauseProcessor::processAllocator(lower::StatementContext &stmtCtx, mlir::omp::AllocatorClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Allocator>()) {
+    result.allocator = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const {
   return markClauseOccurrence<omp::clause::OmpxBare>(result.bare);
 }
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 52e69c1796876..62436e2a19173 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -57,6 +57,8 @@ class ClauseProcessor {
       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
 
   // 'Unique' clauses: They can appear at most once in the clause list.
+  bool processAlign(mlir::omp::AlignClauseOps &result) const;
+  bool processAllocator(lower::StatementContext &stmtCtx, mlir::omp::AllocatorClauseOps &result) const;
   bool processBare(mlir::omp::BareClauseOps &result) const;
   bool processBind(mlir::omp::BindClauseOps &result) const;
   bool processCancelDirectiveName(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ae5f6f50bda09..9784765fe033e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1493,6 +1493,21 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter,
 // Code generation functions for clauses
 //===----------------------------------------------------------------------===//
 
+static void genAllocateClauses(lower::AbstractConverter &converter,
+                            semantics::SemanticsContext &semaCtx,
+                            lower::StatementContext &stmtCtx,
+                            const ObjectList &objects,
+                            const List<Clause> &clauses, mlir::Location loc,
+                            llvm::SmallVectorImpl<mlir::Value> &operandRange,
+                            mlir::omp::AllocateDirOperands &clauseOps) {
+  if (!objects.empty())
+    genObjectList(objects, converter, operandRange);
+
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processAlign(clauseOps);
+  cp.processAllocator(stmtCtx, clauseOps);
+}
+
 static void genCancelClauses(lower::AbstractConverter &converter,
                              semantics::SemanticsContext &semaCtx,
                              const List<Clause> &clauses, mlir::Location loc,
@@ -1899,6 +1914,17 @@ static void genWsloopClauses(
 //===----------------------------------------------------------------------===//
 // Code generation functions for leaf constructs
 //===----------------------------------------------------------------------===//
+static mlir::omp::AllocateDirOp
+genAllocateDirOp(lower::AbstractConverter &converter,
+           semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval,
+           mlir::Location loc, const ObjectList &objects,  const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+  llvm::SmallVector<mlir::Value> operandRange;
+  mlir::omp::AllocateDirOperands clauseOps;
+  genAllocateClauses(converter, semaCtx, stmtCtx, objects, item->clauses, loc,
+                  operandRange, clauseOps);
+
+  return mlir::omp::AllocateDirOp::create(converter.getFirOpBuilder(), loc, operandRange, clauseOps.align, clauseOps.allocator);
+}
 
 static mlir::omp::BarrierOp
 genBarrierOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
@@ -3796,8 +3822,18 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
                    semantics::SemanticsContext &semaCtx,
                    lower::pft::Evaluation &eval,
                    const parser::OmpAllocateDirective &allocate) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OmpAllocateDirective");
+  lower::StatementContext stmtCtx;
+  ObjectList objects = makeObjects((allocate.BeginDir().Arguments()), semaCtx);
+  const auto &clauseList = (allocate.BeginDir().Clauses());
+  List<Clause> clauses = makeClauses(clauseList, semaCtx);
+  mlir::Location loc = converter.genLocation(allocate.source);
+
+  ConstructQueue queue{buildConstructQueue(
+      converter.getFirOpBuilder().getModule(), semaCtx, eval, allocate.source,
+      llvm::omp::Directive::OMPD_allocate, clauses)};
+
+  genAllocateDirOp(converter, semaCtx, stmtCtx, eval, loc, objects,
+             queue, queue.begin());
 }
 
 static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90 b/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90
deleted file mode 100644
index fec146ac70313..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-! This test checks lowering of OpenMP allocate Directive with align clause.
-
-! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 %s 2>&1 | FileCheck %s
-
-program main
-  integer :: x
-
-  ! CHECK: not yet implemented: OmpAllocateDirective
-  !$omp allocate(x) align(32)
-end
diff --git a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90 b/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90
deleted file mode 100644
index 7cae8051fda77..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90
+++ /dev/null
@@ -1,10 +0,0 @@
-! This test checks lowering of OpenMP allocate Directive.
-
-! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s
-
-program main
-  integer :: x, y
-
-  ! CHECK: not yet implemented: OmpAllocateDirective
-  !$omp allocate(x, y)
-end
diff --git a/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90 b/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90
new file mode 100644
index 0000000000000..50c6ab1f64002
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90
@@ -0,0 +1,47 @@
+! This test checks lowering of OpenMP allocate Directive with align and allocator
+! clauses to LLVM IR. Verifies code generation for:
+!   - align(16) only (null allocator)
+!   - allocator(omp_default_mem_alloc) only (no align)
+!   - align(64) allocator(omp_cgroup_mem_alloc) (both clauses, array variable)
+!   - align(32) allocator(3) (both clauses, multiple variables)
+
+! RUN: %flang_fc1 -emit-llvm %openmp_flags -fopenmp-version=51 %s -o - 2>&1 | FileCheck %s
+
+program main
+  use omp_lib
+  integer :: x, y
+  integer :: z(10)
+  character c
+  real(kind = 16) :: r
+  complex cmplx
+  !$omp allocate(x) align(16)
+  !$omp allocate(y) allocator(omp_default_mem_alloc)
+  !$omp allocate(z) align(64) allocator(omp_cgroup_mem_alloc)
+  !$omp allocate(c, r, cmplx) align(32) allocator(3)
+  x = 1
+  y = 2
+  z = x + y
+  print *, "z : ", z
+end program
+
+! CHECK: define void @_QQmain()
+! CHECK: call i32 @__kmpc_global_thread_num(
+
+! CHECK: call ptr @__kmpc_aligned_alloc(i32 {{.*}}, i64 16, i64 {{.*}}, ptr null)
+! CHECK: call ptr @__kmpc_alloc(i32 {{.*}}, i64 {{.*}}, ptr inttoptr (i64 1 to ptr))
+! CHECK: call ptr @__kmpc_aligned_alloc(i32 {{.*}}, i64 64, i64 {{.*}}, ptr inttoptr (i64 6 to ptr))
+! CHECK: call ptr @__kmpc_aligned_alloc(i32 {{.*}}, i64 32, i64 {{.*}}, ptr inttoptr (i32 3 to ptr))
+! CHECK: call ptr @__kmpc_aligned_alloc(i32 {{.*}}, i64 32, i64 {{.*}}, ptr inttoptr (i32 3 to ptr))
+! CHECK: call ptr @__kmpc_aligned_alloc(i32 {{.*}}, i64 32, i64 {{.*}}, ptr inttoptr (i32 3 to ptr))
+
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr inttoptr (i32 3 to ptr))
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr inttoptr (i32 3 to ptr))
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr inttoptr (i32 3 to ptr))
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr inttoptr (i64 6 to ptr))
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr inttoptr (i64 1 to ptr))
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr null)
+! CHECK: ret void
+
+! CHECK: declare noalias ptr @__kmpc_aligned_alloc(i32, i64, i64, ptr)
+! CHECK: declare noalias ptr @__kmpc_alloc(i32, i64, ptr)
+! CHECK: declare void @__kmpc_free(i32, ptr, ptr)
diff --git a/flang/test/Lower/OpenMP/omp-declarative-allocate.f90 b/flang/test/Lower/OpenMP/omp-declarative-allocate.f90
new file mode 100644
index 0000000000000..7c8047ebf7f53
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-declarative-allocate.f90
@@ -0,0 +1,19 @@
+! This test checks lowering of OpenMP allocate Directive to LLVM IR.
+! Verifies code generation for default (no align, null allocator) case.
+
+! RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s
+
+program main
+  integer :: x, y
+  !$omp allocate(x, y)
+end program
+
+! CHECK: define void @_QQmain()
+! CHECK: call i32 @__kmpc_global_thread_num(
+! CHECK: call ptr @__kmpc_alloc(i32 {{.*}}, i64 8, ptr null)
+! CHECK: call ptr @__kmpc_alloc(i32 {{.*}}, i64 8, ptr null)
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr null)
+! CHECK: call void @__kmpc_free(i32 {{.*}}, ptr {{.*}}, ptr null)
+! CHECK: ret void
+! CHECK: declare noalias ptr @__kmpc_alloc(i32, i64, ptr)
+! CHECK: declare void @__kmpc_free(i32, ptr, ptr)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 042c3c75e9cb8..f18075c146408 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3180,7 +3180,7 @@ class OpenMPIRBuilder {
                                                   llvm::IntegerType *IntPtrTy,
                                                   bool BranchtoEnd = true);
 
-  /// Create a runtime call for kmpc_Alloc
+  /// Create a runtime call for kmpc_alloc
   ///
   /// \param Loc The insert and source location description.
   /// \param Size Size of allocated memory space
@@ -3191,6 +3191,19 @@ class OpenMPIRBuilder {
   LLVM_ABI CallInst *createOMPAlloc(const LocationDescription &Loc, Value *Size,
                                     Value *Allocator, std::string Name = "");
 
+  /// Create a runtime call for kmpc_align_alloc
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Align Align value
+  /// \param Size Size of allocated memory space
+  /// \param Allocator Allocator information instruction
+  /// \param Name Name of call Instruction for OMP_Align_Alloc
+  ///
+  /// \returns CallInst to the OMP_Align_Alloc call
+  LLVM_ABI CallInst *createOMPAlignedAlloc(const LocationDescription &Loc,
+                                    Value *Align, Value *Size, Value *Allocator,
+                                    std::string Name = "");
+
   /// Create a runtime call for kmpc_free
   ///
   /// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 8148e113195cc..ef88289803512 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7619,6 +7619,23 @@ CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
   return createRuntimeFunctionCall(Fn, Args, Name);
 }
 
+CallInst *OpenMPIRBuilder::createOMPAlignedAlloc(const LocationDescription &Loc,
+                                          Value *Align, Value *Size, Value *Allocator,
+                                          std::string Name) {
+  IRBuilder<>::InsertPointGuard IPG(Builder);
+  updateToLocation(Loc);
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+  Value *ThreadId = getOrCreateThreadID(Ident);
+  Value *Args[] = {ThreadId, Align, Size, Allocator};
+
+  Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_aligned_alloc);
+
+  return Builder.CreateCall(Fn, Args, Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
                                          Value *Addr, Value *Allocator,
                                          std::string Name) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 23c2fbdfd7368..42aa9c7d4f4f0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -146,11 +146,11 @@ class OpenMP_AllocatorClauseSkip<
                     extraClassDeclaration> {
 
   let arguments = (ins
-    Optional<I64>:$allocator
+    Optional<AnyInteger>:$allocator
   );
 
   let optAssemblyFormat = [{
-    `allocator` `(` $allocator `)`
+    `allocator` `(` $allocator `:` type($allocator) `)`
   }];
 
   let description = [{
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index c67bb57985bd0..f073081002719 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -272,6 +272,11 @@ class ModuleTranslation {
   /// constructed.
   llvm::OpenMPIRBuilder *getOpenMPBuilder();
 
+  /// Registers a pending __kmpc_free call for the given block. These are
+  /// emitted before the block's terminator during block conversion.
+  void registerPendingOmpAllocateFree(Block *block, llvm::Value *ptr,
+                                      llvm::Value *allocator);
+
   /// Returns the LLVM module in which the IR is being constructed.
   llvm::Module *getLLVMModule() { return llvmModule.get(); }
 
@@ -401,6 +406,9 @@ class ModuleTranslation {
                                  llvm::IRBuilderBase &builder,
                                  bool recordInsertions);
 
+  /// Emits pending __kmpc_free calls for the block, before its terminator.
+  void emitPendingOmpAllocateFrees(Block &bb, llvm::IRBuilderBase &builder);
+
   /// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
   /// TBAATagAttr.
   llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
@@ -509,6 +517,11 @@ class ModuleTranslation {
   /// block.
   DenseMap<BlockAddressAttr, llvm::BasicBlock *> blockAddressToLLVMMapping;
 
+  /// Pending __kmpc_free calls per block, emitted before the terminator.
+  DenseMap<Block *,
+           llvm::SmallVector<std::pair<llvm::Value *, llvm::Value *>>>
+      pendingOmpAllocateFrees;
+
   /// Stack of user-specified state elements, useful when translating operations
   /// with regions.
   StateStack stack;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 37b1a37c2e1a5..0ce9c55f8ec95 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4595,6 +4595,22 @@ static Operation *getGlobalOpFromValue(Value value) {
   return nullptr;
 }
 
+static Value getBaseValueForTypeLookup(Value value) {
+  while (Operation *op = value.getDefiningOp()) {
+    if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
+      value = addrCast.getOperand();
+    else if (op->getName().getIdentifier()) {
+      if (op->getNumOperands() > 0)
+        value = op->getOperand(0);
+      else
+        break;
+    } else {
+      break;
+    }
+  }
+  return value;
+}
+
 static llvm::SmallString<64>
 getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
                              llvm::OpenMPIRBuilder &ompBuilder) {
@@ -7430,6 +7446,94 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation) {
+  auto allocateDirOp = cast<omp::AllocateDirOp>(opInst);
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+  llvm::DataLayout dataLayout = llvmModule->getDataLayout();
+  SmallVector<Value> vars = allocateDirOp.getVarList();
+  std::optional<int64_t> alignAttr = allocateDirOp.getAlign();
+
+  llvm::Value *allocator;
+  if (auto allocatorVar = allocateDirOp.getAllocator()) {
+    allocator = moduleTranslation.lookupValue(allocatorVar);
+    if (allocator->getType()->isIntegerTy())
+      allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy());
+    else if (allocator->getType()->isPointerTy())
+      allocator =
+          builder.CreatePointerBitCastOrAddrSpaceCast(allocator, builder.getPtrTy());
+  } else {
+    allocator = llvm::ConstantPointerNull::get(builder.getPtrTy());
+  }
+
+  SmallVector<std::pair<llvm::CallInst *, llvm::Value *>> allocatedVars;
+
+  for (Value var : vars) {
+    llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType());
+
+    // Opaque pointers lose element type. Trace to GlobalOp for type
+    // Falls back to llvmVarTy when not from a global.
+    llvm::Type *typeToInspect = llvmVarTy;
+    if (llvmVarTy->isPointerTy()) {
+      Value baseVar = getBaseValueForTypeLookup(var);
+      if (Operation *globalOp = getGlobalOpFromValue(baseVar)) {
+        if (auto gop = dyn_cast<LLVM::GlobalOp>(globalOp))
+          typeToInspect =
+              moduleTranslation.convertType(gop.getGlobalType());
+      }
+    }
+
+    llvm::Value *size;
+    if (auto arrTy = llvm::dyn_cast<llvm::ArrayType>(typeToInspect)) {
+      llvm::Value *elementCount = builder.getInt64(1);
+      llvm::Type *currentType = arrTy;
+      while (auto nestedArrTy = llvm::dyn_cast<llvm::ArrayType>(currentType)) {
+        elementCount = builder.CreateMul(
+            elementCount, builder.getInt64(nestedArrTy->getNumElements()));
+        currentType = nestedArrTy->getElementType();
+      }
+      uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType);
+      size = builder.CreateMul(elementCount,
+                              builder.getInt64(elemSizeInBits / 8));
+    } else {
+      size = builder.getInt64(
+          dataLayout.getTypeStoreSize(typeToInspect).getFixedValue());
+    }
+
+    uint64_t alignValue =
+        alignAttr ? alignAttr.value()
+                  : dataLayout.getABITypeAlign(typeToInspect).value();
+    llvm::Value *alignConst = builder.getInt64(alignValue);
+    // Align the size: ((size + align - 1) / align) * align
+    size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true);
+    size = builder.CreateUDiv(size, alignConst);
+    size = builder.CreateMul(size, alignConst, "", true);
+
+    std::string allocName =
+        ompBuilder->createPlatformSpecificName({".void.addr"});
+    llvm::CallInst *allocCall;
+    if (alignAttr.has_value()) {
+      allocCall = ompBuilder->createOMPAlignedAlloc(
+          ompLoc, builder.getInt64(alignAttr.value()), size, allocator, allocName);
+    } else {
+      allocCall = ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName);
+    }
+    allocatedVars.push_back({allocCall, allocator});
+  }
+
+  // Register __kmpc_free calls to be emitted before the block terminator.
+  Block *block = allocateDirOp->getBlock();
+  for (auto &alloc : allocatedVars)
+    moduleTranslation.registerPendingOmpAllocateFree(block, alloc.first,
+                                                    alloc.second);
+
+  return success();
+}
+
 static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
                              ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/187167


More information about the Mlir-commits mailing list