[llvm-branch-commits] [flang] [llvm] [mlir] [Flang][MLIR][OpenMP] Add explicit shared memory (de-)allocation ops (PR #161862)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 23 05:49:52 PST 2026
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/161862
>From 5d3807333c6d2b57d042dcb0b5ae27deeb024908 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 12 Sep 2025 15:56:04 +0100
Subject: [PATCH 1/3] [Flang][MLIR][OpenMP] Add explicit shared memory
(de-)allocation ops
This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem`
operations to represent explicit allocations and deallocations of shared memory
across threads in a team, mirroring the existing `omp.target_allocmem` and
`omp.target_freemem`.
The `omp.alloc_shared_mem` op goes through the same Flang-specific
transformations as `omp.target_allocmem`, so that the size of the buffer can be
properly calculated when translating to LLVM IR.
The corresponding runtime functions produced for these new operations are
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be
created for implicit allocations (e.g. privatized and reduction variables).
---
flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++-----
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 23 +++++++
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 29 ++++++---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 61 ++++++++++++++++++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 11 ++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 64 ++++++++++++++++---
mlir/test/Dialect/OpenMP/invalid.mlir | 21 ++++++
mlir/test/Dialect/OpenMP/ops.mlir | 31 ++++++++-
8 files changed, 249 insertions(+), 33 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 3e1fe1d2b1613..13214a9e51161 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,36 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
return converter.convertType(firType);
}
-// FIR Op specific conversion for TargetAllocMemOp
-struct TargetAllocMemOpConversion
- : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
- using OpenMPFIROpConversion::OpenMPFIROpConversion;
+// FIR Op specific conversion for allocation operations
+template <typename T>
+struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
+ using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
llvm::LogicalResult
- matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+ matchAndRewrite(T allocmemOp,
+ typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type heapTy = allocmemOp.getAllocatedType();
mlir::Location loc = allocmemOp.getLoc();
- auto ity = lowerTy().indexType();
+ auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
mlir::Type dataTy = fir::unwrapRefType(heapTy);
- mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+ mlir::Type llvmObjectTy =
+ convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
- TODO(loc, "omp.target_allocmem codegen of derived type with length "
- "parameters");
+ TODO(loc, allocmemOp->getName().getStringRef() +
+ " codegen of derived type with length parameters");
mlir::Value size = fir::computeElementDistance(
- loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+ loc, llvmObjectTy, ity, rewriter,
+ OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
if (auto scaleSize = fir::genAllocationScaleSize(
loc, allocmemOp.getInType(), ity, rewriter))
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
- for (mlir::Value opnd : adaptor.getOperands().drop_front())
+ for (mlir::Value opnd : adaptor.getTypeparams())
+ size = mlir::LLVM::MulOp::create(
+ rewriter, loc, ity, size,
+ integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+ opnd));
+ for (mlir::Value opnd : adaptor.getShape())
size = mlir::LLVM::MulOp::create(
rewriter, loc, ity, size,
- integerCast(lowerTy(), loc, rewriter, ity, opnd));
- auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+ integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+ opnd));
+ auto mallocTyWidth =
+ OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
auto mallocTy =
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
- size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+ size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
+ mallocTy, size);
rewriter.modifyOpInPlace(allocmemOp, [&]() {
allocmemOp.setInType(rewriter.getI8Type());
allocmemOp.getTypeparamsMutable().clear();
@@ -281,6 +292,7 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
- patterns.add<TargetAllocMemOpConversion>(converter);
patterns.add<DeclareMapperOpConversion>(converter);
+ patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
+ AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ca77f4c302c2b..865be4d1f1c93 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3215,6 +3215,17 @@ class OpenMPIRBuilder {
LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
Value *Allocator, std::string Name = "");
+ /// Create a runtime call for kmpc_alloc_shared.
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param Size Size of allocated memory space.
+ /// \param Name Name of call Instruction.
+ ///
+ /// \returns CallInst to the kmpc_alloc_shared call.
+ LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
+ Value *Size,
+ const Twine &Name = Twine(""));
+
/// Create a runtime call for kmpc_alloc_shared.
///
/// \param Loc The insert and source location description.
@@ -3226,6 +3237,18 @@ class OpenMPIRBuilder {
Type *VarType,
const Twine &Name = Twine(""));
+ /// Create a runtime call for kmpc_free_shared.
+ ///
+ /// \param Loc The insert and source location description.
+ /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
+ /// \param Size Size of allocated memory space.
+ /// \param Name Name of call Instruction.
+ ///
+ /// \returns CallInst to the kmpc_free_shared call.
+ LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
+ Value *Addr, Value *Size,
+ const Twine &Name = Twine(""));
+
/// Create a runtime call for kmpc_free_shared.
///
/// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 022097bea40e6..fb768c2fe443c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7796,32 +7796,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
}
CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
- Type *VarType,
+ Value *Size,
const Twine &Name) {
IRBuilder<>::InsertPointGuard IPG(Builder);
updateToLocation(Loc);
- const DataLayout &DL = M.getDataLayout();
- Value *Args[] = {Builder.getInt64(DL.getTypeAllocSize(VarType))};
+ Value *Args[] = {Size};
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
CallInst *Call = Builder.CreateCall(Fn, Args, Name);
- Call->addRetAttr(
- Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
+ Call->addRetAttr(Attribute::getWithAlignment(
+ M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
return Call;
}
+CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
+ Type *VarType,
+ const Twine &Name) {
+ return createOMPAllocShared(
+ Loc, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)), Name);
+}
+
CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
- Value *Addr, Type *VarType,
+ Value *Addr, Value *Size,
const Twine &Name) {
IRBuilder<>::InsertPointGuard IPG(Builder);
updateToLocation(Loc);
- Value *Args[] = {
- Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType))};
+ Value *Args[] = {Addr, Size};
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
return Builder.CreateCall(Fn, Args, Name);
}
+CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
+ Value *Addr, Type *VarType,
+ const Twine &Name) {
+ return createOMPFreeShared(
+ Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeAllocSize(VarType)),
+ Name);
+}
+
CallInst *OpenMPIRBuilder::createOMPInteropInit(
const LocationDescription &Loc, Value *InteropVar,
omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 28122f4d2ae89..fe8f01d64691a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2235,6 +2235,67 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
}
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
+ AttrSizedOperandSegments
+ ], clauses = [
+ OpenMP_HeapAllocClause
+ ]> {
+ let summary = "allocate storage on shared memory for an object of a given type";
+
+ let description = [{
+ Allocates memory shared across threads of a team for an object of the given
+ type. Returns a pointer representing the allocated memory. The memory is
+ uninitialized after allocation. Operations must be paired with
+ `omp.free_shared` to avoid memory leaks.
+
+ ```mlir
+ // Allocate a static 3x3 integer vector.
+ %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+ // ...
+ omp.free_shared_mem %ptr_shared : !llvm.ptr
+ ```
+ }] # clausesDescription;
+
+ let results = (outs OpenMP_PointerLikeType);
+ let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
+ let summary = "free shared memory";
+
+ let description = [{
+ Deallocates shared memory that was previously allocated by an
+ `omp.alloc_shared_mem` operation. After this operation, the deallocated
+ memory is in an undefined state and should not be accessed.
+ It is crucial to ensure that all accesses to the memory region are completed
+ before `omp.alloc_shared_mem` is called to avoid undefined behavior.
+
+ ```mlir
+ // Example of allocating and freeing shared memory.
+ %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+ // ...
+ omp.free_shared_mem %ptr_shared : !llvm.ptr
+ ```
+
+ The `heapref` operand represents the pointer to shared memory to be
+ deallocated, previously returned by `omp.alloc_shared_mem`.
+ }];
+
+ let arguments = (ins
+ Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
+ );
+ let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// workdistribute Construct
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cdd3c495ffd44..75dbeced94467 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4424,6 +4424,17 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FreeSharedMemOp::verify() {
+ return getHeapref().getDefiningOp<AllocSharedMemOp>()
+ ? success()
+ : emitOpError() << "'heapref' operand must be defined by an "
+ "'omp.alloc_shared_memory' op";
+}
+
//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a4394dd3fa6e6..e3caae2cac947 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7202,6 +7202,25 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
return func;
}
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
+ OperandRange typeparams, OperandRange shape) {
+ llvm::DataLayout dataLayout =
+ moduleTranslation.getLLVMModule()->getDataLayout();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
+ llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ for (auto typeParam : typeparams) {
+ allocSize = builder.CreateMul(
+ allocSize,
+ builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
+ builder.getInt64Ty(),
+ /*isSigned=*/false));
+ }
+ return allocSize;
+}
+
static LogicalResult
convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -7216,14 +7235,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
mlir::Value deviceNum = allocMemOp.getDevice();
llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
// Get the allocation size.
- llvm::DataLayout dataLayout = llvmModule->getDataLayout();
- mlir::Type heapTy = allocMemOp.getAllocatedType();
- llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
- llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
- llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
- for (auto typeParam : allocMemOp.getTypeparams())
- allocSize =
- builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+ llvm::Value *allocSize = getAllocationSize(
+ builder, moduleTranslation, allocMemOp.getAllocatedType(),
+ allocMemOp.getTypeparams(), allocMemOp.getShape());
// Create call to "omp_target_alloc" with the args as translated llvm values.
llvm::CallInst *call =
builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -7234,6 +7248,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+static LogicalResult
+convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::Value *size = getAllocationSize(
+ builder, moduleTranslation, allocMemOp.getAllocatedType(),
+ allocMemOp.getTypeparams(), allocMemOp.getShape());
+ moduleTranslation.mapValue(allocMemOp.getResult(),
+ ompBuilder->createOMPAllocShared(builder, size));
+ return success();
+}
+
static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
llvm::Module *llvmModule) {
llvm::Type *ptrTy = builder.getPtrTy(0);
@@ -7269,6 +7296,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+static LogicalResult
+convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ auto allocMemOp =
+ freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
+ llvm::Value *size = getAllocationSize(
+ builder, moduleTranslation, allocMemOp.getAllocatedType(),
+ allocMemOp.getTypeparams(), allocMemOp.getShape());
+ ompBuilder->createOMPFreeShared(
+ builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
+ return success();
+}
+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
@@ -7464,6 +7506,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case([&](omp::TargetFreeMemOp) {
return convertTargetFreeMemOp(*op, builder, moduleTranslation);
})
+ .Case([&](omp::AllocSharedMemOp op) {
+ return convertAllocSharedMemOp(op, builder, moduleTranslation);
+ })
+ .Case([&](omp::FreeSharedMemOp op) {
+ return convertFreeSharedMemOp(op, builder, moduleTranslation);
+ })
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 52c5d861a0bd7..3cdc26765098e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3181,3 +3181,24 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
%0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
return
}
+
+// -----
+func.func @alloc_shared_mem_invalid_uniq_name() -> () {
+ // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+ return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_bindc_name() -> () {
+ // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+ return
+}
+
+// -----
+func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
+ // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}}
+ omp.free_shared_mem %ptr : !llvm.ptr
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 1f24796ad71f0..c95128d17e5e9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3568,9 +3568,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
}
// CHECK-LABEL: func.func @omp_target_freemem(
-// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
-func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+// CHECK-SAME: %[[DEVICE:.*]]: i32) {
+func.func @omp_target_freemem(%device : i32) {
+ // CHECK: %[[PTR:.*]] = omp.target_allocmem
+ %ptr = omp.target_allocmem %device : i32, i64
// CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
omp.target_freemem %device, %ptr : i32, i64
return
}
+
+// CHECK-LABEL: func.func @omp_alloc_shared_mem(
+// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) {
+ // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr
+ %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr
+ %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr
+ %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr
+ %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr
+ %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem() {
+func.func @omp_free_shared_mem() {
+ // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem
+ %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+ // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
+ omp.free_shared_mem %0 : !llvm.ptr
+ return
+}
>From e1744f3f323882fb7f167de4a91e6653be7f6d9f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Thu, 5 Feb 2026 13:06:23 +0000
Subject: [PATCH 2/3] simplify omp.alloc_shared_mem
---
flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 42 +++++++-----------
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 37 ++++++++++++----
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 23 +++++++---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 44 +++++++++++++------
mlir/test/Dialect/OpenMP/invalid.mlir | 19 +++++---
mlir/test/Dialect/OpenMP/ops.mlir | 35 +++++++--------
.../LLVMIR/omptarget-device-shared-mem.mlir | 42 ++++++++++++++++++
7 files changed, 160 insertions(+), 82 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 13214a9e51161..3e1fe1d2b1613 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,47 +222,36 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
return converter.convertType(firType);
}
-// FIR Op specific conversion for allocation operations
-template <typename T>
-struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
- using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
+// FIR Op specific conversion for TargetAllocMemOp
+struct TargetAllocMemOpConversion
+ : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
llvm::LogicalResult
- matchAndRewrite(T allocmemOp,
- typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
+ matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type heapTy = allocmemOp.getAllocatedType();
mlir::Location loc = allocmemOp.getLoc();
- auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
+ auto ity = lowerTy().indexType();
mlir::Type dataTy = fir::unwrapRefType(heapTy);
- mlir::Type llvmObjectTy =
- convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
+ mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
- TODO(loc, allocmemOp->getName().getStringRef() +
- " codegen of derived type with length parameters");
+ TODO(loc, "omp.target_allocmem codegen of derived type with length "
+ "parameters");
mlir::Value size = fir::computeElementDistance(
- loc, llvmObjectTy, ity, rewriter,
- OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
+ loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
if (auto scaleSize = fir::genAllocationScaleSize(
loc, allocmemOp.getInType(), ity, rewriter))
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
- for (mlir::Value opnd : adaptor.getTypeparams())
- size = mlir::LLVM::MulOp::create(
- rewriter, loc, ity, size,
- integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
- opnd));
- for (mlir::Value opnd : adaptor.getShape())
+ for (mlir::Value opnd : adaptor.getOperands().drop_front())
size = mlir::LLVM::MulOp::create(
rewriter, loc, ity, size,
- integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
- opnd));
- auto mallocTyWidth =
- OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
+ integerCast(lowerTy(), loc, rewriter, ity, opnd));
+ auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
auto mallocTy =
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
- size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
- mallocTy, size);
+ size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
rewriter.modifyOpInPlace(allocmemOp, [&]() {
allocmemOp.setInType(rewriter.getI8Type());
allocmemOp.getTypeparamsMutable().clear();
@@ -292,7 +281,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
+ patterns.add<TargetAllocMemOpConversion>(converter);
patterns.add<DeclareMapperOpConversion>(converter);
- patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
- AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index fe8f01d64691a..ec95360fe09c4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2240,11 +2240,15 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
//===----------------------------------------------------------------------===//
def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
- AttrSizedOperandSegments
- ], clauses = [
- OpenMP_HeapAllocClause
+ MemoryEffects<[MemAlloc<DefaultResource>]>
]> {
- let summary = "allocate storage on shared memory for an object of a given type";
+ let summary = "allocate storage on shared memory for objects of a given type";
+
+ let arguments = (ins
+ TypeAttr:$elem_type,
+ AnySignlessInteger:$array_size,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment
+ );
let description = [{
Allocates memory shared across threads of a team for an object of the given
@@ -2253,15 +2257,30 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
`omp.free_shared` to avoid memory leaks.
```mlir
- // Allocate a static 3x3 integer vector.
- %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+ // Allocate an i32 vector with %size elements and aligned to 8 bytes.
+ %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr)
// ...
omp.free_shared_mem %ptr_shared : !llvm.ptr
```
- }] # clausesDescription;
+
+ The `elem_type` is the type of the object for which memory is being
+ allocated.
+
+ The `array_size` is the number of objects to allocate memory for.
+
+ The optional `alignment` is used to specify the alignment for each element.
+ If not set, the `DataLayout` defaults will be used instead.
+ }];
let results = (outs OpenMP_PointerLikeType);
- let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+ let assemblyFormat = [{
+ $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results)
+ }];
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); }
+ }];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2280,7 +2299,7 @@ def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
```mlir
// Example of allocating and freeing shared memory.
- %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+ %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr)
// ...
omp.free_shared_mem %ptr_shared : !llvm.ptr
```
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 75dbeced94467..fb0d931c6c9e4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4411,17 +4411,26 @@ LogicalResult ScanOp::verify() {
}
/// Verifies align clause in allocate directive
+LogicalResult verifyAlignment(Operation &op,
+ std::optional<uint64_t> alignment) {
+ if (alignment.has_value()) {
+ if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
+ return op.emitError()
+ << "ALIGN value : " << alignment.value() << " must be power of 2";
+ }
+ return success();
+}
LogicalResult AllocateDirOp::verify() {
- std::optional<uint64_t> align = this->getAlign();
+ return verifyAlignment(*getOperation(), getAlign());
+}
- if (align.has_value()) {
- if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
- return emitError() << "ALIGN value : " << align.value()
- << " must be power of 2";
- }
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
- return success();
+LogicalResult AllocSharedMemOp::verify() {
+ return verifyAlignment(*getOperation(), getAlignment());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e3caae2cac947..94637b571aec0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7204,14 +7204,14 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
static llvm::Value *
getAllocationSize(llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
- OperandRange typeparams, OperandRange shape) {
+ LLVM::ModuleTranslation &moduleTranslation,
+ omp::TargetAllocMemOp op) {
llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
- llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
- llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
+ llvm::TypeSize typeSize = dataLayout.getTypeAllocSize(llvmHeapTy);
llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
- for (auto typeParam : typeparams) {
+ for (auto typeParam : op.getTypeparams()) {
allocSize = builder.CreateMul(
allocSize,
builder.CreateIntCast(moduleTranslation.lookupValue(typeParam),
@@ -7221,6 +7221,27 @@ getAllocationSize(llvm::IRBuilderBase &builder,
return allocSize;
}
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ omp::AllocSharedMemOp op) {
+ llvm::DataLayout dataLayout =
+ moduleTranslation.getLLVMModule()->getDataLayout();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
+
+ auto alignment = op.getAlignment();
+ llvm::TypeSize typeSize = llvm::alignTo(
+ dataLayout.getTypeStoreSize(llvmHeapTy),
+ alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
+
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ return builder.CreateMul(
+ allocSize,
+ builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()),
+ builder.getInt64Ty(),
+ /*isSigned=*/false));
+}
+
static LogicalResult
convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -7235,9 +7256,8 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
mlir::Value deviceNum = allocMemOp.getDevice();
llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
// Get the allocation size.
- llvm::Value *allocSize = getAllocationSize(
- builder, moduleTranslation, allocMemOp.getAllocatedType(),
- allocMemOp.getTypeparams(), allocMemOp.getShape());
+ llvm::Value *allocSize =
+ getAllocationSize(builder, moduleTranslation, allocMemOp);
// Create call to "omp_target_alloc" with the args as translated llvm values.
llvm::CallInst *call =
builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -7253,9 +7273,7 @@ convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- llvm::Value *size = getAllocationSize(
- builder, moduleTranslation, allocMemOp.getAllocatedType(),
- allocMemOp.getTypeparams(), allocMemOp.getShape());
+ llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
moduleTranslation.mapValue(allocMemOp.getResult(),
ompBuilder->createOMPAllocShared(builder, size));
return success();
@@ -7303,9 +7321,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto allocMemOp =
freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
- llvm::Value *size = getAllocationSize(
- builder, moduleTranslation, allocMemOp.getAllocatedType(),
- allocMemOp.getTypeparams(), allocMemOp.getShape());
+ llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
ompBuilder->createOMPFreeShared(
builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
return success();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 3cdc26765098e..a14cbc9a2cb77 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3183,16 +3183,23 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
}
// -----
-func.func @alloc_shared_mem_invalid_uniq_name() -> () {
- // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
- %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () {
+ // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+ %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr
return
}
// -----
-func.func @alloc_shared_mem_invalid_bindc_name() -> () {
- // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
- %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () {
+ // expected-error @below {{ALIGN value : 3 must be power of 2}}
+ %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr
+ return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () {
+ // expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}}
+ %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr
return
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c95128d17e5e9..b1f51d2b77d8b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3578,25 +3578,22 @@ func.func @omp_target_freemem(%device : i32) {
}
// CHECK-LABEL: func.func @omp_alloc_shared_mem(
-// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
-func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) {
- // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr
- %0 = omp.alloc_shared_mem i64 : !llvm.ptr
- // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr
- %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr
- // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr
- %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr
- // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr
- %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr
- // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr
- %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr
- return
-}
-
-// CHECK-LABEL: func.func @omp_free_shared_mem() {
-func.func @omp_free_shared_mem() {
- // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem
- %0 = omp.alloc_shared_mem i64 : !llvm.ptr
+// CHECK-SAME: %[[N:.*]]: i32) {
+func.func @omp_alloc_shared_mem(%n: i32) {
+ // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x i64 : (i32) -> !llvm.ptr
+ %0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr
+ %1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr
+ %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: func.func @omp_free_shared_mem(
+// CHECK-SAME: %[[N:.*]]: i64) {
+func.func @omp_free_shared_mem(%n: i64) {
+ // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr
+ %0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr
// CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
omp.free_shared_mem %0 : !llvm.ptr
return
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
new file mode 100644
index 0000000000000..72b0a2daadfc3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+ // CHECK-LABEL: define void @device_shared_mem(
+ // CHECK-SAME: i32 %[[N0:.*]], i64 %[[N1:.*]])
+ llvm.func @device_shared_mem(%n0: i32, %n1: i64) attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>} {
+ // CHECK: %[[CAST_N0:.*]] = zext i32 %[[N0]] to i64
+ // CHECK-NEXT: %[[ALLOC0_SZ:.*]] = mul i64 8, %[[CAST_N0]]
+ // CHECK-NEXT: %[[ALLOC0:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC0_SZ]])
+ %0 = omp.alloc_shared_mem %n0 x i64 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ALLOC1_SZ:.*]] = mul i64 8, %[[N1]]
+ // CHECK-NEXT: %[[ALLOC1:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC1_SZ]])
+ %1 = omp.alloc_shared_mem %n1 x i64 : (i64) -> !llvm.ptr
+
+ // CHECK: %[[ALLOC2_SZ:.*]] = mul i64 64, %[[N1]]
+ // CHECK-NEXT: %[[ALLOC2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC2_SZ]])
+ %2 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) -> !llvm.ptr
+
+ // CHECK: %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]]
+ // CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]])
+ %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr
+
+ // CHECK: %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64
+ // CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]]
+ // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]])
+ omp.free_shared_mem %0 : !llvm.ptr
+
+ // CHECK: %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]]
+ // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]])
+ omp.free_shared_mem %1 : !llvm.ptr
+
+ // CHECK: %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]]
+ // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]])
+ omp.free_shared_mem %2 : !llvm.ptr
+
+ // CHECK: %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]]
+ // CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]])
+ omp.free_shared_mem %3 : !llvm.ptr
+ llvm.return
+ }
+}
>From 881dd50236d8f67e2b37fc53458792a92186724e Mon Sep 17 00:00:00 2001
From: Sergio Afonso <Sergio.AfonsoFumero at amd.com>
Date: Tue, 17 Feb 2026 10:49:35 +0000
Subject: [PATCH 3/3] address review comments: make omp.free_shared_mem
self-contained, update alignment handling for shared memory allocations
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 37 ++++++++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 50 ++++++++----------
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 7 +--
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 51 +++++++++----------
mlir/test/Dialect/OpenMP/invalid.mlir | 23 ++++++---
mlir/test/Dialect/OpenMP/ops.mlir | 12 +++--
.../LLVMIR/omptarget-device-shared-mem.mlir | 10 ++--
7 files changed, 112 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1ffb01be4707a..9540cbcbe83d0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -908,6 +908,43 @@ class OpenMP_MapClauseSkip<
def OpenMP_MapClause : OpenMP_MapClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to memory allocation information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_MemAllocationSizeClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+
+ let arguments = (ins
+ TypeAttr:$mem_elem_type,
+ AnySignlessInteger:$mem_array_size,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$mem_alignment
+ );
+
+ let reqAssemblyFormat = [{
+ $mem_array_size `x` $mem_elem_type `:` `(` type($mem_array_size) `)`
+ }];
+
+ let optAssemblyFormat = [{
+ `align` `(` $mem_alignment `)`
+ }];
+
+ let description = [{
+ The `mem_elem_type` is the type of the object the memory allocation refers
+ to. It is used to calculate the size of the allocation.
+
+ The `mem_array_size` is the number of objects.
+
+ The optional `mem_alignment` is used to specify the alignment for each
+ element. If not set, the `DataLayout` defaults will be used instead.
+ }];
+}
+
+def OpenMP_MemAllocationSizeClause : OpenMP_MemAllocationSizeClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [15.8.1] `memory-order` clause set
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ec95360fe09c4..8661700ec1f01 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2241,15 +2241,11 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
MemoryEffects<[MemAlloc<DefaultResource>]>
+ ], clauses = [
+ OpenMP_MemAllocationSizeClause
]> {
let summary = "allocate storage on shared memory for objects of a given type";
- let arguments = (ins
- TypeAttr:$elem_type,
- AnySignlessInteger:$array_size,
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$alignment
- );
-
let description = [{
Allocates memory shared across threads of a team for an object of the given
type. Returns a pointer representing the allocated memory. The memory is
@@ -2258,27 +2254,18 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
```mlir
// Allocate an i32 vector with %size elements and aligned to 8 bytes.
- %ptr_shared = omp.alloc_shared_mem %size x i32 {alignment = 8} : (i64) -> (!llvm.ptr)
+ %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) align(8) -> !llvm.ptr
// ...
- omp.free_shared_mem %ptr_shared : !llvm.ptr
+ omp.free_shared_mem [%size x i32 : (i64) align(8)] %ptr_shared : !llvm.ptr
```
-
- The `elem_type` is the type of the object for which memory is being
- allocated.
-
- The `array_size` is the number of objects to allocate memory for.
-
- The optional `alignment` is used to specify the alignment for each element.
- If not set, the `DataLayout` defaults will be used instead.
- }];
+ }] # clausesDescription;
let results = (outs OpenMP_PointerLikeType);
- let assemblyFormat = [{
- $array_size `x` $elem_type attr-dict `:` `(` type($array_size) `)` `->` type(results)
- }];
+ let assemblyFormat = clausesReqAssemblyFormat # " oilist(" #
+ clausesOptAssemblyFormat # ") `->` type(results) attr-dict";
let extraClassDeclaration = [{
- mlir::Type getAllocatedType() { return getElemTypeAttr().getValue(); }
+ mlir::Type getAllocatedType() { return getMemElemTypeAttr().getValue(); }
}];
let hasVerifier = 1;
}
@@ -2287,31 +2274,34 @@ def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
// FreeSharedMemOp
//===----------------------------------------------------------------------===//
-def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", traits = [
+ MemoryEffects<[MemFree]>
+ ], clauses = [
+ OpenMP_MemAllocationSizeClause
+ ]> {
let summary = "free shared memory";
let description = [{
Deallocates shared memory that was previously allocated by an
`omp.alloc_shared_mem` operation. After this operation, the deallocated
memory is in an undefined state and should not be accessed.
- It is crucial to ensure that all accesses to the memory region are completed
- before `omp.alloc_shared_mem` is called to avoid undefined behavior.
```mlir
// Example of allocating and freeing shared memory.
- %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> (!llvm.ptr)
+ %ptr_shared = omp.alloc_shared_mem %size x i32 : (i64) -> !llvm.ptr
// ...
- omp.free_shared_mem %ptr_shared : !llvm.ptr
+ omp.free_shared_mem [%size x i32 : (i64)] %ptr_shared : !llvm.ptr
```
The `heapref` operand represents the pointer to shared memory to be
deallocated, previously returned by `omp.alloc_shared_mem`.
- }];
+ }] # clausesDescription;
- let arguments = (ins
+ let arguments = !con(clausesArgs, (ins
Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
- );
- let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+ ));
+ let assemblyFormat = "` ` `[`" # clausesReqAssemblyFormat # " oilist(" #
+ clausesOptAssemblyFormat # ") `]` $heapref `:` type($heapref) attr-dict";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fb0d931c6c9e4..8cea85bee15b7 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4430,7 +4430,7 @@ LogicalResult AllocateDirOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AllocSharedMemOp::verify() {
- return verifyAlignment(*getOperation(), getAlignment());
+ return verifyAlignment(*getOperation(), getMemAlignment());
}
//===----------------------------------------------------------------------===//
@@ -4438,10 +4438,7 @@ LogicalResult AllocSharedMemOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult FreeSharedMemOp::verify() {
- return getHeapref().getDefiningOp<AllocSharedMemOp>()
- ? success()
- : emitOpError() << "'heapref' operand must be defined by an "
- "'omp.alloc_shared_memory' op";
+ return verifyAlignment(*getOperation(), getMemAlignment());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 94637b571aec0..dfe714085b8af 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -7202,10 +7202,32 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
return func;
}
+template <typename T>
static llvm::Value *
getAllocationSize(llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation,
- omp::TargetAllocMemOp op) {
+ LLVM::ModuleTranslation &moduleTranslation, T op) {
+ llvm::DataLayout dataLayout =
+ moduleTranslation.getLLVMModule()->getDataLayout();
+ llvm::Type *llvmHeapTy =
+ moduleTranslation.convertType(op.getMemElemTypeAttr().getValue());
+
+ auto alignment = op.getMemAlignment();
+ llvm::TypeSize typeSize = llvm::alignTo(
+ dataLayout.getTypeStoreSize(llvmHeapTy),
+ alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
+
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ return builder.CreateMul(
+ allocSize,
+ builder.CreateIntCast(moduleTranslation.lookupValue(op.getMemArraySize()),
+ builder.getInt64Ty(),
+ /*isSigned=*/false));
+}
+
+template <>
+llvm::Value *getAllocationSize(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ omp::TargetAllocMemOp op) {
llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
@@ -7221,27 +7243,6 @@ getAllocationSize(llvm::IRBuilderBase &builder,
return allocSize;
}
-static llvm::Value *
-getAllocationSize(llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation,
- omp::AllocSharedMemOp op) {
- llvm::DataLayout dataLayout =
- moduleTranslation.getLLVMModule()->getDataLayout();
- llvm::Type *llvmHeapTy = moduleTranslation.convertType(op.getAllocatedType());
-
- auto alignment = op.getAlignment();
- llvm::TypeSize typeSize = llvm::alignTo(
- dataLayout.getTypeStoreSize(llvmHeapTy),
- alignment ? *alignment : dataLayout.getABITypeAlign(llvmHeapTy).value());
-
- llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
- return builder.CreateMul(
- allocSize,
- builder.CreateIntCast(moduleTranslation.lookupValue(op.getArraySize()),
- builder.getInt64Ty(),
- /*isSigned=*/false));
-}
-
static LogicalResult
convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -7319,9 +7320,7 @@ convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- auto allocMemOp =
- freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
- llvm::Value *size = getAllocationSize(builder, moduleTranslation, allocMemOp);
+ llvm::Value *size = getAllocationSize(builder, moduleTranslation, freeMemOp);
ompBuilder->createOMPFreeShared(
builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
return success();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a14cbc9a2cb77..0cab769049bac 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3184,28 +3184,35 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
// -----
func.func @alloc_shared_mem_invalid_alignment1(%n: i32) -> () {
- // expected-error @below {{op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
- %0 = omp.alloc_shared_mem %n x i64 {alignment=-2} : (i32) -> !llvm.ptr
+ // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+ %0 = omp.alloc_shared_mem %n x i64 : (i32) align(-2) -> !llvm.ptr
return
}
// -----
func.func @alloc_shared_mem_invalid_alignment2(%n: i32) -> () {
// expected-error @below {{ALIGN value : 3 must be power of 2}}
- %0 = omp.alloc_shared_mem %n x i64 {alignment=3} : (i32) -> !llvm.ptr
+ %0 = omp.alloc_shared_mem %n x i64 : (i32) align(3) -> !llvm.ptr
return
}
// -----
-func.func @alloc_shared_mem_invalid_array_size(%n: f32) -> () {
+func.func @free_shared_mem_invalid_array_size(%n: f32, %ptr : !llvm.ptr) -> () {
// expected-error @below {{invalid kind of type specified: expected builtin.integer, but found 'f32'}}
- %0 = omp.alloc_shared_mem %n x i64 : (f32) -> !llvm.ptr
+ %0 = omp.free_shared_mem [%n x i64 : (f32)] %ptr : !llvm.ptr
return
}
// -----
-func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
- // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}}
- omp.free_shared_mem %ptr : !llvm.ptr
+func.func @free_shared_mem_invalid_alignment1(%n: i32, %ptr : !llvm.ptr) -> () {
+ // expected-error @below {{op attribute 'mem_alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
+ omp.free_shared_mem [%n x i64 : (i32) align(-2)] %ptr : !llvm.ptr
+ return
+}
+
+// -----
+func.func @free_shared_mem_invalid_alignment2(%n: i32, %ptr : !llvm.ptr) -> () {
+ // expected-error @below {{ALIGN value : 3 must be power of 2}}
+ omp.free_shared_mem [%n x i64 : (i32) align(3)] %ptr : !llvm.ptr
return
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b1f51d2b77d8b..1f9544301b184 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3584,8 +3584,8 @@ func.func @omp_alloc_shared_mem(%n: i32) {
%0 = omp.alloc_shared_mem %n x i64 : (i32) -> !llvm.ptr
// CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x vector<16x16xf32> : (i32) -> !llvm.ptr
%1 = omp.alloc_shared_mem %n x vector<16x16xf32> : (i32) -> !llvm.ptr
- // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr {alignment = 16 : i64} : (i32) -> !llvm.ptr
- %2 = omp.alloc_shared_mem %n x !llvm.ptr {alignment = 16} : (i32) -> !llvm.ptr
+ // CHECK: %{{.*}} = omp.alloc_shared_mem %[[N]] x !llvm.ptr : (i32) align(16) -> !llvm.ptr
+ %2 = omp.alloc_shared_mem %n x !llvm.ptr : (i32) align(16) -> !llvm.ptr
return
}
@@ -3594,7 +3594,11 @@ func.func @omp_alloc_shared_mem(%n: i32) {
func.func @omp_free_shared_mem(%n: i64) {
// CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) -> !llvm.ptr
%0 = omp.alloc_shared_mem %n x f32 : (i64) -> !llvm.ptr
- // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr
- omp.free_shared_mem %0 : !llvm.ptr
+ // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64)] %[[PTR]] : !llvm.ptr
+ omp.free_shared_mem [%n x f32 : (i64)] %0 : !llvm.ptr
+ // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem %[[N]] x f32 : (i64) align(32) -> !llvm.ptr
+ %1 = omp.alloc_shared_mem %n x f32 : (i64) align(32) -> !llvm.ptr
+ // CHECK: omp.free_shared_mem [%[[N]] x f32 : (i64) align(32)] %[[PTR]] : !llvm.ptr
+ omp.free_shared_mem [%n x f32 : (i64) align(32)] %1 : !llvm.ptr
return
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
index 72b0a2daadfc3..cdebebc3ed233 100644
--- a/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-device-shared-mem.mlir
@@ -19,24 +19,24 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: %[[ALLOC3_SZ:.*]] = mul i64 128, %[[N1]]
// CHECK-NEXT: %[[ALLOC3:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 %[[ALLOC3_SZ]])
- %3 = omp.alloc_shared_mem %n1 x vector<16xf32> {alignment = 128} : (i64) -> !llvm.ptr
+ %3 = omp.alloc_shared_mem %n1 x vector<16xf32> : (i64) align(128) -> !llvm.ptr
// CHECK: %[[CAST_N0_1:.*]] = zext i32 %[[N0]] to i64
// CHECK-NEXT: %[[FREE0_SZ:.*]] = mul i64 8, %[[CAST_N0_1]]
// CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC0]], i64 %[[FREE0_SZ]])
- omp.free_shared_mem %0 : !llvm.ptr
+ omp.free_shared_mem [%n0 x i64 : (i32)] %0 : !llvm.ptr
// CHECK: %[[FREE1_SZ:.*]] = mul i64 8, %[[N1]]
// CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC1]], i64 %[[FREE1_SZ]])
- omp.free_shared_mem %1 : !llvm.ptr
+ omp.free_shared_mem [%n1 x i64 : (i64)] %1 : !llvm.ptr
// CHECK: %[[FREE2_SZ:.*]] = mul i64 64, %[[N1]]
// CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC2]], i64 %[[FREE2_SZ]])
- omp.free_shared_mem %2 : !llvm.ptr
+ omp.free_shared_mem [%n1 x vector<16xf32> : (i64)] %2 : !llvm.ptr
// CHECK: %[[FREE3_SZ:.*]] = mul i64 128, %[[N1]]
// CHECK-NEXT: call void @__kmpc_free_shared(ptr %[[ALLOC3]], i64 %[[FREE3_SZ]])
- omp.free_shared_mem %3 : !llvm.ptr
+ omp.free_shared_mem [%n1 x vector<16xf32> : (i64) align(128)] %3 : !llvm.ptr
llvm.return
}
}
More information about the llvm-branch-commits
mailing list