[llvm-branch-commits] [flang] [llvm] [mlir] [Flang][MLIR][OpenMP] Add explicit shared memory (de-)allocation ops (PR #161862)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 23 05:49:52 PST 2026


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161862

>From 5d3807333c6d2b57d042dcb0b5ae27deeb024908 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 12 Sep 2025 15:56:04 +0100
Subject: [PATCH 1/3] [Flang][MLIR][OpenMP] Add explicit shared memory
 (de-)allocation ops

This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem`
operations to represent explicit allocations and deallocations of shared memory
across threads in a team, mirroring the existing `omp.target_allocmem` and
`omp.target_freemem`.

The `omp.alloc_shared_mem` op goes through the same Flang-specific
transformations as `omp.target_allocmem`, so that the size of the buffer can be
properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be
created for implicit allocations (e.g. privatized and reduction variables).
---
 flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++-----
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 23 +++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 29 ++++++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 61 ++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 11 ++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 64 ++++++++++++++++---
 mlir/test/Dialect/OpenMP/invalid.mlir         | 21 ++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 31 ++++++++-
 8 files changed, 249 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 3e1fe1d2b1613..13214a9e51161 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,36 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
   return converter.convertType(firType);
 }
 
-// FIR Op specific conversion for TargetAllocMemOp
-struct TargetAllocMemOpConversion
-    : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
-  using OpenMPFIROpConversion::OpenMPFIROpConversion;
+// FIR Op specific conversion for allocation operations
+template <typename T>
+struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
+  using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
 
   llvm::LogicalResult
-  matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+  matchAndRewrite(T allocmemOp,
+                  typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Type heapTy = allocmemOp.getAllocatedType();
     mlir::Location loc = allocmemOp.getLoc();
-    auto ity = lowerTy().indexType();
+    auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
     mlir::Type dataTy = fir::unwrapRefType(heapTy);
-    mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+    mlir::Type llvmObjectTy =
+        convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
     if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
-      TODO(loc, "omp.target_allocmem codegen of derived type with length "
-                "parameters");
+      TODO(loc, allocmemOp->getName().getStringRef() +
+                    " codegen of derived type with length parameters");
     mlir::Value size = fir::computeElementDistance(
-        loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+        loc, llvmObjectTy, ity, rewriter,
+        OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
     if (auto scaleSize = fir::genAllocationScaleSize(
             loc, allocmemOp.getInType(), ity, rewriter))
       size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
-    for (mlir::Value opnd : adaptor.getOperands().drop_front())
+    for (mlir::Value opnd : adaptor.getTypeparams())
+      size = mlir::LLVM::MulOp::create(
+          rewriter, loc, ity, size,
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    for (mlir::Value opnd : adaptor.getShape())
       size = mlir::LLVM::MulOp::create(
           rewriter, loc, ity, size,
-          integerCast(lowerTy(), loc, rewriter, ity, opnd));
-    auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    auto mallocTyWidth =
+        OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
     auto mallocTy =
         mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
     if (mallocTyWidth != ity.getIntOrFloatBitWidth())
-      size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+      size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
+                         mallocTy, size);
     rewriter.modifyOpInPlace(allocmemOp, [&]() {
       allocmemOp.setInType(rewriter.getI8Type());
       allocmemOp.getTypeparamsMutable().clear();
@@ -281,6 +292,7 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
   patterns.add<MapInfoOpConversion>(converter);
   patterns.add<PrivateClauseOpConversion>(converter);
-  patterns.add<TargetAllocMemOpConversion>(converter);
   patterns.add<DeclareMapperOpConversion>(converter);
+  patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
+               AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ca77f4c302c2b..865be4d1f1c93 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3215,6 +3215,17 @@ class OpenMPIRBuilder {
   LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
                                    Value *Allocator, std::string Name = "");
 
+  /// Create a runtime call for kmpc_alloc_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_alloc_shared call.
+  LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
+                                          Value *Size,
+                                          const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_alloc_shared.
   ///
   /// \param Loc The insert and source location description.
@@ -3226,6 +3237,18 @@ class OpenMPIRBuilder {
                                           Type *VarType,
                                           const Twine &Name = Twine(""));
 
+  /// Create a runtime call for kmpc_free_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_free_shared call.
+  LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
+                                         Value *Addr, Value *Size,
+                                         const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_free_shared.
   ///
   /// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 022097bea40e6..fb768c2fe443c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7796,32 +7796,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
 }
 
 CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
-                                                Type *VarType,
+                                                Value *Size,
                                                 const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  const DataLayout &DL = M.getDataLayout();
-  Value *Args[] = {Builder.getInt64(DL.getTypeAllocSize(VarType))};
+  Value *Args[] = {Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
   CallInst *Call = Builder.CreateCall(Fn, Args, Name);
-  Call->addRetAttr(
-      Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
+  Call->addRetAttr(Attribute::getWithAlignment(
+      M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
   return Call;
 }
 
+CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
+                                                Type *VarType,
+                                                const Twine &Name) {
+  return createOMPAllocShared(
+      Loc, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)), Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
-                                               Value *Addr, Type *VarType,
+                                               Value *Addr, Value *Size,
                                                const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  Value *Args[] = {
-      Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType))};
+  Value *Args[] = {Addr, Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
   return Builder.CreateCall(Fn, Args, Name);
 }
 
+CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
+                                               Value *Addr, Type *VarType,
+                                               const Twine &Name) {
+  return createOMPFreeShared(
+      Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)),
+      Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPInteropInit(
     const LocationDescription &Loc, Value *InteropVar,
     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 28122f4d2ae89..fe8f01d64691a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2235,6 +2235,67 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
   let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
 }
 
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
+    AttrSizedOperandSegments
+  ], clauses = [
+    OpenMP_HeapAllocClause
+  ]> {
+  let summary = "allocate storage on shared memory for an object of a given type";
+
+  let description = [{
+    Allocates memory shared across threads of a team for an object of the given
+    type. Returns a pointer representing the allocated memory. The memory is
+    uninitialized after allocation. Operations must be paired with
+    `omp.free_shared` to avoid memory leaks.
+
+    ```mlir
+      // Allocate a static 3x3 integer vector.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+  }] # clausesDescription;
+
+  let results = (outs OpenMP_PointerLikeType);
+  let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
+  let summary = "free shared memory";
+
+  let description = [{
+    Deallocates shared memory that was previously allocated by an
+    `omp.alloc_shared_mem` operation. After this operation, the deallocated
+    memory is in an undefined state and should not be accessed.
+    It is crucial to ensure that all accesses to the memory region are completed
+    before `omp.alloc_shared_mem` is called to avoid undefined behavior.
+
+    ```mlir
+      // Example of allocating and freeing shared memory.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+
+    The `heapref` operand represents the pointer to shared memory to be
+    deallocated, previously returned by `omp.alloc_shared_mem`.
+  }];
+
+  let arguments = (ins
+    Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
+  );
+  let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // workdistribute Construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cdd3c495ffd44..75dbeced94467 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4424,6 +4424,17 @@ LogicalResult AllocateDirOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FreeSharedMemOp::verify() {
+  return getHeapref().getDefiningOp<AllocSharedMemOp>()
+             ? success()
+             : emitOpError() << "'heapref' operand must be defined by an "
+                                "'omp.alloc_shared_memory' op";
+}
+
 //===----------------------------------------------------------------------===//
 // WorkdistributeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a4394dd3fa6e6..e3caae2cac947 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7202,6 +7202,25 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
   return func;
 }
 
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
+                  OperandRange typeparams, OperandRange shape) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
+  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  for (auto typeParam : typeparams) {
+    allocSize = builder.CreateMul(
+        allocSize,
+        builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
+                              builder.getInt64Ty(),
+                              /*isSigned=*/false));
+  }
+  return allocSize;
+}
+
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -7216,14 +7235,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   mlir::Value deviceNum = allocMemOp.getDevice();
   llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
   // Get the allocation size.
-  llvm::DataLayout dataLayout = llvmModule->getDataLayout();
-  mlir::Type heapTy = allocMemOp.getAllocatedType();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
-  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
-  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  for (auto typeParam : allocMemOp.getTypeparams())
-    allocSize =
-        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  llvm::Value *allocSize = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -7234,6 +7248,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
+                        llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  moduleTranslation.mapValue(allocMemOp.getResult(),
+                             ompBuilder->createOMPAllocShared(builder, size));
+  return success();
+}
+
 static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
                                         llvm::Module *llvmModule) {
   llvm::Type *ptrTy = builder.getPtrTy(0);
@@ -7269,6 +7296,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
+                       llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto allocMemOp =
+      freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  ompBuilder->createOMPFreeShared(
+      builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
+  return success();
+}
+
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
 /// OpenMP runtime calls).
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
@@ -7464,6 +7506,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
           .Case([&](omp::TargetFreeMemOp) {
             return convertTargetFreeMemOp(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::AllocSharedMemOp op) {
+            return convertAllocSharedMemOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::FreeSharedMemOp op) {
+            return convertFreeSharedMemOp(op, builder, moduleTranslation);
+          })
           .Default([&](Operation *inst) {
             return inst->emitError()
                    << "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 52c5d861a0bd7..3cdc26765098e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3181,3 +3181,24 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
   %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
   return
 }
+
+// -----
+func.func @alloc_shared_mem_invalid_uniq_name() -> () {
+  // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_bindc_name() -> () {
+  // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
+  // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}}
+  omp.free_shared_mem %ptr : !llvm.ptr
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 1f24796ad71f0..c95128d17e5e9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3568,9 +3568,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
 }
 
 // CHECK-LABEL: func.func @omp_target_freemem(
-// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
-func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+// CHECK-SAME: %[[DEVICE:.*]]: i32) {
+func.func @omp_target_freemem(%device : i32) {
+  // CHECK: %[[PTR:.*]] = omp.target_allocmem
+  %ptr = omp.target_allocmem %device : i32, i64
   // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
   omp.target_freemem %device, %ptr : i32, i64
   return
 }
+
+// CHECK-LABEL: func.func @omp_alloc_shared_mem(
+// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) {
+  // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr
+  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr
+  %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr
+  %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr
+  %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr
+  %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem() {
+func.func @omp_free_shared_mem() {
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem
+  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+  // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem %0 : !llvm.ptr
+  return
+}

>From e1744f3f323882fb7f167de4a91e6653be7f6d9f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Thu, 5 Feb 2026 13:06:23 +0000
Subject: [PATCH 2/3] simplify omp.alloc_shared_mem

---
 flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++-----------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 37 ++++++++++++----
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 23 +++++++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 44 +++++++++++++------
 mlir/test/Dialect/OpenMP/invalid.mlir         | 19 +++++---
 mlir/test/Dialect/OpenMP/ops.mlir             | 35 +++++++--------
 .../LLVMIR/omptarget-device-shared-mem.mlir   | 42 ++++++++++++++++++
 7 files changed, 160 insertions(+), 82 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir

diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 13214a9e51161..3e1fe1d2b1613 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,47 +222,36 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
   return converter.convertType(firType);
 }
 
-// FIR Op specific conversion for allocation operations
-template <typename T>
-struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
-  using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+    : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+  using OpenMPFIROpConversion::OpenMPFIROpConversion;
 
   llvm::LogicalResult
-  matchAndRewrite(T allocmemOp,
-                  typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
+  matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Type heapTy = allocmemOp.getAllocatedType();
     mlir::Location loc = allocmemOp.getLoc();
-    auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
+    auto ity = lowerTy().indexType();
     mlir::Type dataTy = fir::unwrapRefType(heapTy);
-    mlir::Type llvmObjectTy =
-        convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
+    mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
     if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
-      TODO(loc, allocmemOp->getName().getStringRef() +
-                    " codegen of derived type with length parameters");
+      TODO(loc, "omp.target_allocmem codegen of derived type with length "
+                "parameters");
     mlir::Value size = fir::computeElementDistance(
-        loc, llvmObjectTy, ity, rewriter,
-        OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
+        loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
     if (auto scaleSize = fir::genAllocationScaleSize(
             loc, allocmemOp.getInType(), ity, rewriter))
       size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
-    for (mlir::Value opnd : adaptor.getTypeparams())
-      size = mlir::LLVM::MulOp::create(
-          rewriter, loc, ity, size,
-          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
-                      opnd));
-    for (mlir::Value opnd : adaptor.getShape())
+    for (mlir::Value opnd : adaptor.getOperands().drop_front())
       size = mlir::LLVM::MulOp::create(
           rewriter, loc, ity, size,
-          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
-                      opnd));
-    auto mallocTyWidth =
-        OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
+          integerCast(lowerTy(), loc, rewriter, ity, opnd));
+    auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
     auto mallocTy =
         mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
     if (mallocTyWidth != ity.getIntOrFloatBitWidth())
-      size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
-                         mallocTy, size);
+      size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
     rewriter.modifyOpInPlace(allocmemOp, [&]() {
       allocmemOp.setInType(rewriter.getI8Type());
       allocmemOp.getTypeparamsMutable().clear();
@@ -292,7 +281,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
   patterns.add<MapInfoOpConversion>(converter);
   patterns.add<PrivateClauseOpConversion>(converter);
+  patterns.add<TargetAllocMemOpConversion>(converter);
   patterns.add<DeclareMapperOpConversion>(converter);
-  patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
-               AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
 }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index fe8f01d64691a..ec95360fe09c4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2240,11 +2240,15 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
 //===----------------------------------------------------------------------===//
 
 def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
-    AttrSizedOperandSegments
-  ], clauses = [
-    OpenMP_HeapAllocClause
+    MemoryEffects<[MemAlloc<DefaultResource>]>
   ]> {
-  let summary = "allocate storage on shared memory for an object of a given type";
+  let summary = "allocate storage on shared memory for objects of a given type";
+
+  let arguments = (ins
+    TypeAttr:$elem_type,
+    AnySignlessInteger:$array_size,
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment
+  );
 
   let description = [{
     Allocates memory shared across threads of a team for an object of the given
@@ -2253,15 +2257,30 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
     `omp.free_shared` to avoid memory leaks.
 
     ```mlir
-      // Allocate a static 3x3 integer vector.
-      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // Allocate an i32 vector with %size elements and aligned to 8 bytes.
+      %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr)
       // ...
       omp.free_shared_mem %ptr_shared : !llvm.ptr
     ```
-  }] # clausesDescription;
+
+    The `elem_type` is the type of the object for which memory is being
+    allocated.
+
+    The `array_size` is the number of objects to allocate memory for.
+
+    The optional `alignment` is used to specify the alignment for each element.
+    If not set, the `DataLayout` defaults will be used instead.
+  }];
 
   let results = (outs OpenMP_PointerLikeType);
-  let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+  let assemblyFormat = [{
+    $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results)
+  }];
+
+  let extraClassDeclaration = [{
+    mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); }
+  }];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2280,7 +2299,7 @@ def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
 
     ```mlir
       // Example of allocating and freeing shared memory.
-      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr)
       // ...
       omp.free_shared_mem %ptr_shared : !llvm.ptr
     ```
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 75dbeced94467..fb0d931c6c9e4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4411,17 +4411,26 @@ LogicalResult ScanOp::verify() {
 }
 
 /// Verifies align clause in allocate directive
+LogicalResult verifyAlignment(Operation &op,
+                              std::optional<uint64_t> alignment) {
+  if (alignment.has_value()) {
+    if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
+      return op.emitError()
+             << "ALIGN value : " << alignment.value() << " must be power of 2";
+  }
+  return success();
+}
 
 LogicalResult AllocateDirOp::verify() {
-  std::optional<uint64_t> align = this->getAlign();
+  return verifyAlignment(*getOperation(), getAlign());
+}
 
-  if (align.has_value()) {
-    if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
-      return emitError() << "ALIGN value : " << align.value()
-                         << " must be power of 2";
-  }
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
 
-  return success();
+LogicalResult AllocSharedMemOp::verify() {
+  return verifyAlignment(*getOperation(), getAlignment());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e3caae2cac947..94637b571aec0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7204,14 +7204,14 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
 
 static llvm::Value *
 getAllocationSize(llvm::IRBuilderBase &builder,
-                  LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
-                  OperandRange typeparams, OperandRange shape) {
+                  LLVM::ModuleTranslation &moduleTranslation,
+                  omp::TargetAllocMemOp op) {
   llvm::DataLayout dataLayout =
       moduleTranslation.getLLVMModule()->getDataLayout();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
-  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
+  llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
   llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  for (auto typeParam : typeparams) {
+  for (auto typeParam : op.getTypeparams()) {
     allocSize = builder.CreateMul(
         allocSize,
         builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
@@ -7221,6 +7221,27 @@ getAllocationSize(llvm::IRBuilderBase &builder,
   return allocSize;
 }
 
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation,
+                  omp::AllocSharedMemOp op) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
+
+  auto alignment = op.getAlignment();
+  llvm::TypeSize typeSize = llvm::alignTo(
+      dataLayout.getTypeStoreSize(llvmHeapTy),
+      alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
+
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  return builder.CreateMul(
+      allocSize,
+      builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()),
+                            builder.getInt64Ty(),
+                            /*isSigned=*/false));
+}
+
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -7235,9 +7256,8 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   mlir::Value deviceNum = allocMemOp.getDevice();
   llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
   // Get the allocation size.
-  llvm::Value *allocSize = getAllocationSize(
-      builder, moduleTranslation, allocMemOp.getAllocatedType(),
-      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  llvm::Value *allocSize =
+      getAllocationSize(builder, moduleTranslation, allocMemOp);
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -7253,9 +7273,7 @@ convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
                         llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  llvm::Value *size = getAllocationSize(
-      builder, moduleTranslation, allocMemOp.getAllocatedType(),
-      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
   moduleTranslation.mapValue(allocMemOp.getResult(),
                              ompBuilder->createOMPAllocShared(builder, size));
   return success();
@@ -7303,9 +7321,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto allocMemOp =
       freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
-  llvm::Value *size = getAllocationSize(
-      builder, moduleTranslation, allocMemOp.getAllocatedType(),
-      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
   ompBuilder->createOMPFreeShared(
       builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
   return success();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 3cdc26765098e..a14cbc9a2cb77 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3183,16 +3183,23 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
 }
 
 // -----
-func.func @alloc_shared_mem_invalid_uniq_name() -> () {
-  // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
-  %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () {
+  // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr
   return
 }
 
 // -----
-func.func @alloc_shared_mem_invalid_bindc_name() -> () {
-  // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
-  %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () {
+  // expected-error @below {{ALIGN value : 3 must be power of 2}}
+  %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () {
+  // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}}
+  %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr
   return
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c95128d17e5e9..b1f51d2b77d8b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3578,25 +3578,22 @@ func.func @omp_target_freemem(%device : i32) {
 }
 
 // CHECK-LABEL: func.func @omp_alloc_shared_mem(
-// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
-func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) {
-  // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr
-  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
-  // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr
-  %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr
-  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr
-  %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr
-  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr
-  %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr
-  // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr
-  %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: func.func @omp_free_shared_mem() {
-func.func @omp_free_shared_mem() {
-  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem
-  %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+// CHECK-SAME: %[[N:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%n: i32) {
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x i64 : (i32) -> !llvm.ptr
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr
+  %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr
+  %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem(
+// CHECK-SAME: %[[N:.*]]: i64) {
+func.func @omp_free_shared_mem(%n: i64) {
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr
+  %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr
   // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
   omp.free_shared_mem %0 : !llvm.ptr
   return
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
new file mode 100644
index 0000000000000..72b0a2daadfc3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+  // CHECK-LABEL: define void @device_shared_mem(
+  // CHECK-SAME:  i32 %[[N0:.*]], i64 %[[N1:.*]])
+  llvm.func @device_shared_mem(%n0: i32, %n1: i64) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
+    // CHECK:      %[[CAST_N0:.*]] = zext i32 %[[N0]] to i64
+    // CHECK-NEXT: %[[ALLOC0_SZ:.*]] = mul i64 8, %[[CAST_N0]]
+    // CHECK-NEXT: %[[ALLOC0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC0_SZ]])
+    %0 = omp.alloc_shared_mem %n0 x i64 : (i32) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC1_SZ:.*]] = mul i64 8, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC1_SZ]])
+    %1 = omp.alloc_shared_mem %n1 x i64 : (i64) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC2_SZ:.*]] = mul i64 64, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC2_SZ]])
+    %2 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) -> !llvm.ptr
+
+    // CHECK:      %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]]
+    // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]])
+    %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr
+
+    // CHECK:      %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64
+    // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]])
+    omp.free_shared_mem %0 : !llvm.ptr
+
+    // CHECK:      %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]])
+    omp.free_shared_mem %1 : !llvm.ptr
+
+    // CHECK:      %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]])
+    omp.free_shared_mem %2 : !llvm.ptr
+
+    // CHECK:      %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]]
+    // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]])
+    omp.free_shared_mem %3 : !llvm.ptr
+    llvm.return
+  }
+}

>From 881dd50236d8f67e2b37fc53458792a92186724e Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Tue, 17 Feb 2026 10:49:35 +0000
Subject: [PATCH 3/3] address review comments: make omp.free_shared_mem
 self-contained, update alignment handling for shared memory allocations

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 37 ++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 50 ++++++++----------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  7 +--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 51 +++++++++----------
 mlir/test/Dialect/OpenMP/invalid.mlir         | 23 ++++++---
 mlir/test/Dialect/OpenMP/ops.mlir             | 12 +++--
 .../LLVMIR/omptarget-device-shared-mem.mlir   | 10 ++--
 7 files changed, 112 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1ffb01be4707a..9540cbcbe83d0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -908,6 +908,43 @@ class OpenMP_MapClauseSkip<
 
 def OpenMP_MapClause : OpenMP_MapClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to memory allocation information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_MemAllocationSizeClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+
+  let arguments = (ins
+    TypeAttr:$mem_elem_type,
+    AnySignlessInteger:$mem_array_size,
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$mem_alignment
+  );
+
+  let reqAssemblyFormat = [{
+    $mem_array_size `x` $mem_elem_type `:` `(` type($mem_array_size) `)`
+  }];
+
+  let optAssemblyFormat = [{
+    `align` `(` $mem_alignment `)`
+  }];
+
+  let description = [{
+    The `mem_elem_type` is the type of the object the memory allocation refers
+    to. It is used to calculate the size of the allocation.
+
+    The `mem_array_size` is the number of objects.
+
+    The optional `mem_alignment` is used to specify the alignment for each
+    element. If not set, the `DataLayout` defaults will be used instead.
+  }];
+}
+
+def OpenMP_MemAllocationSizeClause : OpenMP_MemAllocationSizeClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [15.8.1] `memory-order` clause set
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ec95360fe09c4..8661700ec1f01 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2241,15 +2241,11 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
 
 def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
     MemoryEffects<[MemAlloc<DefaultResource>]>
+  ], clauses = [
+    OpenMP_MemAllocationSizeClause
   ]> {
   let summary = "allocate storage on shared memory for objects of a given type";
 
-  let arguments = (ins
-    TypeAttr:$elem_type,
-    AnySignlessInteger:$array_size,
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment
-  );
-
   let description = [{
     Allocates memory shared across threads of a team for an object of the given
     type. Returns a pointer representing the allocated memory. The memory is
@@ -2258,27 +2254,18 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
 
     ```mlir
       // Allocate an i32 vector with %size elements and aligned to 8 bytes.
-      %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr)
+      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) align(8) -> !llvm.ptr
       // ...
-      omp.free_shared_mem %ptr_shared : !llvm.ptr
+      omp.free_shared_mem [%size x i32 : (i64) align(8)] %ptr_shared : !llvm.ptr
     ```
-
-    The `elem_type` is the type of the object for which memory is being
-    allocated.
-
-    The `array_size` is the number of objects to allocate memory for.
-
-    The optional `alignment` is used to specify the alignment for each element.
-    If not set, the `DataLayout` defaults will be used instead.
-  }];
+  }] # clausesDescription;
 
   let results = (outs OpenMP_PointerLikeType);
-  let assemblyFormat = [{
-    $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results)
-  }];
+  let assemblyFormat = clausesReqAssemblyFormat # " oilist(" #
+    clausesOptAssemblyFormat # ") `->` type(results) attr-dict";
 
   let extraClassDeclaration = [{
-    mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); }
+    mlir::Type getAllocatedType() { return getMemElemTypeAttr().getValue(); }
   }];
   let hasVerifier = 1;
 }
@@ -2287,31 +2274,34 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
 // FreeSharedMemOp
 //===----------------------------------------------------------------------===//
 
-def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", traits = [
+    MemoryEffects<[MemFree]>
+  ], clauses = [
+    OpenMP_MemAllocationSizeClause
+  ]> {
   let summary = "free shared memory";
 
   let description = [{
     Deallocates shared memory that was previously allocated by an
     `omp.alloc_shared_mem` operation. After this operation, the deallocated
     memory is in an undefined state and should not be accessed.
-    It is crucial to ensure that all accesses to the memory region are completed
-    before `omp.alloc_shared_mem` is called to avoid undefined behavior.
 
     ```mlir
       // Example of allocating and freeing shared memory.
-      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr)
+      %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> !llvm.ptr
       // ...
-      omp.free_shared_mem %ptr_shared : !llvm.ptr
+      omp.free_shared_mem [%size x i32 : (i64)] %ptr_shared : !llvm.ptr
     ```
 
     The `heapref` operand represents the pointer to shared memory to be
     deallocated, previously returned by `omp.alloc_shared_mem`.
-  }];
+  }] # clausesDescription;
 
-  let arguments = (ins
+  let arguments = !con(clausesArgs, (ins
     Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
-  );
-  let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+  ));
+  let assemblyFormat = "` ` `[`" # clausesReqAssemblyFormat # " oilist(" #
+    clausesOptAssemblyFormat # ") `]` $heapref `:` type($heapref) attr-dict";
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fb0d931c6c9e4..8cea85bee15b7 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4430,7 +4430,7 @@ LogicalResult AllocateDirOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocSharedMemOp::verify() {
-  return verifyAlignment(*getOperation(), getAlignment());
+  return verifyAlignment(*getOperation(), getMemAlignment());
 }
 
 //===----------------------------------------------------------------------===//
@@ -4438,10 +4438,7 @@ LogicalResult AllocSharedMemOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult FreeSharedMemOp::verify() {
-  return getHeapref().getDefiningOp<AllocSharedMemOp>()
-             ? success()
-             : emitOpError() << "'heapref' operand must be defined by an "
-                                "'omp.alloc_shared_memory' op";
+  return verifyAlignment(*getOperation(), getMemAlignment());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 94637b571aec0..dfe714085b8af 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7202,10 +7202,32 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
   return func;
 }
 
+template <typename T>
 static llvm::Value *
 getAllocationSize(llvm::IRBuilderBase &builder,
-                  LLVM::ModuleTranslation &moduleTranslation,
-                  omp::TargetAllocMemOp op) {
+                  LLVM::ModuleTranslation &moduleTranslation, T op) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy =
+      moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
+
+  auto alignment = op.getMemAlignment();
+  llvm::TypeSize typeSize = llvm::alignTo(
+      dataLayout.getTypeStoreSize(llvmHeapTy),
+      alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
+
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  return builder.CreateMul(
+      allocSize,
+      builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
+                            builder.getInt64Ty(),
+                            /*isSigned=*/false));
+}
+
+template <>
+llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
+                               LLVM::ModuleTranslation &moduleTranslation,
+                               omp::TargetAllocMemOp op) {
   llvm::DataLayout dataLayout =
       moduleTranslation.getLLVMModule()->getDataLayout();
   llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
@@ -7221,27 +7243,6 @@ getAllocationSize(llvm::IRBuilderBase &builder,
   return allocSize;
 }
 
-static llvm::Value *
-getAllocationSize(llvm::IRBuilderBase &builder,
-                  LLVM::ModuleTranslation &moduleTranslation,
-                  omp::AllocSharedMemOp op) {
-  llvm::DataLayout dataLayout =
-      moduleTranslation.getLLVMModule()->getDataLayout();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
-
-  auto alignment = op.getAlignment();
-  llvm::TypeSize typeSize = llvm::alignTo(
-      dataLayout.getTypeStoreSize(llvmHeapTy),
-      alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
-
-  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  return builder.CreateMul(
-      allocSize,
-      builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()),
-                            builder.getInt64Ty(),
-                            /*isSigned=*/false));
-}
-
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -7319,9 +7320,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
                        llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation) {
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  auto allocMemOp =
-      freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
-  llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
+  llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
   ompBuilder->createOMPFreeShared(
       builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
   return success();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a14cbc9a2cb77..0cab769049bac 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3184,28 +3184,35 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
 
 // -----
 func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () {
-  // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
-  %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr
+  // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) align(-2) -> !llvm.ptr
   return
 }
 
 // -----
 func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () {
   // expected-error @below {{ALIGN value : 3 must be power of 2}}
-  %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr
+  %0 = omp.alloc_shared_mem %n x i64 : (i32) align(3) -> !llvm.ptr
   return
 }
 
 // -----
-func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () {
+func.func @free_shared_mem_invalid_array_size(%n: f32, %ptr : !llvm.ptr) -> () {
   // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}}
-  %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr
+  %0 = omp.free_shared_mem [%n x i64 : (f32)] %ptr : !llvm.ptr
   return
 }
 
 // -----
-func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
-  // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}}
-  omp.free_shared_mem %ptr : !llvm.ptr
+func.func @free_shared_mem_invalid_alignment1(%n: i32, %ptr : !llvm.ptr) -> () {
+  // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+  omp.free_shared_mem [%n x i64 : (i32) align(-2)] %ptr : !llvm.ptr
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_alignment2(%n: i32, %ptr : !llvm.ptr) -> () {
+  // expected-error @below {{ALIGN value : 3 must be power of 2}}
+  omp.free_shared_mem [%n x i64 : (i32) align(3)] %ptr : !llvm.ptr
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b1f51d2b77d8b..1f9544301b184 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3584,8 +3584,8 @@ func.func @omp_alloc_shared_mem(%n: i32) {
   %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr
   // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr
   %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr
-  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr
-  %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr
+  // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr : (i32) align(16) -> !llvm.ptr
+  %2 = omp.alloc_shared_mem %n x !llvm.ptr : (i32) align(16) -> !llvm.ptr
   return
 }
 
@@ -3594,7 +3594,11 @@ func.func @omp_alloc_shared_mem(%n: i32) {
 func.func @omp_free_shared_mem(%n: i64) {
   // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr
   %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr
-  // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
-  omp.free_shared_mem %0 : !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64)] %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem [%n x f32 : (i64)] %0 : !llvm.ptr
+  // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) align(32) -> !llvm.ptr
+  %1 = omp.alloc_shared_mem %n x f32 : (i64) align(32) -> !llvm.ptr
+  // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64) align(32)] %[[PTR]] : !llvm.ptr
+  omp.free_shared_mem [%n x f32 : (i64) align(32)] %1 : !llvm.ptr
   return
 }
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
index 72b0a2daadfc3..cdebebc3ed233 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
@@ -19,24 +19,24 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 
     // CHECK:      %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]]
     // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]])
-    %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr
+    %3 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) align(128) -> !llvm.ptr
 
     // CHECK:      %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64
     // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]]
     // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]])
-    omp.free_shared_mem %0 : !llvm.ptr
+    omp.free_shared_mem [%n0 x i64 : (i32)] %0 : !llvm.ptr
 
     // CHECK:      %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]]
     // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]])
-    omp.free_shared_mem %1 : !llvm.ptr
+    omp.free_shared_mem [%n1 x i64 : (i64)] %1 : !llvm.ptr
 
     // CHECK:      %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]]
     // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]])
-    omp.free_shared_mem %2 : !llvm.ptr
+    omp.free_shared_mem [%n1 x vector<16xf32> : (i64)] %2 : !llvm.ptr
 
     // CHECK:      %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]]
     // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]])
-    omp.free_shared_mem %3 : !llvm.ptr
+    omp.free_shared_mem [%n1 x vector<16xf32> : (i64) align(128)] %3 : !llvm.ptr
     llvm.return
   }
 }



More information about the llvm-branch-commits mailing list