[Mlir-commits] [llvm] [mlir] [Flang][MLIR][OpenMP] Add explicit shared memory (de-)allocation ops (PR #161862)

Sergio Afonso llvmlistbot at llvm.org
Mon Apr 27 05:13:23 PDT 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161862

>From 34533893a8241a6d7a8a46b650c9244eb6d6f8fa Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 12 Sep 2025 15:56:04 +0100
Subject: [PATCH] [Flang][MLIR][OpenMP] Add explicit shared memory
 (de-)allocation ops

This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem`
operations to represent explicit allocations and deallocations of shared memory
across threads in a team, mirroring the existing `omp.target_allocmem` and
`omp.target_freemem`.

The `omp.alloc_shared_mem` op goes through the same Flang-specific
transformations as `omp.target_allocmem`, so that the size of the buffer can be
properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be
created for implicit allocations (e.g. privatized and reduction variables).
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 23 ++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 29 +++++--
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 37 +++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 70 ++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 31 ++++++--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 79 +++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 35 ++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 32 +++++++-
 .../LLVMIR/omptarget-device-shared-mem.mlir   | 42 ++++++++++
 9 files changed, 353 insertions(+), 25 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index e0093656758cd..00b951505268b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3266,6 +3266,17 @@ class OpenMPIRBuilder {
   LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
                                    Value *Allocator, std::string Name = "");
 
+  /// Create a runtime call for kmpc_alloc_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_alloc_shared call.
+  LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
+                                          Value *Size,
+                                          const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_alloc_shared.
   ///
   /// \param Loc The insert and source location description.
@@ -3277,6 +3288,18 @@ class OpenMPIRBuilder {
                                           Type *VarType,
                                           const Twine &Name = Twine(""));
 
+  /// Create a runtime call for kmpc_free_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_free_shared call.
+  LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
+                                         Value *Addr, Value *Size,
+                                         const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_free_shared.
   ///
   /// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 767fd91a27d4f..ad07fcfc1957f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7930,32 +7930,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
 }
 
 CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
-                                                Type *VarType,
+                                                Value *Size,
                                                 const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  const DataLayout &DL = M.getDataLayout();
-  Value *Args[] = {Builder.getInt64(DL.getTypeAllocSize(VarType))};
+  Value *Args[] = {Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
   CallInst *Call = Builder.CreateCall(Fn, Args, Name);
-  Call->addRetAttr(
-      Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
+  Call->addRetAttr(Attribute::getWithAlignment(
+      M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
   return Call;
 }
 
+CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
+                                                Type *VarType,
+                                                const Twine &Name) {
+  return createOMPAllocShared(
+      Loc, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)), Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
-                                               Value *Addr, Type *VarType,
+                                               Value *Addr, Value *Size,
                                                const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  Value *Args[] = {
-      Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType))};
+  Value *Args[] = {Addr, Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
   return Builder.CreateCall(Fn, Args, Name);
 }
 
+CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
+                                               Value *Addr, Type *VarType,
+                                               const Twine &Name) {
+  return createOMPFreeShared(
+      Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)),
+      Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPInteropInit(
     const LocationDescription &Loc, Value *InteropVar,
     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 5a8077e251db8..6270e05b77780 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -922,6 +922,43 @@ class OpenMP_MapClauseSkip<
 
 def OpenMP_MapClause : OpenMP_MapClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to memory allocation information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_MemAllocationSizeClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+
+  let arguments = (ins
+    TypeAttr:$mem_elem_type,
+    AnySignlessInteger:$mem_array_size,
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$mem_alignment
+  );
+
+  let reqAssemblyFormat = [{
+    $mem_array_size `x` $mem_elem_type `:` `(` type($mem_array_size) `)`
+  }];
+
+  let optAssemblyFormat = [{
+    `align` `(` $mem_alignment `)`
+  }];
+
+  let description = [{
+    The `mem_elem_type` is the type of the object the memory allocation refers
+    to. It is used to calculate the size of the allocation.
+
+    The `mem_array_size` is the number of objects.
+
+    The optional `mem_alignment` is used to specify the alignment for each
+    element. If not set, the `DataLayout` defaults will be used instead.
+  }];
+}
+
+def OpenMP_MemAllocationSizeClause : OpenMP_MemAllocationSizeClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [15.8.1] `memory-order` clause set
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 858bf673537c1..9954e1f233542 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2323,6 +2323,76 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
   let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
 }
 
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
+    MemoryEffects<[MemAlloc<DefaultResource>]>
+  ], clauses = [
+    OpenMP_MemAllocationSizeClause
+  ]> {
+  let summary = "allocate storage on shared memory for objects of a given type";
+
+  let description = [{
+    Allocates memory shared across threads of a team for an object of the given
+    type. Returns a pointer representing the allocated memory. The memory is
+    uninitialized after allocation. Operations must be paired with
+    `omp.free_shared` to avoid memory leaks.
+
+    ```mlir
+      // Allocate an i32 vector with %size elements and aligned to 8 bytes.
+      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) align(8) -> !llvm.ptr
+      // ...
+      omp.free_shared_mem [%size x i32 : (i64) align(8)] %ptr_shared : !llvm.ptr
+    ```
+  }] # clausesDescription;
+
+  let results = (outs OpenMP_PointerLikeType);
+  let assemblyFormat = clausesReqAssemblyFormat # " oilist(" #
+    clausesOptAssemblyFormat # ") `->` type(results) attr-dict";
+
+  let extraClassDeclaration = [{
+    mlir::Type getAllocatedType() { return getMemElemTypeAttr().getValue(); }
+  }];
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", traits = [
+    MemoryEffects<[MemFree]>
+  ], clauses = [
+    OpenMP_MemAllocationSizeClause
+  ]> {
+  let summary = "free shared memory";
+
+  let description = [{
+    Deallocates shared memory that was previously allocated by an
+    `omp.alloc_shared_mem` operation. After this operation, the deallocated
+    memory is in an undefined state and should not be accessed.
+
+    ```mlir
+      // Example of allocating and freeing shared memory.
+      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> !llvm.ptr
+      // ...
+      omp.free_shared_mem [%size x i32 : (i64)] %ptr_shared : !llvm.ptr
+    ```
+
+    The `heapref` operand represents the pointer to shared memory to be
+    deallocated, previously returned by `omp.alloc_shared_mem`.
+  }] # clausesDescription;
+
+  let arguments = !con(clausesArgs, (ins
+    Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
+  ));
+  let assemblyFormat = "` ` `[`" # clausesReqAssemblyFormat # " oilist(" #
+    clausesOptAssemblyFormat # ") `]` $heapref `:` type($heapref) attr-dict";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // workdistribute Construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5c52a309544f2..ab86081b6c995 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4690,17 +4690,34 @@ LogicalResult ScanOp::verify() {
 }
 
 /// Verifies align clause in allocate directive
+LogicalResult verifyAlignment(Operation &op,
+                              std::optional<uint64_t> alignment) {
+  if (alignment.has_value()) {
+    if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
+      return op.emitError()
+             << "ALIGN value : " << alignment.value() << " must be power of 2";
+  }
+  return success();
+}
 
 LogicalResult AllocateDirOp::verify() {
-  std::optional<uint64_t> align = this->getAlign();
+  return verifyAlignment(*getOperation(), getAlign());
+}
 
-  if (align.has_value()) {
-    if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
-      return emitError() << "ALIGN value : " << align.value()
-                         << " must be power of 2";
-  }
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
 
-  return success();
+LogicalResult AllocSharedMemOp::verify() {
+  return verifyAlignment(*getOperation(), getMemAlignment());
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FreeSharedMemOp::verify() {
+  return verifyAlignment(*getOperation(), getMemAlignment());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 06354aa17e323..e0cccbde6b442 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7816,6 +7816,47 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
   return func;
 }
 
+template <typename T>
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation, T op) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy =
+      moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
+
+  auto alignment = op.getMemAlignment();
+  llvm::TypeSize typeSize = llvm::alignTo(
+      dataLayout.getTypeStoreSize(llvmHeapTy),
+      alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
+
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  return builder.CreateMul(
+      allocSize,
+      builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
+                            builder.getInt64Ty(),
+                            /*isSigned=*/false));
+}
+
+template <>
+llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
+                               LLVM::ModuleTranslation &moduleTranslation,
+                               omp::TargetAllocMemOp op) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
+  llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  for (auto typeParam : op.getTypeparams()) {
+    allocSize = builder.CreateMul(
+        allocSize,
+        builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
+                              builder.getInt64Ty(),
+                              /*isSigned=*/false));
+  }
+  return allocSize;
+}
+
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -7830,14 +7871,8 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   mlir::Value deviceNum = allocMemOp.getDevice();
   llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
   // Get the allocation size.
-  llvm::DataLayout dataLayout = llvmModule->getDataLayout();
-  mlir::Type heapTy = allocMemOp.getAllocatedType();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
-  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
-  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  for (auto typeParam : allocMemOp.getTypeparams())
-    allocSize =
-        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  llvm::Value *allocSize =
+      getAllocationSize(builder, moduleTranslation, allocMemOp);
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -7848,6 +7883,17 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
+                        llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
+  moduleTranslation.mapValue(allocMemOp.getResult(),
+                             ompBuilder->createOMPAllocShared(builder, size));
+  return success();
+}
+
 static LogicalResult
 convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation,
@@ -7998,6 +8044,17 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
+                       llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
+  ompBuilder->createOMPFreeShared(
+      builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
+  return success();
+}
+
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
 /// OpenMP runtime calls).
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
@@ -8215,6 +8272,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
             return convertAllocateFreeOp(*op, builder, moduleTranslation,
                                          *this);
           })
+          .Case([&](omp::AllocSharedMemOp op) {
+            return convertAllocSharedMemOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::FreeSharedMemOp op) {
+            return convertFreeSharedMemOp(op, builder, moduleTranslation);
+          })
           .Default([&](Operation *inst) {
             return inst->emitError()
                    << "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 734d28b95b6fb..1a3bd678621b4 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3463,3 +3463,38 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
   %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
   return
 }
+
+// -----
+func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () {
+  // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) align(-2) -> !llvm.ptr
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () {
+  // expected-error @below {{ALIGN value : 3 must be power of 2}}
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) align(3) -> !llvm.ptr
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_array_size(%n: f32, %ptr : !llvm.ptr) -> () {
+  // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}}
+  %0 = omp.free_shared_mem [%n x i64 : (f32)] %ptr : !llvm.ptr
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_alignment1(%n: i32, %ptr : !llvm.ptr) -> () {
+  // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  omp.free_shared_mem [%n x i64 : (i32) align(-2)] %ptr : !llvm.ptr
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_alignment2(%n: i32, %ptr : !llvm.ptr) -> () {
+  // expected-error @below {{ALIGN value : 3 must be power of 2}}
+  omp.free_shared_mem [%n x i64 : (i32) align(3)] %ptr : !llvm.ptr
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d7ef7880e051d..925f6c5614899 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3923,9 +3923,37 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
 }
 
 // CHECK-LABEL: func.func @omp_target_freemem(
-// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
-func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+// CHECK-SAME: %[[DEVICE:.*]]: i32) {
+func.func @omp_target_freemem(%device : i32) {
+  // CHECK: %[[PTR:.*]] = omp.target_allocmem
+  %ptr = omp.target_allocmem %device : i32, i64
   // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
   omp.target_freemem %device, %ptr : i32, i64
   return
 }
+
+// CHECK-LABEL: func.func @omp_alloc_shared_mem(
+// CHECK-SAME: %[[N:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%n: i32) {
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x i64 : (i32) -> !llvm.ptr
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr
+  %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr : (i32) align(16) -> !llvm.ptr
+  %2 = omp.alloc_shared_mem %n x !llvm.ptr : (i32) align(16) -> !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem(
+// CHECK-SAME: %[[N:.*]]: i64) {
+func.func @omp_free_shared_mem(%n: i64) {
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr
+  %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64)] %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem [%n x f32 : (i64)] %0 : !llvm.ptr
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) align(32) -> !llvm.ptr
+  %1 = omp.alloc_shared_mem %n x f32 : (i64) align(32) -> !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64) align(32)] %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem [%n x f32 : (i64) align(32)] %1 : !llvm.ptr
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
new file mode 100644
index 0000000000000..cdebebc3ed233
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+  // CHECK-LABEL: define void @device_shared_mem(
+  // CHECK-SAME:  i32 %[[N0:.*]], i64 %[[N1:.*]])
+  llvm.func @device_shared_mem(%n0: i32, %n1: i64) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
+    // CHECK:      %[[CAST_N0:.*]] = zext i32 %[[N0]] to i64
+    // CHECK-NEXT: %[[ALLOC0_SZ:.*]] = mul i64 8, %[[CAST_N0]]
+    // CHECK-NEXT: %[[ALLOC0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC0_SZ]])
+    %0 = omp.alloc_shared_mem %n0 x i64 : (i32) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC1_SZ:.*]] = mul i64 8, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC1_SZ]])
+    %1 = omp.alloc_shared_mem %n1 x i64 : (i64) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC2_SZ:.*]] = mul i64 64, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC2_SZ]])
+    %2 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]])
+    %3 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) align(128) -> !llvm.ptr
+
+    // CHECK:      %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64
+    // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]])
+    omp.free_shared_mem [%n0 x i64 : (i32)] %0 : !llvm.ptr
+
+    // CHECK:      %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]])
+    omp.free_shared_mem [%n1 x i64 : (i64)] %1 : !llvm.ptr
+
+    // CHECK:      %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]])
+    omp.free_shared_mem [%n1 x vector<16xf32> : (i64)] %2 : !llvm.ptr
+
+    // CHECK:      %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]])
+    omp.free_shared_mem [%n1 x vector<16xf32> : (i64) align(128)] %3 : !llvm.ptr
+    llvm.return
+  }
+}



More information about the Mlir-commits mailing list