[llvm-branch-commits] [mlir] [MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC (PR #161861)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Oct 3 08:23:58 PDT 2025
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/161861
This patch moves tablegen definitions that could be used for all kinds of heap allocations out of `omp.target_allocmem` and into a new `OpenMP_HeapAllocClause` that can be reused.
Descriptions are updated to follow the format of most other operations and the custom verifier for `omp.target_allocmem` is removed as it only made a redundant check on its result type.
>From 1eccb260d6d32f2870f8056580e02dec7f7fd19f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 12 Sep 2025 11:26:40 +0100
Subject: [PATCH] [MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse,
NFC
This patch moves tablegen definitions that could be used for all kinds of heap
allocations out of `omp.target_allocmem` and into a new
`OpenMP_HeapAllocClause` that can be reused.
Descriptions are updated to follow the format of most other operations and the
custom verifier for `omp.target_allocmem` is removed as it only made a
redundant check on its result type.
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 ++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 80 ++++-----
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 153 ++++++------------
mlir/test/Dialect/OpenMP/invalid.mlir | 14 ++
mlir/test/Dialect/OpenMP/ops.mlir | 24 +++
5 files changed, 176 insertions(+), 148 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1eda5e4bc1618..3b6ecceb1dfb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -20,6 +20,7 @@
#define OPENMP_CLAUSES
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -547,6 +548,58 @@ class OpenMP_HasDeviceAddrClauseSkip<
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold heap allocation information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HeapAllocClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let traits = [
+ MemoryEffects<[MemAlloc<DefaultResource>]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$in_type,
+ OptionalAttr<StrAttr>:$uniq_name,
+ OptionalAttr<StrAttr>:$bindc_name,
+ Variadic<IntLikeType>:$typeparams,
+ Variadic<IntLikeType>:$shape
+ );
+
+ // The custom parser doesn't parse `uniq_name` and `bindc_name`. This is
+ // handled by the attr-dict, which must be present in the operation's
+ // `assemblyFormat`.
+ let reqAssemblyFormat = [{
+ custom<HeapAllocClause>($in_type, $typeparams, type($typeparams), $shape,
+ type($shape))
+ }];
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType() { return getInTypeAttr().getValue(); }
+ }];
+
+ let description = [{
+ The `in_type` is the type of the object for which memory is being allocated.
+ For arrays, this can be a static or dynamic array type.
+
+ The optional `uniq_name` is a unique name for the allocated memory.
+
+ The optional `bindc_name` is a name used for C interoperability.
+
+ The `typeparams` are runtime type parameters for polymorphic or
+ parameterized types. These are typically integer values that define aspects
+ of a type not fixed at compile time.
+
+ The `shape` holds runtime shape operands for dynamic arrays. Each operand is
+ an integer value representing the extent of a specific dimension.
+ }];
+}
+
+def OpenMP_HeapAllocClause : OpenMP_HeapAllocClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [5.4.7] `inclusive` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9003fb2ef7959..8b206f58c7733 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2128,59 +2128,45 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clause
// TargetAllocMemOp
//===----------------------------------------------------------------------===//
-def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
- [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+def TargetAllocMemOp : OpenMP_Op<"target_allocmem", traits = [
+ AttrSizedOperandSegments
+ ], clauses = [
+ OpenMP_HeapAllocClause
+ ]> {
let summary = "allocate storage on an openmp device for an object of a given type";
let description = [{
- Allocates memory on the specified OpenMP device for an object of the given type.
- Returns an integer value representing the device pointer to the allocated memory.
- The memory is uninitialized after allocation. Operations must be paired with
- `omp.target_freemem` to avoid memory leaks.
-
- * `$device`: The integer ID of the OpenMP device where the memory will be allocated.
- * `$in_type`: The type of the object for which memory is being allocated.
- For arrays, this can be a static or dynamic array type.
- * `$uniq_name`: An optional unique name for the allocated memory.
- * `$bindc_name`: An optional name used for C interoperability.
- * `$typeparams`: Runtime type parameters for polymorphic or parameterized types.
- These are typically integer values that define aspects of a type not fixed at compile time.
- * `$shape`: Runtime shape operands for dynamic arrays.
- Each operand is an integer value representing the extent of a specific dimension.
-
- ```mlir
- // Allocate a static 3x3 integer vector on device 0
- %device_0 = arith.constant 0 : i32
- %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
- // ... use %ptr_static ...
- omp.target_freemem %device_0, %ptr_static : i32, i64
-
- // Allocate a dynamic 2D Fortran array (fir.array) on device 1
- %device_1 = arith.constant 1 : i32
- %rows = arith.constant 10 : index
- %cols = arith.constant 20 : index
- %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
- // ... use %ptr_dynamic ...
- omp.target_freemem %device_1, %ptr_dynamic : i32, i64
- ```
- }];
+ Allocates memory on the specified OpenMP device for an object of the given
+ type. Returns an integer value representing the device pointer to the
+ allocated memory. The memory is uninitialized after allocation. Operations
+ must be paired with `omp.target_freemem` to avoid memory leaks.
- let arguments = (ins
- Arg<AnyInteger>:$device,
- TypeAttr:$in_type,
- OptionalAttr<StrAttr>:$uniq_name,
- OptionalAttr<StrAttr>:$bindc_name,
- Variadic<IntLikeType>:$typeparams,
- Variadic<IntLikeType>:$shape
- );
- let results = (outs I64);
+ ```mlir
+ // Allocate a static 3x3 integer vector on device 0
+ %device_0 = arith.constant 0 : i32
+ %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
+ // ... use %ptr_static ...
+ omp.target_freemem %device_0, %ptr_static : i32, i64
+
+ // Allocate a dynamic 2D Fortran array (fir.array) on device 1
+ %device_1 = arith.constant 1 : i32
+ %rows = arith.constant 10 : index
+ %cols = arith.constant 20 : index
+ %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
+ // ... use %ptr_dynamic ...
+ omp.target_freemem %device_1, %ptr_dynamic : i32, i64
+ ```
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
+ The `device` is an integer ID of the OpenMP device where the memory will be
+ allocated.
+ }] # clausesDescription;
- let extraClassDeclaration = [{
- mlir::Type getAllocatedType();
- }];
+ let arguments = !con((ins Arg<AnyInteger>:$device), clausesArgs);
+ let results = (outs I64);
+
+ // Override inherited assembly format to include `device`.
+ let assemblyFormat = " $device `:` type($device) `,` "
+ # clausesReqAssemblyFormat # " attr-dict";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8640c4ba0b757..fabb1b8c173a2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -797,6 +797,58 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for Heap Alloc Clause
+//===----------------------------------------------------------------------===//
+
+/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+static ParseResult parseHeapAllocClause(
+ OpAsmParser &parser, TypeAttr &inTypeAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &typeparams,
+ SmallVectorImpl<Type> &typeparamsTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &shape,
+ SmallVectorImpl<Type> &shapeTypes) {
+ mlir::Type inType;
+ if (parser.parseType(inType))
+ return mlir::failure();
+ inTypeAttr = TypeAttr::get(inType);
+
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
+ return failure();
+ }
+
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(shape, OpAsmParser::Delimiter::None))
+ return failure();
+
+ // TODO: This overrides the actual types of the operands, which might cause
+ // issues when they don't match. At the moment this is done in place of
+ // making the corresponding operand type `Variadic<Index>` because index
+ // types are lowered to I64 prior to LLVM IR translation.
+ shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
+ }
+
+ return success();
+}
+
+static void printHeapAllocClause(OpAsmPrinter &p, Operation *op,
+ TypeAttr inType, ValueRange typeparams,
+ TypeRange typeparamsTypes, ValueRange shape,
+ TypeRange shapeTypes) {
+ p << inType;
+ if (!typeparams.empty()) {
+ p << '(' << typeparams << " : " << typeparamsTypes << ')';
+ }
+ for (auto sh : shape) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -4109,107 +4161,6 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// TargetAllocMemOp
-//===----------------------------------------------------------------------===//
-
-mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
- return getInTypeAttr().getValue();
-}
-
-/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
-/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
-/// attr-dict-without-keyword
-static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
- auto &builder = parser.getBuilder();
- bool hasOperands = false;
- std::int32_t typeparamsSize = 0;
-
- // Parse device number as a new operand
- mlir::OpAsmParser::UnresolvedOperand deviceOperand;
- mlir::Type deviceType;
- if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
- return mlir::failure();
- if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
- return mlir::failure();
- if (parser.parseComma())
- return mlir::failure();
-
- mlir::Type intype;
- if (parser.parseType(intype))
- return mlir::failure();
- result.addAttribute("in_type", mlir::TypeAttr::get(intype));
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
- llvm::SmallVector<mlir::Type> typeVec;
- if (!parser.parseOptionalLParen()) {
- // parse the LEN params of the derived type. (<params> : <types>)
- if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
- parser.parseColonTypeList(typeVec) || parser.parseRParen())
- return mlir::failure();
- typeparamsSize = operands.size();
- hasOperands = true;
- }
- std::int32_t shapeSize = 0;
- if (!parser.parseOptionalComma()) {
- // parse size to scale by, vector of n dimensions of type index
- if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
- return mlir::failure();
- shapeSize = operands.size() - typeparamsSize;
- auto idxTy = builder.getIndexType();
- for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
- typeVec.push_back(idxTy);
- hasOperands = true;
- }
- if (hasOperands &&
- parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
- result.operands))
- return mlir::failure();
-
- mlir::Type restype = builder.getIntegerType(64);
- if (!restype) {
- parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
- return mlir::failure();
- }
- llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
- result.addAttribute("operandSegmentSizes",
- builder.getDenseI32ArrayAttr(segmentSizes));
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.addTypeToList(restype, result.types))
- return mlir::failure();
- return mlir::success();
-}
-
-mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
- return parseTargetAllocMemOp(parser, result);
-}
-
-void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
- p << " ";
- p.printOperand(getDevice());
- p << " : ";
- p << getDevice().getType();
- p << ", ";
- p << getInType();
- if (!getTypeparams().empty()) {
- p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
- }
- for (auto sh : getShape()) {
- p << ", ";
- p.printOperand(sh);
- }
- p.printOptionalAttrDict((*this)->getAttrs(),
- {"in_type", "operandSegmentSizes"});
-}
-
-llvm::LogicalResult omp::TargetAllocMemOp::verify() {
- mlir::Type outType = getType();
- if (!mlir::dyn_cast<IntegerType>(outType))
- return emitOpError("must be a integer type");
- return mlir::success();
-}
-
//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..0cc4b522db466 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3139,3 +3139,17 @@ func.func @invalid_workdistribute() -> () {
}
return
}
+
+// -----
+func.func @target_allocmem_invalid_uniq_name(%device : i32) -> () {
+// expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.target_allocmem %device : i32, i64 {uniq_name=2}
+ return
+}
+
+// -----
+func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
+// expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index cbd863f88fd1f..9e7287178ff66 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3321,3 +3321,27 @@ func.func @omp_workdistribute() {
}
return
}
+
+// CHECK-LABEL: func.func @omp_target_allocmem(
+// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
+func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, i64
+ %0 = omp.target_allocmem %device : i32, i64
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"}
+ %1 = omp.target_allocmem %device : i32, vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"}
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32)
+ %2 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32)
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr, %[[X]], %[[Y]]
+ %3 = omp.target_allocmem %device : i32, !llvm.ptr, %x, %y
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]]
+ %4 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y
+ return
+}
+
+// CHECK-LABEL: func.func @omp_target_freemem(
+// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
+func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+ // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
+ omp.target_freemem %device, %ptr : i32, i64
+ return
+}
More information about the llvm-branch-commits
mailing list